## Import

In [None]:
import jax
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import src
from frozendict import frozendict
from jax import grad, jit
from jax import numpy as jnp
from jax import value_and_grad, vmap
from jax.config import config
from jax.experimental import optimizers
from jraph import GraphNetwork, GraphsTuple
from src import lnn, utils
from src.io import *
from src.lnn import accelerationFull
from src.md import displacement, prediction, shift
from src.models import forward_pass, initialize_mlp
from src.nve import NVEState, NVEStates
from shadow.plot import *
from shadow.font import *
set_font_family(family="Arial")

config.update("jax_enable_x64", True)

try:
    jax.tree_util.register_pytree_node(
    frozendict,
    flatten_func=lambda s: (tuple(s.values()), tuple(s.keys())),
    unflatten_func=lambda k, xs: frozendict(zip(k, xs)))
except:
    pass

## Config

In [None]:
n = 8
LENGTH = 2
link_size = LENGTH / n


def rot(x, t):
    theta = t * jnp.pi / 180
    sin = jnp.sin(theta)
    cos = jnp.cos(theta)
    ROT = jnp.array([
        [cos, -sin, 0],
        [sin, cos, 0],
        [0, 0, 1]
    ])
    return jnp.dot(ROT, x.T).T

def getRV(dim=2):
    R = [jnp.array([[0.0, 0, 0],])]
    for _ in range(2*n-1):
        point = jnp.array([[0.0, -1.0, 0.0], ])
        point = rot(point, np.random.randn()*45)
        R += R[-1] + point

    R = jnp.vstack(R)
    dim = 2

    if dim == 2:
        R = R[:, :2]
    return R, 0*R

#     R1 = link_size * \
#         jnp.vstack([jnp.zeros(n+1), -jnp.arange(0, n+1), jnp.zeros(n+1)]).T
#     R1 = rot(R1, 45)

#     R2 = link_size * \
#         jnp.vstack([jnp.zeros(n) + 0, -n+1 + jnp.arange(0, n), jnp.zeros(n)]).T
#     R2 = np.array(rot(R2, -45))
#     R2[:, 0] += LENGTH*np.sqrt(2)

#     R = jnp.vstack((R1, R2))
#     R = rot(R, 0)

R, V = getRV()

N, dim = R.shape


R0 = R
V0 = V


mass = 1.0
length = 1.0
_g = 10.0
dt = 1.0e-5
stride = 100
runs = 100


sends = [i for i in range(len(R)-1)]
recs = [i for i in range(1, len(R))]

key = jax.random.PRNGKey(0)

nspecies = 1

def OHE(x):
    return jax.nn.one_hot(x, nspecies)

species = OHE(jnp.ones(len(sends+recs)))


R.shape

## END

In [None]:
def std_plot(y_, semilog=True, dt=1.0, name="label"):
    mean_ = jnp.log(jnp.array(y_)).mean(axis=0)
    std_ = jnp.log(jnp.array(y_)).std(axis=0)

    up_b = jnp.exp(mean_ + 2*std_)
    low_b = jnp.exp(mean_ - 2*std_)
    y = jnp.exp(mean_)

    x = jnp.array(range(len(mean_)))*dt
    if semilog:
        plt.semilogy(x, y, label=name)
    else:
        plt.plot(x, y, label=name)
    plt.fill_between(x, low_b, up_b, alpha=0.5)

In [None]:
def norm(a):
    a2 = jnp.square(a)
    n = len(a2)
    a3 = a2.reshape(n, -1)
    return jnp.sqrt(a3.sum(axis=1))

def RelErr(ya, yp):
    return norm(ya-yp) / (norm(ya) + norm(yp))


In [None]:
n=2
data, metadata = loadfile(f"chain_exp/n{n}_states.pkl")
dt = metadata["dt"]
stride = metadata["stride"]

xlims = [0, 0.1]
which_error = "R"

if which_error=="R":
    npoints = len(jnp.arange(0, 1.0, stride*dt))
else:
    npoints = len(jnp.arange(0, xlims[1], stride*dt))


Hs = {}
Zs = {}
NAME = {"syst": "LGNN", "syst_gns": "GNS"}
for syst in ["syst", "syst_gns"]:
    Hs[syst] = {}
    Zs[syst] = {}
    for n in [2, 4, 8, 100]:
        H_error, _ = loadfile(f"chain_exp/n{n}_{which_error}H_error{syst[4:]}_new.pkl")
        Z_error, _ = loadfile(f"chain_exp/n{n}_{which_error}Z_error{syst[4:]}_new.pkl")
        # H_error = None
        # Z_error = None
        Hs[syst][n] = H_error[:, :npoints]
        Zs[syst][n] = Z_error[:, :npoints]
    

    


fig, axs = plt.subplots(nrows=2, ncols=3,
                                sharey='row', figsize=(6*3, 6*2), dpi=100)

for syst in ["syst", "syst_gns"]:
    ax_ = iter(axs.flatten())

    for k in Hs[syst]:
        ax = next(ax_)
        H_error = Hs[syst][k]
        if H_error is not None:
            plt.sca(ax)
            std_plot(H_error, dt=stride*dt)
            plt.title(f"{2*k} links")
        if which_error=="R":
            plt.xlim(xlims)

    for k in Zs[syst]:
        ax = next(ax_)
        Z_error = Zs[syst][k]
        if Z_error is not None:
            plt.sca(ax)
            std_plot(Z_error, dt=stride*dt, name=NAME[syst])
        if which_error=="R":
            plt.xlim(xlims)
        plt.ylim([1.0e-10, 1.0e10])
plt.sca(axs[1, 0])
plt.legend(bbox_to_anchor=(1.2, -0.2), loc=2, ncol=2)
    
axs[0,0].set_ylabel("Energy error")
axs[1,0].set_ylabel("Rollout error")
axs[1,1].set_xlabel("Time")
savefig(f"notebooks/Energy_rollout_{which_error}_{xlims[1]}_error.png", dpi=600)

In [None]:
H_error, _ = loadfile(f"chain_exp/T1_states.pkl")
dt  = dict(syst = _["dt"], syst_gns=10*_["dt"])
stride  = _["stride"]
    
xlims = [-0.01, 1.0]
which_error = "R"

Hs = {}
Zs = {}
NAME = {"syst": "LGNN", "syst_gns": "GNS"}

for syst in ["syst", "syst_gns"]:
    Hs[syst] = {}
    Zs[syst] = {}
    for n in [1, 2, 3]:
        H_error, _ = loadfile(f"chain_exp/T{n}_H_error{syst[4:]}.pkl")
        Z_error, _ = loadfile(f"chain_exp/T{n}_Z_error{syst[4:]}.pkl")
        
        if which_error=="R":
            npoints = len(jnp.arange(0, 1.0, stride*dt[syst]))
        else:
            npoints = len(jnp.arange(0, xlims[1], stride*dt[syst]))
        
        Hs[syst][n] = dt[syst], stride, H_error[:, :npoints]
        Zs[syst][n] = dt[syst], stride, Z_error[:, :npoints]
    


fig, axs = plt.subplots(nrows=2, ncols=3,
                                sharey='row', figsize=(6*3, 6*2), dpi=100)

LABEL = {1:"D", 2:"E", 3:"F"}

for syst in ["syst", "syst_gns"]:
    ax_ = iter(axs.flatten())

    for k in Hs[syst]:
        dt, stride, H_error = Hs[syst][k]
        ax = next(ax_)
        plt.sca(ax)
        std_plot(H_error, dt=stride*dt, name=NAME[syst])
        plt.title(f"{LABEL[k]}")
        plt.xlim(xlims)
        # print(k, syst, H_error)
        
    for k in Zs[syst]:
        dt, stride, Z_error = Zs[syst][k]
        ax = next(ax_)
        plt.sca(ax)
        std_plot(Z_error, dt=stride*dt, name=NAME[syst])
        plt.xlim(xlims)
        
plt.sca(axs[1, 0])
plt.legend(bbox_to_anchor=(0.8, -0.2), loc=2, ncol=2)


axs[0,0].set_ylabel("Energy error")
axs[1,0].set_ylabel("Rollout error")
plt.text(-0.15, -0.2, "Time", transform=axs[1,1].transAxes)
savefig(f"notebooks/Energy_rollout_error_T_{xlims[1]}_.png", dpi=600)

In [None]:
H_error, _ = loadfile(f"chain_exp/n{n}_H_error.pkl")
std_plot(H_error, dt=stride*dt)
plt.ylabel("Energy error")
plt.xlabel("Time step")


In [None]:
Z_error, _ = loadfile(f"chain_exp/n{n}_Z_error.pkl")
std_plot(Z_error, dt=stride*dt)
plt.ylabel("Energy error")
plt.xlabel("Time step")
