In [None]:
import matplotlib.pyplot as plt
import numpy as np
import sys
sys.path.append("../..")
from src.datamodule import ConformerDataModule
import torch
torch.set_printoptions(sci_mode=False)

from tqdm import trange
import tqdm

from src.experimental.lit import LitGenConfig, LitGen

ckpt_path = "../../ckpt/qm9/stiefelFM.ckpt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
lit = LitGen.load_from_checkpoint(ckpt_path, map_location=device)
gen = lit.ema.ema_model
gen = gen.to(torch.float64)
gen.eval();

Using device: cuda


In [None]:
ckpt_path = "../../ckpt/qm9/stiefelFM_OT.ckpt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
lit = LitGen.load_from_checkpoint(ckpt_path, map_location=device)
gen_OT = lit.ema.ema_model
gen_OT = gen_OT.to(torch.float64)
gen_OT.eval();

Using device: cuda


In [12]:
def show_traj(traj, idx=0, viewer=None, view=None):
    trajfile = ""
    trajfile += traj[-1][idx].xyzfile
    for M_batched in traj:
        M = M_batched[idx]
        trajfile += M.xyzfile
        assert (M.check_zero_com(M.coords))
    
    for _ in range(10):
        trajfile += M.xyzfile
    
    if view is None:
        import py3Dmol
        view = py3Dmol.view(width=400, height=400)
    
    view.addModelsAsFrames(trajfile, "xyz", viewer=viewer)
    view.setStyle({"model": -1}, {"stick": {"radius": .2}, "sphere": {"scale": .2}}, viewer=viewer)
    
    # stop animation when done
    view.animate({'loop': "forward",
    # "reps": 1, 
    'interval': 100})
    view.zoomTo()
    return view

In [4]:
dm = ConformerDataModule("qm9", batch_size_train=10, num_workers=0)
dset = dm.datasets["test"]

In [5]:
dset[0].show()

<py3Dmol.view at 0x7f61e0106d00>

In [17]:
import torch_geometric as pyg
K = 10
M = dset[0]
Ms = [M for _ in range(K)]

M_batched = pyg.data.Batch.from_data_list(Ms)
M_batched.batch = torch.repeat_interleave(torch.arange(len(Ms)), M_batched.sizes)

from src.models.flow import sample_ONB_0
coords = []
for i in range(M_batched.batch_size):
    mol = M_batched[i]
    col4 = mol.masses_normalized.sqrt().cpu().numpy()
    ONB = sample_ONB_0(mol.ONB_4col.shape, col4=col4)
    ONB = torch.from_numpy(ONB).to(mol.coords)
    coords.append(mol.from_ONB_4col(ONB))
M_batched.coords = torch.cat(coords, dim=0)

start = M_batched.coords

M_batched = M_batched.to(device)
start = start.to(device)

M_batched.coords = M_batched.coords.to(torch.float64)
start = start.to(torch.float64)

In [18]:
out = gen.sample(M_batched.clone(), pbar=True, return_trajectory=True, start=start.clone())

                                                           

In [19]:
out_OT = gen_OT.sample(M_batched.clone(), pbar=True, return_trajectory=True, start=start.clone())

                                                           

In [None]:
idx = 1
show_traj(out[1], idx=idx).show()
show_traj(out_OT[1], idx=idx).show()

In [68]:
def show(M, view=None, viewer=None, color=None, opacity=None):
    assert not M.batched
    if view is None:
        import py3Dmol

        view = py3Dmol.view(width=400, height=400)
    view.addModel(M.xyzfile, "xyz", viewer=viewer)
    if opacity is None:
        opacity = 1.0
    style = {"sphere": {"scale": .2, 'opacity': opacity}}
    if color is not None:
        # style["stick"]["color"] = color
        style["sphere"]["color"] = color
    view.setStyle({"model": -1}, style, viewer=viewer)
    # view.zoomTo()
    return view

import py3Dmol

res = 200
grid = (2, 5)
view = py3Dmol.view(width=(res * grid[1]), height=(res * grid[0]), viewergrid=grid)

for idx in range(5):
    col = idx
    M_traj_base = [Mb[idx] for Mb in out[1]]
    M_traj_OT = [Mb[idx] for Mb in out_OT[1]]



    c = 0
    once = False
    for M, opacity in zip(M_traj_base, np.linspace(0.3, 0.8, 201)):
        if c % 5 == 0:
            view = show(M, view=view, opacity=opacity, viewer=(0, col))
        c += 1
    view = M_traj_base[-1].show(view=view, viewer=(0, col))


    c = 0
    for M, opacity in zip(M_traj_OT, np.linspace(0.3, 0.8, 201)):
        if c % 5 == 0:
            view = show(M, view=view, opacity=opacity, viewer=(1, col))
        c += 1
    view = M_traj_OT[-1].show(view=view, viewer=(1, col))

view.zoom(0.9)
view

<py3Dmol.view at 0x7f617c3a7ac0>

In [69]:
def make_html(view, fname="input.html"):
    net = f'<img id="img_A"><script src="https://3Dmol.org/build/3Dmol-min.js"></script><script src="https://3Dmol.org/build/3Dmol.ui-min.js"></script>' + view._make_html()
    net = net.replace('viewer_{0}.render();'.format(view.uniqueid), 'viewer_{0}.render();\nvar png = viewer_{0}.pngURI();\ndocument.getElementById("img_A").src = png;'.format(view.uniqueid))

    with open(fname, "w") as f:
        f.write(net)

make_html(view, "qm9paths.html")

In [8]:

def show_traj(traj, idx=0, viewer=None, view=None):
    trajfile = ""
    trajfile += traj[-1][idx].xyzfile
    for M_batched in traj:
        M = M_batched[idx]
        trajfile += M.xyzfile
        assert (M.check_zero_com(M.coords))
    
    for _ in range(10):
        trajfile += M.xyzfile
    
    if view is None:
        import py3Dmol
        view = py3Dmol.view(width=400, height=400)
    
    view.addModelsAsFrames(trajfile, "xyz", viewer=viewer)
    view.setStyle({"model": -1}, {"stick": {"radius": .2}, "sphere": {"scale": .2}}, viewer=viewer)
    
    # stop animation when done
    view.animate({'loop': "forward", "reps": 1, 'interval': 10})
    view.zoomTo()
    return view


grid = (3, len(ckpts) + 1)

import py3Dmol
res = 100
view = py3Dmol.view(width=(res * grid[1]), height=(res * grid[0]), viewergrid=grid)

dataset = "qm9"

for row in range(3):
    view = dset[row].show(view=view, viewer=(row, 0))

    for col_, ckpt in enumerate(ckpts):
        col = col_ + 1
        out = torch.load(f"../figure_data/{dataset}_{ckpt}.pt")

        k = 0
        k_idx = row * 10 + k

        view = show_traj(out[1], idx=k_idx, view=view, viewer=(row, col))

def make_html(view, fname="input.html"):
    # net = f'<img id="img_A"><script src="https://3Dmol.org/build/3Dmol-min.js"></script><script src="https://3Dmol.org/build/3Dmol.ui-min.js"></script>' + view._make_html()
    # net = net.replace('viewer_{0}.render();'.format(view.uniqueid), 'viewer_{0}.render();\nvar png = viewer_{0}.pngURI();\ndocument.getElementById("img_A").src = png;'.format(view.uniqueid))
    net = view._make_html()

    with open(fname, "w") as f:
        f.write(net)

make_html(view, "qm9paths.html")