In [14]:
import pandas as pd
import torch
import os
from tqdm import tqdm
from torch.utils.data import Dataset,DataLoader
from utils.Func import extract_cp_feature
from dgl.data.utils import load_graphs
from utils.PAVC_dataset import PAVC_Dataset_Infer
from utils.collator import Collator_PAVC_Infer
from utils.model.KPGT import *
from utils.featurizer import Vocab, N_ATOM_TYPES, N_BOND_TYPES,VIRTUAL_ATOM_FEATURE_PLACEHOLDER, VIRTUAL_BOND_FEATURE_PLACEHOLDER

In [29]:
args = {
    'config':'base',
    'd_fps': 512,
    'd_mds': 200,
    'dropout':0,
    'model_path':'/home2/kangboming/kangboming/workspace2/AVC_paper/github/pretrained_model/DeepAVC/DeepPAVC.pt',
    'cp_feature_dir':'/home2/kangboming/kangboming/workspace2/AVC_paper/github/data/DeepPAVC_inference/cp_feature',
    'device':'cuda:3',
    'n_tasks':1} 

In [15]:
vocab = Vocab(N_ATOM_TYPES, N_BOND_TYPES)

In [12]:
# Load pre-trianed DeepPAVC mdoel
DeepPAVC = torch.load(args['model_path'],map_location=args['device'])

In [13]:
DeepPAVC

LiGhTPredictor(
  (node_emb): AtomEmbedding(
    (in_proj): Linear(in_features=137, out_features=768, bias=True)
    (virtual_atom_emb): Embedding(1, 768)
    (input_dropout): Dropout(p=0, inplace=False)
  )
  (edge_emb): BondEmbedding(
    (in_proj): Linear(in_features=14, out_features=768, bias=True)
    (virutal_bond_emb): Embedding(1, 768)
    (input_dropout): Dropout(p=0, inplace=False)
  )
  (triplet_emb): TripletEmbedding(
    (in_proj): MLP(
      (dense_layer_list): ModuleList()
      (in_proj): Linear(in_features=1536, out_features=768, bias=True)
      (out_proj): Linear(in_features=768, out_features=768, bias=True)
      (act): GELU()
    )
    (fp_proj): MLP(
      (dense_layer_list): ModuleList()
      (in_proj): Linear(in_features=512, out_features=768, bias=True)
      (out_proj): Linear(in_features=768, out_features=768, bias=True)
      (act): GELU()
    )
    (md_proj): MLP(
      (dense_layer_list): ModuleList()
      (in_proj): Linear(in_features=200, out_features=

In [16]:
def model_infer(model, df, feature_dir,device):

    graphs, label_dict = load_graphs(os.path.join(feature_dir, 'cp_graphs.pkl'))
    fps = torch.load(os.path.join(feature_dir, 'cp_fps.pt'))
    mds = torch.load(os.path.join(feature_dir, 'cp_mds.pt'))

    smiles_list = df['SMILES'].to_list()

    PAVC_ds = PAVC_Dataset_Infer(smiles_list = smiles_list,
                                graphs = graphs,
                                ecfps = fps,
                                mds = mds)
    ### 构建出dataloader
    config = config_dict[args['config']]
    collator = Collator_PAVC_Infer(config['path_length'])

    PAVC_loader = DataLoader(PAVC_ds, 
                            batch_size=32, 
                            shuffle=False, 
                            drop_last=False, 
                            collate_fn=collator)
    
    ### 执行推理
    model.eval()

    pred_scores_all = []
    smiles_list_all = []
    
    for batched_data in tqdm(PAVC_loader, total=len(PAVC_loader)):
        (smiles, g, ecfp, md) = batched_data
        ecfp = ecfp.to(device)
        md = md.to(device)
        g = g.to(device)
        predictions = model.forward_tune(g, ecfp, md)

        predictions = predictions.squeeze(1)
        pred_scores = torch.sigmoid(predictions)
        pred_scores_all.append(pred_scores.detach().cpu())
        smiles_list_all.extend(smiles)
    all_pred_scores = torch.cat(pred_scores_all)
    all_pred_scores = all_pred_scores.detach().cpu().numpy()

    res_df = pd.DataFrame(
            {
                'SMILES': smiles_list_all,
                'pred_score': all_pred_scores
            }
    )

    return res_df

In [26]:
demo_df = pd.read_csv('/home2/kangboming/kangboming/workspace2/AVC_paper/github/data/DeepPAVC_inference/DeepPAVC_inference_demo.csv')

In [30]:
# Extract initial compound feature ( by RDkit)
extract_cp_feature(smiles_list= demo_df['SMILES'].to_list(),
                   output_dir= args['cp_feature_dir'],
                   num_workers=32)

extracting graphs


[Parallel(n_jobs=32)]: Using backend LokyBackend with 32 concurrent workers.
Using backend: pytorchUsing backend: pytorch

Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorchUsing backend: pytorch

Using backend: pytorch
Using backend: pytorch
Using backend: pytorchUsing backend: pytorch

Using backend: pytorch
Using backend: pytorchUsing backend: pytorch

Using backend: pytorch
Using backend: pytorch
Using backend: pytorch


extracting fingerprints
extracting molecular descriptors


[Parallel(n_jobs=32)]: Done  20 out of  20 | elapsed:    2.9s remaining:    0.0s
[Parallel(n_jobs=32)]: Done  20 out of  20 | elapsed:    2.9s finished


'Done!'

In [31]:
res_df = model_infer(model=DeepPAVC,
                     df=demo_df,
                     feature_dir=args['cp_feature_dir'],
                     device=args['device'])

100%|██████████| 1/1 [00:01<00:00,  1.09s/it]


In [33]:
res_df.head()

Unnamed: 0,SMILES,pred_score
0,CCN(CC)C(C)CN1C2=CC=CC=C2SC2=CC=CC=C12,6.4e-05
1,COCCC[C@H](NC(=O)[C@@H]1CCCN1C(=O)[C@@H](CC1=C...,0.154223
2,NCC1=CNC(=S)N1[C@H]1CCC2=C(C1)C=C(F)C=C2F,1.4e-05
3,CC(C)[C@H]1C2=C(CC[C@@]1(CCN(C)CCCC1=NC3=CC=CC...,5.6e-05
4,[H][C@@](O[C@H]1O[C@H](CO)[C@@H](O)[C@H](O)[C@...,0.010274
