In [2]:
import os

os.chdir('..')

In [3]:
import transformers

transformers.__version__

'4.35.2'

In [None]:
import

In [5]:
from src.model import BERT_CONFIG, FocalLoss, TCRModel

In [None]:
from torch.utils.data.sampler import SubsetRandomSample
from sklearn.model_selection import train_test_split


class ModelTrainer(TCRModel):
    '''
        ref: https://github.com/EnthusiasticTeslim/PianoGen/blob/master/hw1.py
        '''
    
    def __init__(self, args, train=False, seed = 2023, lr=2e-5, epochs=1000, log_interval=200, verbose=True, model_dir='model_save'):
        
        self.seed = seed # seed for random number generator
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=args.lr)
        self.epochs = epochs # number of epochs to train
        self.log_interval = log_interval # how many batches to wait before logging training status
        self.device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
        self.model = TCRModel.to(self.device) # takes input_ids, attention_mask, classification

    def train(self, train_loader):
        
        # set the seed for generating random numbers
        torch.manual_seed(args.seed)

        if torch.cuda.is_available():
            torch.cuda.manual_seed(args.seed)
        

        for epoch in range(0, self.epochs + 1):
            self.model.train()
            for step, batch in enumerate(train_loader):

                input_ids = batch['input_ids'].to(device)
                input_mask = batch['attention_mask'].to(device)
                labels = batch['targets'].to(device)

                outputs = self.model(
                                    input_ids = input_ids  # amino acid index numbers
                                    attention_mask = input_mask, # attention mask (1 for non-padding token and 0 for padding)
                                    classification = True # True for classification task
                                    )
                self.model.to(self.device)
                loss = clf_loss_func(input=outputs, target=labels)

                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
                
                if step % self.log_interval == 0:
                    logger.info(
                        "Train Epoch: {} [{}/{} ({:.0f}%)] Training Loss: {:.6f}".format(
                            epoch,
                            step * len(batch['input_ids'])*world_size,
                            len(train_loader.dataset),
                            100.0 * step / len(train_loader),
                            loss
                        )
                    )
        
        def test(self, test_loader):
            
            self.model.eval()
            sum_losses = []
            correct_predictions = 0
            loss_fn = nn.CrossEntropyLoss().to(self.device)
            tmp_eval_accuracy, eval_accuracy = 0, 0
                
            with torch.no_grad():
                for batch in test_loader:
                    
                    input_ids = batch['input_ids'].to(device)
                    input_mask = batch['attention_mask'].to(device)
                    labels = batch['targets'].to(device)

                    outputs = self.model(input=input_ids, attention_mask=input_mask, classification=True)
                    
                    loss = clf_loss_func(input=outputs, target=labels)

                    correct_predictions += torch.sum(torch.max(outputs, dim=1) == labels)
                    sum_losses.append(loss)
                        
            print('\nTest set: loss: {:.4f}, Accuracy: {:.0f}%\n'.format(
                    np.mean(sum_losses), 100. * correct_predictions.double() / len(test_loader.dataset)))
