In [2]:
import sys
sys.path.append("../..")
import numpy as np
np.set_printoptions(suppress=True)
import torch
torch.set_printoptions(sci_mode=False)
import matplotlib.pyplot as plt
import pandas as pd
import py3Dmol

from src.datamodule import ConformerDataModule
import tqdm

dm = ConformerDataModule("qm9", batch_size_train=512, num_workers=0)
# dm = ConformerDataModule("geom", batch_size_train=512, num_workers=0)

dset = dm.datasets["train"]

dset[0]

Molecule(coords=[19, 3], atoms=[19, 1], id=120661142, sizes=[1], moments=[1, 3], num_nodes=19)

In [3]:
def make_html(view, fname="input.html"):
    net = f'<img id="img_A"><script src="https://3Dmol.org/build/3Dmol-min.js"></script><script src="https://3Dmol.org/build/3Dmol.ui-min.js"></script>' + view._make_html()
    net = net.replace('viewer_{0}.render();'.format(view.uniqueid), 'viewer_{0}.render();\nvar png = viewer_{0}.pngURI();\ndocument.getElementById("img_A").src = png;'.format(view.uniqueid))

    with open(fname, "w") as f:
        f.write(net)

In [48]:
M = dset[20729].clone()
M.coords = M.coords * torch.tensor([-1.0, 1.0, 1.0])

arrow_len = 5.0
radius = 0.1
mid = 0.95
def add(view, M, style="mol", opacity=None):

    view = py3Dmol.view(width=1500, height=1500)
    view.addModel(M.xyzfile, "xyz")
    if style == "mol":
        style = {"stick": {"radius": 0.2}, "sphere": {"scale": 0.2}}
    elif style == "cloud":
        style = {"sphere": {"scale": 0.2}}
        if opacity is not None:
            style["sphere"]["opacity"] = opacity
    view.setStyle({"model": -1}, style)
    return view

def finish(view, xshift=0):
    view.addArrow({
                            'start': {'x':0.0, 'y':0.0, 'z':0.0},
                            'end': {'x':arrow_len, 'y':0.0, 'z':0.0},
                            'radius': radius,
                            'color':'red',
                            'mid': mid,
                            })
    view.addArrow({
                            'start': {'x':0.0, 'y':0.0, 'z':0.0},
                            'end': {'x':0.0, 'y':arrow_len, 'z':0.0},
                            'radius': radius,
                            'color':'green',
                            'mid': mid,
                            })
    view.addArrow({
                            'start': {'x':0.0, 'y':0.0, 'z':0.0},
                            'end': {'x':0.0, 'y':0.0, 'z':arrow_len},
                            'radius': radius,
                            'color':'blue',
                            'mid': mid,
                            })
    view.rotate(30, {'x': 1, 'y': -1, 'z': -.2})
    view.zoom(0.45)
    # view.zoomTo()
    view.translate(xshift, -50)
    return view

In [49]:
from src.models.flow import sample_ONB_0
for i in range(10):
    ONB_1 = M.ONB_4col.numpy()
    ONB_0 = sample_ONB_0(ONB_1.shape, ONB_1[:, -1].reshape(-1, 1))

    ONB_0 = torch.tensor(ONB_0)
    ONB_1 = torch.tensor(ONB_1)

    M_0 = M.clone()
    M_0.coords = M.from_ONB_4col(ONB_0)

    view = py3Dmol.view(width=1500, height=1500)
    view = add(view, M_0, style="cloud")
    view = finish(view, xshift=0)
    make_html(view, f"cloud{i}.html")
    # view.show()