# Count how many input SMILES cannot be processed by GeoMol

In [34]:
import sys
print (sys.version)  # Should be 3.8, otherwise brew install python@3.8

## Make sure to have the correct package versions :
# Should be torch-geometric 1.6.3. and pytorch to 1.7.0, otherwise:
# python3.8  -m pip install --upgrade --force-reinstall gast==0.3.3 grpcio~=1.32.0 typing-extensions~=3.7.4 h5py~=2.10.0 numpy~=1.19.2 six~=1.15.0 torch==1.7.0 torchvision==0.8.1 torchaudio==0.7.0  torch-scatter==2.0.7 torch-sparse==0.6.9 torch-cluster==1.5.9 torch-spline-conv==1.2.1 torch-geometric==1.6.3 -f https://data.pyg.org/whl/torch-1.7.0+cpu.html
# (from https://github.com/pytorch/pytorch/issues/47354)

# Install rdkit:
# !python3.8 -m pip install rdkit-pypi
# !python3.8 -m pip install --upgrade --force-reinstall numpy==1.19.2 pot==0.7.0


import torch 
print(torch.__version__)
import torch_geometric as tg
print(tg.__version__)


import numpy as np
print(np.__version__)


import ot
print(ot.__version__)

import rdkit
print(rdkit.__version__)

import os
import json
from rdkit import Chem, Geometry
from rdkit.Chem import AllChem
import pickle
import pandas as pd
from tqdm import tqdm
import random
import yaml

from model.model import GeoMol
from model.featurization import featurize_mol_from_smiles
from torch_geometric.data import Batch
from model.inference import construct_conformers

import random

import pickle

#!python3.8 -m pip install ipywidgets py3Dmol
from ipywidgets import interact, fixed, IntSlider
import ipywidgets
import py3Dmol

3.8.12 (default, Oct 13 2021, 06:42:42) 
[Clang 13.0.0 (clang-1300.0.29.3)]
1.7.0
1.6.3
1.19.2
0.7.0
2021.09.4


In [4]:
from IPython.display import HTML
HTML('''<script>
code_show_err=false; 
function code_toggle_err() {
 if (code_show_err){
 $('div.output_stderr').hide();
 } else {
 $('div.output_stderr').show();
 }
 code_show_err = !code_show_err
} 
$( document ).ready(code_toggle_err);
</script>
To toggle on/off output_stderr, click <a href="javascript:code_toggle_err()">here</a>.''')

In [None]:
f = open('summary_qm9.json')
data = json.load(f)
print(len(data.keys()))
worked = 0
processed = 0
for smi in data.keys():
    processed += 1
    tg_data = featurize_mol_from_smiles(smi, dataset='qm9')
    if not tg_data:
#         print(f'failed to featurize SMILES: {smi}')
        continue
    else:
        worked += 1
        
print(processed, worked)
        


In [18]:
# Same code as in visualize_confs.ipynb

def show_mol(mol, view, grid):
    mb = Chem.MolToMolBlock(mol)
    view.removeAllModels(viewer=grid)
    view.addModel(mb,'sdf', viewer=grid)
    view.setStyle({'model':0},{'stick': {}}, viewer=grid)
    view.zoomTo(viewer=grid)
    return view

def view_single(mol):
    view = py3Dmol.view(width=600, height=600, linked=False, viewergrid=(1,1))
    show_mol(mol, view, grid=(0, 0))
    return view

def MolTo3DView(mol, size=(600, 600), style="stick", surface=False, opacity=0.5, confId=0):
    """Draw molecule in 3D
    
    Args:
    ----
        mol: rdMol, molecule to show
        size: tuple(int, int), canvas size
        style: str, type of drawing molecule
               style can be 'line', 'stick', 'sphere', 'carton'
        surface, bool, display SAS
        opacity, float, opacity of surface, range 0.0-1.0
    Return:
    ----
        viewer: py3Dmol.view, a class for constructing embedded 3Dmol.js views in ipython notebooks.
    """
    assert style in ('line', 'stick', 'sphere', 'carton')
    mblock = Chem.MolToMolBlock(mol[confId])
    viewer = py3Dmol.view(width=size[0], height=size[1])
    viewer.addModel(mblock, 'mol')
    viewer.setStyle({style:{}})
    if surface:
        viewer.addSurface(py3Dmol.SAS, {'opacity': opacity})
    viewer.zoomTo()
    return viewer

def conf_viewer(idx, mol):
    return MolTo3DView(mol, confId=idx).show()

## Code taken from the generate_confs.py. We ask GeoMol to generate one conformer for each input SMILES and check how often it fails.

## NOTE: to reproduce numbers in our paper, one needs additionally to run scripts/clean_smiles.py and to sample two times the number of ground truth conformers per each mol. Otherwise, the code below would still work, just the final metric values would be worse. The file clean_smiles.py accounts for inconsistent molecules in the true dataset

In [37]:

def process_dataset(dataset='qm9'):
    trained_model_dir = 'trained_models/' + dataset + '/'
    mmff = False

    with open(f'{trained_model_dir}/model_parameters.yml') as f:
        model_parameters = yaml.full_load(f)
    model = GeoMol(**model_parameters)

    state_dict = torch.load(f'{trained_model_dir}/best_model.pt', map_location=torch.device('cpu'))
    model.load_state_dict(state_dict, strict=True)
    model.eval()

    f = open('summary_' + dataset + '.json')
    data = json.load(f)
    print(len(data.keys()))
    notworked = 0
    processed = 0

    all_dataset_smiles = list(data.keys())
    
    random.shuffle(all_dataset_smiles)
    
    conformer_dict = {}
    for smi in all_dataset_smiles[:10000]:
        processed += 1

        # create data object (skip smiles rdkit can't handle)
        tg_data = featurize_mol_from_smiles(smi, dataset=dataset)
        if not tg_data:
            notworked += 1
            continue

        # generate model predictions
        data = Batch.from_data_list([tg_data])
        model(data, inference=True, n_model_confs=1)

        # set coords
        n_atoms = tg_data.x.size(0)
        model_coords = construct_conformers(data, model)
        mols = []
        for x in model_coords.split(1, dim=1):
            mol = Chem.AddHs(Chem.MolFromSmiles(smi))
            coords = x.squeeze(1).double().cpu().detach().numpy()
            mol.AddConformer(Chem.Conformer(n_atoms), assignId=True)
            for i in range(n_atoms):
                mol.GetConformer(0).SetAtomPosition(i, Geometry.Point3D(coords[i, 0], coords[i, 1], coords[i, 2]))

            if mmff:
                try:
                    AllChem.MMFFOptimizeMoleculeConfs(mol, mmffVariant='MMFF94s')
                except Exception as e:
                    pass
            mols.append(mol)

        if len(mols) == 0:
            notworked += 1
            continue

        conformer_dict[smi] = mols


#         if processed % 10000 == 0:
#             print(notworked, processed)
    return conformer_dict

conformer_dict = process_dataset('qm9')


In [38]:
print('Using 10000 random SMILES from QM9, we get ',
      ' num SMILES with successfully generated conformers ', len(conformer_dict.keys()))


Using 10000 random SMILES from QM9, we get   num SMILES with successfully generated conformers  9938


In [40]:
conformer_dict = process_dataset('drugs')
print('Using 10000 random SMILES from DRUGS, we get ',
      ' num SMILES with successfully generated conformers ', len(conformer_dict.keys()))


304466
Using 10000 random SMILES from DRUGS, we get   num SMILES with successfully generated conformers  9999


In [41]:
# Show the conformer of one random SMILES
mols = conformer_dict[list(conformer_dict.keys())[0]]
interact(conf_viewer, idx=ipywidgets.IntSlider(min=0, max=len(mols)-1, step=1), mol=fixed(mols));

mols = conformer_dict[list(conformer_dict.keys())[1]]
interact(conf_viewer, idx=ipywidgets.IntSlider(min=0, max=len(mols)-1, step=1), mol=fixed(mols));

mols = conformer_dict[list(conformer_dict.keys())[2]]
interact(conf_viewer, idx=ipywidgets.IntSlider(min=0, max=len(mols)-1, step=1), mol=fixed(mols));


interactive(children=(IntSlider(value=0, description='idx', max=0), Output()), _dom_classes=('widget-interact'…

interactive(children=(IntSlider(value=0, description='idx', max=0), Output()), _dom_classes=('widget-interact'…

interactive(children=(IntSlider(value=0, description='idx', max=0), Output()), _dom_classes=('widget-interact'…