In [None]:
%load_ext autoreload
%autoreload 2

import optax
import numpy as np
import jax.numpy as jnp
import jax.random as jrandom
import matplotlib.pyplot as plt

from pendulum_rk4 import get_pendulum_sequence, get_cos_sin_states
from training import fit
from polynomial_model import predict, get_zs
from constants import RAND_KEY

### Generate double pendulum states with RK4

In [None]:
dt = 0.01
start_state = jnp.array([jnp.pi / 4, jnp.pi / 3, 0., 0.])
V0 = jnp.eye(4) * 0.00001
trans_noise = jnp.eye(4) * 0.001
obs_noise = jnp.eye(4) * 0.001

num_steps = 2000
N = 100

zs, xs = get_pendulum_sequence(start_state, V0, trans_noise, obs_noise, num_steps, N, dt)
cos_sin_zs = get_cos_sin_states(zs)

In [None]:
plt.figure(figsize=(15, 6))
# plt.plot(zs[:, 0, 1], label="Theta2")
plt.plot(cos_sin_zs[:2000, 0, 2], label="Cos Theta2")
plt.plot(cos_sin_zs[:2000, 0, 3], label="Sin Theta2")
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(15, 6))
plt.scatter(cos_sin_zs[:-1, :, 4], cos_sin_zs[1:, :, 4], label="Sin theta2")
plt.scatter(cos_sin_zs[:-1, :, 3], cos_sin_zs[1:, :, 3], label="Cos theta2")
plt.legend()
plt.show()

### Learn parameters with regression

In [None]:
dt = 0.01
V0 = jnp.eye(4) * 0.0001
trans_noise = jnp.eye(4) * 0.001
obs_noise = jnp.eye(4) * 0.1

num_steps = 10
N = 10000
NUM_TRAINING_STEPS = 4000
LR_ESTIMATOR = False

start_state = jnp.array([jnp.pi / 4, jnp.pi / 3, 0., 0.])
zs, xs = get_pendulum_sequence(start_state, V0, trans_noise, obs_noise, num_steps, N, dt)
cos_sin_zs = get_cos_sin_states(zs)
cos_sin_xs = get_cos_sin_states(xs)
start_state = jnp.array([
    jnp.cos(start_state[0]),
    jnp.sin(start_state[0]),
    jnp.cos(start_state[1]),
    jnp.sin(start_state[1]),
])

num_features = 4

start_weights = jnp.zeros((num_features,))
# start_weights = jnp.array([0.2725405, 0.33645087, 0.22920588, 0.15438452])
params = jnp.array([start_weights])


optimizer = optax.chain(
    optax.adam(learning_rate=0.0005),
    optax.scale(-1.0)
)

optimizer.init(params)

learned_params, training_objectives, gradients = fit(
    params=params,
    optimizer=optimizer,
    training_steps=NUM_TRAINING_STEPS,
    num_features=num_features,
    start_state=start_state,
    V0=V0, 
    trans_noise=trans_noise,
    obs_noise=obs_noise,
    xs=cos_sin_xs,
    # xs=cos_sin_zs,
    num_steps=num_steps,
    N=N,
    lr_estimator=LR_ESTIMATOR,
)

In [None]:
epsilons = jrandom.normal(key=RAND_KEY, shape=(num_steps, N, 4))
learned_zs = predict(learned_params[0], cos_sin_zs[:-1])
pred_zs = get_zs(learned_params[0], start_state, V0, trans_noise, epsilons)
pred_angles = jnp.arctan2(pred_zs[:, :, 0], pred_zs[:, :, 1])

plt.figure(figsize=(10, 6))
plt.scatter(cos_sin_zs[:-1, :, 2], cos_sin_zs[1:, :, 3], label="Ground truth")
plt.scatter(pred_zs[:-1, :, 2], pred_zs[1:, :, 3], label="State sequence prediction")
# plt.plot(pred_angles.mean(axis=1))
# plt.plot(zs[:, :, 0].mean(axis=1))
plt.legend()
plt.show()