# Training the Model
In this notebook I will train the derived model (AssemblyModel) to predict protein-ligand affinity. I choose to do this in a notebook because results are easier to track and manipulate.

In [1]:
# imports, solves custom package importing by appending cwd to system paths
import os, sys
parent_dir = os.path.abspath('..')
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

import torch
from torch import nn
from src.models.assembly_model import AssemblyModel

In [2]:
import multiprocessing

K_CORES = multiprocessing.cpu_count()
print("Running threads on {} cpu cores.".format(K_CORES))

Running threads on 16 cpu cores.


In [3]:
# use gpu if it's available
DEVICE = 'cpu'
if torch.cuda.is_available():
    print('Detected GPU {}, training will run on it.'.format(torch.cuda.get_device_name(0)))
    DEVICE = 'cuda'
else:
    print('No GPUs available, will run on cpu.')

Detected GPU NVIDIA GeForce RTX 3080, training will run on it.


## Preparation
We prepare the training data and functions in this section.

### Training Data

In [4]:
# We choose to have a global preprocess here to save memory.
# The train/test datasets we are gonna later will simply query it.
from src.preprocess.preprocess import TrainPreprocessor

train_processor = TrainPreprocessor()

In [5]:
# as the datasets takes a long time to generate (because of the triple loops in voxelization)
# we cache them on disk in separate batches.
# we accelerate with multi threads
from tqdm.notebook import tqdm
from multiprocessing import Process
from src.gen_dataset import gen_dataset

def gen_dataset_threaded(pairs, cache_dir, rot_aug=True, batch_size=2, chunk_size=30, cache_on_disk=True, k_core=4):
    with tqdm(total=len(pairs)) as pbar:
        for i in range(0, len(pairs), chunk_size * k_core):
            ts = []
            for ti in range(k_core):
                start = i + chunk_size * ti
                if start > len(pairs):
                    break
                end = start + chunk_size
                end = end if end < len(pairs) else len(pairs)
                ts.append(Process(target=gen_dataset, 
                                  args=(pairs[start:end], cache_dir, 
                                        start//chunk_size, train_processor, 
                                        rot_aug, batch_size, cache_on_disk)))
                ts[-1].start()
            for t in ts:
                t.join()
            pbar.write('Processed {} pairs.'.format(i + chunk_size * k_core))
            pbar.update(chunk_size * k_core)

In [6]:
# generate 80% train and 20% validation
import random

train_save_dir = '../data/generated/train'
valid_save_dir = '../data/generated/valid'

pairs = list(train_processor.gt_pairs.items())
pairs_train = random.sample(pairs, int(len(pairs) * 0.8))
pairs_valid = [x for x in pairs if x not in pairs_train]

Only run these two cells if you want to generate the data again.

---

In [7]:
# generate train
if __name__=='__main__':
    gen_dataset_threaded(pairs_train, train_save_dir, k_core=K_CORES-1)

  0%|          | 0/2394 [00:00<?, ?it/s]

Processed 450 pairs.
Processed 900 pairs.
Processed 1350 pairs.
Processed 1800 pairs.
Processed 2250 pairs.
Processed 2700 pairs.


In [8]:
# generate valid
if __name__=='__main__':
    gen_dataset_threaded(pairs_valid, valid_save_dir, k_core=K_CORES-1)

  0%|          | 0/599 [00:00<?, ?it/s]

Processed 450 pairs.
Processed 900 pairs.


---

In [None]:
# prepare the data from the pair ids given, returns a dataloader
# a pair id is just the index of that pair in pairs.csv
def prepare_data(batch_size, pair_ids):
    

### Training Functions

In [None]:
# the train loss accumulates loss from each 1024 vector
def train_loss(pred, target, loss_fn):
    target = target.expand(-1, pred.size()[1])
    loss = loss_fn(pred, target)
    return loss

In [None]:
# the training function
def train(loader, model, loss_fn, optimizer, lapse):
    for batch, (X, y) in enumerate(loader):
        x1 = X[0]
        x2 = X[1]
        # Compute prediction and loss
        pred = model(x1, x2)
        loss = train_loss(pred, target, loss_fn)
        
        # Back Propagate
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch % lapse == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")