In [1]:
import torch
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt
from mpc.track.src import simple_track_generator, track_functions

from torch.func import jacfwd, vmap

import utils

import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer

import torch.autograd.functional as F

from mpc import casadi_control

import scipy.linalg

(CVXPY) Aug 12 09:40:18 AM: Encountered unexpected exception importing solver GLOP:
RuntimeError('Unrecognized new version of ortools (9.5.2237). Expected < 9.5.0.Please open a feature request on cvxpy to enable support for this version.')
(CVXPY) Aug 12 09:40:18 AM: Encountered unexpected exception importing solver PDLP:
RuntimeError('Unrecognized new version of ortools (9.5.2237). Expected < 9.5.0.Please open a feature request on cvxpy to enable support for this version.')


In [2]:
from casadi import *

class CasadiControl():
    def __init__(self, track_coordinates, params):
        super().__init__()

        params = params.numpy()
        
        # states: sigma, d, phi, v (4) + sigma_0, sigma_diff (2) + d_pen (1) + v_ub (1) + ac_ub (1)
        self.n_state = 4+2+1+1+1
        print(self.n_state)          # here add amount of states plus amount of exact penalty terms
        # control: a, delta
        self.n_ctrl = 2

        self.track_coordinates = track_coordinates

        # everything to calculate curvature
        self.track_sigma = self.track_coordinates[2,:]
        self.track_curv = self.track_coordinates[4,:]

        self.track_curv_shift = torch.empty(self.track_curv.size())
        self.track_curv_shift[1:] = self.track_curv[0:-1]
        self.track_curv_shift[0] = self.track_curv[-1]
        self.track_curv_diff = self.track_curv - self.track_curv_shift

        self.mask = torch.where(torch.absolute(self.track_curv_diff) < 0.1, False, True)
        self.sigma_f = self.track_sigma[self.mask]
        self.curv_f = self.track_curv_diff[self.mask]

        self.params = params

        self.l_r = params[0]
        self.l_f = params[1]
        
        self.track_width = params[2]
        
        self.delta_threshold_rad = np.pi
        self.dt = params[3]

        self.smooth_curve = params[4]
        
        self.v_max = params[5]
        
        self.delta_max = params[6]
        
        self.a_max = params[7]
        
        self.mpc_T = int(params[8])
        
    def sigmoid(self, x):
        return (tanh(x/2)+1.)/2

    def curv_casadi(self, sigma):
        
        num_sf = self.sigma_f.size()
        num_s = sigma.size()

        sigma_f_mat = self.sigma_f.repeat(num_s[1],1)
   
        sigma_f_mat_np = sigma_f_mat.numpy()
        sigma_f_np = self.sigma_f.numpy()
        curv_f_np = self.curv_f.numpy()

        sigma_shifted = reshape(sigma,num_s[1],1)- sigma_f_mat_np
        curv_unscaled = self.sigmoid(self.smooth_curve*sigma_shifted)
        curv = reshape((curv_unscaled@(curv_f_np.reshape(-1,1))),1,num_s[1])

        return curv
    
    
    def mpc_casadi(self,q,p,x0_np,dx,du):
        mpc_T = self.mpc_T

        x_sym = SX.sym('x_sym',dx,mpc_T+1)
        u_sym = SX.sym('u_sym',du,mpc_T)

        beta = np.arctan(l_r/(l_r+l_f)*np.tan(u_sym[1,0:mpc_T]))

        dyn1 = horzcat(
            (x_sym[0,0] - x0_np[0]), 
            (x_sym[0,1:mpc_T+1] - x_sym[0,0:mpc_T] - dt*(x_sym[3,0:mpc_T]*(np.cos(x_sym[2,0:mpc_T]+beta)/(1.-self.curv_casadi(x_sym[0,0:mpc_T])*x_sym[1,0:mpc_T])))))

        dyn2 = horzcat(
            (x_sym[1,0] - x0_np[1]), 
            (x_sym[1,1:mpc_T+1] - x_sym[1,0:mpc_T] - dt*(x_sym[3,0:mpc_T]*np.sin(x_sym[2,0:mpc_T]+beta))))

        dyn3 = horzcat(
            (x_sym[2,0] - x0_np[2]), 
            (x_sym[2,1:mpc_T+1] - x_sym[2,0:mpc_T] - dt*(x_sym[3,0:mpc_T]*(1/l_f)*np.sin(beta)-self.curv_casadi(x_sym[0,0:mpc_T])*x_sym[3,0:mpc_T]*(np.cos(x_sym[2,0:mpc_T]+beta)/(1-self.curv_casadi(x_sym[0,0:mpc_T])*x_sym[1,0:mpc_T])))))

        dyn4 = horzcat(
            (x_sym[3,0] - x0_np[3]), 
            (x_sym[3,1:mpc_T+1] - x_sym[3,0:mpc_T] - dt*(u_sym[0,0:mpc_T])))

        feat = vertcat(x_sym[0,0:mpc_T]-x0_np[0],x_sym[1:,0:mpc_T],u_sym[:,0:mpc_T])

        q_sym = SX.sym('q_sym',dx+du,mpc_T)
        p_sym = SX.sym('q_sym',dx+du,mpc_T)
        Q_sym = diag(q_sym)

        l = sum2(sum1(q_sym*feat*feat + p_sym*feat))
        dl = substitute(substitute(l,q_sym,q),p_sym,p)

        const = vertcat(
                transpose(dyn1),
                transpose(dyn2),
                transpose(dyn3),
                transpose(dyn4),
                transpose(u_sym[0,0:mpc_T]),
                transpose(u_sym[1,0:mpc_T]),
                transpose(x_sym[1,0:mpc_T+1]),
                transpose(x_sym[3,0:mpc_T+1]))

        lbg = np.r_[np.zeros(mpc_T+1),
                    np.zeros(mpc_T+1),
                    np.zeros(mpc_T+1),
                    np.zeros(mpc_T+1),
                    -self.a_max*np.ones(mpc_T),
                    -self.delta_max*np.ones(mpc_T),
                    -0.35*self.track_width*np.ones(mpc_T+1),
                    -0.1*np.ones(mpc_T+1)]

        ubg = np.r_[np.zeros(mpc_T+1),
                    np.zeros(mpc_T+1),
                    np.zeros(mpc_T+1),
                    np.zeros(mpc_T+1),
                    self.a_max*np.ones(mpc_T),
                    self.delta_max*np.ones(mpc_T),
                    0.35*self.track_width*np.ones(mpc_T+1),
                    self.v_max*np.ones(mpc_T+1)]


        lbx = -np.inf * np.ones(dx*(mpc_T+1)+du*mpc_T)
        ubx = np.inf * np.ones(dx*(mpc_T+1)+du*mpc_T)

        x = vertcat(reshape(x_sym[:,0:mpc_T+1],(dx*(mpc_T+1),1)),
                    reshape(u_sym[:,0:mpc_T],(du*mpc_T,1)))

        options = {
                    'verbose': False,
                    'ipopt.print_level': 0,
                    'print_time': 0,
                    'ipopt.tol': 1e-4,
                    'ipopt.max_iter': 4000,
                    'ipopt.hessian_approximation': 'limited-memory'
                }

        nlp = {'x':x,'f':dl, 'g':const}
        solver = nlpsol('solver','ipopt', nlp, options)

        solver_input = {}
        solver_input['lbx'] = lbx
        solver_input['ubx'] = ubx
        solver_input['lbg'] = lbg
        solver_input['ubg'] = ubg

        solver_output = solver(**solver_input)

        sol = solver_output['x']

        sol_evalf = np.squeeze(evalf(sol))
        u = sol_evalf[-du*mpc_T:].reshape(-1,du)
        x = sol_evalf[:-du*mpc_T].reshape(-1,dx)

        return x, u

In [3]:
class SimpleNN(nn.Module):
    def __init__(self, mpc_H, mpc_T, O, K):
        super(SimpleNN, self).__init__()
        input_size = 3 + mpc_H
        self.fc1 = nn.Linear(input_size, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, mpc_T*O)
        self.activation = nn.ReLU()
        self.output_activation = nn.Tanh()
        self.K = K
        self.O = O
        self.mpc_T = mpc_T

    def forward(self, x):
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        x = self.fc3(x)
        x = self.output_activation(x) * self.K
        x = x.reshape(-1, self.mpc_T, self.O)
        return x

In [4]:
def sample_init(BS):
    di = 1000
    sigma_sample = torch.randint(0, int(0.5*di), (BS,1))/di
    d_sample = torch.randint(int(-.15*di), int(.15*di), (BS,1))/di
    phi_sample = torch.randint(int(-0.08*di), int(0.08*di), (BS,1))/di
    v_sample = torch.randint(0, int(.22*di), (BS,1))/di
    x_init_sample = torch.hstack((sigma_sample, d_sample, phi_sample, v_sample))
    return x_init_sample

def get_curve_hor_from_x(x, track_coord, H_curve):
    idx_track_batch = ((x[:,0]-track_coord[[2],:].T)**2).argmin(0)
    idcs_track_batch = idx_track_batch[:, None] + torch.arange(H_curve)
    curvs = track_coord[4,idcs_track_batch].float()
    return curvs

In [5]:
class FrenetKinBicycleDx(nn.Module):
    def __init__(self, track_coordinates, params, dev):
        super().__init__()
        
        self.params = params

        # states: sigma, d, phi, v (4) + sigma_0, sigma_diff (2) + d_pen (1) + v_ub (1)
        self.n_state = 4+2+1+1
        print('Number of states:', self.n_state)
        
        self.n_ctrl = 2 # control: a, delta

        self.track_coordinates = track_coordinates.to(dev)

        # everything to calculate curvature
        self.track_sigma = self.track_coordinates[2,:]
        self.track_curv = self.track_coordinates[4,:]

        self.track_curv_shift = torch.empty(self.track_curv.size()).to(dev)
        self.track_curv_shift[1:] = self.track_curv[0:-1]
        self.track_curv_shift[0] = self.track_curv[-1]
        self.track_curv_diff = self.track_curv - self.track_curv_shift

        self.mask = torch.where(torch.absolute(self.track_curv_diff) < 0.1, False, True)
        self.sigma_f = self.track_sigma[self.mask]
        self.curv_f = self.track_curv_diff[self.mask]
     
        self.l_r = params[0]
        self.l_f = params[1]
        
        self.track_width = params[2]
        
        self.delta_threshold_rad = np.pi
        self.dt = params[3]

        self.smooth_curve = params[4]
        
        self.v_max = params[5]
        
        self.delta_max = params[6]
                
        
        
    def curv(self, sigma):

        num_sf = self.sigma_f.size()
        num_s = sigma.size()

        sigma_f_mat = self.sigma_f.repeat(num_s[0],1)

        sigma_shifted = sigma.reshape(-1,1) - sigma_f_mat
        curv_unscaled = torch.sigmoid(self.smooth_curve*sigma_shifted)
        curv = (curv_unscaled@(self.curv_f.reshape(-1,1))).type(torch.float)

        return curv.reshape(-1)
    
    
    def penalty_d(self, d, factor=1000.):  
        overshoot_pos = (d - 0.34*self.track_width).clamp(min=0)
        overshoot_neg = (-d - 0.34*self.track_width).clamp(min=0)
        penalty_pos = torch.exp(overshoot_pos) - 1
        penalty_neg = torch.exp(overshoot_neg) - 1 
        return factor*(penalty_pos + penalty_neg)
    
    def penalty_v(self, v, factor=1000.):          
        overshoot_pos = (v - self.v_max).clamp(min=0)
        overshoot_neg = (-v + 0.001).clamp(min=0)
        penalty_pos = torch.exp(overshoot_pos) - 1
        penalty_neg = torch.exp(overshoot_neg) - 1 
        return factor*(penalty_pos + penalty_neg)
    
    def penalty_delta(self, delta, factor=1000.):          
        overshoot_pos = (delta - self.delta_max).clamp(min=0)
        overshoot_neg = (-delta - self.delta_max).clamp(min=0)
        penalty_pos = torch.exp(overshoot_pos) - 1
        penalty_neg = torch.exp(overshoot_neg) - 1 
        return factor*(penalty_pos + penalty_neg)
    
    def forward(self, state, u):
        squeeze = state.ndimension() == 1
        if squeeze:
            state = state.unsqueeze(0)
            u = u.unsqueeze(0)
        if state.is_cuda and not self.params.is_cuda:
            self.params = self.params.cuda()


        a, delta = torch.unbind(u, dim=1)

        #sigma, d, phi, v, sigma_0, sigma_diff, d_pen, v_ub = torch.unbind(state, dim=1)
        
        sigma, d, phi, v, sigma_0, sigma_diff = torch.unbind(state, dim=1)
        beta = torch.atan(self.l_r/(self.l_r+self.l_f)*torch.tan(delta))       
        k = self.curv(sigma)

        dsigma = v*(torch.cos(phi+beta)/(1.-k*d))
        dd = v*torch.sin(phi+beta)
        dphi = v/self.l_f*torch.sin(beta)-k*v*(torch.cos(phi+beta)/(1-k*d))       
        
        dv = a      

        sigma = sigma + self.dt * dsigma
        d = d + self.dt * dd
        phi = phi + self.dt * dphi
        v = v + self.dt * dv 
        
        sigma_diff = sigma - sigma_0 
                
        #d_pen = self.penalty_d(d)        
        #v_ub = self.penalty_v(v)

        #state = torch.stack((sigma, d, phi, v, sigma_0, sigma_diff, d_pen, v_ub), 1)
        
        state = torch.stack((sigma, d, phi, v, sigma_0, sigma_diff), 1)

        return state

In [6]:
def solve_casadi(q_np,p_np,x0_np,dx,du):
    
    x_curr_opt, u_curr_opt = control.mpc_casadi(q_np,p_np,x0_np,dx,du)

    sigzero_curr_opt = np.expand_dims(x_curr_opt[[0],0].repeat(mpc_T+1), 1)
    sigsiff_curr_opt = x_curr_opt[:,[0]]-x_curr_opt[0,0]

    x_curr_opt_plus = np.concatenate((
        x_curr_opt,sigzero_curr_opt,sigsiff_curr_opt), axis = 1)

    x_star = x_curr_opt_plus[1:]
    u_star = u_curr_opt
    
    return x_star, u_star


def formulate_QP(x0_cp, x_star, u_star, true_dx):

    J_x = vmap(jacfwd(true_dx.forward, 0))(torch.tensor(x_star), torch.tensor(u_star)).squeeze()
    J_u = vmap(jacfwd(true_dx.forward, 1))(torch.tensor(x_star), torch.tensor(u_star)).squeeze()

    x_cp = cp.Variable((mpc_T, dx + 2))
    u_cp = cp.Variable((mpc_T, du))

    Q_sigma_diff_sqrt = cp.Parameter((mpc_T, mpc_T))
    Q_d_sqrt = cp.Parameter((mpc_T, mpc_T))

    p_sigma_diff = cp.Parameter((mpc_T))
    p_d = cp.Parameter((mpc_T))

    objective = cp.Minimize(
        p_sigma_diff@x_cp[:,5] \
        + p_d@x_cp[:,1] \
        + cp.sum_squares(Q_sigma_diff_sqrt @ x_cp[:,5])
        + cp.sum_squares(Q_d_sqrt @ x_cp[:,1])
    )

    constraints = []

    constraints += [x_cp[0,:4] == x0_cp]
    constraints += [x_cp[:, 4] == x0_cp[0]]
    constraints += [x_cp[0, 5] == 0]

    constraints += [u_cp[:,1] >= -delta_max, u_cp[:,1] <= delta_max]
    constraints += [u_cp[:,0] >= -a_max, u_cp[:,0] <= a_max]

    constraints += [x_cp[:,3] >= 0, x_cp[:,3] <= v_max]
    constraints += [x_cp[:,1] >= -track_width*0.35, x_cp[:,1] <= track_width*0.35]
    
    constraints += [x_cp >= -999., x_cp <= 999.]

    for tt in range(mpc_T-1):
        Jdx = (x_cp - x_star)[tt]@J_x[tt].T
        Jdu = (u_cp - u_star)[tt]@J_u[tt].T
        constraints += [x_cp[tt+1] == x_star[tt+1] + Jdx + Jdu]

    problem = cp.Problem(objective, constraints)
    assert problem.is_dpp() #the problem should be DPP for us to backpropagate (that is why Q_sqrt)

    cvxpylayer = CvxpyLayer(
        problem, 
        parameters=[Q_sigma_diff_sqrt, Q_d_sqrt, p_sigma_diff, p_d], 
        variables=[x_cp, u_cp])
    
    return cvxpylayer


def q_and_p(mpc_T, q_p_pred):
    # Cost order: 
    # sigma_diff, d, phi, v, a, delta
    q = torch.zeros((6,mpc_T))
    p = torch.zeros((6,mpc_T))

    q[0,:] = q_p_pred[0,:,0].clamp(0.1)
    q[1,:] = 5 + q_p_pred[0,:,1].clamp(0)
    q[2,:] = 5

    p[0,:] = -10 + q_p_pred[0,:,2]
    p[1,:] = q_p_pred[0,:,3]
    
    return q, p

In [7]:
k_curve = 100.

dt = 0.04

mpc_T = 30
mpc_H = 60
n_batch = 32

l_r = 0.2
l_f = 0.2

v_max = 2.5

delta_max = 0.6

a_max = 2.

track_density = 300
track_width = 0.5
t_track = 0.3
init_track = [0,0,0]

max_p = 100 

BS = 1

params = torch.tensor([l_r, l_f, track_width, dt, k_curve, v_max, delta_max, a_max, mpc_T])

In [8]:
gen = simple_track_generator.trackGenerator(track_density,track_width)
track_name = 'DEMO_TRACK'

track_function = {
    'DEMO_TRACK'    : track_functions.demo_track,
    'HARD_TRACK'    : track_functions.hard_track,
    'LONG_TRACK'    : track_functions.long_track,
    'LUCERNE_TRACK' : track_functions.lucerne_track,
    'BERN_TRACK'    : track_functions.bern_track,
    'INFINITY_TRACK': track_functions.infinity_track,
    'SNAIL_TRACK'   : track_functions.snail_track
}.get(track_name, track_functions.demo_track)

In [9]:
track_function(gen, t_track, init_track)
gen.populatePointsAndArcLength()
gen.centerTrack()
track_coord = torch.from_numpy(np.vstack(
    [gen.xCoords, 
     gen.yCoords, 
     gen.arcLength, 
     gen.tangentAngle, 
     gen.curvature]))

In [10]:
true_dx = FrenetKinBicycleDx(track_coord, params, 'cpu')

Number of states: 8


In [11]:
x0 = torch.tensor([0.0, 0.1, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0])
u0 = torch.tensor([0.0, 0.0])

In [12]:
dx=4
du=2

In [13]:
control = CasadiControl(track_coord, params)

9


In [14]:
model = SimpleNN(mpc_H, mpc_T, 4, max_p)
opt = torch.optim.Adam(model.parameters(), lr=0.002)


In [46]:
x0 = sample_init(BS)
print(x0)

tensor([[ 0.4440,  0.0450, -0.0450,  0.0020]])


In [47]:
T_approx = 20

x_ap = torch.zeros((BS,T_approx,dx))
u_ap = torch.zeros((BS,T_approx,du))

for t in range(T_approx):

    curv = get_curve_hor_from_x(x0, track_coord, mpc_H)
    inp = torch.hstack((x0[:,1:], curv))
    q_p_pred = model(inp)
    q, p = q_and_p(mpc_T, q_p_pred)

    x0_np = x0.detach().numpy().squeeze()
    q_np = q.detach().numpy()
    p_np = p.detach().numpy()
    x_star, u_star = solve_casadi(q_np,p_np,x0_np,dx,du)

    cvxpylayer = formulate_QP(x0_np, x_star, u_star, true_dx)

    x_approx, u_approx = cvxpylayer(
        q[0,:].sqrt().diag(), 
        q[1,:].sqrt().diag(), 
        p[0,:], 
        p[1,:])

    x0 = x_approx[1,:4].unsqueeze(0)
    
    x_ap[:,t] = x0
    u_ap[:,t] = u_approx[1].unsqueeze(0)

In [48]:
x_ap

tensor([[[ 0.4441,  0.0461, -0.0396,  0.0820],
         [ 0.4474,  0.0481, -0.0286,  0.1619],
         [ 0.4542,  0.0514, -0.0115,  0.2420],
         [ 0.4646,  0.0561,  0.0127,  0.3219],
         [ 0.4789,  0.0630,  0.0460,  0.4019],
         [ 0.4976,  0.0729,  0.0914,  0.4819],
         [ 0.5214,  0.0870,  0.1510,  0.5619],
         [ 0.5459,  0.0938,  0.1663,  0.6419],
         [ 0.5713,  0.0897,  0.1247,  0.7219],
         [ 0.6055,  0.1084,  0.1977,  0.8016],
         [ 0.6372,  0.1043,  0.1459,  0.8816],
         [ 0.6721,  0.0982,  0.0891,  0.9616],
         [ 0.7096,  0.0892,  0.0269,  1.0416],
         [ 0.7495,  0.0769, -0.0405,  1.1216],
         [ 0.8084,  0.0965,  0.0734,  1.2016],
         [ 0.8587,  0.0950,  0.0500,  1.4189],
         [ 0.9345,  0.1398,  0.2637,  1.4989],
         [ 0.9941,  0.1355,  0.1702,  1.5789],
         [ 1.0182,  0.1251,  0.3892,  1.6589],
         [ 1.1163,  0.1341,  0.0554,  1.7389]]], grad_fn=<CopySlices>)

In [51]:
x_ap.sum().backward()

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


ValueError: Can't apply Jacobian with a quadratic objective.

In [None]:
from matplotlib import cm


cart_3 = np.zeros((120, 3))

for i in range(120):
    cart_3[i,:2] = utils.frenet_to_cartesian(torch.tensor(x[i]), track_coord).numpy()
    cart_3[i,2] = x[i,3]

fig, ax = plt.subplots(1,1, figsize=(8,4), dpi=150)
gen.plotPoints(ax)


custom_cmap = plt.get_cmap('cubehelix').reversed()
sct = ax.scatter(cart_3[:,0], cart_3[:,1], 
                 c=cart_3[:,2], cmap=custom_cmap, s=4)

cbar = plt.colorbar(sct)
cbar.set_label('Velocity') 
#sct = ax.scatter(cart[:,0].numpy(), cart[:,1].numpy(), s=4, color='red')
#sct = ax.scatter(cart_2[:,0].numpy(), cart_2[:,1].numpy(), s=4, color='green')
#sct = ax.scatter(cart_3[:,0].detach().numpy(), cart_3[:,1].detach().numpy(), s=4, color='blue')

plt.show()

In [None]:
from matplotlib import cm

cart_3 = torch.zeros((n_sim, 3))

x = x0_1
for i in range(n_sim):
    x_aux = true_dx.forward(x, a_delta_learned[-190].detach()[i])
    x = x_aux.clone().squeeze()
    cart_3[i,:2] = utils.frenet_to_cartesian(x, track_coord)
    cart_3[i,2] = x[3]

fig, ax = plt.subplots(1,1, figsize=(8,4), dpi=150)
gen.plotPoints(ax)

custom_cmap = plt.get_cmap('cubehelix').reversed()
sct = ax.scatter(cart_3[:,0].detach().numpy(), cart_3[:,1].detach().numpy(), 
                 c=cart_3[:,2], cmap=custom_cmap, s=4)
cbar = plt.colorbar(sct)
cbar.set_label('Velocity') 
#sct = ax.scatter(cart[:,0].numpy(), cart[:,1].numpy(), s=4, color='red')
#sct = ax.scatter(cart_2[:,0].numpy(), cart_2[:,1].numpy(), s=4, color='green')
#sct = ax.scatter(cart_3[:,0].detach().numpy(), cart_3[:,1].detach().numpy(), s=4, color='blue')

plt.show()

In [None]:
sigma = cp.Variable((BS, mpc_T))
d = cp.Variable((BS, mpc_T))
phi = cp.Variable((BS, mpc_T))
v = cp.Variable((BS, mpc_T))
sigma_0 = cp.Variable((BS, mpc_T))
sigma_diff = cp.Variable((BS, mpc_T))
d_pen = cp.Variable((BS, mpc_T))
v_ub = cp.Variable((BS, mpc_T))
a = cp.Variable((BS, mpc_T))
delta = cp.Variable((BS, mpc_T))

In [None]:
p_d = cp.Parameter((BS, mpc_T))
Q_v = cp.Parameter((BS, mpc_T))
p_v = torch.ones((BS, mpc_T))
Q_d = torch.ones((BS, mpc_T))

In [None]:
objective = cp.Minimize(
    cp.sum(
        cp.sum(Q_d @ cp.square(d).T, 1) +
        cp.sum(p_d @ d.T, 1) +
        cp.sum(Q_v @ cp.square(v).T, 1) +
        cp.sum(p_v @ v.T, 1)
    )
)

In [None]:
A[0][:4,:4]

In [None]:
constraints = []

constraints += [delta >= -delta_max, delta <= delta_max]

constraints += [v >= 0, v <= v_max]
constraints += [d >= -track_width*0.4, d <= track_width*0.4]

In [None]:
for t in range(mpc_T-1):
    state_t = cp.hstack([sigma[t], d[t], phi[t], v[t], sigma_0[t], sigma_diff[t], d_pen[t], v_ub[t]])
    state_t1 = cp.hstack([sigma[t+1], d[t+1], phi[t+1], v[t+1], sigma_0[t+1], sigma_diff[t+1], d_pen[t+1], v_ub[t+1]])
    control_t = cp.hstack([a[t], delta[t]])
    constraints += [state_t1 == A[t] @ state_t + B[t] @ control_t]

In [None]:
A.shape

In [None]:
n_sim = 100

In [None]:
a_delta_learned = []

In [None]:
p_d

In [None]:
for s in range(2,n_sim+1):
    
    for it in range(400):

        loss = torch.zeros((s,))
        x = x0_1
        for i in range(s):
            x_aux = true_dx.forward(x, a_delta[i])
            x = x_aux.clone().squeeze()
            loss[i] = true_dx.penalty_d(x[1]) \
            + true_dx.penalty_v(x[3]) \
            + true_dx.penalty_a(a_delta[i,0]) \
            + true_dx.penalty_delta(a_delta[i,1])
            #cart_3[i] = x[[0,1,2]]

        loss_total = loss.sum() - x[0]
        opt.zero_grad()
        loss_total.backward()
        opt.step()

        if it%50==0:
            with torch.no_grad():
                print(s, x[0].item(), x[1].item(), x[3].item())
                a_delta_learned.append(a_delta.detach().clone())

In [None]:
from matplotlib import cm

cart_3 = torch.zeros((n_sim, 3))

x = x0_1
for i in range(n_sim):
    x_aux = true_dx.forward(x, a_delta_learned[-190].detach()[i])
    x = x_aux.clone().squeeze()
    cart_3[i,:2] = utils.frenet_to_cartesian(x, track_coord)
    cart_3[i,2] = x[3]
    #cart_3[i] = x[[0,1,2]]

fig, ax = plt.subplots(1,1, figsize=(8,4), dpi=150)
gen.plotPoints(ax)

custom_cmap = plt.get_cmap('cubehelix').reversed()
sct = ax.scatter(cart_3[:,0].detach().numpy(), cart_3[:,1].detach().numpy(), 
                 c=cart_3[:,2], cmap=custom_cmap, s=4)
cbar = plt.colorbar(sct)
cbar.set_label('Velocity') 
#sct = ax.scatter(cart[:,0].numpy(), cart[:,1].numpy(), s=4, color='red')
#sct = ax.scatter(cart_2[:,0].numpy(), cart_2[:,1].numpy(), s=4, color='green')
#sct = ax.scatter(cart_3[:,0].detach().numpy(), cart_3[:,1].detach().numpy(), s=4, color='blue')

plt.show()