In [10]:
import numpy as np
import matplotlib.pyplot as plt
import mat73
import time

import jax
import jax.numpy as jnp
from jax import random, jit
import matplotlib.pyplot as plt



# ------------------     MATLAB STUFF  ----------------------------------
def get_matlab_variables(mat_file_path):
    variables = mat73.loadmat(mat_file_path)
    #double gets converted to np array by default
    value_func_data = variables['Vx']
    lx_data = variables['lx'] 
    tau2 = variables['tau2']

    #Deriv is cell which gets converted into list of lists
    deriv_x_data = np.array(variables['Deriv'][0])
    deriv_x_data = deriv_x_data.squeeze()
    deriv_y_data = np.array(variables['Deriv'][1])
    deriv_y_data = deriv_y_data.squeeze()
    deriv_th_data = np.array(variables['Deriv'][2])
    deriv_th_data = deriv_th_data.squeeze()
    
    #uOpt is also cell which gets converted into list of lists
    uOpt_vel = np.array(variables['uOpt'][0])
    uOpt_vel = uOpt_vel.squeeze()
    uOpt_angle = np.array(variables['uOpt'][1])
    uOpt_angle = uOpt_angle.squeeze() 

    #g is struct whic gets converted into dic
    #vs is cell which give a list
    x_coord=np.array(variables['g']['vs'][0])
    y_coord=np.array(variables['g']['vs'][1])
    th_coord=np.array(variables['g']['vs'][2])
    x_coord = x_coord.squeeze()
    y_coord = y_coord.squeeze()
    th_coord = th_coord.squeeze()

    matlab_var_dict = dict( value_func_data=value_func_data,
                            lx_data=lx_data,
                            deriv_x_data=deriv_x_data,
                            deriv_y_data=deriv_y_data,
                            deriv_th_data=deriv_th_data,
                            uOpt_vel=uOpt_vel,
                            uOpt_angle=uOpt_angle,
                            x_coord=x_coord,
                            y_coord=y_coord,
                            th_coord=th_coord,
                            tau2=tau2
                           )
    return matlab_var_dict


#---------------------- Load MATLAB ---------------------------------------------------------
#v3 added uopt lookup table
matlab_var_dict= get_matlab_variables('/home/javier/jax_work/mppi/rc_car_mppi/brt_rc_wh_coarse_v3.mat')

data = matlab_var_dict['value_func_data']
data_lx = matlab_var_dict['lx_data']
uOpt_vel = matlab_var_dict['uOpt_vel']
uOpt_angle = matlab_var_dict['uOpt_angle']
coords = [matlab_var_dict['x_coord'], matlab_var_dict['y_coord'], matlab_var_dict['th_coord']]

data = jnp.array(data)
uOpt_vel = jnp.array(uOpt_vel)
uOpt_angle = jnp.array(uOpt_angle)
coords = [jnp.array(coord) for coord in coords]



In [2]:
import sys
print(type(uOpt_vel))
print(data.shape)

<class 'jaxlib.xla_extension.ArrayImpl'>
(152, 102, 181)


In [11]:
# Experiment Constants
DT = 0.02
L = 0.235
V_MIN = 0.5
V_MAX = 1.5
DELTA_MIN = -0.5
DELTA_MAX = 0.5

@jit
def ackerman_dynamics(state, control, dt=DT, L=L):
    x, y, theta = state
    v, delta = control
    
    x_dot = v * jnp.cos(theta)
    y_dot = v * jnp.sin(theta)
    theta_dot = v * jnp.tan(delta) * (1 / L)
    
    new_x = x + x_dot * dt
    new_y = y + y_dot * dt
    new_theta = theta + theta_dot * dt
    
    # Handle the angle wrap around
    new_theta = ((new_theta + jnp.pi) % (2 * jnp.pi)) - jnp.pi
    
    return new_x, new_y, new_theta

@jit
def cost_function(states, controls, v_target=V_MAX, delta_target=0.0):
    x = states[0, :]
    v = controls[:, 0]
    delta = controls[:, 1]
    cost = jnp.sum((v - v_target) ** 2)
    cost += jnp.sum((delta - delta_target) ** 2) * 5.0
    return cost 

@jit
def simulate_ackerman(initial_state, disturbed_controls, data, uOpt_vel, uOpt_angle, coords, dt=DT, L=L):
  
    def step(state, control):
        #find index in coords closest to state to get value and optimal control
        x_idx = jnp.argmin(jnp.abs(coords[0] - state[0]))
        y_idx = jnp.argmin(jnp.abs(coords[1] - state[1]))
        th_idx = jnp.argmin(jnp.abs(coords[2] - state[2]))        
        value_now = data[x_idx, y_idx, th_idx]
        uOpt_vel_now = uOpt_vel[x_idx, y_idx, th_idx]
        uOpt_angle_now = uOpt_angle[x_idx, y_idx, th_idx]
        
        # Update control with optimal values
        control = jnp.array([uOpt_vel_now, uOpt_angle_now])*(value_now<0.1) + control*(value_now>=0.1)
        new_state = ackerman_dynamics(state, control, dt, L)
        return new_state, new_state
    
    # Use jax.lax.scan to iterate over the controls and accumulate the states
    _, states = jax.lax.scan(step, initial_state, disturbed_controls)
    
    # Convert states to a JAX array
    states = jnp.array(states)
    
    # Compute the cost for the entire array of states and controls
    total_cost = cost_function(states, disturbed_controls)
    
    return states, total_cost

# Vectorize the simulation function to run multiple trajectories in parallel
simulate_ackerman_parallel = jax.vmap(simulate_ackerman, in_axes=(None, 0, None, None, None, None))

# ------------------     SIMULATION INIT  ----------------------------------

# Example usage with jax.random
key = random.PRNGKey(0)

# Simulation parameters
HALLUCINATION_STEPS = 100
NUM_THREADS = 500
TEMPERATURE = 1.0
EXPERIMENT_T = 12.0
EXPERIMENT_STEPS = int(EXPERIMENT_T / DT)
INITIAL_STATE = (3.0, 0.5, 0)

# Safety filter parameters
FILTER_EXPERIMENT = True
EXPERIMENT_THRESHOLD = 0.1
FILTER_HALLUCINATIONS = False

# Data structures
state_now = INITIAL_STATE
state_history = []
control_now = (0.0, 0.0)
control_history = []
hallucination_history = []
m_t = np.zeros((EXPERIMENT_STEPS, 3))

# timing structures
time_rollouts = []
time_sim_step = [] 

# Generate nominal control inputs
nominal_velocities = jnp.ones(HALLUCINATION_STEPS)
nominal_steering_angles = jnp.zeros(HALLUCINATION_STEPS)
nominal_controls = jnp.stack((nominal_velocities, nominal_steering_angles), axis=1)

# ------------------     SIMULATION LOOP  ----------------------------------

for i in range(EXPERIMENT_STEPS):
    start_time_sim_step = time.time()
    # Generate random noise for multiple trajectories
    control_noise = random.normal(key, shape=(NUM_THREADS, HALLUCINATION_STEPS, 2)) * jnp.array([0.2, 0.1])  # Adjust the scale of noise as needed

    # Combine controls and noise before passing to the simulation function, clip to valid range
    disturbed_controls = nominal_controls + control_noise
    disturbed_controls = jnp.clip(disturbed_controls, jnp.array([V_MIN, DELTA_MIN]), jnp.array([V_MAX, DELTA_MAX]))
    # Grab the noise that was actually applied
    control_noise = disturbed_controls - nominal_controls    

    # Perform the simulation
    start_time_rollouts = time.time()
    states_parallel, costs_parallel = simulate_ackerman_parallel(state_now, disturbed_controls, data, uOpt_vel, uOpt_angle, coords)
    end_time_rollouts = time.time()
    
    # Update nominal controls using the costs and noise
    weights = jnp.exp(-TEMPERATURE * (costs_parallel))
    weights = weights[:, jnp.newaxis, jnp.newaxis]  # Adjust shape for broadcasting
    nominal_controls = nominal_controls + jnp.sum(weights * control_noise, axis=0) / jnp.sum(weights)
    # Clip the controls to the valid range
    nominal_controls = jnp.clip(nominal_controls, jnp.array([V_MIN, DELTA_MIN]), jnp.array([V_MAX, DELTA_MAX]))
    
    # Check value function and apply LR filter
    #find index in coords closest to state to get value and optimal control
    x_idx = jnp.argmin(jnp.abs(coords[0] - state_now[0]))
    y_idx = jnp.argmin(jnp.abs(coords[1] - state_now[1]))
    th_idx = jnp.argmin(jnp.abs(coords[2] - state_now[2]))        
    value_now = data[x_idx, y_idx, th_idx]
    uOpt_vel_now = uOpt_vel[x_idx, y_idx, th_idx]
    uOpt_angle_now = uOpt_angle[x_idx, y_idx, th_idx]
  
    if FILTER_EXPERIMENT and value_now < EXPERIMENT_THRESHOLD:
        control_now = jnp.array([uOpt_vel_now,uOpt_angle_now])
        m_t[i] = [0.8, 0.4, 0.4] 
    else:
        control_now = nominal_controls[0]
        m_t[i] = [0.4, 0.4, 0.8] 
    
    # Apply the first control from the updated nominal controls to the system and store the new state
    state_now = ackerman_dynamics(state_now, control_now)
    state_history.append(state_now)
    control_history.append(control_now)
    hallucination_history.append(states_parallel)
    
    # Move the control sequence one step forward and maintain the last control
    nominal_controls = jnp.roll(nominal_controls, -1, axis=0)
    nominal_controls = nominal_controls.at[-1].set(nominal_controls[-2])

    end_time_sim_step = time.time()
    # Print and save the elapsed times
    print(f"Elapsed time for rollouts {(end_time_rollouts - start_time_rollouts)*1000:.1f} ms")
    print(f"Elapsed time for sim step {(end_time_sim_step - start_time_sim_step)*1000:.1f} ms")
    time_rollouts.append((end_time_rollouts - start_time_rollouts)*1000)
    time_sim_step.append((end_time_sim_step - start_time_sim_step)*1000)
    

#remove first 5 timings
time_rollouts = time_rollouts[5:]
time_sim_step = time_sim_step[5:]
# timing statistics avg+-std
print(f"Average time for rollouts {np.mean(time_rollouts):.1f} ± {np.std(time_rollouts):.1f} ms")
print(f"Average time for sim step {np.mean(time_sim_step):.1f} ± {np.std(time_sim_step):.1f} ms")


Elapsed time for rollouts 230.8 ms
Elapsed time for sim step 302.3 ms
Elapsed time for rollouts 213.5 ms
Elapsed time for sim step 278.3 ms
Elapsed time for rollouts 2.3 ms
Elapsed time for sim step 9.1 ms
Elapsed time for rollouts 2.4 ms
Elapsed time for sim step 9.0 ms
Elapsed time for rollouts 2.2 ms
Elapsed time for sim step 8.7 ms
Elapsed time for rollouts 2.2 ms
Elapsed time for sim step 9.0 ms
Elapsed time for rollouts 2.1 ms
Elapsed time for sim step 8.7 ms
Elapsed time for rollouts 2.2 ms
Elapsed time for sim step 9.2 ms
Elapsed time for rollouts 2.2 ms
Elapsed time for sim step 8.8 ms
Elapsed time for rollouts 2.2 ms
Elapsed time for sim step 8.6 ms
Elapsed time for rollouts 2.1 ms
Elapsed time for sim step 8.8 ms
Elapsed time for rollouts 2.2 ms
Elapsed time for sim step 8.6 ms
Elapsed time for rollouts 2.1 ms
Elapsed time for sim step 8.6 ms
Elapsed time for rollouts 2.2 ms
Elapsed time for sim step 8.4 ms
Elapsed time for rollouts 2.1 ms
Elapsed time for sim step 8.5 ms
El

In [6]:
from ipywidgets import interact

# Assuming state_history , matlab_var_dict, data_lx, data, m_t, and list_hallucinations_at_idx are defined elsewhere in your code
state_history = np.array(state_history)
control_history = np.array(control_history)

def plot_func(idx_to_plot):
    state_plot = state_history [idx_to_plot]
    # Print the state with 2 decimals
    print('[x,y,th]')
    print(np.around(state_plot, decimals=2))
    # Get the index of the closest v and th to the state_first_fltr
    th_idx = np.argmin(np.abs(matlab_var_dict['th_coord'] - state_plot[2]))

    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(6, 18))

    ####race track#####
    data1 = data_lx
    data1 = data1.transpose()
    CP1 = ax1.contour(matlab_var_dict['x_coord'], matlab_var_dict['y_coord'], data1, 0, colors='red', linewidths=1)

    data2 = data[:, :, th_idx]  # Slice of lx
    data2 = data2.transpose()
    CP2 = ax1.contour(matlab_var_dict['x_coord'], matlab_var_dict['y_coord'], data2, 0, colors='blue', linewidths=1)

    # Grab closest theta
    th = matlab_var_dict['th_coord']
    th = th[th_idx]
    # Plot trajectory and final state
    SP1 = ax1.scatter(state_history [0:idx_to_plot, 0], state_history[0:idx_to_plot, 1], c=m_t[0:idx_to_plot], s=5, alpha=0.5) #
    SP2 = ax1.scatter(state_history [idx_to_plot, 0], state_history [idx_to_plot, 1], s=10, c=[[0.0, 0.0, 0.0]], alpha=1.0)
    # Plot arrow at final state
    ax1.arrow(state_history [idx_to_plot, 0], state_history [idx_to_plot, 1], np.cos(th) * 0.2, np.sin(th) * 0.2, head_width=0.1, head_length=0.1, fc='k', ec='k')

    # Plot the hallucinations corresponding to idx_to_plot        
    for i in range(0, len(hallucination_history[idx_to_plot]), 20):
        h = hallucination_history[idx_to_plot][i]
        ax1.plot(h[0, :], h[1, :], color='green', alpha=0.2)

    # Set x and y limits
    ax1.set_xlim([0, 6])
    ax1.set_ylim([0, 4])

    # Change aspect ratio to match the grid
    ax1.set_aspect('equal')

    ####velocity plot####
    ax2.plot(control_history[:idx_to_plot+1, 0])
    ax2.scatter(idx_to_plot, control_history[idx_to_plot, 0], s=10, c=[[0.0, 0.0, 0.0]], alpha=1.0)
    ax2.set_title('velocity control')
    ax2.set_xlabel('k')
    ax2.set_aspect(aspect=50)

    ####angle plot####
    ax3.plot(control_history[:idx_to_plot+1, 1])
    ax3.scatter(idx_to_plot, control_history[idx_to_plot, 1], s=10, c=[[0.0, 0.0, 0.0]], alpha=1.0)
    ax3.set_title('angle control')
    ax3.set_xlabel('k')
    ax3.set_aspect(aspect=50)
       
    
    plt.show()

# Use interact to create a slider for idx_to_plot
interact(plot_func, idx_to_plot=(0, len(state_history ) - 2, 1))

interactive(children=(IntSlider(value=299, description='idx_to_plot', max=598), Output()), _dom_classes=('widg…

<function __main__.plot_func(idx_to_plot)>

In [None]:
print(hallucination_history[0][0].shape)

In [None]:
# Visualize the state trajectories
hallucinations = hallucination_history[0]
plt.figure(figsize=(10, 6))
for i in range(NUM_THREADS):
    x_coords = jnp.array(hallucinations)[i, 0, :]
    y_coords = jnp.array(hallucinations)[i, 1, :]
    plt.plot(x_coords, y_coords, color='tab:blue', alpha=0.3)  
plt.xlabel('X Position')
plt.ylabel('Y Position')
plt.title('State Trajectories of Ackerman Steering System with External Noise')
plt.grid(True)
# Set x and y limits
plt.xlim([0, 6])
plt.ylim([0, 4])
plt.show()

# Print the total cost for each trajectory
print("Total costs for each trajectory:", costs_parallel)

In [None]:
# Ensure you are in a Jupyter Notebook environment
# Use %timeit to measure the execution time
%timeit -n 20 -r 10 states_parallel = simulate_ackerman_parallel(initial_state, disturbed_controls); jax.block_until_ready(states_parallel)