In [1]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import wandb
import omegaconf
import copy
import os, sys, pathlib
import pickle
import scipy

from torch_geometric.loader.dataloader import Collater


project = "EquilibriumEquiFormer"

from src.deq2ff.plotting.style import set_seaborn_style, set_style_after

# parent folder of the plot
plotfolder = "/ssd/gen/equilibrium-forcefields/src/deq2ff/plotting/"
plotfolder = os.path.join(plotfolder, "plots")

chemical_symbols = [
    "_",
    "H",
    "He",
    "Li",
    "Be",
    "B",
    "C",
    "N",
    "O",
    "F",
    "Ne",
    "Na",
    "Mg",
    "Al",
    "Si",
    "P",
    "S",
    "Cl",
    "Ar",
    "K",
]

In [2]:
import py3Dmol

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import wandb
import omegaconf
import copy
import os, sys, pathlib
import pickle
import scipy

from torch_geometric.loader.dataloader import Collater

In [3]:
import equiformer.datasets.pyg.md_all as md_all

In [4]:
def get_dataset(target="ethanol", dname="md17"):
    print(f"Loading dataset...")
    (
        train_dataset,
        val_dataset,
        test_dataset,
        test_dataset_full,
        all_dataset,
    ) = md_all.get_md_datasets(
        root="datasets",
        dataset_arg=target,
        dname=dname,
        train_size=1000,
        val_size=50,
        test_patch_size=None,  # influences data splitting
        test_patch_size_select=None,  # doesn't influence data splitting
        seed=0,
        order="consecutive_all",
    )
    return all_dataset

In [5]:
# assert that dataset is consecutive
# samples = Collater(follow_batch=None, exclude_keys=None)(
#     [all_dataset[i] for i in range(10)]
# )

In [6]:

def plot_model_py3d(
    idx, dataset, show_forces=True, next=False, savefig=False, target="ethanol",
    force_length = 0.05,
    force_width = 0.1,
    ):
    """Display a 3D visualization of the molecule with forces.

    Args:
        idx (_type_): _description_
        dataset (_type_): _description_
        show_forces (bool, optional): Display forces. Defaults to True.
        next (bool, optional): Overlay a plot of the sample with idx+1. Defaults to False.
    """
    collate = Collater(None, None)

    # get the data
    # data = collate([dataset[_idx] for _idx in idx])
    data = collate([dataset[idx]])
    # positions = data.pos
    # forces = data.dy
    # e = data.y
    positions = data.pos.tolist()
    forces = data.dy.tolist()
    e = data.y.tolist()

    z = data.z
    # print('z', z)
    atoms = [chemical_symbols[int(_z)] for _z in z]
    # print('atoms', atoms)
    
    # Create the 3D visualization
    view = py3Dmol.view(width=800, height=600)

    def add_mol_to_view(_view, _atoms, _positions, _forces):
        # Generate XYZ string for py3Dmol
        xyz_str = f"{len(_atoms)}\n\n"
        for atom, pos in zip(_atoms, _positions):
            xyz_str += f"{atom} {pos[0]} {pos[1]} {pos[2]}\n"

        _view.addModel(xyz_str, "xyz")

        if show_forces:
            for pos, force in zip(_positions, _forces):
                # print('pos', pos)
                # print('force', force)
                start = {"x": pos[0], "y": pos[1], "z": pos[2]}
                end = {
                    "x": pos[0] + force[0] * force_length,
                    "y": pos[1] + force[1] * force_length,
                    "z": pos[2] + force[2] * force_length,
                }
                _view.addArrow({"start": start, "end": end, "radius": force_width, "color": "orange"})

    add_mol_to_view(view, atoms, positions, forces)
    
    if next:
        idx = idx - 1
        data = collate([dataset[idx]])
        pos2 = data.pos.tolist()
        force2 = data.dy.tolist()
        e2 = data.y.tolist()
        z2 = data.z
        atoms2 = [chemical_symbols[int(_z)] for _z in z2]

        add_mol_to_view(view, atoms2, pos2, force2)

    style = {"stick": {"radius": 0.1}, "sphere": {"scale": 0.2}}
    view.setStyle({"model": -1}, style, viewer=None)
    
    # viewer.setBackgroundColor('#FFFFFF')  

    view.setStyle({
        # 'stick': {'radius': 0.1, 'color': "FFFFD1"}, 
        'stick': {'radius': 0.08, 'color': "#FFFFD1"}, # FFFFD1
        'sphere': {'scale': 0.20, 'colorscheme': {'C': '#4171bb', 'O': '#f4919f', 'N': '#FABF50', 'H': '#40a597'}}
    })
    
    view.zoomTo()
    view.show()
    
    if savefig:
        view.render_image()
        view.png()
        fname = f"mol_{target}_{idx}{'_forces' if show_forces else ''}"
        
        with open(fname + '.html', 'w') as f:
            f.write(view.render())
    
        from selenium import webdriver
        from webdriver_manager.chrome import ChromeDriverManager

        # Set up Selenium to take a screenshot of the HTML view
        browser = webdriver.Chrome(ChromeDriverManager().install())
        browser.get(fname + '.html')

        # Set the window size to capture the desired resolution
        browser.set_window_size(800, 800)

        # Take a screenshot
        browser.save_screenshot(fname + '.png')

        browser.quit()

        print(f" Saved as {fname}.png")


In [7]:
dataset = get_dataset()
print(f"size train dataset: {len(dataset)}")

Loading dataset...
Found ['datasets/md17/ethanol/raw/md17_ethanol.npz'] files, skipping download
Found ['datasets/md17/ethanol/processed/md17-ethanol.pt'] files, skipping process
Dataset size: 555092
Dataset: Using consecutive order
size train dataset: 555092


In [8]:
# show first sample
plot_model_py3d(0, dataset, show_forces=True, force_length=0.09, force_width=0.06)

In [9]:
# show first sample
plot_model_py3d(0, dataset, show_forces=False)

In [10]:
dataset_aspirin = get_dataset(target="aspirin")
print(f"size train dataset: {len(dataset)}")
plot_model_py3d(0, dataset_aspirin, show_forces=False)

Loading dataset...
Found ['datasets/md17/aspirin/raw/md17_aspirin.npz'] files, skipping download
Found ['datasets/md17/aspirin/processed/md17-aspirin.pt'] files, skipping process
Dataset size: 211762
Dataset: Using consecutive order
size train dataset: 555092


In [11]:
# from rdkit import Chem
# from rdkit.Chem import AllChem, rdMolDescriptors
# import plotly.graph_objects as go

# def display_molecule(smiles, width=500, height=400):
#     mol = Chem.MolFromSmiles(smiles)
#     mol = Chem.AddHs(mol)
#     AllChem.EmbedMolecule(mol)
#     AllChem.MMFFOptimizeMolecule(mol)
    
#     viewer = py3Dmol.view(width=width, height=height)

#     mol_block = Chem.MolToMolBlock(mol)
#     viewer.addModel(mol_block, "mol")
#     # light gray F0F2F6. # white FFFFFF
#     viewer.setBackgroundColor('#FFFFFF')  

#     # viewer.setStyle({'stick': {}})
#     # viewer.setStyle({'model': -1}, {"cartoon": {'color': 'spectrum'}})
#     # cscheme = {'C': 'gray', 'O': 'red', 'N': 'blue', 'H': 'white'}
#     cscheme = "cyanToBlue"
#     viewer.setStyle({
#         # 'stick': {'radius': 0.1, 'color': "FFFFD1"}, 
#         'stick': {'radius': 0.08, 'color': "#FFFFD1"}, # FFFFD1
#         'sphere': {'scale': 0.20, 'colorscheme': {'C': '#4171bb', 'O': '#f4919f', 'N': '#FABF50', 'H': '#40a597'}}
#     })
#     # viewer.setStyle({'stick': {'radius': 0.2}, 'sphere': {'scale': 0.25}})
#     # viewer.addSurface(py3Dmol.VDW, {'opacity': 0.5, 'color': 'white'})

#     # viewer.setViewStyle({'style': 'outline'})  # Adds an outline style for better focus
#     viewer.zoomTo()
#     # viewer.animate({'duration': 1000, 'zoom': 1.5})  # Smooth zoom animation

#     return viewer

In [18]:



from rdkit import Chem
from rdkit.Chem import AllChem, rdMolDescriptors

def display_molecule(smiles, width=500, height=400):
    mol = Chem.MolFromSmiles(smiles)
    mol = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol)
    AllChem.MMFFOptimizeMolecule(mol)
    
    viewer = py3Dmol.view(width=width, height=height)

    mol_block = Chem.MolToMolBlock(mol)
    viewer.addModel(mol_block, "mol")
    # light gray F0F2F6. # white FFFFFF
    viewer.setBackgroundColor('#FFFFFF')  

    # viewer.setStyle({'stick': {}})
    # viewer.setStyle({'model': -1}, {"cartoon": {'color': 'spectrum'}})
    # cscheme = {'C': 'gray', 'O': 'red', 'N': 'blue', 'H': 'white'}
    viewer.setStyle({
        # 'stick': {'radius': 0.1, 'color': "FFFFD1"}, 
        'stick': {'radius': 0.08, 'color': "#FFFFD1"}, # FFFFD1
        'sphere': {'scale': 0.20, 'colorscheme': {'C': '#4171bb', 'O': '#f4919f', 'N': '#FABF50', 'H': '#40a597'}}
    })
    # viewer.setStyle({'stick': {'radius': 0.2}, 'sphere': {'scale': 0.25}})
    # viewer.addSurface(py3Dmol.VDW, {'opacity': 0.5, 'color': 'white'})

    # viewer.setViewStyle({'style': 'outline'})  # Adds an outline style for better focus
    viewer.zoomTo()
    # viewer.animate({'duration': 1000, 'zoom': 1.5})  # Smooth zoom animation

    fname = f"mol_{smiles}"
    output_file = fname + '.html'
    
    # viewer.show()
    
    # viewer.render_image()
    # viewer.png()
    
    # Use `viewer.getHTML()` to get the HTML/JavaScript string
    html_content = viewer.getHTML()
    
    print(type(html_content))

    # Save the HTML content to a file
    with open(output_file, 'w') as f:
        f.write(html_content)

    
    print(f" Saved as {fname}.html")

In [19]:
smiles = "O=S(=O)(O)CCN1CCN(CCO)CC1"
display_molecule(smiles)

<class 'py3Dmol.view'>


TypeError: write() argument must be str, not view