In [None]:
from CartPole import CartPole, remap_angle, cartpole_step, remap_angle2
import numpy as np
import matplotlib
import pandas as pd
import seaborn as sns
matplotlib.use('TkAgg') 
import scipy


import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)

import pickle

---
# Task 4.1

In [None]:
# important helper functions for all tasks

def get_std(X):
    return np.std(X, axis=0)

def convert_dict_to_array(data):
    # zip all the values in the dictionary together and convert it to a numpy array
    # suppose the keys are not known beforehand
    keys = list(data.keys())
    values = [data[key] for key in keys]
    return np.array(list(zip(*values)))


def rollout(initial_state, initial_force, num_steps, visual=True, max_force=20):
    """
    Simulate the CartPole environment for a given number of steps.
    
    Args:
        initial_state (tuple): The initial state of the environment.
        it should be a tuple of the form (cart_location, cart_velocity, 
                                        pole_angle, pole_velocity).

        initial_force (float): The initial force applied to the cart.
        num_steps (int): The number of steps to simulate.
    
    Returns:
        data: A dictionary containing the cart location, cart velocity, 
              pole angle and pole angular velocity at each step.
    """
    env = CartPole(visual=visual, max_force=max_force)
    env.reset()

    data = {'cart_location': [],
            'cart_velocity': [],
            'pole_angle': [],
            'pole_velocity': []
        }
    
    # Set the initial state
    env.setState(initial_state)

    # Perform the action for the specified number of steps
    for step in range(num_steps + 1):
        # Store the current state
        data['cart_location'].append(env.cart_location)
        data['cart_velocity'].append(env.cart_velocity)
        data['pole_angle'].append(env.pole_angle)
        data['pole_velocity'].append(env.pole_velocity)

        # Perform the action
        env.performAction(initial_force)

        # remap the angle to be between -pi and pi
        env.remap_angle()
    
    # close the plot
    if visual:
        env.close_plot()
        plt.close()
        
    return data

# Plotting functions --------------------------------------------------------
def plot_policy(X, target, graph_title):
    fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(15, 6))
    titles = ['cart_location', 'cart_velocity', 'pole_angle', 'pole_velocity']
    for i in range(2):
        for j in range(2):
            title = titles[i * 2 + j]
            ax[i, j].plot(np.arange(0, len(X[:, i * 2 + j])), X[:, i * 2 + j], 'r-', label='policy')
            ax[i, j].plot(np.arange(0, len(X[:, i * 2 + j])), [target[i * 2 + j]] * len(X[:, i * 2 + j]), 'b--', label='target')
            # ax[i, j].set_title(title)
            ax[i, j].set_xlabel('Iterations')
            ax[i, j].set_ylabel(title)
            ax[i, j].grid()
            ax[i, j].legend()
    
    # center the title on top of the figure
    fig.suptitle(graph_title)
    fig.tight_layout(rect=[0, 0.03, 1, 1])  # Adjust the rect to make space for the title
    plt.rcParams.update({'font.size': 16})
    plt.show()

    
def plot_fit(Y_actual, Y_pred, graph_title):
    fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(15, 8))
    titles = ['cart_location', 'cart_velocity', 'pole_angle', 'pole_velocity']
    for i in range(2):
        for j in range(2):
            title = titles[i * 2 + j]
            # make scatter plot dots smaller
            # colour the scatter plot dots dark blue
            ax[i, j].scatter(Y_actual[:, i * 2 + j], Y_pred[:, i * 2 + j], label='pred', s=2, color='darkblue')
            ax[i, j].plot(Y_actual[:, i * 2 + j], Y_actual[:, i * 2 + j], 'r--', label='Y = X')
            # ax[i, j].set_title(title)
            ax[i, j].set_xlabel('actual change in state')
            ax[i, j].set_ylabel(title)
            ax[i, j].grid()
            ax[i, j].legend()

    # center the title on top of the figure
    fig.suptitle(graph_title)
    fig.tight_layout(rect=[0, 0.03, 1, 1])  # Adjust the rect to make space for the title
    plt.rcParams.update({'font.size': 15})
    plt.show()
    plt.close()

def plot_actual_pred_iterations(X_actual, X_forecast, graph_title):
    fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(15, 6))
    titles = ['cart_location', 'cart_velocity', 'pole_angle', 'pole_velocity']
    for i in range(2):
        for j in range(2):
            title = titles[i * 2 + j]
            ax[i, j].plot(np.arange(0, len(X_actual[:, i * 2 + j])), X_actual[:, i * 2 + j], 'r-', label='actual')
            ax[i, j].plot(np.arange(0, len(X_forecast[:, i * 2 + j])), X_forecast[:, i * 2 + j], 'b--', label='forecast')
            # ax[i, j].set_title(title)
            ax[i, j].set_xlabel('Iterations')
            ax[i, j].set_ylabel(title)
            ax[i, j].grid()
            ax[i, j].legend()
    
    # center the title on top of the figure
    fig.suptitle(graph_title)
    fig.tight_layout(rect=[0, 0.03, 1, 1])  # Adjust the rect to make space for the title
    plt.rcParams.update({'font.size': 16})
    plt.show()

def plot_actual_pred_time(X_actual, X_forecast, graph_title):
    fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(15, 6))
    titles = ['cart_location', 'cart_velocity', 'pole_angle', 'pole_velocity']
    for i in range(2):
        for j in range(2):
            title = titles[i * 2 + j]
            ax[i, j].plot(0.1 * np.arange(0, len(X_actual[:, i * 2 + j])), X_actual[:, i * 2 + j], 'r-', label='actual')
            ax[i, j].plot(0.1 * np.arange(0, len(X_forecast[:, i * 2 + j])), X_forecast[:, i * 2 + j], 'b--', label='forecast')
            # ax[i, j].set_title(title)
            ax[i, j].set_xlabel('time')
            ax[i, j].set_ylabel(title)
            ax[i, j].grid()
            ax[i, j].legend()
    
    # center the title on top of the figure
    fig.suptitle(graph_title)
    fig.tight_layout(rect=[0, 0.03, 1, 1])  # Adjust the rect to make space for the title
    plt.rcParams.update({'font.size': 16})
    plt.show()

def forecast_nonlinear_force(initial_state, num_steps, alpha, sigma, X_prime, kernel_fn, max_force=20):
    x_state = initial_state[:-1]  # exclude the force from the state
    initial_force = initial_state[-1]  # force is the last element of the state
    # obtain the actual state
    X_actual = convert_dict_to_array(rollout(x_state, initial_force, num_steps, visual=False, max_force=max_force))

    current_state = np.array(initial_state)
    current_x_state = np.array(x_state)
    X_forecast = [current_x_state.copy()]
    for i in range(num_steps):
        # calculate the kernel for the current state
        K = kernel_fn(np.expand_dims(current_state, axis=0), X_prime, sigma)

        Y_pred = K @ alpha
        current_x_state = (current_state[:-1] + Y_pred).flatten()
        current_state = np.concatenate([current_x_state, current_state[-1:]]) # keep the force unchanged

        # remap the angle to be between -pi and pi PURELY FOR PLOTTING
        # The original state is still used for the forecast
        remapped_state = current_x_state.copy()
        remapped_state[2] = remap_angle(remapped_state[2])
        X_forecast.append(remapped_state)

    X_forecast = np.array(X_forecast)
    
    plot_actual_pred_iterations(X_actual, X_forecast, graph_title=f"Forecast for initial state: {initial_state[0]:.2f}, {initial_state[1]:.2f}, {initial_state[2]:.2f}, {initial_state[3]:.2f} with force {initial_state[4]:.2f}")
    plot_actual_pred_time(X_actual, X_forecast, graph_title=f"Forecast for initial state: {initial_state[0]:.2f}, {initial_state[1]:.2f}, {initial_state[2]:.2f}, {initial_state[3]:.2f} with force {initial_state[4]:.2f}")


# training functions --------------------------------------------------------
def generate_data_random_force(num_steps, max_force=20):
    env = CartPole(visual=False, max_force=max_force)
    env.reset()
    x_data = {
        'cart_location': [],
        'cart_velocity': [],
        'pole_angle': [],
        'pole_velocity': [],
        'force': []
    }

    y_data = {
        'cart_location': [],
        'cart_velocity': [],
        'pole_angle': [],
        'pole_velocity': []
    }
    for i in range(num_steps):
        initial_state = [np.random.uniform(-10, 10), np.random.uniform(-10, 10),
                         np.random.uniform(-np.pi, np.pi), np.random.uniform(-15, 15)]
        initial_force = np.random.uniform(-env.max_force, env.max_force)
        env.reset()
        env.setState(initial_state)
        env.performAction(initial_force)

        # remap the angle to be between -pi and pi
        # env.remap_angle()
        
        next_state = env.getState()
    
        x_data['cart_location'].append(initial_state[0])
        x_data['cart_velocity'].append(initial_state[1])
        x_data['pole_angle'].append(initial_state[2])
        x_data['pole_velocity'].append(initial_state[3])
        x_data['force'].append(initial_force)

        y_data['cart_location'].append(next_state[0] - initial_state[0])
        y_data['cart_velocity'].append(next_state[1] - initial_state[1])
        y_data['pole_angle'].append(next_state[2] - initial_state[2])
        y_data['pole_velocity'].append(next_state[3] - initial_state[3])

    X = convert_dict_to_array(x_data)
    Y = convert_dict_to_array(y_data)
    
    return X, Y


def generate_data_random_force_observed_noise(num_steps, max_force, y_std, noise_factor):
    env = CartPole(visual=False, max_force=max_force)
    env.reset()
    x_data = {
        'cart_location': [],
        'cart_velocity': [],
        'pole_angle': [],
        'pole_velocity': [],
        'force': []
    }

    y_data = {
        'cart_location': [],
        'cart_velocity': [],
        'pole_angle': [],
        'pole_velocity': []
    }
    for i in range(num_steps):
        initial_state = [np.random.uniform(-10, 10), np.random.uniform(-10, 10),
                         np.random.uniform(-np.pi, np.pi), np.random.uniform(-15, 15)]
        initial_force = np.random.uniform(-env.max_force, env.max_force)
        env.reset()
        env.setState(initial_state)
        env.performAction(initial_force)

        # remap the angle to be between -pi and pi
        # env.remap_angle()
        
        next_state = env.getState()
    
        x_data['cart_location'].append(initial_state[0])
        x_data['cart_velocity'].append(initial_state[1])
        x_data['pole_angle'].append(initial_state[2])
        x_data['pole_velocity'].append(initial_state[3])
        x_data['force'].append(initial_force)

        y_data['cart_location'].append(next_state[0] - initial_state[0])
        y_data['cart_velocity'].append(next_state[1] - initial_state[1])
        y_data['pole_angle'].append(next_state[2] - initial_state[2])
        y_data['pole_velocity'].append(next_state[3] - initial_state[3])

        # add noise to observations y
        noise_std = noise_factor * y_std
        noise = np.random.normal(0, noise_std, size=(4,))
        y_data['cart_location'][-1] += noise[0]
        y_data['cart_velocity'][-1] += noise[1]
        y_data['pole_angle'][-1] += noise[2]
        y_data['pole_velocity'][-1] += noise[3]


    X = convert_dict_to_array(x_data)
    Y = convert_dict_to_array(y_data)
    
    # print("shape of X:", X.shape, "\nshape of Y:", Y.shape)
    return X, Y

def tikhonov_solve(K, regularisation_matrix, Y,lamb):
    """
    Solve the Tikhonov regularization problem.
    
    Args:
        K (numpy.ndarray): The kernel matrix. (N, M)
        regularisation_matrix (numpy.ndarray): The regularization matrix. (M, M)
        Y (numpy.ndarray): The output data. (N, D)
        lamb (float): The regularization parameter.
    
    Returns:
        numpy.ndarray: The weights of the model.
    """
    Y_solve = (K.T) @ Y  # (N, D)

    regularisation_term = lamb * regularisation_matrix # (M, M)

    X_solve = ((K.T) @ K) + regularisation_term # (M, M)

    alpha = np.linalg.lstsq(X_solve, Y_solve, rcond=None)[0]  # (M, D)
    
    return alpha

def train_nonlinear_models(X, Y, M, lamb, sigma, kernel_fn):
  

    # choose M random points from X
    # indices = np.random.choice(X.shape[0], M, replace=False)
    # X_prime = X[indices]
    X_prime = X[:M] # (M, D)

    # Create the kernel matrix
    K = kernel_fn(X, X_prime, sigma) # (N, M)

    # Create the regularization matrix
    regularisation_matrix = kernel_fn(X_prime, X_prime, sigma) # (M, M)

    # solve the Tikhonov regularization problem
    alpha = tikhonov_solve(K, regularisation_matrix, Y, lamb)

    return alpha, X_prime, K

def kernel_expanded(X, X_prime, sigma):
    # create new X where 2 additional dimensions are added, replacing the angle with sin and cos
    # angle dimension is removed
    X_new = np.hstack((X[:,0:2], np.sin(X[:, 2:3]), np.cos(X[:, 2:3]), X[:, 3:]))  # (N, D+1)
    X_prime_new = np.hstack((X_prime[:,0:2], np.sin(X_prime[:, 2:3]), np.cos(X_prime[:, 2:3]), X_prime[:, 3:]))  # (M, D+1)

    X_e = np.expand_dims(X_new, axis=1)  # (N, 1, D+1)
    X_prime_e = np.expand_dims(X_prime_new, axis=0)  # (1, M, D+1)

    diff = X_e - X_prime_e  # (N, M, D+1)
    scaled_squared_diff = (diff ** 2)/(2 * sigma ** 2) # (N, M, D+1)

    K = np.exp(-np.sum(scaled_squared_diff, axis=-1))  # (N, M)
    return K


# jax helper functions ---------------------------------------------------

def train_nonlinear_models_j(X, Y, M, lamb, sigma, kernel_fn):
    # choose M points from X
    X_prime = X[:M] # (M, D)

    # Create the kernel matrix
    K = kernel_fn(X, X_prime, sigma) # (N, M)

    # Create the regularization matrix
    regularisation_matrix = kernel_fn(X_prime, X_prime, sigma) # (M, M)

    alpha = tikhonov_solve_j(K, regularisation_matrix, Y, lamb)
    return alpha, X_prime, K



def tikhonov_solve_j(K, regularisation_matrix, Y,lamb):
    """
    Solve the Tikhonov regularization problem.
    
    Args:
        K (numpy.ndarray): The kernel matrix. (N, M)
        regularisation_matrix (numpy.ndarray): The regularization matrix. (M, M)
        Y (numpy.ndarray): The output data. (N, D)
        lamb (float): The regularization parameter.
    
    Returns:
        numpy.ndarray: The weights of the model.
    """
    Y_solve = (K.T) @ Y  # (N, D)

    regularisation_term = lamb * regularisation_matrix # (M, M)
    X_solve = ((K.T) @ K) + regularisation_term # (M, M)

    alpha = jnp.linalg.lstsq(X_solve, Y_solve, rcond=None, numpy_resid=True)[0]  # (M, D)
    return alpha

@jax.jit
def kernel_expanded_j(X, X_prime, sigma):
    # create new X where 2 additional dimensions are added, replacing the angle with sin and cos
    # angle dimension is removed
    X_new = jnp.hstack((X[:,0:2], jnp.sin(X[:, 2:3]), jnp.cos(X[:, 2:3]), X[:, 3:]))  # (N, D+1)
    X_prime_new = jnp.hstack((X_prime[:,0:2], jnp.sin(X_prime[:, 2:3]), jnp.cos(X_prime[:, 2:3]), X_prime[:, 3:]))  # (M, D+1)

    X_e = jnp.expand_dims(X_new, axis=1)  # (N, 1, D+1)
    X_prime_e = jnp.expand_dims(X_prime_new, axis=0)  # (1, M, D+1)

    diff = X_e - X_prime_e  # (N, M, D+1)
    scaled_squared_diff = (diff ** 2)/(2 * sigma ** 2) # (N, M, D+1)

    K = jnp.exp(-jnp.sum(scaled_squared_diff, axis=-1))  # (N, M)
    return K


In [None]:
N_train, N_test = 4096, 2048
X_no_noise, Y_no_noise = generate_data_random_force(num_steps=N_train+N_test)
Y_std = get_std(Y_no_noise)
print("Standard deviation of Y:", Y_std)

In [None]:
# train sine cosine model

N_train, N_test, M = 4096, 2048, 1024
max_force = 15

X, Y = generate_data_random_force_observed_noise(num_steps=N_train+N_test, max_force=max_force, y_std=Y_std, noise_factor=0.1)
X = jnp.array(X)
Y = jnp.array(Y)
X_prime = X[:M]

X_train = X[:N_train]
Y_train = Y[:N_train]

X_val = X[-N_test:]
Y_val = Y[-N_test:]

def loss(parameters):
    lamb = parameters[0]
    sigma = parameters[1:]
    # train model
    alpha, X_prime, _= train_nonlinear_models_j(X_train, Y_train, M=M, lamb=lamb, sigma=sigma, kernel_fn=kernel_expanded_j)

    # predict using validation set
    K_val = kernel_expanded_j(X_val, X_prime, sigma)
    Y_pred = K_val @ alpha

    mse = jnp.mean((Y_val - Y_pred) ** 2)
    return mse

# create a function that calculates the gradient of the loss function using jax.grad
grad_loss = jax.grad(loss)

initial_lamb = 1E-4
std_force = max_force / (3**0.5)  # standard deviation for force
x_sigma = get_std(X)
# initial_sigma = jnp.array([6, 6, 0.5, 0.5, 6])
std_sine, std_cos = (0.125)**0.5, (0.125)**0.5  # standard deviation for sine and cosine
initial_sigma = jnp.array([x_sigma[0], x_sigma[1], std_sine, std_cos, x_sigma[-1], std_force])
initial_hyperparameters = jnp.array([initial_lamb] + initial_sigma.tolist())

losses = [loss(initial_hyperparameters)]

def callback(intermediate_result):
    print(intermediate_result)
    losses.append(intermediate_result.fun)



bounds = [(1E-6, 1E-1)] + [(0, 30)] + [(0, 40)] + [(0, 1)] * 2 +[(0, 10)] + [(0, max_force*3)]  # bounds for lamb and sigma
res = scipy.optimize.minimize(loss, x0=initial_hyperparameters, method='L-BFGS-B', jac=grad_loss, bounds=bounds, callback=callback)

In [None]:
print("optimal lambda:", res.x[0])
print("optimal sigma:", res.x[1:])
print("initial loss", losses[0])
print("Final loss:", res.fun)
print("number of iterations:", len(losses) - 1)

def plot_loss(losses):
    plt.figure(figsize=(10, 6))
    plt.rcParams.update({'font.size': 16})
    plt.plot(losses, label='Loss')
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Loss over iterations')
    plt.grid()
    plt.legend()
    plt.show()

plot_loss(losses)

In [None]:
N, M, lamb = 4096, 1024, res.x[0]
# N, M, lamb = 4096, 1024, 4.543e-03 # mine
# N, M, lamb = 4096, 1024, 4.918e-04 # andrew
max_force=15
# generate training data
X, Y = generate_data_random_force_observed_noise(num_steps=N, max_force=max_force, y_std=Y_std, noise_factor=0.1)

# Get the standard deviation of X
sigma = res.x[1:]
# sigma = np.array([1.000e+01,  1.000e+01,  9.916e-01,  6.063e-01, 7.005e+00,  2.000e+01]) # mine
# sigma = np.array([15.41,  1.413e+01,  5.24,  0.97, 7.356,  13.52])    # andrew

# train model
alpha, X_prime, K = train_nonlinear_models(X, Y, M=M, lamb=lamb, sigma=sigma, kernel_fn=kernel_expanded)

# predict using training set
Y_pred = K @ alpha

# plot_fit(X, Y, Y_pred, graph_title="Fit of the model")
plot_fit(Y, Y_pred, graph_title="Change in state")


# Example initial states for testing
initial_states = [[0, -2, np.pi, 4, 1], [0, 0, np.pi, 5, 2], [0, 0, np.pi, 0, 10], [0, 0, 0.1, 0, 8]]
# initial_states = [[0, 0, np.pi, 0, 15]]
for initial_state in initial_states:
    forecast_nonlinear_force(initial_state, num_steps=100, alpha=alpha, sigma=sigma, X_prime=X_prime, kernel_fn=kernel_expanded)

non_linear_model_sin_force = {
    'lambda': res.x[0],
    'sigma': res.x[1:],
    'alpha': alpha,
    'X_prime': X_prime,
}

In [None]:
# train policy
initial_state = jnp.array([0, 0, 0.1, 0])
sigma = jnp.array([10, 10, 7, 10])
# sigma = jnp.array([5.8, 5.8, 1.5, 8.5])  # Example sigma values
num_steps = 30
target = jnp.array([0, 0, 0, 0])
max_force = 8

initial_p = jnp.array([1, 1, 1, 1])  # initial policy parameters for upright initial starting position

model_x_prime = jnp.array(non_linear_model_sin_force['X_prime'])
model_sigma = jnp.array(non_linear_model_sin_force['sigma'])
model_alpha = jnp.array(non_linear_model_sin_force['alpha'])

@jax.jit
def kernel_expanded_jax(X, X_prime, sigma):
    # create new X where 2 additional dimensions are added, replacing the angle with sin and cos
    # angle dimension is removed
    X_new = jnp.hstack((X[:,0:2], jnp.sin(X[:, 2:3]), jnp.cos(X[:, 2:3]), X[:, 3:]))  # (N, D+1)
    X_prime_new = jnp.hstack((X_prime[:,0:2], jnp.sin(X_prime[:, 2:3]), jnp.cos(X_prime[:, 2:3]), X_prime[:, 3:]))  # (M, D+1)

    X_e = jnp.expand_dims(X_new, axis=1)  # (N, 1, D+1)
    X_prime_e = jnp.expand_dims(X_prime_new, axis=0)  # (1, M, D+1)

    diff = X_e - X_prime_e  # (N, M, D+1)
    scaled_squared_diff = (diff ** 2)/(2 * sigma ** 2) # (N, M, D+1)

    K = jnp.exp(-jnp.sum(scaled_squared_diff, axis=-1))  # (N, M)
    return K

@jax.jit
def loss_policy_jax(state, target, sigma):
    delta = (state - target) / sigma
    exponent = 0.5 * jnp.dot(delta, delta)
    return 1 - jnp.exp(-exponent)

@jax.jit
def loss_rollout_linear_jax(P):
    def scan_step_jax(state, _):
        force = P @ state
        force = max_force * jnp.tanh(force/max_force)

        # add force as the last element to the state
        current_state = jnp.concatenate([state, jnp.array([force])])
        K = kernel_expanded_jax(jnp.expand_dims(current_state, axis=0), model_x_prime, model_sigma)
        Y_pred = K @ model_alpha
        next_state = jnp.ravel(current_state[:-1] + Y_pred)

        loss = loss_policy_jax(next_state, target, sigma)
        return next_state, loss

    init_loss = loss_policy_jax(initial_state, target, sigma)
    _, losses = jax.lax.scan(scan_step_jax, initial_state, None, length=num_steps)
    return init_loss + losses.sum()

grad_loss_linear_jax = jax.grad(loss_rollout_linear_jax)

losses = [loss_rollout_linear_jax(initial_p)]
print("Initial loss:", losses[0])

def callback(intermediate_result):
    print("Iteration:", len(losses))
    print("P:", intermediate_result.x)
    print("Loss:", intermediate_result.fun)
    print()
    losses.append(intermediate_result.fun)

res = scipy.optimize.minimize(loss_rollout_linear_jax, x0=initial_p, method='L-BFGS-B', jac=grad_loss_linear_jax, callback=callback, bounds=[(-30, 30)] * 4) 

def rollout_linear_force(initial_state, num_steps, P, max_force):
    X_forecast = [initial_state.copy()]
    
    state = initial_state.copy()
    for step in range(num_steps):
        force = P @ state
        force = max_force * np.tanh(force/max_force)
        # add force as the last element to the state
        current_state = jnp.concatenate([state, jnp.array([force])])
        K = kernel_expanded_jax(jnp.expand_dims(current_state, axis=0), model_x_prime, model_sigma)

        Y_pred = K @ model_alpha
        state = jnp.ravel(current_state[:-1] + Y_pred)

        remapped_state = state.copy()
        # remap jax angle to be between -pi and pi PURELY FOR PLOTTING
        remapped_state = np.array([remapped_state[0], remapped_state[1], remap_angle2(remapped_state[2]), remapped_state[3]])
        X_forecast.append(remapped_state)

    return np.array(X_forecast)

P = res.x  # optimised policy matrix
X = rollout_linear_force(initial_state=initial_state, num_steps=50, P=P, max_force=max_force)
plot_policy(X, target, f"Rollout Policy, initial_state: {initial_state}")


---
# upside down

In [None]:
N_train, N_test = 4096, 2048
X_no_noise, Y_no_noise = generate_data_random_force(num_steps=N_train+N_test)
Y_std = get_std(Y_no_noise)
print("Standard deviation of Y:", Y_std)

In [None]:
# train sine cosine model

N_train, N_test, M = 4096, 2048, 1024
max_force = 15

X, Y = generate_data_random_force_observed_noise(num_steps=N_train+N_test, max_force=max_force, y_std=Y_std, noise_factor=0.1)
X = jnp.array(X)
Y = jnp.array(Y)
X_prime = X[:M]

X_train = X[:N_train]
Y_train = Y[:N_train]

X_val = X[-N_test:]
Y_val = Y[-N_test:]

def loss(parameters):
    lamb = parameters[0]
    sigma = parameters[1:]
    # train model
    alpha, X_prime, _= train_nonlinear_models_j(X_train, Y_train, M=M, lamb=lamb, sigma=sigma, kernel_fn=kernel_expanded_j)

    # predict using validation set
    K_val = kernel_expanded_j(X_val, X_prime, sigma)
    Y_pred = K_val @ alpha

    mse = jnp.mean((Y_val - Y_pred) ** 2)
    return mse

# create a function that calculates the gradient of the loss function using jax.grad
grad_loss = jax.grad(loss)

initial_lamb = 1E-4
std_force = max_force / (3**0.5)  # standard deviation for force
x_sigma = get_std(X)
# initial_sigma = jnp.array([6, 6, 0.5, 0.5, 6])
std_sine, std_cos = (0.125)**0.5, (0.125)**0.5  # standard deviation for sine and cosine
initial_sigma = jnp.array([x_sigma[0], x_sigma[1], std_sine, std_cos, x_sigma[-1], std_force])
initial_hyperparameters = jnp.array([initial_lamb] + initial_sigma.tolist())

losses = [loss(initial_hyperparameters)]

def callback(intermediate_result):
    print(intermediate_result)
    losses.append(intermediate_result.fun)



bounds = [(1E-6, 1E-1)] + [(0, 30)] + [(0, 40)] + [(0, 1)] * 2 +[(0, 10)] + [(0, max_force*3)]  # bounds for lamb and sigma
res = scipy.optimize.minimize(loss, x0=initial_hyperparameters, method='L-BFGS-B', jac=grad_loss, bounds=bounds, callback=callback)

In [None]:
N, M, lamb = 4096, 1024, res.x[0]
# N, M, lamb = 4096, 1024, 4.543e-03 # mine
# N, M, lamb = 4096, 1024, 4.918e-04 # andrew
max_force=15
# generate training data
X, Y = generate_data_random_force_observed_noise(num_steps=N, max_force=max_force, y_std=Y_std, noise_factor=0.1)

# Get the standard deviation of X
sigma = res.x[1:]
# sigma = np.array([1.000e+01,  1.000e+01,  9.916e-01,  6.063e-01, 7.005e+00,  2.000e+01]) # mine
# sigma = np.array([15.41,  1.413e+01,  5.24,  0.97, 7.356,  13.52])    # andrew

# train model
alpha, X_prime, K = train_nonlinear_models(X, Y, M=M, lamb=lamb, sigma=sigma, kernel_fn=kernel_expanded)

# predict using training set
Y_pred = K @ alpha

# plot_fit(X, Y, Y_pred, graph_title="Fit of the model")
plot_fit(Y, Y_pred, graph_title="Change in state")


# Example initial states for testing
initial_states = [[0, -2, np.pi, 4, 1], [0, 0, np.pi, 5, 2], [0, 0, np.pi, 0, 10], [0, 0, 0.1, 0, 8]]
# initial_states = [[0, 0, np.pi, 0, 15]]
for initial_state in initial_states:
    forecast_nonlinear_force(initial_state, num_steps=100, alpha=alpha, sigma=sigma, X_prime=X_prime, kernel_fn=kernel_expanded)

non_linear_model_sin_force_down = {
    'lambda': res.x[0],
    'sigma': res.x[1:],
    'alpha': alpha,
    'X_prime': X_prime,
}

In [None]:
# train policy
initial_state = jnp.array([0, 0, np.pi, 0])
sigma = jnp.array([10, 8, 4, 10])
# sigma = jnp.array([5.8, 5.8, 1.5, 8.5])  # Example sigma values
num_steps = 50
target = jnp.array([0, 0, 0, 0])
max_force = 10

initial_p = jnp.array([5, 5, 5, 5])  # initial policy parameters for upright initial starting position

model_x_prime = jnp.array(non_linear_model_sin_force_down['X_prime'])
model_sigma = jnp.array(non_linear_model_sin_force_down['sigma'])
model_alpha = jnp.array(non_linear_model_sin_force_down['alpha'])

@jax.jit
def kernel_expanded_jax(X, X_prime, sigma):
    # create new X where 2 additional dimensions are added, replacing the angle with sin and cos
    # angle dimension is removed
    X_new = jnp.hstack((X[:,0:2], jnp.sin(X[:, 2:3]), jnp.cos(X[:, 2:3]), X[:, 3:]))  # (N, D+1)
    X_prime_new = jnp.hstack((X_prime[:,0:2], jnp.sin(X_prime[:, 2:3]), jnp.cos(X_prime[:, 2:3]), X_prime[:, 3:]))  # (M, D+1)

    X_e = jnp.expand_dims(X_new, axis=1)  # (N, 1, D+1)
    X_prime_e = jnp.expand_dims(X_prime_new, axis=0)  # (1, M, D+1)

    diff = X_e - X_prime_e  # (N, M, D+1)
    scaled_squared_diff = (diff ** 2)/(2 * sigma ** 2) # (N, M, D+1)

    K = jnp.exp(-jnp.sum(scaled_squared_diff, axis=-1))  # (N, M)
    return K

@jax.jit
def loss_policy_jax(state, target, sigma):
    delta = (state - target) / sigma
    exponent = 0.5 * jnp.dot(delta, delta)
    return 1 - jnp.exp(-exponent)

@jax.jit
def loss_rollout_linear_jax(P):
    def scan_step_jax(state, _):
        force = P @ state
        force = max_force * jnp.tanh(force/max_force)

        # add force as the last element to the state
        current_state = jnp.concatenate([state, jnp.array([force])])
        K = kernel_expanded_jax(jnp.expand_dims(current_state, axis=0), model_x_prime, model_sigma)
        Y_pred = K @ model_alpha
        next_state = jnp.ravel(current_state[:-1] + Y_pred)

        loss = loss_policy_jax(next_state, target, sigma)
        return next_state, loss

    init_loss = loss_policy_jax(initial_state, target, sigma)
    _, losses = jax.lax.scan(scan_step_jax, initial_state, None, length=num_steps)
    return init_loss + losses.sum()

grad_loss_linear_jax = jax.grad(loss_rollout_linear_jax)

losses = [loss_rollout_linear_jax(initial_p)]
print("Initial loss:", losses[0])

def callback(intermediate_result):
    print("Iteration:", len(losses))
    print("P:", intermediate_result.x)
    print("Loss:", intermediate_result.fun)
    print()
    losses.append(intermediate_result.fun)

res = scipy.optimize.minimize(loss_rollout_linear_jax, x0=initial_p, method='L-BFGS-B', jac=grad_loss_linear_jax, callback=callback, bounds=[(-20, 20)] * 2 + [(-30, 30)] + [(-20, 20)]) 

def rollout_linear_force(initial_state, num_steps, P, max_force):
    X_forecast = [initial_state.copy()]
    
    state = initial_state.copy()
    for step in range(num_steps):
        force = P @ state
        force = max_force * np.tanh(force/max_force)
        # add force as the last element to the state
        current_state = jnp.concatenate([state, jnp.array([force])])
        K = kernel_expanded_jax(jnp.expand_dims(current_state, axis=0), model_x_prime, model_sigma)

        Y_pred = K @ model_alpha
        state = jnp.ravel(current_state[:-1] + Y_pred)

        remapped_state = state.copy()
        # remap jax angle to be between -pi and pi PURELY FOR PLOTTING
        remapped_state = np.array([remapped_state[0], remapped_state[1], remap_angle2(remapped_state[2]), remapped_state[3]])
        X_forecast.append(remapped_state)

    return np.array(X_forecast)

P = res.x  # optimised policy matrix
X = rollout_linear_force(initial_state=initial_state, num_steps=25, P=P, max_force=max_force)
plot_policy(X, target, f"Rollout Policy, initial_state: {initial_state}")

---
### Real dynamics noise

In [None]:
def generate_data_random_force_real_noise(num_steps, max_force, std):
    env = CartPole(visual=False, max_force=max_force)
    env.reset()
    x_data = {
        'cart_location': [],
        'cart_velocity': [],
        'pole_angle': [],
        'pole_velocity': [],
        'force': []
    }

    y_data = {
        'cart_location': [],
        'cart_velocity': [],
        'pole_angle': [],
        'pole_velocity': []
    }
    for i in range(num_steps):
        initial_state = [np.random.uniform(-10, 10), np.random.uniform(-10, 10),
                         np.random.uniform(-np.pi, np.pi), np.random.uniform(-15, 15)]
        initial_force = np.random.uniform(-env.max_force, env.max_force)
        env.reset()
        env.setState(initial_state)
        env.performAction_noise(initial_force, std)

        # remap the angle to be between -pi and pi
        # env.remap_angle()
        
        next_state = env.getState()
    
        x_data['cart_location'].append(initial_state[0])
        x_data['cart_velocity'].append(initial_state[1])
        x_data['pole_angle'].append(initial_state[2])
        x_data['pole_velocity'].append(initial_state[3])
        x_data['force'].append(initial_force)

        y_data['cart_location'].append(next_state[0] - initial_state[0])
        y_data['cart_velocity'].append(next_state[1] - initial_state[1])
        y_data['pole_angle'].append(next_state[2] - initial_state[2])
        y_data['pole_velocity'].append(next_state[3] - initial_state[3])

    X = convert_dict_to_array(x_data)
    Y = convert_dict_to_array(y_data)
    
    # print("shape of X:", X.shape, "\nshape of Y:", Y.shape)
    return X, Y

In [None]:
# train sine cosine model

N_train, N_test, M = 4096, 2048, 1024
max_force = 15
state_std = 0.05

X, Y = generate_data_random_force_real_noise(num_steps=N_train+N_test, max_force=max_force, std=state_std)
X = jnp.array(X)
Y = jnp.array(Y)
X_prime = X[:M]

X_train = X[:N_train]
Y_train = Y[:N_train]

X_val = X[-N_test:]
Y_val = Y[-N_test:]

def loss(parameters):
    lamb = parameters[0]
    sigma = parameters[1:]
    # train model
    alpha, X_prime, _= train_nonlinear_models_j(X_train, Y_train, M=M, lamb=lamb, sigma=sigma, kernel_fn=kernel_expanded_j)

    # predict using validation set
    K_val = kernel_expanded_j(X_val, X_prime, sigma)
    Y_pred = K_val @ alpha

    mse = jnp.mean((Y_val - Y_pred) ** 2)
    return mse

# create a function that calculates the gradient of the loss function using jax.grad
grad_loss = jax.grad(loss)

initial_lamb = 1E-4
std_force = max_force / (3**0.5)  # standard deviation for force
x_sigma = get_std(X)
# initial_sigma = jnp.array([6, 6, 0.5, 0.5, 6])
std_sine, std_cos = (0.125)**0.5, (0.125)**0.5  # standard deviation for sine and cosine
initial_sigma = jnp.array([x_sigma[0], x_sigma[1], std_sine, std_cos, x_sigma[-1], std_force])
initial_hyperparameters = jnp.array([initial_lamb] + initial_sigma.tolist())

losses = [loss(initial_hyperparameters)]

def callback(intermediate_result):
    print(intermediate_result)
    losses.append(intermediate_result.fun)



bounds = [(1E-6, 1E-1)] + [(0, 30)] + [(0, 40)] + [(0, 1)] * 2 +[(0, 10)] + [(0, max_force*3)]  # bounds for lamb and sigma
res = scipy.optimize.minimize(loss, x0=initial_hyperparameters, method='L-BFGS-B', jac=grad_loss, bounds=bounds, callback=callback)

In [None]:
N, M, lamb = 4096, 1024, res.x[0]
# N, M, lamb = 4096, 1024, 4.543e-03 # mine
# N, M, lamb = 4096, 1024, 4.918e-04 # andrew

# generate training data
X, Y = generate_data_random_force_real_noise(num_steps=N, max_force=max_force, std=state_std)

# Get the standard deviation of X
sigma = res.x[1:]
# sigma = np.array([1.000e+01,  1.000e+01,  9.916e-01,  6.063e-01, 7.005e+00,  2.000e+01]) # mine
# sigma = np.array([15.41,  1.413e+01,  5.24,  0.97, 7.356,  13.52])    # andrew

# train model
alpha, X_prime, K = train_nonlinear_models(X, Y, M=M, lamb=lamb, sigma=sigma, kernel_fn=kernel_expanded)

# predict using training set
Y_pred = K @ alpha

# plot_fit(X, Y, Y_pred, graph_title="Fit of the model")
plot_fit(Y, Y_pred, graph_title="Change in state")


# Example initial states for testing
initial_states = [[0, -2, np.pi, 4, 1], [0, 0, np.pi, 5, 2], [0, 0, np.pi, 0, 10], [0, 0, 0.1, 0, 8]]
# initial_states = [[0, 0, np.pi, 0, 15]]
for initial_state in initial_states:
    forecast_nonlinear_force(initial_state, num_steps=100, alpha=alpha, sigma=sigma, X_prime=X_prime, kernel_fn=kernel_expanded)

non_linear_model_sin_force_real_noise_up = {
    'lambda': res.x[0],
    'sigma': res.x[1:],
    'alpha': alpha,
    'X_prime': X_prime,
}

In [None]:
# train policy
initial_state = jnp.array([0, 0, 0.1, 0])
sigma = jnp.array([10, 10, 10, 10])
# sigma = jnp.array([5.8, 5.8, 1.5, 8.5])  # Example sigma values
num_steps = 25
target = jnp.array([0, 0, 0, 0])
max_force = 7.5

model_x_prime = jnp.array(non_linear_model_sin_force_real_noise_up['X_prime'])
model_sigma = jnp.array(non_linear_model_sin_force_real_noise_up['sigma'])
model_alpha = jnp.array(non_linear_model_sin_force_real_noise_up['alpha'])

@jax.jit
def kernel_expanded_jax(X, X_prime, sigma):
    # create new X where 2 additional dimensions are added, replacing the angle with sin and cos
    # angle dimension is removed
    X_new = jnp.hstack((X[:,0:2], jnp.sin(X[:, 2:3]), jnp.cos(X[:, 2:3]), X[:, 3:]))  # (N, D+1)
    X_prime_new = jnp.hstack((X_prime[:,0:2], jnp.sin(X_prime[:, 2:3]), jnp.cos(X_prime[:, 2:3]), X_prime[:, 3:]))  # (M, D+1)

    X_e = jnp.expand_dims(X_new, axis=1)  # (N, 1, D+1)
    X_prime_e = jnp.expand_dims(X_prime_new, axis=0)  # (1, M, D+1)

    diff = X_e - X_prime_e  # (N, M, D+1)
    scaled_squared_diff = (diff ** 2)/(2 * sigma ** 2) # (N, M, D+1)

    K = jnp.exp(-jnp.sum(scaled_squared_diff, axis=-1))  # (N, M)
    return K

@jax.jit
def loss_policy_jax(state, target, sigma):
    delta = (state - target) / sigma
    exponent = 0.5 * jnp.dot(delta, delta)
    return 1 - jnp.exp(-exponent)

@jax.jit
def loss_rollout_linear_jax(P):
    def scan_step_jax(state, _):
        force = P @ state
        force = max_force * jnp.tanh(force/max_force)

        # add force as the last element to the state
        current_state = jnp.concatenate([state, jnp.array([force])])
        K = kernel_expanded_jax(jnp.expand_dims(current_state, axis=0), model_x_prime, model_sigma)
        Y_pred = K @ model_alpha
        next_state = jnp.ravel(current_state[:-1] + Y_pred)

        loss = loss_policy_jax(next_state, target, sigma)
        return next_state, loss

    init_loss = loss_policy_jax(initial_state, target, sigma)
    _, losses = jax.lax.scan(scan_step_jax, initial_state, None, length=num_steps)
    return init_loss + losses.sum()

initial_p = jnp.array([1, 1, 1, 1])  # initial policy parameters for upright initial starting position

grad_loss_linear_jax = jax.grad(loss_rollout_linear_jax)

losses = [loss_rollout_linear_jax(initial_p)]
print("Initial loss:", losses[0])

def callback(intermediate_result):
    print("Iteration:", len(losses))
    print("P:", intermediate_result.x)
    print("Loss:", intermediate_result.fun)
    print()
    losses.append(intermediate_result.fun)

res = scipy.optimize.minimize(loss_rollout_linear_jax, x0=initial_p, method='L-BFGS-B', jac=grad_loss_linear_jax, callback=callback, bounds=[(-30, 30)] * len(initial_p))

def rollout_linear_force(initial_state, num_steps, P, max_force):
    X_forecast = [initial_state.copy()]
    
    state = initial_state.copy()
    for step in range(num_steps):
        force = P @ state
        force = max_force * np.tanh(force/max_force)
        # add force as the last element to the state
        current_state = jnp.concatenate([state, jnp.array([force])])
        K = kernel_expanded_jax(jnp.expand_dims(current_state, axis=0), model_x_prime, model_sigma)

        Y_pred = K @ model_alpha
        state = jnp.ravel(current_state[:-1] + Y_pred)

        remapped_state = state.copy()
        # remap jax angle to be between -pi and pi PURELY FOR PLOTTING
        remapped_state = np.array([remapped_state[0], remapped_state[1], remap_angle2(remapped_state[2]), remapped_state[3]])
        X_forecast.append(remapped_state)

    return np.array(X_forecast)

P = res.x  # optimised policy matrix
X = rollout_linear_force(initial_state=initial_state, num_steps=50, P=P, max_force=max_force)
plot_policy(X, target, f"Rollout Policy, initial_state: {initial_state}")

In [None]:
initial_state = jnp.array([0, 0, np.pi, 0])
sigma = jnp.array([10, 10, 4, 10])
num_steps = 50
target = jnp.array([0, 0, 0, 0])
max_force = 10
initial_p = jnp.array([5, 5, 5, 5])  # initial policy parameters for upright initial starting position

model_x_prime = jnp.array(non_linear_model_sin_force_real_noise_up['X_prime'])
model_sigma = jnp.array(non_linear_model_sin_force_real_noise_up['sigma'])
model_alpha = jnp.array(non_linear_model_sin_force_real_noise_up['alpha'])

@jax.jit
def kernel_expanded_jax(X, X_prime, sigma):
    # create new X where 2 additional dimensions are added, replacing the angle with sin and cos
    # angle dimension is removed
    X_new = jnp.hstack((X[:,0:2], jnp.sin(X[:, 2:3]), jnp.cos(X[:, 2:3]), X[:, 3:]))  # (N, D+1)
    X_prime_new = jnp.hstack((X_prime[:,0:2], jnp.sin(X_prime[:, 2:3]), jnp.cos(X_prime[:, 2:3]), X_prime[:, 3:]))  # (M, D+1)

    X_e = jnp.expand_dims(X_new, axis=1)  # (N, 1, D+1)
    X_prime_e = jnp.expand_dims(X_prime_new, axis=0)  # (1, M, D+1)

    diff = X_e - X_prime_e  # (N, M, D+1)
    scaled_squared_diff = (diff ** 2)/(2 * sigma ** 2) # (N, M, D+1)

    K = jnp.exp(-jnp.sum(scaled_squared_diff, axis=-1))  # (N, M)
    return K

@jax.jit
def loss_policy_jax(state, target, sigma):
    delta = (state - target) / sigma
    exponent = 0.5 * jnp.dot(delta, delta)
    return 1 - jnp.exp(-exponent)

@jax.jit
def loss_rollout_linear_jax(P):
    def scan_step_jax(state, _):
        force = P @ state
        force = max_force * jnp.tanh(force/max_force)

        # add force as the last element to the state
        current_state = jnp.concatenate([state, jnp.array([force])])
        K = kernel_expanded_jax(jnp.expand_dims(current_state, axis=0), model_x_prime, model_sigma)
        Y_pred = K @ model_alpha
        next_state = jnp.ravel(current_state[:-1] + Y_pred)

        loss = loss_policy_jax(next_state, target, sigma)
        return next_state, loss

    init_loss = loss_policy_jax(initial_state, target, sigma)
    _, losses = jax.lax.scan(scan_step_jax, initial_state, None, length=num_steps)
    return init_loss + losses.sum()

grad_loss_linear_jax = jax.grad(loss_rollout_linear_jax)

losses = [loss_rollout_linear_jax(initial_p)]
print("Initial loss:", losses[0])

def callback(intermediate_result):
    print("Iteration:", len(losses))
    print("P:", intermediate_result.x)
    print("Loss:", intermediate_result.fun)
    print()
    losses.append(intermediate_result.fun)

res = scipy.optimize.minimize(loss_rollout_linear_jax, x0=initial_p, method='L-BFGS-B', jac=grad_loss_linear_jax, callback=callback, bounds=[(-30, 30)] * len(initial_p))

def rollout_linear_force(initial_state, num_steps, P, max_force):
    X_forecast = [initial_state.copy()]
    
    state = initial_state.copy()
    for step in range(num_steps):
        force = P @ state
        force = max_force * np.tanh(force/max_force)
        # add force as the last element to the state
        current_state = jnp.concatenate([state, jnp.array([force])])
        K = kernel_expanded_jax(jnp.expand_dims(current_state, axis=0), model_x_prime, model_sigma)

        Y_pred = K @ model_alpha
        state = jnp.ravel(current_state[:-1] + Y_pred)

        remapped_state = state.copy()
        # remap jax angle to be between -pi and pi PURELY FOR PLOTTING
        remapped_state = np.array([remapped_state[0], remapped_state[1], remap_angle2(remapped_state[2]), remapped_state[3]])
        X_forecast.append(remapped_state)

    return np.array(X_forecast)

P = res.x  # optimised policy matrix
X = rollout_linear_force(initial_state=initial_state, num_steps=25, P=P, max_force=max_force)
plot_policy(X, target, f"Rollout Policy, initial_state: {initial_state}")