In [4]:
import sys
sys.path.append("../..")
from src.datamodule import ConformerDataModule
import torch
import pandas as pd
import numpy as np
import py3Dmol

In [5]:
qm9 = ConformerDataModule("qm9", batch_size_train=10, num_workers=0)
geom = ConformerDataModule("geom", batch_size_train=2, num_workers=0, only_lowest_energy_conformers=True)

In [6]:
def make_html(view, fname="input.html"):
    net = f'<img id="img_A"><script src="https://3Dmol.org/build/3Dmol-min.js"></script><script src="https://3Dmol.org/build/3Dmol.ui-min.js"></script>' + view._make_html()
    net = net.replace("});\n</script>", f"var png = viewer_" + str(view.uniqueid) + ".pngURI();\ndocument.getElementById(\"img_A\").src = png;\n});\n</script>")

    with open(fname, "w") as f:
        f.write(net)

In [7]:
x = 1
y = -4
dataset = "geom"

for ckpt in ["kreedXL", "stiefelFM_OT_filter"]:
    for test_set_idx in [0, 11, 229, 1704]:

        ckpt_samples = torch.load(f"../../samples/{dataset}/{ckpt}.pt")
        samples = ckpt_samples[test_set_idx]

        M_true = geom.datasets["test"][test_set_idx]
        M_pred = M_true.clone()

        grid = (2, 5)
        view = py3Dmol.view(width=(400 * grid[1]), height=(400 * grid[0]), viewergrid=grid)
        ks = np.argsort([samples[k]["coord_rmse"] for k in range(10)])
        min_k = ks[0]
        for i, k in enumerate(ks):
            r, c = i // grid[1], i % grid[1]
            M_pred.coords = samples[k]["coords"]
            if samples[k]["rmsd_under_pt25"] or k == min_k:
                view = M_pred.show(view=view, viewer=(r, c))
                view = M_true.show(view=view, viewer=(r, c), color="cyan")
                view.addLabel(f"RMSD: {samples[k]['coord_rmse']:.3f}", {"fontSize": 24, "fontColor": "black", "backgroundOpacity": 0.0, "position": {"x": x, "y": y, "z": 0}}, viewer=(r, c))
                if samples[k]["rmsd_under_pt25"]:
                    view.setBackgroundColor("#cffbcf", viewer=(r, c))
            else:
                M_pred.show(view=view, viewer=(r, c))
            if samples[k]["validity"] < 1.0:
                view.setBackgroundColor("#feefb2", viewer=(r, c))
        view.zoom(0.7)

        view.show()
        make_html(view, f"{ckpt}_{test_set_idx}.html")


In [8]:
x = 1
y = -3
dataset = "qm9"


for ckpt in ["kreedXL", "stiefelFM"]:
    for test_set_idx in [0, 342, 401, 171]:

        ckpt_samples = torch.load(f"../../samples/{dataset}/{ckpt}.pt")
        samples = ckpt_samples[test_set_idx]

        M_true = qm9.datasets["test"][test_set_idx]
        M_pred = M_true.clone()

        grid = (2, 5)
        view = py3Dmol.view(width=(400 * grid[1]), height=(400 * grid[0]), viewergrid=grid)
        ks = np.argsort([samples[k]["coord_rmse"] for k in range(10)])
        min_k = ks[0]
        for i, k in enumerate(ks):
            r, c = i // grid[1], i % grid[1]
            M_pred.coords = samples[k]["coords"]
            if samples[k]["rmsd_under_pt25"] or k == min_k:
                view = M_pred.show(view=view, viewer=(r, c))
                view = M_true.show(view=view, viewer=(r, c), color="cyan")
                view.addLabel(f"RMSD: {samples[k]['coord_rmse']:.3f}", {"fontSize": 24, "fontColor": "black", "backgroundOpacity": 0.0, "position": {"x": x, "y": y, "z": 0}}, viewer=(r, c))
                if samples[k]["rmsd_under_pt25"]:
                    view.setBackgroundColor("#cffbcf", viewer=(r, c))
            else:
                M_pred.show(view=view, viewer=(r, c))
            if samples[k]["validity"] < 1.0:
                view.setBackgroundColor("#feefb2", viewer=(r, c))
        view.zoom(1.0)

        view.show()
        make_html(view, f"qm9_{ckpt}_{test_set_idx}.html")
