In [1]:
import dgl
from collections import defaultdict
from dgl.nn.pytorch.glob import AvgPooling
from dgllife.model import load_pretrained
from dgllife.model.model_zoo import *
from dgllife.utils import mol_to_bigraph, PretrainAtomFeaturizer, PretrainBondFeaturizer
import numpy as np
import pandas as pd
import pickle
from rdkit import Chem
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

def collate(gs):
    return dgl.batch(gs)


Using backend: pytorch


In [262]:
model = load_pretrained('gin_supervised_infomax') # contextpred infomax edgepred masking
model.to('cpu')
model.eval()

with open('/tf/notebooks/code_for_pub/smiles_files/smiles_drugcomb_BY_cid_duplicated.pickle','rb') as f:
    b = pickle.load(f)


Downloading gin_supervised_infomax_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_infomax.pth...
Pretrained model loaded


In [3]:
graphs = []
b_res = b.reset_index(drop=True, inplace=False)
for smi in b_res:
    try:
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            continue
        g = mol_to_bigraph(mol, add_self_loop=True,
                           node_featurizer=PretrainAtomFeaturizer(),
                           edge_featurizer=PretrainBondFeaturizer(),
                           canonical_atom_order=True)
        graphs.append(g)

    except:
        continue
del b_res



In [4]:
data_loader = DataLoader(graphs, batch_size=256, collate_fn=collate, shuffle=False)

readout = AvgPooling()

mol_emb = []
for batch_id, bg in enumerate(data_loader):
    bg = bg.to('cpu')
    nfeats = [bg.ndata.pop('atomic_number').to('cpu'),
              bg.ndata.pop('chirality_type').to('cpu')]
    efeats = [bg.edata.pop('bond_type').to('cpu'),
              bg.edata.pop('bond_direction_type').to('cpu')]
    with torch.no_grad():
        node_repr = model(bg, nfeats, efeats)
    mol_emb.append(readout(bg, node_repr))
mol_emb = torch.cat(mol_emb, dim=0).detach().cpu().numpy()

In [5]:
http://86.50.253.156:8888/notebooks/notebooks/code_for_pub/_5_make_infomaxFP.ipynb#fps_infomax_new = pd.DataFrame(data=mol_emb, index=b.index)

drugs_name = '/tf/notebooks/code_for_pub/smiles_files/drugcomb_drugs_export_OCT2020.csv'
drugs = pd.read_csv(drugs_name, names=['dname','id', 'smiles', 'cid'], header=0) # oct2020 version

mapping = defaultdict(list) 
for i in drugs.itertuples(): # map cid to id
    mapping[i.cid] = i.id
fps_infomax_new['id'] = fps_infomax_new.index
fps_infomax_new['id'] = fps_infomax_new['id'].map(mapping)
fps_infomax_new = fps_infomax_new.set_index('id', drop=True)
#fps_infomax_new = b.iloc[:,0]

In [6]:
fps_infomax_new.head()

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,290,291,292,293,294,295,296,297,298,299
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,0.006908,-0.110677,0.089927,-0.088475,-0.087732,0.018758,0.084044,0.065477,0.110099,-0.00996,...,-0.167624,-0.065075,0.09484,-0.018449,-0.033862,0.02081,0.04761,0.178198,-0.484106,0.021591
2,0.16054,0.0044,0.064088,0.266052,-0.008104,0.019542,0.044715,0.033699,-0.173925,-0.038794,...,0.118521,-0.068513,-0.060505,-0.02203,0.346666,0.015235,-0.055782,0.122431,-0.086857,0.307728
3,-0.085681,-0.081969,-0.025518,0.049404,0.106336,-0.008085,0.042325,-0.011998,0.04468,-0.055106,...,0.090689,0.238188,-0.052238,-0.175773,0.095863,0.014043,0.00566,-0.019315,-0.24578,0.207331
4,-0.06181,-0.077405,0.055247,-0.013646,0.02292,-0.004234,-0.033726,-0.056457,0.014356,-0.23992,...,0.154744,0.107573,-0.013802,0.014862,0.010022,0.016165,0.027415,-0.078507,0.084592,-0.015352
5,0.063767,-0.020197,0.055716,-0.009538,0.142679,0.013067,-0.067377,0.076436,0.026296,-0.041145,...,0.056123,0.050382,0.049863,-0.059274,0.169545,0.017217,0.044571,-0.065558,0.035818,-0.272809


In [7]:
with open('/tf/notebooks/code_for_pub/fp_files/fps_infomax_new.pickle', 'wb') as f:
    pickle.dump(fps_infomax_new, f)