In [None]:
%load_ext autoreload
%autoreload 2

import optax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import matplotlib.pyplot as plt


from spring_utils import get_observations, get_zs, get_A
from spring_gradients import marginal_likelihood
from training_spring import fit
from constants import RAND_KEY

### Constants and utils

In [None]:
delta_t = 0.1
m = 3.
k = 20.
z = 0.5

mu0 = jnp.array([3, 1])
V0 = jnp.eye(2) * 0.0001
trans_noise = jnp.eye(2) * 0.01
obs_noise = jnp.eye(2) * 0.5


JAX_KEY = jrandom.PRNGKey(2)

### Generated data

In [None]:
num_steps = 500
N = 4

zs, xs = get_observations(k, delta_t, m, z, mu0, V0, trans_noise, obs_noise, num_steps, N, key=JAX_KEY)

In [None]:
plt.figure(figsize=(15, 6))
plt.plot(zs[:, 0, 0])
plt.plot(xs[:, 0, 0])
plt.show()

### Learn the dynamics of the mass-spring system

In [None]:
N = 10000
num_steps = 3
NUM_TRAIN_STEPS = 1500

_, xs = get_observations(k, delta_t, m, z, mu0, V0, trans_noise, obs_noise, num_steps, N)

LR_ESTIMATOR = False

A_init = jnp.eye(2)
params = jnp.array([A_init])


optimizer = optax.chain(
    optax.adam(learning_rate=0.0006),
    optax.scale(-1.0)
)

optimizer.init(params)

print(f"True value of A: {get_A(k, delta_t, m, z)}\n")

learned_params, losses, gradients = fit(
    params=params,
    optimizer=optimizer, 
    training_steps=NUM_TRAIN_STEPS, 
    mu0=mu0, V0=V0,
    trans_noise=trans_noise,
    obs_noise=obs_noise, xs=xs, 
    num_steps=num_steps, 
    N=N, 
    lr_estimator=LR_ESTIMATOR,
)

In [None]:
np.save('lr_training_objectives.npy', losses)
epsilons = jrandom.normal(key=jrandom.PRNGKey(4), shape=(num_steps, N, 2))
goal = marginal_likelihood(get_A(k, delta_t, m, z), mu0, V0, trans_noise, obs_noise, epsilons, xs)


plt.figure(figsize=(10,6))
plt.axhline(goal, color='red', linestyle='dashed')
plt.plot(losses)
plt.show()

### Gradient Variance analysis

In [None]:
NUM_BATCHES = 20
NUM_TRAIN_STEPS = 1500
num_samples = 10000

all_rp_gradients = np.zeros((NUM_BATCHES, NUM_TRAIN_STEPS, 2, 2))
all_lr_gradients = np.zeros((NUM_BATCHES, NUM_TRAIN_STEPS, 2, 2))

key = RAND_KEY

for i in range(NUM_BATCHES):
    key, subkey = jrandom.split(key)
    _, xs = get_observations(
        k, delta_t, m, z, mu0, V0, trans_noise, obs_noise, num_steps, num_samples, key=subkey
    )

    # Reparameterized gradients
    _, _, rp_gradients = fit(
        params=params,
        optimizer=optimizer, 
        training_steps=NUM_TRAIN_STEPS, 
        mu0=mu0, V0=V0,
        trans_noise=trans_noise,
        obs_noise=obs_noise, xs=xs, 
        num_steps=num_steps, 
        N=num_samples, 
        lr_estimator=False,
        key=subkey
    )

    # Likelihood-ratio gradients
    _, _, lr_gradients = fit(
        params=params,
        optimizer=optimizer, 
        training_steps=NUM_TRAIN_STEPS, 
        mu0=mu0, V0=V0,
        trans_noise=trans_noise,
        obs_noise=obs_noise, xs=xs, 
        num_steps=num_steps, 
        N=num_samples, 
        lr_estimator=True,
        key=subkey
    )

    all_rp_gradients[i] = rp_gradients
    all_lr_gradients[i] = lr_gradients

np.save(f'rp_gradient_batches_{num_samples}_samples.npy', all_rp_gradients)
np.save(f'lr_gradient_batches_{num_samples}_samples.npy', all_lr_gradients)

In [None]:
all_lr_gradients = all_lr_gradients.reshape(NUM_BATCHES, NUM_TRAIN_STEPS, -1)
lr_grad_var = all_lr_gradients.var(axis=0)[:, 0].mean()
print(lr_grad_var)

all_rp_gradients = all_rp_gradients.reshape(NUM_BATCHES, NUM_TRAIN_STEPS, -1)
rp_grad_var = all_rp_gradients.var(axis=0)[:, 0].mean()
print(rp_grad_var)

### Consistency check

In [None]:
NUM_TRAIN_STEPS = 1500
samples = np.linspace(1e1, 1e6, 6, dtype=int)

A_init = jnp.eye(2)
params = jnp.array([A_init])


optimizer = optax.chain(
    optax.adam(learning_rate=0.01),
    optax.scale(-1.0)
)

optimizer.init(params)

A_diffs = np.zeros((samples.shape[0], NUM_TRAIN_STEPS, 4))

for i, num_samples in enumerate(samples):
    _, xs = get_observations(
        k, delta_t, m, mu0, V0, trans_noise, obs_noise, num_steps, num_samples
    )

    # Reparameterized gradients
    _, _, rp_gradients = fit(
        params=params,
        optimizer=optimizer, 
        training_steps=NUM_TRAIN_STEPS, 
        mu0=mu0, V0=V0,
        trans_noise=trans_noise,
        obs_noise=obs_noise, xs=xs, 
        num_steps=num_steps, 
        N=num_samples, 
        lr_estimator=False,
    )

    # Likelihood-ratio gradients
    _, _, lr_gradients = fit(
        params=params,
        optimizer=optimizer, 
        training_steps=NUM_TRAIN_STEPS, 
        mu0=mu0, V0=V0,
        trans_noise=trans_noise,
        obs_noise=obs_noise, xs=xs, 
        num_steps=num_steps, 
        N=num_samples, 
        lr_estimator=True,
    )

    cur_diffs = np.abs(lr_gradients - rp_gradients).reshape(NUM_TRAIN_STEPS, -1)
    A_diffs[i] = cur_diffs


In [None]:
caca = np.load("A_diffs.npy")
plt.figure(figsize=(10,10))
plt.plot(samples, caca[:, :, 0].mean(axis=1), color="cornflowerblue", label="A[0,0]")
plt.plot(samples, caca[:, :, 1].mean(axis=1), color="orange", label="A[0,1]")
plt.plot(samples, caca[:, :, 2].mean(axis=1), color="green", label="A[1,0]")
plt.plot(samples, caca[:, :, 3].mean(axis=1), color="purple", label="A[1,1]")
plt.plot(samples, caca[:, :, :].mean(axis=(1,2)), color="pink", label="Average")
plt.xscale('log')
plt.yscale('log')
plt.show()