In [1]:
import os
import sys
sys.path.append("/Users/shashank/University/rrc/ddn/")
import warnings
warnings.filterwarnings('ignore')

import torch
import numpy as np
import scipy.special
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from scipy.linalg import block_diag
from torch.utils.data import Dataset, DataLoader
from bernstein_coeff_order10_arbitinterval import bernstein_coeff_order10_new
from ddn.pytorch.node import LinEqConstDeclarativeNode, AbstractDeclarativeNode, EqConstDeclarativeNode

#### CUDA Initialization

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

Using cpu device


In [3]:
class OPTNode(AbstractDeclarativeNode):
    def __init__(self, rho_eq=1.0, rho_goal=1.0, rho_lane=1.0, rho_nonhol=1.0, rho_w_psi=1.0, rho_psi=1.0, maxiter=500, weight_smoothness=1.0, weight_smoothness_psi=1.0, t_fin=8.0, num=100, rho_mid=0.01):
        super().__init__()
        self.rho_eq = rho_eq
        self.rho_goal = rho_goal
        self.rho_lane = rho_lane
        self.rho_nonhol = rho_nonhol
        self.rho_w_psi = rho_w_psi
        self.rho_psi = rho_psi
        self.maxiter = maxiter
        self.weight_smoothness = weight_smoothness
        self.weight_smoothness_psi = weight_smoothness_psi

        self.t_fin = t_fin
        self.num = num
        self.t = self.t_fin / self.num

        tot_time = torch.linspace(0.0, self.t_fin, self.num).to(device)
        tot_time_copy = tot_time.reshape(self.num, 1)
        self.P, self.Pdot, self.Pddot = bernstein_coeff_order10_new(10, tot_time_copy[0], tot_time_copy[-1], tot_time_copy)

        self.nvar = np.shape(self.P)[1]
        self.A_eq_psi = np.vstack((self.P[0], self.Pdot[0], self.P[-1], self.Pdot[-1]))

        self.cost_smoothness = self.weight_smoothness * np.dot(self.Pddot.T, self.Pddot)
        self.cost_smoothness_v = self.weight_smoothness * np.dot(self.Pddot.T, self.Pddot)
        self.cost_smoothness_psi = self.weight_smoothness_psi * np.dot(self.Pddot.T, self.Pddot)
        self.lincost_smoothness_psi = np.zeros(self.nvar)

        self.P = torch.tensor(self.P, dtype=torch.double).to(device)
        self.Pdot = torch.tensor(self.Pdot, dtype=torch.double).to(device)
        self.Pddot = torch.tensor(self.Pddot, dtype=torch.double).to(device)
        self.A_eq_psi = torch.tensor(self.A_eq_psi, dtype=torch.double).to(device)
        self.cost_smoothness = torch.tensor(self.cost_smoothness, dtype=torch.double).to(device)
        self.cost_smoothness_v = torch.tensor(self.cost_smoothness_v, dtype=torch.double).to(device)
        self.cost_smoothness_psi = torch.tensor(self.cost_smoothness_psi, dtype=torch.double).to(device)
        self.lincost_smoothness_psi = torch.tensor(self.lincost_smoothness_psi, dtype=torch.double).to(device)

        self.rho_mid = rho_mid
        self.mid_idx = torch.tensor([int(self.num/4), int(self.num/2), int(3*self.num/4)]).to(device)
        
        self.A_w = self.P
        self.A_w_psi = self.P
        self.A_psi = self.P

    def compute_w_psi(self, x_init, y_init, x_fin, y_fin, x_mid, y_mid, psi, v, lamda_wc, lamda_ws):
        b_wc_psi = torch.cos(psi).double()
        b_ws_psi = torch.sin(psi).double()

        temp_x = torch.cumsum(self.P * (v * self.t)[:, np.newaxis], axis=0)
        temp_y = torch.cumsum(self.P * (v * self.t)[:, np.newaxis], axis=0)

        A_x = temp_x[0:self.num - 1]
        A_y = temp_y[0:self.num - 1]

        A_x_goal = A_x[-1].reshape(1, self.nvar)
        b_x_goal = torch.tensor([x_fin - x_init], dtype=torch.double)

        A_y_goal = A_y[-1].reshape(1, self.nvar)
        b_y_goal = torch.tensor([y_fin - y_init], dtype=torch.double)

        A_x_mid = A_x[self.mid_idx]
        A_y_mid = A_y[self.mid_idx]

        b_x_mid = x_mid - x_init
        b_y_mid = y_mid - y_init

        obj_x_goal = self.rho_goal * torch.matmul(A_x_goal.T, A_x_goal)
        linterm_augment_x_goal = -self.rho_goal * torch.matmul(A_x_goal.T, b_x_goal)

        obj_y_goal = self.rho_goal * torch.matmul(A_y_goal.T, A_y_goal)
        linterm_augment_y_goal = -self.rho_goal * torch.matmul(A_y_goal.T, b_y_goal)

        obj_x_mid = self.rho_mid * torch.matmul(A_x_mid.T, A_x_mid)
        linterm_augment_x_mid = -self.rho_mid * torch.matmul(A_x_mid.T, b_x_mid)

        obj_y_mid = self.rho_mid * torch.matmul(A_y_mid.T, A_y_mid)
        linterm_augment_y_mid = -self.rho_mid * torch.matmul(A_y_mid.T, b_y_mid)

        obj_wc_psi = self.rho_w_psi * torch.matmul(self.A_w_psi.T, self.A_w_psi)
        linterm_augment_wc_psi = -self.rho_w_psi * torch.matmul(self.A_w_psi.T, b_wc_psi)

        obj_ws_psi = self.rho_w_psi * torch.matmul(self.A_w_psi.T, self.A_w_psi)
        linterm_augment_ws_psi = -self.rho_w_psi * torch.matmul(self.A_w_psi.T, b_ws_psi)

        cost_wc = obj_wc_psi + obj_x_goal + obj_x_mid
        lincost_wc = -lamda_wc + linterm_augment_x_goal + linterm_augment_wc_psi + linterm_augment_x_mid

        cost_ws = obj_y_goal + obj_ws_psi + obj_y_mid
        lincost_ws = -lamda_ws + linterm_augment_y_goal + linterm_augment_ws_psi + linterm_augment_y_mid

        c_wc_psi = torch.linalg.solve(-cost_wc, lincost_wc)
        c_ws_psi = torch.linalg.solve(-cost_ws, lincost_ws)

        wc = torch.matmul(self.P, c_wc_psi)
        ws = torch.matmul(self.P, c_ws_psi)

        return wc, ws, c_wc_psi, c_ws_psi

    def compute_psi(self, wc, ws, b_eq_psi, lamda_psi):
        b_psi = torch.atan2(ws, wc).double()
        obj_psi = self.rho_psi * torch.matmul(self.A_psi.T, self.A_psi)
        linterm_augment_psi = -self.rho_psi * torch.matmul(self.A_psi.T, b_psi)

        cost_psi = self.cost_smoothness_psi + obj_psi + self.rho_eq * torch.matmul(self.A_eq_psi.T, self.A_eq_psi)
        lincost_psi = -lamda_psi + linterm_augment_psi - self.rho_eq * torch.matmul(self.A_eq_psi.T, b_eq_psi)
        sol = torch.linalg.solve(-cost_psi, lincost_psi)

        c_psi = sol[0:self.nvar]
        psi = torch.matmul(self.P, c_psi)

        res_psi = torch.matmul(self.A_psi, c_psi) - b_psi
        res_eq_psi = torch.matmul(self.A_eq_psi, c_psi) - b_eq_psi
        lamda_psi = lamda_psi - self.rho_psi * torch.matmul(self.A_psi.T, res_psi) - self.rho_eq * torch.matmul(self.A_eq_psi.T, res_eq_psi)
        return psi, c_psi, torch.linalg.norm(res_psi), torch.linalg.norm(res_eq_psi), lamda_psi
    
    def compute_v(self, v_init, x_init, x_fin, x_mid, y_init, y_fin, y_mid, psi, lamda_v):
        temp_x = torch.cumsum(self.P * (torch.cos(psi) * self.t)[:, np.newaxis], axis=0)
        temp_y = torch.cumsum(self.P * (torch.sin(psi) * self.t)[:, np.newaxis], axis=0)

        A_x = temp_x[0:self.num - 1]
        A_y = temp_y[0:self.num - 1]

        A_x_goal = A_x[-1].resize(1, self.nvar)
        b_x_goal = (x_fin - x_init).unsqueeze(0)

        A_y_goal = A_y[-1].resize(1, self.nvar)
        b_y_goal = (y_fin - y_init).unsqueeze(0)

        A_x_mid = A_x[self.mid_idx]
        A_y_mid = A_y[self.mid_idx]

        b_x_mid = x_mid - x_init
        b_y_mid = y_mid - y_init

        A_vel_init = self.P[0].reshape(1, self.nvar)
        b_vel_init = torch.tensor([v_init], dtype=torch.double)

        obj_x_goal = self.rho_goal * torch.matmul(A_x_goal.T, A_x_goal)
        linterm_augment_x_goal = -self.rho_goal * torch.matmul(A_x_goal.T, b_x_goal)

        obj_y_goal = self.rho_goal * torch.matmul(A_y_goal.T, A_y_goal)
        linterm_augment_y_goal = -self.rho_goal * torch.matmul(A_y_goal.T, b_y_goal)

        obj_x_mid = self.rho_mid * torch.matmul(A_x_mid.T, A_x_mid)
        linterm_augment_x_mid = -self.rho_mid * torch.matmul(A_x_mid.T, b_x_mid)

        obj_y_mid = self.rho_mid * torch.matmul(A_y_mid.T, A_y_mid)
        linterm_augment_y_mid = -self.rho_mid * torch.matmul(A_y_mid.T, b_y_mid)

        obj_v_init = self.rho_eq * torch.matmul(A_vel_init.T, A_vel_init)
        linterm_augment_v_init = -self.rho_eq * torch.matmul(A_vel_init.T, b_vel_init)

        cost = obj_x_goal + obj_y_goal + self.cost_smoothness_v + obj_x_mid + obj_y_mid + obj_v_init
        lincost = -lamda_v + linterm_augment_x_goal + linterm_augment_y_goal + linterm_augment_x_mid + linterm_augment_y_mid + linterm_augment_v_init

        sol = torch.linalg.solve(-cost, lincost)

        v = torch.matmul(self.P, sol)

        res_v_init = torch.matmul(A_vel_init, sol) - b_vel_init
        lamda_v = lamda_v - self.rho_eq * torch.matmul(A_vel_init.T, res_v_init)

        return sol, lamda_v, v
    
    def optimize(self, fixed_params, variable_params):
        x_init, y_init, v_init, psi_init, psidot_init = fixed_params
        x_fin, y_fin, v_fin, psi_fin, psidot_fin = variable_params[:5]
        x_mid = variable_params[5::2]
        y_mid = variable_params[6::2]

        v = v_init * torch.ones(self.num, dtype=torch.double).to(device)
        psi = psi_init * torch.ones(self.num, dtype=torch.double).to(device)

        res_psi = torch.ones(self.maxiter, dtype=torch.double).to(device)
        res_w_psi = torch.ones(self.maxiter, dtype=torch.double).to(device)
        res_w = torch.ones(self.maxiter, dtype=torch.double).to(device)
        res_eq_psi = torch.ones(self.maxiter, dtype=torch.double).to(device)
        res_eq = torch.ones(self.maxiter, dtype=torch.double).to(device)
        b_eq_psi = torch.hstack((psi_init, psidot_init, psi_fin, psidot_fin))

        lamda_wc = torch.zeros(self.nvar, dtype=torch.double).to(device)
        lamda_ws = torch.zeros(self.nvar, dtype=torch.double).to(device)
        lamda_psi = torch.zeros(self.nvar, dtype=torch.double).to(device)
        lamda_v = torch.zeros(self.nvar, dtype=torch.double).to(device)
        
        for i in range(0, self.maxiter):
            wc, ws, c_wc_psi, c_ws_psi = self.compute_w_psi(x_init, y_init, x_fin, y_fin, x_mid, y_mid, psi, v, lamda_wc, lamda_ws)
            psi, c_psi, res_psi[i], res_eq_psi[i], lamda_psi = self.compute_psi(wc, ws, b_eq_psi, lamda_psi)
            c_v, lamda_v, v = self.compute_v(v_init, x_init, x_fin, x_mid, y_init, y_fin, y_mid, psi, lamda_v)

            res_wc = wc - torch.cos(psi)
            res_ws = ws - torch.sin(psi)
            
            lamda_wc = lamda_wc - self.rho_w_psi * torch.matmul(self.A_w.T, res_wc)
            lamda_ws = lamda_ws - self.rho_w_psi * torch.matmul(self.A_w.T, res_ws)

            res_w[i] = torch.linalg.norm(torch.hstack((res_wc, res_ws)))
            
        primal_sol = torch.hstack((c_psi, c_v))
        return primal_sol
    
    def solve(self, fixed_params, variable_params):
        batch_size, _ = fixed_params.size()
        y = torch.zeros(batch_size, 2 * self.nvar, dtype=torch.double).to(device)
        mid_size = self.mid_idx.size()[0]
        for i in range(batch_size):
            fixed_params_cur = fixed_params[i]
            variable_params_cur = variable_params[i]
            primal_sol = self.optimize(fixed_params_cur, variable_params_cur)
            y[i, :] = primal_sol
        return y, None
    
    def objective(self, fixed_params, variable_params, y):
        x_init, y_init, v_init, psi_init, psidot_init = torch.chunk(fixed_params, 5, dim=1)
        x_fin, y_fin, v_fin, psi_fin, psidot_fin = torch.chunk(variable_params[:, :5], 5, dim=1)
        x_mid = variable_params[:, 5::2]
        y_mid = variable_params[:, 6::2]
        
        c_psi = y[:, :self.nvar]
        c_v = y[:, self.nvar:]
        
        v_new = torch.matmul(self.P, c_v.T)
        psi_new = torch.matmul(self.P, c_psi.T)
        psidot_new = torch.matmul(self.Pdot, c_psi.T)
        
        x_temp = x_init + torch.cumsum(v_new * torch.cos(psi_new) * self.t, dim=0).T
        y_temp = y_init + torch.cumsum(v_new * torch.sin(psi_new) * self.t, dim=0).T
        
        x = torch.hstack((x_init, x_temp[:, :-1]))
        y = torch.hstack((y_init, y_temp[:, :-1]))
        
        cost_final_pos = 0.5 * self.rho_goal * ((x[:, -1:] - x_fin) ** 2 + (y[:, -1:] - y_fin) ** 2)
        cost_psi_term = 0.5 * self.rho_eq * ((psi_new[:1, :].T - psi_init) ** 2 + (psidot_new[:1, :].T - psidot_init) ** 2 + (psi_new[-1:, :].T - psi_fin) ** 2 + (psidot_new[-1:, :].T - psidot_fin) ** 2)
        cost_v_term = 0.5 * self.rho_eq * (v_new[:1, :].T - v_init) ** 2
        cost_mid_term = 0.5 * self.rho_mid * (torch.sum((x[:, self.mid_idx] - x_mid) ** 2, dim=1, keepdim=True) + torch.sum((y[:, self.mid_idx] - y_mid) ** 2, dim=1, keepdim=True))
        
        cost = cost_final_pos + cost_psi_term + cost_v_term + cost_mid_term
        return cost

#### Test with original inputs

In [4]:
x_init = 129
y_init = 202.3 - 2.0

x_fin = 238
y_fin = 139.0 - 6.0

v_init = 20.0
v_fin =  16.0

vdot_init = 0.0
vdot_fin = 0.0

psi_init = 0.0
psidot_init = 0.0
psiddot_init = 0.0

psi_fin = -91 * np.pi / 180
psidot_fin = 0.0
psiddot_fin = 0.0

x_mid = np.hstack((162, 198, 228))
y_mid = np.hstack((202, 193, 170))

In [5]:
fixed_params = np.array([x_init, y_init, v_init, psi_init, psidot_init])
variable_params = np.array([x_fin, y_fin, v_fin, psi_fin, psidot_fin, x_mid[0], y_mid[0], x_mid[1], y_mid[1], x_mid[2], y_mid[2]])
fixed_params.shape, variable_params.shape

((5,), (11,))

In [8]:
fixed_params_new = torch.tensor([fixed_params], dtype=torch.double).to(device)
variable_params_new = torch.tensor([variable_params], dtype=torch.double).to(device)

In [9]:
opt_node = OPTNode()
primal_sol, _ = opt_node.solve(fixed_params_new, variable_params_new)

In [10]:
primal_sol

tensor([[ 4.2368e-05,  7.0907e-05,  2.2459e-03,  2.2370e-02, -3.5275e-01,
         -6.6612e-02, -8.7917e-01, -7.2940e-01, -1.3671e+00, -1.5883e+00,
         -1.5883e+00,  2.0000e+01,  1.9729e+01,  1.9459e+01,  1.9205e+01,
          1.8979e+01,  1.8772e+01,  1.8586e+01,  1.8404e+01,  1.8225e+01,
          1.8046e+01,  1.7867e+01]], dtype=torch.float64)

#### Test with random inputs

In [None]:
fixed_params_new = torch.randn(10, 5, dtype=torch.double).to(device)
variable_params_new = torch.randn(10, 11, dtype=torch.double).to(device)
variable_params_new.requires_grad = True

In [None]:
variable_params_new[7]

In [None]:
fixed_params_new[7]

In [None]:
opt_node = OPTNode()
primal_sol, _ = opt_node.solve(fixed_params_new, variable_params_new)

In [None]:
primal_sol.size()

In [None]:
primal_sol[7]

In [None]:
cost = opt_node.objective(fixed_params_new, variable_params_new, primal_sol)
cost

In [None]:
cost[7]

In [None]:
# opt_node.gradient(fixed_params_new, variable_params_new, y=primal_sol)
primal_sol.requires_grad = True
variable_params_new.requires_grad = True
opt_node.gradient(fixed_params_new, variable_params_new, y=primal_sol)