In [5]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix

In [2]:
splice_df = pd.read_csv('splice_sites_full_centered_balanced_correct_V3.csv')
splice_df.head()

Unnamed: 0,id,coord,kind,transcript,strand,chrom,start,end,sequence,win_start,win_end,is_truncated,motif_len,motif
0,NC_050096.1_35318_donor,35318,0,XM_020544715.3,+,NC_050096.1,35319,35320,GGGCCCGGCTGGGCCTCAGCGGGGTCGTCGAGATGGAGATGGGGAG...,35118,35520,False,2,GT
1,NC_050096.1_34607_acceptor,34607,1,XM_020544715.3,+,NC_050096.1,34605,34606,TCCGGTGATTAATTTGTCCTTATACCTTTACAACAAAAATTCACTA...,34404,34806,False,2,TG
2,NC_050096.1_36174_donor,36174,0,XM_020544715.3,+,NC_050096.1,36175,36176,ATAATATGTTCATTATATCACAACACTCTTTTCTTATGGAGTCGTG...,35974,36376,False,2,GT
3,NC_050096.1_36037_acceptor,36037,1,XM_020544715.3,+,NC_050096.1,36035,36036,GCACAAAACTAACTAAAGGAATCATTCTGATAGATAACACTATAAA...,35834,36236,False,2,AG
4,NC_050096.1_36504_donor,36504,0,XM_020544715.3,+,NC_050096.1,36505,36506,TGTCATTTCCTTACCTCATTGAATCATTTCCGATGCTTCTTCTCTG...,36304,36706,False,2,GT


In [3]:
# Split data with stratification
train_df, temp_df = train_test_split(splice_df, test_size=0.3, stratify=splice_df['kind'], random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['kind'], random_state=42)

## Full Pipeline

In [3]:
# Cell 1: Imports and Setup
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from mamba_ssm import Mamba
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import pandas as pd
from sklearn.metrics import classification_report, accuracy_score
import numpy as np

ImportError: /nas/ucb/kishorechidambaram/anaconda3/envs/mamba_env/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so: undefined symbol: iJIT_NotifyEvent

In [None]:
# Cell 2: Dataset Class
class DNASequenceAndEmbeddingDataset(Dataset):
    def __init__(self, sequences, labels, tokenizer, bert_model, device, max_length=128):
        self.sequences = sequences
        self.labels = labels
        self.tokenizer = tokenizer
        self.bert_model = bert_model
        self.max_length = max_length
        self.device = device
        
        # Nucleotide to index mapping
        self.nuc_to_idx = {'A': 0, 'C': 1, 'G': 2, 'T': 3, 'N': 4, 'PAD': 5}
        
        # Pre-compute embeddings and sequence encodings
        self.embeddings = []
        self.sequence_encodings = []
        self.bert_model.eval()
        
        print("Computing embeddings and sequence encodings...")
        with torch.no_grad():
            for seq in tqdm(sequences):
                # Get BERT embeddings
                inputs = tokenizer(
                    seq,
                    padding="max_length",
                    truncation=True,
                    max_length=self.max_length,
                    return_tensors="pt"
                )
                inputs = {k: v.to(device) for k, v in inputs.items()}
                outputs = self.bert_model(**inputs)
                embeddings = outputs.last_hidden_state.cpu()
                
                # Create sequence encoding
                seq_encoding = torch.tensor([
                    self.nuc_to_idx.get(nuc, self.nuc_to_idx['N']) 
                    for nuc in (seq[:max_length] + 'P' * max(0, max_length - len(seq)))
                ], dtype=torch.long)
                
                self.embeddings.append(embeddings.squeeze(0))
                self.sequence_encodings.append(seq_encoding)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            'embeddings': self.embeddings[idx],
            'sequence': self.sequence_encodings[idx],
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }

In [None]:
# Cell 3: Model Architecture
class DNABertMambaWithSequenceClassifier(nn.Module):
    def __init__(self, d_model=768, n_classes=3, d_state=16, d_conv=4, expand=2):
        super().__init__()
        
        # Nucleotide embedding
        self.nuc_embedding = nn.Embedding(6, 32)  # 6 tokens (ACGTN + PAD), 32 dimensions
        
        # Combine sequence and BERT embeddings
        self.combine_layer = nn.Sequential(
            nn.Linear(d_model + 32, d_model),
            nn.LayerNorm(d_model),
            nn.ReLU()
        )
        
        # Mamba configuration
        self.mamba = Mamba(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand,
            use_fast_path=True  # Added for PyTorch 1.12.0 compatibility
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Dropout(0.1),
            nn.Linear(d_model, n_classes)
        )
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, embeddings, sequence, labels=None):
        # Get nucleotide embeddings
        seq_embeddings = self.nuc_embedding(sequence)
        
        # Combine BERT embeddings with sequence embeddings
        combined = torch.cat([embeddings, seq_embeddings], dim=-1)
        x = self.combine_layer(combined)
        
        # Pass through Mamba
        x = self.mamba(x)
        
        # Global average pooling
        x = x.mean(dim=1)
        
        # Classification
        logits = self.classifier(x)
        
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
            return loss, logits
        return None, logits

In [None]:
# Cell 4: Data Preparation and Model Initialization
def prepare_model_and_data(train_df, test_df, sample_size=5000):
    # Sample the data
    train_sample_size = min(sample_size, len(train_df))
    test_sample_size = min(sample_size//5, len(test_df))
    
    # Stratified sampling
    train_df_sample = train_df.groupby('kind', group_keys=False).apply(
        lambda x: x.sample(n=min(len(x), train_sample_size // 3))
    ).reset_index(drop=True)
    
    test_df_sample = test_df.groupby('kind', group_keys=False).apply(
        lambda x: x.sample(n=min(len(x), test_sample_size // 3))
    ).reset_index(drop=True)
    
    # Initialize BERT
    model_name = "zhihan1996/DNABERT-2-117M"
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    bert_model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Freeze BERT and move to device
    for param in bert_model.parameters():
        param.requires_grad = False
    bert_model = bert_model.to(device)
    
    # Create datasets
    train_dataset = DNASequenceAndEmbeddingDataset(
        sequences=train_df_sample['sequence'].tolist(),
        labels=train_df_sample['kind'].tolist(),
        tokenizer=tokenizer,
        bert_model=bert_model,
        device=device
    )
    
    test_dataset = DNASequenceAndEmbeddingDataset(
        sequences=test_df_sample['sequence'].tolist(),
        labels=test_df_sample['kind'].tolist(),
        tokenizer=tokenizer,
        bert_model=bert_model,
        device=device
    )
    
    # Create dataloaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )
    
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
    
    # Initialize model
    model = DNABertMambaWithSequenceClassifier(
        d_model=768,
        n_classes=3,
        d_state=16,
        d_conv=4,
        expand=2
    )
    model = model.to(device)
    
    return model, train_dataloader, test_dataloader, device

In [None]:
# Cell 5: Training Function
def train_model(model, train_dataloader, val_dataloader, device, 
                num_epochs=10,
                patience=3,
                learning_rate=1e-4):
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2, verbose=True
    )
    
    best_val_loss = float('inf')
    patience_counter = 0
    training_history = []
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_train_loss = 0
        train_preds = []
        train_labels = []
        
        progress_bar = tqdm(train_dataloader, desc=f'Epoch {epoch+1}/{num_epochs} - Training')
        for batch in progress_bar:
            embeddings = batch['embeddings'].to(device)
            sequence = batch['sequence'].to(device)
            labels = batch['labels'].to(device)
            
            optimizer.zero_grad()
            loss, logits = model(embeddings, sequence, labels)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Added gradient clipping
            optimizer.step()
            
            total_train_loss += loss.item()
            predictions = torch.argmax(logits, dim=-1)
            train_preds.extend(predictions.cpu().numpy())
            train_labels.extend(labels.cpu().numpy())
            
            progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        # Validation phase
        model.eval()
        total_val_loss = 0
        val_preds = []
        val_labels = []
        
        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc='Validation'):
                embeddings = batch['embeddings'].to(device)
                sequence = batch['sequence'].to(device)
                labels = batch['labels'].to(device)
                
                loss, logits = model(embeddings, sequence, labels)
                
                total_val_loss += loss.item()
                predictions = torch.argmax(logits, dim=-1)
                val_preds.extend(predictions.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())
        
        # Calculate metrics
        avg_train_loss = total_train_loss / len(train_dataloader)
        avg_val_loss = total_val_loss / len(val_dataloader)
        train_acc = accuracy_score(train_labels, train_preds)
        val_acc = accuracy_score(val_labels, val_preds)
        
        # Update learning rate
        scheduler.step(avg_val_loss)
        
        print(f'\nEpoch {epoch+1}:')
        print(f'Training Loss: {avg_train_loss:.4f} | Training Accuracy: {train_acc:.4f}')
        print(f'Validation Loss: {avg_val_loss:.4f} | Validation Accuracy: {val_acc:.4f}')
        print('\nValidation Classification Report:')
        print(classification_report(val_labels, val_preds))
        
        training_history.append({
            'epoch': epoch+1,
            'train_loss': avg_train_loss,
            'train_acc': train_acc,
            'val_loss': avg_val_loss,
            'val_acc': val_acc
        })
        
        # Early stopping
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_val_loss,
            }, 'best_mamba_model.pt')
            print("Saved new best model!")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered after epoch {epoch+1}")
                break
    
    return pd.DataFrame(training_history)

In [None]:
# Cell 6: Run the Pipeline
# Initialize model and data
model, train_dataloader, test_dataloader, device = prepare_model_and_data(
    train_df=train_df,
    test_df=test_df,
    sample_size=5000
)

# Train the model
history_df = train_model(
    model=model,
    train_dataloader=train_dataloader,
    val_dataloader=test_dataloader,
    device=device,
    num_epochs=10,
    patience=3,
    learning_rate=1e-4
)

In [None]:
# Plot training history
import matplotlib.pyplot as plt

plt.figure(figsize=(15, 5))

plt.subplot(1, 2, 1)
plt.plot(history_df['train_loss'], label='Train Loss')
plt.plot(history_df['val_loss'], label='Val Loss')
plt.title('Loss Over Time')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history_df['train_acc'], label='Train Acc')
plt.plot(history_df['val_acc'], label='Val Acc')
plt.title('Accuracy Over Time')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()