In [2]:
%load_ext autoreload
%autoreload 2

In [20]:
import sys
sys.path.append("..")
sys.path.append(".")
import torch
import plotly
plotly.offline.init_notebook_mode(connected=True)

import robust_value_approx.relu_mpc as relu_mpc
import robust_value_approx.adversarial_sample as adversarial_sample
import robust_value_approx.utils as utils
import plotting_utils

# Double Integrator Example

In [15]:
import double_integrator_utils

N = 5

# value function we benchmark the resulting control against
vf = double_integrator_utils.get_value_function(N=N+1)
V = vf.get_value_function()

# value function used by the controller beyond one time step
vf_next = double_integrator_utils.get_value_function(N=N)
V_next = vf_next.get_value_function()

x0_lo = -1 * torch.ones(vf.sys.x_dim, dtype=vf.dtype)
x0_up = 1 * torch.ones(vf.sys.x_dim, dtype=vf.dtype)

# file options
sys_name = 'double_int'
x_samples_file = '../data/learn_value_function_' + sys_name + '_x'
v_samples_file = '../data/learn_value_function_' + sys_name + '_v'
model_file = '../data/' + sys_name

# Vertical Ball Paddle Example

In [56]:
import ball_paddle_utils

N = 5

# value function we benchmark the resulting control against
vf = ball_paddle_utils.get_value_function_vertical(N=N+1)
V = vf.get_value_function()

# value function used by the controller beyond one time step
vf_next = ball_paddle_utils.get_value_function_vertical(N=N)
V_next = vf_next.get_value_function()

x0_lo = torch.Tensor([1.5, .15, -5., -1.]).type(vf.dtype)
x0_up = torch.Tensor([2., .15, 1., 5.]).type(vf.dtype)

# data file options
sys_name = 'ball_paddle_vertical'
x_samples_file = '../data/learn_value_function_' + sys_name + '_x'
v_samples_file = '../data/learn_value_function_' + sys_name + '_v'
model_file = '../data/' + sys_name

# SLIP Walking Model

In [None]:
# TODO

# Get Controllers

In [59]:
baseline_model = torch.load(model_file + '_baseline_model.pt')
robust_model = torch.load(model_file + '_robust_model.pt')
baseline_ctrl = relu_mpc.ReLUMPC(vf, baseline_model)
robust_ctrl = relu_mpc.ReLUMPC(vf, robust_model)
def eval_one_step_ctrl(ctrl, x0_samp):
    (u0, x1) = ctrl.get_ctrl(x0_samp)
    if u0 is None:
        return (None, None, None)
    (x_traj_next, u_traj_next, alpha_traj_next) = vf_next.sol_to_traj(x1, *V_next(x1)[1:])
    if x_traj_next is None:
        return (None, None, None)
    x_traj = torch.cat((x0_samp.unsqueeze(0).t(), x_traj_next), axis=1)
    u_traj = torch.cat((u0.unsqueeze(0).t(), u_traj_next), axis=1)
    # assumes no cost on alpha! (true on all benchmarks)    
    value = vf.traj_cost(x_traj[:,1:], u_traj)
    return (value, x_traj, u_traj)

In [60]:
x_samples_validation = torch.load(x_samples_file + '_validation.pt')

# Evaluate their Performance on Average

In [61]:
cost_opt = torch.Tensor(0, 1).type(vf.dtype)
cost_baseline = torch.Tensor(0, 1).type(vf.dtype)
cost_robust = torch.Tensor(0, 1).type(vf.dtype)

In [62]:
# num_samples = 500
num_samples = x_samples_validation.shape[0]
for i in range(num_samples):    
#     x0_samp = torch.rand(vf.sys.x_dim) * (x0_up - x0_lo) + x0_lo
    x0_samp = x_samples_validation[i,:]
    
    optimal_value, opt_s, opt_alpha = V(x0_samp)
    if optimal_value is None:
        continue
    (x_traj_opt, u_traj_opt, alpha_traj_opt) = vf.sol_to_traj(x0_samp, opt_s, opt_alpha)

    (baseline_value, baseline_x_traj, baseline_u_traj) = eval_one_step_ctrl(baseline_ctrl, x0_samp)
    if baseline_value is None:
        continue
    
    (robust_value, robust_x_traj, robust_u_traj) = eval_one_step_ctrl(robust_ctrl, x0_samp)
    if robust_value is None:
        continue

    cost_opt = torch.cat((cost_opt, torch.Tensor([[optimal_value]]).type(vf.dtype)), 0)
    cost_baseline = torch.cat((cost_baseline, torch.Tensor([[baseline_value.item()]]).type(vf.dtype)), 0)
    cost_robust = torch.cat((cost_robust, torch.Tensor([[robust_value.item()]]).type(vf.dtype)), 0)                                                               
    utils.update_progress((i + 1) / num_samples)

Progress: [##################----------------------] 46.0%


SolverError: Solver 'GUROBI' failed. Try another solver, or solve with verbose=True for more information.

In [70]:
fig = plotting_utils.control_perf(cost_opt, cost_baseline, cost_robust, nbin=40, bartop=3200)
fig.show()