In [None]:
## Imports
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib
from matplotlib import animation
import matplotlib.pyplot as plt
import os
from sklearn import decomposition
import scipy.interpolate
import sys
sys.path.append("..")

matplotlib.rc("animation", html="jshtml")
PLT_STYLE_CONTEXT = ['science', 'ieee', 'grid']

In [None]:
%load_ext autoreload

In [None]:
# @title Source Imports
%autoreload 2
from action_angle_networks.simulation import double_pendulum_simulation
from action_angle_networks.configs.double_pendulum import default
from action_angle_networks import analysis

In [None]:
config = default.get_config()
simulation_parameters = double_pendulum_simulation.sample_simulation_parameters(config.simulation_parameter_ranges.to_dict(), num_trajectories=config.num_trajectories, rng=jax.random.PRNGKey(0))

In [None]:
simulation_parameters["theta1_init"] = jnp.pi / 2
simulation_parameters["theta2_init"] = 0
simulation_parameters["m2"] = 0.1

In [None]:
times = jnp.arange(0, 5, 0.01)
positions, momentums = double_pendulum_simulation.generate_canonical_coordinates(times, simulation_parameters)

## Visualizing

In [None]:
# @title Location of Pretrained Model
# config_name = "euler_update_flow"
config_name = "action_angle_flow"
setup = "setup_1"
workdir = f"/Users/ameyad/Documents/google-research/tmp/double_pendulum/{setup}/action_angle_flow/"

In [None]:
config, scaler, state, aux = analysis.load_from_workdir(workdir)

In [None]:
true_positions, true_momentums = analysis.get_train_trajectories(workdir, jump=1)
double_pendulum_simulation.plot_coordinates(true_positions, true_momentums, simulation_parameters, title="Train Trajectories")

In [None]:
true_positions, true_momentums = analysis.get_test_trajectories(workdir, jump=1)
double_pendulum_simulation.plot_coordinates(true_positions, true_momentums, simulation_parameters, title="Test Trajectories")

In [None]:
jump = 1
predicted_positions, predicted_momentums = analysis.get_recursive_multi_step_predicted_trajectories(workdir, jump=jump)
double_pendulum_simulation.plot_coordinates(predicted_positions, predicted_momentums, simulation_parameters, title=f"Predicted Trajectories\n Jump {jump}")


## Saving

In [None]:
train_positions, train_momentums = analysis.get_train_trajectories(workdir, jump=1)
anim = double_pendulum_simulation.plot_coordinates(train_positions, train_momentums, simulation_parameters, title="Train Trajectories")
anim.save(f"../notebook_outputs/double_pendulum/{setup}/train_trajectories.gif")

In [None]:
test_positions, test_momentums = analysis.get_test_trajectories(workdir, jump=1)
anim = double_pendulum_simulation.plot_coordinates(test_positions, test_momentums, simulation_parameters, title="Test Trajectories")
anim.save(f"../notebook_outputs/double_pendulum/{setup}/test_trajectories.gif")

In [None]:
for jump in [1, 2, 5, 10, 20, 50]:
    predicted_positions, predicted_momentums = analysis.get_recursive_multi_step_predicted_trajectories(workdir, jump=jump)
    anim = double_pendulum_simulation.plot_coordinates(predicted_positions, predicted_momentums, simulation_parameters, title=f"Predicted Trajectories\n Jump {jump}")
    anim.save(f"../notebook_outputs/double_pendulum/{setup}/jump_{jump}_recursive_multi_step_predicted_trajectories.gif")