# Training the Model
In this notebook I will train the derived model (AssemblyModel) to predict protein-ligand affinity. I choose to do this in a notebook because results are easier to track and manipulate.

In [1]:
# imports, solves custom package importing by appending cwd to system paths
import os, sys
parent_dir = os.path.abspath('..')
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

import torch
from torch import nn
from src.models.assembly_model import AssemblyModel

In [2]:
import multiprocessing

K_CORES = multiprocessing.cpu_count()
print("Running threads on {} cpu cores.".format(K_CORES))

Running threads on 16 cpu cores.


In [3]:
# use gpu if it's available
DEVICE = 'cpu'
if torch.cuda.is_available():
    print('Detected GPU {}, training will run on it.'.format(torch.cuda.get_device_name(0)))
    DEVICE = 'cuda'
else:
    print('No GPUs available, will run on cpu.')

Detected GPU NVIDIA GeForce RTX 3080, training will run on it.


In [4]:
# glboal vars controlling training logic
GEN_DATA = False  # will not generate the training data again on disk
# directories to save the generated data in and how we divide the data
train_save_dir = '../data/generated/train'
valid_save_dir = '../data/generated/valid'
train_ratio = 0.8
chunk_size = 30  # we divide data into chunks to save on disk
total_pairs_n = 3000  # total number of pairs

## Preparation
We prepare the training data and functions in this section.

### Training Data

In [5]:
# We choose to have a global preprocess here to save memory.
# The train/test datasets we are gonna later will simply query it.
from src.preprocess.preprocess import TrainPreprocessor

if GEN_DATA:
    train_processor = TrainPreprocessor()

In [6]:
# as the datasets takes a long time to generate (because of the loops in voxelization)
# we cache them on disk in separate batches.
# we accelerate with multi threads
from tqdm.notebook import tqdm
from multiprocessing import Process
from src.gen_dataset import gen_dataset

def gen_dataset_threaded(pairs, cache_dir, rot_aug=True, batch_size=2, chunk_size=30, cache_on_disk=True, k_core=4):
    with tqdm(total=len(pairs)) as pbar:
        for i in range(0, len(pairs), chunk_size * k_core):
            ts = []
            for ti in range(k_core):
                start = i + chunk_size * ti
                if start > len(pairs):
                    break
                end = start + chunk_size
                end = end if end < len(pairs) else len(pairs)
                ts.append(Process(target=gen_dataset, 
                                  args=(pairs[start:end], cache_dir, 
                                        start//chunk_size, train_processor, 
                                        rot_aug, batch_size, cache_on_disk)))
                ts[-1].start()
            for t in ts:
                t.join()
            pbar.write('Processed {} pairs.'.format(i + chunk_size * k_core))
            pbar.update(chunk_size * k_core)

Only run this cell if you want to generate the data again. It will take a long time on cpu!

---

In [7]:
import random

# generate train
if __name__=='__main__':
    if GEN_DATA:
        # generate 80% train and 20% validation
        pairs = list(train_processor.gt_pairs.items())
        pairs_train = random.sample(pairs, int(len(pairs) * train_ratio))
        pairs_valid = [x for x in pairs if x not in pairs_train]
        
        gen_dataset_threaded(pairs_train, train_save_dir, k_core=K_CORES-1)
        gen_dataset_threaded(pairs_valid, valid_save_dir, k_core=K_CORES-1)

---

### Training Functions

In [8]:
# for training loss, we combine l1 and l2 with weights
def train_loss(pred, target):
    eps = 1e-5
    pred = pred.to('cpu')
    target = target.reshape(-1, 1).expand(-1, pred.shape[1])
    loss = -target * torch.log(pred + eps) - (1 - target) * torch.log(1 - pred + eps)
    loss = torch.mean(loss, dim=1)
    loss = torch.sum(loss)
    return loss

In [9]:
def valid_loss(pred, target):
    with torch.no_grad():
        eps = 1e-5
        pred = pred.to('cpu')
        loss = -torch.mean(target * torch.log(pred + eps) + (1 - target) * torch.log(1 - pred + eps))
        return loss

In [10]:
# the training function
def train_one_epoch(epoch_index, batch_rp_size, tb_writer, loader, model, optimizer, loss_fn):
    running_loss = 0.0
    last_loss = 0.0
    
    for batch, X in enumerate(loader):
        x1 = X[0]  # grid
        x2 = X[1]  # embeds
        target = X[2]  # labels
        
        optimizer.zero_grad()
        
        # Compute prediction and loss
        pred = model(x1, x2)
        loss = loss_fn(pred.float(), target.float())
        
        # Back Propagate
        loss.backward()
        optimizer.step()
        
        # Gather data and report to tensorboard
        running_loss += loss.item()
        if batch % batch_rp_size == batch_rp_size - 1:
            last_loss = running_loss / batch_rp_size # loss per batch
            tb_x = epoch_index * len(loader) + batch + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.0
        
    return last_loss

## Training
With everything set, we can start to train the model now.

In [11]:
from src.models.assembly_model import AssemblyModel

# training hyperparams
model_save_dir = '../models'
tb_save_dir = '../models/runs'

TOTAL_EPOCHS = 60  # the total epochs we use to go through the entire dataset
epoch_number = 0
best_vloss = 1_000_000
batch_size = 32
learning_rate = 0.001
early_stop_t = 0.1  # threshold for eTOTAL_EPOCHStopping

# create model
model = AssemblyModel(device=DEVICE)

# optimisation
loss_fn = train_loss
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=4e-5)

In [12]:
# logging
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('{}/train_{}'.format(tb_save_dir, timestamp))

### Training Loops

In [13]:
# as training data is too large, we try to dynamically load it in chunks from disk as we train
from src.gen_dataset import ProteinLigandDataset


with tqdm(total=TOTAL_EPOCHS * 10) as pbar:
    for i in range(TOTAL_EPOCHS):

        tds_indices = list(range(80))
        vds_indices = list(range(20))
        random.shuffle(tds_indices)
        random.shuffle(vds_indices)

        for ds in range(0, 20, 2):
            tds = ds * 4
            # load generated data from disk
            train_set = ProteinLigandDataset([], None, 0, rot_aug=True)
            valid_set = ProteinLigandDataset([], None, 0, rot_aug=True)

            for i in range(tds, tds + 8):
                s = torch.load('{}/{}.data'.format(train_save_dir, tds_indices[i]))
                train_set.concat(s)

            for i in range(ds, ds + 2):
                s = torch.load('{}/{}.data'.format(valid_save_dir, vds_indices[i]))
                valid_set.concat(s)

            train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=K_CORES//2)
            valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=K_CORES//2)

            pbar.write('EPOCH {}:'.format(epoch_number + 1))

            # Make sure gradient tracking is on, and do a pass over the data
            model.train(True)
            avg_loss = train_one_epoch(epoch_number, 4, writer, train_loader, model, optimizer, loss_fn)


            # We don't need gradients on to do reporting
            model.train(False)

            running_vloss = 0.0
            for i, vdata in enumerate(valid_loader):
                vx1, vx2, vlabels = vdata
                voutputs = model(vx1, vx2)
                vloss = valid_loss(voutputs, vlabels)
                running_vloss += vloss

            avg_vloss = running_vloss / (i + 1)
            pbar.write('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

            # Log the running loss averaged per batch
            # for both training and validation
            writer.add_scalars('Training vs. Validation Loss',
                            { 'Training' : avg_loss, 'Validation' : avg_vloss },
                            epoch_number + 1)
            writer.flush()

            # Track best performance, and save the model's state
            if avg_vloss < best_vloss:
                best_vloss = avg_vloss
                model_path = '{}/model_{}_{}'.format(model_save_dir, timestamp, epoch_number)
                torch.save(model.state_dict(), model_path)

            epoch_number += 1
            pbar.update(1)

  0%|          | 0/600 [00:00<?, ?it/s]

EPOCH 1:
LOSS train 23.828877449035645 valid 0.7345786271577405
EPOCH 2:
LOSS train 25.104795932769775 valid 1.2149989417820226
EPOCH 3:
LOSS train 24.717485904693604 valid 0.7343309702023758
EPOCH 4:
LOSS train 26.88937282562256 valid 0.7317704448323591
EPOCH 5:
LOSS train 26.84062623977661 valid 0.7121472005261845
EPOCH 6:
LOSS train 26.15716314315796 valid 0.7581125077826769
EPOCH 7:
LOSS train 26.850772857666016 valid 0.8837772757607318
EPOCH 8:
LOSS train 27.981881618499756 valid 1.078766767382019
EPOCH 9:
LOSS train 27.311137199401855 valid 0.7792108246650352
EPOCH 10:
LOSS train 23.091619968414307 valid 0.7864697836312861
EPOCH 11:
LOSS train 24.98884391784668 valid 0.8572819390791765
EPOCH 12:
LOSS train 24.19709825515747 valid 0.7719017570202037
EPOCH 13:
LOSS train 30.18438959121704 valid 0.8641032175451229
EPOCH 14:
LOSS train 27.71218490600586 valid 0.8412620658434745
EPOCH 15:
LOSS train 25.20663356781006 valid 0.817432746189867
EPOCH 16:
LOSS train 24.960410594940186 vali

KeyboardInterrupt: 

### Example Predicting Pairs

In [6]:
from src.gen_dataset import ProteinLigandDataset

model_path = '../models/model_20220228_223914_20'
model = AssemblyModel(device=DEVICE)
model.load_state_dict(torch.load(model_path))

train_set = ProteinLigandDataset([], None, 0, rot_aug=True)
s = torch.load('{}/{}.data'.format(train_save_dir, 1))
train_set.concat(s)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=4, shuffle=True, num_workers=K_CORES//2)

preds = []
labels = []

for i, X in enumerate(train_loader):
    x1, x2, label = X
    model.train(False)
    pred = model(x1, x2)
    preds.append(pred)
    labels.append(label)

In [7]:
print(preds[:5])

[tensor([[0.5165],
        [0.5179],
        [0.5050],
        [0.5156]], device='cuda:0', grad_fn=<SigmoidBackward0>), tensor([[0.5239],
        [0.5128],
        [0.5833],
        [0.5146]], device='cuda:0', grad_fn=<SigmoidBackward0>), tensor([[0.5136],
        [0.5146],
        [0.5169],
        [0.5085]], device='cuda:0', grad_fn=<SigmoidBackward0>), tensor([[0.8928],
        [0.6253],
        [0.5193],
        [0.5146]], device='cuda:0', grad_fn=<SigmoidBackward0>), tensor([[0.5193],
        [0.5083],
        [0.5141],
        [0.6421]], device='cuda:0', grad_fn=<SigmoidBackward0>)]


In [8]:
print(labels[:5])

[tensor([0., 0., 0., 1.], dtype=torch.float64), tensor([0., 1., 0., 1.], dtype=torch.float64), tensor([1., 1., 0., 1.], dtype=torch.float64), tensor([0., 1., 0., 1.], dtype=torch.float64), tensor([0., 1., 1., 0.], dtype=torch.float64)]
