# Mutation Type Classification with a Transformer Model

This notebook trains a Transformer-based model to classify mutation types from DNA sequences. 

**Key Features:**
- **Data Balancing:** Implements **SMOTE** (Synthetic Minority Over-sampling Technique) to address class imbalance by generating synthetic data for minority classes.
- **Model:** Utilizes a Transformer Encoder architecture.
- **Training:** Employs features like Focal Loss, Adam optimizer, mixed-precision training (FP16), and a learning rate scheduler with warmup.

## 1. Setup and Configuration

In [None]:
hyperparameters = {
    "learning_rate": 1e-4,
    "batch_size": 256,
    "embed_dim": 256,
    "num_heads": 4,
    "num_layers": 2,
    "dropout": 0.1,
    "ff_dim": 1024,
    "epochs": 300,
    "num_warmup_steps": 1000,
    "weight_decay": 1e-4,
    "k-mers": 3,
    "max_len": 199
}

info = {
    "dataset_size": "data",
    "precision": "FP16",
    "dir_name": "Mutation Model",
    "run": "Fifteenth Run (SMOTE)",
    "loss": {
        "type": "Focal Loss",
        "gamma": 2
    },
    "optimizer": "Adam",
    "is_pre_training": False
}

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 2. Load and Inspect Data

In [None]:
import pandas as pd

data_path = f"/content/drive/MyDrive/dataset/{info['dataset_size']}.csv"

data = pd.read_csv(data_path)
print(f"Dataset shape: {data.shape}")

In [None]:
print("Original class distribution:")
data['label'].value_counts()

In [None]:
x = data['sequence']
y = data['label']

## 3. Pre-processing and Vocabulary Creation

In [None]:
import pickle

def get_codon(seq, k=hyperparameters['k-mers']):
    return [seq[i:i+k] for i in range(len(seq) - k + 1)]

vocab = {}

for seq in data['sequence']:
    for codons in get_codon(seq.lower()):
        if codons not in vocab:
            vocab[codons] = len(vocab)

# Add a padding token to the vocabulary
if '<pad>' not in vocab:
    vocab['<pad>'] = len(vocab)

with open('vocab.pkl', 'wb') as f:
    pickle.dump(vocab, f)

def get_tensor(text):
    return [vocab[codons.lower()] for codons in get_codon(text)]

print(f"Vocabulary size: {len(vocab)}")

## 4. Data Balancing with SMOTE

To handle the imbalanced dataset, we will use the SMOTE algorithm. This requires installing the `imbalanced-learn` library. SMOTE works on numerical data, so we must first tokenize and pad all our sequences before applying it.

In [None]:
!pip install imbalanced-learn -q

In [None]:
import numpy as np
from imblearn.over_sampling import SMOTE
from torch.nn.utils.rnn import pad_sequence
import torch

# 1. Tokenize all sequences and pad them to a uniform length
print("Tokenizing all sequences...")
tokenized_sequences = [torch.tensor(get_tensor(seq)) for seq in x]

# Pad sequences. The padding value should correspond to your '<pad>' token.
X_padded = pad_sequence(tokenized_sequences, batch_first=True, padding_value=vocab['<pad>'])
print(f"Shape of padded data before SMOTE: {X_padded.shape}")

# 2. Apply SMOTE
# We will oversample the minority classes (3 and 4) to have 150,000 samples each.
sampling_strategy = {
    3: 150000,
    4: 150000
}

smote = SMOTE(sampling_strategy=sampling_strategy, random_state=42, n_jobs=-1)

print("\nApplying SMOTE... This can take a few minutes.")
# SMOTE needs a 2D array, so we flatten the sequence dimension
X_reshaped = X_padded.reshape(X_padded.shape[0], -1)
X_resampled_flat, y_resampled = smote.fit_resample(X_reshaped, y)

# Reshape X back to its original sequence format
X_resampled = X_resampled_flat.reshape(X_resampled_flat.shape[0], X_padded.shape[1])
print(f"Shape of data after SMOTE: {X_resampled.shape}")

# 3. Check the new class distribution
print("\nNew class distribution after SMOTE:")
print(pd.Series(y_resampled).value_counts())

## 5. Create PyTorch Dataset and DataLoaders

In [None]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, random_split

# We use a new Dataset class that accepts pre-tokenized data
class PreTokenizedDataset(Dataset):
  def __init__(self, x, y):
    self.x_frame = torch.tensor(x, dtype=torch.long)
    self.y_frame = torch.tensor(y, dtype=torch.long)

  def __len__(self):
    return len(self.x_frame)

  def __getitem__(self, index):
    return self.x_frame[index], self.y_frame[index]

In [None]:
# Use the resampled data from SMOTE to create the dataset
dataset = PreTokenizedDataset(X_resampled, y_resampled)

# Split the balanced dataset into training and testing sets
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=hyperparameters['batch_size'],
    shuffle=True
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=hyperparameters['batch_size']
)

## 6. Model Definition

In [None]:
import math

class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp((torch.arange(0, embed_dim, 2)) * (-math.log(10000.0) / embed_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :].to(x.device)
        return x

class Transformer(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8, num_layers=6, ff_dim=2048, dropout=0.1, vocab_size=10000, max_len=5000):
        super(Transformer, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embed_dim)
        self.position_encoding = PositionalEncoding(embed_dim=embed_dim, max_len=max_len)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=ff_dim,
            dropout=dropout,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer,
            num_layers=num_layers
        )

        self.y_labels_out = nn.Linear(embed_dim, 5) # 5 output classes

    def forward(self, x):
        x = self.embeddings(x)
        x = self.position_encoding(x)
        x = self.encoder(x)
        x = x.mean(dim=1)
        y_label_out = self.y_labels_out(x)
        return y_label_out

In [None]:
model = Transformer(
    embed_dim=hyperparameters['embed_dim'],
    num_heads=hyperparameters['num_heads'],
    num_layers=hyperparameters['num_layers'],
    ff_dim=hyperparameters['ff_dim'],
    dropout=hyperparameters['dropout'],
    vocab_size=len(vocab),
    max_len=hyperparameters['max_len']
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model moved to device: {device}")

In [None]:
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {num_params}")

## 7. Loss Function, Optimizer, and Scheduler

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2, weight=None, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.weight = weight
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = nn.functional.cross_entropy(inputs, targets, weight=self.weight, reduction='none')
        probs = torch.softmax(inputs, dim=-1)
        probs_for_class = probs.gather(1, targets.unsqueeze(1))

        focal_loss = ((1 - probs_for_class) ** self.gamma * ce_loss).squeeze(1)

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

In [None]:
from transformers import get_linear_schedule_with_warmup

# Since we balanced the dataset with SMOTE, we no longer need to pass class weights to the loss function.
ce = FocalLoss(gamma=info['loss']['gamma'])
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparameters['learning_rate'], weight_decay=hyperparameters["weight_decay"])
scaler = torch.cuda.amp.GradScaler()

start_epoch = 1
if info['is_pre_training']:
    checkpoint = torch.load(f"/content/drive/MyDrive/{info['dir_name']}/pretrained_model.pth")
    start_epoch = checkpoint['epoch'] + 1
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print(f"Resuming training from epoch {start_epoch}")

num_training_steps = len(train_loader) * hyperparameters['epochs']
num_warmup_steps = hyperparameters["num_warmup_steps"]

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

## 8. Training and Validation Loops

In [None]:
def train16(model, loader, ce, optimizer, scaler, scheduler):
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for x, y in loader:
        optimizer.zero_grad()
        x = x.to(device)
        y = y.to(device)

        with torch.cuda.amp.autocast(dtype=torch.float16):
            output = model(x)
            loss = ce(output, y)

        prediction = torch.argmax(output, dim=1)
        correct += (prediction == y).sum().item()
        total += len(x)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        running_loss += loss.item() * len(x)

    accuracy = correct / total
    return (running_loss / len(loader.dataset), accuracy)

In [None]:
def validation(model, loader, ce):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            if info['precision'] == 'FP16':
              with torch.cuda.amp.autocast(dtype=torch.float16):
                  output = model(x)
                  loss = ce(output, y)
            else:
              output = model(x)
              loss = ce(output, y)

            running_loss += loss.item() * len(x)
            prediction = torch.argmax(output, dim=1)
            correct += (prediction == y).sum().item()
            total += len(x)

    accuracy = correct / total
    return (running_loss / len(loader.dataset), accuracy)

## 9. Main Training Execution

In [None]:
import os
from tqdm.auto import tqdm

patience = 10
best_val_loss = float('inf')
counter = 0

train_loss_history = []
train_acc_history = []
val_loss_history = []
val_acc_history = []

save_dir = f"/content/drive/MyDrive/{info['dir_name']}"
os.makedirs(save_dir, exist_ok=True)

epochs = hyperparameters['epochs']

for epoch in tqdm(range(start_epoch, epochs + 1), desc="Training Epochs"):
    current_train_loss, current_train_acc = train16(
        model, train_loader, ce, optimizer, scaler, scheduler
    )

    current_val_loss, current_val_acc = validation(
        model, test_loader, ce
    )

    train_loss_history.append(current_train_loss)
    train_acc_history.append(current_train_acc)
    val_loss_history.append(current_val_loss)
    val_acc_history.append(current_val_acc)

    print(f"Epoch {epoch}/{epochs}: Train Loss={current_train_loss:.4f}, Train Acc={current_train_acc:.4f} | Val Loss={current_val_loss:.4f}, Val Acc={current_val_acc:.4f}")

    if epoch % 10 == 0:
      checkpoint_path = f"{save_dir}/model_{info['run']}_epoch_{epoch}.pth"
      torch.save({
          'epoch': epoch,
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'train_losses': train_loss_history,
          'val_losses': val_loss_history,
          'train_acc': train_acc_history,
          'val_acc': val_acc_history
      }, checkpoint_path)
      print(f"Model checkpoint saved at {checkpoint_path}")

    if current_val_loss < best_val_loss:
        best_val_loss = current_val_loss
        counter = 0
    else:
        counter += 1
        print(f"No improvement in validation loss. Counter: {counter}/{patience}")
        if counter >= patience:
            print("Early stopping triggered!")
            break

## 10. Evaluation and Visualization

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_loss_history, label='Train Loss')
plt.plot(val_loss_history, label='Validation Loss')
plt.title('Training vs Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(train_acc_history, label='Train Accuracy')
plt.plot(val_acc_history, label='Validation Accuracy')
plt.title('Training vs Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

In [None]:
def get_predictions_and_labels(model, loader):
    model.eval()
    all_y_true = []
    all_y_pred = []
    with torch.no_grad():
        for x, y in tqdm(loader, desc="Getting Predictions"):
            x = x.to(device)
            yout = model(x)
            _, pred_mut = torch.max(yout, 1)
            all_y_true.extend(y.cpu().numpy())
            all_y_pred.extend(pred_mut.cpu().numpy())
    return (all_y_true, all_y_pred)

In [None]:
from sklearn.metrics import classification_report

y_true, y_pred = get_predictions_and_labels(model, test_loader)

print("\n" + "="*60)
print("Classification Report Summary")
print("="*60)
print(classification_report(y_true, y_pred))
print("="*60 + "\n")

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)

fig, ax = plt.subplots(figsize=(8, 8))
disp.plot(ax=ax, xticks_rotation=45)
plt.title("Confusion Matrix")
plt.show()