In [3]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt


def random_initial_state(key, batch_size):
    theta_min_rad = np.deg2rad(-160)
    theta_max_rad = np.deg2rad(160)
    keys = jax.random.split(key, batch_size)
    thetas = jax.vmap(lambda k: jax.random.uniform(k, minval=theta_min_rad, maxval=theta_max_rad))(keys)
    thetas = (thetas + jnp.pi) % (2 * jnp.pi) - jnp.pi
    return jnp.stack([thetas / jnp.pi, jnp.zeros_like(thetas)], axis=1)

def plot_results(results, tau):
    plt.plot(jnp.log(jnp.array(results.losses)))
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.grid()
    plt.show()

    num_plots = min(10, len(results.state_trajectories[0]))  # Plot only 10
    time_axis = jnp.arange(results.state_trajectories[0].shape[0]) * tau
    #fig, axes = plt.subplots(len(results.state_trajectories[0]), 1, figsize=(10, 4 * len(results.state_trajectories[0])), sharex=True)
    fig, axes = plt.subplots(num_plots, 1, figsize=(10, 4 * num_plots), sharex=True)
    
    for i in range(num_plots):
        for epoch in range(len(results.state_trajectories)):
            axes[i].plot(time_axis, results.state_trajectories[epoch][:, i, 0], label=r'$\theta$ (Observed)', color='b')
            axes[i].plot(time_axis, results.state_trajectories[epoch][:, i, 1], label=r'$\omega$ (Observed)', color='g')
            axes[i].plot(time_axis, results.action_trajectories[epoch][:, i], 'r--', label=r'$\theta$ (Action)')
        axes[i].set_title(f'Test Batch {i+1}')
        axes[i].set_ylabel('State')
        axes[i].legend(loc='upper right')
        axes[i].grid()

    axes[-1].set_xlabel('Time (seconds)')
    plt.tight_layout()
    plt.grid()
    plt.show()
    
    '''
    state_trajectories = jnp.array(results.state_trajectories)  # Convert to a JAX array
    action_trajectories = jnp.array(results.action_trajectories)
        
    
    for i in range(len(results.state_trajectories[0])):
        axes[i].plot(time_axis, results.state_trajectories[:, i, 0], label=r'$\theta$ (Observed)', color='b')
        axes[i].plot(time_axis, results.state_trajectories[:, i, 1], label=r'$\omega$ (Observed)', color='g')
        axes[i].plot(time_axis, results.action_trajectories[:, i], 'r--', label=r'$\theta$ (Action)')
        axes[i].set_title(f'Test Batch {i+1}')
        axes[i].set_ylabel('State')
        axes[i].legend(loc='upper right')
        axes[i].grid()

    axes[-1].set_xlabel('Time (seconds)')
    plt.tight_layout()
    plt.grid()
    plt.show()
    
    '''
