## Import

In [None]:
import jax
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import src
from frozendict import frozendict
from jax import grad, jit
from jax import numpy as jnp
from jax import value_and_grad, vmap
from jax.config import config
from jax.experimental import optimizers
from jraph import GraphNetwork, GraphsTuple
from src import lnn, utils
from src.io import *
from src.lnn import accelerationFull
from src.md import displacement, prediction, shift
from src.models import forward_pass, initialize_mlp
from src.nve import NVEState, NVEStates
from shadow.plot import *
from shadow.font import *
set_font_family(family="Arial")

config.update("jax_enable_x64", True)

try:
    jax.tree_util.register_pytree_node(
    frozendict,
    flatten_func=lambda s: (tuple(s.values()), tuple(s.keys())),
    unflatten_func=lambda k, xs: frozendict(zip(k, xs)))
except:
    pass

In [None]:
np.random.rand()

## Config

In [None]:
n = 2
LENGTH = 2
link_size = LENGTH / n


def rot(x, t):
    theta = t * jnp.pi / 180
    sin = jnp.sin(theta)
    cos = jnp.cos(theta)
    ROT = jnp.array([
        [cos, -sin, 0],
        [sin, cos, 0],
        [0, 0, 1]
    ])
    return jnp.dot(ROT, x.T).T

def getRV(n, dim=2, t=0):
    R = [jnp.array([[0.0, 0, 0],])]
    for _ in range(2*n):
        point = jnp.array([[0.0, -1.0, 0.0], ])
        point = rot(point, (np.random.rand()-0.5)*30)
        R += R[-1] + point

    R = jnp.vstack(R)
    R = rot(R, t)

    dim = 2

    if dim == 2:
        R = R[:, :2]
    return R, 0*R

#     R1 = link_size * \
#         jnp.vstack([jnp.zeros(n+1), -jnp.arange(0, n+1), jnp.zeros(n+1)]).T
#     R1 = rot(R1, 45)

#     R2 = link_size * \
#         jnp.vstack([jnp.zeros(n) + 0, -n+1 + jnp.arange(0, n), jnp.zeros(n)]).T
#     R2 = np.array(rot(R2, -45))
#     R2[:, 0] += LENGTH*np.sqrt(2)

#     R = jnp.vstack((R1, R2))
#     R = rot(R, 0)

R, V = getRV(n, t=45)



N, dim = R.shape


R0 = R
V0 = V


mass = 1.0
length = 1.0
_g = 10.0
dt = 1.0e-5
stride = 100
runs = 100

sends = [i for i in range(len(R)-1)]
recs = [i for i in range(1, len(R))]

key = jax.random.PRNGKey(0)

nspecies = 1

def OHE(x):
    return jax.nn.one_hot(x, nspecies)

species = OHE(jnp.ones(len(sends+recs)))


In [None]:
# def makechainfig():
#     label = ["A", "B", "C"]
#     for ind in range(3):

#         n = 2**(3-ind)
#         R, V = getRV(n, t=45)

#         R = R + jnp.array([[4.0*ind, 0.0], ])

#         sends = [i for i in range(len(R)-1)]
#         recs = [i for i in range(1, len(R))]

#         key = jax.random.PRNGKey(0)

#         nspecies = 1

#         def OHE(x):
#             return jax.nn.one_hot(x, nspecies)

#         species = OHE(jnp.ones(len(sends+recs)))


#         for (i,j) in zip(sends, recs):
#             plt.plot([R[i, 0], R[j, 0]], [R[i, 1], R[j, 1]], color="#729fcf", lw=4)

#         # for (i,j) in zip(sends, recs):
#         #     plt.plot([R[i, 0], R[j, 0]], [R[i, 1], R[j, 1]], color="black", lw=1)

#         for r in R:
#             plt.scatter(r[0], r[1], zorder=2, color="#999999", ec="white", s=120)

#         # for r in R:
#         #     plt.scatter(r[0], r[1], zorder=2, color="white", s=20)

#         plt.scatter(R[0, 0], R[0, 1], zorder=2, color="brown", s=150, marker="+")
#         plt.text(R[0, 0]+1, R[0, 1], label[ind])

#     plt.axis("square")
#     plt.axis("off")
#     plt.savefig(f"notebooks/chains.png", dpi=1000)

    
# makechainfig()

In [None]:
def get_I(m, L):
    return m/12*L**2


def get_omega(r, v):
    # ω = r × v / |r | ²
    r2 = jnp.square(r).sum(axis=-1, keepdims=True)
    rxv = jnp.cross(r, v)
    return rxv / r2


dimω,  = get_omega(R[0], V[0]).shape

key = jax.random.PRNGKey(0)

if dim == 2:
    endpoint1 = np.array([[0.0, 0]])
    endpoint2 = np.array([[LENGTH*np.sqrt(2), 0]])
else:
    endpoint1 = np.array([[0.0, 0, 0.0]])
    endpoint2 = np.array([[LENGTH*np.sqrt(2), 0, 0.0]])


In [None]:
# def holo_con_wrap(r1, r2, sp, params, l=1.0):
#     return jnp.square(r1-r2).sum() - l**2

 
# v_holo_con = vmap(holo_con_wrap, in_axes=(0, 0, 0, None))

 
# def update_edge_fn2(edges, sent_attributes, received_attributes, globals_):
#     H = v_holo_con(sent_attributes["position"],
#                    received_attributes["position"], edges["species"], edges["params"])
#     return frozendict({"hcon": H})

# cnet = GraphNetwork(update_edge_fn2, None)


# def hconstraints2(R, V, params):
#     graph = GraphsTuple(
#         nodes={
#             "position": R,
#             "velocity": V,
#             "acceleration": 0*V,
#         },
#         edges={
#             "params": params,
#             "species": species,
#         },
#         senders=jnp.array(sends+recs),
#         receivers=jnp.array(recs+sends),
#         globals=None,
#         n_node=jnp.array([len(R)]),
#         n_edge=None
#     )
#     return cnet(graph).edges["hcon"]


# @jit
# def constraints2(x, v, params):
#     return jax.jacobian(lambda x: hconstraints2(x.reshape(-1, dim), v.reshape(-1, dim), params), 0)(x)


# constraints2(R.flatten(), V, params).shape

In [None]:
# def hconstraints1(R, l=jnp.array([1.0])):
#     if len(l) != len(R):
#         l = l[0]*jnp.ones((2))
#     out = jnp.square(jnp.vstack([R[:1], endpoint2]) -
#                      jnp.vstack([endpoint1, R[-1:]])).sum(axis=1) - l**2
#     return out

# @ jit
# def constraints1(x, v, params):
#     return jax.jacobian(lambda x: hconstraints1(x.reshape(-1, dim)), 0)(x)

# constraints1(R.flatten(), V, params).shape

In [None]:
# def constraints(R, V, params):
#     return jnp.vstack((constraints1(R, V, params), constraints2(R, V, params)))

# constraints(R.flatten(), V, params).shape

In [None]:
def hconstraints(R, l=jnp.array([1.0])):
    if len(l) != len(R):
        l = l[0]*jnp.ones((len(R)+1))
    out = jnp.square(jnp.vstack([R, endpoint2]) -
                     jnp.vstack([endpoint1, R])).sum(axis=1) - l**2
    return out

def hconstraints(R, l=jnp.array([1.0])):
    if len(l) != len(R):
        l = l[0]*jnp.ones((len(R)))
    out = jnp.square(R -
                     jnp.vstack([endpoint1, R[:-1]])).sum(axis=1) - l**2
    return out

@jit 
def jac_constraints(x, v, params):
    return jax.hessian(lambda x: hconstraints(x.reshape(-1, dim)), 0)(x)

@ jit
def constraints(x, v, params):
    return jax.jacobian(lambda x: hconstraints(x.reshape(-1, dim)), 0)(x)

## Lag

In [None]:

def lag_link_wrap(r1, r2, v1, v2, sp, params):
    R = jnp.hstack((r1, r2))
    V = jnp.hstack((v1, v2))
    return lag_link(R, V, sp, params)


def identity(x):
    return x


def lag_link(R, V, sp, params):
    if params is None:
        r1, r2 = jnp.split(R, 2)
        v1, v2 = jnp.split(V, 2)
        pos = (r1 + r2) / 2
        vel = (v1 + v2) / 2
        r_ = r1 - r2
        v_ = v1 - v2
        w = jnp.cross(r_, v_)
        T = 0.5 * mass * jnp.square(vel).sum() + \
            0.5 * get_I(mass, length) * jnp.square(w).sum()
        V_ = mass * _g * pos[1]
        # T = 0.5 * (v1**2 + v2**2).sum()
        # V = 10.0 * (r1[1] + r2[1]).sum()
        return T, V_
    else:
        r1, r2 = jnp.split(R, 2)
        v1, v2 = jnp.split(V, 2)
        pos = (r1 + r2) / 2
        vel = (v1 + v2) / 2
        r_ = r1 - r2
        v_ = v1 - v2
        w = jnp.cross(r_, v_)

        w_ = forward_pass(params["ke_a_emb"], sp, activation_fn=identity)
        w_ = jnp.hstack((w_, w))
        vel_ = forward_pass(params["ke_l_emb"], sp, activation_fn=identity)
        vel_ = jnp.hstack((vel_, vel))

        pos_ = forward_pass(params["pe_emb"], sp, activation_fn=identity)
        pos_ = jnp.hstack((pos_, pos, r_))

        T = jnp.square(forward_pass(params["ke_l"], vel_)).sum() 
        T += jnp.square(forward_pass(params["ke_a"], w_)).sum()
        V_ = forward_pass(params["pe"], pos_).sum()

        # T = 0.5 * (v1**2 + v2**2).sum()
        # V_ = mass * _g * (r1[1] + r2[1]).sum() /2

        # r1, r2 = jnp.split(R, 2)
        # v1, v2 = jnp.split(V, 2)
        # pos = (r1 + r2) / 2
        # vel = (v1 + v2) / 2
        # r_ = r1 - r2
        # v_ = v1 - v2
        # w = jnp.cross(r_, v_)

        # w_ = forward_pass(params["ke_a_emb"], sp, activation_fn=identity)
        # w_ = jnp.hstack((w_, w))
        # vel_ = forward_pass(params["ke_l_emb"], sp, activation_fn=identity)
        # vel_ = jnp.hstack((vel_, vel))

        # pos_ = forward_pass(params["pe_emd"], sp, activation_fn=identity)
        # pos_ = jnp.hstack((pos_, pos, r_))

        # T = jnp.square(forward_pass(params["ke_l"], vel_)).sum() + \
        #     jnp.square(forward_pass(params["ke_a"], w_)).sum()

#         r1, r2 = jnp.split(R, 2)
#         v1, v2 = jnp.split(V, 2)

#         vel = (v1 + v2) / 2
#         r_ = r1 - r2
#         v_ = v1 - v2
#         w = jnp.cross(r_, v_)
#         T = 0.5 * mass * jnp.square(vel).sum() + \
#             0.5 * get_I(mass, length) * jnp.square(w).sum()

#         # Node embedding
#         r_i, r_j = r1, r2
#         v_i, v_j = v1, v2
#         pos_i = forward_pass(params["posi_emb"], r_i, activation_fn=identity)
#         vel_i = forward_pass(params["veli_emb"], v_i, activation_fn=identity)
#         pos_j = forward_pass(params["posi_emb"], r_j, activation_fn=identity)
#         vel_j = forward_pass(params["veli_emb"], v_j, activation_fn=identity)

#         # Edge embedding
#         dr_ij = r_i - r_j
#         dv_ij = v_i - v_j
#         # w_ij = jnp.cross(dr_ij, dv_ij) 
#         w_ij = get_omega(dr_ij, dv_ij)
#         pos_ij = jnp.hstack(
#             (forward_pass(params["posij_emb"], sp, activation_fn=identity), dr_ij))
#         vel_ij = jnp.hstack(
#             (forward_pass(params["velij_emb"], sp, activation_fn=identity), w_ij))
#         # pos_ij = forward_pass(params["posij_emb"], jnp.hstack(
#         #     (sp, dr_ij)), activation_fn=identity)
#         # vel_ij = forward_pass(params["velij_emb"], jnp.hstack(
#         #     (sp, w_ij)), activation_fn=identity)

#         # Node update skipped

#         # Edge update message passing
#         pos_ij = pos_ij + forward_pass(params["mp_pos_e"], pos_i + pos_j)
#         # vel_ij = vel_ij + forward_pass(params["mp_vel_e"],
#         #                                jnp.hstack((vel_i, vel_j)))

#         # Edge final output
#         T = forward_pass(params["ke"], vel_ij).sum() #+ 0.5 * mass * jnp.square(vel).sum()
#         V_ = forward_pass(params["pe"], pos_ij).sum()
#         # V_ = mass * _g * (r1[1] + r2[1])/2
#         T = 0.5 * mass * jnp.square(vel).sum()
#         # T += 0.5 * get_I(mass, length) * jnp.square(w_ij).sum()

        return T, V_

v_lag_link = vmap(lag_link_wrap, in_axes=(0, 0, 0, 0, 0, None))


In [None]:
def update_edge_fn(edges, sent_attributes, received_attributes, globals_):
    L = v_lag_link(sent_attributes["position"], received_attributes["position"],
                   sent_attributes["velocity"], received_attributes["velocity"],
                   edges["species"], edges["params"])
    return frozendict({"T_V": L})

def update_node_fn(nodes, sent_attributes, received_attributes, globals_):
    T, V = received_attributes["T_V"]
    nodes = {"lag":(T-V), "ham":(T+V), "ke":T, "pe":V}
    return frozendict(nodes)

In [None]:
net = GraphNetwork(update_edge_fn, update_node_fn)

In [None]:
ss = 5
sdim = species.shape[1]
print(sdim, dim, dimω)

params = { 
        "ke_a_emb": initialize_mlp([sdim, ss], key),
        "ke_l_emb": initialize_mlp([sdim, ss], key),
        "pe_emb": initialize_mlp([sdim, ss], key),
        "d_emb": initialize_mlp([sdim, ss], key),

        "ke_a": initialize_mlp([ss+dimω, 10, 1], key),
        "ke_l": initialize_mlp([ss+dim, 10, 1], key),
        "pe": initialize_mlp([ss+2*dim, 10, 1], key),
        # "d_ang": initialize_mlp([1, 10, 1], key),
        "d_lin": initialize_mlp([ss, 10, 1], key),
}

In [None]:
# params, _ = loadfile("chain_model_trained_free.pkl")
# print(params.keys())

In [None]:

graph = GraphsTuple(
    nodes={
        "position": R,
        "velocity": V,
        "acceleration": 0*V,
    },
    edges={
        "params": params,
        "species": species,
    },
    senders=jnp.array(sends+recs),
    receivers=jnp.array(recs+sends),
    globals=None,
    n_node=jnp.array([len(R)]),
    n_edge=None
)


In [None]:
@jit
def Lgraph(R, V, params):
    graph = GraphsTuple(
        nodes={
            "position": R,
            "velocity": V,
            "acceleration": 0*V,
        },
        edges={
            "params": params,
            "species": species,
        },
        senders=jnp.array(sends+recs),
        receivers=jnp.array(recs+sends),
        globals=None,
        n_node=jnp.array([len(R)]),
        n_edge=None
    )
    g = net(graph)
    return g.nodes["lag"].sum()


def Lactual(R0, V0, params):
    return Lgraph(R0, V0, None)


Lgraph(R0, V0, params), Lactual(R0, V0, None)

In [None]:
CMAP = ['RdPu'] #['Blues','BuGn','BuPu','GnBu','OrRd','PuBu','PuRd','Purples','RdPu','ocean_r']

from mpl_toolkits.axes_grid1 import make_axes_locatable

def makeMplot(M, threshold=0.0, label=None):
    fig, axs = panel(1, 2, dpi=1000, hshift=0.08, label=label)
    
    cmap = "RdPu"

    
    plt.sca(axs[0])
    ax = plt.gca()

    cb = plt.imshow(M[::-1], cmap=cmap)
    m = len(M)

    # Major ticks
    ax.set_xticks(np.arange(0, m, 1))
    ax.set_yticks(np.arange(0, m, 1))

    # Labels for major ticks
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    # Minor ticks
    ax.set_xticks(np.arange(-.5, m, 1), minor=True)
    ax.set_yticks(np.arange(-.5, m, 1), minor=True)

    
    # Gridlines based on minor ticks
    # ax.grid(which='minor', color='w', linestyle='-', linewidth=2)        

    ax.tick_params(axis='x', colors='w', which="both", width=0)
    ax.tick_params(axis='y', colors='w', which="both", width=0)
    
    divider = make_axes_locatable(plt.gca())
    ax_cb = divider.new_horizontal(size="5%", pad=0.05)    
    cb1 = mpl.colorbar.ColorbarBase(ax_cb, cb, cmap=mpl.cm.RdPu, orientation='vertical', label="Percentage error")
    plt.gcf().add_axes(ax_cb)
    

    plt.sca(axs[1])
    ax = plt.gca()

    cb = plt.imshow(jnp.abs(M[::-1]) > threshold, cmap=cmap)
    # plt.axis("off")
    m = len(M)
    ax = plt.gca()

    # Major ticks
    ax.set_xticks(np.arange(0, m, 1))
    ax.set_yticks(np.arange(0, m, 1))

    # Labels for major ticks
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    # Minor ticks
    ax.set_xticks(np.arange(-.5, m, 1), minor=True)
    ax.set_yticks(np.arange(-.5, m, 1), minor=True)

    
    # Gridlines based on minor ticks
    # ax.grid(which='minor', color='w', linestyle='-', linewidth=2)        

    ax.tick_params(axis='x', colors='w', which="both", width=0)
    ax.tick_params(axis='y', colors='w', which="both", width=0)
        
# makeMplot(M)
# makeMplot(M_actual)

In [None]:
parts_actual = jit(lnn.EL_parts(N, dim,
                         lagrangian=Lactual,
                         # non_conservative_forces=None,
                         constraints=constraints,
                         # external_force=None,
                         )
            )

M_actual = parts_actual(R, V, params)[0]
M_actual

parts = jit(lnn.EL_parts(N, dim,
                         lagrangian=Lgraph,
                         # non_conservative_forces=None,
                         constraints=constraints,
                         # external_force=None,
                         )
            )

M = parts(R, V, params)[0]

M_normalised = M / M[0,0]
M_actual_normalised = M_actual / M_actual[0,0]

error = np.abs(M_normalised - M_actual_normalised) / np.abs(M_actual_normalised + 1.0e-10) * 100

ms_error = np.square(M_normalised - M_actual_normalised).mean()

print(ms_error)

M.shape

# makeMplot(error, label=["a", "b"])
# savefig("notebooks/mass_p_error.png", dpi=600)

In [None]:

def drag_link_wrap(r1, r2, v1, v2, sp, params):
    R = jnp.hstack((r1, r2))
    V = jnp.hstack((v1, v2))
    return drag_link(R, V, sp, params)


def identity(x):
    return x


def drag_link(R, V, sp, params):
    if params is None:
        r1, r2 = jnp.split(R, 2)
        v1, v2 = jnp.split(V, 2)
        pos = (r1 + r2) / 2
        vel = (v1 + v2) / 2
        r_ = r1 - r2
        v_ = v1 - v2
        w = jnp.cross(r_, v_)
        # drag_ang = 0.1*w
        drag_lin = -0.01*vel
        return drag_lin #+drag_ang
    else:
        r1, r2 = jnp.split(R, 2)
        v1, v2 = jnp.split(V, 2)
        pos = (r1 + r2) / 2
        vel = (v1 + v2) / 2
        r_ = r1 - r2
        v_ = v1 - v2
        w = jnp.cross(r_, v_)
        sp_ = forward_pass(params["d_emb"], sp, activation_fn=identity)
        dr = forward_pass(params["d_lin"], sp_)*vel
        return dr

v_drag_link = vmap(drag_link_wrap, in_axes=(0, 0, 0, 0, 0, None))


In [None]:
def update_edge_fn_drag(edges, sent_attributes, received_attributes, globals_):
    dr = v_drag_link(sent_attributes["position"], received_attributes["position"],
                   sent_attributes["velocity"], received_attributes["velocity"],
                   edges["species"], edges["params"])
    return frozendict({"drag": dr})

def update_node_fn_drag(nodes, sent_attributes, received_attributes, globals_):
    dr = received_attributes["drag"]
    nodes = {"drag": dr/2}
    return frozendict(nodes)

In [None]:
net2 = GraphNetwork(update_edge_fn_drag, update_node_fn_drag)

@jit
def DRAG(R, V, params):
    graph = GraphsTuple(
        nodes={
            "position": R,
            "velocity": V,
            "acceleration": 0*V,
        },
        edges={
            "params": params,
            "species": species,
        },
        senders=jnp.array(sends+recs),
        receivers=jnp.array(recs+sends),
        globals=None,
        n_node=jnp.array([len(R)]),
        n_edge=None
    )
    g = net2(graph)
    return g.nodes["drag"].reshape(-1, 1)

DRAG(R, V, None)

In [None]:
params.keys()

In [None]:
acceleration_fn_graph = jit(lnn.accelerationFull(N, dim,
                                                 lagrangian=Lgraph,
                                                 non_conservative_forces=DRAG,
                                                 constraints=constraints,
                                                 # external_force=None,
                                                 jac_constraints=jac_constraints,
                                                 )
                           )

acceleration_fn_graph(R, V, params)

@ jit
def force_fn_graph(R, V, params, mass=None):
    if mass is None:
        return acceleration_fn_graph(R, V, params)
    else:
        return acceleration_fn_graph(R, V, params) * mass.reshape(-1, 1)


masses = mass*jnp.ones(len(R0))

acceleration_fn_graph(R0, V0, t.params), acceleration_fn_graph(R0, V0, None)



In [None]:
@ jit
def force_fn_graph(R, V, params, mass=None):
    if mass is None:
        return acceleration_fn_graph(R, V, params)
    else:
        return acceleration_fn_graph(R, V, params) * mass.reshape(-1, 1)


acceleration_fn_graph(R0, V0, params)

masses = jnp.ones(len(R0))

dtt = dt

@ jit
def forward_sim_graph(R, V):
    return prediction(R,  V, None, force_fn_graph, shift, dtt, masses,
                      dR_max=1.0e-5, stride=stride, runs=10*runs)


# try:
#     states_graph, _ = loadfile("chain_data_4_gen.pkl")
# except:

states_graph = forward_sim_graph(R0, V0)
savefile("chain_data_4_drag.pkl", states_graph)
save_ovito("chain_graph_4_drag.data", NVEStates(states_graph), length=10.0,
       insert_origin=True)


# @ jit
# def forward_sim_graph_model(R, V):
#     return prediction(R,  V, params, force_fn_graph, shift, dt, masses,
#                       dR_max=1.0e-5, stride=stride, runs=runs)




In [None]:

Rs, Vs, Fs = utils.States().fromlist(NVEStates(states_graph)).get_array()

Rs = Rs.reshape(-1, N, dim)
Vs = Vs.reshape(-1, N, dim)
Fs = Fs.reshape(-1, N, dim)

mask = np.random.choice(len(Rs), len(Rs), replace=False)
allRs = Rs[mask]
allVs = Vs[mask]
allFs = Fs[mask]

Ntr = int(0.75*len(Rs))
Nts = len(Rs) - Ntr

Rs = allRs[:Ntr]
Vs = allVs[:Ntr]
Fs = allFs[:Ntr]

Rst = allRs[Ntr:]
Vst = allVs[Ntr:]
Fst = allFs[Ntr:]

################################################
################### ML Model ###################
################################################

R, V = Rs[0], Vs[0]

v_acceleration_fn_graph = vmap(acceleration_fn_graph, in_axes=(0, 0, None))

v_acceleration_fn_graph(Rs, Vs, params)

################################################
################## ML Training #################
################################################

error_fn = "L2error"
LOSS = getattr(src.models, error_fn)


@ jit
def loss_fn(params, Rs, Vs, Fs):
    pred = v_acceleration_fn_graph(Rs, Vs, params)
    return LOSS(pred, Fs)


loss_fn(params, Rst, Vst, Fst)


@ jit
def gloss(*args):
    return value_and_grad(loss_fn)(*args)


lr = 1.0e-3
opt_init, opt_update_, get_params = optimizers.adam(lr)


@ jit
def update(i, opt_state, params, loss__, *data):
    """ Compute the gradient for a batch and update the parameters """
    value, grads_ = gloss(params, *data)
    k = 1.0
    opt_state = opt_update(i, jax.tree_map(
        lambda x: jnp.clip(x, a_min=-k, a_max=k), grads_), opt_state)
    return opt_state, get_params(opt_state), value


@ jit
def opt_update(i, grads_, opt_state):
    grads_ = jax.tree_map(jnp.nan_to_num, grads_)
    # grads_ = jax.tree_map(partial(jnp.clip, a_min=-1000.0, a_max=1000.0), grads_)
    return opt_update_(i, grads_, opt_state)


@ jit
def step(i, ps, *args):
    return update(i, *ps, *args)


def batching(*args, size=None):
    L = len(args[0])
    if size != None:
        nbatches1 = int((L - 0.5) // size) + 1
        nbatches2 = max(1, nbatches1 - 1)
        size1 = int(L/nbatches1)
        size2 = int(L/nbatches2)
        if size1*nbatches1 > size2*nbatches2:
            size = size1
            nbatches = nbatches1
        else:
            size = size2
            nbatches = nbatches2
    else:
        nbatches = 1
        size = L

    newargs = []
    for arg in args:
        newargs += [jnp.array([arg[i*size:(i+1)*size]
                               for i in range(nbatches)])]
    return newargs


batch_size = 10
bRs, bVs, bFs = batching(Rs, Vs, Fs,
                         size=min(len(Rs), batch_size))



In [None]:
batch_size

In [None]:
print(f"training ...")

opt_state = opt_init(params)


class Trainer():
    def __init__(self, step, loss_fn, opt_state, traindata, testdata, params):
        self.step = step
        self.loss_fn = loss_fn
        self.traindata = traindata
        self.testdata = testdata
        self.params = params
        self.opt_state = opt_state
        self.epoch = 0
        self.optimizer_step = -1
        self.larray = []
        self.ltarray = []
        self.last_loss = 1000
        self.saveat = 1
        self.larray += [loss_fn(params, Rs, Vs, Fs)]
        self.ltarray += [loss_fn(params, Rst, Vst, Fst)]

    def print_loss(self, iter_):
        print(f"{iter_} train={self.larray[-1]}, test={self.ltarray[-1]}")

    def training(self, epochs, batch_size=1, lr=1.0e-3, saveat=1):
        self.saveat = saveat
        Rs, Vs, Fs = self.traindata
        Rst, Vst, Fst = self.testdata
        bRs, bVs, bFs = batching(Rs, Vs, Fs,
                                 size=min(len(Rs), batch_size))
        params = self.params
        opt_state = self.opt_state
        for epoch in range(epochs):
            for data in zip(bRs, bVs, bFs):
                self.optimizer_step += 1
                opt_state, params, l_ = step(
                    self.optimizer_step, (opt_state, params, 0), *data)

            # opt_state, params, l = step(
            #     optimizer_step, (opt_state, params, 0), Rs, Vs, Fs)

            if epoch % self.saveat == 0:
                self.larray += [self.loss_fn(params, Rs, Vs, Fs)]
                self.ltarray += [self.loss_fn(params, Rst, Vst, Fst)]
                self.print_loss(f"Epoch {epoch}/{epochs}, ")

            if epoch % self.saveat == 0:
                if self.last_loss > self.larray[-1]:
                    self.last_loss = self.larray[-1]
                self.params = params
                self.opt_state = opt_state

        self.params = params
        self.opt_state = opt_state


In [None]:
t = Trainer(step, loss_fn, opt_state, (Rs, Vs, Fs),
            (Rst, Vst, Fst), params)


In [None]:
t.training(10000, batch_size=10, saveat=10)

## Forward 

In [None]:
params = t.params

RUNS = 10*runs

@ jit
def forward_sim_graph_model(R, V):
    return prediction(R,  V, params, force_fn_graph, shift, dt, masses,
                      dR_max=1.0e-5, stride=stride, runs=RUNS)


@ jit
def forward_sim_long(R, V):
    return prediction(R,  V, None, force_fn_graph, shift, dt, masses,
                      dR_max=1.0e-5, stride=stride, runs=RUNS)


In [None]:
# Rs = []
# Vs = []
# for _ in range(20):
#     R, V = getRV()
#     Rs += [R[None, :, :]]
#     Vs += [V[None, :, :]]
    
# Rs, Vs = jnp.vstack(Rs), jnp.vstack(Vs)
# Rs.shape

In [None]:
# states_long = vmap(forward_sim_long)(Rs, Vs)

In [None]:
# model_states = vmap(forward_sim_graph_model)(Rs, Vs)

In [None]:
n

In [None]:
# RS = []
# VS = []
# for i in range(5):
#     r, v = getRV(n, t=45)
#     RS += [r]
#     VS += [v]

# RS = jnp.array(RS)
# VS = jnp.array(VS)
# RS.shape

In [None]:
# vmap(forward_sim_long, 0, 0)(RS, VS)
n = 50

In [None]:
# DATA1 = {}
# DATA2 = {}
# for i in range(1):
#     DATA1[i] = []
#     DATA2[i] = []
#     file1 = f"chain_exp/chain_n{n}_i{i}_free_new.data"
#     file2 = f"chain_exp/chain_n{n}_i{i}_model_free_new.data"
#     for j in range(100):
#         data1 = np.loadtxt(file1, skiprows=104*j+2, max_rows=102)
#         data2 = np.loadtxt(file2, skiprows=104*j+2, max_rows=102)
#         DATA1[i] += [data1]
#         DATA2[i] += [data2]
        
        
        

In [None]:
# states_long_ = []
# model_states_ = []
# for data1, data2 in zip(DATA1[0], DATA2[0]):
#     pos, vel = data1[:, 2:4], data1[:, 4:6]
#     states_long_ += [NVEState(pos, vel, 0*vel, masses, 0.0)]
#     pos, vel = data2[:, 2:4], data2[:, 4:6]
#     model_states_ += [NVEState(pos, vel, 0*vel, masses, 0.0)]
    

In [None]:
# model_states = [jax.tree_multimap(lambda *a: jnp.stack(a), *model_states_)]
# states_long = [jax.tree_multimap(lambda *a: jnp.stack(a), *states_long_)]


In [None]:
n = 2

In [None]:
states_long = []
model_states = []

seed = 0

np.random.seed(seed)
key = jax.random.PRNGKey(seed)

for i in range(20):
    print(i)
    R, V = getRV(n, t=45)
    
    states_long += [forward_sim_long(R, V)]
    save_ovito(f"chain_exp/chain_n{n}_i{i}_free_drag.data", NVEStates(states_long[i]), length=10.0,
               insert_origin=True)

    model_states += [forward_sim_graph_model(R, V)]
    save_ovito(f"chain_exp/chain_n{n}_i{i}_model_free_drag.data", NVEStates(model_states[i]), length=10.0,
               insert_origin=True)
    
    

## END

In [None]:
savefile(f"chain_exp/n{n}_states_drag.pkl", {"model":model_states, "states":states_long}, 
         metadata=dict(dt=dt, stride=stride, runs=10*runs, samples=len(states_long)))

In [None]:
@jit
def Energy(R, V, params):
    graph = GraphsTuple(
        nodes={
            "position": R,
            "velocity": V,
            "acceleration": 0*V,
        },
        edges={
            "params": params,
            "species": species,
        },
        senders=jnp.array(sends+recs),
        receivers=jnp.array(recs+sends),
        globals=None,
        n_node=jnp.array([len(R)]),
        n_edge=None
    )
    nodes = net(graph).nodes
    return jnp.array([nodes["lag"].sum(), nodes["ham"].sum(), nodes["ke"].sum(), nodes["pe"].sum()])

In [None]:
v_Energy = vmap(Energy, in_axes=(0, 0, None))

In [None]:
v_Energy(states_long[0].position, states_long[0].velocity, None).shape

In [None]:
states_long[0].position

In [None]:
# E_pred = v_Energy(model_states[0].position, model_states[0].velocity, None)
# E = v_Energy(states_long[0].position, states_long[0].velocity, None)

# plt.plot(E[:, 0]*0, "--")

# plt.plot(E, label=["L", "H", "KE", "PE"], lw=6)

# plt.plot(E_pred, label=["L_pred", "H_pred", "KE_pred", "PE_pred"], ls="--", color="k")

# plt.ylabel("Energy")
# plt.xlabel("Time")
# plt.legend(ncol=2, loc=2, bbox_to_anchor=(1,1))

In [None]:
def std_plot(y_, semilog=True, dt=1.0):
    mean_ = jnp.log(jnp.array(y_)).mean(axis=0)
    std_ = jnp.log(jnp.array(y_)).std(axis=0)

    up_b = jnp.exp(mean_ + 2*std_)
    low_b = jnp.exp(mean_ - 2*std_)
    y = jnp.exp(mean_)

    x = jnp.array(range(len(mean_)))*dt
    if semilog:
        plt.semilogy(x, y)
    else:
        plt.plot(x, y)
    plt.fill_between(x, low_b, up_b, alpha=0.5)

In [None]:
def norm(a):
    a2 = jnp.square(a)
    n = len(a2)
    a3 = a2.reshape(n, -1)
    return jnp.sqrt(a3.sum(axis=1))

def RelErr(ya, yp):
    return norm(ya-yp) / (norm(ya) + norm(yp))

def AbsErr(ya, yp):
    return norm(ya-yp)  


In [None]:
def getall_state_H():
    out1 = []
    out2 = []
    for model_state, state in zip(model_states, states_long):
        E_pred = v_Energy(model_state.position, model_state.velocity, None)[:, 1]
        E = v_Energy(state.position, state.velocity, None)[:, 1]
        out1 += [RelErr(E_pred, E)]
        out2 += [AbsErr(E_pred, E)]
    return jnp.array(out1), jnp.array(out2)


In [None]:
stride, dt, RUNS, # H_error.shape

In [None]:
# # {"model":model_states, "states":states_long}, 
#          # metadata=dict(dt=dt, stride=stride, runs=100*runs, samples=len(states_long))
# n = 100
# data = loadfile(f"chain_exp/n{n}_states_new.pkl")[0]
# # data = loadfile(f"chain_exp/n{n}_states.pkl")[0]
# model_states = data['model']
# states_long = data['states']

In [None]:
RH_error, AH_error = getall_state_H()

In [None]:
n

In [None]:
std_plot(RH_error, dt=stride*dt)
plt.ylabel("Energy error")
plt.xlabel("Time step")

plt.savefig(f"chain_exp/n{n}_RH_error_drag.png", dpi=600)
savefile(f"chain_exp/n{n}_RH_error_drag.pkl", RH_error)


In [None]:
std_plot(AH_error, dt=stride*dt)
plt.ylabel("Energy error")
plt.xlabel("Time step")

plt.savefig(f"chain_exp/n{n}_AH_error_drag.png", dpi=600)
savefile(f"chain_exp/n{n}_AH_error_drag.pkl", AH_error)


In [None]:
def getall_state_Z():
    out1 = []
    out2 = []
    for model_state, state in zip(model_states, states_long):
        out1 += [RelErr(model_state.position, state.position)]
        out2 += [AbsErr(model_state.position, state.position)]
    return jnp.array(out1), jnp.array(out2)


In [None]:
RZ_error, AZ_error = getall_state_Z()


In [None]:
std_plot(RZ_error, dt=stride*dt)
plt.ylabel("Rollout error")
plt.xlabel("Time step")

plt.savefig(f"chain_exp/n{n}_RZ_error_drag.png", dpi=600)
savefile(f"chain_exp/n{n}_RZ_error_drag.pkl", RZ_error)

In [None]:
std_plot(AZ_error, dt=stride*dt)
plt.ylabel("Rollout error")
plt.xlabel("Time step")

plt.savefig(f"chain_exp/n{n}_AZ_error_drag.png", dpi=600)
savefile(f"chain_exp/n{n}_AZ_error_drag.pkl", AZ_error)

In [None]:
species

In [None]:
def drag_(v):
    sp = species[0]
    sp_ = forward_pass(params["d_emb"], sp, activation_fn=identity)
    return forward_pass(params["d_lin"], sp_)*v

In [None]:
v_ = jnp.arange(-1, 1, 0.001)
dr = vmap(drag_, 0)(v_)

plt.plot(v_, dr)
# plt.plot(v_, -0.1*v_)



In [None]:
# make_plot(states_long)
# make_plot(model_states)

In [None]:
# @jit
# def potE(R, params=None):
#     r = (R[1:, :] + R[:-1, :]) / 2
#     if params is None:
#         return (mass * _g * r[:, 1]).sum()
#     else:
#         return vmap(forward_pass, in_axes=(None, 0))(params, r).sum()


# potE(R)


# @jit
# def kinE_l(R, V, params=None):
#     v = (V[1:, :] + V[:-1, :]) / 2
#     if params is None:
#         kin1 = 0.5 * mass * jnp.square(v).sum()
#     else:
#         kin1 = jnp.square(
#             vmap(forward_pass, in_axes=(None, 0))(params, v)).sum()
#     return kin1


# kinE_l(R, V)


# def kinE_a(R, V, params=None):
#     r_ = (R[1:, :] - R[:-1, :])
#     v_ = (V[1:, :] - V[:-1, :])
#     w = vmap(get_omega)(r_, v_)
#     if params is None:
#         I = get_I(mass, length)
#         kin2 = 0.5 * I * jnp.square(w).sum()
#     else:
#         kin2 = jnp.square(
#             vmap(forward_pass, in_axes=(None, 0))(params, w)).sum()
#     return kin2


# kinE_a(R, V)


# @jit
# def kinE_(R, V, params=None):
#     return kinE_l(R, V, params=params["l"]) + kinE_a(R, V, params=params["a"])


# @jit
# def kinE(R, V, params=None):
#     if params is None:
#         params = {"l": None, "a": None}
#     return kinE_(R, V, params=params)
#     # return 0.5 * mass * jnp.square(V).sum()

# def make_plot(states):
#     PE = []
#     KE_l = []
#     KE_a = []
#     KE = []
#     time = []
#     dt = []
#     time_ = 0.0
#     for state in NVEStates(states):
#         PE.append(potE(state.position))
#         KE.append(kinE(state.position, state.velocity))
#         KE_l.append(kinE_l(state.position, state.velocity))
#         KE_a.append(kinE_a(state.position, state.velocity))
#         time.append(state.time)
#         dt.append(state.time - time_)
#         time_ = state.time

#     KE_l = jnp.array(KE_l)
#     KE_a = jnp.array(KE_a)
#     PE = jnp.array(PE)
#     KE = jnp.array(KE)

#     PE = PE - PE[0]
#     KE = KE - KE[0]
#     KE_l = KE_l - KE_l[0]
#     KE_a = KE_a - KE_a[0]
#     TE = KE + PE

#     fix, ax = plt.subplots(1, 3, figsize=(18, 6), sharex=True)

#     plt.sca(ax[0])
#     plt.plot(time, 0*TE, "-k", label="x-axis", alpha=0.5)
#     plt.plot(time, PE, label="PE")
#     plt.plot(time, KE, label="KE")
#     plt.plot(time, TE, label="TE")
#     plt.plot(time, KE_l, "--", label="KE_l")
#     plt.plot(time, KE_a, "--", label="KE_a")
#     plt.legend()

#     plt.sca(ax[1])
#     plt.plot(time, TE)
#     plt.ylabel("Hamiltonian")

#     plt.sca(ax[2])
#     plt.semilogy(time, dt)
#     plt.ylabel("dt")

#     plt.show()
