# Training the Model
In this notebook I will train the derived model (AssemblyModel) to predict protein-ligand affinity. I choose to do this in a notebook because results are easier to track and manipulate.

In [None]:
# imports, solves custom package importing by appending cwd to system paths
import os, sys
parent_dir = os.path.abspath('..')
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

import torch
from torch import nn
from src.models.assembly_model import AssemblyModel

In [None]:
# use gpu if it's available
DEVICE = 'cpu'
if torch.cuda.is_available():
    print('Detected GPU {}, training will run on it.'.format(torch.cuda.get_device_name(0)))
    DEVICE = 'cuda'
else:
    print('No GPUs available, will run on cpu.')

## Preparation
We prepare the training data and functions in this section.

### Training Data

In [None]:
# We choose to have a global preprocess here to save memory.
# The train/test datasets we are gonna later will simply query it.
from src.preprocess.preprocess import TrainPreprocessor

train_processor = TrainPreprocessor()

In [None]:
# define a custom dataset format for our data
import random
import pandas as pd
from typing import NamedTuple
from torch.utils.data import Dataset
from src.preprocess.bind_grid import BindGrid
from src.global_vars import AUGMENT_ROTATION


class DataItem(NamedTuple):
    grid: torch.Tensor
    embed: torch.Tensor
    label: float

class ProteinLigandDataset(Dataset):
    def __init__(self, pairs, batch_size, rot_aug=True):
        self.pairs = pairs  # a list of pairs from querying pair.csv
        self.batch_size = batch_size
        self.rot_aug = rot_aug  # augment data by random grid rotation
        self.data = []
        self.gen_data()
    
    # based on the given pairs, generate the full dataset with both true and false pairs
    def gen_data(self):
        smiles = []
        grids = []
        labels = []
        for pair in self.pairs:
            pid, lid = pair
            for i in range(self.batch_size):
                lid = int(lid)
                if i != 0:  # generate false pair
                    prev_lid = lid
                    while prev_lid == lid or (str(lid) not in train_processor.gt_pairs.values()):
                        lid = random.randint(1, len(train_processor.ligands))
                lid = str(lid)
                grid = BindGrid(train_processor.proteins[pid], train_processor.ligands[lid], train_processor.centroids[pid])
                gs = [grid.grid]
                smiles.append(train_processor.ligands[lid].smiles)
                if self.rot_aug:
                    gs = grid.rotation_augment()
                for j in range(len(gs)):
                    grids.append(gs[j])
                    if i == 0:
                        labels.append(1.0)
                    else:
                        labels.append(0.0)
        
        embeds = train_processor.generate_embeddings(smiles)
        for i in range(len(labels)):
            self.data.append(DataItem(grids[i], embeds[i//(AUGMENT_ROTATION+1)], labels[i]))
            
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        i = self.data[idx]
        return i.grid, i.embed, i.label
        

In [None]:
# as the datasets takes a long time to generate (because of the loops in voxelization)
# we cache them on disk in separate batches.
from tqdm.notebook import tqdm

def gen_dataset(pairs, cache_dir, rot_aug=True, batch_size=2, chunk_size=30, cache_on_disk=True):
    with tqdm(total=len(pairs)) as pbar:
        for i in range(0, len(pairs), chunk_size):
            e = i + chunk_size if i + chunk_size <= len(pairs) else len(pairs)
            dataset = ProteinLigandDataset(pairs[i:e], batch_size, rot_aug)
            torch.save(dataset, '{}/{}.data'.format(cache_dir, i))
            pbar.write('Processed {} pairs.'.format(e-i))
            pbar.update(e-i)

In [None]:
# generate 80% train and 20% validation
train_save_dir = '../data/generated/train'
valid_save_dir = '../data/generated/valid'

pairs = list(train_processor.gt_pairs.items())
pairs_train = random.sample(pairs, int(len(pairs) * 0.8))
pairs_valid = [x for x in pairs if x not in pairs_train]

Only run these two cells if you want to generate the data again.

---

In [None]:
# generate train
gen_dataset(pairs_train, train_save_dir)

In [None]:
# generate valid
gen_dataset(pairs_valid, valid_save_dir)

---

In [None]:
# prepare the data from the pair ids given, returns a dataloader
# a pair id is just the index of that pair in pairs.csv
def prepare_data(batch_size, pair_ids):
    

### Training Functions

In [None]:
# the train loss accumulates loss from each 1024 vector
def train_loss(pred, target, loss_fn):
    target = target.expand(-1, pred.size()[1])
    loss = loss_fn(pred, target)
    return loss

In [None]:
# the training function
def train(loader, model, loss_fn, optimizer, lapse):
    for batch, (X, y) in enumerate(loader):
        x1 = X[0]
        x2 = X[1]
        # Compute prediction and loss
        pred = model(x1, x2)
        loss = train_loss(pred, target, loss_fn)
        
        # Back Propagate
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch % lapse == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")