In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from functools import partial
import os
import pickle as pkl
from collections.abc import MutableMapping
import time
from typing import Any, Callable, Iterable, Mapping, Optional, Union
import json

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
import tqdm
from IPython.display import HTML

os.environ["CUDA_VISIBLE_DEVICES"] = "7"

import jax
from jax import vmap, lax
import jax.numpy as jnp
from jax.example_libraries import optimizers

import flax
from flax import linen as nn
import optax
from frozendict import frozendict

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, HMC, SVI, Trace_ELBO

import jraph
from jraph._src import graph as gn_graph
from jraph._src import utils

print(f'Jax: CPUs={jax.local_device_count("cpu")} - GPUs={jax.local_device_count("gpu")}')

  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)


Jax: CPUs=1 - GPUs=1


In [3]:
from hgnn.noisify import add_noise_and_truncate
from hgnn.model import *
from hgnn.hamiltonian import *
from hgnn.training import *
from hgnn.simulating import *

In [4]:
prefix = 'nbody-n4'
noise_scale = 0.01
truncate_decimal = 2

Zs_train = jnp.load(f'./data/{prefix}/Zs_train.npy')
Zs_dot_train = jnp.load(f'./data/{prefix}/Zs_dot_train.npy')
Zs_test = jnp.load(f'./data/{prefix}/Zs_test.npy')
Zs_dot_test = jnp.load(f'./data/{prefix}/Zs_dot_test.npy')

Zs_train, Zs_dot_train = add_noise_and_truncate(Zs_train, Zs_dot_train, 
                                                scale=noise_scale, 
                                                decimals=truncate_decimal)
# Zs_test, Zs_dot_test = add_noise_and_truncate(Zs_test, Zs_dot_test, scale=0.05, decimals=1)

N2, dim = Zs_train.shape[-2:]
N = N2 // 2
species = jnp.zeros(N, dtype=int)
masses = jnp.ones(N)

Zs = Zs_train.reshape(-1, N2, dim)
Zs_dot = Zs_dot_train.reshape(-1, N2, dim)

Zst = Zs_test.reshape(-1, N2, dim)
Zst_dot = Zs_dot_test.reshape(-1, N2, dim)

print('N2:', N2)
print('dim:', dim)
print('Zs.shape:', Zs.shape)
print('Zs_dot.shape:', Zs_dot.shape)
print('Zst.shape:', Zst.shape)
print('Zst_dot.shape:', Zst_dot.shape)
print('Zs_test.shape:', Zs_test.shape)
print()

with open(f'./data/{prefix}/param.json', 'r') as f:
    d = json.load(f)
    stride = d['stride']
    dt = d['dt']
    lr = d['lr']
    batch_size = d['batch_size']
    epochs = d['epochs']
    
print('stride:', stride)
print('dt:', dt)
print('lr:', lr)
print('batch_size:', batch_size)
print('epochs:', epochs)

N2: 8
dim: 3
Zs.shape: (7500, 8, 3)
Zs_dot.shape: (7500, 8, 3)
Zst.shape: (2500, 8, 3)
Zst_dot.shape: (2500, 8, 3)
Zs_test.shape: (5, 500, 8, 3)

stride: 100
dt: 0.001
lr: 0.01
batch_size: 100
epochs: 2000


In [5]:
def get_fully_connected_senders_and_receivers(num_particles: int, self_edges: bool = False):
    """Returns senders and receivers for fully connected particles."""
    particle_indices = np.arange(num_particles)
    senders, receivers = np.meshgrid(particle_indices, particle_indices)
    senders, receivers = senders.flatten(), receivers.flatten()
    if not self_edges:
        mask = senders != receivers
        senders, receivers = senders[mask], receivers[mask]
    return senders, receivers

def get_fully_edge_order(N):
    out = []
    for j in range(N):
        for i in range(N):
            if i == j:
                pass
            else:
                if j > i:
                    out += [i*(N-1) + j-1]
                else:
                    out += [i*(N-1) + j]
    return np.array(out)

senders, receivers = get_fully_connected_senders_and_receivers(N)
eorder = get_fully_edge_order(N)

In [6]:
key = jax.random.PRNGKey(42)

Ef = 1  # eij dim
Oh = 1

Eei = 5
Nei = 5

hidden = 5
nhidden = 2

In [7]:
R, V = jnp.split(Zs[0], 2, axis=0)

apply_fn = energy_fn(senders=senders, receivers=receivers, species=species, R=R, V=V, eorder=eorder, dropout_rate=0.)
Hmodel = generate_Hmodel(apply_fn)

zdot_model, lamda_force_model = get_zdot_lambda(
    N, dim, hamiltonian=Hmodel, drag=None, 
    constraints=None, 
    external_force=None)


v_zdot_model = vmap(zdot_model, in_axes=(0, 0, None))

In [8]:
def initialize_mlp_prob(name, sizes, affine=[False], scale=1.0):
    """ Initialize the weights of all layers of a linear layer network """
    # Initialize a single layer with Gaussian weights -  helper function
    if len(affine) != len(sizes):
        affine = [affine[0]]*len(sizes)
    affine[-1] = True

    def initialize_layer(name_, m, n, affine=True, scale=1e-2):
        ws = numpyro.sample(f'{name_}_W', dist.Normal(scale=scale), sample_shape=(n, m))
#         bs = numpyro.sample(f'{name_}_b', dist.Normal(scale=(0. if affine else scale)), sample_shape=(n,))
        bs = numpyro.sample(f'{name_}_b', dist.Normal(scale=scale), sample_shape=(n,))
        return ws, bs
        
    return [initialize_layer(f'{name}_{i}', m, n, affine=a, scale=scale) 
            for i, (m, n, a) in enumerate(zip(sizes[:-1], sizes[1:], affine))]


def mlp_prob(name, in_, out_, hidden, nhidden, **kwargs):
    return initialize_mlp_prob(name, get_layers(in_, out_, hidden, nhidden), **kwargs)


def generate_HGNN_params_prob(Oh, Nei, Ef, Eei, dim, hidden, nhidden):
    
    fneke_params = initialize_mlp_prob('fneke', [Oh, Nei])
    fne_params = initialize_mlp_prob('fne', [Oh, Nei])

    fb_params = mlp_prob('fb', Ef, Eei, hidden, nhidden)
    fv_params = mlp_prob('fv', Nei + Eei, Nei, hidden, nhidden)
    fe_params = mlp_prob('fe', Nei, Eei, hidden, nhidden)

    ff1_params = mlp_prob('ff1', Eei, 1, hidden, nhidden)
    ff2_params = mlp_prob('ff2', Nei, 1, hidden, nhidden)
    ff3_params = mlp_prob('ff3', dim + Nei, 1, hidden, nhidden)
    ke_params = initialize_mlp_prob('ke', [1 + Nei, 10, 10, 1], affine=[True])

    Hparams = dict(
        fb=fb_params,
        fv=fv_params,
        fe=fe_params,
        ff1=ff1_params,
        ff2=ff2_params,
        ff3=ff3_params,
        fne=fne_params,
        fneke=fneke_params,
        ke=ke_params
    )
    
    return {"H": Hparams}

In [9]:
def model(Rs_, Vs_, Zs_dot_=None):
    
    p_ = generate_HGNN_params_prob(Oh, Nei, Ef, Eei, dim, hidden, nhidden)
    
    with numpyro.plate("data", len(Rs_), dim=-1, subsample_size=batch_size) as ind:
        R = Rs_[ind]
        V = Vs_[ind]
        Zdot = None if (Zs_dot_ is None) else Zs_dot_[ind]
        pred = v_zdot_model(R, V, p_)
    numpyro.sample('zs_dot', dist.Normal(pred, 1e-4), obs=Zdot)
        
#     pred = v_zdot_model(Rs_, Vs_, p_)
#     obs = numpyro.sample('zs_dot', dist.Normal(pred, scale=1e-4), obs=Zs_dot_)  # originally 1e-4
        
    return pred

In [10]:
Rs, Vs = jnp.split(Zs, 2, axis=1)
Rst, Vst = jnp.split(Zst, 2, axis=1)

# bRs, bVs, bZs_dot = batching(Rs, Vs, Zs_dot, size=min(len(Rs), batch_size))
# batches = list(zip(bRs, bVs, bZs_dot))
# batches = jnp.concatenate([bRs, bZs_dot], axis=2)

In [11]:
# Start from this source of randomness. We will split keys for subsequent operations.
rng_key = random.PRNGKey(42)
rng_key, rng_key_ = random.split(rng_key)

guide = numpyro.infer.autoguide.AutoNormal(model, init_scale=0.1)
optimizer = numpyro.optim.ClippedAdam(step_size=1e-3, clip_norm=1e-7)
svi = SVI(model, guide=guide, optim=optimizer, loss=Trace_ELBO())
svi_result = svi.run(rng_key_, 200000, Rs, Vs, Zs_dot, stable_update=True)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200000/200000 [13:41<00:00, 243.50it/s, init loss: 111360771044147200.0000, avg. loss [190001-200000]: 351681856.0000]


In [12]:
#51825428.0000

In [13]:
def construct_param(samples, i):
    
    d = dict()
    p_names = set([n.split('_')[0] for n in samples.keys()])
    
    for n in p_names:
        k = 0
        l = []
        while f'{n}_{k}_W' in samples:
            ws = samples[f'{n}_{k}_W'][i]
            bs = samples[f'{n}_{k}_b'][i]
            l.append((ws, bs))
            k += 1
        d[n] = l
        
    return {'H': d}

In [None]:
num_samples = 100

all_traj = []

test_count = Zs_test.shape[0]
runs = Zs_test.shape[1]

sim_model = get_forward_sim_noparam(
    zdot_model=zdot_model, runs=runs, stride=stride, dt=dt, tol=1e-5)

predictive = numpyro.infer.Predictive(guide, params=svi_result.params, num_samples=num_samples)
samples = predictive(rng_key, Rs, Vs, Zs_dot)

pbar = tqdm.tnrange(test_count * num_samples)

for idx in range(test_count):
    
    pbar.set_description(f'Test {idx+1} of {test_count}')

    z_actual_out = Zs_test[idx]
    x_act_out, p_act_out = jnp.split(z_actual_out, 2, axis=1)

    Zs_init = Zs_test[idx:idx+1, 0:1]

    with jax.default_device(jax.devices('cpu')[0]):
        trajectories = {
            'pred_pos': [],
            'pred_vel': [],
            'actual_pos': jnp.array(x_act_out),
            'actual_vel': jnp.array(p_act_out),
        }

    for i in range(num_samples):

        st = construct_param(samples, i)

        Zs_noisy = add_noise_and_truncate(Zs_init, Zs_init, scale=noise_scale)[0].squeeze((0, 1))
        R_noisy = Zs_noisy[:N]
        V_noisy = Zs_noisy[N:]

#         R_noisy = Zs_init.squeeze((0, 1))[:N]
#         V_noisy = Zs_init.squeeze((0, 1))[N:]

        z_pred_out = sim_model(R_noisy, V_noisy, st)
        x_pred_out, p_pred_out = jnp.split(z_pred_out, 2, axis=1)

        with jax.default_device(jax.devices('cpu')[0]):
            trajectories['pred_pos'].append(x_pred_out)
            trajectories['pred_vel'].append(p_pred_out)
            
        pbar.update()
        
    with jax.default_device(jax.devices('cpu')[0]):
        trajectories['pred_pos'] = jnp.array(trajectories['pred_pos'])
        trajectories['pred_vel'] = jnp.array(trajectories['pred_vel'])
        trajectories['pred_pos_avg'] = jnp.mean(trajectories['pred_pos'], axis=0)
        trajectories['pred_vel_avg'] = jnp.mean(trajectories['pred_vel'], axis=0)
    
    all_traj.append(trajectories)

  pbar = tqdm.tnrange(test_count * num_samples)


  0%|          | 0/500 [00:00<?, ?it/s]

In [None]:
with open('./results/nbody-n4-vi.pkl', 'wb+') as f:
    pkl.dump(all_traj, f)

In [None]:
idx = 0

trajectories = all_traj[idx]

r = trajectories['actual_pos']
for i in range(r.shape[1]):
    plt.plot(r[:,i,0], r[:,i,1], '-', color=f'C{i}', alpha=0.2)
plt.plot(r[-1,:,0], r[-1,:,1], 'o', color='black', alpha=0.1)

r = trajectories['pred_pos'][15]
for i in range(r.shape[1]):
    plt.plot(r[:,i,0], r[:,i,1], '-', color=f'C{i}', alpha=1.)
plt.plot(r[-1,:,0], r[-1,:,1], 'o', color='black', alpha=1.)

plt.gca().set_aspect('equal', adjustable='box')
plt.show()

In [None]:
i_err = 1

r = trajectories['actual_pos']
for i in range(r.shape[1]):
    plt.plot(r[:,i,0], r[:,i,1], '-', color=f'C{i}', alpha=0.2)
plt.plot(r[-1,:,0], r[-1,:,1], 'o', color='black', alpha=0.1)

r = trajectories['pred_pos_avg']
for i in range(r.shape[1]):
    if i != i_err:
        plt.plot(r[:,i,0], r[:,i,1], '-', color=f'C{i}')
plt.plot(r[-1,:,0], r[-1,:,1], 'o', color='black')

r = trajectories['pred_pos']
for t in range(r.shape[1]):
    plt.plot(r[:,t,i_err,0], r[:,t,i_err,1], '.', color=f'C{i_err}', markerfacecolor=None, alpha=1. / num_samples)

plt.gca().set_aspect('equal', adjustable='box')
plt.show()

In [None]:
fig, ax = plt.subplots()

# r = trajectories['actual_pos']
# traj_actual = [ax.plot(r[:,i,0], r[:,i,1], '-', color=f'C{i}', alpha=0.1)[0] for i in range(r.shape[1])]  
# ball_actual, = ax.plot(r[-1,:,0], r[-1,:,1], 'o', color='black', alpha=0.1)

r = trajectories['pred_pos_avg']
traj_pred = [ax.plot(r[:,i,0], r[:,i,1], '-', color=f'C{i}', alpha=1.)[0] for i in range(r.shape[1])]  
ball_pred, = ax.plot(r[-1,:,0], r[-1,:,1], 'o', color='black', alpha=0.5, zorder=5.)

r = trajectories['pred_pos']
point_cloud = [ax.plot(r[:,0,i,0], r[:,0,i,1], 'o', color=f'C{i}', 
                       alpha=2. / num_samples, zorder=4., markerfacecolor=None)[0] 
               for i in range(r.shape[2])]  

def gather():
#     return traj_actual + [ball_actual] + traj_pred + [ball_pred]
    return point_cloud + traj_pred + [ball_pred]

def init():
    ax.set_aspect('equal', adjustable='box')
    return gather()

def update(frame):
    
#     r = trajectories['actual_pos']
#     for i in range(r.shape[1]):
#         traj_actual[i].set_data(r[:frame,i,0], r[:frame,i,1])
#     ball_actual.set_data(r[frame,:,0], r[frame,:,1])

    r = trajectories['pred_pos_avg']
    for i in range(r.shape[1]):
        traj_pred[i].set_data([], [])
    ball_pred.set_data(r[frame,:,0], r[frame,:,1])
    
    r = trajectories['pred_pos']
    for i in range(r.shape[2]):
        point_cloud[i].set_data(r[:,frame,i,0], r[:,frame,i,1])
    
    return gather()

ani = FuncAnimation(fig, update, frames=tqdm.tnrange(runs), init_func=init, blit=True, interval=10)
plt.close(fig)
HTML(ani.to_jshtml())

In [None]:
writergif = PillowWriter(fps=30, bitrate=300) 
ani.save(f'./demo/nbody-n4-vi-{idx}.mp4', writer=writergif)