In [1]:
import numpy as np
import matplotlib.pyplot as plt
import mat73
import time

import jax
import jax.numpy as jnp
from jax import random, jit
import matplotlib.pyplot as plt

# ------------------     MATLAB STUFF  ----------------------------------
def get_matlab_variables(mat_file_path):
    variables = mat73.loadmat(mat_file_path)
    #double gets converted to np array by default
    value_func_data = variables['Vx']
    lx_data = variables['lx'] 
    tau2 = variables['tau2']

    #Deriv is cell which gets converted into list of lists
    deriv_x_data = np.array(variables['Deriv'][0])
    deriv_x_data = deriv_x_data.squeeze()
    deriv_y_data = np.array(variables['Deriv'][1])
    deriv_y_data = deriv_y_data.squeeze()
    deriv_th_data = np.array(variables['Deriv'][2])
    deriv_th_data = deriv_th_data.squeeze()
    
    #uOpt is also cell which gets converted into list of lists
    uOpt_vel = np.array(variables['uOpt'][0])
    uOpt_vel = uOpt_vel.squeeze()
    uOpt_angle = np.array(variables['uOpt'][1])
    uOpt_angle = uOpt_angle.squeeze() 

    #g is struct whic gets converted into dic
    #vs is cell which give a list
    x_coord=np.array(variables['g']['vs'][0])
    y_coord=np.array(variables['g']['vs'][1])
    th_coord=np.array(variables['g']['vs'][2])
    x_coord = x_coord.squeeze()
    y_coord = y_coord.squeeze()
    th_coord = th_coord.squeeze()

    matlab_var_dict = dict( value_func_data=value_func_data,
                            lx_data=lx_data,
                            deriv_x_data=deriv_x_data,
                            deriv_y_data=deriv_y_data,
                            deriv_th_data=deriv_th_data,
                            uOpt_vel=uOpt_vel,
                            uOpt_angle=uOpt_angle,
                            x_coord=x_coord,
                            y_coord=y_coord,
                            th_coord=th_coord,
                            tau2=tau2
                           )
    return matlab_var_dict


#---------------------- Load MATLAB ---------------------------------------------------------
#v3 added uopt lookup table
matlab_var_dict= get_matlab_variables('/home/javier/jax_work/mppi/rc_car_mppi/brt_rc_wh_coarse_v3.mat')

data = matlab_var_dict['value_func_data']
data_lx = matlab_var_dict['lx_data']
uOpt_vel = matlab_var_dict['uOpt_vel']
uOpt_angle = matlab_var_dict['uOpt_angle']
coords = [matlab_var_dict['x_coord'], matlab_var_dict['y_coord'], matlab_var_dict['th_coord']]

data = jnp.array(data)
uOpt_vel = jnp.array(uOpt_vel)
uOpt_angle = jnp.array(uOpt_angle)
coords = [jnp.array(coord) for coord in coords]



In [2]:

#open the npz file in the experiment folder
exp_path = '/home/javier/jax_work/mppi/rc_car_mppi/experiments/mppi_data_20240819-181814.npz'
npzfile = np.load(exp_path)

state_history = npzfile['state_history']
control_history = npzfile['control_history']
hallucination_history = npzfile['hallucination_history']
lr_active = npzfile['lr_active']
m_t = [[0.8, 0.4, 0.4] if lr_active[i] == 1 else [0.4, 0.4, 0.8] for i in range(len(lr_active))]

print(state_history.shape)
print(control_history.shape)
print(hallucination_history.shape)
print(lr_active.shape)


# Print the extracted range
range_init = 850
range_ends = 890
print('lr_active:', lr_active[range_init:range_ends])


range_test_point= 885
print('hallucinations:\n', hallucination_history[range_test_point,100,:,0:6])




(1243, 3)
(1243, 2)
(1243, 500, 3, 100)
(1243,)
lr_active: [ True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True]
hallucinations:
 [[1.4814475 1.4792186 1.4769634 1.4746821 1.4723748 1.4700416]
 [3.7621891 3.7655106 3.7688143 3.7721    3.7753675 3.7786164]
 [2.1618335 2.1697705 2.177708  2.1856453 2.1935828 2.2015202]]


In [3]:
from ipywidgets import interact

# Assuming state_history , matlab_var_dict, data_lx, data, m_t, and list_hallucinations_at_idx are defined elsewhere in your code
state_history = np.array(state_history)
control_history = np.array(control_history)

def plot_func(idx_to_plot):
    state_plot = state_history [idx_to_plot]
    # Print the state with 2 decimals
    print('[x,y,th]')
    print(np.around(state_plot, decimals=2))
    # Get the index of the closest v and th to the state_first_fltr
    th_idx = np.argmin(np.abs(matlab_var_dict['th_coord'] - state_plot[2]))

    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(6, 18))

    ####race track#####
    data1 = data_lx
    data1 = data1.transpose()
    CP1 = ax1.contour(matlab_var_dict['x_coord'], matlab_var_dict['y_coord'], data1, 0, colors='red', linewidths=1)

    data2 = data[:, :, th_idx]  # Slice of lx
    data2 = data2.transpose()
    CP2 = ax1.contour(matlab_var_dict['x_coord'], matlab_var_dict['y_coord'], data2, 0, colors='blue', linewidths=1)

    # Grab closest theta
    th = matlab_var_dict['th_coord']
    th = th[th_idx]
    # Plot trajectory and final state
    SP1 = ax1.scatter(state_history [0:idx_to_plot, 0], state_history[0:idx_to_plot, 1], c=m_t[0:idx_to_plot], s=5, alpha=0.5) #
    SP2 = ax1.scatter(state_history [idx_to_plot, 0], state_history [idx_to_plot, 1], s=10, c=[[0.0, 0.0, 0.0]], alpha=1.0)
    # Plot arrow at final state
    ax1.arrow(state_history [idx_to_plot, 0], state_history [idx_to_plot, 1], np.cos(th) * 0.2, np.sin(th) * 0.2, head_width=0.1, head_length=0.1, fc='k', ec='k')

    # Plot the hallucinations corresponding to idx_to_plot        
    for i in range(0, len(hallucination_history[idx_to_plot]), 20):
        h = hallucination_history[idx_to_plot][i]
        ax1.plot(h[0, :], h[1, :], color='green', alpha=0.2)

    # Set x and y limits
    ax1.set_xlim([0, 6])
    ax1.set_ylim([0, 4])

    # Change aspect ratio to match the grid
    ax1.set_aspect('equal')

    ####velocity plot####
    ax2.plot(control_history[:idx_to_plot+1, 0])
    ax2.scatter(idx_to_plot, control_history[idx_to_plot, 0], s=10, c=[[0.0, 0.0, 0.0]], alpha=1.0)
    ax2.set_title('velocity control')
    ax2.set_xlabel('k')
    ax2.set_aspect(aspect=50)

    ####angle plot####
    ax3.plot(control_history[:idx_to_plot+1, 1])
    ax3.scatter(idx_to_plot, control_history[idx_to_plot, 1], s=10, c=[[0.0, 0.0, 0.0]], alpha=1.0)
    ax3.set_title('angle control')
    ax3.set_xlabel('k')
    ax3.set_aspect(aspect=50)
       
    
    plt.show()

# Use interact to create a slider for idx_to_plot
interact(plot_func, idx_to_plot=(0, len(state_history ) - 2, 1))

interactive(children=(IntSlider(value=620, description='idx_to_plot', max=1241), Output()), _dom_classes=('wid…

<function __main__.plot_func(idx_to_plot)>