In [None]:
################################################
################## IMPORT ######################
################################################

import json
import sys
import os
from datetime import datetime
from functools import partial, wraps
from statistics import mode

import fire
import jax
import jax.numpy as jnp
import numpy as np
from jax import jit, random, value_and_grad, vmap
from jax.experimental import optimizers
from jax_md import space
from pyexpat import model
from shadow.plot import *
import time

# from psystems.nsprings import (chain, edge_order, get_connections,
#                                get_fully_connected_senders_and_receivers,
#                                get_fully_edge_order, get_init)
from psystems.nbody import (get_fully_connected_senders_and_receivers,get_fully_edge_order, get_init_conf)

MAINPATH = ".."  # nopep8
sys.path.append(MAINPATH)  # nopep8

import jraph
import src
from jax.config import config
# from src.graph import *
from src.graph_interpretability import *
from src.md import *
from src.models import MSE, initialize_mlp
from src.nve import NVEStates, nve
from src.utils import *
from src.hamiltonian import *

config.update("jax_enable_x64", True)
config.update("jax_debug_nans", True)

def namestr(obj, namespace):
    return [name for name in namespace if namespace[name] is obj]


def pprint(*args, namespace=globals()):
    for arg in args:
        print(f"{namestr(arg, namespace)[0]}: {arg}")

def rcParams_Set():
    plt.rcParams['font.weight']='normal'
    plt.rcParams["axes.labelweight"] = "normal"

rcParams_Set()

In [None]:
def rcParams_Set():
    plt.rcParams['axes.linewidth']=2.5
    plt.rcParams['lines.linewidth']=2.5
    plt.rcParams["font.family"]='CMU Serif'
    plt.rcParams['font.weight']='normal'
    plt.rcParams["axes.labelweight"] = "normal"
    plt.rcParams['axes.unicode_minus'] = False
    plt.rcParams['font.size']=20
    plt.rcParams['axes.labelsize']='large'
    plt.rcParams['xtick.labelsize']='large'
    plt.rcParams['ytick.labelsize']='large'
    plt.rcParams['xtick.direction']='in'
    plt.rcParams['ytick.direction']='in'
    plt.rcParams['xtick.top']='True'
    plt.rcParams['ytick.right']='True'
    plt.rcParams['ytick.labelright']='False'
    plt.rcParams['xtick.labeltop']='False'
    plt.rcParams['xtick.major.size'] = 10.0
    plt.rcParams['xtick.minor.size'] = 5.0
    plt.rcParams['ytick.major.size'] = 10.0
    plt.rcParams['ytick.minor.size'] = 5.0
    plt.rcParams['xtick.major.width']=2.5
    plt.rcParams['ytick.major.width']=2.5
    plt.rcParams['xtick.minor.visible']='True'
    plt.rcParams['ytick.minor.visible']='True'
    plt.rcParams['xtick.minor.width']=2.5
    plt.rcParams['ytick.minor.width']=2.5

rcParams_Set()

In [None]:
N=4
dim=3
dt=1.0e-3
stride=100
useN=4
withdata=None
datapoints=100
grid=False
ifdrag=0
seed=42
rname=0
saveovito=1
trainm=1
runs=100
semilog=1
maxtraj=100
plotthings=False
redo=0
ifDataEfficiency = 0
if_noisy_data=0

In [None]:
# def main(N=4, dim=3, dt=1.0e-3, stride=100, useN=4, withdata=None, datapoints=100, grid=False, ifdrag=0, seed=42, rname=0, saveovito=1, trainm=1, runs=100, semilog=1, maxtraj=100, plotthings=False, redo=0, ifDataEfficiency = 0, if_noisy_data=0):
if (ifDataEfficiency == 1):
    data_points = int(sys.argv[1])
    batch_size = int(data_points/100)

print("Configs: ")
pprint(dt, ifdrag, namespace=locals())

PSYS = f"{N}-body"
TAG = f"hgnn"

if (ifDataEfficiency == 1):
    out_dir = f"../data-efficiency"
elif (if_noisy_data == 1):
    out_dir = f"../noisy_data"
else:
    out_dir = f"../results"

randfilename = datetime.now().strftime(
    "%m-%d-%Y_%H-%M-%S") + f"_{datapoints}"

def _filename(name, tag=TAG, trained=None):
    if tag == "data":
        part = f"_{ifdrag}."
    else:
        part = f"_{ifdrag}_{trainm}."

    name = ".".join(name.split(".")[:-1]) + part + name.split(".")[-1]
    rstring  = "0" if (tag != "data" ) else "2"
    if (ifDataEfficiency == 1):
        rstring = "2_" + str(data_points)

    if (tag == "data"):
        filename_prefix = f"../results/{PSYS}-{tag}/2_test/"
    elif (trained is not None):
        psys = f"{trained}-{PSYS.split('-')[1]}"
        filename_prefix = f"{out_dir}/{psys}-{tag}/{rstring}/"
    else:
        filename_prefix = f"{out_dir}/{PSYS}-{tag}/{rstring}/"

    file = f"{filename_prefix}/{name}"
    os.makedirs(os.path.dirname(file), exist_ok=True)
    filename = f"{filename_prefix}/{name}".replace("//", "/")
    print("===", filename, "===")
    return filename

def OUT(f):
    @wraps(f)
    def func(file, *args, tag=TAG, trained=None, **kwargs):
        return f(_filename(file, tag=tag, trained=trained),
                    *args, **kwargs)
    return func

def _fileexist(f):
    if redo:
        return False
    else:
        return os.path.isfile(f)

loadmodel = OUT(src.models.loadmodel)
savemodel = OUT(src.models.savemodel)

loadfile = OUT(src.io.loadfile)
savefile = OUT(src.io.savefile)
save_ovito = OUT(src.io.save_ovito)
fileexist = OUT(_fileexist)



In [None]:
################################################
################## CONFIG ######################
################################################
np.random.seed(seed)
key = random.PRNGKey(seed)

# if grid:
#     a = int(np.sqrt(N))
#     senders, receivers = get_connections(a, a)
#     eorder = edge_order(len(senders))
# else:
#     # senders, receivers = get_fully_connected_senders_and_receivers(N)
#     # eorder = get_fully_edge_order(N)
#     print("Creating Chain")
#     _, _, senders, receivers = chain(N)
#     eorder = edge_order(len(senders))
senders, receivers = get_fully_connected_senders_and_receivers(N)
eorder = get_fully_edge_order(N)

dataset_states = loadfile(f"model_states.pkl", tag="data")[0]
z_out, zdot_out = dataset_states[0]
xout, pout = jnp.split(z_out, 2, axis=1)

R = xout[0]
V = pout[0]

print(f"Total number of training data points: {len(dataset_states)}x{z_out.shape}")

N, dim = xout.shape[-2:]
species = jnp.zeros(N, dtype=int)
masses = jnp.ones(N)



In [None]:
################################################
################## SYSTEM ######################
################################################

# def pot_energy_orig(x):
#     dr = jnp.square(x[senders, :] - x[receivers, :]).sum(axis=1)
#     return jax.vmap(partial(src.hamiltonian.SPRING, stiffness=1.0, length=1.0))(dr).sum()

def pot_energy_orig(x):
    dr = jnp.sqrt(jnp.square(x[senders, :] - x[receivers, :]).sum(axis=1))
    return vmap(partial(lnn.GRAVITATIONAL, Gc = 1))(dr).sum()/2

kin_energy = partial(src.hamiltonian._T, mass=masses)

def Hactual(x, p, params):
    return kin_energy(p) + pot_energy_orig(x)

# def phi(x):
#     X = jnp.vstack([x[:1, :]*0, x])
#     return jnp.square(X[:-1, :] - X[1:, :]).sum(axis=1) - 1.0

# constraints = get_constraints(N, dim, phi)

def external_force(x, v, params):
    F = 0*R
    F = jax.ops.index_update(F, (1, 1), -1.0)
    return F.reshape(-1, 1)

if ifdrag == 0:
    print("Drag: 0.0")

    def drag(x, p, params):
        return 0.0
elif ifdrag == 1:
    print("Drag: -0.1*v")

    def drag(x, p, params):
        return -0.1*p.reshape(-1, 1)

zdot, lamda_force = get_zdot_lambda(
    N, dim, hamiltonian=Hactual, drag=drag, constraints=None)

def zdot_func(z, t, params):
    x, p = jnp.split(z, 2)
    return zdot(x, p, params)

def z0(x, p):
    return jnp.vstack([x, p])

def get_forward_sim(params=None, zdot_func=None, runs=10):
    def fn(R, V):
        t = jnp.linspace(0.0, runs*stride*dt, runs*stride)
        _z_out = ode.odeint(zdot_func, z0(R, V), t, params)
        return _z_out[0::stride]
    return fn

sim_orig = get_forward_sim(
    params=None, zdot_func=zdot_func, runs=maxtraj*runs)
# z_out = sim_orig(R, V)

# if fileexist("gt_trajectories.pkl"):
#     print("Loading from saved.")
#     full_traj, metadata = loadfile("gt_trajectories.pkl")
#     full_traj = NVEStates(full_traj)
#     if metadata["key"] != f"maxtraj={maxtraj}, runs={runs}":
#         print("Metadata doesnot match.")
#         full_traj = NVEStates(simGT())
# else:
#     full_traj = NVEStates(simGT())


In [None]:
################################################
################### ML Model ###################
################################################

def H_energy_fn(params, graph):
    # g, g_PE, g_KE = cal_graph(params, graph, eorder=eorder,useT=True)
    # return g_PE + g_KE
    g, PEij, PEi, KE, _ = cal_graph(params, graph, eorder=eorder,useT=True)
    return KE.sum() + PEij.sum() + PEi.sum()
    

state_graph = jraph.GraphsTuple(nodes={
    "position": R,
    "velocity": V,
    "type": species,
},
    edges={},
    senders=senders,
    receivers=receivers,
    n_node=jnp.array([N]),
    n_edge=jnp.array([senders.shape[0]]),
    globals={})

def energy_fn(species):
    # senders, receivers = [np.array(i)
    #                       for i in Spring_connections(R.shape[0])]
    state_graph = jraph.GraphsTuple(nodes={
        "position": R,
        "velocity": V,
        "type": species
    },
        edges={},
        senders=senders,
        receivers=receivers,
        n_node=jnp.array([R.shape[0]]),
        n_edge=jnp.array([senders.shape[0]]),
        globals={})

    def apply(R, V, params):
        state_graph.nodes.update(position=R)
        state_graph.nodes.update(velocity=V)
        # jax.tree_util.tree_map(lambda a: print(a.shape), state_graph.nodes)
        return H_energy_fn(params, state_graph)
    return apply

apply_fn = energy_fn(species)
v_apply_fn = vmap(apply_fn, in_axes=(None, 0))

def Hmodel(x, v, params):
    return apply_fn(x, v, params["H"])

def nndrag(v, params):
    return - jnp.abs(models.forward_pass(params, v.reshape(-1), activation_fn=models.SquarePlus)) * v

if ifdrag == 0:
    print("Drag: 0.0")

    def drag(x, v, params):
        return 0.0
elif ifdrag == 1:
    print("Drag: nn")

    def drag(x, v, params):
        return vmap(nndrag, in_axes=(0, None))(v.reshape(-1), params["drag"]).reshape(-1, 1)

zdot_model, lamda_force_model = get_zdot_lambda(
    N, dim, hamiltonian=Hmodel, drag=drag, constraints=None)

def zdot_model_func(z, t, params):
    x, p = jnp.split(z, 2)
    return zdot_model(x, p, params)

params = loadfile(f"trained_model.dil", trained=useN)[0]

sim_model = get_forward_sim(
    params=params, zdot_func=zdot_model_func, runs=runs)


In [None]:
def KE_PE_H_fn(params, graph):
    g, PEij, PEi, KE, _ = cal_graph(params, graph, eorder=eorder,useonlyedge=True, useT=True)
    return KE.sum(), PEij.sum() + PEi.sum(), KE.sum() + PEij.sum() + PEi.sum()

def energy_fn_kph(species):
    senders, receivers = get_fully_connected_senders_and_receivers(N)
    # eorder = get_fully_edge_order(N)

    state_graph = jraph.GraphsTuple(nodes={
        "position": R,
        "velocity": V,
        "type": species
    },
        edges={},
        senders=senders,
        receivers=receivers,
        n_node=jnp.array([R.shape[0]]),
        n_edge=jnp.array([senders.shape[0]]),
        globals={})
    
    def apply_kph(R, V, params):
        state_graph.nodes.update(position=R)
        state_graph.nodes.update(velocity=V)
        return KE_PE_H_fn(params, state_graph)
    return apply_kph

apply_fn_kph = energy_fn_kph(species)

def KE_PE_H_model(x, v, params):
    return apply_fn_kph(x, v, params["H"])



In [None]:
def nodewise_KE_PE_H_fn(params, graph):
    g, PEij, PEi, KE, drij = cal_graph(params, graph, eorder=eorder,useonlyedge=True, useT=True)
    return PEij, PEi, KE, drij

def energy_fn_node_kph(species):
    senders, receivers = get_fully_connected_senders_and_receivers(N)
    
    state_graph = jraph.GraphsTuple(nodes={
        "position": R,
        "velocity": V,
        "type": species
    },
        edges={},
        senders=senders,
        receivers=receivers,
        n_node=jnp.array([R.shape[0]]),
        n_edge=jnp.array([senders.shape[0]]),
        globals={})
    
    def apply_kph(R, V, params):
        state_graph.nodes.update(position=R)
        state_graph.nodes.update(velocity=V)
        return nodewise_KE_PE_H_fn(params, state_graph)
    return apply_kph

apply_fn_node_kph = energy_fn_node_kph(species)

def node_KE_PE_H_model(x, v, params):
    return apply_fn_node_kph(x, v, params["H"])


In [None]:
# z_model_out = sim_model(R, V)

################################################
############## forward simulation ##############
################################################

def norm(a):
    a2 = jnp.square(a)
    n = len(a2)
    a3 = a2.reshape(n, -1)
    return jnp.sqrt(a3.sum(axis=1))

def RelErr(ya, yp):
    return norm(ya-yp) / (norm(ya) + norm(yp))

def Err(ya, yp):
    return ya-yp

def AbsErr(*args):
    return jnp.abs(Err(*args))

def caH_energy_fn(lag=None, params=None):
    def fn(states):
        KE = vmap(kin_energy)(states.velocity)
        H = vmap(lag, in_axes=(0, 0, None)
                    )(states.position, states.velocity, params)
        PE = (H - KE)
        # return jnp.array([H]).T
        return jnp.array([PE, KE, H, KE+PE]).T
    return fn

Es_fn = caH_energy_fn(lag=Hactual, params=None)
Es_pred_fn = caH_energy_fn(lag=Hmodel, params=params)

def net_force_fn(force=None, params=None):
    def fn(states):
        zdot_out = vmap(force, in_axes=(0, 0, None))(
            states.position, states.velocity, params)
        _, force_out = jnp.split(zdot_out, 2, axis=1)
        return force_out
    return fn

net_force_orig_fn = net_force_fn(force=zdot)
net_force_model_fn = net_force_fn(force=zdot_model, params=params)

nexp = {
    "z_pred": [],
    "z_actual": [],
    "Zerr": [],
    "Herr": [],
    "E": [],
    "Perr": [],
}

trajectories = []

sim_orig2 = get_forward_sim(params=None, zdot_func=zdot_func, runs=runs)
t = 0.0
ind = 0
print(f"Simulating trajectory {ind}/{maxtraj} ...")

z_out, _ = dataset_states[0]
xout, pout = jnp.split(z_out, 2, axis=1)

R = xout[ind*69]
V = pout[ind*69]

z_actual_out = sim_orig2(R, V)  # full_traj[start_:stop_]
x_act_out, p_act_out = jnp.split(z_actual_out, 2, axis=1)
zdot_act_out = jax.vmap(zdot, in_axes=(0, 0, None))(
    x_act_out, p_act_out, None)
_, force_act_out = jnp.split(zdot_act_out, 2, axis=1)
my_state = States()
my_state.position = x_act_out
my_state.velocity = p_act_out
my_state.force = force_act_out
my_state.mass = jnp.ones(x_act_out.shape[0])
actual_traj = my_state

start = time.time()
z_pred_out = sim_model(R, V)
x_pred_out, p_pred_out = jnp.split(z_pred_out, 2, axis=1)
zdot_pred_out = jax.vmap(zdot_model, in_axes=(
    0, 0, None))(x_pred_out, p_pred_out, params)
_, force_pred_out = jnp.split(zdot_pred_out, 2, axis=1)
my_state_pred = States()
my_state_pred.position = x_pred_out
my_state_pred.velocity = p_pred_out
my_state_pred.force = force_pred_out
my_state_pred.mass = jnp.ones(x_pred_out.shape[0])
pred_traj = my_state_pred
end = time.time()
t += end - start

if saveovito:
    if ind < 1:
        save_ovito(f"pred_{ind}.data", [
            state for state in NVEStates(pred_traj)], lattice="")
        save_ovito(f"actual_{ind}.data", [
            state for state in NVEStates(actual_traj)], lattice="")
    else:
        pass

trajectories += [(actual_traj, pred_traj)]

key='pred'
traj=pred_traj

# for key, traj in {"actual": actual_traj, "pred": pred_traj}.items():

print(f"plotting energy ({key})...")

net_force_orig = net_force_orig_fn(traj)
net_force_model = net_force_model_fn(traj)

fig, axs = panel(1+R.shape[0], 1, figsize=(20,
                                        R.shape[0]*5), hshift=0.1, vs=0.35)
for i, ax in zip(range(R.shape[0]+1), axs):
    if i == 0:
        ax.text(0.6, 0.8, "Averaged over all particles",
                transform=ax.transAxes, color="k")
        ax.plot(net_force_orig.sum(axis=1), lw=6, label=[
                r"$F_x$", r"$F_y$", r"$F_z$"][:R.shape[1]], alpha=0.5)
        ax.plot(net_force_model.sum(axis=1), "--", color="k")
        ax.plot([], "--", c="k", label="Predicted")
    else:
        ax.text(0.6, 0.8, f"For particle {i}",
                transform=ax.transAxes, color="k")
        ax.plot(net_force_orig[:, i-1, :], lw=6, label=[r"$F_x$",
                r"$F_y$", r"$F_z$"][:R.shape[1]], alpha=0.5)
        ax.plot(net_force_model[:, i-1, :], "--", color="k")
        ax.plot([], "--", c="k", label="Predicted")

    ax.legend(loc=2, bbox_to_anchor=(1, 1),
            labelcolor="markerfacecolor")
    ax.set_ylabel("Net force")
    ax.set_xlabel("Time step")
    ax.set_title(f"{N}-Spring Exp {ind}")
# , dpi=500)
plt.show()
# plt.savefig(_filename(f"net_force_Exp_{ind}_{key}.png"))


Es = Es_fn(actual_traj)
Eshat = Es_fn(pred_traj)
Eshat = Eshat - Eshat[0] + Es[0]
H = Es[:, -1]
Hhat = Eshat[:, -1]


Es = Es_fn(traj)
Es_pred = Es_pred_fn(traj)
Es_pred = Es_pred - Es_pred[0] + Es[0]

fig, axs = panel(1, 1, figsize=(20, 5))
axs[0].plot(Es, label=["PE", "KE", "L", "TE"], lw=6, alpha=0.5)
axs[0].plot(Es_pred, "--", label=["PE", "KE", "L", "TE"])
plt.legend(bbox_to_anchor=(1, 1), loc=2)
axs[0].set_facecolor("w")

xlabel("Time step", ax=axs[0])
ylabel("Energy", ax=axs[0])

title = f"(HGNN) {N}-Spring Exp {ind}"
plt.title(title)
# plt.savefig(_filename(title.replace(
#     " ", "-")+f"_{key}.png"))  # , dpi=500)
plt.show()


In [None]:
states = pred_traj
vij, vi, KE, drij = vmap(node_KE_PE_H_model, in_axes=(0, 0, None))(states.position, states.velocity, params)
drij = drij.reshape(runs,-1)

# Es = Es_fn(traj)
# Eshat = Es_pred_fn(traj)
# Eshat = Eshat - Eshat[0] + Es[0]

dr = traj.position[:,senders,:]-traj.position[:,receivers,:]
drij_ac = jnp.sqrt(jnp.square(dr).sum(axis=2))
vij_actual = 0.5*1*(jnp.square(drij_ac-1))
KE_actual = 0.5*1*(jnp.square(pred_traj.velocity).sum(axis=2))

vij = vij - vij[0] + vij_actual[0]

In [None]:
states = pred_traj
vij, vi, KE, drij = vmap(node_KE_PE_H_model, in_axes=(0, 0, None))(states.position, states.velocity, params)
eij = jnp.sqrt(1e-10 + jnp.square(drij).sum(axis=1))
vi = vi-vi[0]
KE = KE-KE[0]

pos_y = states.position[:,:,1]
pos_y = pos_y-pos_y[0]

for i in range(5):
    i=4-i
    axs[0].plot(pos_y[:,i], vi[:,i],'.',alpha=0.4, label=f"{i+1} pendulum")

axs[0].plot(pos_y.flatten(), 10*pos_y.flatten(),'-',c='k')#, label=f"{i+1} pendulum")

axs[0].set_xlabel("$h_i$")
axs[0].set_ylabel("$V_i$")
axs[0].legend()

v_magnitude = jnp.sqrt(jnp.square(states.velocity).sum(axis=2))

for i in range(5):
    i=4-i
    axs[1].plot(v_magnitude[:,i], KE[:,i],'.', label=f"{i+1} pendulum",alpha=0.4)

axs[1].plot(v_magnitude.flatten().sort(), 0.5*(v_magnitude.flatten()**2).sort(),'-',c='k')
# ax.set_xlim([-2,2])
# ax.set_ylim([-2,1])
axs[1].set_xlabel("$v_i$")
axs[1].set_ylabel("KE$_{i}$")
axs[1].legend()

In [None]:
fig, axs = panel(1,2)
states = pred_traj[0]

g, vij, vi, KE, drij = vmap(node_KE_PE_H_model, in_axes=(0, 0, None))(states.position, states.velocity, params)
eij = jnp.sqrt(1e-10 + jnp.square(drij).sum(axis=1))
vi = vi-vi[0]
KE = KE-KE[0]

pos_y = states.position[:,:,1]
pos_y = pos_y-pos_y[0]

for i in range(5):
    i=4-i
    axs[0].plot(pos_y[:,i], vi[:,i],'.',alpha=0.4, label=f"{i+1} pendulum")

axs[0].plot(pos_y.flatten(), 10*pos_y.flatten(),'-',c='k')#, label=f"{i+1} pendulum")

axs[0].set_xlabel("$h_i$")
axs[0].set_ylabel("$V_i$")
axs[0].legend()

v_magnitude = jnp.sqrt(jnp.square(states.velocity).sum(axis=2))

for i in range(5):
    i=4-i
    axs[1].plot(v_magnitude[:,i], KE[:,i],'.', label=f"{i+1} pendulum",alpha=0.4)

axs[1].plot(v_magnitude.flatten().sort(), 0.5*(v_magnitude.flatten()**2).sort(),'-',c='k')
# ax.set_xlim([-2,2])
# ax.set_ylim([-2,1])
axs[1].set_xlabel("$v_i$")
axs[1].set_ylabel("KE$_{i}$")
axs[1].legend()

In [None]:
fig, axs = panel(1, 5, figsize=(5*5,5*1))

axs[0].plot(KE_actual,KE,'.') # PE, KE, KE-PE, H
axs[1].plot(vij_actual,vij,'.')


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score, mean_squared_error
from math import sqrt
import matplotlib.ticker as plticker
from matplotlib import cm
from matplotlib.colors import ListedColormap
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.ticker import MaxNLocator
from shadow.plot import *

def make_den_plot_ps(ax, rng, den_scale_bool, den_scale, ytest, ytest_pred,tt_s=0,label='label'):
    plt.sca(ax)
    
    def put_legend():
        ax.plot([], [],'k',ls='none', mew=0,label=label)
    
    put_legend()
    
    if tt_s:
        ax.plot(ytest, ytest_pred, 'ob',mec='k')
        xlabel('$x$')
        ylabel('$\dot{x}$')
    else:
        s= 70
        if rng==None:
            im_data, xedges, yedges= np.histogram2d(ytest.ravel(), ytest_pred.ravel(), bins=(s,s), density=0)#, range=np.array([rng,rng]))
        else:
            im_data, xedges, yedges= np.histogram2d(ytest.ravel(), ytest_pred.ravel(), bins=(s,s), density=0, range=np.array([rng,rng]))
        
        if den_scale_bool:
            mask = im_data > (im_data.mean()+den_scale*im_data.std())
            im_data[mask] = im_data.mean()+den_scale*im_data.std()
            im_data = im_data.astype(int)
        else:
            im_data = im_data.astype(int)
        ocean = cm.get_cmap('gist_heat', 256)
        c_data = ocean(np.linspace(0, 1, 256))
        mycm = ListedColormap(c_data[::-1,:])
        # cb = ax.imshow(im_data.T,cmap = mycm)
        cb = ax.imshow(im_data.T, extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], cmap = mycm)
        xlabel('$x$')
        ylabel('$\dot{x}$')
        legend_on(loc=1)
        return ax,cb



In [None]:
filename = f'../results/4-body-hgnn/0/error_parameter_0_1.pkl'
nexp = pickle.load(open(filename,'rb'))[0]

In [None]:
zx = jnp.array(nexp['z_pred'])[:,:,:,0].reshape(-1,4)
zy = jnp.array(nexp['z_pred'])[:,:,:,1].reshape(-1,4)
vx = jnp.array(nexp['v_pred'])[:,:,:,0].reshape(-1,4)
vy = jnp.array(nexp['v_pred'])[:,:,:,1].reshape(-1,4)

zx_a = jnp.array(nexp['z_actual'])[:,:,:,0].reshape(-1,4)
zy_a = jnp.array(nexp['z_actual'])[:,:,:,1].reshape(-1,4)
vx_a = jnp.array(nexp['v_actual'])[:,:,:,0].reshape(-1,4)
vy_a = jnp.array(nexp['v_actual'])[:,:,:,1].reshape(-1,4)
fig, axs = panel(1,2)

p_id = 3    
axs[0].plot(jnp.array(nexp['z_pred'])[:,:,p_id,0].reshape(-1,1),jnp.array(nexp['z_pred'])[:,:,p_id,1].reshape(-1,1),alpha=0.6)
axs[1].plot(jnp.array(nexp['z_actual'])[:,:,p_id,0].reshape(-1,1),jnp.array(nexp['z_actual'])[:,:,p_id,1].reshape(-1,1),alpha=0.6)

axs[0].plot(jnp.array(nexp['z_pred'])[0,:,p_id,0].reshape(-1,1),jnp.array(nexp['z_pred'])[0,:,p_id,1].reshape(-1,1),alpha=0.9)
axs[1].plot(jnp.array(nexp['z_actual'])[0,:,p_id,0].reshape(-1,1),jnp.array(nexp['z_actual'])[0,:,p_id,1].reshape(-1,1),alpha=0.9)
axs[0].plot(jnp.array(nexp['z_pred'])[1,:,p_id,0].reshape(-1,1),jnp.array(nexp['z_pred'])[1,:,p_id,1].reshape(-1,1),alpha=0.9)
axs[1].plot(jnp.array(nexp['z_actual'])[1,:,p_id,0].reshape(-1,1),jnp.array(nexp['z_actual'])[1,:,p_id,1].reshape(-1,1),alpha=0.9)

# for i in range(100):
#     axs[0].plot(jnp.array(nexp['z_pred'])[:,:,p_id,0].reshape(100,-1)[i,:],jnp.array(nexp['z_pred'])[:,:,p_id,1].reshape(100,-1)[i,:],alpha=0.5)
#     axs[1].plot(jnp.array(nexp['z_actual'])[:,:,p_id,0].reshape(100,-1)[i,:],jnp.array(nexp['z_actual'])[:,:,p_id,1].reshape(100,-1)[i,:],alpha=0.5)


In [None]:
zx = jnp.array(nexp['z_pred'])[:,:,:,0].reshape(-1,4)
zy = jnp.array(nexp['z_pred'])[:,:,:,1].reshape(-1,4)
vx = jnp.array(nexp['v_pred'])[:,:,:,0].reshape(-1,4)
vy = jnp.array(nexp['v_pred'])[:,:,:,1].reshape(-1,4)

zx_a = jnp.array(nexp['z_actual'])[:,:,:,0].reshape(-1,4)
zy_a = jnp.array(nexp['z_actual'])[:,:,:,1].reshape(-1,4)
vx_a = jnp.array(nexp['v_actual'])[:,:,:,0].reshape(-1,4)
vy_a = jnp.array(nexp['v_actual'])[:,:,:,1].reshape(-1,4)
fig, axs = panel(1,2)

p_id = 3
# axs[0].plot(jnp.array(nexp['z_pred'])[:,:,p_id,0].reshape(-1,1),jnp.array(nexp['z_pred'])[:,:,p_id,1].reshape(-1,1),alpha=0.6)
# axs[1].plot(jnp.array(nexp['z_actual'])[:,:,p_id,0].reshape(-1,1),jnp.array(nexp['z_actual'])[:,:,p_id,1].reshape(-1,1),alpha=0.6)

p_id = 3
for i in range(100):
    axs[0].plot(jnp.array(nexp['z_pred'])[i,:,p_id,0].reshape(-1,1),jnp.array(nexp['z_pred'])[i,:,p_id,1].reshape(-1,1), color = 'b', alpha=0.9)
    axs[1].plot(jnp.array(nexp['z_actual'])[i,:,p_id,0].reshape(-1,1),jnp.array(nexp['z_actual'])[i,:,p_id,1].reshape(-1,1), color ='b', alpha=0.9)


# axs[0].plot(jnp.array(nexp['z_pred'])[0,:,p_id,0].reshape(-1,1),jnp.array(nexp['z_pred'])[0,:,p_id,1].reshape(-1,1),alpha=0.9)
# axs[1].plot(jnp.array(nexp['z_actual'])[0,:,p_id,0].reshape(-1,1),jnp.array(nexp['z_actual'])[0,:,p_id,1].reshape(-1,1),alpha=0.9)
# axs[0].plot(jnp.array(nexp['z_pred'])[1,:,p_id,0].reshape(-1,1),jnp.array(nexp['z_pred'])[1,:,p_id,1].reshape(-1,1),alpha=0.9)
# axs[1].plot(jnp.array(nexp['z_actual'])[1,:,p_id,0].reshape(-1,1),jnp.array(nexp['z_actual'])[1,:,p_id,1].reshape(-1,1),alpha=0.9)

# for i in range(100):
#     axs[0].plot(jnp.array(nexp['z_pred'])[:,:,p_id,0].reshape(100,-1)[i,:],jnp.array(nexp['z_pred'])[:,:,p_id,1].reshape(100,-1)[i,:],alpha=0.5)
#     axs[1].plot(jnp.array(nexp['z_actual'])[:,:,p_id,0].reshape(100,-1)[i,:],jnp.array(nexp['z_actual'])[:,:,p_id,1].reshape(100,-1)[i,:],alpha=0.5)


In [None]:
zx = jnp.array(nexp['z_pred'])[:,:,:,0].reshape(-1,4)
zy = jnp.array(nexp['z_pred'])[:,:,:,1].reshape(-1,4)
vx = jnp.array(nexp['v_pred'])[:,:,:,0].reshape(-1,4)
vy = jnp.array(nexp['v_pred'])[:,:,:,1].reshape(-1,4)

zx_a = jnp.array(nexp['z_actual'])[:,:,:,0].reshape(-1,4)
zy_a = jnp.array(nexp['z_actual'])[:,:,:,1].reshape(-1,4)
vx_a = jnp.array(nexp['v_actual'])[:,:,:,0].reshape(-1,4)
vy_a = jnp.array(nexp['v_actual'])[:,:,:,1].reshape(-1,4)
fig, axs = panel(1,5)
p_id = 3
i=p_id
for i in range(4):
    axs[0].plot(jnp.array(nexp['z_pred'])[0,:,i,0],jnp.array(nexp['z_pred'])[0,:,i,1])
    axs[0].plot(jnp.array(nexp['z_pred'])[10,:,i,0],jnp.array(nexp['z_pred'])[10,:,i,1])
    axs[0].plot(jnp.array(nexp['z_pred'])[50,:,i,0],jnp.array(nexp['z_pred'])[50,:,i,1])
    axs[0].plot(jnp.array(nexp['z_pred'])[100,:,i,0],jnp.array(nexp['z_pred'])[100,:,i,1])
# axs[0].plot(jnp.array(nexp['z_pred'])[:,:,i,0].reshape(-1,1),jnp.array(nexp['z_pred'])[:,:,i,1].reshape(-1,1),alpha=0.5)
# axs[1].plot(jnp.array(nexp['z_actual'])[:,:,i,0].reshape(-1,1),jnp.array(nexp['z_actual'])[:,:,i,1].reshape(-1,1),alpha=0.5)


In [None]:
zx = jnp.array(nexp['z_pred'])[:,:,:,0].reshape(-1,4)
zy = jnp.array(nexp['z_pred'])[:,:,:,1].reshape(-1,4)
vx = jnp.array(nexp['v_pred'])[:,:,:,0].reshape(-1,4)
vy = jnp.array(nexp['v_pred'])[:,:,:,1].reshape(-1,4)

zx_a = jnp.array(nexp['z_actual'])[:,:,:,0].reshape(-1,4)
zy_a = jnp.array(nexp['z_actual'])[:,:,:,1].reshape(-1,4)
vx_a = jnp.array(nexp['v_actual'])[:,:,:,0].reshape(-1,4)
vy_a = jnp.array(nexp['v_actual'])[:,:,:,1].reshape(-1,4)

# fig, axs = panel(1, 5, figsize=(5*5,5*1))

fig, axs = panel(1,5)
p_id = 3
# make_den_plot_ps(axs[0],[-1,1],True, 3, np.array(zx[:,p_id]),np.array(vx[:,p_id]),label=f"Predicted",tt_s=1)
# make_den_plot_ps(axs[1],[-1,1],True, 3, np.array(zx_a[:,p_id]),np.array(vx_a[:,p_id]),label=f"Actual",tt_s=1)
# make_den_plot_ps(axs[0],[-1,1],True, 3, np.array(zx[:,p_id])/10,np.array(vx[:,p_id])/1.5,label=f"Predicted",tt_s=0)
# make_den_plot_ps(axs[1],[-1,1],True, 3, np.array(zx_a[:,p_id])/10,np.array(vx_a[:,p_id])/1.5,label=f"Actual",tt_s=0)
for i in range(100):
    axs[0].plot(jnp.array(nexp['z_pred'])[i,:,p_id,0].reshape(-1,1),jnp.array(nexp['z_pred'])[i,:,p_id,1].reshape(-1,1), color = 'b', alpha=0.9)
    axs[1].plot(jnp.array(nexp['z_actual'])[i,:,p_id,0].reshape(-1,1),jnp.array(nexp['z_actual'])[i,:,p_id,1].reshape(-1,1), color ='b', alpha=0.9)

Es = jnp.array(nexp['Es']).reshape(-1,4)
Eshat = jnp.array(nexp['Eshat']).reshape(-1,4)

# PE, KE, KE-PE, H

r2_ke = (r2_score(Es[:,1],Eshat[:,1])).round(1)
rmse_ke = jnp.array(sqrt(mean_squared_error(Es[:,1],Eshat[:,1]))).round(1)
axs[2].plot(Es[:,1],Eshat[:,1],'.',label=f"$R^2$= {r2_ke} \nRMSE= {rmse_ke}")
# axs[2].legend()
leg = axs[2].legend(handlelength=0, handletextpad=0, fancybox=True)
for item in leg.legendHandles:
    item.set_visible(False)


r2_pe = (r2_score(Es[:,0],Eshat[:,0])).round(1)
rmse_pe = jnp.array(sqrt(mean_squared_error(Es[:,0],Eshat[:,0]))).round(1)
axs[3].plot(Es[:,0],Eshat[:,0],'.',label=f"$R^2$= {r2_pe} \nRMSE= {rmse_pe}")
# axs[3].legend()
leg = axs[3].legend(handlelength=0, handletextpad=0, fancybox=True)
for item in leg.legendHandles:
    item.set_visible(False)


axs[4].plot(net_force_orig[:,p_id,0],net_force_model[:,p_id,0],'s',label='$F_x$',alpha=0.5)
axs[4].plot(net_force_orig[:,p_id,1],net_force_model[:,p_id,1],'^',label='$F_y$',alpha=0.5)
axs[4].legend()

# axs[0].set_xlabel('$X$')
# axs[0].set_ylabel('$\dot{X}$')

# axs[1].set_xlabel('$X$')
# axs[1].set_ylabel('$\dot{X}$')

axs[0].set_xlabel('$q_x$')
axs[0].set_ylabel('$q_y$')

axs[1].set_xlabel('$q_x$')
axs[1].set_ylabel('$q_y$')

axs[2].set_xlabel('$KE_{ac}$')
axs[2].set_ylabel('$KE_{pr}$')
# loc = plticker.MultipleLocator(base=1)
# axs[2].xaxis.set_major_locator(loc)
# axs[2].yaxis.set_major_locator(loc)

axs[3].set_xlabel('$PE_{ac}$')
axs[3].set_ylabel('$PE_{pr}$')
# loc = plticker.MultipleLocator(base=1)
# axs[3].xaxis.set_major_locator(loc)
# axs[3].yaxis.set_major_locator(loc)

axs[4].set_xlabel('$F_{ac}$')
axs[4].set_ylabel('$F_{pr}$')
# axs[4].yaxis.set_label_coords(-0.1, 0.5)
# loc = plticker.MultipleLocator(base=0.5)
# axs[4].xaxis.set_major_locator(loc)
# axs[4].yaxis.set_major_locator(loc)
# plt.savefig('../results/fig2-5spring.png',dpi=100)

In [None]:
filename = f"../results/5-Pendulum-3HGNN/0/error_parameter_0_0.pkl"
nexp = pickle.load(open(filename,'rb'))[0]

zx = jnp.array(nexp['z_pred'])[:,:,:,0].reshape(-1,5)
zy = jnp.array(nexp['z_pred'])[:,:,:,1].reshape(-1,5)
vx = jnp.array(nexp['v_pred'])[:,:,:,0].reshape(-1,5)
vy = jnp.array(nexp['v_pred'])[:,:,:,1].reshape(-1,5)

zx_a = jnp.array(nexp['z_actual'])[:,:,:,0].reshape(-1,5)
zy_a = jnp.array(nexp['z_actual'])[:,:,:,1].reshape(-1,5)
vx_a = jnp.array(nexp['v_actual'])[:,:,:,0].reshape(-1,5)
vy_a = jnp.array(nexp['v_actual'])[:,:,:,1].reshape(-1,5)

fig, axs = panel(1, 5, figsize=(5*5,5*1))

# fig, axs = panel(1,5)
p_id = 4
make_den_plot_ps(axs[0],[-1,1],True, 3, np.array(zx[:,p_id])/5,np.array(vy[:,p_id])/20,label=f"Predicted",tt_s=0)
make_den_plot_ps(axs[1],[-1,1],True, 3, np.array(zx_a[:,p_id])/5,np.array(vy_a[:,p_id])/20,label=f"Actual",tt_s=0)
# make_den_plot_ps(axs[0],[-1,1],True, 3, pred_traj.position[:,p_id,0]/5,pred_traj.velocity[:,p_id,1]/20,label=f"Predicted",tt_s=1)
# make_den_plot_ps(axs[1],[-1,1],True, 3, actual_traj.position[:,p_id,0]/5,actual_traj.velocity[:,p_id,1]/20,label=f"Actual",tt_s=1)

Es = jnp.array(nexp['Es']).reshape(-1,4)
Eshat = jnp.array(nexp['Eshat']).reshape(-1,4)

# PE, KE, KE-PE, H

r2_ke = (r2_score(Es[:,1],Eshat[:,1])).round(1)
rmse_ke = jnp.array(sqrt(mean_squared_error(Es[:,1],Eshat[:,1]))).round(1)
axs[2].plot(Es[:,1],Eshat[:,1],'.',label=f"$R^2$: {r2_ke} \nRMSE : {rmse_ke}")
# axs[2].legend()
leg = axs[2].legend(handlelength=0, handletextpad=0, fancybox=True)
for item in leg.legendHandles:
    item.set_visible(False)


r2_pe = (r2_score(Es[:,0],Eshat[:,0])).round(1)
rmse_pe = jnp.array(sqrt(mean_squared_error(Es[:,0],Eshat[:,0]))).round(1)
axs[3].plot(Es[:,0],Eshat[:,0],'.',label=f"$R^2$: {r2_pe} \nRMSE : {rmse_pe}")
# axs[3].legend()
leg = axs[3].legend(handlelength=0, handletextpad=0, fancybox=True)
for item in leg.legendHandles:
    item.set_visible(False)


axs[4].plot(net_force_orig[:,p_id,0],net_force_model[:,p_id,0],'s',label='$F_x$',alpha=0.5)
axs[4].plot(net_force_orig[:,p_id,1],net_force_model[:,p_id,1],'^',label='$F_y$',alpha=0.5)
axs[4].legend()


axs[0].set_xlabel('$X$')
axs[0].set_ylabel('$\dot{X}$')

axs[1].set_xlabel('$X$')
axs[1].set_ylabel('$\dot{X}$')

axs[2].set_xlabel('$KE_{ac}$')
axs[2].set_ylabel('$KE_{pr}$')
loc = plticker.MultipleLocator(base=50)
axs[2].xaxis.set_major_locator(loc)
axs[2].yaxis.set_major_locator(loc)

axs[3].set_xlabel('$PE_{ac}$')
axs[3].set_ylabel('$PE_{pr}$')
axs[3].yaxis.set_label_coords(-0.1, 0.5)
loc = plticker.MultipleLocator(base=50)
axs[3].xaxis.set_major_locator(loc)
axs[3].yaxis.set_major_locator(loc)

axs[4].set_xlabel('$F_{ac}$')
axs[4].set_ylabel('$F_{pr}$')
loc = plticker.MultipleLocator(base=1)
axs[4].xaxis.set_major_locator(loc)
axs[4].yaxis.set_major_locator(loc)
# plt.savefig('../results/fig2-5pendulum.png',dpi=100)
