In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import math
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() *
            (-math.log(10000.0) / d_model)
        )

        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)



class G4Predictor(nn.Module):
    def __init__(self, seq_length=201, embed_dim=64, num_heads=4, dropout=0.3):
        super().__init__()

        self.cnn = nn.Sequential(
            nn.Conv1d(4, 32, kernel_size=9, padding=4),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Dropout(dropout),

            nn.Conv1d(32, embed_dim, kernel_size=5, padding=2),
            nn.BatchNorm1d(embed_dim),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Dropout(dropout)
        )

        self.pos_encoder = PositionalEncoding(embed_dim, dropout)
        encoder_layer = TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=256, 
            dropout=dropout,
            batch_first=True
        )
        self.transformer = TransformerEncoder(encoder_layer, num_layers=3) 

        self.access_proj = nn.Sequential(
            nn.Linear(1, embed_dim),
            nn.LayerNorm(embed_dim)
        )

        self.classifier = nn.Sequential(
            nn.Linear(embed_dim * (seq_length//4), 1024),
            nn.LayerNorm(1024),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(1024, 1)
        )

    def forward(self, seq_onehot, access):
        x = self.cnn(seq_onehot) 
        x = x.permute(0, 2, 1) 

        x = self.pos_encoder(x)
        x = self.transformer(x)

        access_feat = self.access_proj(access)
        x = x + access_feat.unsqueeze(1)  

        x = x.flatten(1)
        return self.classifier(x).squeeze(1)


def dna_to_onehot(seqs):
    mapping = {"A": [1,0,0,0], "C": [0,1,0,0], "G": [0,0,1,0], "T": [0,0,0,1], "N": [0,0,0,0]}
    seq_length = len(seqs[0])
    onehot = np.zeros((len(seqs), 4, seq_length), dtype=np.float32)
    for i, seq in enumerate(seqs):
        for j, base in enumerate(seq):
            onehot[i, :, j] = mapping.get(base, [0, 0, 0, 0])
    return onehot


data = pd.read_csv("training_data.csv")
sequences = data["sequence"].values
access = data["is_open"].values
labels = data["label"].values

X = dna_to_onehot(sequences)
X = torch.tensor(X, dtype=torch.float32)
access = torch.tensor(access, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(labels, dtype=torch.float32)

X_train, X_test, a_train, a_test, y_train, y_test = train_test_split(
    X, access, y, test_size=0.2, random_state=42, stratify=y
)

train_dataset = torch.utils.data.TensorDataset(X_train, a_train, y_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, pin_memory=True)

val_dataset = torch.utils.data.TensorDataset(X_test, a_test, y_test)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False, pin_memory=True)



def train_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = G4Predictor().to(device)


    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, steps_per_epoch=len(train_loader), epochs=50)


    pos_weight = torch.tensor([(len(y_train) - sum(y_train)) / sum(y_train)]).to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    best_val_acc = 0
    early_stop_patience = 5
    no_improve = 0

    for epoch in range(10):
        model.train()
        train_loss = 0
        for batch_X, batch_a, batch_y in train_loader:
            batch_X, batch_a, batch_y = batch_X.to(device), batch_a.to(device), batch_y.to(device)

            optimizer.zero_grad()

            with torch.amp.autocast('cuda'):
                outputs = model(batch_X, batch_a)
                loss = criterion(outputs, batch_y)

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        val_loss, val_acc, val_auc, val_f1 = validate(model, val_loader, criterion, device)

        scheduler.step()

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            no_improve = 0
            torch.save(model.state_dict(), 'weights.pt')
        else:
            no_improve += 1
            if no_improve >= early_stop_patience:
                print(f"Early stopping at epoch {epoch + 1}")
                break

        print(f"Epoch {epoch + 1}: "
              f"Train Loss: {train_loss / len(train_loader):.4f}, "
              f"Val Loss: {val_loss:.4f}, "
              f"Val Acc: {val_acc:.4f}, "
              f"Val AUC: {val_auc:.4f}, "
              f"Val F1: {val_f1:.4f}")


def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch_X, batch_a, batch_y in loader:
            batch_X, batch_a, batch_y = batch_X.to(device), batch_a.to(device), batch_y.to(device)

            outputs = model(batch_X, batch_a)
            loss = criterion(outputs, batch_y)
            total_loss += loss.item()

            preds = torch.sigmoid(outputs).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(batch_y.cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    acc = accuracy_score(all_labels, all_preds > 0.5)
    auc = roc_auc_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds > 0.5)

    return total_loss / len(loader), acc, auc, f1


In [None]:
train_model()

Epoch 1: Train Loss: 0.5357, Val Loss: 0.5489, Val Acc: 0.7186, Val AUC: 0.8223, Val F1: 0.7485
Epoch 2: Train Loss: 0.4922, Val Loss: 0.5033, Val Acc: 0.7449, Val AUC: 0.8446, Val F1: 0.7651
Epoch 3: Train Loss: 0.4796, Val Loss: 0.4825, Val Acc: 0.7561, Val AUC: 0.8572, Val F1: 0.7739
Epoch 4: Train Loss: 0.4726, Val Loss: 0.4844, Val Acc: 0.7522, Val AUC: 0.8616, Val F1: 0.7763
Epoch 5: Train Loss: 0.4668, Val Loss: 0.4613, Val Acc: 0.7705, Val AUC: 0.8700, Val F1: 0.7862
Epoch 6: Train Loss: 0.4613, Val Loss: 0.4686, Val Acc: 0.7686, Val AUC: 0.8719, Val F1: 0.7876
Epoch 7: Train Loss: 0.4560, Val Loss: 0.4547, Val Acc: 0.7746, Val AUC: 0.8794, Val F1: 0.7938
Epoch 8: Train Loss: 0.4517, Val Loss: 0.4567, Val Acc: 0.7760, Val AUC: 0.8800, Val F1: 0.7949
Epoch 9: Train Loss: 0.4473, Val Loss: 0.4742, Val Acc: 0.7610, Val AUC: 0.8825, Val F1: 0.7901
Epoch 10: Train Loss: 0.4441, Val Loss: 0.4714, Val Acc: 0.7698, Val AUC: 0.8849, Val F1: 0.7953
