## Load raw generation result

In [1]:
import os
import pickle

path = os.path.dirname(os.path.abspath(''))
os.chdir(path)

data_path = path + "/data/MOSES2_test_mol.pkl"
test_data = pickle.load(open(data_path, 'rb'))
raw_path = "./result/with_guide/"

In [2]:
ligand_index_dict_file = open("./data/index_map.txt", 'r')

ligand_index_dict = {}

for line in ligand_index_dict_file.readlines():
    items = line.strip().split(":")
    ligand_index_dict[int(items[0])] = int(items[1])

ligand_index_dict_file.close()

### use their modified reconstruction algorithm

In [7]:
from utils import reconstruct
from utils import transforms
import pdb
import torch
from rdkit import Chem
import numpy as np

select_idxs = [0, 1]
all_mols = {}
all_data_mols = {}
all_sims = {}
for f in os.listdir(raw_path):
    print(f)
    if not f.startswith("result") or not f.endswith(".pt"): continue
    data_idx = int(f.split("_")[-1].split(".")[0])
    print(f)
    if data_idx not in select_idxs: continue
    path = raw_path + f

    data = torch.load(open(path, 'rb'))
    
    mols = []
    num_complete = 0
    for i in range(len(data['pred_ligand_v_traj'])):
        pred_v = data['pred_ligand_v_traj'][i][-1]
        pred_pos = data['pred_ligand_pos_traj'][i][-1]
        pred_atom_type = transforms.get_atomic_number_from_index(pred_v, mode="add_aromatic")
        pred_aromatic = transforms.is_aromatic_from_index(pred_v, mode="add_aromatic")
        
        try:
            mol = reconstruct.reconstruct_from_generated(pred_pos, pred_atom_type, pred_aromatic,
                basic_mode=True)
        except:
            print("error")
            continue
        smiles = Chem.MolToSmiles(mol)
        if "." not in smiles: num_complete += 1
        mols.append(mol)
    
    print("complete percentage: %.4f" % (num_complete / len(mols)))
    all_mols[data_idx] = mols
    print("get %d data" % (data_idx))

result_1.pt
result_1.pt


[20:55:45] Explicit valence for atom # 16 C, 5, is greater than permitted


error
complete percentage: 1.0000
get 1 data
result_0.pt
result_0.pt
complete percentage: 1.0000
get 0 data


### Visualization

In [12]:
import py3Dmol
from rdkit import Chem

test_data_idx = 0
sample_idx = 3
mblock = Chem.MolToMolBlock(all_mols[test_data_idx][sample_idx])
view = py3Dmol.view(data=mblock,style={'stick':{'colorscheme':'cyanCarbon'}})
view.addSurface(py3Dmol.SAS, {'opacity': 0.5, 'probeRadius': 1.4})

view.show()

In [15]:
import py3Dmol
from rdkit import Chem
import pdb

NUM_INDEX = 0
NUM_SAMPLE_INDEX = 2
mols = all_mols[NUM_INDEX]
idx = ligand_index_dict[NUM_INDEX]

mblock1 = Chem.MolToMolBlock(mols[NUM_SAMPLE_INDEX])
mblock2 = Chem.MolToMolBlock(test_data[idx])

view = py3Dmol.view()
view.addModel(mblock1, 'sdf')
view.setStyle({'model': -1}, {'stick': {'color': 'red'}})

view.addModel(mblock2, 'sdf')
view.setStyle({'model': -1}, {'stick': {'color': 'gray'}})

#view.addSurface(py3Dmol.VDW, {'probeRadius': 1.4, 'opacity': 0.4, 'color': 'blue'}, {'model': 0})
view.addSurface(py3Dmol.VDW, {'probeRadius': 1.4, 'opacity': 0.4, 'color': 'yellow'}, {'model': 1})

view.zoomTo()
view.show()

## Analyze intermediate molecule

In [16]:
import os
from utils import reconstruct
from utils import transforms
from utils.covalent_graph import connect_covalent_graph
import pdb
import torch
from rdkit import Chem
import numpy as np
import torch.nn.functional as F

ptb = Chem.GetPeriodicTable()

raw_path = "./result/with_guide/"
steps = [i for i in range(0, 1000, 10)] + [999]
all_data_steps = {step: {} for step in steps}
error_rates = {step: 0 for step in steps}
total_nums = {step: 0 for step in steps}
data_num = 0
for f in os.listdir(raw_path):
    if not f.startswith("result"): continue
    data_idx = int(f.split("_")[1].split(".")[0])
    if data_idx != 0: continue
    path = raw_path + f
    data = torch.load(open(path, 'rb'))
    data_num += 1
    for step in steps:
        pos_type_data = {}
        for i in range(len(data['pred_ligand_v_traj'])):
            pred_v = data['pred_ligand_v_traj'][i][step]
            pred_pos = data['pred_ligand_pos_traj'][i][step]
            pred_pos = pred_pos - np.mean(pred_pos, axis=0)
            pred_atom_type = transforms.get_atomic_number_from_index(pred_v, mode="add_aromatic")
            pred_atom_type_text = [ptb.GetElementSymbol(atom_type) for atom_type in pred_atom_type]
            one_hot_pred_v = F.one_hot(torch.tensor(pred_v), num_classes=15)
            pred_edges = connect_covalent_graph(torch.tensor(pred_pos), one_hot_pred_v).numpy()
            
            pred_aromatic = transforms.is_aromatic_from_index(pred_v, mode="add_aromatic")
            total_nums[step] += 1
            pos_type_data[i] = (pred_pos, pred_atom_type, pred_atom_type_text, pred_edges)
        
        all_data_steps[step][data_idx] = pos_type_data
    if data_num > 2: break

In [19]:
import plotly.graph_objects as go
import numpy as np
import pandas as pd

NUM_DIMS = 3
NUM_SAMPLE_INDEX = 0
NUM_INDEX = [idx for idx in all_data_steps[0]][0]

# Create figure
fig = go.Figure()

fig.update_layout(
    autosize=False,
    width=800,
    height=800
)

# Add traces, one for each slider step
for step in all_data_steps:
    atom_pos, atomic_num, atom_type, edges = all_data_steps[step][NUM_INDEX][NUM_SAMPLE_INDEX]
    #Random data
    edge_x = []
    edge_y = []
    edge_z = []
    for node0, node1 in zip(edges[0, :], edges[1, :]):
        x0, y0, z0 = atom_pos[node0]
        x1, y1, z1 = atom_pos[node1]
        edge_x.append(x0)
        edge_x.append(x1)
        edge_x.append(None)
        edge_y.append(y0)
        edge_y.append(y1)
        edge_y.append(None)
        edge_z.append(z0)
        edge_z.append(z1)
        edge_z.append(None)

    edge_trace = go.Scatter3d(
            visible=False,
            x=edge_x, y=edge_y, z=edge_z,
            line=dict(width=4, color='#888'),
            hoverinfo='none',
            legendgroup="𝜈 = " + str(step),
            mode='lines',
            name="𝜈 = " + str(step),
        )

    fig.add_trace(
        edge_trace
    )

    inter_data = pd.DataFrame({'x': atom_pos[:, 0], 'y': atom_pos[:, 1], 'z': atom_pos[:, 2], 'atomic_num': atomic_num, 'atom_type': atom_type})
    atoms = go.Scatter3d(
            visible=False,
            name="𝜈 = " + str(step),
            legendgroup="𝜈 = " + str(step),
            x=inter_data['x'],
            y=inter_data['y'],
            z=inter_data['z'],
            mode='markers+text',
            customdata=inter_data['atom_type'],
            text=inter_data['atom_type'],
            textposition='middle center',
            marker=dict(
                size=8,
                color=inter_data['atomic_num'],                # set color to an array/list of desired values
                colorscale='Tealrose',   # choose a colorscale
                opacity=1
            )
        )
    #tmp.update(textfont_color=tmp.marker.color, textposition='top center', mode="text", showlegend=True)
 
    fig.add_trace(
        atoms
    )

# Create and add slider
steps = []
for i in range(0, len(fig.data), 2):
    step = dict(
        method="restyle",
        args=["visible", [False] * len(fig.data)],
    )
    step["args"][1][i] = True  # Toggle i'th trace to "visible"
    step["args"][1][i+1] = True
    steps.append(step)

sliders = [dict(
    active=10,
    currentvalue={"prefix": "Frequency: "},
    pad={"t": 50},
    steps=steps
)]

fig.update_layout(
    sliders=sliders,
)

fig.update_layout(scene = dict(
            xaxis = dict(range=(-10, 10), dtick=1),
            yaxis = dict(range=(-10, 10), dtick=1),
            zaxis = dict(range=(-10, 10), dtick=1),),
  )

fig.update_scenes(aspectmode='cube')
fig.show()

In [21]:
## visualize intermediate molecules
from rdkit import Chem
import rdkit 
import nglview as nv

index = 0
step = 800 # intermediate generated molecules at a specific step
sample_index = 0
rdmol = Chem.RWMol(Chem.MolFromSmiles(""))
conformer = Chem.Conformer(0)
data = all_data_steps[step][index][sample_index]
pred_pos = data[0]
pred_atom_text = data[2]
pred_edges = data[3]

for atom_pos, atom_type in zip(pred_pos, pred_atom_text):
    atom = Chem.Atom(atom_type)
    atom_idx = rdmol.AddAtom(atom)
    conformer.SetAtomPosition(atom_idx, atom_pos)

edge_dict = {}
for edge1, edge2 in zip(pred_edges[0], pred_edges[1]):
    if edge1 not in edge_dict: edge_dict[edge1] = []
    if edge2 not in edge_dict: edge_dict[edge2] = []

    if edge1 in edge_dict[edge2]: continue
    edge_dict[edge1].append(edge2)
    edge_dict[edge2].append(edge1)
    
    rdmol.AddBond(int(edge1), int(edge2), order=rdkit.Chem.rdchem.BondType.UNSPECIFIED)

rdmol.AddConformer(conformer)

view = nv.show_rdkit(rdmol)
view

NGLWidget()