# Training the Model
In this notebook I will train the derived model (AssemblyModel) to predict protein-ligand affinity. I choose to do this in a notebook because results are easier to track and manipulate.

In [8]:
# imports, solves custom package importing by appending cwd to system paths
import os, sys
parent_dir = os.path.abspath('..')
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

import torch
from torch import nn
from src.models.assembly_model import AssemblyModel

In [9]:
import multiprocessing

K_CORES = multiprocessing.cpu_count()
print("Running threads on {} cpu cores.".format(K_CORES))

Running threads on 16 cpu cores.


In [10]:
# use gpu if it's available
DEVICE = 'cpu'
if torch.cuda.is_available():
    print('Detected GPU {}, training will run on it.'.format(torch.cuda.get_device_name(0)))
    DEVICE = 'cuda'
else:
    print('No GPUs available, will run on cpu.')

No GPUs available, will run on cpu.


In [18]:
# glboal vars controlling training logic
GEN_DATA = False  # will not generate the training data again on disk
# directories to save the generated data in and how we divide the data
train_save_dir = '../data/generated/train'
valid_save_dir = '../data/generated/valid'
train_ratio = 0.8
chunk_size = 30  # we divide data into chunks to save on disk
total_pairs_n = 3000  # total number of pairs

## Preparation
We prepare the training data and functions in this section.

### Training Data

In [12]:
# We choose to have a global preprocess here to save memory.
# The train/test datasets we are gonna later will simply query it.
from src.preprocess.preprocess import TrainPreprocessor

if GEN_DATA:
    train_processor = TrainPreprocessor()

In [13]:
# as the datasets takes a long time to generate (because of the loops in voxelization)
# we cache them on disk in separate batches.
# we accelerate with multi threads
from tqdm.notebook import tqdm
from multiprocessing import Process
from src.gen_dataset import gen_dataset

def gen_dataset_threaded(pairs, cache_dir, rot_aug=True, batch_size=2, chunk_size=30, cache_on_disk=True, k_core=4):
    with tqdm(total=len(pairs)) as pbar:
        for i in range(0, len(pairs), chunk_size * k_core):
            ts = []
            for ti in range(k_core):
                start = i + chunk_size * ti
                if start > len(pairs):
                    break
                end = start + chunk_size
                end = end if end < len(pairs) else len(pairs)
                ts.append(Process(target=gen_dataset, 
                                  args=(pairs[start:end], cache_dir, 
                                        start//chunk_size, train_processor, 
                                        rot_aug, batch_size, cache_on_disk)))
                ts[-1].start()
            for t in ts:
                t.join()
            pbar.write('Processed {} pairs.'.format(i + chunk_size * k_core))
            pbar.update(chunk_size * k_core)

Only run this cell if you want to generate the data again. It will take a long time on cpu!

---

In [14]:
import random

# generate train
if __name__=='__main__':
    if GEN_DATA:
        # generate 80% train and 20% validation
        pairs = list(train_processor.gt_pairs.items())
        pairs_train = random.sample(pairs, int(len(pairs) * train_ratio))
        pairs_valid = [x for x in pairs if x not in pairs_train]
        
        gen_dataset_threaded(pairs_train, train_save_dir, k_core=K_CORES-1)
        gen_dataset_threaded(pairs_valid, valid_save_dir, k_core=K_CORES-1)

---

In [20]:
# load generated data from disk
from src.gen_dataset import ProteinLigandDataset

train_set = ProteinLigandDataset()
valid_set = ProteinLigandDataset()

for i in range(total_pairs_n * train_ratio // chunk_size + 1):
    s = torch.load('{}/{}.data'.format(train_save_dir, i))
    train_set.concat(s)

for j in range((total_pairs_n - valid_ratio) // chunk_size + 1):
    s = torch.load('{}/{}.data'.format(valid_save_dir, j))
    valid_set.concat(s)

TypeError: __init__() missing 3 required positional arguments: 'pairs', 'train_processor', and 'batch_size'

### Training Functions

In [None]:
# the train loss accumulates loss from each 1024 vector
def train_loss(pred, target, loss_fn):
    target = target.expand(-1, pred.shape[1])
    loss = loss_fn(pred, target)
    return loss

In [None]:
def valid_loss(pred, target, loss_fn):
    pred = torch.mean(pred, pred.shape[1])
    loss = loss_fn(pred, target)
    return loss

In [None]:
# the training function
def train_one_epoch(epoch_index, batch_size, tb_writer, loader, model, loss_fn, optimizer):
    running_loss = 0.
    last_loss = 0.
    
    for batch, X in enumerate(loader):
        x1 = X[0]  # grid
        x2 = X[1]  # embeds
        target = X[2]  # labels
        # Compute prediction and loss
        pred = model(x1, x2)
        loss = train_loss(pred, target, loss_fn)
        
        # Back Propagate
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Gather data and report to tensorboard
        running_loss += loss.item()
        if i % batch_size == batch_size - 1:
            last_loss = running_loss / batch_size # loss per batch
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.
        
    return last_loss

## Training
With everything set, we can start to train the model now.

In [None]:
from src.models.assembly_model import AssemblyModel

# training hyperparams
model_save_dir = '../models'
tb_save_dir = '../models/runs'

EPOCHS = 50
best_loss = 1_000_000
batch_size = 4
learning_rate = 0.001
early_stop_t = 0.1  # threshold for ealy stopping

# create model
model = AssemblyModel(DEVICE)

# utilities
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# create dataloaders
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=K_CORES//2)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=K_CORES//2)

In [None]:
# finally, let's train!
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('{}/train_{}'.format(tb_save_dir, timestamp))

epoch_number = 0

with tqdm(total=EPOCHS) as pbar:
    for epoch in range(EPOCHS):
        pbar.write('EPOCH {}:'.format(epoch_number + 1))

        # Make sure gradient tracking is on, and do a pass over the data
        model.train(True)
        avg_loss = train_one_epoch(epoch_number, batch_size, writer, train_loader, model, loss_fn, optimizer)

        # We don't need gradients on to do reporting
        model.train(False)

        running_vloss = 0.0
        for i, vdata in enumerate(valid_loader):
            vx1, vx2, vlabels = vdata
            voutputs = model(vx1, vx2)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

        avg_vloss = running_vloss / (i + 1)
        pbar.write('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

        # Log the running loss averaged per batch
        # for both training and validation
        writer.add_scalars('Training vs. Validation Loss',
                        { 'Training' : avg_loss, 'Validation' : avg_vloss },
                        epoch_number + 1)
        writer.flush()

        # Track best performance, and save the model's state
        if avg_vloss < best_vloss:
            best_vloss = avg_vloss
            model_path = '{}/model_{}_{}'.format(model_save_dir, timestamp, epoch_number)
            torch.save(model.state_dict(), model_path)

        epoch_number += 1
        pbar.update(epoch_number)