In [None]:
import sys
sys.path.append('../../EquiScore')
import pickle
from train import *
from torch import distributed as dist
import torch.multiprocessing as mp
from dist_utils import *
import utils
import random
from parsing import parse_train_args
args = parse_train_args() # in notebook you should add a '[]' to  args = parser.parse_args() in parse_train_args function
seed_torch(seed = args.seed)
args.ngpu = 1

In [None]:
'''
util funcs

'''
import seaborn as sns
import torch.nn.functional as F
import matplotlib.pyplot as plt
from dgl.nn.functional import edge_softmax
def MeanStd(data):
    max_value = max(data)
    min_value = min(data)
    X_std = (data - min_value)/((max_value-min_value) + 1e-11)
    X_scaled  =  X_std*(max_value - min_value) + min_value
    return X_scaled/X_scaled.sum()
def MaxMin(atom_weights):
    min_value = min(atom_weights)
    max_value = max(atom_weights)
    atom_weights = (atom_weights - min_value) / (max_value - min_value)
    return atom_weights
def plotAttDist(pred_g):
    score = edge_softmax(graph = pred_g,logits = pred_g.edata['score'])
    atts = score.squeeze(-1).mean(1)
    plt.plot(atts)
def copyAtt(pred_g):
    atts_copy = np.zeros(shape = (len(pred_g.nodes()),len(pred_g.nodes())))
    u,v = pred_g.edges()
    eids = pred_g.edge_ids(u,v)
    atts_copy[u,v]= atts[eids]
    return atts_copy
def getAttnFP(pred_g,fp,n1):

    temp_fp= np.array(fp)
    u,v = list(temp_fp[:,0]) +  list((n1+ temp_fp[:,1])),list((n1+ temp_fp[:,1])) + list(temp_fp[:,0])
    eids = pred_g.edge_ids(u,v)
 
    all_ids = list(range(pred_g.num_edges()))
    eid_res = torch.tensor(list(set(all_ids) - set(eids.data.cpu().numpy()))).to(eids.device)
    
    score = edge_softmax(graph = pred_g,logits = pred_g.edata['score'].clamp(-5,5))
    attnFP = score[eids]
    attnFP_res = score[eid_res]
    return attnFP,attnFP_res


In [None]:
# select keys to compute 
def getDataLoader(args,key_type= 'active',nums = 10000):
    try:
        with open (args.test_keys, 'rb') as fp:
            keys = pickle.load(fp)
    except:
        keys = args.test_keys
    test_actives = [i for i in keys if '_active'  in i.split('/')[-1] ]
    test_decoys = [i for i in keys if '_active' not in i.split('/')[-1]]
    all_pocket =  test_actives + test_decoys 
    random.shuffle(all_pocket)
    random.shuffle(test_actives)
    random.shuffle(test_decoys)
    
    print(f'actives num : {len(test_actives)} decoys num : {len(test_decoys)}')
    if key_type== 'active':
        test_keys = test_actives[:nums]
    elif key_type== 'decoy' :
        test_keys = test_decoys[:nums]
    else:
        test_keys = all_pocket[:nums]
    test_dataset = ESDataset(test_keys,args, args.data_path,args.debug)
    test_sampler = SequentialDistributedSampler(test_dataset,1) if args.ngpu > 1 else None
    test_dataloader = DataLoaderX(test_dataset, 1, sampler=test_sampler,prefetch_factor = 4,\
        shuffle=False, num_workers = 1, collate_fn=test_dataset.collate,pin_memory=True) 
    return test_sampler,test_dataloader,test_keys

In [None]:
def getModel(args):
    if args.FP:
        args.N_atom_features = 39
    else:
        args.N_atom_features = 28
    #hyper parameters
    num_epochs = args.epoch
    lr = args.lr
    ngpu = 1
    args.ngpu = 1
    batch_size = args.batch_size = 1
    data_path = args.data_path
    save_dir = args.save_dir
    train_time = time.strftime('%Y-%m-%d-%H-%M-%S')
    #make save dir if it doesn't exist
    #initialize model
    if args.ngpu>0:
        cmd = get_available_gpu(num_gpu=args.ngpu, min_memory=8000, sample=3, nitro_restriction=False, verbose=True)

        if cmd[-1] == ',':
            os.environ['CUDA_VISIBLE_DEVICES']=cmd[:-1]
        else:
            os.environ['CUDA_VISIBLE_DEVICES']=cmd
        print(cmd)
    model = EquiScore(args) if args.model == 'EquiScore' else None
    print ('number of parameters : ', sum(p.numel() for p in model.parameters() if p.requires_grad))
    args.device = 'cuda:0'
    args.local_rank = 'cuda:0'
    model ,opt_dict,epoch_start= utils.initialize_model(model, args.device,args,args.save_model)
    return model

In [None]:
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem.Draw import SimilarityMaps
import  rdkit.Chem.AllChem as AllChem
IPythonConsole.ipython_useSVG = True
import copy
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
from IPython.display import SVG
from IPython.display import display
import matplotlib
import matplotlib.cm as cm
def InchMap(mol):
    inchi, aux_info = Chem.MolToInchiAndAuxInfo(mol)
    for i in aux_info.split('/'):
        if i[0]=='N':
            # print(i)
            pos=i[2:].split(',')
    inchi_to_mol = {i:int(j)-1 for i,j in enumerate(pos)}
    mol_to_inchi = {int(j)-1 :i for i,j in enumerate(pos)}
    return inchi_to_mol,mol_to_inchi
def mol_with_atom_weight( mol ,weights,mol_to_inchi):
    atoms = mol.GetNumAtoms()
    for idx in range( atoms ):

        mol.GetAtomWithIdx( idx ).SetProp("atomNote_Score", f"{str( round(float(weights[mol_to_inchi[mol.GetAtomWithIdx( idx ).GetIdx()]]),3))}")

    return mol

def drawmol(mol,atom_weights,highlightAtoms = [31,27,26,2,1,0],flag = ''):
    inchi_to_mol,mol_to_inchi = InchMap(mol)
    weight_inchi = []
    for i in range(len(atom_weights)):
        try:
            weight_inchi.append(atom_weights[inchi_to_mol[i]])
        except:
            weight_inchi.append(atom_weights[i])
                                        
    mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
    weight_inchi = MaxMin(weight_inchi)
    
    note =flag +  '  vitual node weight:'
    for i in range(len(mol.GetAtoms()),len(weight_inchi)):
        note += f' {i - len(mol.GetAtoms()) + 1} : {round(weight_inchi[i],3)} '
        

    mol = mol_with_atom_weight( mol ,weight_inchi,mol_to_inchi)
    
    norm = matplotlib.colors.Normalize(vmin=-1, vmax=1) # minmax: (vmin=-1.0, vmax=1.28)： meanstd : (vmin=-0.2, vmax=1.0)
    cmap = cm.get_cmap('bwr')

    plt_colors = cm.ScalarMappable(norm=norm, cmap=cmap)
    atom_colors = {idx: plt_colors.to_rgba(float(mol.GetAtomWithIdx( idx ).GetProp("atomNote_Score"))) for idx in range(len(mol.GetAtoms()))}
    
    rdDepictor.Compute2DCoords(mol)
    drawer = rdMolDraw2D.MolDraw2DSVG(1200, 800)
    drawer.SetFontSize(120)
    op = drawer.drawOptions()
     
    mol = rdMolDraw2D.PrepareMolForDrawing(mol)
    drawer.DrawMolecule(mol, highlightAtoms=range(len(mol.GetAtoms())) if not highlightAtoms else highlightAtoms,
                             highlightBonds=[],
                             highlightBondColors=atom_colors,
                             highlightAtomColors=atom_colors,legend = note)

                             
    drawer.FinishDrawing()
    svg = drawer.GetDrawingText()
    svg = svg.replace('svg:', '')
    return (mol, weight_inchi, svg)

In [None]:


def save_PNG_smilarymap(model,h,mol,pred_g,logit,key,save_dir):
    logit = F.softmax(logit,dim = 1)[:,1].data.cpu().numpy()[0]

    ligand_num = int(pred_g.ndata['V'].sum().data.cpu().numpy())
    # h =  active_gs[idx][0]
    for module in model.weight_and_sum.atom_weighting:
        h = module(h)
    mol, aw, svg = drawmol(mol, h.flatten().data.cpu().numpy()[:ligand_num],flag = ' prob: ' + str(round(logit,4))) 
    
    weights = [float(mol.GetAtomWithIdx( idx ).GetProp("atomNote")) for idx in range(len(mol.GetAtoms()))]
    fig = SimilarityMaps.GetSimilarityMapFromWeights(mol,weights)
    save_path = os.path.join(*[save_dir,os.path.join(*key.split('/')[-2:-1])]) 
    os.makedirs(save_path, exist_ok=True)

    fig.savefig(os.path.join(save_path,key.split('/')[-1] + '_{}.png'.format(round(logit,4))),dpi=660,format='png',bbox_inches='tight')


In [None]:

import pickle 
from equiscore_utils import *
from tqdm import tqdm
import dataset_utils
from cairosvg import svg2png
def save_PNG(model,h,mol,pred_g,logit,key,save_dir):
    logit = F.softmax(logit,dim = 1)[:,1].data.cpu().numpy()[0]

    ligand_num = int(pred_g.ndata['V'].sum().data.cpu().numpy())
    # h =  active_gs[idx][0]
    for module in model.weight_and_sum.atom_weighting:
        h = module(h)
    mol, aw, svg = drawmol(mol, h.flatten().data.cpu().numpy()[:ligand_num],flag = ' prob: ' + str(round(logit,4))) 
    save_path = os.path.join(*[save_dir,os.path.join(*key.split('/')[-2:-1])]) 
    os.makedirs(save_path, exist_ok=True)
    svg2png(bytestring=svg,dpi=1200,write_to=os.path.join(save_path,key.split('/')[-1] + '_{}.png'.format(round(logit,4))))
def getResult(model,test_sampler,test_dataloader,test_keys,save_dir):
    all_attns = []
    all_attns_res = []
    # gs = []
    none_fp  = 0
    model.eval()
    pbar = tqdm(test_dataloader)
    with torch.no_grad():
        test_losses,test_true,test_pred = [], [],[]
        for i_batch, (g,full_g,Y) in enumerate(pbar):
            key = test_keys[i_batch]
            # get fp idx
            with open(key,'rb') as f:
                m1,m2,atompairs,types = pickle.load(f)
                f.close()
            n1,d1,adj1 = dataset_utils.get_mol_info(m1)
            n2,d2,adj2 = dataset_utils.get_mol_info(m2)
            H1 = get_atom_graphformer_feature(m1,FP = args.FP)
            if args.virtual_aromatic_atom:
                adj1,H1,d1,n1 = dataset_utils.add_atom_to_mol(m1,adj1,H1,d1,n1)
            ################## ##############################
            model.zero_grad()
            g = g.to(args.local_rank)
            full_g = full_g.to(args.local_rank)
            Y = Y.long().to(args.local_rank)
            h,pred_g,pred_full_g,logit = model.getAtt(g,full_g)
            # save png
            save_PNG(model,h,m1,pred_g,logit,key,save_dir = save_dir) # just for save figs
            if len(atompairs) > 0 :
                attns,attns_res = getAttnFP(pred_g,atompairs,n1) # change pred_g to pred_full_g as to get distance and IFP edges 
                all_attns.append(attns.data.cpu().numpy())
                all_attns_res.append(attns_res.data.cpu().numpy())
                none_fp += 1
            pbar.set_description('fp num : {}-- ration : {}/{}'.format(none_fp,i_batch + 1,len(test_dataloader)))
    return all_attns,all_attns_res


In [None]:
'''
atoms need to plot 
ptp1b_23485 : 31,27,26,2,1,0
ptp1b_23484 : 31,5,6,7,32,
'''


In [None]:

import glob
save_dir = './figs_dpi660_norm_style_ptp1b_local_color_23485/'
test_keys = glob.glob('./leadopt_pocket/*')
test_keys = [i for i in test_keys if '23485' in i]
args.test_keys = test_keys
args.test = True
args.save_model = './EquiScore/workdir/official_weight/save_best_f1_model_LeadOpt.pt'
test_sampler,test_dataloader,test_keys = getDataLoader(args,key_type= 'all',nums = 10000)
model = getModel(args)
active_attns,all_attns_res= getResult(model,test_sampler,test_dataloader,test_keys,save_dir = save_dir)