# Toy problem: Learning LJ potential with three atoms

This notebook showcased the usage of PiNN with a toy problem of learning Lennard Jones
potential with a hand-generated dataset.  
It serves as a basic test, and demonstration of the workflow with PiNN.

In [None]:
%matplotlib inline

In [None]:
import os 
import numpy as np
import matplotlib.pyplot as plt
from ase import Atoms
from ase.calculators.lj import LennardJones
os.environ['CUDA_VISIBLE_DEVICES'] = ''

## Reference data

In [None]:
# Helper function: get the position given PES dimension(s)
def three_body_sample(atoms, a, r):
    x = a * np.pi / 180
    pos = [[0, 0, 0],
           [0, 2, 0],
           [0, r*np.cos(x), r*np.sin(x)]]
    atoms.set_positions(pos)
    return atoms

In [None]:
atoms = Atoms('H3', calculator=LennardJones())

na, nr = 50, 50
arange = np.linspace(30,180,na)
rrange = np.linspace(1,3,nr)

# Truth
agrid, rgrid = np.meshgrid(arange, rrange)
egrid = np.zeros([na, nr])
for i in range(na):
    for j in range(nr):
        atoms = three_body_sample(atoms, arange[i], rrange[j])
        egrid[i,j] = atoms.get_potential_energy()
        
# Samples
nsample = 50
asample, rsample = [], []
distsample = []
data = {'e_data':[], 'f_data':[], 'atoms':[], 'coord':[]}
for i in range(nsample):
    a, r = np.random.choice(arange), np.random.choice(rrange)
    atoms = three_body_sample(atoms, a, r)
    dist = atoms.get_all_distances()
    dist = dist[np.nonzero(dist)]
    data['e_data'].append(atoms.get_potential_energy())
    data['f_data'].append(atoms.get_forces())
    data['coord'].append(atoms.get_positions())
    data['atoms'].append(atoms.numbers)
    asample.append(a)
    rsample.append(r)
    distsample.append(dist)

In [None]:
plt.pcolormesh(agrid, rgrid, egrid)
plt.plot(asample, rsample, 'rx')
plt.colorbar()

## Dataset from numpy arrays

In [None]:
data = {k:np.array(v) for k,v in data.items()}
dataset = lambda: load_numpy_dataset(data)

train = lambda: dataset()['train'].shuffle(100).repeat().batch(50)
test = lambda: dataset()['test'].repeat().batch(10)

## Training

In [None]:
import tensorflow as tf
from pinn.models import potential_model
from pinn.networks import pinn_network
from pinn.datasets.numpy import load_numpy_dataset
from pinn.calculator import PiNN_calc

## Model specification

In [None]:
params={
    'model_dir': '/tmp/toy_models/LJ_three_body',
    'network': 'pinn_network',
    'netparam': {'pre_level': 0,
                 'ii_nodes':[8,8],
                 'pi_nodes':[8,8],
                 'pp_nodes':[8,8],
                 'en_nodes':[8,8],
                 'depth': 4,
                 'rc': 3.0,
                 'atom_types':[1],
                 'atomic_dress': {}},
    'train':{
        'en_scale': 10,
        'train_force': True,
        'force_ratio': 1,
        'learning_rate': 3e-4,
        'regularization': 'clip'}}
model = potential_model(params)

In [None]:
#%rm -r /tmp/toy_models/LJ_three_body/ # To trash the model
train_spec = tf.estimator.TrainSpec(input_fn=train, max_steps=2e3)
eval_spec = tf.estimator.EvalSpec(input_fn=test, steps=10)
tf.estimator.train_and_evaluate(model, train_spec, eval_spec)

## Validate the results
### PES analysis

In [None]:
atoms = Atoms('H3', calculator=PiNN_calc(model))
epred = np.zeros([na, nr])
for i in range(na):
    for j in range(nr):
        a, r = arange[i], rrange[j]
        atoms = three_body_sample(atoms, a, r)
        epred[i,j] = atoms.get_potential_energy()

In [None]:
plt.pcolormesh(agrid, rgrid, epred)
plt.colorbar()
plt.title('NN predicted PES')
plt.figure()
plt.pcolormesh(agrid, rgrid, 10*egrid-epred)
plt.plot(asample, rsample, 'rx')
plt.title('NN Prediction error and sampled points')
plt.colorbar()

### Pairwise potential analysis

In [None]:
atoms1 = Atoms('H2', calculator=PiNN_calc(model))
atoms2 = Atoms('H2', calculator=LennardJones())

nr2 = 50
rrange2 = np.linspace(1,1.5,nr2)
epred = np.zeros(nr2)
etrue = np.zeros(nr2)

for i in range(nr):
    pos = [[0, 0, 0],
           [rrange[i], 0, 0]]
    atoms1.set_positions(pos)
    atoms2.set_positions(pos)
    epred[i] = atoms1.get_potential_energy()
    etrue[i] = atoms2.get_potential_energy()

In [None]:
f, (ax1, ax2) = plt.subplots(2,1, gridspec_kw = {'height_ratios':[3, 1]})
ax1.plot(rrange2, epred)
ax1.plot(rrange2, etrue*10,'--')
ax1.legend(['Prediction', 'Truth'], loc=4)
_=ax2.hist(np.concatenate(distsample,0), 10, range=(1,1.5))