In [1]:
# --- Importing and defining some functions ----
import torch
import py3Dmol
import numpy as np

from typing import Optional
from torch import tensor
from e3nn import o3
from torch_scatter import scatter_mean

from oa_reactdiff.model import LEFTNet

default_float = torch.float64
torch.set_default_dtype(default_float)  # Use double precision for more accurate testing


def remove_mean_batch(
    x: tensor, 
    indices: Optional[tensor] = None
) -> tensor:
    """Remove the mean from each batch in x

    Args:
        x (tensor): input tensor.
        indices (Optional[tensor], optional): batch indices. Defaults to None.

    Returns:
        tensor: output tensor with batch mean as 0.
    """
    if indices == None:
         return x - torch.mean(x, dim=0)
    mean = scatter_mean(x, indices, dim=0)
    x = x - mean[indices]
    return x


def draw_in_3dmol(mol: str, fmt: str = "xyz") -> py3Dmol.view:
    """Draw the molecule

    Args:
        mol (str): str content of molecule.
        fmt (str, optional): format. Defaults to "xyz".

    Returns:
        py3Dmol.view: output viewer
    """
    viewer = py3Dmol.view(1024, 576)
    viewer.addModel(mol, fmt)
    viewer.setStyle({'stick': {}, "sphere": {"radius": 0.36}})
    viewer.zoomTo()
    return viewer


def assemble_xyz(z: list, pos: tensor) -> str:
    """Assembling atomic numbers and positions into xyz format

    Args:
        z (list): chemical elements
        pos (tensor): 3D coordinates

    Returns:
        str: xyz string
    """
    natoms =len(z)
    xyz = f"{natoms}\n\n"
    for _z, _pos in zip(z, pos.numpy()):
        xyz += f"{_z}\t" + "\t".join([str(x) for x in _pos]) + "\n"
    return xyz

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Read the content of the .xyz file
file_path = "demo/example-3/generated/gen_0_ts.xyz"
with open(file_path, 'r') as f:
    xyz = f.read()

view = draw_in_3dmol(xyz, "xyz")
view  # Display Molecules

<py3Dmol.view at 0x7f4eb43be7d0>

In [5]:
# Read the content of the .xyz file
file_path = "demo/example-3/ground_truth/sample_0_ts.xyz"
with open(file_path, 'r') as f:
    xyz = f.read()

view = draw_in_3dmol(xyz, "xyz")
view  # Display Molecules

<py3Dmol.view at 0x7f4eb43bc640>