# Training the Model
In this notebook I will train the derived model (AssemblyModel) to predict protein-ligand affinity. I choose to do this in a notebook because results are easier to track and manipulate.

In [1]:
# imports, solves custom package importing by appending cwd to system paths
import os, sys
parent_dir = os.path.abspath('..')
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

import torch
from torch import nn
from src.models.assembly_model import AssemblyModel

In [2]:
# use gpu if it's available
DEVICE = 'cpu'
if torch.cuda.is_available():
    print('Detected GPU {}, training will run on it.'.format(torch.cuda.get_device_name(0)))
    DEVICE = 'cuda'
else:
    print('No GPUs available, will run on cpu.')

No GPUs available, will run on cpu.


## Preparation
We prepare the training data and functions in this section.

### Training Data

In [3]:
# We choose to have a global preprocess here to save memory.
# The train/test datasets we are gonna later will simply query it.
from src.preprocess.preprocess import TrainPreprocessor

train_processor = TrainPreprocessor()

In [21]:
# define a custom dataset format for our data
import random
import pandas as pd
from typing import NamedTuple
from torch.utils.data import Dataset
from src.preprocess.bind_grid import BindGrid


class TrainDataItem(NamedTuple):
    grid: torch.Tensor
    embed: torch.Tensor
    label: float

class ProteinLigandDataset(Dataset):
    def __init__(self, pairs, batch_size, rot_aug=True):
        self.pairs = pairs  # a list of pairs from querying pair.csv
        self.batch_size = batch_size
        self.rot_aug = rot_aug  # augment data by random grid rotation
        self.data = []
        self.gen_data()
    
    # based on the given pairs, generate the full dataset with both true and false pairs
    def gen_data(self):
        smiles = []
        grids = []
        labels = []
        for pair in self.pairs:
            pid, lid = pair
            for i in range(self.batch_size):
                lid = int(lid)
                if i != 0:  # generate false pair
                    prev_lid = lid
                    while prev_lid == lid or (str(lid) not in train_processor.gt_pairs.values()):
                        lid = random.randint(1, len(train_processor.ligands))
                lid = str(lid)
                grid = BindGrid(train_processor.proteins[pid], train_processor.ligands[lid], train_processor.centroids[pid])
                gs = [grid.grid]
                if self.rot_aug:
                    gs = grid.rotation_augment()
                for j in range(len(gs)):
                    grids.append(gs[j])
                    smiles.append(train_processor.ligands[lid].smiles)
                    if i == 0:
                        labels.append(1.0)
                    else:
                        labels.append(0.0)
        
        embeds = train_processor.generate_embeddings(smiles)
        for i in range(len(labels)):
            self.data.append(TrainDataItem(grids[i], embeds[i], labels[i]))
            
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        i = self.data[idx]
        return i.grid, i.embed, i.label
        

In [25]:
train_dataset = ProteinLigandDataset(list(train_processor.gt_pairs.items())[:5], 2, True)

In [26]:
print(len(train_dataset))

40


In [None]:
# prepare the data from the pair ids given, returns a dataloader
# a pair id is just the index of that pair in pairs.csv
def prepare_data(batch_size, pair_ids):
    

### Training Functions

In [2]:
# the train loss accumulates loss from each 1024 vector
def train_loss(pred, target, loss_fn):
    target = target.expand(-1, pred.size()[1])
    loss = loss_fn(pred, target)
    return loss

In [3]:
# the training function
def train(loader, model, loss_fn, optimizer, lapse):
    for batch, (X, y) in enumerate(loader):
        x1 = X[0]
        x2 = X[1]
        # Compute prediction and loss
        pred = model(x1, x2)
        loss = train_loss(pred, target, loss_fn)
        
        # Back Propagate
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch % lapse == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")