$\textbf{Libraries}$

In [1]:
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import pickle
from sklearn.utils import shuffle
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

$\textbf{Data Processing}$

In [2]:
data = pd.read_csv("raw.csv")

# Step 1: Separate the data based on labels
label_1_data = data[data['hit'] == 1]
label_0_data = data[data['hit'] == 0]

# Step 2: Sample from the label 0 data to match the number of label 1 points
balanced_label_0_data = label_0_data.sample(n=len(label_1_data), random_state=42)

# Step 3: Concatenate the balanced subsets
data = pd.concat([label_1_data, balanced_label_0_data])

$\textit{Mapping the alleles to corresponding pseudo sequences}$

In [3]:
# Getting the pseudo sequences
pseudo_seq=pd.read_csv('MHC_pseudo.dat',delimiter=r"\s+",header=None,names=['allele','Pseudo_Sequence'])

# Normalizing the pseudo sequences HLA-[gene][allele_group]:[protein_code] -> HLA-[gene]*[allele_group]:[protein_code]
def regularize(allele):
    if allele[:3]=='HLA':
        allele=allele[:5]+'*'+allele[5:]
    return allele
pseudo_seq['allele']=pseudo_seq['allele'].apply(regularize)

# Mapping the alleles to the corresponding pseudo sequences

allele2pds={} # allele to pseudo sequence dictionary

alleles=data['allele'].unique()
for allele_idx in range(len(alleles)):
    for pds_idx in range(len(pseudo_seq)):
        if pseudo_seq['allele'][pds_idx]==alleles[allele_idx]:
            allele2pds[alleles[allele_idx]]=pseudo_seq['Pseudo_Sequence'][pds_idx]
            break

def allele2pds_fn(allele):
    return allele2pds[allele]
data['pds']=data['allele'].apply(allele2pds_fn)

$\textit{Saving}$

In [4]:
with open('data.pkl','wb') as file :
    pickle.dump(data,file)

In [5]:
with open('data.pkl','rb') as file :
  data = pickle.load(file)

$\textbf{Model}$

In [16]:
class MHCBindingModel(nn.Module):
    def __init__(self,
                 model_name='Rostlab/prot_bert',
                 embedding_dim=1024,
                 hidden_dims=[512, 256],
                 dropout_rate=0.3,
                 fine_tune= "qlora"):
        super(MHCBindingModel, self).__init__()
        if fine_tune == "qlora":
          # Quantization configuration
          quantization_config = BitsAndBytesConfig(
              load_in_4bit=True,
              bnb_4bit_compute_dtype=torch.float16,
              bnb_4bit_quant_type="nf4",
              bnb_4bit_use_double_quant=True,
          )

          # Load base model with quantization
          self.base_model = AutoModel.from_pretrained(
              model_name,
              quantization_config=quantization_config,
              device_map="auto"
          )

          # Prepare model for kbit training
          self.base_model = prepare_model_for_kbit_training(self.base_model)
        elif fine_tune == "lora":
          self.base_model = AutoModel.from_pretrained(model_name)
        elif fine_tune =="no_lora":
          self.base_model = AutoModel.from_pretrained(model_name)
          for param in self.base_model.parameters():
            param.requires_grad = True  # Enable gradients

        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        if fine_tune in {"lora","qlora"} :
          # Apply LoRA configuration
          lora_config = LoraConfig(
              r=16,  # Rank of LoRA adaptation
              lora_alpha=32,
              target_modules=['query', 'value', 'key'],
              lora_dropout=0.1,
              bias="none",
              task_type="FEATURE_EXTRACTION"
          )
          self.base_model = get_peft_model(self.base_model, lora_config)

        # Classification layers
        classifier_layers = []
        prev_dim = embedding_dim * 2  # Concatenated embeddings

        for dim in hidden_dims:
            classifier_layers.extend([
                nn.Linear(prev_dim, dim),
                nn.BatchNorm1d(dim),
                nn.ReLU(),
                nn.Dropout(dropout_rate)
            ])
            prev_dim = dim

        classifier_layers.append(nn.Linear(prev_dim, 1))
        classifier_layers.append(nn.Sigmoid())

        self.classifier = nn.Sequential(*classifier_layers)

    def embed_sequence(self, sequence):
        # Tokenize sequence
        inputs = self.tokenizer(
            sequence,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=512
        ).to(self.base_model.device)

        # Get embeddings
        outputs = self.base_model(**inputs)
        embeddings = outputs.last_hidden_state.mean(dim=1)

        return embeddings.squeeze()

    def forward(self, mhc_sequence, peptide_sequence):
        # Embed both sequences
        mhc_embedding = self.embed_sequence(mhc_sequence)
        peptide_embedding = self.embed_sequence(peptide_sequence)

        # Concatenate embeddings
        combined_embedding = torch.cat([mhc_embedding, peptide_embedding], dim=-1)

        # Classify
        binding_probability = self.classifier(combined_embedding)

        return binding_probability

$\textit{New Dataset Class}$

In [7]:
class MHCPeptideDataset(Dataset):
    def __init__(self, data):
        """
        Initialize dataset from a DataFrame
        Expected columns: 'mhc_sequence', 'peptide_sequence', 'binding_label'
        """
        self.mhc_sequences = data['pds'].tolist()
        self.peptide_sequences = data['peptide'].tolist()
        self.labels = torch.tensor(data['hit'].values, dtype=torch.float32)

    def __len__(self):
        return len(self.mhc_sequences)

    def __getitem__(self, idx):
        return {
            'mhc_sequences': self.mhc_sequences[idx],
            'peptide_sequences': self.peptide_sequences[idx],
            'labels': self.labels[idx]
        }

$\textit{Training}$

In [8]:
device = 'cuda' if torch.cuda.is_available else 'cpu'
batch_size = 64
lr = 1e-4
epochs = 3

In [9]:
# Shuffle the DataFrame
shuffled_data = shuffle(data, random_state=42)

p=0.8 # Percentage of train dataset

train_data = shuffled_data.iloc[:int(p * len(shuffled_data))]
train_data = MHCPeptideDataset(train_data)

val_data = shuffled_data.iloc[int(p * len(shuffled_data)):]
val_data = MHCPeptideDataset(val_data)

In [10]:

def train(
    model,
    train_dataset,
    val_dataset,
    batch_size=batch_size,
    learning_rate=lr,
    epochs=epochs,
    device=device,
):
    """
    Train QLoRA MHC Binding Prediction Model

    Args:
    - model (MHCBindingModel): Initialized MHC binding model
    - train_dataset (MHCPeptideDataset): Training dataset
    - val_dataset (MHCPeptideDataset, optional): Validation dataset
    - batch_size (int): Training batch size
    - learning_rate (float): Optimizer learning rate
    - epochs (int): Number of training epochs
    - device (torch.device, optional): Device to train on
    - wandb_logging (bool): Enable Weights & Biases logging
    - wandb_project (str): Weights & Biases project name
    - wandb_run_name (str, optional): Specific run name

    Returns:
    - Trained model
    - Training history (dict)
    """
    # Set device
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)


    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4
    ) if val_dataset is not None else None

    # Initialize optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = torch.nn.BCELoss()

    # Training history
    history = {
        'train_loss': [],
        'val_loss': []
    }

    # Training loop
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        # Training phase
        model.train()
        total_train_loss = 0
        train_acc = 0

        for batch in tqdm(train_loader, desc="Training", total=len(train_loader)):
            # Zero gradients
            optimizer.zero_grad()

            # Prepare batch
            mhc_sequences = batch['mhc_sequences']
            peptide_sequences = batch['peptide_sequences']
            labels = batch['labels'].to(device)

            # Compute predicitions
            predictions = model(mhc_sequences, peptide_sequences).squeeze().to(device)

            # Compute loss
            loss = criterion(predictions, labels)

            # Compute acc
            predicted_labels = (predictions > 0.5).int()
            train_acc += (predicted_labels == labels).sum().item()

            # Backward pass
            loss.backward()

            # Optimize
            optimizer.step()

            # Accumulate loss
            total_train_loss += loss.item()

        train_acc = (train_acc*100)/len(train_dataset)

        # Average training loss
        avg_train_loss = total_train_loss / len(train_data)
        print(f"  Training Loss: {avg_train_loss:.4f}   Training Accuracy: {train_acc:.4f}%")

        history['train_loss'].append(avg_train_loss)

        # Validation phase
        model.eval()
        total_val_loss = 0
        val_acc = 0

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation", total=len(val_loader)):
                # Prepare batch
                mhc_sequences = batch['mhc_sequences']
                peptide_sequences = batch['peptide_sequences']
                labels = batch['labels'].to(device)


                # Compute predictions
                predictions = model(mhc_sequences, peptide_sequences).squeeze().to(device)

                # Compute loss
                val_loss = criterion(predictions, labels)
                total_val_loss += val_loss.item()

                predicted_labels = (predictions > 0.5).int()
                val_acc += (predicted_labels == labels).sum().item()


        val_acc = (val_acc*100)/len(val_dataset)

        avg_val_loss = total_val_loss / len(val_loader)
        history['val_loss'].append(avg_val_loss)

        print(f"  Validation Loss: {avg_val_loss:.4f}   Validation Accuracy: {val_acc:.4f}%")

    return model, history

In [None]:
# Baseline finetuning
base_model = MHCBindingModel(fine_tune = "no_lora")
_,history_base =  train(
            base_model,
            train_data,
            val_data,
            batch_size=batch_size,
            learning_rate=lr,
            epochs=epochs,
            device=device
            )

In [None]:
# Training with lora weights
lora_model = MHCBindingModel(fine_tune = "lora")
_,history_lora = train(
        lora_model,
        train_data,
        val_data,
        batch_size=batch_size,
        learning_rate=lr,
        epochs=epochs,
        device=device
        )

In [15]:
torch.cuda.empty_cache()

In [None]:
# Training with QLoRA weights
qlora_model = MHCBindingModel(fine_tune = "qlora")
_,history_qlora= train(
      qlora_model,
      train_data,
      val_data,
      batch_size=batch_size,
      learning_rate=lr,
      epochs=epochs,
      device=device
      )