# Calculating and Plotting Electrostatic Potential

This code shows how to calculate the electrostatic potential (and electric field) around a molecular cluster. This requires a custom-built version of psi4 available at https://github.com/JoshRackers/psi4 . It will also show how to use the marching cubes algorithm to generate a density isosurface and plot the potential on the surface.

In [None]:
import sys
if "/home/jracker/codes/psi4/build/stage/lib" not in sys.path:
    sys.path.append("/home/jracker/codes/psi4/build/stage/lib")
    #print(sys.path)
import psi4

In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os
import math
import numpy as np

import torch
import torch_geometric

sys.path.append(os.path.dirname(os.path.abspath('')))
from utils import get_iso_permuted_dataset
from utils import flatten_list

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
print (device)

# conversion from Hartrees to kcal/mol
ha2kcalmol = 627.5094740631

In [None]:
# first, get dataset
hhh = "../data/h_s_only_def2-universal-jfit-decontract_density.out"
ooo = "../data/o_s_only_def2-universal-jfit-decontract_density.out"

fours = "../tests/test_data_generation/testdata_w4.pkl"
w04_dataset = get_iso_permuted_dataset(fours,o_iso=ooo,h_iso=hhh)

data_loader = torch_geometric.data.DataLoader(w04_dataset[:], batch_size=1, shuffle=False)

In [None]:
# now get model
from e3nn.nn.models.gate_points_2101 import Network
from e3nn import o3

model_kwargs = {
        "irreps_in": "2x 0e", #irreps_in 
        "irreps_hidden": [(mul, (l, p)) for l, mul in enumerate([125,40,25,15]) for p in [-1, 1]], #irreps_hidden
        #"irreps_hidden": "100x0e + 100x0o",
        "irreps_out": "12x0e + 5x1o + 4x2e + 2x3o + 1x4e", #irreps_out
        "irreps_node_attr": None, #irreps_node_attr
        "irreps_edge_attr": o3.Irreps.spherical_harmonics(3), #irreps_edge_attr
        "layers": 3,
        "max_radius": 3.5,
        "number_of_basis": 10,
        "radial_layers": 1,
        "radial_neurons": 128,
        "num_neighbors": 12.2298,
        "num_nodes": 24,
        "reduce_output": False,
    }

model = Network(**model_kwargs)
model.to(device)

model.load_state_dict(torch.load('pretrained_model.pt'))

In [None]:
# generate density on a grid for the first structure
from utils import generate_grid, gau2grid_density_kdtree

data = data_loader.dataset[0]

mask = torch.where(data.y == 0, torch.zeros_like(data.y), torch.ones_like(data.y)).detach()
y_ml = model(data.to(device))*mask.to(device)

x,y,z,vol,x_spacing,y_spacing,z_spacing = generate_grid(data,spacing=0.1,buffer=2.0)

Rs = [(12, 0), (5, 1), (4, 2), (2, 3), (1, 4)]
target_density, ml_density = gau2grid_density_kdtree(x.flatten(),y.flatten(),z.flatten(),data,y_ml,Rs)
print("done!")

In [None]:
## Now compute isosurface with marching cubes

from skimage import measure

l = round((target_density.shape[0])**(1/3))
t_density = target_density.reshape(l,l,l)
m_density = ml_density.reshape(l,l,l)

# level sets what the density isosurface is
level = 0.002
t_verts, t_faces, t_normals, t_values = measure.marching_cubes(t_density, spacing=(x_spacing, y_spacing, z_spacing), level=level, step_size=3)
ml_verts, ml_faces, ml_normals, ml_values = measure.marching_cubes(m_density, spacing=(x_spacing, y_spacing, z_spacing), level=level, step_size=3)

# now scale to get back to real coordinates
xyz_min = np.array([x[0,0,0],y[0,0,0],z[0,0,0]])
t_verts = t_verts + xyz_min
ml_verts = ml_verts + xyz_min
print(t_verts[:,0].shape)

In [None]:
# now compute the electrostatic potential and field at the points on the isosurface

from utils import compute_potential_field
t_isopot, t_isofield, m_isopot, m_isofield = compute_potential_field(t_verts[:,0].flatten(),t_verts[:,1].flatten(),t_verts[:,2].flatten(),data,y_ml,Rs)

In [None]:
import plotly
from plotly.subplots import make_subplots
import plotly.graph_objects as go

#plotly.offline.init_notebook_mode(connected=True)

rows = 1
cols = 3
specs = [[{'is_3d': True} for i in range(cols)]
         for j in range(rows)]

# FigureWidget apparently works faster with numpy arrays
fig = go.FigureWidget(make_subplots(rows=rows, cols=cols, specs=specs, subplot_titles=('ML Electrostatic Potential','Target Electrostatic Potential','Difference')))

traces = []
for verts, faces, intense, name in zip([t_verts, t_verts], [t_faces, t_faces], [m_isopot, t_isopot], ['ML Density','Target Density']):
    traces.append(go.Mesh3d(
        x=verts[:,0].flatten(),
        y=verts[:,1].flatten(),
        z=verts[:,2].flatten(),
        i=faces[:,0].flatten(),
        j=faces[:,1].flatten(),
        k=faces[:,2].flatten(),
        intensity=intense.flatten(),
        colorscale="RdBu",
        cmin=-0.075,
        cmax=0.075,
        opacity=0.5,
        name=name,
        showlegend=False,
        showscale=False
    ))

#fig.add_traces(traces[0], rows=[1] * len(traces[0]), cols=[1] * len(traces[0]))
#fig.add_traces(traces[1], rows=[1] * len(traces[1]), cols=[2] * len(traces[1]))
fig.add_trace(traces[0], row=1, col=1)
fig.add_trace(traces[1], row=1, col=2)

traces.append(go.Mesh3d(
        x=t_verts[:,0].flatten(),
        y=t_verts[:,1].flatten(),
        z=t_verts[:,2].flatten(),
        i=t_faces[:,0].flatten(),
        j=t_faces[:,1].flatten(),
        k=t_faces[:,2].flatten(),
        intensity=(m_isopot.flatten() - t_isopot.flatten()),
        colorscale="RdBu",
        cmin=-0.075,
        cmax=0.075,
        opacity=0.5,
        name='Difference',
        showlegend=False,
))

fig.add_trace(traces[2], row=1, col=3)


points = data.pos_orig

xs = points.cpu().numpy()[:,0]
ys = points.cpu().numpy()[:,1]
zs = points.cpu().numpy()[:,2]
geom = go.Scatter3d(x=xs,y=ys,z=zs,mode='markers',marker=dict(size=5,color='Black',opacity=1.0))
#fig.add_scatter3d(x=xs,y=ys,z=zs,mode='markers',marker=dict(size=12,color='Black',opacity=1.0))
fig.add_trace(geom, row=1, col=1)
fig.add_trace(geom, row=1, col=2)
fig.add_trace(geom, row=1, col=3)


#fig.update_layout(showlegend=True)

# this is subtracted off the buffer defined above
modbuf = 0.0

fig.update_layout(showlegend=False, 
                  scene = dict(
                    xaxis = dict(visible=False,showgrid=False,zeroline=False,range=[x[0,0,0]+modbuf,x[-1,-1,-1]-modbuf]),
                    yaxis = dict(visible=False,showgrid=False,zeroline=False,range=[y[0,0,0]+modbuf,y[-1,-1,-1]-modbuf]),
                    zaxis = dict(visible=False,showgrid=False,zeroline=False,range=[z[0,0,0]+modbuf,z[-1,-1,-1]-modbuf]),
                    ),
                  scene2 = dict(
                    xaxis = dict(visible=False,showgrid=False,zeroline=False,range=[x[0,0,0]+modbuf,x[-1,-1,-1]-modbuf]),
                    yaxis = dict(visible=False,showgrid=False,zeroline=False,range=[y[0,0,0]+modbuf,y[-1,-1,-1]-modbuf]),
                    zaxis = dict(visible=False,showgrid=False,zeroline=False,range=[z[0,0,0]+modbuf,z[-1,-1,-1]-modbuf])
                    ),
                  scene3 = dict(
                    xaxis = dict(visible=False,showgrid=False,zeroline=False,range=[x[0,0,0]+modbuf,x[-1,-1,-1]-modbuf]),
                    yaxis = dict(visible=False,showgrid=False,zeroline=False,range=[y[0,0,0]+modbuf,y[-1,-1,-1]-modbuf]),
                    zaxis = dict(visible=False,showgrid=False,zeroline=False,range=[z[0,0,0]+modbuf,z[-1,-1,-1]-modbuf])
                    ),
                  width  = 1000,
                  height = 500,
                 )

fig.show()