# Model Training Notebook
- This notebook provides a guide to reproduce the results of the model or fine-tune the model. 
- It includes preprocessing, data loading, and model training.

In [1]:
import os
import tqdm
import numpy as np
import pandas as pd
import random 
import copy

import torch
from torch import nn
from torch.utils.data import DataLoader

from unimol_tools import utils
from unimol_tools.data import DataHub
from unimol_tools.models import UniMolModel

from src.dataset import PPIInhibitorDataset, process_interface
from src.model import PPIInhibitorModel
from src.utils import train, predict, performance_evaluation, batch_collate_fn

2024-10-29 10:03:38 | unimol_tools/weights/weighthub.py | 17 | INFO | Uni-Mol Tools | Weights will be downloaded to default directory: /data/dongok/anaconda3/envs/unimol/lib/python3.9/site-packages/unimol_tools/weights


In [2]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [4]:
seed = 2022 
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [5]:
#config
batch_size = 64
learning_rate = 5e-4
num_epochs = 70

In [6]:
#preprocess interface infomation
process_interface('data/ppi_interface.csv').to_csv('data/processed_interface.csv')

In [7]:
#dataloader

#S1 (Unseen Interaction): Interactions between PPIs and compounds are not seen during training, while both PPIs and compounds are seen individually.
#S2 (Unseen Compound): Compounds are not seen during training; PPIs are seen.
#S3 (Unseen PPI): PPIs are not seen during training; compounds are seen.
#S4 (Unseen Both): Both PPIs and compounds are not seen during training.

eval_setting = 'S1'
fold = 1

train_dataset = PPIInhibitorDataset(f'data/folds/{eval_setting}/train_fold{fold}.csv', device)  
valid_dataset = PPIInhibitorDataset(f'data/folds/{eval_setting}/valid_fold{fold}.csv', device)
test_dataset = PPIInhibitorDataset(f'data/folds/{eval_setting}/test_fold{fold}.csv', device)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False, collate_fn=batch_collate_fn)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, drop_last=False, collate_fn=batch_collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, collate_fn=batch_collate_fn)

In [8]:
#model
compound_model = UniMolModel()
unimol_path = 'src/weights/mol_pre_all_h_220816.pt'
compound_model.load_pretrained_weights(unimol_path)

model = PPIInhibitorModel(compound_model).to(device)

for name, param in model.named_parameters():
    param.requires_grad = True

    for layer in ['compound_model']:
        if layer in name:
            param.requires_grad = False

    for layer in ['encoder.layers.14']:
        if layer in name:
            param.requires_grad = True

    #print(name, param.requires_grad)
    
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

2024-10-29 10:04:00 | unimol_tools/models/unimol.py | 120 | INFO | Uni-Mol Tools | Loading pretrained weights from /data/dongok/anaconda3/envs/unimol/lib/python3.9/site-packages/unimol_tools/weights/mol_pre_all_h_220816.pt
2024-10-29 10:04:01 | unimol_tools/models/unimol.py | 120 | INFO | Uni-Mol Tools | Loading pretrained weights from src/weights/mol_pre_all_h_220816.pt


In [None]:
#training loop
best_model = None
best_auc = 0
best_epoch = 0

tr_aucs = []
tr_auprs = []
tr_losses = []

val_aucs = []
val_auprs = []

for epoch in range(num_epochs):
    print('Epoch {}'.format(epoch))

    label, pred, tr_loss = train(model, train_dataloader, optimizer, criterion, device)
    current_auc, current_aupr = performance_evaluation(label, pred)
    print('Train AUC:\t{}'.format(current_auc))
    print('Train AUPR:\t{}'.format(current_aupr))
    
    tr_aucs.append(current_auc)
    tr_auprs.append(current_aupr)
    tr_losses.append(tr_loss)


    label, pred = predict(model, valid_dataloader, device)
    current_auc, current_aupr = performance_evaluation(label, pred)
    print('Val AUC:\t{}'.format(current_auc))
    print('Val AUPR:\t{}'.format(current_aupr))
    
    if current_auc > best_auc:
        best_model = copy.deepcopy(model)
        best_auc = current_auc
        best_epoch = epoch
        print('AUC is improved at epoch {}\tbest AUC: {}'.format(best_epoch, best_auc))

    val_aucs.append(current_auc)
    val_auprs.append(current_aupr)

In [None]:
#save the best model
torch.save(best_model.state_dict(), 'best_model.model')