In [None]:
!pip install tensorflow

In [None]:
pip install --upgrade "jax[cuda]"

In [None]:
%load_ext autoreload
%autoreload 2

## Testing polynomial trajectory generation

In [None]:
from generate_data import generate_polynomial_trajectory
import numpy as np
import matplotlib.pyplot as plt

start = np.zeros(3)
end = np.ones(3)
T = 101
order = 2

trajectory = generate_polynomial_trajectory(start, end, T, order)

print(trajectory.shape)

# Plot the results
plt.plot(trajectory[:, 0], trajectory[:, 1])
plt.title('Polynomial Reference Trajectory')
plt.xlabel('Time Step')
plt.ylabel('State')
plt.show()

## Data generation

In [None]:
from generate_data import gen_uni_training_data, ILQR, unicycle, save_object
import jax.numpy as onp

uni_ilqr1 = ILQR(unicycle, maxiter=1000)
num_iter = 50
xtraj, rtraj, rdottraj, costs = gen_uni_training_data(uni_ilqr1, num_iter, 6, 3)

# Save as pickle file
save_object([xtraj, rtraj, rdottraj, costs], 'data/uni_train-nonoise2.pkl')

## Load data

In [None]:
from helper_functions import compute_tracking_cost
from mlp_jax import MLP
from generate_data import load_object
from model_learning import TrajDataset, train_model, eval_model, numpy_collate, save_checkpoint, restore_checkpoint
import numpy as np


file_path = r"/home/anusha/Research/Layered-architecture-quadrotor-control/Simulations/data/uni_train-nonoise2.pkl"
unicycle_data = load_object(file_path)

actual_traj = np.vstack(unicycle_data[0])
ref_traj = np.vstack(unicycle_data[1])
rdot_traj = np.vstack(unicycle_data[2])
print(actual_traj.shape)
print(ref_traj.shape)
print(rdot_traj.shape)

## Adding noise to trajectories

In [None]:
from helper_functions import forward_simulate, compute_rdot
import jax

noise_level = [0.001, 0.002, 0.003, 0.004, 0.005]

Kp = 50 * np.array([[2, 0, 0], [0, 1, 0]])
key = jax.random.PRNGKey(793)
Kd = 50 * jax.random.uniform(key=key, shape=(2, 3))
N = 101

for i in range(len(noise_level)):
    for j in range(num_iter):
        noise = np.random.normal(0, noise_level[i], (N, 2))
        noisy_ref = np.zeros(shape=(N, 3))
        noisy_ref[:, 0] = ref_traj[j*N:(j+1)*N, 0] + noise[:, 0]
        noisy_ref[:, 1:3] = ref_traj[j*N:(j+1)*N, 1:3]
        ref_traj = np.vstack([ref_traj, noisy_ref])
        cost, act = forward_simulate(noisy_ref[0, :], noisy_ref, Kp, Kd, N)
        actual_traj = np.vstack([actual_traj, act])
        rdot = compute_rdot(noisy_ref, 0.01)
        rdot_traj = np.vstack([rdot_traj, rdot])

In [None]:
print(actual_traj.shape)
print(ref_traj.shape)
print(rdot_traj.shape)

## Look at trajectories

In [None]:
%matplotlib notebook

In [None]:
import matplotlib.pyplot as plt

plt.figure()
Tstart = 0
Tend = 5050
plt.plot(ref_traj[Tstart:Tend, 0], ref_traj[Tstart:Tend, 1], 'r-', 
         actual_traj[Tstart:Tend, 0], actual_traj[Tstart:Tend, 1], 'b--')

In [None]:
plt.figure()
Tstart = 0
Tend = 6006
plt.plot(ref_traj[Tstart:Tend, 2], 'r-', 
         actual_traj[Tstart:Tend, 2], 'b--')

## Prepare data with augmented states and compute cost

In [None]:
import jax

N = 101
q = 2
p = 3 + 3*N
traj_len = ref_traj.shape[0]

# ref_chunks = np.vstack([ref_traj[i*100+5:(i+1)*100-5, :] for i in range(0, len(ref_traj))])
# actual_chunks = np.vstack([actual_traj[i*101+5:(i+1)*101-6, :] for i in range(0, len(actual_traj))])

Kp = 50 * np.array([[2, 0, 0], [0, 1, 0]])
key = jax.random.PRNGKey(793)
Kd = 50 * jax.random.uniform(key=key, shape=(2, 3))

cost_traj, input_traj = compute_tracking_cost(ref_traj, actual_traj, rdot_traj, Kp, Kd, N)

aug_state = [np.append(actual_traj[r, :], ref_traj[r:r+N, :]) for r in range(len(ref_traj)-N)]
aug_state = np.array(aug_state)

Tstart=0
Tend=traj_len-2000

dataset = TrajDataset(aug_state[Tstart:Tend-1, :].astype('float64'), input_traj[Tstart:Tend-1, :].astype('float64'),
                               cost_traj[Tstart:Tend-1, None].astype('float64'), aug_state[Tstart+1:Tend, :].astype('float64'))

## Define model parameters and train

In [None]:
# Load model parameters using yaml file
import ruamel.yaml as yaml

with open(r"/home/anusha/Research/Layered-architecture-quadrotor-control/Simulations/data/params.yaml") as f:
        yaml_data = yaml.load(f, Loader=yaml.RoundTripLoader)

num_hidden = yaml_data['num_hidden']
batch_size = yaml_data['batch_size']
learning_rate = yaml_data['learning_rate']
num_epochs = yaml_data['num_epochs']
model_save = yaml_data['save_path']

In [None]:
model = MLP(num_hidden=num_hidden, num_outputs=1)
# Printing the model shows its attributes
print(model)

In [None]:
rng = jax.random.PRNGKey(427)
rng, inp_rng, init_rng = jax.random.split(rng, 3)
inp = jax.random.normal(inp_rng, (batch_size, p))  # Batch size 64, input size p
# Initialize the model
params = model.init(init_rng, inp)

In [None]:
# Run only if error in next cell
try:
    import optax
except ModuleNotFoundError: 
    !pip install --quiet optax
    import optax 

In [None]:
import optax # Run previous cell if optax not found
# Input to the optimizer are optimizer settings like learning rate
optimizer = optax.adam(learning_rate=learning_rate)

In [None]:
from flax.training import train_state

model_state = train_state.TrainState.create(apply_fn=model.apply,
                                            params=params,
                                            tx=optimizer)

In [None]:
import torch.utils.data as data

train_data_loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=numpy_collate)
trained_model_state = train_model(model_state, train_data_loader, num_epochs=num_epochs)

In [None]:
eval_model(trained_model_state, train_data_loader, batch_size)

In [None]:
trained_model = model.bind(trained_model_state.params)

### Save model

In [None]:
save_checkpoint(trained_model_state, model_save, 0)

### Load model

In [None]:
trained_model_state = restore_checkpoint(model_state, model_save)

In [None]:
trained_model = model.bind(trained_model_state.params)

## Inference 

## Load inference data

In [None]:
import numpy as np

file_path = r"/home/anusha/Research/Layered-architecture-quadrotor-control/Simulations/data/uni_train-nonoise2.pkl"
unicycle_data = load_object(file_path)

actual_traj = np.vstack(unicycle_data[0])
ref_traj = np.vstack(unicycle_data[1])
rdot_traj = np.vstack(unicycle_data[2])

print(actual_traj.shape)
print(ref_traj.shape)

In [None]:
import torch.utils.data as data

Tstart=traj_len-1500
Tend=traj_len-101

N = 101
q = 2
p = 3 + 3*N
traj_len = ref_traj.shape[0]

# ref_chunks = np.vstack([ref_traj[i*100+5:(i+1)*100-5, :] for i in range(0, len(ref_traj))])
# actual_chunks = np.vstack([actual_traj[i*101+5:(i+1)*101-6, :] for i in range(0, len(actual_traj))])

Kp = 50 * np.array([[2, 0, 0], [0, 1, 0]])
key = jax.random.PRNGKey(793)
Kd = 50 * jax.random.uniform(key=key, shape=(2, 3))

cost_traj, input_traj = compute_tracking_cost(ref_traj, actual_traj, rdot_traj, Kp, Kd, N)

aug_state = [np.append(actual_traj[r, :], ref_traj[r:r+N, :]) for r in range(len(ref_traj)-N)]
aug_state = np.array(aug_state)

test_dataset = TrajDataset(aug_state[Tstart:Tend-1, :].astype('float64'), input_traj[Tstart:Tend-1, :].astype('float64'),
                               cost_traj[Tstart:Tend-1, None].astype('float64'), aug_state[Tstart+1:Tend, :].astype('float64'))
test_data_loader = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=numpy_collate)
eval_model(trained_model_state, test_data_loader, batch_size)

data_input, _, cost, _ = next(iter(test_data_loader))
out = trained_model(data_input)  # No explicit parameter passing necessary anymore
plt.figure()
plt.plot(out.ravel(), 'b-', label="Predictions")
plt.plot(cost.ravel(), 'r--', label="Actual")
plt.legend()
plt.title("MLP with JAX on hold out data")

### Inference on the test data

In [None]:
out = []
true = []
for batch in test_data_loader:
    data_input, _, cost, _ = batch
    out.append(trained_model(data_input))
    true.append(cost)

In [None]:
out = np.vstack(out)
true = np.vstack(true)

In [None]:
plt.figure()
plt.plot(out.ravel(), 'b-', label="Predictions")
plt.plot(true.ravel(), 'r--', label="Actual")
plt.legend()
plt.title("MLP with JAX on hold out data")

## Test time

### Optimize over the trajectories in the test dataset

In [None]:
from model_learning import test_model

data_state, _, data_cost, _ = next(iter(test_data_loader))
print(data_state[0])
print(data_cost[0])
solution, data = test_model(trained_model_state, test_data_loader, batch_size)

In [None]:
from helper_functions import forward_simulate

cost = []
reg_cost = []
rollout = []
ref = []

i = 0
for sol in solution:
    # Take the augmented state -- init state + ref and simulate the unicycle -- compute cost and verify
    reg_cost.append(sol.fun)
    new_aug_state = sol.x
    x0 = new_aug_state[0:3]
    print("init", x0)
    ref.append(new_aug_state.reshape([N, 3]))
    print("New ref", ref[i])
    # input
    c, x = forward_simulate(x0, ref[i], Kp, Kd, N)
    cost.append(c)
    rollout.append(x)
    print("rollout", x)
    i += 1
    
plt.figure()
plt.plot(data_cost.ravel(), '*', label="true cost")
plt.plot(reg_cost, 'o', label="optimized cost")
plt.legend()

In [None]:
ref = np.vstack(ref)
print(ref.shape)

In [None]:
new_data = np.vstack(data)
print(new_data.shape)

In [None]:
plt.figure()
m = 0
print(ref[m])
print(data[m])
# plt.plot(ref[0:10, 0], ref[0:10, 1], 'r--', label="new ref")
plt.plot(new_data[0, 0::3], new_data[0, 1::3], 'b--')
plt.xlabel("x position")
plt.ylabel("y position")
plt.legend()

## Evaluate on polynomial trajectories


In [None]:
from generate_data import generate_polynomial_trajectory
import numpy as np
import jax.numpy as jnp

num_inf = 100
inits = np.random.randint(0, 2, (2, num_inf))
inits = np.append(inits, np.zeros(num_inf))
inits = np.reshape(inits, (3, num_inf))

goals = np.random.randint(1, 3, (2, num_inf))
goals = np.append(goals, np.zeros(num_inf))
goals = np.reshape(goals, (3, num_inf))

N = 101

poly_traj = []
for i in range(num_inf):
    poly_traj.append(generate_polynomial_trajectory(inits[:, i], goals[:, i], 101, 2))
    
print(poly_traj[0].shape)

poly_aug_state = [np.append(poly_traj[r][0, :], poly_traj[r][:, :]) for r in range(len(poly_traj))]
poly_aug_state = jnp.array(poly_aug_state)
print(poly_aug_state.shape)

In [None]:
from model_learning import calculate_cost
from jax.scipy.optimize import minimize

solution = []
for i in range(3):
    solution.append(minimize(calculate_cost, poly_aug_state[i, 3:], 
        args=(poly_aug_state[i, :3], poly_aug_state[i, -3:], trained_model_state, trained_model_state.params), method="BFGS"))
    

In [None]:
# How to parse the solution at test time?
from helper_functions import forward_simulate

cost = []
reg_cost = []
rollout = []
ref = []

i = 0
for sol in solution:
    # Take the augmented state -- init state + ref and simulate the unicycle -- compute cost and verify
    reg_cost.append(sol.fun)
    new_aug_state = sol.x
    x0 = new_aug_state[0:3]
    print("init", x0)
    ref.append(new_aug_state.reshape([N, 3]))
    print("New ref", ref[i])
    # input
    c, x = forward_simulate(x0, ref[i], Kp, Kd, N)
    cost.append(c)
    rollout.append(x)
    print("rollout", x)
    i += 1
    
plt.figure()
plt.plot(data_cost.ravel(), '*', label="true cost")
plt.plot(reg_cost, 'o', label="optimized cost")
plt.legend()

In [None]:
plt.figure()
m = 0
print(ref[m])
print(data[m])
# plt.plot(ref[0:10, 0], ref[0:10, 1], 'r--', label="new ref")
plt.plot(new_data[0, 0::3], new_data[0, 1::3], 'b--')
plt.xlabel("x position")
plt.ylabel("y position")
plt.legend()

In [None]:
print(solution[0].fun)
print(solution[0].x)

## Visualization plots code

In [None]:
import pandas as pd
import seaborn as sns

df = pd.DataFrame(np.dstack([out, cost]).reshape([64, 2]), columns=['preds', 'actual'])

order=['preds', 'actual']
x = "Evaluation on Training"
y = "Tracking Cost"
#print(pd.melt(df))
flierprops = dict(marker='o', markerfacecolor='#FFFFFF', markersize=4,
                  linestyle='none', markeredgecolor='#D3D3D3')
plt.figure()
axes = sns.boxplot(data=pd.melt(df, var_name=x, value_name=y), x=x, y=y, order=order, dodge=False, width=0.5, medianprops=dict(color='black'), 
                     # palette={labels[0]:"blue", labels[1]:"orange", labels[2]:"green", labels[3]:"red"}, saturation=1,
                      flierprops=flierprops,
                      showmeans=True, meanprops={"marker":"*", "markerfacecolor":"black", "markeredgecolor":"black"})
axes.set_xlabel(x, fontsize=15)
axes.set_ylabel(y, fontsize=15)
axes.set_ylim((0, 10))
# plt.savefig('/home/anusha/Downloads/icra_results/train.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
import matplotlib.pyplot as plt

data_input, _, cost, _ = next(iter(train_data_loader))
out = trained_model(data_input)  # No explicit parameter passing necessary anymore
plt.figure()
plt.plot(out.ravel(), 'o', label="Predictions")
plt.plot(cost.ravel(), 'o', label="Actual")
plt.legend()
plt.title("MLP with JAX on training data")

In [None]:
import pandas as pd
import seaborn as sns

df = pd.DataFrame(np.dstack([np.vstack(out), np.vstack(true)]).reshape([499, 2]), columns=['preds', 'actual'])

order=['preds', 'actual']
x = "Evaluation on Hold Out Data"
y = "Tracking Cost"
#print(pd.melt(df))
flierprops = dict(marker='o', markerfacecolor='#FFFFFF', markersize=4,
                  linestyle='none', markeredgecolor='#D3D3D3')
plt.figure()
axes = sns.boxplot(data=pd.melt(df, var_name=x, value_name=y), x=x, y=y, order=order, dodge=False, width=0.5, medianprops=dict(color='black'), 
                     # palette={labels[0]:"blue", labels[1]:"orange", labels[2]:"green", labels[3]:"red"}, saturation=1,
                      flierprops=flierprops,
                      showmeans=True, meanprops={"marker":"*", "markerfacecolor":"black", "markeredgecolor":"black"})
axes.set_xlabel(x, fontsize=15)
axes.set_ylabel(y, fontsize=15)
axes.set_ylim((0, 10))
# plt.savefig('/home/anusha/Downloads/icra_results/train.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
import pandas as pd
import seaborn as sns

df = pd.DataFrame(np.dstack([data_cost.ravel(), reg_cost]).reshape([64, 2]), columns=['actual', 'optimized'])

order=['actual', 'optimized']
x = "Evaluation on Optimizing new references"
y = "Tracking Cost"
#print(pd.melt(df))
flierprops = dict(marker='o', markerfacecolor='#FFFFFF', markersize=4,
                  linestyle='none', markeredgecolor='#D3D3D3')
plt.figure()
axes = sns.boxplot(data=pd.melt(df, var_name=x, value_name=y), x=x, y=y, order=order, dodge=False, width=0.5, medianprops=dict(color='black'), 
                     # palette={labels[0]:"blue", labels[1]:"orange", labels[2]:"green", labels[3]:"red"}, saturation=1,
                      flierprops=flierprops,
                      showmeans=True, meanprops={"marker":"*", "markerfacecolor":"black", "markeredgecolor":"black"})
axes.set_xlabel(x, fontsize=15)
axes.set_ylabel(y, fontsize=15)
axes.set_ylim((0, 1))
# plt.savefig('/home/anusha/Downloads/icra_results/train.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
df = pd.DataFrame(np.dstack(traj_costs).reshape([10, 2]), columns=['0.1', '10'])

order=['0.1', '10']
x = "Tracking Penalty (rho)"
y = "Tracking Cost"
#print(pd.melt(df))
flierprops = dict(marker='o', markerfacecolor='#FFFFFF', markersize=4,
                  linestyle='none', markeredgecolor='#D3D3D3')
axes = sns.boxplot(data=pd.melt(df, var_name=x, value_name=y), x=x, y=y, order=order, dodge=False, width=0.5, medianprops=dict(color='black'), 
                     # palette={labels[0]:"blue", labels[1]:"orange", labels[2]:"green", labels[3]:"red"}, saturation=1,
                      flierprops=flierprops,
                      showmeans=True, meanprops={"marker":"*", "markerfacecolor":"black", "markeredgecolor":"black"})
axes.set_xlabel(x, fontsize=15)
axes.set_ylabel(y, fontsize=15)
#axes.set_ylim((7, 9))
plt.savefig('/home/anusha/Downloads/icra_results/cost-rho-results.png', dpi=300, bbox_inches='tight')
plt.show()