In [82]:
%load_ext autoreload
%autoreload 2

from ff_energy.utils import pickle_output, read_from_pickle
import pandas as pd
import scikit_posthocs as sp
import seaborn as sns
import matplotlib.pyplot as plt
from ff_energy.plot import plot_energy_MSE, plot_ff_fit

from ff_energy.structure import atom_key_pairs
from ff_energy.potential import LJ, akp_indx
from ff_energy.ff import FF
from ff_energy.ff_fit import LJ_bound, load_ff, fit_func, fit_repeat
from ff_energy.data import pairs_data


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#  Data 
##  ab intio data

In [3]:
pc_pbe0_dz_d4 = next(read_from_pickle("data_pc_pbe0_d4.obj.pkl"))
pc_pbe0_dz = next(read_from_pickle("data_pc_pbe0.obj.pkl"))
pc_hf_dz = next(read_from_pickle("data_pc_hfdz.obj.pkl"))
pc_pbe0tz = next(read_from_pickle("data_pc_pbe0tz.obj.pkl"))
pnolccsd_pvtzdf = next(read_from_pickle("data_pc_pno-lccsd-pvtzdf.obj.pkl"))

In [12]:
pc_pbe0_dz.monomer_df

Unnamed: 0,M_ENERGY,KEY,n_monomers
test467,-1527.154217,test467,20
test263,-1527.156016,test263,20
test35,-1527.143158,test35,20
test223,-1527.147556,test223,20
test254,-1527.135568,test254,20
...,...,...,...
test100,-1527.141372,test100,20
test304,-1527.150985,test304,20
test327,-1527.154084,test327,20
test185,-1527.144999,test185,20


#  FFs

A FF object should contain all the data needed to fit a potential.
Getting all the jax arrays is slow - FFs should be made once and then reused for preceding fits

Some fits will use Coloumb distances/charges, these are different for each electrostatic model

In [113]:
test_ff = load_ff("test", 
        "water_cluster", 
        pk = "data_pc_pbe0.obj.pkl",
        intern="harmonic",
        elec="ECOL"
       )
test_ff, data = test_ff

Pickled structures/dists found:  ['water_cluster.pkl.pkl', 'water_dimer.pkl.pkl', 'water_dimer.pkl', 'structures.pkl', 'dists.pkl', 'water_cluster.pkl']
Loading pickled distances/structure: water_cluster
loading data from pickle: data_pc_pbe0.obj.pkl


In [114]:
test_ff.elec = "ELEC"
test_ff.set_targets()
test_ff

FF: LJ water_cluster ELEC harmonicjax_coloumb: False

In [115]:
fit_func(test_ff, None)

Optimizing LJ parameters...
function: LJ
bounds: ((0.25, 2.5), (0.25, 2.5), (0.0001, 0.5), (0.0001, 0.5))
maxfev: 10000
initial guess: [1.6170418400175541, 0.9351073612728495, 0.47240351913685147, 0.03593968041434267]


  res = minimize(


final_loss_fn:  129.8406219482422
       message: Optimization terminated successfully.
       success: True
        status: 0
           fun: 129.8406219482422
             x: [ 2.500e-01  2.500e-01  5.000e-01  1.000e-04]
           nit: 50
          nfev: 79
 final_simplex: (array([[ 2.500e-01,  2.500e-01,  5.000e-01,  1.000e-04],
                       [ 2.500e-01,  2.500e-01,  5.000e-01,  1.000e-04],
                       ...,
                       [ 2.500e-01,  2.500e-01,  5.000e-01,  1.000e-04],
                       [ 2.500e-01,  2.500e-01,  5.000e-01,  1.000e-04]]), array([ 1.298e+02,  1.298e+02,  1.298e+02,  1.298e+02,
                        1.298e+02]))


       message: Optimization terminated successfully.
       success: True
        status: 0
           fun: 129.8406219482422
             x: [ 2.500e-01  2.500e-01  5.000e-01  1.000e-04]
           nit: 50
          nfev: 79
 final_simplex: (array([[ 2.500e-01,  2.500e-01,  5.000e-01,  1.000e-04],
                       [ 2.500e-01,  2.500e-01,  5.000e-01,  1.000e-04],
                       ...,
                       [ 2.500e-01,  2.500e-01,  5.000e-01,  1.000e-04],
                       [ 2.500e-01,  2.500e-01,  5.000e-01,  1.000e-04]]), array([ 1.298e+02,  1.298e+02,  1.298e+02,  1.298e+02,
                        1.298e+02]))

In [87]:
jax_data_kmdcm = next(read_from_pickle("jax_data_kmdcm.pkl"))

In [88]:
test_ff.init_jax_col(jax_data_kmdcm)

In [89]:
# test_ff.cluster_labels

In [90]:
test_ff

FF: LJ water_cluster ELEC harmonicjax_coloumb: True

In [102]:
x0 = [2.1801901325157633, 0.9689078330468776, 0.3951433427330229, 0.26583564356898853, 1.34]
sig_bound = (0.25, 2.5)
ep_bound = (0.0001, 0.5)
bounds = ((sig_bound), (sig_bound), (ep_bound), (ep_bound), (1.0,1.5))

fit_func(test_ff, x0, bounds=bounds, loss='lj_ecol')

Optimizing LJ parameters...
function: LJ
bounds: ((0.25, 2.5), (0.25, 2.5), (0.0001, 0.5), (0.0001, 0.5), (1.0, 1.5))
maxfev: 10000
initial guess: [2.1801901325157633, 0.9689078330468776, 0.3951433427330229, 0.26583564356898853, 1.34]


  res = minimize(


final_loss_fn:  783.3369140625
       message: Optimization terminated successfully.
       success: True
        status: 0
           fun: 783.3369140625
             x: [ 1.116e+00  7.808e-01  5.000e-01  5.000e-01  1.500e+00]
           nit: 402
          nfev: 683
 final_simplex: (array([[ 1.116e+00,  7.808e-01, ...,  5.000e-01,
                         1.500e+00],
                       [ 1.116e+00,  7.808e-01, ...,  5.000e-01,
                         1.500e+00],
                       ...,
                       [ 1.116e+00,  7.808e-01, ...,  5.000e-01,
                         1.500e+00],
                       [ 1.116e+00,  7.808e-01, ...,  5.000e-01,
                         1.500e+00]]), array([ 7.833e+02,  7.833e+02,  7.833e+02,  7.833e+02,
                        7.833e+02,  7.833e+02]))


       message: Optimization terminated successfully.
       success: True
        status: 0
           fun: 783.3369140625
             x: [ 1.116e+00  7.808e-01  5.000e-01  5.000e-01  1.500e+00]
           nit: 402
          nfev: 683
 final_simplex: (array([[ 1.116e+00,  7.808e-01, ...,  5.000e-01,
                         1.500e+00],
                       [ 1.116e+00,  7.808e-01, ...,  5.000e-01,
                         1.500e+00],
                       ...,
                       [ 1.116e+00,  7.808e-01, ...,  5.000e-01,
                         1.500e+00],
                       [ 1.116e+00,  7.808e-01, ...,  5.000e-01,
                         1.500e+00]]), array([ 7.833e+02,  7.833e+02,  7.833e+02,  7.833e+02,
                        7.833e+02,  7.833e+02]))

In [107]:
test_ff.data[["intE"]]

Unnamed: 0,intE
test0,-88.423171
test1,-53.572861
test2,-74.628101
test3,-59.373728
test4,-74.612477
...,...
test495,-76.747446
test496,-58.484874
test497,-58.307079
test498,-64.986095


In [105]:
# test_ff.eval_lj_coulomb(x0)
test_ff.targets

Array([ -87.73192 ,  -53.635742,  -74.65536 ,  -59.582268,  -74.86736 ,
        -63.04028 ,  -60.319916,  -53.79419 ,  -61.938053,  -63.166183,
        -52.11618 ,  -51.378582,  -67.03404 ,  -68.89161 ,  -75.26452 ,
        -66.317085,  -64.54564 ,  -60.093925,  -76.63827 ,  -53.382748,
        -62.73625 ,  -59.81628 ,  -43.87575 ,  -38.67543 ,  -57.742355,
        -72.39073 ,  -52.655025,  -36.681713,  -26.06485 ,  -41.864223,
        -56.548836,  -52.899216,  -55.62058 ,  -75.0999  ,  -58.183712,
        -71.26341 ,  -56.36466 ,  -65.41849 ,  -34.3824  ,  -48.865692,
        -37.059254,  -57.55826 ,  -42.26478 ,  -66.15077 ,  -35.782288,
        -43.95934 ,  -42.953323,  -36.88685 ,  -63.937653,  -61.80994 ,
        -64.10442 ,  -74.30323 ,  -52.068268,  -56.535233,  -62.20413 ,
        -54.383987,  -65.781815,  -48.28325 ,  -65.54897 ,  -54.550266,
        -92.84306 ,  -52.51835 ,  -59.42012 ,  -32.81041 ,  -37.48269 ,
       -101.47126 ,  -68.13017 ,  -52.969017,  -74.65725 ,  -45.