In [None]:
import torch
import torch.nn as nn

import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
import os

os.chdir("../")

In [None]:
from transformers import AutoTokenizer, AutoConfig, AutoModelForMaskedLM, DataCollatorForLanguageModeling, Trainer,  TrainingArguments
from transformers import BertModel, BertConfig

import datasets
import evaluate

# 1. Load Data

In [None]:
data = pd.read_csv('./data/data.csv')

In [None]:
data.head(1)

In [None]:
#data.drop(columns=['interaction'], inplace=True)

In [None]:
# Apply a lambda function to insert spaces between characters
data['antigen'] = data['antigen'].apply(lambda x: ' '.join(list(x)))
data['TCR'] = data['TCR'].apply(lambda x: ' '.join(list(x)))

In [None]:
data

# 2. Tokenize Data

In [None]:
BERT_CONFIG = BertConfig(
    vocab_size=25,
    max_position_embeddings=64,
    type_vocab_size=2,
    num_attention_heads=8,
    num_hidden_layers=8,
    hidden_size=512,
    intermediate_size=2048,
    num_labels=2
)

In [None]:
config = BERT_CONFIG

In [None]:
tokenizer = AutoTokenizer.from_pretrained("src/antigen", config=config)
tokenizer.model_max_length = 64


[CLS]antigen[SEP]TCR[EOS]

# 3. Model

In [None]:
import pandas as pd
from torch.utils.data import Dataset
from datasets import Dataset as HFDataset  # Importing Hugging Face Dataset as HFDataset to avoid confusion with PyTorch Dataset
from transformers import PreTrainedTokenizerBase  # Assuming you're using a tokenizer from the transformers library

file='./data/data.csv'

data = pd.read_csv(file)
# Apply a lambda function to insert spaces between characters
data['antigen'] = data['antigen'].apply(lambda x: ' '.join(list(x)))
data['TCR'] = data['TCR'].apply(lambda x: ' '.join(list(x)))

# Put into Hugging Face dataset
dataset = HFDataset.from_pandas(data)
dataset = dataset.train_test_split(test_size=0.2)

max_len = 64
column_names = data.columns.tolist()

def tokenize_function(examples):
    return tokenizer(examples[column_names[0]], examples[column_names[1]], max_length=max_len, padding='max_length', truncation=True, return_tensors="pt")
        
tokenized_datasets = dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=column_names[:2],
            desc="Running tokenizer on dataset"
        )

In [None]:
# for j in range(10):
#     print(len(tokenized_datasets['train']['input_ids'][j]), len(tokenized_datasets['train']['attention_mask'][j]), (tokenized_datasets['train']['interaction'][j]))

In [None]:
# from src.classifier import ModelTrainer
import sys
import numpy as np
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

from sklearn.model_selection import KFold

from src.model import FocalLoss, TCRModel # model, loss function

# """
#     package versions:
#         torch: 2.1.1+cu121
#         transformers: 4.35.2
#         sklearn: 1.3.0
#         logging: 0.5.1.2
# """


# key reference: 
#               https://github.com/aws-samples/amazon-sagemaker-protein-classification/blob/main/code/train.py
#               https://medium.com/analytics-vidhya/bert-pre-training-fine-tuning-eb574be614f6
class ModelTrainer(nn.Module):

    """
        ************** Train/Test the model using cross validation ************** 
        seed: seed for random number generator
        epochs: number of epochs to train
        lr: learning rate
        train: flag whether to train the model
        log_interval: how many batches to wait before logging training status
        model: takes input_ids: str, attention_mask: str, classification: bool
        
    """

    def __init__(self, train=True, seed = 2023, lr=2e-5, epochs=1000, log_interval=10):
        super(ModelTrainer, self).__init__()
        self.seed = seed 
        self.epochs = epochs 
        self.lr = lr    
        self.log_interval = log_interval 
        self.device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
        self.model = TCRModel().to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.loss_func = FocalLoss(gamma=3, alpha=0.25, no_agg=False, size_average=True)    

    def validate(self, val_loader, model, loss_func):
        """Evaluate the network on the entire validation (part of training data) set."""

        loss_accum = []
        model.eval()
        with torch.set_grad_enabled(False):

            for data in val_loader:

                input_ids = data['input_ids'].to(self.device)
                input_mask = data['attention_mask'].to(self.device)
                labels = data['interaction'].to(self.device)

                outputs = model(input_ids=input_ids, attention_mask=input_mask)

                loss = loss_func(input=outputs, target=labels)
                loss_accum.append(loss.item())


        return np.mean(loss_accum)  

    def test(self, test_loader, model, loss_func):
        """Evaluate the network on the entire test set."""
            
        model.eval()
        sum_losses = []
        correct_predictions = 0
            
        with torch.no_grad():
            for data in test_loader:
                    
                input_ids = data['input_ids'].to(self.device)
                input_mask = data['attention_mask'].to(self.device)
                labels = data['interaction'].to(self.device)

                outputs = model(input_ids=input_ids, attention_mask=input_mask)
                    
                loss = loss_func(input=outputs, target=labels)

                correct_predictions += torch.sum(torch.max(outputs, dim=1) == labels)
                sum_losses.append(loss.item())
                        
            print('\nTest set: loss: {:.4f}, Accuracy: {:.0f}%\n'.format(
                    np.mean(sum_losses), 100. * correct_predictions.double() / len(test_loader.dataset)))

    def train(self, train_loader, test_loader, fold = 3, batch_size = 32):
        '''Train the network on the training set using cross validation.'''

        print(f"Training on: {self.device}")

        torch.manual_seed(self.seed) # set the seed for generating random numbers

        if torch.cuda.is_available():
            torch.cuda.manual_seed(self.seed)
        
        # split data for K-fold cross validation to avoid overfitting
        indices = list(range(len(train_loader.dataset)))
        kf = KFold(n_splits=fold, shuffle=True)

        for cv_index, (train_indices, valid_indices) in enumerate(kf.split(indices)):

            train_sampler = SubsetRandomSampler(train_indices)
            valid_sampler = SubsetRandomSampler(valid_indices)

            train_loader = DataLoader(train_loader.dataset, batch_size=batch_size,
                                                       sampler=train_sampler,
                                                       shuffle=False, collate_fn=collate_fn, pin_memory=True)
            val_loader = DataLoader(train_loader.dataset, batch_size=batch_size,
                                                     sampler=valid_sampler,
                                                     shuffle=False, collate_fn=collate_fn, pin_memory=True)
            epoch_train_loss = []
            for epoch in range(0, self.epochs + 1):

                self.model.train()
                
                for data in train_loader:

                    #print(data['input_ids'])

                    input_ids = data['input_ids'].to(self.device)   # amino acid index numbers
                    input_mask = data['attention_mask'].to(self.device) # attention mask (1 for non-padding token and 0 for padding)
                    labels = data['interaction'].to(self.device) # True for classification task

                    outputs = self.model(
                                        input_ids = input_ids, attention_mask = input_mask)
                 
                    loss = self.loss_func(input=outputs, target=labels)
                   
                    epoch_train_loss.append(loss.item())

                    loss.backward()
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    

                # train & validation error after every epoch
                #print(epoch_train_loss)
                train_loss_avg = np.mean(epoch_train_loss)
                val_loss_avg = self.validate(val_loader=val_loader, model=self.model, loss_func=self.loss_func)

                print("At end of Epoch: {}/{}, Training Loss: {:.4f},  Validation Loss: {:.4f}".format(
                                epoch, self.epochs, train_loss_avg, val_loss_avg))
                
            # model testing
            print('Testing the model...')
            self.test(test_loader=test_loader, model=self.model, loss_func=self.loss_func)
            print('Finished training & testing the model.')


In [15]:
def collate_fn(batch):
    return {key: torch.stack([torch.tensor(val[key]) for val in batch]) for key in batch[0]}

batch_size = 512
train_loader = torch.utils.data.DataLoader(tokenized_datasets['train'], batch_size=batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True)
test_loader = torch.utils.data.DataLoader(tokenized_datasets['test'], batch_size=batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True)
# model
Model = ModelTrainer(epochs=10, lr=1e-3)
Model.train(train_loader=train_loader, test_loader=test_loader, batch_size=batch_size)