# Utilize pyCHARMM *user energy* to implement a machine learned quantum mechanical potential to calculate energy and forces

## Our objectives with this tutorial will be to:
> 1. Illustrate how to use the pyCHARMM implementation of CHARMM User Energy functionality
> 2. Integrate the torchANI ML-QM potentials to provide direct accesss to QM quality energy and forces in pyCHARMM
> 3. Compare the torchANI-2x energy function for butane to that in CGENFF

## import needed python/pyCHARMM functionality

In [None]:
import os
import sys
import numpy as np
import pandas as pd

import pycharmm
import pycharmm.generate as gen
import pycharmm.ic as ic
import pycharmm.coor as coor
import pycharmm.energy as energy
import pycharmm.dynamics as dyn
import pycharmm.nbonds as nbonds
import pycharmm.minimize as minimize
import pycharmm.crystal as crystal
import pycharmm.select as select
import pycharmm.image as image
import pycharmm.psf as psf
import pycharmm.param as param
import pycharmm.read as read
import pycharmm.write as write
import pycharmm.settings as settings
import pycharmm.cons_harm as cons_harm
import pycharmm.cons_fix as cons_fix
import pycharmm.shake as shake
import pycharmm.scalar as scalar

# include torch and torchani and set-up model
# To begin with, let's first import the modules we will use:
import torch
import torchani

## Define some functions to interface with the torchANI QM potential models

### **Note** torchANI models only support some atoms, depending on model, see: https://github.com/aiqm/torchani

In [None]:
def SetupTorchANI():
    ###############################################################################
    # Let's now manually specify the device we want TorchANI to run:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Using {} for ANI energy and force calculations'.format(device))

    ###############################################################################
    # Let's now load the built-in ANI-2x models. The builtinANI2x model is an ensemble of 8 networks
    # that was trained on the ANI-2x dataset. The target level of theory is wB97X/6-31G(d). It predicts
    # energies on HCNOFSCl elements exclusively it shouldn’t be used with other atom types.
    #
    # The ``periodic_table_index`` arguments tells TorchANI to use element index
    # in periodic table to index species. If not specified, you need to use
    # 0, 1, 2, 3, ... to index species
    model = torchani.models.ANI2x(periodic_table_index=True).to(device)
    return (device,model)

def iupac_2_number(iupac):
    from mendeleev import element
    allowed = ['H','C','N','O','F','S','CL']
    number = []
    for i in iupac:
         if i[0:1] == 'CL': number.append(element(i[0:1]).atomic_number)
         elif i[0] in allowed: number.append(element(i[0]).atomic_number)
         else:
             print('Element not supported by ANI2X models: atom {}'.format(i))
             exit()
    return number

def get_EnergyDeriv(coor,species):
    # Return the enrgy and derivative of the energy from ANI model
    energy = model((species,coor)).energies
    deriv = torch.autograd.grad(energy.sum(), coor)[0]
    return (energy.item()*627.5,deriv.squeeze()*627.5) # hartree/A -> kcal/mol/A

def ANI_EDX(natoms,
         x_pos, y_pos, z_pos,
         dx, dy, dz):
    coor = []
    for i in range(natoms): coor.append([x_pos[i],y_pos[i],z_pos[i]])
    ener,deriv = get_EnergyDeriv(torch.tensor([np.asarray(coor,dtype=np.float32)],
                                  requires_grad=True, device=device),species)
    for i in range(natoms):
        dx[i] = deriv[i,0] # Note, dx is the gradient of the potential
        dy[i] = deriv[i,1]
        dz[i] = deriv[i,2]
    return ener


## Set-up topology/parameter files, generate butane molecule

In [None]:
# template for f/y restraints
Fcons = '1 c1 1 c2 1 c3 1 c4'
read.rtf('toppar/top_all36_cgenff.rtf')
bl =settings.set_bomb_level(-2)
wl =settings.set_warn_level(-2)
read.prm('toppar/par_all36_cgenff.prm')
settings.set_bomb_level(bl)
settings.set_warn_level(wl)
pycharmm.lingo.charmm_script('bomlev 0')
read.sequence_string('BUTA')  # butane
gen.new_segment(seg_name='BUTA',
                setup_ic=True)
ic.prm_fill(replace_all=True)
ic.seed(1,'C1',1,'C2',1,'C3')  
ic.build()

device,model = SetupTorchANI()
# Get data array for atomic numbers
species = torch.tensor([iupac_2_number(pycharmm.psf.get_atype())],device=device)
print(species)

pycharmm.NonBondedScript(**{'cutnb': 16,
                  'ctofnb': 14,
                  'ctonnb': 12,
                  'atom': True,
                  'vatom': True,
                  'eps': 1,
                  'fswitch': True,
                  'vfswitch': True,
                  'cdie': True})


pycharmm.charmm_script('skipe incl all excl user cdih')

e_func = pycharmm.EnergyFunc(ANI_EDX)

# set up phi/psi grid to apply restraints and
# compute energy
F = np.linspace(-180,180,36)
fmap = {'F':F,
        'eC':[],
        'eQ':[]}

## Loop over dihedral space to construct $\phi$-dependent energy surface

In [None]:
for iphi,f in enumerate(F):
    print(f)
    # turn off noise
    settings.set_verbosity(0)
    settings.set_warn_level(-5)
    # Need to use stream here because no api for cons dihe
    cons = 'cons dihe {} force {} min {:4.2f}'.format(Fcons,500,f)
    pycharmm.lingo.charmm_script(cons)
    minimize.run_abnr(**{'nstep': 1000,
                         'tolenr': 1e-6,
                         'tolgrd': 1e-3})
    pycharmm.lingo.charmm_script('cons cldh')
    pycharmm.lingo.charmm_script('skipe none')
    settings.set_verbosity(5)
    energy.show()
    settings.set_verbosity(0)
    fmap['eC'].append(pycharmm.lingo.get_energy_value('ENER')-pycharmm.lingo.get_energy_value('USER'))
    fmap['eQ'].append(pycharmm.lingo.get_energy_value('USER'))
    pycharmm.lingo.charmm_script('skipe all excl user cdih')

fmap['eC'] = np.asarray(fmap['eC'])
fmap['eC'] = (fmap['eC']-np.min(fmap['eC']))
fmap['eQ'] = np.asarray(fmap['eQ'])
fmap['eQ'] = (fmap['eQ']-np.min(fmap['eQ']))

## Finally plot results

In [None]:
# Plot the results
import matplotlib.pyplot as plt
fig,ax = plt.subplots()
ax.plot(fmap['F'],fmap['eQ'])
ax.plot(fmap['F'],fmap['eC'])
plt.show()