In [None]:
import pandas as pd
import torch
import os
from torch.utils.data import random_split
from torch.utils.data import Dataset,DataLoader
from utils.Func import extract_cp_feature
from dgl.data.utils import load_graphs
from utils.PAVC_dataset import PAVC_Dataset_Train
from utils.collator import Collator_PAVC_Train
from utils.PAVC_trainer import PAVC_Trainer
from utils.scheduler import PolynomialDecayLR
from torch.optim import Adam
from torch.nn import  BCEWithLogitsLoss
from utils.model.KPGT import *
from utils.featurizer import Vocab, N_ATOM_TYPES, N_BOND_TYPES,VIRTUAL_ATOM_FEATURE_PLACEHOLDER, VIRTUAL_BOND_FEATURE_PLACEHOLDER

In [28]:
args = {
    'config':'base',
    'd_fps': 512,
    'd_mds': 200,
    'dropout':0,
    'weight_decay':1e-6,
    'n_tasks':1,
    'lr': 3e-5,
    'model_path':'../pretrained_model/KPGT/KPGT.pth',
    'cp_feature_dir': '../data/DeepPAVC/cp_feature',
    'n_epochs':10, 
    'device':'cuda:0',
    'random_seed': 42,
    'batch_size':32,
    'train_ratio': 0.8,
    'val_ratio': 0.1,
    'test_ratio':0.1,
    'MLP_layer_num':2,
    'MLP_hidden_dim':256}
vocab = Vocab(N_ATOM_TYPES, N_BOND_TYPES)

In [3]:
# Load DeepPAVC dataset
DeepPAVC_dataset = pd.read_csv('../data/DeepPAVC/DeepPAVC_dataset.csv')

In [5]:
smiles_list = DeepPAVC_dataset['smiles'].to_list()

In [None]:
# Extract initial compound feature (by RDkit)
extract_cp_feature(smiles_list= smiles_list,
                   output_dir= '../data/DeepPAVC/cp_feature',
                   num_workers=32)

In [29]:
# Load molecular initial feature
graphs, label_dict = load_graphs(os.path.join(args['cp_feature_dir'], 'cp_graphs.pkl'))
fps = torch.load(os.path.join(args['cp_feature_dir'], 'cp_fps.pt'))
mds = torch.load(os.path.join(args['cp_feature_dir'], 'cp_mds.pt'))

In [30]:
assert len(graphs) == len(fps) == len(mds)

In [8]:
# Build dataset
PAVC_dataset = PAVC_Dataset_Train(smiles_list = smiles_list,
                             graphs = graphs,
                             ecfps = fps,
                             mds = mds,
                             label_list= DeepPAVC_dataset['label'].to_list())

In [None]:
### data split
train_ratio = args['train_ratio']
val_ratio = args['val_ratio']
dataset_size = len(PAVC_dataset)
train_size = int(train_ratio * dataset_size) 
val_size = int(val_ratio * dataset_size)   
test_size = dataset_size - train_size - val_size 

torch.manual_seed(42)
train_dataset, val_dataset, test_dataset = random_split(PAVC_dataset, [train_size, val_size, test_size])
print(f'Train size:{len(train_dataset)}\nValidation size:{len(val_dataset)}\nTest size:{len(test_dataset)}')

### build dataloader 
config = config_dict[args['config']]
collator = Collator_PAVC_Train(config['path_length'])

train_loader = DataLoader(train_dataset, 
                          batch_size=args['batch_size'], 
                          shuffle=True,  
                          drop_last=False, 
                          collate_fn=collator)
val_loader = DataLoader(val_dataset, batch_size=args['batch_size'], 
                        shuffle=False,
                          drop_last=False, 
                          collate_fn=collator)
test_loader = DataLoader(test_dataset, batch_size=args['batch_size'], 
                         shuffle=False,  
                         drop_last=False, 
                         collate_fn=collator)


In [None]:
###  model initialization
DeepPAVC = LiGhTPredictor(
    d_node_feats=config['d_node_feats'],
    d_edge_feats=config['d_edge_feats'],
    d_g_feats=config['d_g_feats'],
    d_fp_feats=args['d_fps'],
    d_md_feats=args['d_mds'],
    d_hpath_ratio=config['d_hpath_ratio'],
    n_mol_layers=config['n_mol_layers'],
    path_length=config['path_length'],
    n_heads=config['n_heads'],
    n_ffn_dense_layers=config['n_ffn_dense_layers'],
    input_drop=0,
    attn_drop=args['dropout'],
    feat_drop=args['dropout'],
    n_node_types=vocab.vocab_size).to(args['device'])

In [12]:
# add classification head
def get_predictor(d_input_feats, n_tasks, n_layers, predictor_drop, device, d_hidden_feats=None):
    if n_layers == 1:
        predictor = nn.Linear(d_input_feats, n_tasks)
    else:
        predictor = nn.ModuleList()
        predictor.append(nn.Linear(d_input_feats, d_hidden_feats))
        predictor.append(nn.Dropout(predictor_drop))
        predictor.append(nn.GELU())
        for _ in range(n_layers-2):
            predictor.append(nn.Linear(d_hidden_feats, d_hidden_feats))
            predictor.append(nn.Dropout(predictor_drop))
            predictor.append(nn.GELU())
        predictor.append(nn.Linear(d_hidden_feats, n_tasks))
        predictor = nn.Sequential(*predictor)
    predictor.apply(lambda module: init_params(module))
    return predictor.to(device)

In [None]:
### load pretrained weights
DeepPAVC.load_state_dict({k.replace('module.',''):v for k,v in torch.load(args['model_path'],map_location=args['device']).items()})

In [None]:
# Optional: forzen KPGT's parameters and only train MLP 
# for param in model.parameters():
#     param.requires_grad = False

In [15]:
# delete unused block
del DeepPAVC.md_predictor
del DeepPAVC.fp_predictor
del DeepPAVC.node_predictor

In [17]:
# add MLP for classification
DeepPAVC.predictor = get_predictor(d_input_feats=config['d_g_feats']*3, 
                                n_tasks=args['n_tasks'], 
                                n_layers=args['MLP_layer_num'], 
                                predictor_drop=args['dropout'], 
                                device=args['device'], 
                                d_hidden_feats=args['MLP_hidden_dim'])

In [None]:
# print trainable parameters size
print("Model has {:.4f}M parameters that require gradient updates".format(
sum(x.numel() for x in DeepPAVC.parameters() if x.requires_grad) / 1e6))

In [24]:
optimizer = Adam( filter(lambda p: p.requires_grad, DeepPAVC.parameters()),lr=args['lr'], weight_decay=args['weight_decay'])

lr_scheduler = PolynomialDecayLR(optimizer, 
                                 warmup_updates=args['n_epochs']*(len(train_loader))//100, 
                                 tot_updates=args['n_epochs']* (len(train_loader)) , 
                                 lr=args['lr'], 
                                 end_lr=1e-6,
                                 power=1)

loss_fn = BCEWithLogitsLoss(reduction='mean')

In [26]:
### DeepPAVC Trainer initialization
trainer = PAVC_Trainer(args, 
                  optimizer, 
                  lr_scheduler, 
                  loss_fn,
                 device=args['device'],
                 model_name='DeepPAVC')

In [None]:
# train DeepPAVC model
performance_res_df = trainer.fit(model=DeepPAVC,
                      train_loader=train_loader,
                      val_loader=val_loader)