In [1]:
# %pip uninstall torch torchvision torchaudio transformers mamba-ssm tqdm pandas scikit-learn numpy -y
# %pip cache purge
# %pip install peft torch transformers mamba-ssm tqdm pandas scikit-learn numpy


In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix

In [2]:
splice_df = pd.read_csv('../../splice_sites_full_centered_balanced_correct_V3.csv')
splice_df.head()

Unnamed: 0,id,coord,kind,transcript,strand,chrom,start,end,sequence,win_start,win_end,is_truncated,motif_len,motif
0,NC_050096.1_35318_donor,35318,0,XM_020544715.3,+,NC_050096.1,35319,35320,GGGCCCGGCTGGGCCTCAGCGGGGTCGTCGAGATGGAGATGGGGAG...,35118,35520,False,2,GT
1,NC_050096.1_34607_acceptor,34607,1,XM_020544715.3,+,NC_050096.1,34605,34606,TCCGGTGATTAATTTGTCCTTATACCTTTACAACAAAAATTCACTA...,34404,34806,False,2,TG
2,NC_050096.1_36174_donor,36174,0,XM_020544715.3,+,NC_050096.1,36175,36176,ATAATATGTTCATTATATCACAACACTCTTTTCTTATGGAGTCGTG...,35974,36376,False,2,GT
3,NC_050096.1_36037_acceptor,36037,1,XM_020544715.3,+,NC_050096.1,36035,36036,GCACAAAACTAACTAAAGGAATCATTCTGATAGATAACACTATAAA...,35834,36236,False,2,AG
4,NC_050096.1_36504_donor,36504,0,XM_020544715.3,+,NC_050096.1,36505,36506,TGTCATTTCCTTACCTCATTGAATCATTTCCGATGCTTCTTCTCTG...,36304,36706,False,2,GT


In [3]:
# Split data with stratification
train_df, temp_df = train_test_split(splice_df, test_size=0.3, stratify=splice_df['kind'], random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['kind'], random_state=42)

## Full Pipeline

In [41]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")

PyTorch version: 2.4.0+cu121
CUDA available: True
CUDA version: 12.1


In [42]:
# Cell 1: Imports and Setup
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, AutoConfig
from mamba_ssm import Mamba
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from sklearn.metrics import classification_report, accuracy_score
import os

# Disable tokenizer parallelism to avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [55]:
# Cell 2: Dataset Class (Updated)
class DNASequenceDataset(Dataset):
    def __init__(self, sequences, labels, max_length=512):
        self.sequences = sequences
        self.labels = torch.tensor(labels, dtype=torch.long)
        self.max_length = max_length
        
        # Nucleotide to index mapping (ACGT)
        self.nuc_to_idx = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
        
        # Preprocess sequences
        self.seq_encodings = []
        self.attention_masks = []
        
        for seq in sequences:
            # Convert sequence to indices
            encoding = [self.nuc_to_idx.get(nuc, 0) for nuc in seq[:max_length]]
            encoding += [0] * (max_length - len(encoding))  # Pad with A
            self.seq_encodings.append(torch.tensor(encoding, dtype=torch.long))
            
            # Create attention mask (1 for actual sequence, 0 for padding)
            attention_mask = [1] * min(len(seq), max_length)
            attention_mask += [0] * (max_length - len(attention_mask))
            self.attention_masks.append(torch.tensor(attention_mask, dtype=torch.long))
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return {
            'sequence': self.seq_encodings[idx],
            'attention_mask': self.attention_masks[idx],
            'label': self.labels[idx]
        }

In [54]:
# Cell 3: Mamba Model (Updated)
class MambaModel(nn.Module):
    def __init__(self, d_model=128, d_state=16, d_conv=4, expand=2, num_mamba_layers=4, n_classes=3, max_length=512):
        super().__init__()
        self.nuc_embedding = nn.Embedding(6, d_model)
        self.pos_embedding = nn.Embedding(max_length, d_model)
        
        nn.init.xavier_uniform_(self.nuc_embedding.weight)
        nn.init.xavier_uniform_(self.pos_embedding.weight)
        
        # Configure Mamba layers with explicit settings
        self.mamba = nn.Sequential(*[
            nn.Sequential(
                nn.LayerNorm(d_model),
                Mamba(
                    d_model=d_model,
                    d_state=d_state,
                    d_conv=d_conv,
                    expand=expand,
                    use_fast_path=True,
                    dt_rank="auto",  # Add this to avoid potential issues
                ),
                nn.Dropout(0.1)
            ) for _ in range(num_mamba_layers)
        ])
        
        self.attention_pool = nn.Sequential(
            nn.Linear(d_model, 1),
            nn.Softmax(dim=1)
        )
        
        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Dropout(0.1),
            nn.Linear(d_model, n_classes)
        )
    
    def forward(self, sequence):
        pos_ids = torch.arange(sequence.size(1), device=sequence.device).unsqueeze(0).expand(sequence.size(0), -1)
        x = self.nuc_embedding(sequence) + self.pos_embedding(pos_ids)
        x = self.mamba(x)
        weights = self.attention_pool(x)
        x = (x * weights).sum(dim=1)
        logits = self.classifier(x)
        return logits
    

# Cell 4: DNABert2 Model (Simplified BERT)
class DNABert2Model(nn.Module):
    def __init__(self, n_classes=3, max_length=512):
        super().__init__()
        # Use standard BERT model instead of DNABERT-2
        self.bert = AutoModel.from_pretrained("bert-base-uncased")
        
        # Add a projection layer to match DNA sequence vocabulary
        self.dna_projection = nn.Linear(4, 768)  # 4 for ACGT
        
        # Freeze BERT parameters
        for param in self.bert.parameters():
            param.requires_grad = False
            
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(768, n_classes)
        
    def forward(self, input_ids, attention_mask):
        # Project DNA sequence embeddings to BERT dimension
        batch_size, seq_len = input_ids.shape
        dna_embeddings = torch.zeros(batch_size, seq_len, 4, device=input_ids.device)
        
        # One-hot encode the DNA sequence
        for i in range(4):  # ACGT
            dna_embeddings[:, :, i] = (input_ids == i).float()
            
        # Project to BERT dimension
        bert_input = self.dna_projection(dna_embeddings)
        
        # Get BERT outputs
        outputs = self.bert(
            inputs_embeds=bert_input,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        
        # Use the last hidden state's [CLS] token
        pooled_output = outputs.last_hidden_state[:, 0, :]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

In [57]:
# Cell 5: Ensemble Model (Updated)
class EnsembleModel(nn.Module):
    def __init__(self, mamba_model, bert_model, n_classes=3):
        super().__init__()
        self.mamba_model = mamba_model
        self.bert_model = bert_model
        self.ensemble_weight = nn.Parameter(torch.tensor([0.5]))
        
    def forward(self, sequence, attention_mask, labels=None):
        # Get predictions from both models
        mamba_logits = self.mamba_model(sequence)
        bert_logits = self.bert_model(sequence, attention_mask)
        
        # Combine predictions using learned weight
        weight = torch.sigmoid(self.ensemble_weight)
        combined_logits = weight * mamba_logits + (1 - weight) * bert_logits
        
        if labels is not None:
            loss = nn.CrossEntropyLoss(label_smoothing=0.1)(combined_logits, labels)
            return loss, combined_logits
        return None, combined_logits

In [58]:
# Cell 6: Model and Data Preparation (Updated)
def prepare_model_and_data(train_df, test_df, sample_size=5000, max_length=512):
    # Prepare datasets
    train_sample = train_df.groupby('kind', group_keys=False).apply(
        lambda x: x.sample(n=min(len(x), sample_size // 3))
    ).reset_index(drop=True)
    
    test_sample = test_df.groupby('kind', group_keys=False).apply(
        lambda x: x.sample(n=min(len(x), sample_size // 15))
    ).reset_index(drop=True)
    
    train_dataset = DNASequenceDataset(
        sequences=train_sample['sequence'].tolist(),
        labels=train_sample['kind'].tolist(),
        max_length=max_length
    )
    
    test_dataset = DNASequenceDataset(
        sequences=test_sample['sequence'].tolist(),
        labels=test_sample['kind'].tolist(),
        max_length=max_length
    )
    
    train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=8)
    
    # Initialize models
    mamba_model = MambaModel(
        d_model=128,
        d_state=16,
        d_conv=4,
        expand=1,
        num_mamba_layers=4,
        n_classes=3
    ).to(device)
    
    bert_model = DNABert2Model(n_classes=3).to(device)
    
    # Create ensemble model
    ensemble_model = EnsembleModel(mamba_model, bert_model).to(device)
    
    return ensemble_model, train_dataloader, test_dataloader, device

In [61]:
# Cell 7: Training Function (Updated)
def train_model(model, train_dataloader, val_dataloader, device, 
                num_epochs=20, patience=3, learning_rate=1e-4):
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2)
    
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None
    
    history = {
        'epoch': [], 'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': [], 'learning_rate': []
    }
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_train_loss = 0
        train_correct = 0
        train_total = 0
        
        for batch in tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
            sequence = batch['sequence'].to(device)  # Use 'sequence' instead of 'input_ids'
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            optimizer.zero_grad()
            loss, logits = model(sequence, attention_mask, labels)  # Pass attention_mask instead of input_ids
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            total_train_loss += loss.item()
            predictions = torch.argmax(logits, dim=1)
            train_correct += (predictions == labels).sum().item()
            train_total += labels.size(0)
        
        avg_train_loss = total_train_loss / len(train_dataloader)
        train_accuracy = train_correct / train_total
        
        # Validation phase
        model.eval()
        total_val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for batch in val_dataloader:
                sequence = batch['sequence'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)
                
                loss, logits = model(sequence, attention_mask, labels)
                total_val_loss += loss.item()
                
                predictions = torch.argmax(logits, dim=1)
                val_correct += (predictions == labels).sum().item()
                val_total += labels.size(0)
        
        avg_val_loss = total_val_loss / len(val_dataloader)
        val_accuracy = val_correct / val_total
        
        # Update learning rate
        scheduler.step(avg_val_loss)
        current_lr = optimizer.param_groups[0]['lr']
        
        # Print metrics
        print(f"Epoch {epoch + 1}/{num_epochs}:")
        print(f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.4f}")
        print(f"Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.4f}")
        print(f"Learning Rate: {current_lr}")
        
        # Update history
        history['epoch'].append(epoch + 1)
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['train_acc'].append(train_accuracy)
        history['val_acc'].append(val_accuracy)
        history['learning_rate'].append(current_lr)
        
        # Early stopping
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_state = model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered after {epoch + 1} epochs")
                break
    
    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    # Convert history to DataFrame
    history_df = pd.DataFrame(history)
    return history_df

In [62]:
# Cell 8: Train the Ensemble Model
model, train_dataloader, test_dataloader, device = prepare_model_and_data(train_df, test_df)
history = train_model(model, train_dataloader, test_dataloader, device)

  train_sample = train_df.groupby('kind', group_keys=False).apply(
  test_sample = test_df.groupby('kind', group_keys=False).apply(


Epoch 1/20:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 1/20:
Train Loss: 1.0123, Train Acc: 0.4790
Val Loss: 0.9082, Val Acc: 0.5976
Learning Rate: 0.0001


Epoch 2/20:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 2/20:
Train Loss: 0.7817, Train Acc: 0.7027
Val Loss: 0.6396, Val Acc: 0.8048
Learning Rate: 0.0001


Epoch 3/20:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 3/20:
Train Loss: 0.5924, Train Acc: 0.8385
Val Loss: 0.5922, Val Acc: 0.8549
Learning Rate: 0.0001


Epoch 4/20:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 4/20:
Train Loss: 0.5448, Train Acc: 0.8659
Val Loss: 0.5792, Val Acc: 0.8549
Learning Rate: 0.0001


Epoch 5/20:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 5/20:
Train Loss: 0.4953, Train Acc: 0.8962
Val Loss: 0.6940, Val Acc: 0.7998
Learning Rate: 0.0001


Epoch 6/20:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 6/20:
Train Loss: 0.4675, Train Acc: 0.9120
Val Loss: 0.5755, Val Acc: 0.8559
Learning Rate: 0.0001


Epoch 7/20:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 7/20:
Train Loss: 0.4382, Train Acc: 0.9276
Val Loss: 0.5889, Val Acc: 0.8639
Learning Rate: 0.0001


Epoch 8/20:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 8/20:
Train Loss: 0.4172, Train Acc: 0.9374
Val Loss: 0.5846, Val Acc: 0.8619
Learning Rate: 0.0001


Epoch 9/20:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 9/20:
Train Loss: 0.3929, Train Acc: 0.9484
Val Loss: 0.5784, Val Acc: 0.8659
Learning Rate: 1e-05
Early stopping triggered after 9 epochs


In [63]:
model.ensemble_weight

Parameter containing:
tensor([0.5253], device='cuda:0', requires_grad=True)