### First training stage:  Train CADTI

In [None]:
import pandas as pd
import torch
import os
from torch.utils.data import random_split
from torch.utils.data import Dataset,DataLoader
from utils.Func import extract_esm_feature,seq2fasta,extract_cp_feature,filter_invalid_smiles
from dgl.data.utils import load_graphs
from utils.TAVC_dataset import TAVC_Dataset_Train
from utils.collator import Collator_TAVC_Train
from utils.TAVC_trainer import TAVC_Trainer
from utils.scheduler import PolynomialDecayLR
from torch.optim import Adam
from torch.nn import  BCEWithLogitsLoss
from utils.model.KPGT_v2 import *
from utils.model.DeepAVC import *
from utils.featurizer import Vocab, N_ATOM_TYPES, N_BOND_TYPES,VIRTUAL_ATOM_FEATURE_PLACEHOLDER, VIRTUAL_BOND_FEATURE_PLACEHOLDER

In [40]:
args = {
    'config':'base',
    'd_fps': 512,
    'd_mds': 200,
    'dropout':0,
    'weight_decay':1e-6,
    'n_tasks':1,
    'lr': 1e-4,
    'kpgt_model_path':'../pretrained_model/KPGT/KPGT.pth',
    'cp_feature_dir': '../data/DeepTAVC/CPI_dataset/demo/cp_feature',
    'pro_feature_dir': '../data/DeepTAVC/CPI_dataset/demo/pro_feature',
    'CADTI_model_path': '../pretrained_model/DeepAVC/CADTI.pt',
    'n_epochs':20, 
    'device':'cuda:0',
    'random_seed': 42,
    'batch_size':32,
    'train_ratio': 0.8,
    'val_ratio': 0.1,
    'test_ratio':0.1,
    'MLP_layer_num':2,
    'MLP_hidden_dim':256}
vocab = Vocab(N_ATOM_TYPES, N_BOND_TYPES)

In [3]:
# load CPI dataset
CPI_dataset = pd.read_pickle('../data/T_AVC/cpi_data/cpi_dataset.pkl')

In [None]:
# Extract compound initial feature by RDKit(if necessary)
smiles_list = CPI_dataset['SMILES'].to_list()
# filter compound with invalid smiles 
valid_smiles, invalid_smiles = filter_invalid_smiles(smiles_list)
extract_cp_feature(smiles_list = valid_smiles, 
                   output_dir = args['cp_feature_dir'],
                   num_workers=32)

In [None]:
# Extract protein initial feature by ESM-2 (if necessary)
pro_seq_list = list(CPI_dataset['sequence'].unique())
# Transform protein sequences into the fasta format
seq2fasta(seq_list=pro_seq_list, 
          save_dir=args['pro_feature_dir'])

extract_esm_feature(
    model_location = '../pretrained_model/ESM/esm2_t33_650M_UR50D.pt',
    fasta_file = os.path.join(args['pro_feature_dir'], 'target_seq.fasta'),
    output_dir = args['pro_feature_dir'],
    toks_per_batch = 10000,
    repr_layers = [-1],
    include=['per_tok'],
    device='cuda:0',
    truncation_seq_length = 1024)

In [7]:
target_id_list = [ f'Target_{i+1}' for i in range(len(CPI_dataset['sequence'].unique()))]

In [8]:
seq2id_dict = dict(zip( CPI_dataset['sequence'].unique(),target_id_list))

In [9]:
CPI_dataset['target_idx'] = CPI_dataset['sequence'].map(seq2id_dict)

In [10]:
### load compound initial feature
graphs, label_dict = load_graphs(os.path.join(args['cp_feature_dir'], 'cp_graphs.pkl'))
fps = torch.load(os.path.join(args['cp_feature_dir'], 'cp_fps.pt'))
mds = torch.load(os.path.join(args['cp_feature_dir'], 'cp_mds.pt'))

In [11]:
assert len(graphs) == len(fps) == len(mds)

In [12]:
### load protein initial feature
pro_feature_dict = torch.load(os.path.join(args['pro_feature_dir'],'esm_feature.pt'),map_location=args['device'])

In [13]:
# Build dataset
CPI_dataset = TAVC_Dataset_Train(smiles_list = CPI_dataset['SMILES'].to_list(),
                          target_seq_list=CPI_dataset['target_idx'].to_list(),
                          target_feature_dict=pro_feature_dict,
                          label_list=CPI_dataset['label'].to_list(),
                          graphs=graphs,
                          fps=fps,
                          mds=mds)

In [None]:
### data split
train_ratio = args['train_ratio']
val_ratio = args['val_ratio']
dataset_size = len(CPI_dataset)
train_size = int(train_ratio * dataset_size) 
val_size = int(val_ratio * dataset_size)   
test_size = dataset_size - train_size - val_size 

torch.manual_seed(42)
train_dataset, val_dataset, test_dataset = random_split(CPI_dataset, [train_size, val_size, test_size])
print(f'Train size:{len(train_dataset)}\nValidation size:{len(val_dataset)}\nTest size:{len(test_dataset)}')

### build dataloader 
config = config_dict[args['config']]
collator = Collator_TAVC_Train(config['path_length'])

train_loader = DataLoader(train_dataset, 
                          batch_size=args['batch_size'], 
                          shuffle=True,  
                          drop_last=False, 
                          collate_fn=collator)
val_loader = DataLoader(val_dataset, batch_size=args['batch_size'], 
                        shuffle=False,
                          drop_last=False, 
                          collate_fn=collator)
test_loader = DataLoader(test_dataset, batch_size=args['batch_size'], 
                         shuffle=False,  
                         drop_last=False, 
                         collate_fn=collator)


In [15]:
kpgt_model = LiGhTPredictor(
    d_node_feats=config['d_node_feats'],
    d_edge_feats=config['d_edge_feats'],
    d_g_feats=config['d_g_feats'],
    d_fp_feats=args['d_fps'],
    d_md_feats=args['d_mds'],
    d_hpath_ratio=config['d_hpath_ratio'],
    n_mol_layers=config['n_mol_layers'],
    path_length=config['path_length'],
    n_heads=config['n_heads'],
    n_ffn_dense_layers=config['n_ffn_dense_layers'],
    input_drop=0,
    attn_drop=args['dropout'],
    feat_drop=args['dropout'],
    n_node_types=vocab.vocab_size).to(args['device'])

In [16]:
# Load pre-trained weigths of KPGT model
kpgt_model.load_state_dict({k.replace('module.',''):v for k,v in torch.load(args['kpgt_model_path'],map_location=args['device']).items()})
# Delete unused structure
del kpgt_model.md_predictor
del kpgt_model.fp_predictor
del kpgt_model.node_predictor

In [17]:
### Model Initialization
CADTI_model = CADTI_Finetune(
d_model=256,
n_heads=8,
num_layers=1,
kpgt_model=kpgt_model,
smiles_dim=768,
protein_dim=1280,
kpgt_features_dim=2304,
mlp_hidden_dim=256,
num_classes=1,
dropout=0,
return_attn=True).to(args['device'])

In [None]:
print("model have {}M parameters in total that require gradients".format(
sum(p.numel() for p in CADTI_model.parameters() if p.requires_grad) / 1e6))

In [19]:
optimizer = Adam(CADTI_model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
lr_scheduler = PolynomialDecayLR(optimizer, 
                                warmup_updates=args['n_epochs']*(len(train_loader))//100, 
                                tot_updates=args['n_epochs']*len(train_loader),
                                lr=args['lr'], 
                                end_lr=1e-5, 
                                power=1)

loss_fn = BCEWithLogitsLoss(reduction='mean')

In [20]:
trainer = TAVC_Trainer(args=args, 
                        optimizer=optimizer,
                    lr_scheduler=lr_scheduler,
                    loss_fn=loss_fn,
                    device=args['device'],
                    model_name='CADTI')

In [None]:
performance_res_df = trainer.fit(model=CADTI_model,
                    train_loader=train_loader,
                    val_loader=val_loader)

### Second training stage: Train DeepTAVC

In [25]:
args = {
    'config':'base',
    'd_fps': 512,
    'd_mds': 200,
    'dropout':0,
    'weight_decay':1e-6,
    'n_tasks':1,
    'lr': 1e-4,
    'kpgt_model_path':'../pretrained_model/KPGT/KPGT.pth',
    'cp_feature_dir': '../data/DeepTAVC/TAVC_dataset/cp_feature',
    'pro_feature_dir': '../data/DeepTAVC/TAVC_dataset/pro_feature',
    'n_epochs':20, 
    'device':'cuda:3',
    'random_seed': 42,
    'batch_size':32,
    'train_ratio': 0.8,
    'val_ratio': 0.1,
    'test_ratio':0.1,
    'MLP_layer_num':2,
    'MLP_hidden_dim':256}

In [23]:
TAVC_data = pd.read_csv('../data/DeepTAVC/TAVC_dataset/DeepTAVC_dataset.csv')

In [None]:
# Extract compound initial feature by RDKit(if necessary)
smiles_list = TAVC_data['canonical_smiles'].to_list()
# filter compound with invalid smiles 
valid_smiles, invalid_smiles = filter_invalid_smiles(smiles_list)
extract_cp_feature(smiles_list = valid_smiles, 
                   output_dir = args['cp_feature_dir'],
                   num_workers=32)

In [None]:
# Extract protein initial feature by ESM-2 (if necessary)
pro_seq_list = list(TAVC_data['sequence'].unique())
# Transform protein sequences into the fasta format
seq2fasta(seq_list=pro_seq_list, 
          save_dir=args['pro_feature_dir'])

extract_esm_feature(
    model_location = '../pretrained_model/ESM/esm2_t33_650M_UR50D.pt',
    fasta_file = os.path.join(args['pro_feature_dir'], 'target_seq.fasta'),
    output_dir = args['pro_feature_dir'],
    toks_per_batch = 10000,
    repr_layers = [-1],
    include=['per_tok'],
    device='cuda:3',
    truncation_seq_length = 1024)

In [30]:
target_id_list = [ f'Target_{i+1}' for i in range(len(TAVC_data['sequence'].unique()))]
seq2id_dict = dict(zip( TAVC_data['sequence'].unique(),target_id_list))
TAVC_data['target_idx'] = TAVC_data['sequence'].map(seq2id_dict)

In [32]:
### load compound initial feature
graphs, label_dict = load_graphs(os.path.join(args['cp_feature_dir'], 'cp_graphs.pkl'))
fps = torch.load(os.path.join(args['cp_feature_dir'], 'cp_fps.pt'))
mds = torch.load(os.path.join(args['cp_feature_dir'], 'cp_mds.pt'))

In [33]:
assert len(graphs) == len(fps) == len(mds)

In [34]:
### load protein initial feature
pro_feature_dict = torch.load(os.path.join(args['pro_feature_dir'],'esm_feature.pt'),map_location=args['device'])

In [35]:
# Build dataset
TAVC_dataset = TAVC_Dataset_Train(smiles_list = TAVC_data['canonical_smiles'].to_list(),
                          target_seq_list=TAVC_data['target_idx'].to_list(),
                          target_feature_dict=pro_feature_dict,
                          label_list=TAVC_data['avd_label'].to_list(),
                          graphs=graphs,
                          fps=fps,
                          mds=mds)

In [None]:
### data split
train_ratio = args['train_ratio']
val_ratio = args['val_ratio']
dataset_size = len(TAVC_dataset)
train_size = int(train_ratio * dataset_size) 
val_size = int(val_ratio * dataset_size)   
test_size = dataset_size - train_size - val_size 

torch.manual_seed(42)
train_dataset, val_dataset, test_dataset = random_split(TAVC_dataset, [train_size, val_size, test_size])
print(f'Train size:{len(train_dataset)}\nValidation size:{len(val_dataset)}\nTest size:{len(test_dataset)}')

### build dataloader 
config = config_dict[args['config']]
collator = Collator_TAVC_Train(config['path_length'])

train_loader = DataLoader(train_dataset, 
                          batch_size=args['batch_size'], 
                          shuffle=True,  
                          drop_last=False, 
                          collate_fn=collator)
val_loader = DataLoader(val_dataset, batch_size=args['batch_size'], 
                        shuffle=False,
                          drop_last=False, 
                          collate_fn=collator)
test_loader = DataLoader(test_dataset, batch_size=args['batch_size'], 
                         shuffle=False,  
                         drop_last=False, 
                         collate_fn=collator)


In [None]:
### Model Initialization
DeepTAVC = CADTI_Finetune(
d_model=256,
n_heads=8,
num_layers=1,
kpgt_model=kpgt_model,
smiles_dim=768,
protein_dim=1280,
kpgt_features_dim=2304,
mlp_hidden_dim=256,
num_classes=1,
dropout=0,
return_attn=True).to(args['device'])
print("model have {}M parameters in total that require gradients".format(
sum(p.numel() for p in DeepTAVC.parameters() if p.requires_grad) / 1e6))

In [None]:
DeepTAVC.load_state_dict(torch.load(args['CADTI_model_path'],map_location=args['device'])) 

In [45]:
trainer = TAVC_Trainer(args=args, 
                        optimizer=optimizer,
                    lr_scheduler=lr_scheduler,
                    loss_fn=loss_fn,
                    device=args['device'],
                    model_name='DeepTAVC')

In [None]:
performance_res_df = trainer.fit(model=DeepTAVC,
                    train_loader=train_loader,
                    val_loader=val_loader)