# Plotting Electron Densities

This notebook is meant to show how you can plot machine learned densities. It uses an example structure and pretrained model. You should use this as a template with your own structure(s) and model(s)

In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os
import math
import numpy as np

import torch
import torch_geometric

from torch_cluster import radius_graph
from torch_scatter import scatter

#sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath('')))
from utils import get_iso_permuted_dataset, get_iso_dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.set_default_dtype(torch.float32)

In [None]:
# first, get dataset
hhh = "../data/h_s_only_def2-universal-jfit-decontract_density.out"
ooo = "../data/o_s_only_def2-universal-jfit-decontract_density.out"

fours = "../tests/test_data_generation/testdata_w4.pkl"
w04_dataset = get_iso_permuted_dataset(fours,o_iso=ooo,h_iso=hhh)

data_loader = torch_geometric.data.DataLoader(w04_dataset[:], batch_size=1, shuffle=False)


In [None]:
# now get model
from e3nn.nn.models.gate_points_2101 import Network
from e3nn import o3

model_kwargs = {
        "irreps_in": "2x 0e", #irreps_in 
        "irreps_hidden": [(mul, (l, p)) for l, mul in enumerate([125,40,25,15]) for p in [-1, 1]], #irreps_hidden
        #"irreps_hidden": "100x0e + 100x0o",
        "irreps_out": "12x0e + 5x1o + 4x2e + 2x3o + 1x4e", #irreps_out
        "irreps_node_attr": None, #irreps_node_attr
        "irreps_edge_attr": o3.Irreps.spherical_harmonics(3), #irreps_edge_attr
        "layers": 3,
        "max_radius": 3.5,
        "number_of_basis": 10,
        "radial_layers": 1,
        "radial_neurons": 128,
        "num_neighbors": 12.2298,
        "num_nodes": 24,
        "reduce_output": False,
    }

model = Network(**model_kwargs)
model.to(device)

model.load_state_dict(torch.load('pretrained_model.pt'))

In [None]:
# import ploting stuff

import plotly
from plotly.subplots import make_subplots
import plotly.graph_objects as go

from utils import generate_grid, gau2grid_density_kdtree

In [None]:
# plot the density
# for first structure

num = 0
data = data_loader.dataset[num]

# generate grid
x,y,z,vol,x_spacing,y_spacing,z_spacing = generate_grid(data,spacing=0.15,buffer=1.0)

# evaluate model
mask = torch.where(data.y == 0, torch.zeros_like(data.y), torch.ones_like(data.y)).detach()
y_ml = model(data.to(device))*mask.to(device)

# get densities
Rs = [(12, 0), (5, 1), (4, 2), (2, 3), (1, 4)]
target_density, ml_density = gau2grid_density_kdtree(x.flatten(),y.flatten(),z.flatten(),data,y_ml,Rs)

# plot densities
rows = 1
cols = 2
specs = [[{'is_3d': True} for i in range(cols)]
         for j in range(rows)]
fig = go.FigureWidget(make_subplots(rows=rows, cols=cols, specs=specs, vertical_spacing=0.0))

points = data.pos_orig
xs = points.cpu().numpy()[:,0]
ys = points.cpu().numpy()[:,1]
zs = points.cpu().numpy()[:,2]
geom = go.Scatter3d(x=xs,y=ys,z=zs,mode='markers',marker=dict(size=[30,15,15]*64,color=["red","black","black"]*64,opacity=1.0))
fig.add_trace(geom)
fig.add_trace(geom,row=1,col=2)

fig.add_trace(go.Volume(
        x=x.flatten(),
        y=y.flatten(),
        z=z.flatten(),
        value=target_density.flatten(),
        isomax=0.05,
        colorscale='BuGn',
        opacity=0.1, # needs to be small to see through all surfaces
        surface_count=12, # needs to be a large number for good volume rendering
        showscale=False,
    ),row=1,col=1)

fig.add_trace(go.Volume(
        x=x.flatten(),
        y=y.flatten(),
        z=z.flatten(),
        value=ml_density.flatten(),
        isomax=0.05,
        colorscale='BuGn',
        opacity=0.1, # needs to be small to see through all surfaces
        surface_count=12, # needs to be a large number for good volume rendering
        showscale=False,
    ),row=1,col=2)

fig.update_layout(showlegend=False, width=1000, height=500)

fig.show()

In [None]:
# now plot the difference

fig = go.FigureWidget()

points = data.pos_orig
xs = points.cpu().numpy()[:,0]
ys = points.cpu().numpy()[:,1]
zs = points.cpu().numpy()[:,2]
geom = go.Scatter3d(x=xs,y=ys,z=zs,mode='markers',marker=dict(size=[30,15,15]*64,color=["red","black","black"]*64,opacity=1.0))
fig.add_trace(geom)

fig.add_trace(go.Volume(
        x=x.flatten(),
        y=y.flatten(),
        z=z.flatten(),
        value=ml_density.flatten() - target_density.flatten(),
        isomin=-0.05,
        isomax=0.05,
        #colorscale='BuGn',
        opacity=0.1, # needs to be small to see through all surfaces
        surface_count=12, # needs to be a large number for good volume rendering
    ))

fig.update_layout(width=600, height=600)

fig.show()

In [None]:
# now calculate ep
ep = 100 * np.sum(np.abs(ml_density-target_density)) / np.sum(target_density)
print("Density Difference Error (%)", ep)