In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt

import casadi as ca

from intent_predict.irl.data_processing.utils import IrlDataLoader
from intent_predict.irl.optimizer import WeightOptimizer, DecisionMaker

np.random.seed(1)

In [2]:
# Load data
file_path = 'data/DJI_0012'
irl_dataset = IrlDataLoader(file_path)

In [9]:
loss_list = []
w = np.random.randn(6)
w /= np.linalg.norm(w)
all_w = w.copy()

decision_maker = DecisionMaker()

feature, label = irl_dataset[0]

p_val, lam_val, mu_val = decision_maker.solve(w=w, phi=feature)

print(p_val, lam_val, mu_val)

[0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] (18,) [-0.01054809]


In [12]:
dim = p_val.shape[0]

p = ca.MX.sym('p', dim)
lam = ca.MX.sym('lambda', dim)
mu = ca.MX.sym('mu')

dot_L = feature.T @ w -lam + mu * ca.MX.ones(dim,1) + p
print(dot_L.shape)
comp_slack = p * lam
print(comp_slack.shape)
primal_eq = ca.sum1(p) - 1

f = ca.Function('f', [p, lam, mu], [dot_L, comp_slack, primal_eq], ['p', 'lambda', 'mu'], ['dot_L', 'comp_slack', 'primal_eq'])

f_jac = f.jacobian()
print(f_jac)

J = f_jac(p_val, lam_val, mu_val, ca.DM.zeros(dim), ca.DM.zeros(dim), 0)

print(J)
print("Determinant", np.linalg.det(J))

(18, 1)
(18, 1)
jac_f:(p[18],lambda[18],mu,out_dot_L[18x1,0nz],out_comp_slack[18x1,0nz],out_primal_eq[1x1,0nz])->(jac[37x37,108nz]) MXFunction
sparse: 37-by-37, 108 nnz
 (0, 0) -> 1
 (18, 0) -> 0.0324767
 (36, 0) -> 1
 (1, 1) -> 1
 (19, 1) -> 0.0298681
 (36, 1) -> 1
 (2, 2) -> 1
 (20, 2) -> 0.000259108
 (36, 2) -> 1
 (3, 3) -> 1
 (21, 3) -> 0.0276467
 (36, 3) -> 1
 (4, 4) -> 1
 (22, 4) -> 0.0283032
 (36, 4) -> 1
 (5, 5) -> 1
 (23, 5) -> 0.0253724
 (36, 5) -> 1
 (6, 6) -> 1
 (24, 6) -> 0.00394729
 (36, 6) -> 1
 (7, 7) -> 1
 (25, 7) -> 0.0208252
 (36, 7) -> 1
 (8, 8) -> 1
 (26, 8) -> 0.00130585
 (36, 8) -> 1
 (9, 9) -> 1
 (27, 9) -> 0
 (36, 9) -> 1
 (10, 10) -> 1
 (28, 10) -> 0.0141779
 (36, 10) -> 1
 (11, 11) -> 1
 (29, 11) -> 0.0251133
 (36, 11) -> 1
 (12, 12) -> 1
 (30, 12) -> 0.0228388
 (36, 12) -> 1
 (13, 13) -> 1
 (31, 13) -> 0.0205653
 (36, 13) -> 1
 (14, 14) -> 1
 (32, 14) -> 0.0182932
 (36, 14) -> 1
 (15, 15) -> 1
 (33, 15) -> 0.0160187
 (36, 15) -> 1
 (16, 16) -> 1
 (34, 16) ->