In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.amp import autocast, GradScaler

from transformers import T5ForConditionalGeneration, T5Config, get_linear_schedule_with_warmup
from sklearn.model_selection import KFold, train_test_split

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import re
import gc
import json
from datetime import datetime
from tqdm.auto import tqdm

# Set paths
DATA_PATH = "data/dataset_sequences_renamed64plus_filtered_dropna_hope5.csv"
MODEL_PATH = "models/peptide_model"

def clear_cuda_memory():
    """Clear CUDA memory and garbage collect to free up resources"""
    gc.collect()
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()


class FocalLossWithSmoothing(nn.Module):
    """
    Focal loss with label smoothing for handling class imbalance
    """
    def __init__(self, gamma=2.0, smoothing=0.035, ignore_index=-100):
        super().__init__()
        self.gamma = gamma
        self.smoothing = smoothing
        self.ignore_index = ignore_index

    def forward(self, input, target):
        if input.dim() > 2:
            input = input.view(-1, input.size(-1))
        target = target.view(-1)

        mask = target != self.ignore_index
        input = input[mask]
        target = target[mask]

        n_class = input.size(1)

        one_hot = torch.zeros_like(input).scatter(1, target.unsqueeze(1), 1)
        one_hot = one_hot * (1 - self.smoothing) + (1 - one_hot) * self.smoothing / (n_class - 1)

        log_prb = F.log_softmax(input, dim=1)
        prb = torch.exp(log_prb)

        focal_loss = -(torch.pow(1-prb, self.gamma)) * one_hot * log_prb

        return focal_loss.sum(dim=1).mean()


class NoisyT5(nn.Module):
    """
    T5 model with added Gaussian noise to improve robustness
    """
    def __init__(self, t5_model, noise_std=0.1):
        super().__init__()
        self.t5 = t5_model
        self.noise_std = noise_std

    def forward(self, input_ids=None, attention_mask=None, labels=None, inputs_embeds=None, return_dict=True, **kwargs):
        if inputs_embeds is None and input_ids is not None:
            inputs_embeds = self.t5.shared(input_ids)

        noise = torch.randn_like(inputs_embeds) * self.noise_std
        noisy_embeddings = inputs_embeds + noise

        return self.t5(
            inputs_embeds=noisy_embeddings,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=return_dict,
            **kwargs
        )

    def generate(self, *args, **kwargs):
        return self.t5.generate(*args, **kwargs)

In [None]:
class TrainingHistory:
    """Class for tracking and visualizing training metrics"""
    def __init__(self):
        self.train_losses = []
        self.val_losses = []
        self.epochs = []
        self.learning_rates = []
        self.fold = None

    def add_epoch(self, epoch, train_loss, val_loss, lr):
        self.epochs.append(epoch)
        self.train_losses.append(train_loss)
        self.val_losses.append(val_loss)
        self.learning_rates.append(lr)

    def set_fold(self, fold):
        self.fold = fold

    def plot_losses(self, save_path=None):
        plt.close('all')
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
        title_prefix = f"Fold {self.fold}: " if self.fold is not None else ""

        ax1.plot(self.epochs, self.train_losses, label='Training Loss')
        ax1.plot(self.epochs, self.val_losses, label='Validation Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_title(f'{title_prefix}Training and Validation Loss')
        ax1.legend()
        ax1.grid(True)

        ax2.plot(self.epochs, self.learning_rates, label='Learning Rate')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Learning Rate')
        ax2.set_title(f'{title_prefix}Learning Rate Schedule')
        ax2.grid(True)

        plt.tight_layout()
        if save_path:
            plt.savefig(save_path)
            plt.close()
        else:
            plt.show()
            plt.close()

    def save_history(self, path):
        history_dict = {
            'fold': self.fold,
            'epochs': self.epochs,
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'learning_rates': self.learning_rates
        }
        with open(path, 'w') as f:
            json.dump(history_dict, f)


class RegularTokenizer:
    """
    Custom tokenizer for protein sequences with special characters
    for phosphorylation sites
    """
    def __init__(self):
        self.chars = set('ACDEFGHIKLMNPQRSTUVWXY+-?abcdefghijklmnopqrstuvwxyz[]!@#$%^&*<>,.')
        self.pad_token = "<pad>"
        self.unk_token = "<unk>"
        self.eos_token = "</s>"
        self.bos_token = "<s>"

        special_tokens = [self.pad_token, self.unk_token, self.eos_token, self.bos_token]
        self.token_to_id = {token: idx for idx, token in enumerate(special_tokens)}
        next_idx = len(special_tokens)

        for char in sorted(self.chars):
            self.token_to_id[char] = next_idx
            next_idx += 1

        self.id_to_token = {idx: token for token, idx in self.token_to_id.items()}
        self.pad_token_id = self.token_to_id[self.pad_token]
        self.eos_token_id = self.token_to_id[self.eos_token]
        self.bos_token_id = self.token_to_id[self.bos_token]
        self.vocab_size = len(self.token_to_id)

    def encode(self, text):
        return [self.bos_token_id] + [self.token_to_id.get(c, self.token_to_id[self.unk_token])
                                     for c in text] + [self.eos_token_id]

    def decode(self, ids):
        if torch.is_tensor(ids):
            ids = ids.cpu().tolist()
        return ''.join([self.id_to_token[id] for id in ids
                       if id not in [self.pad_token_id, self.bos_token_id, self.eos_token_id]])

    def __call__(self, texts, max_length=136, padding='max_length', truncation=True, return_tensors=None):
        if isinstance(texts, str):
            texts = [texts]
        batch_ids = []
        for text in texts:
            ids = self.encode(text)
            if truncation and len(ids) > max_length:
                ids = ids[:max_length-1] + [self.eos_token_id]
            if padding == 'max_length':
                pad_length = max_length - len(ids)
                if pad_length > 0:
                    ids = ids + [self.pad_token_id] * pad_length
            batch_ids.append(ids)
        if return_tensors == "pt":
            return {
                "input_ids": torch.tensor(batch_ids),
                "attention_mask": torch.tensor([[1 if id != self.pad_token_id else 0 for id in ids]
                                             for ids in batch_ids])
            }
        return batch_ids

In [None]:
class PeptideDataset(Dataset):
    """
    Dataset class for peptide sequences with and without phosphorylation markers
    """
    def __init__(self, sequences_no_plus, sequences_with_plus, tokenizer, max_length=136):
        self.sequences_no_plus = sequences_no_plus
        self.sequences_with_plus = sequences_with_plus
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sequences_no_plus)

    def __getitem__(self, idx):
        sequence_no_plus = self.sequences_no_plus[idx]
        sequence_with_plus = self.sequences_with_plus[idx]

        source_encoding = self.tokenizer(
            sequence_no_plus,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )
        target_encoding = self.tokenizer(
            sequence_with_plus,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )
        return {
            'input_ids': source_encoding['input_ids'].squeeze(),
            'attention_mask': source_encoding['attention_mask'].squeeze(),
            'labels': target_encoding['input_ids'].squeeze()
        }


def get_optimal_batch_size(model, dataset, device, start_size=2):
    """
    Determine the optimal batch size that fits in GPU memory
    """
    clear_cuda_memory()
    batch_size = start_size
    while True:
        try:
            loader = DataLoader(dataset,
                              batch_size=batch_size,
                              num_workers=4,
                              pin_memory=True)
            batch = next(iter(loader))
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            with autocast(device_type='cuda'):
                outputs = model(input_ids=input_ids,
                              attention_mask=attention_mask,
                              labels=labels)
            del outputs, input_ids, attention_mask, labels
            clear_cuda_memory()
            batch_size *= 2
        except RuntimeError:
            clear_cuda_memory()
            return batch_size // 2


def check_predictions(model, val_loader, tokenizer, device, num_samples=2):
    """
    Generate predictions on a batch of validation data and print comparison
    """
    model.eval()
    clear_cuda_memory()
    with torch.no_grad(), autocast(device_type='cuda'):
        batch = next(iter(val_loader))
        input_ids = batch['input_ids'][:num_samples].to(device)
        attention_mask = batch['attention_mask'][:num_samples].to(device)
        labels = batch['labels'][:num_samples]

        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=136,
            num_beams=5,
            length_penalty=1.0,
            early_stopping=True,
            min_length=111
        )

        for pred, label in zip(outputs, labels):
            pred_text = tokenizer.decode(pred.cpu())
            label_text = tokenizer.decode(label)
            print(f"\nPredicted: {pred_text}")
            print(f"Actual: {label_text}")

        del outputs, input_ids, attention_mask
        clear_cuda_memory()

In [None]:
def train_model(model, train_loader, val_loader, tokenizer, device, fold=None, num_epochs=80, save_path=None):
    """
    Train the model with mixed precision, gradient accumulation, and early stopping
    """
    optimizer = AdamW(model.parameters(), lr=4e-5, weight_decay=0.01)
    num_training_steps = len(train_loader) * num_epochs * 2
    num_warmup_steps = int(num_training_steps * 0.07)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
    criterion = FocalLossWithSmoothing(gamma=2.0, smoothing=0.035)
    scaler = GradScaler()

    history = TrainingHistory()
    if fold is not None:
        history.set_fold(fold)

    best_val_loss = float('inf')
    patience = 3
    no_improvement = 0
    accumulation_steps = 2
    max_grad_norm = 0.5
    noise_iterations = 2
    total_steps = 0

    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0

        for i, batch in enumerate(tqdm(train_loader)):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            for noise_iter in range(noise_iterations):
                optimizer.zero_grad()

                with autocast(device_type='cuda'):
                    outputs = model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        labels=labels,
                        return_dict=True
                    )

                    logits = outputs.logits.view(-1, model.t5.config.vocab_size)
                    labels_view = labels.view(-1)
                    loss = criterion(logits, labels_view) / accumulation_steps

                scaler.scale(loss).backward()
                total_train_loss += loss.item() * accumulation_steps

                del outputs, loss, logits, labels_view
                clear_cuda_memory()

                if (i * noise_iterations + noise_iter + 1) % accumulation_steps == 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()
                    scheduler.step()
                    optimizer.zero_grad()
                    total_steps += 1

            del input_ids, attention_mask, labels
            clear_cuda_memory()

            if (i + 1) % 100 == 0:
                clear_cuda_memory()
                model.eval()

        total_val_loss = 0
        model.eval()

        with torch.no_grad(), autocast(device_type='cuda'):
            for batch in tqdm(val_loader):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,
                    return_dict=True
                )

                logits = outputs.logits.view(-1, model.t5.config.vocab_size)
                labels_view = labels.view(-1)
                loss = criterion(logits, labels_view)
                total_val_loss += loss.item()

                del outputs, loss, logits, labels_view, input_ids, attention_mask, labels
                clear_cuda_memory()

            avg_train_loss = total_train_loss / (len(train_loader) * noise_iterations)
            avg_val_loss = total_val_loss / len(val_loader)
            current_lr = scheduler.get_last_lr()[0]

            history.add_epoch(epoch + 1, avg_train_loss, avg_val_loss, current_lr)

            print(f"\nEpoch {epoch+1}/{num_epochs}")
            print(f"Average training loss: {avg_train_loss:.4f}")
            print(f"Average validation loss: {avg_val_loss:.4f}")
            print(f"Learning rate: {current_lr:.2e}")

            try:
                check_predictions(model, val_loader, tokenizer, device)
            except Exception as e:
                print(f"Error in predictions: {e}")
                clear_cuda_memory()
                break

            if save_path:
                fold_suffix = f"_fold{fold}" if fold is not None else ""
                epoch_dir = os.path.join(save_path, f'epoch_{epoch+1}')
                os.makedirs(epoch_dir, exist_ok=True)

                model.t5.save_pretrained(os.path.join(epoch_dir, f'model{fold_suffix}'))
                torch.save(tokenizer, os.path.join(epoch_dir, f'tokenizer{fold_suffix}.pkl'))
                history.plot_losses(os.path.join(epoch_dir, f'loss_plot{fold_suffix}.png'))
                history.save_history(os.path.join(epoch_dir, f'training_history{fold_suffix}.json'))

                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    no_improvement = 0
                    best_dir = os.path.join(save_path, 'best_model')
                    os.makedirs(best_dir, exist_ok=True)
                    model.t5.save_pretrained(os.path.join(best_dir, f'model{fold_suffix}'))
                    torch.save(tokenizer, os.path.join(best_dir, f'tokenizer{fold_suffix}.pkl'))
                    print(f"Saved best model with loss: {best_val_loss:.4f}")
                else:
                    no_improvement += 1
                    if no_improvement >= patience:
                        print("Early stopping triggered")
                        break

            clear_cuda_memory()

    return history, best_val_loss

In [None]:
def train_with_cv(df, n_splits=5, base_path=None, start_fold=0, noise_std=0.1):
    """
    Train model with k-fold cross validation
    
    Args:
        df: DataFrame with sequences data
        n_splits: Number of folds
        base_path: Path to save models
        start_fold: First fold to process (for resuming training)
        noise_std: Standard deviation of noise to add
        
    Returns:
        Fold results and best fold index
    """
    clear_cuda_memory()
    print(f"Initializing training from fold {start_fold}/{n_splits}")
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    fold_results = []
    best_fold = None
    best_loss = float('inf')
    folds = list(kfold.split(df))
    
    # Set a fixed batch size for all folds
    batch_size_test = 32

    for fold_idx in range(start_fold, n_splits):
        clear_cuda_memory()
        train_idx, val_idx = folds[fold_idx]
        print(f"\n{'='*50}")
        print(f"Processing Fold {fold_idx}/{n_splits-1}")
        print(f"{'='*50}")

        fold_dir = os.path.join(base_path, f'fold_{fold_idx}')
        os.makedirs(fold_dir, exist_ok=True)

        train_sequences_no_plus = df['Seq_no_plus'].values[train_idx]
        train_sequences_with_plus = df['Seq'].values[train_idx]
        val_sequences_no_plus = df['Seq_no_plus'].values[val_idx]
        val_sequences_with_plus = df['Seq'].values[val_idx]

        # Initialize tokenizer and model
        tokenizer = RegularTokenizer()
        config = T5Config.from_pretrained('t5-large')
        config.vocab_size = tokenizer.vocab_size

        # Set dropout for regularization
        config.dropout_rate = 0.1
        config.attention_dropout = 0.1
        config.activation_dropout = 0.1
        config.classifier_dropout = 0.1

        base_model = T5ForConditionalGeneration(config)
        model = NoisyT5(base_model, noise_std=noise_std)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)

        # Create datasets and data loaders
        train_dataset = PeptideDataset(train_sequences_no_plus, train_sequences_with_plus, tokenizer)
        val_dataset = PeptideDataset(val_sequences_no_plus, val_sequences_with_plus, tokenizer)

        train_loader = DataLoader(train_dataset, batch_size=batch_size_test, num_workers=4, 
                                 pin_memory=True, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size_test, num_workers=4, 
                               pin_memory=True)

        # Train the model
        history, val_loss = train_model(model, train_loader, val_loader, tokenizer, 
                                       device, fold=fold_idx, num_epochs=80, save_path=fold_dir)

        fold_results.append({
            'fold': fold_idx,
            'history': history,
            'val_loss': val_loss
        })

        if val_loss < best_loss:
            best_loss = val_loss
            best_fold = fold_idx

        # Clean up memory
        del model, train_loader, val_loader, train_dataset, val_dataset
        clear_cuda_memory()

    print(f"\nBest performing fold: {best_fold} with validation loss: {best_loss:.4f}")
    return fold_results, best_fold


def main():
    """
    Main function to load data and start training
    """
    clear_cuda_memory()
    print("Loading dataset from:", DATA_PATH)
    df = pd.read_csv(DATA_PATH)
    print(f"Dataset loaded successfully: {len(df)} rows")

    # Create a timestamp folder for this training run
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    results_dir = os.path.join(MODEL_PATH, f'training_run_{timestamp}')
    os.makedirs(results_dir, exist_ok=True)
    print(f"Results will be saved in: {results_dir}")

    # Set noise standard deviation
    noise_std = 0.1
    print(f"Training with noise standard deviation: {noise_std}")

    # Start training with 5-fold cross validation
    fold_results, best_fold = train_with_cv(df, n_splits=5, base_path=results_dir, 
                                           start_fold=0, noise_std=noise_std)

    print("\nTraining completed!")
    print(f"Results saved in {results_dir}")
    clear_cuda_memory()


if __name__ == "__main__":
    main()