In [1]:
import os
os.chdir('../..')

In [2]:
import numpy as np
import pandas as pd
import torch

from rdkit.Chem import MolFromSmiles, MolFromInchi, MolToInchi
from rdkit.Chem.Draw import IPythonConsole
from IPython.display import SVG

from molexplain.utils import MODELS_PATH, PROCESSED_DATA_PATH, DEVICE
from molexplain.vis import molecule_importance

Using backend: pytorch


In [3]:
from molexplain.net import MPNNPredictor

model = MPNNPredictor(node_in_feats=46,
                      edge_in_feats=10,
                      global_feats=4,
                      n_tasks=1).to(DEVICE) 
model.load_state_dict(torch.load(os.path.join(MODELS_PATH,
                                              "CYP3A4_noHs.pt"),
                                 map_location=DEVICE))

df = pd.read_csv('../cyp/CYP3A4.csv', header=0, sep=';')
df = df.loc[df['Class'] == 'Active']  ## Filter inactives
smiles = df['SMILES'].to_numpy()

inchis = []
invalid_idx = []

for idx, sm in enumerate(smiles):
    try:
        mol = MolFromSmiles(sm)
        inchi = MolToInchi(mol)
        mol_back = MolFromInchi(inchi)
        if mol_back is not None:
            inchis.append(inchi)
        else:
            invalid_idx.append(idx)
    except:
        invalid_idx.append(idx)
        continue


inchis = np.array(inchis)

In [4]:
idx = 10
example_inchi = inchis[idx]

print(example_inchi)

InChI=1S/C20H21F3N2O/c1-15-5-7-16(8-6-15)13-19(26)25-11-9-24(10-12-25)18-4-2-3-17(14-18)20(21,22)23/h2-8,14H,9-13H2,1H3


In [None]:
from tqdm import tqdm

IMG_DIR = 'imgs_cyp_v2_noHs'
os.makedirs(IMG_DIR, exist_ok=True)


for idx in tqdm(range(len(inchis))):
    mol = MolFromInchi(inchis[idx])
    svg, _, _, _, global_importance = molecule_importance(mol,
                                                          model,
                                                          task=0,
                                                          vis_factor=5,
                                                          addHs=False)

    with open(os.path.join(IMG_DIR, f'{idx}.svg'), 'w+') as handle:
        handle.write(svg)

  4%|‚ñç         | 163/3626 [03:12<1:13:59,  1.28s/it]

In [None]:
np.save('/home/jose/cyp/inchis.npy', arr=inchis)