In [None]:
import sys
sys.path.append("../..")
import numpy as np
np.set_printoptions(suppress=True)
import torch
torch.set_printoptions(sci_mode=False)
import matplotlib.pyplot as plt
import pandas as pd
import py3Dmol

from src.datamodule import ConformerDataModule
import tqdm

dm = ConformerDataModule("qm9", batch_size_train=512, num_workers=0)
# dm = ConformerDataModule("geom", batch_size_train=512, num_workers=0)

dset = dm.datasets["train"]

dset[0]

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


Molecule(coords=[19, 3], atoms=[19, 1], id=120661142, sizes=[1], moments=[1, 3], num_nodes=19)

In [2]:
def make_html(view, fname="input.html"):
    net = f'<img id="img_A"><script src="https://3Dmol.org/build/3Dmol-min.js"></script><script src="https://3Dmol.org/build/3Dmol.ui-min.js"></script>' + view._make_html()
    net = net.replace('viewer_{0}.render();'.format(view.uniqueid), 'viewer_{0}.render();\nvar png = viewer_{0}.pngURI();\ndocument.getElementById("img_A").src = png;'.format(view.uniqueid))

    with open(fname, "w") as f:
        f.write(net)

In [3]:
M = dset[20729].clone()
M.coords = M.coords * torch.tensor([-1.0, 1.0, 1.0])

arrow_len = 5.0
radius = 0.1
mid = 0.95
def add(view, M, style="mol", opacity=None):

    view = py3Dmol.view(width=1500, height=1500)
    view.addModel(M.xyzfile, "xyz")
    if style == "mol":
        style = {"stick": {"radius": 0.2}, "sphere": {"scale": 0.2}}
    elif style == "cloud":
        style = {"sphere": {"scale": 0.2}}
        if opacity is not None:
            style["sphere"]["opacity"] = opacity
    view.setStyle({"model": -1}, style)
    return view

def finish(view, xshift=0):
    view.addArrow({
                            'start': {'x':0.0, 'y':0.0, 'z':0.0},
                            'end': {'x':arrow_len, 'y':0.0, 'z':0.0},
                            'radius': radius,
                            'color':'red',
                            'mid': mid,
                            })
    view.addArrow({
                            'start': {'x':0.0, 'y':0.0, 'z':0.0},
                            'end': {'x':0.0, 'y':arrow_len, 'z':0.0},
                            'radius': radius,
                            'color':'green',
                            'mid': mid,
                            })
    view.addArrow({
                            'start': {'x':0.0, 'y':0.0, 'z':0.0},
                            'end': {'x':0.0, 'y':0.0, 'z':arrow_len},
                            'radius': radius,
                            'color':'blue',
                            'mid': mid,
                            })
    view.rotate(30, {'x': 1, 'y': -1, 'z': -.2})
    view.zoom(0.62)
    # view.zoomTo()
    view.translate(xshift, -50)
    return view

view = py3Dmol.view(width=1500, height=1500)
view = add(view, M, style="mol")
view = finish(view)
make_html(view, "mol.html")
view

<py3Dmol.view at 0x7fa14a765790>

In [4]:
from src.models.flow import sample_ONB_0

seed = 11
import random
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
ONB_1 = M.ONB_4col.numpy()
ONB_0 = sample_ONB_0(ONB_1.shape, ONB_1[:, -1].reshape(-1, 1))

ONB_0 = torch.tensor(ONB_0)
ONB_1 = torch.tensor(ONB_1)

M_0 = M.clone()
M_0.coords = M.from_ONB_4col(ONB_0)

view = py3Dmol.view(width=1500, height=1500)
view = add(view, M_0, style="cloud")
# view.zoomTo()
view = finish(view, xshift=120)
make_html(view, "cloud.html")
view

<py3Dmol.view at 0x7fa210511550>

In [5]:
from src.stiefel_log import Stiefel_Log_alg, Stiefel_Exp
log = Stiefel_Log_alg(ONB_0.numpy(), ONB_1.numpy())
ts = np.linspace(0.3, 0.7, 101, endpoint=True)

view = py3Dmol.view(width=1500, height=1500)

t = 0.5
ONB_t = torch.from_numpy(Stiefel_Exp(ONB_0, log * t))
opacity = 0.5
M_t = M.clone()
M_t.coords = M.from_ONB_4col(ONB_t)

view = add(view, M_t, style="cloud")

for t in ts:
    ONB_t = torch.from_numpy(Stiefel_Exp(ONB_0, log * t))
    opacity = 0.5
    M_t = M.clone()
    M_t.coords = M.from_ONB_4col(ONB_t)

    view.addModel(M_t.xyzfile, "xyz")
    style = {"sphere": {"scale": 0.2, "opacity": opacity}}
    view.setStyle({"model": -1}, style)

view = finish(view, xshift=120)
make_html(view, "cloud_traj.html")
view

<py3Dmol.view at 0x7fa143df6e80>

In [None]:
from src.stiefel_log import Stiefel_Log_alg, Stiefel_Exp, OT_permutation_reflection

ts = np.linspace(0.3, 0.7, 101, endpoint=True)

view = py3Dmol.view(width=1500, height=1500)

ONB_0_3col = OT_permutation_reflection(M.atoms.squeeze(-1).numpy(), ONB_0[:, :3].numpy(), ONB_1[:, :3].numpy(), 20, 500)
ONB_0_OT = torch.cat([torch.from_numpy(ONB_0_3col), ONB_1[:, 3].reshape(-1, 1)], dim=1)

view = py3Dmol.view(width=1500, height=1500)
M_0_OT = M.clone()
M_0_OT.coords = M.from_ONB_4col(ONB_0_OT)
view = add(view, M_0_OT, style="cloud")
view = finish(view)
make_html(view, "cloud_OT3.html")
view.show()

log = Stiefel_Log_alg(ONB_0_OT.numpy(), ONB_1.numpy())

t = 0.5
ONB_t = torch.from_numpy(Stiefel_Exp(ONB_0_OT, log * t))
opacity = 0.5
M_t = M.clone()
M_t.coords = M.from_ONB_4col(ONB_t)

view = add(view, M_t, style="cloud")

for t in ts:
    ONB_t = torch.from_numpy(Stiefel_Exp(ONB_0_OT, log * t))
    opacity = 0.5
    M_t = M.clone()
    M_t.coords = M.from_ONB_4col(ONB_t)

    view.addModel(M_t.xyzfile, "xyz")
    style = {"sphere": {"scale": 0.2, "opacity": opacity}}
    view.setStyle({"model": -1}, style)

view = finish(view)
make_html(view, "cloud_traj_OT3.html")
view

<py3Dmol.view at 0x7f913668a6a0>

In [None]:
# torch.save(ONB_0_OT, "ONB_0_OT.pt")

In [6]:
from src.stiefel_log import Stiefel_Log_alg, Stiefel_Exp, OT_permutation_reflection

ts = np.linspace(0.25, 0.75, 101, endpoint=True)

view = py3Dmol.view(width=1500, height=1500)

ONB_0_OT = torch.load("ONB_0_OT.pt")

view = py3Dmol.view(width=1500, height=1500)
M_0_OT = M.clone()
M_0_OT.coords = M.from_ONB_4col(ONB_0_OT)
view = add(view, M_0_OT, style="cloud")
view = finish(view, xshift=120)
make_html(view, "cloud_OT3.html")
view.show()

log = Stiefel_Log_alg(ONB_0_OT.numpy(), ONB_1.numpy())

t = 0.5
ONB_t = torch.from_numpy(Stiefel_Exp(ONB_0_OT, log * t))
opacity = 0.5
M_t = M.clone()
M_t.coords = M.from_ONB_4col(ONB_t)

view = add(view, M_t, style="cloud")

for t in ts:
    ONB_t = torch.from_numpy(Stiefel_Exp(ONB_0_OT, log * t))
    opacity = 0.5
    M_t = M.clone()
    M_t.coords = M.from_ONB_4col(ONB_t)

    view.addModel(M_t.xyzfile, "xyz")
    style = {"sphere": {"scale": 0.2, "opacity": opacity}}
    view.setStyle({"model": -1}, style)

view = finish(view, xshift=120)
make_html(view, "cloud_traj_OT3.html")
view

<py3Dmol.view at 0x7fa14546cc70>

In [43]:
mo = M.moments.squeeze().tolist()
# print as tuple with 2 decimal places
tuple([f"{m:.3f}" for m in mo])


('378.254', '163.007', '9.747')

In [41]:
M.formula

'C_6 N_1 O_2 H_9'