In [2]:
################## IMPORT ######################
import json
from datetime import datetime
from functools import partial, wraps

import fire
import jax
import jax.numpy as jnp
import numpy as np
from jax import jit, random, value_and_grad, vmap
# from jax.experimental import optimizers
from jax.example_libraries import optimizers
from jax_md import space
from shadow.plot import *
from sklearn.metrics import r2_score
from psystems.nsprings import (chain, edge_order, get_connections,
                               get_fully_connected_senders_and_receivers,
                               get_fully_edge_order)
# from statistics import mode
# from sympy import LM
# from torch import batch_norm_gather_stats_with_counts
import sys
MAINPATH = ".."  # nopep8
sys.path.append(MAINPATH)  # nopep8
import jraph
import src
from jax.config import config
# from src import fgn, lnn
from src.graph import *
# from src.lnn import acceleration, accelerationFull, accelerationTV
from src.md import *
from src.models import MSE, initialize_mlp, GaussianNLL, initialize_mlp_gamma, forward_pass_gamma
from src.nve import NVEStates, nve, BrownianStates
from src.utils import *

# config.update("jax_enable_x64", True)
# config.update("jax_debug_nans", True)
# jax.config.update('jax_platform_name', 'gpu')

def namestr(obj, namespace):
    return [name for name in namespace if namespace[name] is obj]


def pprint(*args, namespace=globals()):
    for arg in args:
        print(f"{namestr(arg, namespace)[0]}: {arg}")

f32 = jnp.float32
f64 = jnp.float64

In [3]:
N = 5  # number of particles
dim = 2  # dimensions
runs = 1
kT = 1 #1.380649e-23*T  # boltzmann constant*temperature
# spring_constant = 1.0
# length_constant = 1.0
# nconfig=100
seed=42
dt = 1.0e-3 # time step*stride 
lr=1e-4
batch_size=20
epochs = 10000
# node_type = jnp.array([0,0,0,0,0])
masses = jnp.ones(N)
species = jnp.zeros(N, dtype=int).reshape(-1,1)
# gamma = jnp.ones(jnp.unique(species).shape)  # damping constant

rname=True
withdata = None

print("Configs: ")
pprint(N, epochs, seed, rname, dt, lr, batch_size, namespace=locals())

randfilename = datetime.now().strftime("%m-%d-%Y_%H-%M-%S")

PSYS = f"a-{N}-non-linear-Spring-data-brownian_EM"
TAG = f"1NN"
out_dir = f"../results"

def _filename(name, tag=TAG):
    rstring = randfilename if (rname and (tag != "data")) else (
        "0" if (tag == "data") or (withdata == None) else f"0_{withdata}")
    filename_prefix = f"{out_dir}/{PSYS}-{tag}/{rstring}/"
    file = f"{filename_prefix}/{name}"
    os.makedirs(os.path.dirname(file), exist_ok=True)
    filename = f"{filename_prefix}/{name}".replace("//", "/")
    print("===", filename, "===")
    return filename

def displacement(a, b):
    return a - b

def shift(R, dR):
    return R+dR

def OUT(f):
    @wraps(f)
    def func(file, *args, tag=TAG, **kwargs):
        return f(_filename(file, tag=tag), *args, **kwargs)
    return func

loadmodel = OUT(src.models.loadmodel)
savemodel = OUT(src.models.savemodel)

loadfile = OUT(src.io.loadfile)
savefile = OUT(src.io.savefile)
save_ovito = OUT(src.io.save_ovito)


Configs: 
N: 5
epochs: 10000
seed: 42
rname: True
dt: 0.001
lr: 0.0001
batch_size: 20


In [4]:
################################################
################## CONFIG ######################
################################################
np.random.seed(seed)
key = random.PRNGKey(seed)

try:
    dataset_states = loadfile(f"model_states_brownian.pkl", tag="data")[0]
except:
    raise Exception("Generate dataset first.")

model_states = dataset_states[0]

print(f"Total number of data points: {len(dataset_states)}x{model_states.position.shape[0]}")

Rs = States_Brow().fromlist(dataset_states).get_array()

Rs_in = Rs[:,:99,:,:]
Rs_out = Rs[:,1:100,:,:]


=== ../results/a-5-non-linear-Spring-data-brownian_EM-data/0/model_states_brownian.pkl ===
Total number of data points: 100x100


In [5]:
################################################
################### ML Model ###################
################################################
# print("Creating Chain")
x, _, senders, receivers = chain(N)

hidden = 16
nhidden = 2

def get_layers(in_, out_):
    return [in_] + [hidden]*nhidden + [out_]

def mlp(in_, out_, key, **kwargs):
    return initialize_mlp(get_layers(in_, out_), key, **kwargs)

params = {"F_pos": mlp(N*dim, N*dim, key)}

def acceleration_node(x, params, **kwargs):
    n,dim = x.shape
    inp = x.flatten() #jnp.hstack([x.flatten(),v.flatten()])
    out = forward_pass(params, inp)
    return out.reshape(-1,dim)

def _force_fn():    
    def apply(R, params):
        return acceleration_node(R, params)
    return apply

apply_fn = _force_fn()

def next_step_fn_model(x, params): return apply_fn(x, params['F_pos'])
v_next_step_fn_model = vmap(next_step_fn_model, in_axes=(0, None))
v_v_next_step_fn_model = vmap(v_next_step_fn_model, in_axes=(0, None))

In [6]:
next_step_fn_model(x, params)

Array([[-0.3587485 , -5.5219827 ],
       [-0.16316819, -2.8957014 ],
       [ 2.7990637 , -0.8542066 ],
       [ 0.00590687, -0.661735  ],
       [-1.1160023 , -2.2723567 ]], dtype=float32)

In [7]:
@jit
def loss_fn(params, Rs, Rs_1_ac):
    Rs_1_pred = v_v_next_step_fn_model(Rs, params)
    return MSE(Rs_1_pred, Rs_1_ac)

def gloss(*args):
    return value_and_grad(loss_fn)(*args)

def update(i, opt_state, params, loss__, *data):
    """ Compute the gradient for a batch and update the parameters """
    value, grads_ = gloss(params, *data)
    opt_state = opt_update(i, grads_, opt_state)
    return opt_state, get_params(opt_state), value

@jit
def step(i, ps, *args):
    return update(i, *ps, *args)

opt_init, opt_update_, get_params = optimizers.adam(lr)

@jit
def opt_update(i, grads_, opt_state):
    grads_ = jax.tree_map(jnp.nan_to_num, grads_)
    grads_ = jax.tree_map(partial(jnp.clip, a_min=-1000.0, a_max=1000.0), grads_)
    return opt_update_(i, grads_, opt_state)

def batching(*args, size=None):
    L = len(args[0])
    if size != None:
        nbatches1 = int((L - 0.5) // size) + 1
        nbatches2 = max(1, nbatches1 - 1)
        size1 = int(L/nbatches1)
        size2 = int(L/nbatches2)
        if size1*nbatches1 > size2*nbatches2:
            size = size1
            nbatches = nbatches1
        else:
            size = size2
            nbatches = nbatches2
    else:
        nbatches = 1
        size = L
    
    newargs = []
    for arg in args:
        newargs += [jnp.array([arg[i*size:(i+1)*size]
                                for i in range(nbatches)])]
    return newargs

bRs_in, bRs_out = batching(Rs_in, Rs_out, size=min(len(Rs_in), batch_size))

print(f"training ...")

opt_state = opt_init(params)
epoch = 0
optimizer_step = -1
larray = []
ltarray = []
last_loss = 1000

for epoch in range(epochs):
    l = 0.0
    count = 0
    for data in zip(bRs_in, bRs_out):
        optimizer_step += 1
        opt_state, params, l_ = step(
            optimizer_step, (opt_state, params, 0), *data)
        l += l_
        count+=1
    # print("epoch,countttttt: ", epoch,count)
    # opt_state, params, l_ = step(optimizer_step, (opt_state, params, 0), Rs, Vs, Fs)
    l = l/count
    larray += [l]
    # ltarray += [loss_fn(params, bRs_in, bVs_in, bRs_out)]
    if epoch % 10 == 0:
        print(f"Epoch: {epoch}/{epochs} Loss (MSE):  train={larray[-1]}")#, test={ltarray[-1]}")
    if epoch % 100 == 0:
        metadata = {
            "savedat": epoch,
            # "mpass": mpass,
            }
        savefile(f"fgnode_trained_model.dil",
                    params, metadata=metadata)
        # savefile(f"loss_array.dil", (larray, ltarray), metadata=metadata)
        savefile(f"loss_array.dil", larray, metadata=metadata)
        if last_loss > larray[-1]:
            last_loss = larray[-1]
            savefile(f"fgnode_trained_model_low.dil",
                        params, metadata=metadata)
        fig, axs = panel(1, 1)
        # plt.semilogy(larray, label="Training")
        plt.plot(larray, label="Training")
        # plt.semilogy(ltarray, label="Test")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(_filename(f"training_loss.png"))

fig, axs = panel(1, 1)
# plt.semilogy(larray, label="Training")
plt.plot(larray, label="Training")
# plt.semilogy(ltarray, label="Test")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.savefig(_filename(f"training_loss.png"))

params = get_params(opt_state)
savefile(f"fgnode_trained_model.dil", params, metadata=metadata)
# savefile(f"loss_array.dil", (larray, ltarray), metadata=metadata)

if last_loss > larray[-1]:
    last_loss = larray[-1]
    savefile(f"fgnode_trained_model_low.dil", params, metadata=metadata)


training ...
Epoch: 0/10000 Loss (MSE):  train=12.50381851196289
=== ../results/a-5-non-linear-Spring-data-brownian_EM-1NN/05-24-2023_22-09-39/fgnode_trained_model.dil ===
=== ../results/a-5-non-linear-Spring-data-brownian_EM-1NN/05-24-2023_22-09-39/loss_array.dil ===
=== ../results/a-5-non-linear-Spring-data-brownian_EM-1NN/05-24-2023_22-09-39/fgnode_trained_model_low.dil ===
=== ../results/a-5-non-linear-Spring-data-brownian_EM-1NN/05-24-2023_22-09-39/training_loss.png ===
Epoch: 10/10000 Loss (MSE):  train=11.287025451660156
Epoch: 20/10000 Loss (MSE):  train=10.20813274383545
Epoch: 30/10000 Loss (MSE):  train=9.252677917480469
Epoch: 40/10000 Loss (MSE):  train=8.402363777160645
Epoch: 50/10000 Loss (MSE):  train=7.642260551452637
Epoch: 60/10000 Loss (MSE):  train=6.960115432739258
Epoch: 70/10000 Loss (MSE):  train=6.345608711242676
Epoch: 80/10000 Loss (MSE):  train=5.789948463439941
Epoch: 90/10000 Loss (MSE):  train=5.285636901855469
Epoch: 100/10000 Loss (MSE):  train=4.8262

  fig = plt.figure(figsize=figsize, **kwargs)


Epoch: 2010/10000 Loss (MSE):  train=0.1037348061800003
Epoch: 2020/10000 Loss (MSE):  train=0.10242026299238205
Epoch: 2030/10000 Loss (MSE):  train=0.1011204868555069
Epoch: 2040/10000 Loss (MSE):  train=0.09983598440885544
Epoch: 2050/10000 Loss (MSE):  train=0.09856736660003662
Epoch: 2060/10000 Loss (MSE):  train=0.09731513261795044
Epoch: 2070/10000 Loss (MSE):  train=0.09607967734336853
Epoch: 2080/10000 Loss (MSE):  train=0.09486141800880432
Epoch: 2090/10000 Loss (MSE):  train=0.0936608538031578
Epoch: 2100/10000 Loss (MSE):  train=0.09247809648513794
=== ../results/a-5-non-linear-Spring-data-brownian_EM-1NN/05-24-2023_22-09-39/fgnode_trained_model.dil ===
=== ../results/a-5-non-linear-Spring-data-brownian_EM-1NN/05-24-2023_22-09-39/loss_array.dil ===
=== ../results/a-5-non-linear-Spring-data-brownian_EM-1NN/05-24-2023_22-09-39/fgnode_trained_model_low.dil ===
=== ../results/a-5-non-linear-Spring-data-brownian_EM-1NN/05-24-2023_22-09-39/training_loss.png ===
Epoch: 2110/10000 

In [7]:
# params, _ = loadfile(f"fgnode_trained_model_low.dil", verbose=True)

In [8]:
rname=False

# PSYS = f"a-{N}-Spring-data-brownian_EM"
# TAG = f"1NN"
# out_dir = f"../results"

def _filename(name, tag=TAG):
    rstring = randfilename if (rname and (tag != "data")) else (
        "0" if (tag == "data") or (withdata == None) else f"0_{withdata}")
    filename_prefix = f"{out_dir}/{PSYS}-{tag}/{rstring}/"
    file = f"{filename_prefix}/{name}"
    os.makedirs(os.path.dirname(file), exist_ok=True)
    filename = f"{filename_prefix}/{name}".replace("//", "/")
    print("===", filename, "===")
    return filename

def OUT(f):
    @wraps(f)
    def func(file, *args, tag=TAG, **kwargs):
        return f(_filename(file, tag=tag), *args, **kwargs)
    return func

loadmodel = OUT(src.models.loadmodel)
savemodel = OUT(src.models.savemodel)

loadfile = OUT(src.io.loadfile)
savefile = OUT(src.io.savefile)
save_ovito = OUT(src.io.save_ovito)


In [10]:
params, _ = loadfile(f"fgnode_trained_model_low.dil", verbose=True)

=== ../results/a-5-Spring-data-brownian_EM-1NN/0/fgnode_trained_model_low.dil ===
Loading ../results/a-5-Spring-data-brownian_EM-1NN/0/fgnode_trained_model_low.dil


In [8]:
# %matplotlib inline
# import matplotlib.pyplot as plt

spring_constant = 1.0
length_constant = 1.0
gamma_orig = jnp.ones(jnp.unique(species).shape)
stride = 1
runs=100

def SPRING(x, stiffness=1.0, length=1.0):
    x_ = jnp.linalg.norm(x, keepdims=True)
    return 0.5*stiffness*(x_ - length)**4

def pot_energy_orig(x):
    dr = x[senders, :] - x[receivers, :]
    return vmap(partial(SPRING, stiffness=spring_constant, length=length_constant))(dr).sum()

def force_fn_orig(R, params):
    return -grad(pot_energy_orig)(R)

def get_forward_sim(params = None, force_fn = None, gamma = None, runs=10):
        @jit
        def fn(R,key):
            return predition_brow(R, params, force_fn, shift, dt, kT, masses, gamma = gamma, stride=stride, runs=runs, key=key)
        return fn


sim_orig = get_forward_sim(params=None,force_fn=force_fn_orig, gamma=gamma_orig,runs=runs)

# model
def get_forward_sim_model(params = None, next_step_fn = None, runs=10, stride=1):
        next_step_fn = lambda R: next_step_fn_model(R, params)
        @jit
        def solve_dynamics(R_init):
            step = jit(lambda i, R: next_step_fn(R))
            def f(R):
                y = jax.lax.fori_loop(0, stride, step, R)
                return y, y
            
            def func(R, i): return f(R)
            @jit
            def scan(R0):
                return jax.lax.scan(func, R0, jnp.array(range(runs)))
            
            final_state, traj = scan(R_init)
            return traj
        return solve_dynamics

sim_model = get_forward_sim_model(params = params, next_step_fn = next_step_fn_model, runs=runs)

In [9]:
def write_xyz_traj(Filepath,Name,R):
    '''Writes ovito xyz file'''
    f=open(Filepath,'w')
    f.write(str(R.shape[1])+"\n")
    f.write(Name)
    for i in range(R.shape[0]): #R.shape[0]
        for j in range(R.shape[1]):
            f.write("\n"+str(species[j])+"\t"+str(R[i,j,0])+"\t"+str(R[i,j,1])+"\t"+str(R[i,j,2]))
        f.write("\n"+str(R.shape[1]))
        f.write("\n")

In [10]:
import time
# plotthings = True
rng_key = random.PRNGKey(0)
maxtraj = 100
np.random.seed(seed)
key = random.PRNGKey(seed)

nexp = {
        "dz_actual": [],
        "dz_pred": [],
        "z_actual": [],
        "z_pred": [],
        "_gamma": [],
        "simulation_time":[],
        }

trajectories = []
for ind in range(maxtraj):
    print(f"Simulating trajectory {ind}/{maxtraj} ...")
    R, _ = chain(N)[:2]
    for rand in range(10):
        rng_key, subkey = random.split(rng_key)
        actual_traj = sim_orig(R,(ind+13)*subkey)
        rng_key, subkey = random.split(rng_key)
        
        start = time.time()
        pred_pos = sim_model(R)
        end = time.time()
        nexp["simulation_time"] += [end-start]
        
        nexp["dz_actual"] += [actual_traj.position-R]
        nexp["dz_pred"] += [pred_pos-R]

        nexp["z_actual"] += [actual_traj.position]
        nexp["z_pred"] += [pred_pos]
        
        if save_ovito:
            if ind<1 and rand<1:
                save_ovito(f"actual_{ind}_{rand}.xyz", [state for state in BrownianStates(actual_traj)], lattice="")
                write_xyz_traj(_filename(f"pred_{ind}_{rand}.xyz"),'spring_ddnn',pred_pos)
                
        # trajectories += [(actual_traj.position, pred_pos)]
        # if ind%10==0:
        #     savefile("trajectories.pkl", trajectories)


def KL_divergence(sigma0,mu0,sigma1,mu1, eps=1e-8):
    return jnp.log(sigma1/sigma0) + (jnp.square(sigma0)+jnp.square(mu0-mu1))/(2*jnp.square(sigma1)) - 0.5

def get_kld(d_actual, d_pred):
    mu0 = jnp.mean(d_actual, axis=(0,2,3))
    std0 = jnp.std(d_actual, axis=(0,2,3))
    mu1 = jnp.mean(d_pred, axis=(0,2,3))
    std1 = jnp.std(d_pred, axis=(0,2,3))
    kld = []
    for i in range(len(std0)):
        kld.append(KL_divergence(std0[i],mu0[i],std1[i],mu1[i]))
    return jnp.array(kld)

def get_std_rmse(d_actual, d_pred):
    std0 = jnp.std(d_actual, axis=(0,2,3))
    std1 = jnp.std(d_pred, axis=(0,2,3))
    return jnp.sqrt(jnp.square(std0 - std1))

def get_dist_by_var(actual, pred, zeta):
    disp = displacement(actual, pred)
    dist_matrix = jnp.sqrt(jnp.square(disp).sum(-1))
    dist_mean = jnp.mean(dist_matrix, axis=(0,2))
    dist_by_zeta = dist_mean/zeta
    return dist_by_zeta

nexp2 = {
        "kld": [],
        "std_rmse": [],
        }

nexp2["kld"] = jnp.array(get_kld(jnp.array(nexp["dz_actual"]),jnp.array(nexp["dz_pred"])))

nexp2["std_rmse"] = jnp.array(get_std_rmse(jnp.array(nexp["dz_actual"]),jnp.array(nexp["dz_pred"])))

# nexp2["dist_by_var"] = jnp.array(get_dist_by_var(jnp.array(nexp['z_actual']), jnp.array(nexp['z_pred']),1/(jnp.array(nexp['_gamma'])[0][0])))

savefile(f"error_paramete_plot_a_b_c.pkl", nexp2)
# savefile("trajectories.pkl", trajectories)
savefile(f"error_parameter.pkl", nexp)
# savefile("trajectories.pkl", trajectories)


Simulating trajectory 0/100 ...
=== ../results/a-5-non-linear-Spring-data-brownian_EM-1NN/05-24-2023_22-09-39/actual_0_0.xyz ===
Saving ovito file: ../results/a-5-non-linear-Spring-data-brownian_EM-1NN/05-24-2023_22-09-39/actual_0_0.xyz
=== ../results/a-5-non-linear-Spring-data-brownian_EM-1NN/05-24-2023_22-09-39/pred_0_0.xyz ===
Simulating trajectory 1/100 ...
Simulating trajectory 2/100 ...
Simulating trajectory 3/100 ...
Simulating trajectory 4/100 ...
Simulating trajectory 5/100 ...
Simulating trajectory 6/100 ...
Simulating trajectory 7/100 ...
Simulating trajectory 8/100 ...
Simulating trajectory 9/100 ...
Simulating trajectory 10/100 ...
Simulating trajectory 11/100 ...
Simulating trajectory 12/100 ...
Simulating trajectory 13/100 ...
Simulating trajectory 14/100 ...
Simulating trajectory 15/100 ...
Simulating trajectory 16/100 ...
Simulating trajectory 17/100 ...
Simulating trajectory 18/100 ...
Simulating trajectory 19/100 ...
Simulating trajectory 20/100 ...
Simulating trajec