In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import math
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score
import warnings

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

In [4]:
class LearnablePositionalEncoding(nn.Module):
    def __init__(self, seq_len, embed_dim):
        super().__init__()
        self.pe = nn.Parameter(torch.randn(1, seq_len, embed_dim))

    def forward(self, x):
        return x + self.pe


class ResidualCNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dropout):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(),
            nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2),
            nn.BatchNorm1d(out_channels),
        )
        self.shortcut = nn.Conv1d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.relu(self.conv(x) + self.shortcut(x)))


class StableG4Predictor(nn.Module):
    def __init__(self, seq_length=201, embed_dim=256, num_heads=8, dropout=0.3):
        super().__init__()

        # ------------------------
        # Deep CNN Feature Extractor
        # ------------------------
        self.cnn = nn.Sequential(
            ResidualCNNBlock(4, 64, kernel_size=9, dropout=dropout),
            nn.MaxPool1d(2),
            ResidualCNNBlock(64, 128, kernel_size=7, dropout=dropout),
            nn.MaxPool1d(2),
            ResidualCNNBlock(128, embed_dim, kernel_size=5, dropout=dropout),
            nn.Dropout(dropout)
        )

        reduced_len = seq_length // 4  # After two max pools

        # ------------------------
        # Learnable Positional Encoding
        # ------------------------
        self.pos_encoder = LearnablePositionalEncoding(seq_len=reduced_len, embed_dim=embed_dim)

        # ------------------------
        # Bigger Transformer Encoder
        # ------------------------
        encoder_layer = TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=2048,
            dropout=dropout,
            batch_first=True,
            norm_first=True  # More stable
        )
        self.transformer = TransformerEncoder(encoder_layer, num_layers=6)

        # ------------------------
        # Accessibility Integration
        # ------------------------
        self.access_proj = nn.Sequential(
            nn.Linear(1, embed_dim),
            nn.LayerNorm(embed_dim)
        )

        # ------------------------
        # Deep Classifier Head
        # ------------------------
        self.classifier = nn.Sequential(
            nn.LayerNorm(embed_dim * reduced_len),
            nn.Linear(embed_dim * reduced_len, 2048),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(2048, 512),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(512, 1)
        )

    def forward(self, seq_onehot, access):
        # CNN
        x = self.cnn(seq_onehot)  # (batch, embed_dim, seq_len/4)
        x = x.permute(0, 2, 1)    # (batch, seq_len/4, embed_dim)

        # Add positional encoding
        x = self.pos_encoder(x)

        # Transformer
        x = self.transformer(x)

        # Broadcast accessibility and add
        access_feat = self.access_proj(access)
        x = x + access_feat.unsqueeze(1)

        # Flatten and classify
        x = x.flatten(1)
        return self.classifier(x).squeeze(1)

        # ---------------------------
# 3. Data Preprocessing
# ---------------------------
def dna_to_onehot(seqs):
    mapping = {"A": [1,0,0,0], "C": [0,1,0,0], "G": [0,0,1,0], "T": [0,0,0,1], "N": [0,0,0,0]}
    seq_length = len(seqs[0])
    onehot = np.zeros((len(seqs), 4, seq_length), dtype=np.float32)
    for i, seq in enumerate(seqs):
        for j, base in enumerate(seq):
            onehot[i, :, j] = mapping.get(base, [0, 0, 0, 0])
    return onehot

# ---------------------------
# 4. Load and Prepare Data
# ---------------------------
data = pd.read_csv("training_data.csv")
sequences = data["sequence"].values
access = data["is_open"].values
labels = data["label"].values

# One-hot encoding
X = dna_to_onehot(sequences)
X = torch.tensor(X, dtype=torch.float32)
access = torch.tensor(access, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(labels, dtype=torch.float32)

# Train/test split with stratification
X_train, X_test, a_train, a_test, y_train, y_test = train_test_split(
    X, access, y, test_size=0.2, random_state=42, shuffle = True
)

# Create DataLoaders
train_dataset = torch.utils.data.TensorDataset(X_train, a_train, y_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, pin_memory=True)

val_dataset = torch.utils.data.TensorDataset(X_test, a_test, y_test)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False, pin_memory=True)



# ---------------------------
# 5. Training and Validation
# ---------------------------
def train_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = StableG4Predictor().to(device)

    # Optimizer & Scheduler

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, steps_per_epoch=len(train_loader), epochs=25)


    # Weighted loss to handle imbalance
    criterion = nn.BCEWithLogitsLoss()

    best_val_acc = 0
    early_stop_patience = 5
    no_improve = 0

    for epoch in range(10):
        model.train()
        train_loss = 0
        for batch_X, batch_a, batch_y in train_loader:
            batch_X, batch_a, batch_y = batch_X.to(device), batch_a.to(device), batch_y.to(device)

            optimizer.zero_grad()

            # Mixed precision for faster training
            with torch.amp.autocast('cuda'):
                outputs = model(batch_X, batch_a)
                loss = criterion(outputs, batch_y)

            # Gradient Clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # Validate after each epoch
        val_loss, val_acc, val_auc, val_f1 = validate(model, val_loader, criterion, device)

        # Update scheduler
        scheduler.step()

        # Early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            no_improve = 0
            torch.save(model.state_dict(), 'weights.pt')
        else:
            no_improve += 1
            if no_improve >= early_stop_patience:
                print(f"Early stopping at epoch {epoch + 1}")
                break

        print(f"Epoch {epoch + 1}: "
              f"Train Loss: {train_loss / len(train_loader):.4f}, "
              f"Val Loss: {val_loss:.4f}, "
              f"Val Acc: {val_acc:.4f}, "
              f"Val AUC: {val_auc:.4f}, "
              f"Val F1: {val_f1:.4f}")


# ---------------------------
# 6. Validation and Evaluation
# ---------------------------
def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch_X, batch_a, batch_y in loader:
            batch_X, batch_a, batch_y = batch_X.to(device), batch_a.to(device), batch_y.to(device)

            outputs = model(batch_X, batch_a)
            loss = criterion(outputs, batch_y)
            total_loss += loss.item()

            preds = torch.sigmoid(outputs).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(batch_y.cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # Calculate metrics
    acc = accuracy_score(all_labels, all_preds > 0.5)
    auc = roc_auc_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds > 0.5)

    return total_loss / len(loader), acc, auc, f1



In [5]:
train_model()

RuntimeError: CUDA error: the launch timed out and was terminated
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
