In [None]:
import py3Dmol
from rdkit import Chem
from copy import deepcopy
import pickle


def visualize_mol(mol, size=(300, 300), surface=False, opacity=0.5):
    """Draw molecule in 3D
    
    Args:
    ----
        mol: rdMol, molecule to show
        size: tuple(int, int), canvas size
        style: str, type of drawing molecule
               style can be 'line', 'stick', 'sphere', 'carton'
        surface, bool, display SAS
        opacity, float, opacity of surface, range 0.0-1.0
    Return:
    ----
        viewer: py3Dmol.view, a class for constructing embedded 3Dmol.js views in ipython notebooks.
    """
    # assert style in ('line', 'stick', 'sphere', 'carton')
    mblock = Chem.MolToMolBlock(mol)
    viewer = py3Dmol.view(width=size[0], height=size[1])
    viewer.addModel(mblock, 'mol')
    viewer.setStyle({'stick':{}, 'sphere':{'radius':0.35}})
    if surface:
        viewer.addSurface(py3Dmol.SAS, {'opacity': opacity})
    viewer.zoomTo()
    return viewer

def set_rdmol_positions(rdkit_mol, pos):
    """
    Args:
        rdkit_mol:  An `rdkit.Chem.rdchem.Mol` object.
        pos: (N_atoms, 3)
    """
    mol = deepcopy(rdkit_mol)
    set_rdmol_positions_(mol, pos)
    return mol


def set_rdmol_positions_(mol, pos):
    """
    Args:
        rdkit_mol:  An `rdkit.Chem.rdchem.Mol` object.
        pos: (N_atoms, 3)
    """
    for i in range(pos.shape[0]):
        mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())
    return mol

In [11]:

import pickle
import numpy as np
from tqdm.auto import tqdm
from rdkit.Chem import AllChem
import os



generated="checkpoints/qm9_500steps/samples/samples_all.pkl" # or traj saved by test.py --save_traj


num_confs=1




def generate_conformer(rdmol, confs):    
    conf_idx = np.arange(confs.shape[0])
    # np.random.RandomState(2021).shuffle(conf_idx)
    conf_idx = conf_idx[:num_confs]
    for idx in conf_idx:
        mol = set_rdmol_positions(rdmol, confs[idx])
    return mol


viewers=[]
original_viewers=[]
confs_all=[]




with open(generated, 'rb') as f:
    gens = pickle.load(f)
gens_prop = []
for data in tqdm(gens):
    # if not isinstance(data, Data):
    #     data = EasyDict(data)
    data.num_nodes = data.rdmol.GetNumAtoms()
    
    rdmol = data.rdmol
    mol_orginal = set_rdmol_positions(rdmol,data.pos)

    if len(data.pos_gen.shape)==3: #traj torch.Size([n_steps, n_atoms, 3])
        data.pos_prop = data.pos_gen.reshape(data.pos_gen.shape[0], -1, data.num_nodes,3)
        mol  = rdmol
        num_conformers = data.pos_gen.shape[0]

        AllChem.EmbedMultipleConfs(mol, numConfs=num_conformers)
        for step,confs in enumerate(data.pos_prop):
            # mol = generate_conformer(rdmol=rdmol, confs=confs)
            pos = confs[0]
            for i in range(pos.shape[0]):
                # mol.GetConformer(step).GetPositions()
                mol.GetConformer(step).SetAtomPosition(i, pos[i].tolist())
                # mol.GetConformer(step).GetPositions()
            confs_all.append(mol)
        writer = Chem.PDBWriter(f'{os.path.dirname(generated)}/{data.smiles}_traj.pdb')
        writer.write(mol)
        writer.close()
        # Chem.MolToPDBFile(mol, f'{os.path.dirname(generated)}/{data.smiles}_traj.pdb')
        print("save to",  f'{os.path.dirname(generated)}/{data.smiles}_traj.pdb')
        # break
    else:
        data.pos_prop = data.pos_gen.reshape(-1, data.num_nodes, 3)
        confs = data.pos_prop
        mol = generate_conformer(rdmol, confs)
        confs_all.append(mol)
        viewers.append(visualize_mol(mol))
        original_viewers.append(visualize_mol(mol_orginal))
        # Chem.MolToPDBFile(mol, f'{os.path.dirname(generated)}/{data.smiles}_{idx}.pdb')
    # break

print(generated)
for num, viewer in enumerate(viewers[0:5]):
    print(f"Ground true {num}: ")
    original_viewers[num].show()
    print(f"generated {num}:")
    viewer.show()


  0%|          | 0/200 [00:00<?, ?it/s]

checkpoints/qm9_500steps/samples/samples_all.pkl
Ground true 0: 


generated 0:


Ground true 1: 


generated 1:


Ground true 2: 


generated 2:


Ground true 3: 


generated 3:


Ground true 4: 


generated 4:
