In [None]:
import jax
import jax.numpy as jnp
import numpy as np # get rid of this eventually

from matplotlib import pyplot as plt


import argparse
from jax import jit
from jax.experimental.ode import odeint
from functools import partial # reduces arguments to function by making some subset implicit

#ML functions
from jax.experimental import stax
from jax.experimental import optimizers

import os, sys, time
import pickle as pkl
from tqdm.notebook import tqdm
from copy import deepcopy as copy

from jax.experimental.ode import odeint

In [None]:
from jax.experimental.stax import Dense, Softplus, Tanh, elementwise, Relu


sigmoid = jit(lambda x: 1/(1+jnp.exp(-x)))
swish = jit(lambda x: x/(1+jnp.exp(-x)))
relu3 = jit(lambda x: jnp.clip(x, 0.0, float('inf'))**3)
Swish = elementwise(swish)
Relu3 = elementwise(relu3)

def extended_mlp(args):
    act = {
        'softplus': [Softplus, Softplus],
        'swish': [Swish, Swish],
        'tanh': [Tanh, Tanh],
        'tanh_relu': [Tanh, Relu],
        'soft_relu': [Softplus, Relu],
        'relu_relu': [Relu, Relu],
        'relu_relu3': [Relu, Relu3],
        'relu3_relu': [Relu3, Relu],
        'relu_tanh': [Relu, Tanh],
    }[args.act]
    hidden = args.hidden_dim
    output_dim = args.output_dim
    nlayers = args.layers
    
    layers = []
    layers.extend([
        Dense(hidden),
        act[0]
    ])
    for _ in range(nlayers - 1):
        layers.extend([
            Dense(hidden),
            act[1]
        ])
        
    layers.extend([Dense(output_dim)])
    
    return stax.serial(*layers)

In [None]:
class ObjectView(object):
    def __init__(self, d): self.__dict__ = d

@jax.jit
def qdotdot(q, q_t, conditionals):
    #evaluation of the acceleration
    g = conditionals
    
    q_tt = (
        g * (1 - q_t**2)**(5./2) / 
        (1 + 2 * q_t**2)
    )
    
    return q_t, q_tt



@jax.jit
def ofunc(y, t=None):
    #derivatives of the function
    q = y[::3]
    q_t = y[1::3]
    g = y[2::3]
    
    q_t, q_tt = qdotdot(q, q_t, g)
    return jnp.stack([q_t, q_tt, jnp.zeros_like(g)]).T.ravel()

def hamiltonian_eom(hamiltonian, state, conditionals, t=None):
    #Hamiltonian equation of motion
    q, q_t = jnp.split(state, 2)
    
    #Move to canonical coordinates:
    p = q_t/(1.0-q_t**2)**(3/2.0)
    q = q
    
    conditionals = conditionals / 10.0 #Normalize
    p_t = -jax.grad(hamiltonian, 0)(q, p, conditionals)
    #Move back to generalized coordinates:
    q_tt = p_t*(1-q_t**2)**(5.0/2)/(1.0+2*q_t**2)
    
    #Avoid nans by computing q_t afterwards:
    q_t = jax.grad(hamiltonian, 1)(q, p, conditionals)
    return jnp.concatenate([q_t, q_tt])

# replace the lagrangian with a parameteric model
def learned_dynamics(params, nn_forward_fn):
    @jit
    def dynamics(q, p, conditionals):
        state = jnp.concatenate([q, p, conditionals])
        return jnp.squeeze(nn_forward_fn(params, state), axis=-1)
    return dynamics

In [None]:
#function for the generation of the dataset

@partial(jax.jit, static_argnums=(1, 2), backend='cpu')
def gen_data(seed, batch, num):
    #takes the seed, number of batches and the time spacing
    rng = jax.random.PRNGKey(seed)
    q0 = jax.random.uniform(rng, (batch,), minval=-10, maxval=10)
    qt0 = jax.random.uniform(rng+1, (batch,), minval=-0.99, maxval=0.99)
    g = jax.random.normal(rng+2, (batch,))*10

    y0 = jnp.stack([q0, qt0, g]).T.ravel()

    yt = odeint(ofunc, y0, jnp.linspace(0, 1, num=num), mxsteps=300)

    qall = yt[:, ::3]
    qtall = yt[:, 1::3]
    gall = yt[:, 2::3]
    
    return jnp.stack([qall, qtall]).reshape(2, -1).T, gall.reshape(1, -1).T, qdotdot(qall, qtall, gall)[1].reshape(1, -1).T

@partial(jax.jit, static_argnums=(1,))
def gen_data_batch(seed, batch):
    rng = jax.random.PRNGKey(seed)
    q0 = jax.random.uniform(rng, (batch,), minval=-10, maxval=10)
    qt0 = (jnp.tanh(jax.random.normal(jax.random.PRNGKey(1), (batch,))*2)*0.99999)#jax.random.uniform(rng+1, (batch,), minval=-1, maxval=1)
    g = jax.random.normal(rng+2, (batch,))*10
    
    return jnp.stack([q0, qt0]).reshape(2, -1).T, g.reshape(1, -1).T, jnp.stack(qdotdot(q0, qt0, g)).T

In [None]:
cstate, cconditionals, ctarget = gen_data_batch(0, 5)
cstate, cconditionals, ctarget

In [None]:
args = ObjectView({'dataset_size': 200,
 'fps': 10,
 'samples': 100,
 'num_epochs': 80000,
 'seed': 0,
 'loss': 'l1',
 'act': 'softplus',
 'hidden_dim': 500,
 'output_dim': 1,
 'layers': 4,
 'n_updates': 1,
 'lr': 0.001,
 'lr2': 2e-05,
 'dt': 0.1,
 'model': 'gln',
 'batch_size': 68,
 'l2reg': 5.7e-07,
})
# args = loaded['args']
rng = jax.random.PRNGKey(args.seed)

In [None]:
init_random_params, nn_forward_fn = extended_mlp(args)
rng = jax.random.PRNGKey(0)
_, init_params = init_random_params(rng, (-1, 3))
rng += 1

In [None]:
best_small_loss = np.inf
iteration = 0
total_epochs = 100
minibatch_per = 3000
train_losses, test_losses = [], []

lr = 1e-3 #1e-3

final_div_factor=1e4

#OneCycleLR:
@jax.jit
def OneCycleLR(pct):
    #Rush it:
    start = 0.3 #0.2
    pct = pct * (1-start) + start
    high, low = lr, lr/final_div_factor
    
    scale = 1.0 - (jnp.cos(2 * jnp.pi * pct) + 1)/2
    
    return low + (high - low)*scale
    

opt_init, opt_update, get_params = optimizers.adam(
    OneCycleLR
)
opt_state = opt_init(init_params)
# opt_state = opt_init(best_params)

In [None]:
plt.plot(OneCycleLR(jnp.linspace(0, 1, num=200)))
plt.yscale('log')
plt.title('lr schedule')

In [None]:
@jax.jit
def loss(params, cstate, cconditionals, ctarget):
    runner = jax.vmap(
        partial(
            hamiltonian_eom,
            learned_dynamics(params, nn_forward_fn)), (0, 0), 0)
    preds = runner(cstate, cconditionals)[:, [0, 1]]
    
    error = jnp.abs(preds - ctarget)
    #Weight additionally by proximity to c!
    error_weights = (1 + 1/jnp.sqrt(1.0-cstate[:, [1]]**2))
    
    return jnp.sum(error * error_weights)*len(preds)/jnp.sum(error_weights)

@jax.jit
def update_derivative(i, opt_state, cstate, cconditionals, ctarget):
    params = get_params(opt_state)
    param_update = jax.grad(
            lambda *args: loss(*args)/len(cstate),
            0
        )(params, cstate, cconditionals, ctarget)
    params = get_params(opt_state)
    return opt_update(i, param_update, opt_state), params


In [None]:
filename = 'best_sr_params_hamiltonian_true_canonical.pkl'

try:
    for epoch in tqdm(range(total_epochs)):
        epoch_loss = 0.0
        num_samples = 0
        batch = 512
        ocstate, occonditionals, octarget = gen_data_batch(epoch, minibatch_per*batch)
        
        
        
        for minibatch in range(minibatch_per):
            fraction = (epoch + minibatch/minibatch_per)/total_epochs
            s = np.s_[minibatch*batch:(minibatch+1)*batch]
            
            cstate, cconditionals, ctarget = ocstate[s], occonditionals[s], octarget[s]
            opt_state, params = update_derivative(fraction, opt_state, cstate, cconditionals, ctarget);
            rng += 10
            
            cur_loss = loss(params, cstate, cconditionals, ctarget)
            
            epoch_loss += cur_loss
            num_samples += len(cstate)
        closs = epoch_loss/num_samples
        print('epoch={} lr={} loss={}'.format(
            epoch, OneCycleLR(fraction), closs)
            )
        if closs < best_loss:
            best_loss = closs
            best_params = [[copy(jax.device_get(l2)) for l2 in l1] if len(l1) > 0 else () for l1 in params]

except KeyboardInterrupt:
    pkl.dump({'params': best_params, 'description': 'q and g are divided by 10. hidden=500. act=Softplus'},
             open(filename, 'wb'))

## Evaluation of the model

In [None]:
best_params = (
    pkl.load(open('best_sr_params_hamiltonian_true_canonical.pkl', 'rb'))['params']
)
opt_state = opt_init(best_params)
params = get_params(opt_state)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4*1, 4*1), sharex=True, sharey=True)
ax_idx = [(i, j) for i in range(1) for j in range(1)]

for i in tqdm(range(1)):
    
    ci = ax_idx[i]
    
    cstate, cconditionals, ctarget = gen_data((i+4)*(i+1), 1, 50)
    
    runner = jax.jit(jax.vmap(
        partial(
            hamiltonian_eom,
            learned_dynamics(params, nn_forward_fn)), (0, 0), 0))

    @jax.jit
    def odefunc_learned(y, t):
        return jnp.concatenate((runner(y[None, :2], y[None, [2]])[0], jnp.zeros(1)))

    yt_learned = odeint(
        odefunc_learned,
        jnp.concatenate([cstate[0], cconditionals[0]]),
        np.linspace(0, 1, 50),
        mxsteps=100)
    
    cax = ax#[ci[0], ci[1]]
    cax.plot(cstate[:, 1], label='Truth')
    cax.plot(yt_learned[:, 1], label='Learned')
    cax.legend()
    if ci[1] == 0:
        cax.set_ylabel('Velocity of particle/Speed of light')
    if ci[0] == 0:
        cax.set_xlabel('Time')
        
    cax.set_ylim(-1, 1)
    
plt.title("Hamiltonian NN - Special Relativity")
plt.tight_layout()
#plt.savefig('sr_hnn_true.png', dpi=150)