## Import

In [None]:
import jax
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import sys
sys.path.append("..")
import src
from frozendict import frozendict
from jax import grad, jit
from jax import numpy as jnp
from jax import value_and_grad, vmap
from jax.config import config
from jax.experimental import optimizers
from jraph import GraphNetwork, GraphsTuple
from src import lnn, utils
from src.io import *
from src.lnn import accelerationFull
from src.md import displacement, prediction, shift
from src.models import forward_pass, initialize_mlp
from src.nve import NVEState, NVEStates
from shadow.plot import *
from shadow.font import *

config.update("jax_enable_x64", True)

try:
    jax.tree_util.register_pytree_node(
    frozendict,
    flatten_func=lambda s: (tuple(s.values()), tuple(s.keys())),
    unflatten_func=lambda k, xs: frozendict(zip(k, xs)))
except:
    pass

## Config

In [None]:
n = 10
LENGTH = 1
link_size = LENGTH / n

def move_by_link(R, n=1, x=1, y=1, ls=1, t=0):
    dx = jnp.cos(t/180*jnp.pi)*ls
    dy = jnp.sin(t/180*jnp.pi)*ls
    return R + n* jnp.array([[x*dx, y*dy, 0.0], ])
        
def rot(x, t):
    theta = t * jnp.pi / 180
    sin = jnp.sin(theta)
    cos = jnp.cos(theta)
    ROT = jnp.array([
        [cos, -sin, 0],
        [sin, cos, 0],
        [0, 0, 1]
    ])
    return jnp.dot(ROT, x.T).T


sss = LENGTH*0.3

R_ = link_size * \
    jnp.vstack([jnp.zeros(n+1), -jnp.arange(0, n+1), jnp.zeros(n+1)]).T

R1 = rot(R_, 45)

R_ = link_size * \
    jnp.vstack([jnp.zeros(n), -jnp.arange(0, n), jnp.zeros(n)]).T

R2 = rot(R_, 45+90) + R1[-1:, :]
R2 = move_by_link(R2, ls=link_size, t=45)

R_ = link_size * \
    jnp.vstack([jnp.zeros(n), -jnp.arange(0, n), jnp.zeros(n)]).T

R3 = rot(R_, 45+90+90) + R2[-1:, :]
R3 = move_by_link(R3, ls=link_size, t=45+90)

R_ = link_size * \
    jnp.vstack([jnp.zeros(n), -jnp.arange(0, n), jnp.zeros(n)]).T

R4 = rot(R_, 45+90+90+90) + R3[-1:, :]
R4 = move_by_link(R4[:-1], ls=link_size, t=45+90+90)

R = jnp.vstack((R1,R2,R3,R4))

R = R + jnp.array([[0.0, -1.0, 0.0], ])

dim = 2

if dim == 2:
    R = R[:, :2]


N, dim = R.shape

V = 0*R

# V = np.array(V)
# V[1, 0] = 1.0
# V[-1, 0] = -1.0


R0 = R
V0 = V


mass = 1.0
length = 1.0
_g = 10.0
dt = 1.0e-4
stride = 100
runs = 100



sends = [i for i in range(N-1)] + [N-1, n]
recs = [i for i in range(1, N)] + [0, 3*n]

key = jax.random.PRNGKey(0)

nspecies = 1

def OHE(x):
    return jax.nn.one_hot(x, nspecies)

species = OHE(jnp.ones(len(sends+recs)))




print(R.shape)


In [None]:
def make_fixed_node(ax, shift=0.0, theta=0.0, scale=1.0, *args, **kwargs):
    R = jnp.array([[0, 0, 0],
                   [1, 1, 0],
                   [1, -1, 0.0],
                  ])
    P = []
    for i in range(11):
        P += [[1.0, -1+0.2*i, 0.0], [1.2, -1+0.2*i+0.2,0.0]]
    P = jnp.array(P)
    
    R *= scale
    P *= scale
    
    P = rot(P, theta)
    R = rot(R, theta)
    
    R += shift
    P += shift
    
    for i in range(len(R)-1):
        ax.plot([R[i, 0], R[i+1, 0]], [R[i, 1], R[i+1, 1]], **kwargs)
    ax.plot([R[0, 0], R[-1, 0]], [R[0, 1], R[-1, 1]], **kwargs)
    
    for k in range(len(P)//2):
        ax.plot([P[2*k][0], P[2*k+1][0]], [P[2*k][1], P[2*k+1][1]], **kwargs)


In [None]:
def makechainfig(R, R1):
    fig, ax = plt.subplots(figsize=(12, 12), dpi=1000)
    
    for i,j in zip(sends, recs):
        plt.plot([R[i, 0], R[j, 0]], [R[i, 1], R[j, 1]], color="#729fcf", lw=4)
        plt.plot([R1[i, 0], R1[j, 0]], [R1[i, 1], R1[j, 1]], zorder=3, color="k", lw=1)

    i = n 
    j = 3*n
    plt.plot([R[i, 0], R[j, 0]], [R[i, 1], R[j, 1]], color="#aa33ff", lw=4)
    plt.plot([R1[i, 0], R1[j, 0]], [R1[i, 1], R1[j, 1]], zorder=3, color="k", lw=1)


    plt.scatter(R[:, 0], R[:, 1], zorder=2, color="#999999", ec="white", s=120, label="Actaul")
    plt.scatter(R1[:, 0], R1[:, 1], zorder=4, color="r", ec="white", s=20, label="Predicted")
    plt.scatter(R[0, 0], R[0, 1], zorder=2, color="brown", s=150, marker="+")
    plt.scatter(R[2*n, 0], R[2*n, 1], zorder=2, color="brown", s=150, marker="+")
    
    make_fixed_node(ax, theta=180, shift=jnp.array([[0.0, -1, 0.0], ]), color="lightblue", scale=0.1)
    make_fixed_node(ax, theta=0, shift=jnp.array([[jnp.sqrt(2), -1, 0.0], ]), 
                    color="lightblue", scale=0.1)
    
    plt.text(R[0, 0]-0.1, R[0, 1]+0.7, "E")
    
    
    ax.set_aspect(1)
    plt.axis("off")
    


In [None]:
set_font_size(50)

ind = 0
for pred_state, act_state in zip(NVEStates(model_states), NVEStates(states_long)):
    ind += 1
    if ind%1000==0 or ind==1:
        R1 = pred_state.position
        R2 = act_state.position
        makechainfig(R2, R1)
        # plt.legend()
        # plt.text(R1[0, 0]-0.1, R1[0, 1]+0.5, "Time = {:.1f}".format(ind*0.001))
        plt.savefig(f"notebooks/T2.png", dpi=600)

In [None]:
def get_I(m, L):
    return m/12*L**2


def get_omega(r, v):
    # ω = r × v / |r | ²
    r2 = jnp.square(r).sum(axis=-1, keepdims=True)
    rxv = jnp.cross(r, v)
    return rxv / r2


dimω,  = get_omega(R[0], V[0]).shape

key = jax.random.PRNGKey(0)

endpoint2 = jnp.array([[jnp.sqrt(2), -1.0, 0.0]])
if dim == 2:
    endpoint1 = np.array([[0.0, -1.0]])
    endpoint2 = endpoint2[:, :2]
else:
    endpoint1 = np.array([[0.0, -1.0, 0.0]])

endpoint2

In [None]:
# def holo_con_wrap(r1, r2, sp, params, l=1.0):
#     return jnp.square(r1-r2).sum() - l**2

 
# v_holo_con = vmap(holo_con_wrap, in_axes=(0, 0, 0, None))

 
# def update_edge_fn2(edges, sent_attributes, received_attributes, globals_):
#     H = v_holo_con(sent_attributes["position"],
#                    received_attributes["position"], edges["species"], edges["params"])
#     return frozendict({"hcon": H})

# cnet = GraphNetwork(update_edge_fn2, None)


# def hconstraints2(R, V, params):
#     graph = GraphsTuple(
#         nodes={
#             "position": R,
#             "velocity": V,
#             "acceleration": 0*V,
#         },
#         edges={
#             "params": params,
#             "species": species,
#         },
#         senders=jnp.array(sends+recs),
#         receivers=jnp.array(recs+sends),
#         globals=None,
#         n_node=jnp.array([len(R)]),
#         n_edge=None
#     )
#     return cnet(graph).edges["hcon"]


# @jit
# def constraints2(x, v, params):
#     return jax.jacobian(lambda x: hconstraints2(x.reshape(-1, dim), v.reshape(-1, dim), params), 0)(x)


# constraints2(R.flatten(), V, params).shape

In [None]:
# def hconstraints1(R, l=jnp.array([1.0])):
#     if len(l) != len(R):
#         l = l[0]*jnp.ones((2))
#     out = jnp.square(jnp.vstack([R[:1], endpoint2]) -
#                      jnp.vstack([endpoint1, R[-1:]])).sum(axis=1) - l**2
#     return out

# @ jit
# def constraints1(x, v, params):
#     return jax.jacobian(lambda x: hconstraints1(x.reshape(-1, dim)), 0)(x)

# constraints1(R.flatten(), V, params).shape

In [None]:
# def constraints(R, V, params):
#     return jnp.vstack((constraints1(R, V, params), constraints2(R, V, params)))

# constraints(R.flatten(), V, params).shape

In [None]:
# def hconstraints(R, l=jnp.array([1.0])):
#     if len(l) != len(R):
#         l = l[0]*jnp.ones((len(R)+1))
#     out = jnp.square(jnp.vstack([R, endpoint2]) -
#                      jnp.vstack([endpoint1, R])).sum(axis=1) - l**2
#     return out

# def hconstraints(R, l=jnp.array([1.0])):
#     if len(l) != len(R):
#         l = l[0]*jnp.ones((len(R)+2))
#     out = jnp.square(jnp.vstack([R, R[1:2], R[:1]]) -
#                      jnp.vstack([endpoint1, R, R[-2:-1]])).sum(axis=1) - l**2
#     return out

# @ jit
# def constraints(x, v, params):
#     return jax.jacobian(lambda x: hconstraints(x.reshape(-1, dim)), 0)(x)

In [None]:
def hconstraints(R, l=jnp.array([1.0])):
    if len(l) != len(R):
        l = l[0]*jnp.ones((len(R)+3))
    out = jnp.square(jnp.vstack([R, R[:1], R[2*n:2*n+1], R[n:n+1]]) -
                     jnp.vstack([endpoint1, R, endpoint2, R[3*n:3*n+1]])).sum(axis=1) - l**2
    return out

@ jit
def constraints(x, v, params):
    return jax.jacobian(lambda x: hconstraints(x.reshape(-1, dim)), 0)(x)

## Lag

In [None]:

def lag_link_wrap(r1, r2, v1, v2, sp, params):
    R = jnp.hstack((r1, r2))
    V = jnp.hstack((v1, v2))
    return lag_link(R, V, sp, params)


def identity(x):
    return x


def lag_link(R, V, sp, params):
    if params is None:
        r1, r2 = jnp.split(R, 2)
        v1, v2 = jnp.split(V, 2)
        pos = (r1 + r2) / 2
        vel = (v1 + v2) / 2
        r_ = r1 - r2
        v_ = v1 - v2
        w = jnp.cross(r_, v_)
        T = 0.5 * mass * jnp.square(vel).sum() + \
            0.5 * get_I(mass, length) * jnp.square(w).sum()
        V_ = mass * _g * pos[1]
        # T = 0.5 * (v1**2 + v2**2).sum()
        # V = 10.0 * (r1[1] + r2[1]).sum()
        return T, V_
    else:
        r1, r2 = jnp.split(R, 2)
        v1, v2 = jnp.split(V, 2)
        pos = (r1 + r2) / 2
        vel = (v1 + v2) / 2
        r_ = r1 - r2
        v_ = v1 - v2
        w = jnp.cross(r_, v_)

        w_ = forward_pass(params["ke_a_emb"], sp, activation_fn=identity)
        w_ = jnp.hstack((w_, w))
        vel_ = forward_pass(params["ke_l_emb"], sp, activation_fn=identity)
        vel_ = jnp.hstack((vel_, vel))

        pos_ = forward_pass(params["pe_emd"], sp, activation_fn=identity)
        pos_ = jnp.hstack((pos_, pos, r_))

        T = jnp.square(forward_pass(params["ke_l"], vel_)).sum() 
        T += jnp.square(forward_pass(params["ke_a"], w_)).sum()
        V_ = forward_pass(params["pe"], pos_).sum()

        # T = 0.5 * (v1**2 + v2**2).sum()
        # V_ = mass * _g * (r1[1] + r2[1]).sum() /2

        # r1, r2 = jnp.split(R, 2)
        # v1, v2 = jnp.split(V, 2)
        # pos = (r1 + r2) / 2
        # vel = (v1 + v2) / 2
        # r_ = r1 - r2
        # v_ = v1 - v2
        # w = jnp.cross(r_, v_)

        # w_ = forward_pass(params["ke_a_emb"], sp, activation_fn=identity)
        # w_ = jnp.hstack((w_, w))
        # vel_ = forward_pass(params["ke_l_emb"], sp, activation_fn=identity)
        # vel_ = jnp.hstack((vel_, vel))

        # pos_ = forward_pass(params["pe_emd"], sp, activation_fn=identity)
        # pos_ = jnp.hstack((pos_, pos, r_))

        # T = jnp.square(forward_pass(params["ke_l"], vel_)).sum() + \
        #     jnp.square(forward_pass(params["ke_a"], w_)).sum()

#         r1, r2 = jnp.split(R, 2)
#         v1, v2 = jnp.split(V, 2)

#         vel = (v1 + v2) / 2
#         r_ = r1 - r2
#         v_ = v1 - v2
#         w = jnp.cross(r_, v_)
#         T = 0.5 * mass * jnp.square(vel).sum() + \
#             0.5 * get_I(mass, length) * jnp.square(w).sum()

#         # Node embedding
#         r_i, r_j = r1, r2
#         v_i, v_j = v1, v2
#         pos_i = forward_pass(params["posi_emb"], r_i, activation_fn=identity)
#         vel_i = forward_pass(params["veli_emb"], v_i, activation_fn=identity)
#         pos_j = forward_pass(params["posi_emb"], r_j, activation_fn=identity)
#         vel_j = forward_pass(params["veli_emb"], v_j, activation_fn=identity)

#         # Edge embedding
#         dr_ij = r_i - r_j
#         dv_ij = v_i - v_j
#         # w_ij = jnp.cross(dr_ij, dv_ij) 
#         w_ij = get_omega(dr_ij, dv_ij)
#         pos_ij = jnp.hstack(
#             (forward_pass(params["posij_emb"], sp, activation_fn=identity), dr_ij))
#         vel_ij = jnp.hstack(
#             (forward_pass(params["velij_emb"], sp, activation_fn=identity), w_ij))
#         # pos_ij = forward_pass(params["posij_emb"], jnp.hstack(
#         #     (sp, dr_ij)), activation_fn=identity)
#         # vel_ij = forward_pass(params["velij_emb"], jnp.hstack(
#         #     (sp, w_ij)), activation_fn=identity)

#         # Node update skipped

#         # Edge update message passing
#         pos_ij = pos_ij + forward_pass(params["mp_pos_e"], pos_i + pos_j)
#         # vel_ij = vel_ij + forward_pass(params["mp_vel_e"],
#         #                                jnp.hstack((vel_i, vel_j)))

#         # Edge final output
#         T = forward_pass(params["ke"], vel_ij).sum() #+ 0.5 * mass * jnp.square(vel).sum()
#         V_ = forward_pass(params["pe"], pos_ij).sum()
#         # V_ = mass * _g * (r1[1] + r2[1])/2
#         T = 0.5 * mass * jnp.square(vel).sum()
#         # T += 0.5 * get_I(mass, length) * jnp.square(w_ij).sum()

        return T, V_

v_lag_link = vmap(lag_link_wrap, in_axes=(0, 0, 0, 0, 0, None))


In [None]:
def update_edge_fn(edges, sent_attributes, received_attributes, globals_):
    L = v_lag_link(sent_attributes["position"], received_attributes["position"],
                   sent_attributes["velocity"], received_attributes["velocity"],
                   edges["species"], edges["params"])
    return frozendict({"T_V": L})

def update_node_fn(nodes, sent_attributes, received_attributes, globals_):
    T, V = received_attributes["T_V"]
    nodes = {"lag":(T-V), "ham":(T+V), "ke":T, "pe":V}
    return frozendict(nodes)

In [None]:
net = GraphNetwork(update_edge_fn, update_node_fn)

In [None]:
params_graph, _ = loadfile("chain_model_trained_free.pkl")
print(params_graph.keys())

In [None]:

graph = GraphsTuple(
    nodes={
        "position": R,
        "velocity": V,
        "acceleration": 0*V,
    },
    edges={
        "params": params_graph,
        "species": species,
    },
    senders=jnp.array(sends+recs),
    receivers=jnp.array(recs+sends),
    globals=None,
    n_node=jnp.array([len(R)]),
    n_edge=None
)


In [None]:
@jit
def Lgraph(R, V, params):
    graph = GraphsTuple(
        nodes={
            "position": R,
            "velocity": V,
            "acceleration": 0*V,
        },
        edges={
            "params": params,
            "species": species,
        },
        senders=jnp.array(sends+recs),
        receivers=jnp.array(recs+sends),
        globals=None,
        n_node=jnp.array([len(R)]),
        n_edge=None
    )
    g = net(graph)
    return g.nodes["lag"].sum()


def Lactual(R0, V0, params_graph):
    return Lgraph(R0, V0, None)


Lgraph(R0, V0, params_graph), Lactual(R0, V0, None)

In [None]:
parts = jit(lnn.EL_parts(N, dim,
                         lagrangian=Lgraph,
                         # non_conservative_forces=None,
                         constraints=constraints,
                         # external_force=None,
                         )
            )

M = parts(R, V, params_graph)[0]

In [None]:
parts_actual = jit(lnn.EL_parts(N, dim,
                         lagrangian=Lactual,
                         # non_conservative_forces=None,
                         constraints=constraints,
                         # external_force=None,
                         )
            )

M_actual = parts_actual(R, V, params_graph)[0]
M_actual

parts = jit(lnn.EL_parts(N, dim,
                         lagrangian=Lgraph,
                         # non_conservative_forces=None,
                         constraints=constraints,
                         # external_force=None,
                         )
            )

M = parts(R, V, params_graph)[0]

M_normalised = M / jnp.abs(M).mean()

M_actual_normalised = M_actual / jnp.abs(M_actual).mean()

error = np.abs(M_normalised - M_actual_normalised) / np.abs(M_actual_normalised) * 100

ms_error = np.square(M_normalised - M_actual_normalised).mean()

print(ms_error)

# makeMplot(M, label=["c", "d"])
makeMplot(error, label=["g", "h"])

In [None]:
CMAP = ['RdPu'] #['Blues','BuGn','BuPu','GnBu','OrRd','PuBu','PuRd','Purples','RdPu','ocean_r']

from mpl_toolkits.axes_grid1 import make_axes_locatable

def makeMplot(M, threshold=0.0, label=None):
    fig, axs = panel(1, 2, dpi=1000, hshift=0.08, label=label)
    
    cmap = "RdPu"

    
    plt.sca(axs[0])
    ax = plt.gca()

    cb = plt.imshow(M[::-1], cmap=cmap)
    m = len(M)

    # Major ticks
    ax.set_xticks(np.arange(0, m, 1))
    ax.set_yticks(np.arange(0, m, 1))

    # Labels for major ticks
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    # Minor ticks
    ax.set_xticks(np.arange(-.5, m, 1), minor=True)
    ax.set_yticks(np.arange(-.5, m, 1), minor=True)

    
    # Gridlines based on minor ticks
    ax.grid(which='minor', color='w', linestyle='-', linewidth=2)        

    ax.tick_params(axis='x', colors='w', which="both", width=0)
    ax.tick_params(axis='y', colors='w', which="both", width=0)
    
    divider = make_axes_locatable(plt.gca())
    ax_cb = divider.new_horizontal(size="5%", pad=0.05)    
    cb1 = mpl.colorbar.ColorbarBase(ax_cb, cb, cmap=mpl.cm.RdPu, orientation='vertical')
    plt.gcf().add_axes(ax_cb)
    

    plt.sca(axs[1])
    ax = plt.gca()

    cb = plt.imshow(jnp.abs(M[::-1]) > threshold, cmap=cmap)
    # plt.axis("off")
    m = len(M)
    ax = plt.gca()

    # Major ticks
    ax.set_xticks(np.arange(0, m, 1))
    ax.set_yticks(np.arange(0, m, 1))

    # Labels for major ticks
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    # Minor ticks
    ax.set_xticks(np.arange(-.5, m, 1), minor=True)
    ax.set_yticks(np.arange(-.5, m, 1), minor=True)

    
    # Gridlines based on minor ticks
    ax.grid(which='minor', color='w', linestyle='-', linewidth=2)        

    ax.tick_params(axis='x', colors='w', which="both", width=0)
    ax.tick_params(axis='y', colors='w', which="both", width=0)
    
    plt.show()
    
# makeMplot(M)
# makeMplot(M_actual)

In [None]:
acceleration_fn_graph = jit(lnn.accelerationFull(N, dim,
                                                 lagrangian=Lgraph,
                                                 # non_conservative_forces=None,
                                                 constraints=constraints,
                                                 # external_force=None,
                                                 )
                            )

acceleration_fn_graph(R, V, params_graph)

@ jit
def force_fn_graph(R, V, params, mass=None):
    if mass is None:
        return acceleration_fn_graph(R, V, params)
    else:
        return acceleration_fn_graph(R, V, params) * mass.reshape(-1, 1)


acceleration_fn_graph(R0, V0, params_graph)
masses = mass*jnp.ones(len(R0))



In [None]:
acceleration_fn_graph(R0, V0, None).sum(0)

## Forward 

In [None]:
params = params_graph

@jit
def forward_sim_long(R, V):
    return prediction(R,  V, None, force_fn_graph, shift, dt/10, masses,
                      dR_max=1.0e10, stride=stride, runs=10*runs)

states_long = forward_sim_long(R0, V0)
save_ovito(f"chain_exp/chain_T2_free.data", NVEStates(states_long), length=10.0,
           insert_origin=True)


In [None]:
@ jit
def forward_sim_graph_model(R, V):
    return prediction(R,  V, params, force_fn_graph, shift, dt/10, masses,
                      dR_max=1.0e10, stride=stride, runs=10*runs)

model_states = forward_sim_graph_model(R0, V0)
save_ovito(f"chain_exp/chain_T2_model_free.data", NVEStates(model_states), length=10.0,
           insert_origin=True)

In [None]:
# savefile(f"chain_exp/T2_states.pkl", {"model":model_states, "states":states_long}, 
#          metadata=dict(dt=dt/10, stride=stride, runs=10*runs, samples=1))

## END

In [None]:
data = loadfile(f"chain_exp/T2_states.pkl")[0]

model_states = data["model"]
states_long = data["states"]

In [None]:
@jit
def Energy(R, V, params):
    graph = GraphsTuple(
        nodes={
            "position": R,
            "velocity": V,
            "acceleration": 0*V,
        },
        edges={
            "params": params,
            "species": species,
        },
        senders=jnp.array(sends+recs),
        receivers=jnp.array(recs+sends),
        globals=None,
        n_node=jnp.array([len(R)]),
        n_edge=None
    )
    nodes = net(graph).nodes
    return jnp.array([nodes["lag"].sum(), nodes["ham"].sum(), nodes["ke"].sum(), nodes["pe"].sum()])

In [None]:
v_Energy = vmap(Energy, in_axes=(0, 0, None))

In [None]:
E_pred = v_Energy(model_states.position, model_states.velocity, None)
E = v_Energy(states_long.position, states_long.velocity, None)

plt.plot(E[:, 0]*0, "--")

plt.plot(E, label=["L", "H", "KE", "PE"], lw=6)

plt.plot(E_pred, label=["L_pred", "H_pred", "KE_pred", "PE_pred"], ls="--", color="k")

plt.ylabel("Energy")
plt.xlabel("Time")
plt.legend(ncol=2, loc=2, bbox_to_anchor=(1,1))

In [None]:
def std_plot(y_, semilog=True, dt=1.0):
    mean_ = jnp.log(jnp.array(y_)).mean(axis=0)
    std_ = jnp.log(jnp.array(y_)).std(axis=0)

    up_b = jnp.exp(mean_ + 2*std_)
    low_b = jnp.exp(mean_ - 2*std_)
    y = jnp.exp(mean_)

    x = jnp.array(range(len(mean_)))*dt
    if semilog:
        plt.semilogy(x, y)
    else:
        plt.plot(x, y)
    plt.fill_between(x, low_b, up_b, alpha=0.5)

In [None]:
def norm(a):
    a2 = jnp.square(a)
    n = len(a2)
    a3 = a2.reshape(n, -1)
    return jnp.sqrt(a3.sum(axis=1))

def RelErr(ya, yp):
    return norm(ya-yp) / (norm(ya) + norm(yp))

def AbsErr(ya, yp):
    return norm(ya-yp)  

In [None]:
def getall_state_H(model_states,states_long):
    out1 = []
    out2 = []
    for model_state, state in zip(model_states, states_long):
        E_pred = v_Energy(model_state.position, model_state.velocity, None)[:, 1]
        E = v_Energy(state.position, state.velocity, None)[:, 1]
        out1 += [RelErr(E_pred, E)]
        out2 += [AbsErr(E_pred, E)]
    return jnp.array(out1), jnp.array(out2)


In [None]:
RH_error, AH_error = getall_state_H([model_states], [states_long])

In [None]:

std_plot(RH_error, dt=stride*dt/10)
plt.ylabel("Energy error")
plt.xlabel("Time step")

plt.savefig(f"chain_exp/T2_RH_error_new.png", dpi=600)
savefile(f"chain_exp/T2_RH_error_new.pkl", RH_error)

In [None]:

std_plot(AH_error, dt=stride*dt/10)
plt.ylabel("Energy error")
plt.xlabel("Time step")

plt.savefig(f"chain_exp/T2_AH_error_new.png", dpi=600)
savefile(f"chain_exp/T2_AH_error_new.pkl", AH_error)

In [None]:
def getall_state_Z(model_states, states_long):
    out1 = []
    out2 = []
    for model_state, state in zip(model_states, states_long):
        out1 += [RelErr(model_state.position, state.position)]
        out2 += [AbsErr(model_state.position, state.position)]
    return jnp.array(out1), jnp.array(out2)

In [None]:
RZ_error, AZ_error = getall_state_Z([model_states], [states_long])


In [None]:

std_plot(RZ_error, dt=stride*dt/10)
plt.ylabel("Rollout error")
plt.xlabel("Time step")

plt.savefig(f"chain_exp/T2_RZ_error_new.png", dpi=600)

savefile(f"chain_exp/T2_RZ_error_new.pkl", RZ_error)

In [None]:

std_plot(AZ_error, dt=stride*dt/10)
plt.ylabel("Rollout error")
plt.xlabel("Time step")

plt.savefig(f"chain_exp/T2_AZ_error_new.png", dpi=600)

savefile(f"chain_exp/T2_AZ_error_new.pkl", AZ_error)

In [None]:
make_plot(states_long)
make_plot(model_states)

In [None]:

@jit
def potE(R, params=None):
    r = (R[1:, :] + R[:-1, :]) / 2
    if params is None:
        return (mass * _g * r[:, 1]).sum()
    else:
        return vmap(forward_pass, in_axes=(None, 0))(params, r).sum()


potE(R)


@jit
def kinE_l(R, V, params=None):
    v = (V[1:, :] + V[:-1, :]) / 2
    if params is None:
        kin1 = 0.5 * mass * jnp.square(v).sum()
    else:
        kin1 = jnp.square(
            vmap(forward_pass, in_axes=(None, 0))(params, v)).sum()
    return kin1


kinE_l(R, V)


def kinE_a(R, V, params=None):
    r_ = (R[1:, :] - R[:-1, :])
    v_ = (V[1:, :] - V[:-1, :])
    w = vmap(get_omega)(r_, v_)
    if params is None:
        I = get_I(mass, length)
        kin2 = 0.5 * I * jnp.square(w).sum()
    else:
        kin2 = jnp.square(
            vmap(forward_pass, in_axes=(None, 0))(params, w)).sum()
    return kin2


kinE_a(R, V)


@jit
def kinE_(R, V, params=None):
    return kinE_l(R, V, params=params["l"]) + kinE_a(R, V, params=params["a"])


@jit
def kinE(R, V, params=None):
    if params is None:
        params = {"l": None, "a": None}
    return kinE_(R, V, params=params)
    # return 0.5 * mass * jnp.square(V).sum()

def make_plot(states):
    PE = []
    KE_l = []
    KE_a = []
    KE = []
    time = []
    dt = []
    time_ = 0.0
    for state in NVEStates(states):
        PE.append(potE(state.position))
        KE.append(kinE(state.position, state.velocity))
        KE_l.append(kinE_l(state.position, state.velocity))
        KE_a.append(kinE_a(state.position, state.velocity))
        time.append(state.time)
        dt.append(state.time - time_)
        time_ = state.time

    KE_l = jnp.array(KE_l)
    KE_a = jnp.array(KE_a)
    PE = jnp.array(PE)
    KE = jnp.array(KE)

    PE = PE - PE[0]
    KE = KE - KE[0]
    KE_l = KE_l - KE_l[0]
    KE_a = KE_a - KE_a[0]
    TE = KE + PE

    fix, ax = plt.subplots(1, 3, figsize=(18, 6), sharex=True)

    plt.sca(ax[0])
    plt.plot(time, 0*TE, "-k", label="x-axis", alpha=0.5)
    plt.plot(time, PE, label="PE")
    plt.plot(time, KE, label="KE")
    plt.plot(time, TE, label="TE")
    plt.plot(time, KE_l, "--", label="KE_l")
    plt.plot(time, KE_a, "--", label="KE_a")
    plt.legend()

    plt.sca(ax[1])
    plt.plot(time, TE)
    plt.ylabel("Hamiltonian")

    plt.sca(ax[2])
    plt.semilogy(time, dt)
    plt.ylabel("dt")

    plt.show()
