In [None]:
#run in terminal
# !pip uninstall -y triton
# !pip uninstall -y torch torchvision torchaudio
# !pip uninstall -y transformers
# !pip cache purge
# !pip install torch==2.2.0
# !pip install transformers

Found existing installation: triton 2.2.0
Uninstalling triton-2.2.0:
  Successfully uninstalled triton-2.2.0
Found existing installation: torch 2.2.0
Uninstalling torch-2.2.0:
  Successfully uninstalled torch-2.2.0
[0mFound existing installation: transformers 4.51.3
Uninstalling transformers-4.51.3:
  Successfully uninstalled transformers-4.51.3
Files removed: 84
Collecting torch==2.2.0
  Downloading torch-2.2.0-cp312-cp312-manylinux1_x86_64.whl.metadata (25 kB)
Collecting triton==2.2.0 (from torch==2.2.0)
  Downloading triton-2.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Downloading torch-2.2.0-cp312-cp312-manylinux1_x86_64.whl (755.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m755.4/755.4 MB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading triton-2.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (167.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m167.9/167.9 

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
from datasets import Dataset
import torch

## Load Splice Dataset

In [2]:
splice_df = pd.read_csv('splice_dataset_final.csv')
splice_df.head()

Unnamed: 0,id,coord,kind,transcript,strand,chrom,sequence,motif
0,NC_050096.1_35318_donor,35318,0,ID=exon-XM_020544715.3-1,+,NC_050096.1,GGGCCCGGCTGGGCCTCAGCGGGGTCGTCGAGATGGAGATGGGGAG...,GT
1,NC_050096.1_36174_donor,36174,0,ID=exon-XM_020544715.3-2,+,NC_050096.1,ATAATATGTTCATTATATCACAACACTCTTTTCTTATGGAGTCGTG...,GT
2,NC_050096.1_36504_donor,36504,0,ID=exon-XM_020544715.3-3,+,NC_050096.1,TGTCATTTCCTTACCTCATTGAATCATTTCCGATGCTTCTTCTCTG...,GT
3,NC_050096.1_36713_donor,36713,0,ID=exon-XM_020544715.3-4,+,NC_050096.1,TTCATGGTAGTCATTGGAACCTGCTAGATTGTACACTTGACAATAA...,GT
4,NC_050096.1_37004_donor,37004,0,ID=exon-XM_020544715.3-5,+,NC_050096.1,CAACTTTTTCCTTTCAGATTTCCAGTACAGTCCTCGCTATTGCTGT...,GT


In [3]:
splice_df["kind"].value_counts()

kind
2    1018230
0     509115
1     509115
Name: count, dtype: int64

In [28]:
#@title K-mer Tokenization

def seq2kmer(seq, k=6):
    """Convert DNA sequence to k-mer tokens"""
    return " ".join([seq[i:i+k] for i in range(len(seq)-k+1)])

# Assuming your DataFrame is loaded as df
splice_df['kmer_6'] = splice_df['sequence'].apply(lambda x: seq2kmer(x, k=3))

# Split data with stratification
train_df, temp_df = train_test_split(splice_df, test_size=0.3, stratify=splice_df['kind'], random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['kind'], random_state=42)

## Full Pipeline

In [29]:
#@title Imports and Setup
# Cell 1: Imports and Setup
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    BertConfig
)
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import classification_report, accuracy_score
from tqdm import tqdm
import pandas as pd

In [30]:
#@title Model Class
# Cell 2: Model Class (Simplified)
class DNABERTClassifier(torch.nn.Module):
    def __init__(self, model_name, num_labels, class_weights):
        super().__init__()
        self.config = BertConfig.from_pretrained(
            model_name,
            num_labels=num_labels
        )
        self.bert = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            config=self.config
        )
        self.loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        if labels is not None:
            loss = self.loss_fn(outputs.logits, labels)
            return loss, outputs.logits
        return None, outputs.logits

For Dataset Class, Data Preparation, and Training Function parts: Use the FULL DATASET version for full data. For most testing, use the SMALL DATASET version for faster runs (uses sample of data and another optimizations).

In [7]:
#@title Dataset Class [FULL DATASET]
# [OLD VERSION] - FULL DATASET
# Cell 3: Dataset Class
class SpliceSiteDataset(Dataset):
    def __init__(self, sequences, labels, tokenizer):
        self.sequences = sequences
        self.labels = labels
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.sequences[idx],
            padding="max_length",
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )
        
        item = {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(self.labels[idx])
        }
        return item

    def __len__(self):
        return len(self.labels)

In [31]:
#@title Dataset Class [SMALL DATASET]
# Cell 3: Dataset Class (Optimized)
class SpliceSiteDataset(Dataset):
    def __init__(self, sequences, labels, tokenizer, max_length=256):  # Reduced from 512
        self.sequences = sequences
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.sequences[idx],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        item = {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(self.labels[idx])
        }
        return item

    def __len__(self):
        return len(self.labels)

In [32]:
#@title Model Initialization
# Cell 4: Model Initialization (Modified)
# Model name
model_name = "zhihan1996/DNABERT-2-117M"

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Calculate class weights
class_counts = train_df['kind'].value_counts().sort_index()
class_weights = torch.tensor(
    [class_counts.sum() / c for c in class_counts],
    dtype=torch.float32
)

# Initialize model with specific configuration to avoid flash attention
config = BertConfig.from_pretrained(
    model_name,
    num_labels=3,
    attention_probs_dropout_prob=0.1,
    hidden_dropout_prob=0.1,
    use_cache=False
)

# Initialize model
model = DNABERTClassifier(
    model_name=model_name,
    num_labels=3,
    class_weights=class_weights
)

# Move model to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model = model.to(device)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at zhihan1996/DNABERT-2-117M and are newly initialized: ['bert.embeddings.position_embeddings.weight', 'bert.encoder.layer.0.attention.self.key.bias', 'bert.encoder.layer.0.attention.self.key.weight', 'bert.encoder.layer.0.attention.self.query.bias', 'bert.encoder.layer.0.attention.self.query.weight', 'bert.encoder.layer.0.attention.self.value.bias', 'bert.encoder.layer.0.attention.self.value.weight', 'bert.encoder.layer.0.intermediate.dense.bias', 'bert.encoder.layer.0.intermediate.dense.weight', 'bert.encoder.layer.0.output.LayerNorm.bias', 'bert.encoder.layer.0.output.LayerNorm.weight', 'bert.encoder.layer.0.output.dense.bias', 'bert.encoder.layer.0.output.dense.weight', 'bert.encoder.layer.1.attention.self.key.bias', 'bert.encoder.layer.1.attention.self.key.weight', 'bert.encoder.layer.1.attention.self.query.bias', 'bert.encoder.layer.1.attention.self.query.weight', 'bert.encoder.layer.1.at

Using device: cuda


In [9]:
#@title Data Preparation [FULL DATASET]
# Cell 5: Data Preparation
# Create datasets
train_dataset = SpliceSiteDataset(
    sequences=train_df['sequence'].tolist(),
    labels=train_df['kind'].tolist(),
    tokenizer=tokenizer
)

test_dataset = SpliceSiteDataset(
    sequences=test_df['sequence'].tolist(),
    labels=test_df['kind'].tolist(),
    tokenizer=tokenizer
)

# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16)

In [33]:
#@title Data Preparation [SMALL DATASET]

# Key optimizations made while maintaining stability:
# Reduced sequence length from 512 to 128 tokens
# Added data sampling (5000 training samples, 1000 test samples)
# Increased batch size from 16 to 32
# Added parallel data loading with num_workers=2
# Added early stopping with patience=2
# Added progress bars with loss updates
# Added validation accuracy tracking
# Improved progress monitoring and metrics display

# Cell 5: Data Preparation (Optimized with sampling)
# Sample the data for faster training
train_sample_size = min(50000, len(train_df))  # Adjust this number as needed
test_sample_size = min(10000, len(test_df))    # Adjust this number as needed

train_df_sample = train_df.sample(n=train_sample_size, random_state=42)
test_df_sample = test_df.sample(n=test_sample_size, random_state=42)

# Create datasets with reduced sequence length
train_dataset = SpliceSiteDataset(
    sequences=train_df_sample['sequence'].tolist(),
    labels=train_df_sample['kind'].tolist(),
    tokenizer=tokenizer
)

test_dataset = SpliceSiteDataset(
    sequences=test_df_sample['sequence'].tolist(),
    labels=test_df_sample['kind'].tolist(),
    tokenizer=tokenizer
)

# Create dataloaders with optimal batch size
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=32,  # Increased from 16
    shuffle=True,
    num_workers=2   # Added parallel data loading
)

test_dataloader = DataLoader(
    test_dataset, 
    batch_size=32,  # Increased from 16
    shuffle=False,
    num_workers=2   # Added parallel data loading
)

In [10]:
#@title Training Function [FULL DATASET]
# Cell 6: Training Function
def train_model(model, train_dataloader, val_dataloader, device, num_epochs=3):
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_train_loss = 0
        
        for batch in tqdm(train_dataloader, desc=f'Epoch {epoch+1}/{num_epochs}'):
            batch = {k: v.to(device) for k, v in batch.items()}
            
            optimizer.zero_grad()
            loss, _ = model(**batch)
            
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()

        # Validation phase
        model.eval()
        val_preds = []
        val_labels = []
        total_val_loss = 0
        
        with torch.no_grad():
            for batch in val_dataloader:
                batch = {k: v.to(device) for k, v in batch.items()}
                loss, logits = model(**batch)
                
                total_val_loss += loss.item()
                predictions = torch.argmax(logits, dim=-1)
                val_preds.extend(predictions.cpu().numpy())
                val_labels.extend(batch['labels'].cpu().numpy())
        
        # Print metrics
        print(f'\nEpoch {epoch+1}:')
        print(f'Average training loss: {total_train_loss/len(train_dataloader):.4f}')
        print(f'Average validation loss: {total_val_loss/len(val_dataloader):.4f}')
        print('\nValidation Classification Report:')
        print(classification_report(val_labels, val_preds))
        print("-" * 60)

        # Save best model
        if total_val_loss < best_val_loss:
            best_val_loss = total_val_loss
            torch.save(model.state_dict(), 'best_splice_site_model.pt')
            print("Saved new best model!")

In [36]:
#@title Training Function [SMALL DATASET]
# MODIFIED VERSION - FASTER, SMALLER TRAIN
# Cell 6: Training Function (Optimized)
def train_model(model, train_dataloader, val_dataloader, device, 
                num_epochs=5,           # Increased from 3
                patience=2,             # Added early stopping
                learning_rate=2e-5):
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_train_loss = 0
        
        progress_bar = tqdm(train_dataloader, desc=f'Epoch {epoch+1}/{num_epochs}')
        for batch in progress_bar:
            batch = {k: v.to(device) for k, v in batch.items()}
            
            optimizer.zero_grad()
            loss, _ = model(**batch)
            
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            progress_bar.set_postfix({'train_loss': f'{loss.item():.4f}'})

        # Validation phase
        model.eval()
        val_preds = []
        val_labels = []
        total_val_loss = 0
        
        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc='Validation'):
                batch = {k: v.to(device) for k, v in batch.items()}
                loss, logits = model(**batch)
                
                total_val_loss += loss.item()
                predictions = torch.argmax(logits, dim=-1)
                val_preds.extend(predictions.cpu().numpy())
                val_labels.extend(batch['labels'].cpu().numpy())
        
        # Calculate average losses and accuracy
        avg_train_loss = total_train_loss / len(train_dataloader)
        avg_val_loss = total_val_loss / len(val_dataloader)
        val_accuracy = accuracy_score(val_labels, val_preds)
        
        # Print metrics
        print(f'\nEpoch {epoch+1}:')
        print(f'Average training loss: {avg_train_loss:.4f}')
        print(f'Average validation loss: {avg_val_loss:.4f}')
        print(f'Validation accuracy: {val_accuracy:.4f}')
        print('\nValidation Classification Report:')
        print(classification_report(val_labels, val_preds))
        print("-" * 60)

        # Save best model and early stopping
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), 'best_splice_site_model.pt')
            print("Saved new best model!")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered after epoch {epoch+1}")
                break

In [37]:
#@title Training
# Cell 7: Training
train_model(model, train_dataloader, test_dataloader, device)

Epoch 1/5:   0%|          | 0/1563 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 1/5: 100%|██████████| 1563/1563 [07:22<00:00,  3.53it/s, train_loss=0.6903]
Validation:   0%|          | 0/313 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKEN


Epoch 1:
Average training loss: 0.8683
Average validation loss: 0.8360
Validation accuracy: 0.5927

Validation Classification Report:
              precision    recall  f1-score   support

           0       0.32      0.09      0.15      2475
           1       0.40      0.82      0.53      2464
           2       0.88      0.73      0.80      5061

    accuracy                           0.59     10000
   macro avg       0.53      0.55      0.49     10000
weighted avg       0.62      0.59      0.57     10000

------------------------------------------------------------
Saved new best model!


Epoch 2/5:   0%|          | 0/1563 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 2/5: 100%|██████████| 1563/1563 [07:23<00:00,  3.53it/s, train_loss=0.8140]
Validation:   0%|          | 0/313 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKEN


Epoch 2:
Average training loss: 0.8260
Average validation loss: 0.8219
Validation accuracy: 0.6337

Validation Classification Report:
              precision    recall  f1-score   support

           0       0.43      0.81      0.56      2475
           1       0.47      0.05      0.09      2464
           2       0.83      0.83      0.83      5061

    accuracy                           0.63     10000
   macro avg       0.58      0.56      0.49     10000
weighted avg       0.64      0.63      0.58     10000

------------------------------------------------------------
Saved new best model!


Epoch 3/5:   0%|          | 0/1563 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 3/5: 100%|██████████| 1563/1563 [07:24<00:00,  3.52it/s, train_loss=0.9297]
Validation:   0%|          | 0/313 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKEN


Epoch 3:
Average training loss: 0.7838
Average validation loss: 0.7504
Validation accuracy: 0.6722

Validation Classification Report:
              precision    recall  f1-score   support

           0       0.46      0.74      0.57      2475
           1       0.59      0.50      0.54      2464
           2       0.92      0.72      0.81      5061

    accuracy                           0.67     10000
   macro avg       0.66      0.65      0.64     10000
weighted avg       0.73      0.67      0.68     10000

------------------------------------------------------------
Saved new best model!


Epoch 4/5:   0%|          | 0/1563 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 4/5: 100%|██████████| 1563/1563 [07:23<00:00,  3.53it/s, train_loss=0.5156]
Validation:   0%|          | 0/313 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKEN


Epoch 4:
Average training loss: 0.5800
Average validation loss: 0.4806
Validation accuracy: 0.8101

Validation Classification Report:
              precision    recall  f1-score   support

           0       0.70      0.84      0.76      2475
           1       0.77      0.82      0.79      2464
           2       0.91      0.79      0.85      5061

    accuracy                           0.81     10000
   macro avg       0.79      0.82      0.80     10000
weighted avg       0.82      0.81      0.81     10000

------------------------------------------------------------
Saved new best model!


Epoch 5/5:   0%|          | 0/1563 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Epoch 5/5: 100%|██████████| 1563/1563 [07:23<00:00,  3.53it/s, train_loss=0.2358]
Validation:   0%|          | 0/313 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKEN


Epoch 5:
Average training loss: 0.4251
Average validation loss: 0.5158
Validation accuracy: 0.7869

Validation Classification Report:
              precision    recall  f1-score   support

           0       0.67      0.87      0.76      2475
           1       0.72      0.87      0.79      2464
           2       0.93      0.71      0.80      5061

    accuracy                           0.79     10000
   macro avg       0.78      0.81      0.78     10000
weighted avg       0.82      0.79      0.79     10000

------------------------------------------------------------





In [38]:
#@title Test Evaluation
# Cell 8: Final Evaluation
# Load best model
model.load_state_dict(torch.load('best_splice_site_model.pt'))
model.eval()

# Evaluate on test set
test_preds = []
test_labels_all = []
with torch.no_grad():
    for batch in test_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        _, logits = model(**batch)
        predictions = torch.argmax(logits, dim=-1)
        test_preds.extend(predictions.cpu().numpy())
        test_labels_all.extend(batch['labels'].cpu().numpy())

print("\nFinal Test Set Classification Report:")
print(classification_report(test_labels_all, test_preds))

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)



Final Test Set Classification Report:
              precision    recall  f1-score   support

           0       0.70      0.84      0.76      2475
           1       0.77      0.82      0.79      2464
           2       0.91      0.79      0.85      5061

    accuracy                           0.81     10000
   macro avg       0.79      0.82      0.80     10000
weighted avg       0.82      0.81      0.81     10000

