In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("..")
import double_integrator_utils
import robust_value_approx.relu_mpc as relu_mpc
import robust_value_approx.model_bounds as model_bounds
import robust_value_approx.utils as utils

import torch
import numpy as np
import cvxpy as cp
import plotly
import plotly.graph_objs as go

plotly.offline.init_notebook_mode(connected=True)

In [None]:
# get value function
vf = double_integrator_utils.get_value_function()
V = vf.get_value_function()

# load non-robust and robust models
model = torch.load('../data/robust_value_demo_model.pt')
model_robust = torch.load('../data/robust_value_demo_robust_model.pt')

# get ReLUMPC controller for each
x0_lo = -1. * torch.ones(vf.sys.x_dim, dtype=vf.dtype)
x0_up = 1. * torch.ones(vf.sys.x_dim, dtype=vf.dtype)
ctrl = relu_mpc.ReLUMPC(vf, model)
ctrl_robust = relu_mpc.ReLUMPC(vf, model_robust)

# compute adverserial example on non-robust model
mb = model_bounds.ModelBounds(model, vf)
bound_opt = mb.upper_bound_opt(model, x0_lo, x0_up)
Q1, Q2, q1, q2, k, G1, G2, h, A1, A2, b = utils.torch_to_numpy(bound_opt)

num_y = Q1.shape[0]
num_gamma = Q2.shape[0]
y = cp.Variable(num_y)
gamma = cp.Variable(num_gamma, boolean=True)

obj = cp.Minimize(.5 * cp.quad_form(y, Q1) + .5 *
                  cp.quad_form(gamma, Q2) + q1@y + q2@gamma + k)
con = [A1@y + A2@gamma == b,
       G1@y + G2@gamma <= h]

prob = cp.Problem(obj, con)
prob.solve(solver=cp.GUROBI, verbose=False)
epsilon = obj.value
x0_adv = torch.Tensor(y.value[:vf.sys.x_dim]).type(vf.dtype)

In [None]:
value, _, _ = V(x0_adv)
x_rollout = torch.Tensor(vf.sys.x_dim, vf.N).type(vf.dtype)
u_rollout = torch.Tensor(vf.sys.u_dim, vf.N).type(vf.dtype)
x_rollout[:, 0] = x0_adv
x_rollout_robust = torch.Tensor(vf.sys.x_dim, vf.N).type(vf.dtype)
u_rollout_robust = torch.Tensor(vf.sys.u_dim, vf.N).type(vf.dtype)
x_rollout_robust[:, 0] = x0_adv
for n in range(vf.N):
    x = x_rollout[:, n]
    (u_ctrl, x_) = ctrl.get_ctrl(x)
    u_rollout[:, n] = u_ctrl
    if n < (vf.N - 1):
        x_rollout[:, n+1] = x_
    x = x_rollout_robust[:, n]
    (u_ctrl, x_) = ctrl_robust.get_ctrl(x)
    u_rollout_robust[:, n] = u_ctrl
    if n < (vf.N - 1):
        x_rollout_robust[:, n+1] = x_
print("Optimal")
print(value)
print("L2")
print(x_rollout)
print(u_rollout)
print(vf.traj_cost(x_rollout[:,1:], u_rollout))
print("===")
print("Robust")
print(x_rollout_robust)
print(u_rollout_robust)
print(vf.traj_cost(x_rollout_robust[:,1:], u_rollout_robust))

In [None]:
cost_to_go = torch.Tensor(0, 3).type(vf.dtype)

In [None]:
num_samples = 250
for i in range(num_samples):
    x0_samp = torch.rand(vf.sys.x_dim) * (x0_up - x0_lo) + x0_lo
#     x0_samp = (2.*torch.rand(vf.sys.x_dim) - 1.) * (x0_up - x0_lo)*.1 + x0_adv
    x0_samp = torch.max(torch.min(x0_samp, x0_up), x0_lo)
    value_optimal, _, _ = V(x0_samp)
    if value_optimal is None:
        continue
    x_rollout = torch.Tensor(vf.sys.x_dim, vf.N).type(vf.dtype)
    u_rollout = torch.Tensor(vf.sys.u_dim, vf.N).type(vf.dtype)
    x_rollout[:, 0] = x0_samp
    x_rollout_robust = torch.Tensor(vf.sys.x_dim, vf.N).type(vf.dtype)
    u_rollout_robust = torch.Tensor(vf.sys.u_dim, vf.N).type(vf.dtype)
    x_rollout_robust[:, 0] = x0_samp
    for n in range(vf.N):
        x = x_rollout[:, n]
        (u_ctrl, x_) = ctrl.get_ctrl(x)
        u_rollout[:, n] = u_ctrl
        if n < (vf.N - 1):
            x_rollout[:, n+1] = x_
        x = x_rollout_robust[:, n]
        (u_ctrl, x_) = ctrl_robust.get_ctrl(x)
        u_rollout_robust[:, n] = u_ctrl
        if n < (vf.N - 1):
            x_rollout_robust[:, n+1] = x_
    value = vf.traj_cost(x_rollout[:,1:], u_rollout)
    value_robust = vf.traj_cost(x_rollout_robust[:,1:], u_rollout_robust)
    cost_to_go = torch.cat((cost_to_go, torch.Tensor([[value_optimal, value.item(), value_robust.item()]]).type(vf.dtype)), 0)
    utils.update_progress((i + 1) / num_samples)

In [None]:
sub_opt = cost_to_go[:,1] - cost_to_go[:,0]
sub_opt_robust = cost_to_go[:,2] - cost_to_go[:,0]

nbin = 100
bartop = 70
layout = go.Layout(annotations=[dict(showarrow=False, x=torch.mean(sub_opt), y=int(.65*bartop),
                                     text="L2: mean", xanchor="right", xshift=-4, opacity=.6, textangle=0),
                                dict(showarrow=False, x=torch.mean(sub_opt_robust), y=int(.85*bartop),
                                     text="L2+ε: mean", xanchor="right", xshift=-4, opacity=.6, textangle=0),
                               dict(showarrow=False, x=torch.max(sub_opt), y=int(.65*bartop),
                                     text="L2: max", xanchor="right", xshift=-4, opacity=.95, textangle=0),
                               dict(showarrow=False, x=torch.max(sub_opt_robust), y=int(.85*bartop),
                                     text="L2+ε: max", xanchor="right", xshift=-4, opacity=.95, textangle=0)])
fig = go.Figure(layout=layout)
fig.update_layout(barmode='overlay')
fig.add_trace(go.Histogram(x=sub_opt, name="L2", nbinsx=nbin, bingroup=1, marker_color='#e32249'))
fig.add_trace(go.Histogram(x=sub_opt_robust, name="L2+ε", bingroup=1, marker_color='#0f4c75'))
fig.update_traces(opacity=0.75)
fig.add_shape(go.layout.Shape(type='line', xref='x', yref='y',
                              x0=torch.mean(sub_opt), y0=0, x1=torch.mean(sub_opt), y1=bartop,
                              line=dict(dash='dash', color='#d8c962'),opacity=.5))
fig.add_shape(go.layout.Shape(type='line', xref='x', yref='y',
                              x0=torch.mean(sub_opt_robust), y0=0, x1=torch.mean(sub_opt_robust), y1=bartop,
                              line=dict(dash='dash', color='#d8c962'),opacity=.5))
fig.add_shape(go.layout.Shape(type='line', xref='x', yref='y',
                              x0=torch.max(sub_opt), y0=0, x1=torch.max(sub_opt), y1=bartop,
                              line=dict(dash='dash', color='#d8c962'),opacity=1.))
fig.add_shape(go.layout.Shape(type='line', xref='x', yref='y',
                              x0=torch.max(sub_opt_robust), y0=0, x1=torch.max(sub_opt_robust), y1=bartop,
                              line=dict(dash='dash', color='#d8c962'),opacity=1.))
fig.show()