## Import

In [None]:
import jax
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import sys
sys.path.append("..")
import src
from frozendict import frozendict
from jax import grad, jit
from jax import numpy as jnp
from jax import value_and_grad, vmap
from jax.config import config
from jax.experimental import optimizers
from jraph import GraphNetwork, GraphsTuple
from src import lnn, utils
from src.io import *
from src.lnn import accelerationFull
from src.md import displacement, prediction, shift
from src.models import forward_pass, initialize_mlp
from src.nve import NVEState, NVEStates

from shadow.plot import *
from shadow.font import *

import importlib
importlib.reload(src)

config.update("jax_enable_x64", True)
try:
    jax.tree_util.register_pytree_node(
    frozendict,
    flatten_func=lambda s: (tuple(s.values()), tuple(s.keys())),
    unflatten_func=lambda k, xs: frozendict(zip(k, xs)))
except:
    pass

## Config

In [None]:
n = 10
LENGTH = 1
link_size = LENGTH / n

def move_by_link(R, n=1, x=1, y=1, ls=1, t=0):
    dx = jnp.cos(t/180*jnp.pi)*ls
    dy = jnp.sin(t/180*jnp.pi)*ls
    return R + n* jnp.array([[x*dx, y*dy, 0.0], ])
        
def rot(x, t):
    theta = t * jnp.pi / 180
    sin = jnp.sin(theta)
    cos = jnp.cos(theta)
    ROT = jnp.array([
        [cos, -sin, 0],
        [sin, cos, 0],
        [0, 0, 1]
    ])
    return jnp.dot(ROT, x.T).T


sss = LENGTH*0.3

R_ = link_size * \
    jnp.vstack([jnp.zeros(n+1), -jnp.arange(0, n+1), jnp.zeros(n+1)]).T

R1 = rot(R_, 45)

R_ = link_size * \
    jnp.vstack([jnp.zeros(n), -jnp.arange(0, n), jnp.zeros(n)]).T

R2 = rot(R_, 45+90) + R1[-1:, :]
R2 = move_by_link(R2, ls=link_size, t=45)

R_ = link_size * \
    jnp.vstack([jnp.zeros(n), -jnp.arange(0, n), jnp.zeros(n)]).T

R3 = rot(R_, 45+90+90) + R2[-1:, :]
R3 = move_by_link(R3, ls=link_size, t=45+90)

R_ = link_size * \
    jnp.vstack([jnp.zeros(n), -jnp.arange(0, n), jnp.zeros(n)]).T

R4 = rot(R_, 45+90+90+90) + R3[-1:, :]
R4 = move_by_link(R4[:-1], ls=link_size, t=45+90+90)

R = jnp.vstack((R1,R2,R3,R4))

R = R + jnp.array([[0.0, -1.0, 0.0], ])

dim = 2

if dim == 2:
    R = R[:, :2]


N, dim = R.shape

V = 0*R

# V = np.array(V)
# V[1, 0] = 1.0
# V[-1, 0] = -1.0


R0 = R
V0 = V


mass = 1.0
length = 1.0
_g = 10.0
dt = 1.0e-4
stride = 100
runs = 100



sends = [i for i in range(N-1)] + [N, n]
recs = [i for i in range(1, N)] + [0, 3*n]

key = jax.random.PRNGKey(0)

nspecies = 1

def OHE(x):
    return jax.nn.one_hot(x, nspecies)

species = OHE(jnp.ones(len(sends+recs)))




print(R.shape)



In [None]:
def get_I(m, L):
    return m/12*L**2


def get_omega(r, v):
    # ω = r × v / |r | ²
    r2 = jnp.square(r).sum(axis=-1, keepdims=True)
    rxv = jnp.cross(r, v)
    return rxv / r2


dimω,  = get_omega(R[0], V[0]).shape

key = jax.random.PRNGKey(0)

if dim == 2:
    endpoint1 = np.array([[0.0, 0]])
    endpoint2 = np.array([[LENGTH*np.sqrt(2), 0]])
else:
    endpoint1 = np.array([[0.0, 0, 0.0]])
    endpoint2 = np.array([[LENGTH*np.sqrt(2), 0, 0.0]])


In [None]:
ss = 5

params_graph = { 
        "ke_a_emb": initialize_mlp([1, ss], key),
        "ke_l_emb": initialize_mlp([1, ss], key),
        "pe_emd": initialize_mlp([1, ss], key),

        "ke_a": initialize_mlp([ss+dimω, 10, 1], key),
        "ke_l": initialize_mlp([ss+dim, 10, 1], key),
        "pe": initialize_mlp([ss+2*dim, 10, 1], key),
}

In [None]:
def hconstraints(R, l=jnp.array([1.0])):
    if len(l) != len(R):
        l = l[0]*jnp.ones((len(R)+1))
    out = jnp.square(jnp.vstack([R, endpoint2]) -
                     jnp.vstack([endpoint1, R])).sum(axis=1) - l**2
    return out

def hconstraints(R, l=jnp.array([1.0])):
    if len(l) != len(R):
        l = l[0]*jnp.ones((len(R)))
    out = jnp.square(R -
                     jnp.vstack([endpoint1, R[:-1]])).sum(axis=1) - l**2
    return out

@ jit
def constraints(x, v, params):
    return jax.jacobian(lambda x: hconstraints(x.reshape(-1, dim)), 0)(x)

constraints(R.flatten(), V, params_graph).shape

## Lag

In [None]:

def lag_link_wrap(r1, r2, v1, v2, sp, params):
    R = jnp.hstack((r1, r2))
    V = jnp.hstack((v1, v2))
    return lag_link(R, V, sp, params)


def identity(x):
    return x


def lag_link(R, V, sp, params):
    if params is None:
        r1, r2 = jnp.split(R, 2)
        v1, v2 = jnp.split(V, 2)
        pos = (r1 + r2) / 2
        vel = (v1 + v2) / 2
        r_ = r1 - r2
        v_ = v1 - v2
        w = jnp.cross(r_, v_)
        T = 0.5 * mass * jnp.square(vel).sum() + \
            0.5 * get_I(mass, length) * jnp.square(w).sum()
        V_ = mass * _g * pos[1]
        # T = 0.5 * (v1**2 + v2**2).sum()
        # V = 10.0 * (r1[1] + r2[1]).sum()
        return T, V_
    else:
        r1, r2 = jnp.split(R, 2)
        v1, v2 = jnp.split(V, 2)
        pos = (r1 + r2) / 2
        vel = (v1 + v2) / 2
        r_ = r1 - r2
        v_ = v1 - v2
        w = jnp.cross(r_, v_)

        w_ = forward_pass(params["ke_a_emb"], sp, activation_fn=identity)
        w_ = jnp.hstack((w_, w))
        vel_ = forward_pass(params["ke_l_emb"], sp, activation_fn=identity)
        vel_ = jnp.hstack((vel_, vel))

        pos_ = forward_pass(params["pe_emd"], sp, activation_fn=identity)
        pos_ = jnp.hstack((pos_, pos, r_))

        T = jnp.square(forward_pass(params["ke_l"], vel_)).sum() 
        T += jnp.square(forward_pass(params["ke_a"], w_)).sum()
        V_ = forward_pass(params["pe"], pos_).sum()

        # T = 0.5 * (v1**2 + v2**2).sum()
        # V_ = mass * _g * (r1[1] + r2[1]).sum() /2

        # r1, r2 = jnp.split(R, 2)
        # v1, v2 = jnp.split(V, 2)
        # pos = (r1 + r2) / 2
        # vel = (v1 + v2) / 2
        # r_ = r1 - r2
        # v_ = v1 - v2
        # w = jnp.cross(r_, v_)

        # w_ = forward_pass(params["ke_a_emb"], sp, activation_fn=identity)
        # w_ = jnp.hstack((w_, w))
        # vel_ = forward_pass(params["ke_l_emb"], sp, activation_fn=identity)
        # vel_ = jnp.hstack((vel_, vel))

        # pos_ = forward_pass(params["pe_emd"], sp, activation_fn=identity)
        # pos_ = jnp.hstack((pos_, pos, r_))

        # T = jnp.square(forward_pass(params["ke_l"], vel_)).sum() + \
        #     jnp.square(forward_pass(params["ke_a"], w_)).sum()

#         r1, r2 = jnp.split(R, 2)
#         v1, v2 = jnp.split(V, 2)

#         vel = (v1 + v2) / 2
#         r_ = r1 - r2
#         v_ = v1 - v2
#         w = jnp.cross(r_, v_)
#         T = 0.5 * mass * jnp.square(vel).sum() + \
#             0.5 * get_I(mass, length) * jnp.square(w).sum()

#         # Node embedding
#         r_i, r_j = r1, r2
#         v_i, v_j = v1, v2
#         pos_i = forward_pass(params["posi_emb"], r_i, activation_fn=identity)
#         vel_i = forward_pass(params["veli_emb"], v_i, activation_fn=identity)
#         pos_j = forward_pass(params["posi_emb"], r_j, activation_fn=identity)
#         vel_j = forward_pass(params["veli_emb"], v_j, activation_fn=identity)

#         # Edge embedding
#         dr_ij = r_i - r_j
#         dv_ij = v_i - v_j
#         # w_ij = jnp.cross(dr_ij, dv_ij) 
#         w_ij = get_omega(dr_ij, dv_ij)
#         pos_ij = jnp.hstack(
#             (forward_pass(params["posij_emb"], sp, activation_fn=identity), dr_ij))
#         vel_ij = jnp.hstack(
#             (forward_pass(params["velij_emb"], sp, activation_fn=identity), w_ij))
#         # pos_ij = forward_pass(params["posij_emb"], jnp.hstack(
#         #     (sp, dr_ij)), activation_fn=identity)
#         # vel_ij = forward_pass(params["velij_emb"], jnp.hstack(
#         #     (sp, w_ij)), activation_fn=identity)

#         # Node update skipped

#         # Edge update message passing
#         pos_ij = pos_ij + forward_pass(params["mp_pos_e"], pos_i + pos_j)
#         # vel_ij = vel_ij + forward_pass(params["mp_vel_e"],
#         #                                jnp.hstack((vel_i, vel_j)))

#         # Edge final output
#         T = forward_pass(params["ke"], vel_ij).sum() #+ 0.5 * mass * jnp.square(vel).sum()
#         V_ = forward_pass(params["pe"], pos_ij).sum()
#         # V_ = mass * _g * (r1[1] + r2[1])/2
#         T = 0.5 * mass * jnp.square(vel).sum()
#         # T += 0.5 * get_I(mass, length) * jnp.square(w_ij).sum()

        return T, V_

v_lag_link = vmap(lag_link_wrap, in_axes=(0, 0, 0, 0, 0, None))


In [None]:
def update_edge_fn(edges, sent_attributes, received_attributes, globals_):
    L = v_lag_link(sent_attributes["position"], received_attributes["position"],
                   sent_attributes["velocity"], received_attributes["velocity"],
                   edges["species"], edges["params"])
    return frozendict({"T_V": L})

def update_node_fn(nodes, sent_attributes, received_attributes, globals_):
    T, V = received_attributes["T_V"]
    nodes = {"lag":(T-V), "ham":(T+V), "ke":T, "pe":V}
    return frozendict(nodes)

In [None]:
net = GraphNetwork(update_edge_fn, update_node_fn)

In [None]:
graph = GraphsTuple(
    nodes={
        "position": R,
        "velocity": V,
        "acceleration": 0*V,
    },
    edges={
        "params": params_graph,
        "species": species,
    },
    senders=jnp.array(sends+recs),
    receivers=jnp.array(recs+sends),
    globals=None,
    n_node=jnp.array([len(R)]),
    n_edge=None
)


In [None]:
@jit
def Lgraph(R, V, params):
    graph = GraphsTuple(
        nodes={
            "position": R,
            "velocity": V,
            "acceleration": 0*V,
        },
        edges={
            "params": params,
            "species": species,
        },
        senders=jnp.array(sends+recs),
        receivers=jnp.array(recs+sends),
        globals=None,
        n_node=jnp.array([len(R)]),
        n_edge=None
    )
    g = net(graph)
    return g.nodes["lag"].sum()

def Lactual(R0, V0, params_graph):
    return Lgraph(R0, V0, None)


Lgraph(R0, V0, params_graph), Lactual(R0, V0, None)

In [None]:
importlib.reload(lnn)

acceleration_fn_graph = jit(lnn.accelerationFull(N, dim,
                                                 lagrangian=Lgraph,
                                                 # non_conservative_forces=None,
                                                 constraints=constraints,
                                                 # external_force=None,
                                                 )
                            )

acceleration_fn_graph(R, V, None), acceleration_fn_graph(R, V, params_graph)

In [None]:
@ jit
def force_fn_graph(R, V, params, mass=None):
    if mass is None:
        return acceleration_fn_graph(R, V, params)
    else:
        return acceleration_fn_graph(R, V, params) * mass.reshape(-1, 1)


acceleration_fn_graph(R0, V0, params_graph)

masses = mass*jnp.ones(len(R0))

@ jit
def forward_sim_graph(R, V):
    return prediction(R,  V, None, force_fn_graph, shift, dt, masses,
                      dR_max=1.0e-5, stride=stride, runs=100*runs)



In [None]:
@jit
def Energy(R, V, params):
    graph = GraphsTuple(
        nodes={
            "position": R,
            "velocity": V,
            "acceleration": 0*V,
        },
        edges={
            "params": params,
            "species": species,
        },
        senders=jnp.array(sends+recs),
        receivers=jnp.array(recs+sends),
        globals=None,
        n_node=jnp.array([len(R)]),
        n_edge=None
    )
    nodes = net(graph).nodes
    return jnp.array([nodes["lag"].sum(), nodes["ham"].sum(), nodes["ke"].sum(), nodes["pe"].sum()])

In [None]:
v_Energy = vmap(Energy, in_axes=(0, 0, None))

In [None]:
import jraph
from src import fgn

importlib.reload(fgn)

senders, receivers = jnp.array(sends+recs), jnp.array(recs+sends)

hidden_dim = [128]
edgesize = 4
nodesize = 4
ee = 8
ne = 8
Lparams = dict(
    ee_params=initialize_mlp([edgesize, ee], key),
    ne_params=initialize_mlp([nodesize, ne], key),
    e_params=initialize_mlp([ee+2*ne, *hidden_dim, ee], key),
    n_params=initialize_mlp([2*ee+ne, *hidden_dim, ne], key),
    g_params=initialize_mlp([ne, *hidden_dim, 1], key),
    acc_params=initialize_mlp([ne, *hidden_dim, dim], key),
)

species = jnp.array(species).reshape(-1, 1)

def dist(*args):
    disp = displacement(*args)
    return disp

def omega(R1, R2, V1, V2):
    dr = R1 - R2
    dv = V1 - V2
    return vmap(jnp.cross, in_axes=(0, 0))(dr, dv)

rij = vmap(dist, in_axes=(0, 0))(R[senders], R[receivers])
wij = omega(R[senders], R[receivers], V[senders], V[receivers])

state_graph = jraph.GraphsTuple(nodes={
    "position": R,
    "velocity": V,
},
    edges={"rij": rij, "wij": wij, "type": species},
    senders=senders,
    receivers=receivers,
    n_node=jnp.array([N]),
    n_edge=jnp.array([senders.shape[0]]),
    globals={})

def acceleration_fn(params, graph):
    acc = fgn.cal_acceleration(params, graph, mpass=1, act=jax.nn.leaky_relu)
    return acc

def acc_fn(species):
    state_graph = jraph.GraphsTuple(nodes={
        "position": R,
        "velocity": V,
    },
        edges={"rij": rij.reshape(-1, R.shape[1]), "wij": wij.reshape(-1, 1), "type": species.reshape(-1, 1)},
        senders=senders,
        receivers=receivers,
        n_node=jnp.array([R.shape[0]]),
        n_edge=jnp.array([senders.shape[0]]),
        globals={})

    def apply(R, V, params):
        state_graph.nodes.update(position=R)
        state_graph.nodes.update(velocity=V)
        state_graph.edges.update(rij=vmap(dist, in_axes=(0, 0))(R[senders], R[receivers]).reshape(-1, R.shape[1])
                                 )
        state_graph.edges.update(wij=omega(R[senders], R[receivers], V[senders], V[receivers]).reshape(-1, 1)
                                 )
        return acceleration_fn(params, state_graph)
    return apply

apply_fn = jit(acc_fn(species))
v_apply_fn = vmap(apply_fn, in_axes=(None, 0))

def acceleration_fn_model(x, v, params): return apply_fn(x, v, params["L"])

params = {"L": Lparams}

print(acceleration_fn_model(R, V, params))

# def nndrag(v, params):
#     return - jnp.abs(models.forward_pass(params, v.reshape(-1), activation_fn=models.SquarePlus)) * v

# if ifdrag == 0:
#     print("Drag: 0.0")

#     def drag(x, v, params):
#         return 0.0
# elif ifdrag == 1:
#     print("Drag: -0.1*v")

#     def drag(x, v, params):
#         return vmap(nndrag, in_axes=(0, None))(v.reshape(-1), params["drag"]).reshape(-1, 1)

# params["drag"] = initialize_mlp([1, 5, 5, 1], key)

# acceleration_fn_model = accelerationFull(N, dim,
#                                          lagrangian=Lmodel,
#                                          constraints=constraints,
#                                          non_conservative_forces=drag)

v_acceleration_fn_model = vmap(acceleration_fn_model, in_axes=(0, 0, None))


In [None]:
# params_model = initialize_mlp([len(R)*dim*2, 128, 128, 1], key)

# def Lmodel(R, V, params):
#     return forward_pass(params, jnp.hstack((R.flatten(), V.flatten()))).sum()

# acceleration_fn_model = jit(lnn.accelerationFull(N, dim,
#                                                  lagrangian=Lmodel,
#                                                  # non_conservative_forces=None,
#                                                  constraints=constraints,
#                                                  # external_force=None,
#                                                  )
#                             )

@ jit
def force_fn_model(R, V, params, mass=None):
    if mass is None:
        return acceleration_fn_model(R, V, params)
    else:
        return acceleration_fn_model(R, V, params) * mass.reshape(-1, 1)

@ jit
def forward_sim_model(R, V):
    return prediction(R,  V, params, force_fn_model, shift, dt, masses,
                      dR_max=1.0e-5, stride=stride, runs=runs)


In [None]:
params, _ = loadfile("chain_model_trained_free_gns.pkl")

## Forward 

In [None]:
# params, _ = loadfile("chain_model_trained.pkl")

In [None]:
RUNS = 10*runs

@ jit
def forward_sim_model(R, V):
    return prediction(R,  V, params, force_fn_model, shift, dt, masses,
                      dR_max=1.0e-5, stride=stride, runs=RUNS)


@ jit
def forward_sim_long(R, V):
    return prediction(R,  V, None, force_fn_graph, shift, dt, masses,
                      dR_max=1.0e-5, stride=stride, runs=RUNS)


states_long = forward_sim_long(R0, V0)
save_ovito("chain_long_T2_gns.data", NVEStates(states_long), length=10.0,
           insert_origin=True)


model_states = forward_sim_model(R0, V0)
save_ovito("chain_long_model_T2_gns.data", NVEStates(model_states), length=10.0,
           insert_origin=True)


In [None]:
# def getRV(n, dim=2, t=0):
#     R = [jnp.array([[0.0, 0, 0],])]
#     for _ in range(2*n):
#         point = jnp.array([[0.0, -1.0, 0.0], ])
#         point = rot(point, (np.random.rand()-0.5)*30)
#         R += R[-1] + point

#     R = jnp.vstack(R)
#     R = rot(R, t)

#     dim = 2

#     if dim == 2:
#         R = R[:, :2]
#     return R, 0*R

In [None]:
# states_long = []
# model_states = []

# seed = 0

# np.random.seed(seed)
# key = jax.random.PRNGKey(seed)

# for i in range(20):
#     print(i, n)
#     R, V = getRV(n, t=45)
#     print(R.shape)

#     states_long += [forward_sim_long(R, V)]
#     save_ovito(f"chain_long_gns_n{n}_i{i}.data", NVEStates(states_long[i]), length=10.0,
#                insert_origin=True)


#     model_states += [forward_sim_model(R, V)]
#     save_ovito(f"chain_long_model_gns_n{n}_i{i}.data", NVEStates(model_states[i]), length=10.0,
#                insert_origin=True)

In [None]:
savefile(f"chain_exp/T2_states_gns.pkl", {"model":model_states, "states":states_long}, 
         metadata=dict(dt=dt, stride=stride, runs=RUNS, samples=1))

In [None]:
def std_plot(y_, semilog=True, dt=1.0):
    mean_ = jnp.log(jnp.array(y_)).mean(axis=0)
    std_ = jnp.log(jnp.array(y_)).std(axis=0)

    up_b = jnp.exp(mean_ + 2*std_)
    low_b = jnp.exp(mean_ - 2*std_)
    y = jnp.exp(mean_)

    x = jnp.array(range(len(mean_)))*dt
    print(x, y)
    if semilog:
        plt.semilogy(x, y)
    else:
        plt.plot(x, y)
    plt.fill_between(x, low_b, up_b, alpha=0.5)


In [None]:
def norm(a):
    a2 = jnp.square(a)
    n = len(a2)
    a3 = a2.reshape(n, -1)
    return jnp.sqrt(a3.sum(axis=1))

def RelErr(ya, yp):
    return norm(ya-yp) / (norm(ya) + norm(yp))


In [None]:
def getall_state_H(model_states, states_long ):
    out = []
    for model_state, state in zip(model_states, states_long):
        E_pred = v_Energy(model_state.position, model_state.velocity, None)
        E = v_Energy(state.position, state.velocity, None)
        out += [RelErr(E_pred, E)]
    return jnp.array(out)


In [None]:
stride*dt*RUNS

In [None]:
H_error = getall_state_H([model_states], [states_long])
std_plot(H_error, dt=stride*dt)
plt.ylabel("Energy error")
plt.xlabel("Time step")
plt.savefig(f"chain_exp/T2_H_error_gns.png", dpi=600)

savefile(f"chain_exp/T2_H_error_gns.pkl", H_error)


In [None]:
def getall_state_Z(model_states, states_long):
    out = []
    for model_state, state in zip(model_states, states_long):
        out += [RelErr(model_state.position, state.position)]
    return jnp.array(out)

In [None]:
dt, stride, RUNS

In [None]:
Z_error = getall_state_Z([model_states], [states_long])

std_plot(Z_error, dt=stride*dt)
plt.ylabel("Rollout error")
plt.xlabel("Time step")

plt.savefig(f"chain_exp/T2_Z_error_gns.png", dpi=600)

savefile(f"chain_exp/T2_Z_error_gns.pkl", Z_error)

## END

In [None]:
make_plot(states_long)
make_plot(model_states)

In [None]:
@jit
def potE(R, params=None):
    r = (R[1:, :] + R[:-1, :]) / 2
    if params is None:
        return (mass * _g * r[:, 1]).sum()
    else:
        return vmap(forward_pass, in_axes=(None, 0))(params, r).sum()


potE(R)


@jit
def kinE_l(R, V, params=None):
    v = (V[1:, :] + V[:-1, :]) / 2
    if params is None:
        kin1 = 0.5 * mass * jnp.square(v).sum()
    else:
        kin1 = jnp.square(
            vmap(forward_pass, in_axes=(None, 0))(params, v)).sum()
    return kin1


kinE_l(R, V)


def kinE_a(R, V, params=None):
    r_ = (R[1:, :] - R[:-1, :])
    v_ = (V[1:, :] - V[:-1, :])
    w = vmap(get_omega)(r_, v_)
    if params is None:
        I = get_I(mass, length)
        kin2 = 0.5 * I * jnp.square(w).sum()
    else:
        kin2 = jnp.square(
            vmap(forward_pass, in_axes=(None, 0))(params, w)).sum()
    return kin2


kinE_a(R, V)


@jit
def kinE_(R, V, params=None):
    return kinE_l(R, V, params=params["l"]) + kinE_a(R, V, params=params["a"])


@jit
def kinE(R, V, params=None):
    if params is None:
        params = {"l": None, "a": None}
    return kinE_(R, V, params=params)
    # return 0.5 * mass * jnp.square(V).sum()

def make_plot(states):
    PE = []
    KE_l = []
    KE_a = []
    KE = []
    time = []
    dt = []
    time_ = 0.0
    for state in NVEStates(states):
        PE.append(potE(state.position))
        KE.append(kinE(state.position, state.velocity))
        KE_l.append(kinE_l(state.position, state.velocity))
        KE_a.append(kinE_a(state.position, state.velocity))
        time.append(state.time)
        dt.append(state.time - time_)
        time_ = state.time

    KE_l = jnp.array(KE_l)
    KE_a = jnp.array(KE_a)
    PE = jnp.array(PE)
    KE = jnp.array(KE)

    PE = PE - PE[0]
    KE = KE - KE[0]
    KE_l = KE_l - KE_l[0]
    KE_a = KE_a - KE_a[0]
    TE = KE + PE

    fix, ax = plt.subplots(1, 3, figsize=(18, 6), sharex=True)

    plt.sca(ax[0])
    plt.plot(time, 0*TE, "-k", label="x-axis", alpha=0.5)
    plt.plot(time, PE, label="PE")
    plt.plot(time, KE, label="KE")
    plt.plot(time, TE, label="TE")
    plt.plot(time, KE_l, "--", label="KE_l")
    plt.plot(time, KE_a, "--", label="KE_a")
    plt.legend()

    plt.sca(ax[1])
    plt.plot(time, TE)
    plt.ylabel("Hamiltonian")

    plt.sca(ax[2])
    plt.semilogy(time, dt)
    plt.ylabel("dt")

    plt.show()
