In [1]:
import sys
sys.path.append('../')
from src.datamodule import ConformerDatamodule
data = ConformerDatamodule(
    dataset='qm9',
    seed=100,
    batch_size=1,
    split_ratio=(0.8, 0.1, 0.1),
    num_workers=0,
    distributed=False,
    tol=-1.0,
)

dataset = data.datasets['train']

Dataset: 133471 conformations from 133471 molecules.


In [2]:

import py3Dmol
from src import chem
# i = 70 (train)
# i = 532 (val)
start = 10* 56
import torch
idx = 21153


arrow_len = 6.5
radius = 0.25
mid = 0.9
font_size = 80
M = dataset[idx]
M = M.replace(coords=M.coords * torch.tensor([-1.0, 1.0, 1.0]))
view = py3Dmol.view(width=1500, height=1500)
view.addModel(M.xyzfile(), 'xyz')
view.setStyle({'sphere':{'scale': 0.5}})
view.addArrow({
                      'start': {'x':0.0, 'y':0.0, 'z':0.0},
                      'end': {'x':arrow_len, 'y':0.0, 'z':0.0},
                      'radius': radius,
                      'color':'red',
                      'mid': mid,
                      })
view.addArrow({
                      'start': {'x':0.0, 'y':0.0, 'z':0.0},
                      'end': {'x':0.0, 'y':arrow_len, 'z':0.0},
                      'radius': radius,
                      'color':'green',
                        'mid': mid,
                      })
view.addArrow({
                        'start': {'x':0.0, 'y':0.0, 'z':0.0},
                        'end': {'x':0.0, 'y':0.0, 'z':arrow_len},
                        'radius': radius,
                        'color':'blue',
                        'mid': mid,
                        })
view.addLabel('x', {'position': {'x':arrow_len, 'y':0.3, 'z':0.0}, 'backgroundOpacity':0.0, 'fontColor':'red', 'fontSize': font_size})
view.addLabel('y', {'position': {'x':0.6, 'y':arrow_len, 'z':0.0}, 'backgroundOpacity':0.0, 'fontColor':'green', 'fontSize': font_size})  
view.addLabel('z', {'position': {'x':0.0, 'y':0.15, 'z':arrow_len}, 'backgroundOpacity':0.0, 'fontColor':'blue', 'fontSize': font_size})
view.rotate(30, {'x': 1, 'y': -1, 'z': -.2})
view.zoomTo()
view.translate(-80, -50)
view.show()

In [3]:
view.render()

t = view.js()
js = t.startjs + t.endjs

with open("parent.html", "w") as f:
    f.write(js)

In [4]:
import torch
import torch.linalg as LA
def rotated_to_principal_axes(coords, masses, return_moments=True):
    coords, m = coords.double(), masses.double()

    # Subtract CoM
    com = torch.sum(m * coords, dim=0) / m.sum()
    coords = coords - com

    # Compute planar dyadic
    dyadic = m.unsqueeze(-1) * coords.unsqueeze(-1) * coords.unsqueeze(-2)  # (N 1 1) * (N 3 1) * (N 1 3)
    dyadic = dyadic.sum(dim=0)  # (3 3)

    # Diagonalize in double precision
    moments, V = LA.eigh(dyadic)  # (3) (3 3)

    # Sort eigenvalues in descending order
    moments = torch.flip(moments, dims=[-1])
    V = torch.flip(V, dims=[-1])

    # Sanity check
    Q = V @ torch.diag_embed(moments) @ V.mT
    err = (dyadic - Q).abs().max().item()
    if err > 1e-5:
        print(f"WARNING: numerical instability in diagonalizing planar dyadic (error={err})")

    coords = (coords @ V).float()
    return (coords, (moments.float(), V)) if return_moments else coords




In [5]:

import py3Dmol
from src import chem

import torch
idx = 21153


example = dataset[idx]
original_coords = example.coords.clone()

atom_idx = 8
new_masses = example.masses.clone()
new_masses[atom_idx] = (13.003355 - 12) * 50 + 12

coords, moments = rotated_to_principal_axes(example.coords, new_masses, return_moments=True)
moments, transform = moments
transform *= torch.tensor([-1.0, 1.0, 1.0]).double()
transform = LA.inv(transform)

back_masses = example.masses.clone()
back_masses[atom_idx] = 12
com = torch.sum(coords * back_masses, dim=0) / back_masses.sum()
com

arrow_len = 6.5
radius = 0.25
# radius = 0.10
# radius = 0.01
mid = 0.9
font_size = 80
M = example.replace(coords=coords)

view = py3Dmol.view(width=1500, height=1500)
view.addModel(M.xyzfile(), 'xyz')
view.setStyle({'sphere':{'scale': 0.5}})
view.setStyle({'index': atom_idx}, {'sphere':{'scale': 0.5, 'color':'yellow'}})



x, y, z = (transform[0] * arrow_len) + com
x = x.item()
y = y.item()
z = z.item()
view.addArrow({
  'start': {'x':com[0].item(), 'y':com[1].item(), 'z':com[2].item()},
  'end': {'x':x, 'y':y, 'z':z},
  'radius': radius,
  'color':'red',
  'mid': mid,
  })
x, y, z = (transform[1] * arrow_len) + com
x = x.item()
y = y.item()
z = z.item()
view.addArrow({
  'start': {'x':com[0].item(), 'y':com[1].item(), 'z':com[2].item()},
  'end': {'x':x, 'y':y, 'z':z},
  'radius': radius,
  'color':'green',
  'mid': mid,
  })
x, y, z = (transform[2] * arrow_len) + com
x = x.item()
y = y.item()
z = z.item()
view.addArrow({
  'start': {'x':com[0].item(), 'y':com[1].item(), 'z':com[2].item()},
  'end': {'x':x, 'y':y, 'z':z},
  'radius': radius,
  'color':'blue',
  'mid': mid,
  })
arr_color = 'yellow'
view.addArrow({
                      'start': {'x':0.0, 'y':0.0, 'z':0.0},
                      'end': {'x':arrow_len, 'y':0.0, 'z':0.0},
                      'radius': radius,
                      'color':arr_color,
                      'mid': mid,
                      })
view.addArrow({
                      'start': {'x':0.0, 'y':0.0, 'z':0.0},
                      'end': {'x':0.0, 'y':arrow_len, 'z':0.0},
                      'radius': radius,
                      'color':arr_color,
                        'mid': mid,
                      })
view.addArrow({
                        'start': {'x':0.0, 'y':0.0, 'z':0.0},
                        'end': {'x':0.0, 'y':0.0, 'z':arrow_len},
                        'radius': radius,
                        'color':arr_color,
                        'mid': mid,
                        })

arr_color = '#8B8000'
view.addLabel('x', {'position': {'x':arrow_len-0.7, 'y':-2.0, 'z':0.0}, 'backgroundOpacity':0.0, 'fontColor':'red', 'fontSize': font_size})
view.addLabel('y', {'position': {'x':1.7, 'y':arrow_len-1.0, 'z':0.0}, 'backgroundOpacity':0.0, 'fontColor':'green', 'fontSize': font_size})  
view.addLabel('z', {'position': {'x':-0.5, 'y':0.0, 'z':arrow_len}, 'backgroundOpacity':0.0, 'fontColor':'blue', 'fontSize': font_size})
view.addLabel("x'", {'position': {'x':arrow_len, 'y':0.3, 'z':0.0}, 'backgroundOpacity':0.0, 'fontColor':arr_color, 'fontSize': font_size})
view.addLabel("y'", {'position': {'x':0.5, 'y':arrow_len, 'z':0.0}, 'backgroundOpacity':0.0, 'fontColor':arr_color, 'fontSize': font_size})  
view.addLabel("z'", {'position': {'x':0.0, 'y':0.15, 'z':arrow_len}, 'backgroundOpacity':0.0, 'fontColor':arr_color, 'fontSize': font_size})
view.rotate(30, {'x': 1, 'y': -1, 'z': -.2})
view.zoomTo()
view.translate(-80, -50)
view.show()

In [6]:
view.render()

t = view.js()
js = t.startjs + t.endjs

with open("iso1.html", "w") as f:
    f.write(js)

In [7]:

import py3Dmol
from src import chem

import torch


example = dataset[idx]
original_coords = example.coords.clone()

atom_idx = 0
new_masses = example.masses.clone()
new_masses[atom_idx] = (13.003355 - 12) * 50 + 12

coords, moments = rotated_to_principal_axes(example.coords, new_masses, return_moments=True)
moments, transform = moments
transform *= torch.tensor([-1.0, 1.0, 1.0]).double()
transform = LA.inv(transform)

back_masses = example.masses.clone()
back_masses[atom_idx] = 12
com = torch.sum(coords * back_masses, dim=0) / back_masses.sum()
com

arrow_len = 6.5
radius = 0.25
# radius = 0.10
# radius = 0.01
mid = 0.9
font_size = 80
M = example.replace(coords=coords)

view = py3Dmol.view(width=1500, height=1500)
view.addModel(M.xyzfile(), 'xyz')
view.setStyle({'sphere':{'scale': 0.5}})
view.setStyle({'index': atom_idx}, {'sphere':{'scale': 0.5, 'color':'yellow'}})



x, y, z = (transform[0] * arrow_len) + com
x = x.item()
y = y.item()
z = z.item()
view.addArrow({
  'start': {'x':com[0].item(), 'y':com[1].item(), 'z':com[2].item()},
  'end': {'x':x, 'y':y, 'z':z},
  'radius': radius,
  'color':'red',
  'mid': mid,
  })
x, y, z = (transform[1] * arrow_len) + com
x = x.item()
y = y.item()
z = z.item()
view.addArrow({
  'start': {'x':com[0].item(), 'y':com[1].item(), 'z':com[2].item()},
  'end': {'x':x, 'y':y, 'z':z},
  'radius': radius,
  'color':'green',
  'mid': mid,
  })
x, y, z = (transform[2] * arrow_len) + com
x = x.item()
y = y.item()
z = z.item()
view.addArrow({
  'start': {'x':com[0].item(), 'y':com[1].item(), 'z':com[2].item()},
  'end': {'x':x, 'y':y, 'z':z},
  'radius': radius,
  'color':'blue',
  'mid': mid,
  })
arr_color = 'yellow'
view.addArrow({
                      'start': {'x':0.0, 'y':0.0, 'z':0.0},
                      'end': {'x':arrow_len, 'y':0.0, 'z':0.0},
                      'radius': radius,
                      'color':arr_color,
                      'mid': mid,
                      })
view.addArrow({
                      'start': {'x':0.0, 'y':0.0, 'z':0.0},
                      'end': {'x':0.0, 'y':arrow_len, 'z':0.0},
                      'radius': radius,
                      'color':arr_color,
                        'mid': mid,
                      })
view.addArrow({
                        'start': {'x':0.0, 'y':0.0, 'z':0.0},
                        'end': {'x':0.0, 'y':0.0, 'z':arrow_len},
                        'radius': radius,
                        'color':arr_color,
                        'mid': mid,
                        })

arr_color = '#8B8000'
view.addLabel('x', {'position': {'x':arrow_len-0.7, 'y':-1.0, 'z':0.0}, 'backgroundOpacity':0.0, 'fontColor':'red', 'fontSize': font_size})
view.addLabel('y', {'position': {'x':1.9, 'y':arrow_len-.5, 'z':0.0}, 'backgroundOpacity':0.0, 'fontColor':'green', 'fontSize': font_size})  
view.addLabel('z', {'position': {'x':0.5, 'y':0.0, 'z':arrow_len}, 'backgroundOpacity':0.0, 'fontColor':'blue', 'fontSize': font_size})
view.addLabel("x'", {'position': {'x':arrow_len, 'y':0.6, 'z':0.0}, 'backgroundOpacity':0.0, 'fontColor':arr_color, 'fontSize': font_size})
view.addLabel("y'", {'position': {'x':0.5, 'y':arrow_len, 'z':0.0}, 'backgroundOpacity':0.0, 'fontColor':arr_color, 'fontSize': font_size})  
view.addLabel("z'", {'position': {'x':0.0, 'y':0.15, 'z':arrow_len}, 'backgroundOpacity':0.0, 'fontColor':arr_color, 'fontSize': font_size})
view.rotate(30, {'x': 1, 'y': -1, 'z': -.2})
view.zoomTo()
view.translate(-80, -50)
view.show()

In [8]:
view.render()

t = view.js()
js = t.startjs + t.endjs

with open("iso2.html", "w") as f:
    f.write(js)

In [50]:
from src.diffusion.lit import LitEquivariantDDPM
from src.experimental.train import TrainEquivariantDDPMConfig # necessary

model = LitEquivariantDDPM.load_from_checkpoint('../final_checkpoints/qm9_p10.ckpt').to('cuda:0')
model

  rank_zero_warn(


LitEquivariantDDPM(
  (edm): EquivariantDDPM(
    (embed_timestep): PositionalEmbedding()
    (dynamics): EquivariantDynamics(
      (embed_atom): Embedding(90, 32)
      (proj_h): Linear(in_features=171, out_features=256, bias=True)
      (proj_cond): Sequential(
        (0): Linear(in_features=171, out_features=128, bias=True)
        (1): Activation()
        (2): Linear(in_features=128, out_features=128, bias=True)
        (3): Activation()
      )
      (egnn): ModuleList(
        (0): EquivariantBlock(
          (norm_h): LayerNorm(
            (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
            (proj_ada): Linear(in_features=128, out_features=512, bias=True)
          )
          (norm_h_agg): LayerNorm(
            (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
            (proj_ada): Linear(in_features=128, out_features=512, bias=True)
          )
          (edge_mlp): Sequential(
            (0): Linear(in_features=520, out_features=256,

In [51]:
T = 1000
render_every_n_frames = 1
keep_frames = set(list(reversed(range(-1, T + 1, render_every_n_frames))))
out = model.ema.ema_model.sample(M, keep_frames=keep_frames)
out.show()

                                                             

AttributeError: 'tuple' object has no attribute 'show'

In [52]:
M_pred, frames = out

In [53]:
M_pred.show()

In [54]:
from src.metrics import evaluate_prediction

result = evaluate_prediction(M, M_pred)
result

{'unsigned_coords_rmse': 8.346683699755886e-08,
 'moments_rmse': 0.0,
 'coord_rmse': 0.03220527246594429,
 'heavy_coord_rmse': 0.019756125286221504,
 'transform': tensor([[-1.,  0.,  0.],
         [ 0., -1.,  0.],
         [ 0.,  0., -1.]]),
 'correctness': 1.0}

In [56]:
M_pred.transform(result['transform'].to('cuda:0')).show()

In [57]:
M.show()

In [60]:
M.coords

tensor([[ 2.7195e+00, -6.5621e-01,  7.3354e-04],
        [ 2.4692e+00, -1.2432e+00, -8.8618e-01],
        [ 3.7811e+00, -4.1580e-01, -9.1488e-04],
        [ 2.4713e+00, -1.2401e+00,  8.9027e-01],
        [ 2.0433e+00,  5.8222e-01, -5.5274e-04],
        [ 6.8018e-01,  5.0124e-01, -1.4912e-04],
        [-2.2585e-01, -5.3060e-01, -1.6048e-04],
        [-1.5244e+00,  4.7140e-02,  1.0193e-04],
        [-1.3558e+00,  1.4051e+00,  2.7222e-04],
        [-1.7024e-02,  1.6723e+00, -1.5440e-04],
        [ 4.1781e-01,  2.5778e+00,  1.0608e-03],
        [-2.0839e+00,  2.1900e+00,  2.7729e-04],
        [-2.8083e+00, -7.1663e-01,  1.5328e-04],
        [-2.8877e+00, -1.3498e+00,  8.8555e-01],
        [-3.6471e+00, -2.6501e-02, -1.2912e-03],
        [-2.8864e+00, -1.3520e+00, -8.8378e-01],
        [ 5.6852e-02, -1.8819e+00, -3.5771e-04],
        [-7.8704e-01, -2.3456e+00,  2.7596e-04]], device='cuda:0')

In [83]:
from src.visualize import html_render_trajectory
transform = result['transform'].to('cuda:0') * torch.tensor([1, -1, -1], device='cuda:0')
frames = {step: m.replace(coords=m.coords.to('cpu') @ transform.to('cpu')) for step, m in frames.items()}
frames = [frames[step] for step in list(reversed(range(-1, T + 1, render_every_n_frames))) + [-1]]

In [85]:
frames[-1].show()

In [None]:
view2 = py3Dmol.view(width=1500, height=1500)
trajfile = ""
for M in frames:
    trajfile += M.xyzfile()
trajfile = frames[-1].xyzfile() + trajfile  # prepend the last frame so zoomTo works
view2.addModelsAsFrames(trajfile, 'xyz')
view2.setStyle({'sphere':{'scale': 0.5}})
view2.addArrow({
                      'start': {'x':0.0, 'y':0.0, 'z':0.0},
                      'end': {'x':arrow_len, 'y':0.0, 'z':0.0},
                      'radius': radius,
                      'color':'red',
                      'mid': mid,
                      })
view2.addArrow({
                      'start': {'x':0.0, 'y':0.0, 'z':0.0},
                      'end': {'x':0.0, 'y':arrow_len, 'z':0.0},
                      'radius': radius,
                      'color':'green',
                        'mid': mid,
                      })
view2.addArrow({
                        'start': {'x':0.0, 'y':0.0, 'z':0.0},
                        'end': {'x':0.0, 'y':0.0, 'z':arrow_len},
                        'radius': radius,
                        'color':'blue',
                        'mid': mid,
                        })
view2.addLabel('x', {'position': {'x':arrow_len, 'y':0.3, 'z':0.0}, 'backgroundOpacity':0.0, 'fontColor':'red', 'fontSize': font_size})
view2.addLabel('y', {'position': {'x':0.6, 'y':arrow_len, 'z':0.0}, 'backgroundOpacity':0.0, 'fontColor':'green', 'fontSize': font_size})  
view2.addLabel('z', {'position': {'x':0.0, 'y':0.15, 'z':arrow_len}, 'backgroundOpacity':0.0, 'fontColor':'blue', 'fontSize': font_size})
view2.rotate(30, {'x': 1, 'y': -1, 'z': -.2})
view2.animate({"loop": "forward", "interval": 10})
view2.zoomTo()
view2.translate(-80, -50)
view2.show()

In [100]:
view2.render()

t = view2.js()
js = t.startjs + t.endjs
print(js)

<div id="3dmolviewer_UNIQUEID"  style="position: relative; width: 1500px; height: 1500px">
        <tt>jupyter labextension install jupyterlab_3dmol</tt></p>
        </div>
<script>

var loadScriptAsync = function(uri){
  return new Promise((resolve, reject) => {
    //this is to ignore the existence of requirejs amd
    var savedexports, savedmodule;
    if (typeof exports !== 'undefined') savedexports = exports;
    else exports = {}
    if (typeof module !== 'undefined') savedmodule = module;
    else module = {}

    var tag = document.createElement('script');
    tag.src = uri;
    tag.async = true;
    tag.onload = () => {
        exports = savedexports;
        module = savedmodule;
        resolve();
    };
  var firstScriptTag = document.getElementsByTagName('script')[0];
  firstScriptTag.parentNode.insertBefore(tag, firstScriptTag);
});
};

if(typeof $3Dmolpromise === 'undefined') {
$3Dmolpromise = null;
  $3Dmolpromise = loadScriptAsync('https://cdn.jsdelivr.net/npm/3dmol@la