In [None]:
# visualization of sampled conformers

In [None]:

# sampling conformers from pretrained model
# skip this cell if you have already sampled conformers

import os
import pandas as pd
import numpy as np
from omegaconf import OmegaConf

import torch
import lightning.pytorch as pl

from models.mcf import MCF
from builders.builders import build_dataloader


task_config = OmegaConf.load("configs/vis_qm9.yaml")
ckpt_path = task_config.resume_from_path

checkpoint = torch.load(ckpt_path)
task_config_checkpoint = checkpoint["opt"]
task_config_checkpoint.device = "cpu"

mcf = MCF.load_from_checkpoint(ckpt_path, map_location="cpu")
mcf.eval()
print("pretrained MCF model loaded")

# Get task config from model
task_config = OmegaConf.merge(task_config_checkpoint, task_config)
mcf.online_sample = task_config.model_config.params["online_sample"]
mcf.online_evaluation = task_config.model_config.params["online_evaluation"]
mcf.sampling_fn = task_config.model_config.params.sampling_config.sampling_fn
mcf.num_timesteps_ddim = task_config.model_config.params.sampling_config.num_timesteps_ddim

# build data config
data_module = build_dataloader(task_config.data_config)

# build model
task_config.model_config.params["data_type"] = task_config.data_config.data_type
task_config.model_config.params.architecture_config.params.signal_num_channels = (
    task_config.model_config.params.input_signal_num_channels
)
task_config.model_config.params.architecture_config.params.proj_dim = (
    128  # We need to assign a random value here, this gets updated inside the model
)
task_config.model_config.params.architecture_config.params.coord_num_channels = (
    task_config.model_config.params.input_coord_num_channels
)
task_config.model_config.params["viz_dir"] = os.path.join("artifacts", "viz")
ckpt_path = "artifacts"

trainer = pl.Trainer(
    accelerator="gpu",
    num_nodes=1,
    num_sanity_val_steps=0,
    check_val_every_n_epoch=task_config.eval_freq,
    logger=None,
    precision=task_config.precision,
    max_steps=0,
)

trainer.validate(
    mcf,
    dataloaders=[
        data_module.val_dataloader(),
    ],
)

In [None]:
# load sampled conformers

import pickle
sample_path = "artifacts/viz/samples_epoch_0.pkl" # change this to path of sampled conformers
with open(sample_path, 'rb') as f:
    conformer_dict = pickle.load(f)
gts_list = conformer_dict["ground_truth"]
samples_list = conformer_dict["model_samples"]
smiles_list = conformer_dict["smiles"]
print("Number of GT:", len(gts_list))
print("Number of samples:", len(gts_list))

In [None]:
# align sampled conformers to ground truth

import torch
import numpy as np
from rdkit import Chem
from rdkit.Chem import rdMolAlign

def align_mols(samples, gt):
    gt_ = Chem.RemoveHs(gt)
    
    best_rmsd, best_t_mat, best_idx = np.inf, None, 0
    for i, sample in enumerate(samples):
        sample_ = Chem.RemoveHs(sample)
        rmsd, t_mat, _ = rdMolAlign.GetBestAlignmentTransform(sample_, gt_)
        if rmsd < best_rmsd:
            best_rmsd = rmsd
            best_t_mat = t_mat
            best_idx = i

    best_t_mat = torch.tensor(best_t_mat, dtype=torch.float)
    sample = samples[best_idx]

    pos = torch.tensor(sample.GetConformer().GetPositions(), dtype=torch.float)
    pos_ext = torch.ones((len(pos), 4))
    pos_ext[:, :3] = pos
    pos_ext = torch.matmul(pos_ext, best_t_mat.T)
    pos = pos_ext[:, :3]
    for k in range(len(pos_ext)):
        sample.GetConformer().SetAtomPosition(k, pos[k].tolist())

    # print(best_idx, best_rmsd)
    return sample, best_idx

aligned_gts_list = []
aligned_samples_list = []

for i in range(len(gts_list)):
    aligned_gts = []
    aligned_samples = []
    for j in range(len(gts_list[i])):
        gt = gts_list[i][j]
        sample, _ = align_mols(samples_list[i], gt)
        aligned_gts.append(gt)
        aligned_samples.append(sample)
    aligned_gts_list.append(aligned_gts)
    aligned_samples_list.append(aligned_samples)

In [None]:
# helper function to visualize conformers

import py3Dmol
from ipywidgets import interact, fixed, IntSlider
import ipywidgets

def show_mol(mol, view, grid):
    mb = Chem.MolToMolBlock(mol)
    view.removeAllModels(viewer=grid)
    view.addModel(mb,'sdf', viewer=grid)
    view.setStyle({'model':0},{'stick': {}}, viewer=grid)
    view.zoomTo(viewer=grid)
    return view

def view_single(mol, width=600, height=600):
    view = py3Dmol.view(width=width, height=height, linked=False, viewergrid=(1,1))
    show_mol(mol, view, grid=(0, 0))
    return view

def MolTo3DView(mol, size=(400, 300), style="stick", surface=False, opacity=0.5, confId=0):
    """Draw molecule in 3D
    
    Args:
    ----
        mol: rdMol, molecule to show
        size: tuple(int, int), canvas size
        style: str, type of drawing molecule
               style can be 'line', 'stick', 'sphere', 'carton'
        surface, bool, display SAS
        opacity, float, opacity of surface, range 0.0-1.0
    Return:
    ----
        viewer: py3Dmol.view, a class for constructing embedded 3Dmol.js views in ipython notebooks.
    """
    assert style in ('line', 'stick', 'sphere', 'carton')
    mblock = Chem.MolToMolBlock(mol[confId])
    viewer = py3Dmol.view(width=size[0], height=size[1])
    viewer.addModel(mblock, 'mol')
    viewer.setStyle({style:{}})
    if surface:
        viewer.addSurface(py3Dmol.SAS, {'opacity': opacity})
    viewer.zoomTo()
    return viewer

def conf_viewer(idx, mol, **kwargs):
    return MolTo3DView(mol, confId=idx, **kwargs).show()

In [None]:
# visualize ground truth conformer
viewer = MolTo3DView([aligned_gts_list[0][0]], size=(400, 300), style='stick')
viewer.show()
# viewer.png() # output to png

# visualize the sampled conformer aligned to the ground truth
viewer = MolTo3DView([aligned_samples_list[0][0]], size=(400, 300), style='stick')
viewer.show()
# viewer.png() # output to png