In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("..")
import torch
import plotly
import plotly.graph_objs as go
plotly.offline.init_notebook_mode(connected=True)

import robust_value_approx.relu_mpc as relu_mpc
import robust_value_approx.adversarial_sample as adversarial_sample
import robust_value_approx.utils as utils

# Double Integrator Example

In [None]:
import double_integrator_utils

N = 5

# value function we benchmark the resulting control against
vf = double_integrator_utils.get_value_function(N=N+1)
V = vf.get_value_function()

# value function used by the controller beyond one time step
vf_next = double_integrator_utils.get_value_function(N=N)
V_next = vf_next.get_value_function()

x0_lo = -1 * torch.ones(vf.sys.x_dim, dtype=vf.dtype)
x0_up = 1 * torch.ones(vf.sys.x_dim, dtype=vf.dtype)

# file options
model_file = '../data/double_int'

# Vertical Ball Paddle Example

In [None]:
import ball_paddle_utils

N = 5

# value function we benchmark the resulting control against
vf = ball_paddle_utils.get_value_function_vertical(N=N+1)
V = vf.get_value_function()

# value function used by the controller beyond one time step
vf_next = ball_paddle_utils.get_value_function_vertical(N=N)
V_next = vf_next.get_value_function()

x0_lo = torch.Tensor([1.5, .15, -5., -1.]).type(vf.dtype)
x0_up = torch.Tensor([2., .15, 1., 5.]).type(vf.dtype)

# file options
model_file = '../data/vertical_ball_paddle'

# Get Controllers

In [None]:
baseline_model = torch.load(model_file + '_baseline_model.pt')
robust_model = torch.load(model_file + '_robust_model.pt')
baseline_ctrl = relu_mpc.ReLUMPC(vf, baseline_model)
robust_ctrl = relu_mpc.ReLUMPC(vf, robust_model)
def eval_one_step_ctrl(ctrl, x0_samp):
    (u0, x1) = ctrl.get_ctrl(x0_samp)
    if u0 is None:
        return (None, None, None)
    (x_traj_next, u_traj_next, alpha_traj_next) = vf_next.sol_to_traj(x1, *V_next(x1)[1:])
    if x_traj_next is None:
        return (None, None, None)
    x_traj = torch.cat((x0_samp.unsqueeze(0).t(), x_traj_next), axis=1)
    u_traj = torch.cat((u0.unsqueeze(0).t(), u_traj_next), axis=1)
    # assumes no cost on alpha! (true on all benchmarks)    
    value = vf.traj_cost(x_traj[:,1:], u_traj)
    return (value, x_traj, u_traj)

# Evaluate their Performance on Average

In [None]:
cost_opt = torch.Tensor(0, 1).type(vf.dtype)
cost_baseline = torch.Tensor(0, 1).type(vf.dtype)
cost_robust = torch.Tensor(0, 1).type(vf.dtype)

In [None]:
num_samples = 100
for i in range(num_samples):    
    x0_samp = torch.rand(vf.sys.x_dim) * (x0_up - x0_lo) + x0_lo
    
    optimal_value, opt_s, opt_alpha = V(x0_samp)
    if optimal_value is None:
        continue
    (x_traj_opt, u_traj_opt, alpha_traj_opt) = vf.sol_to_traj(x0_samp, opt_s, opt_alpha)

    (baseline_value, baseline_x_traj, baseline_u_traj) = eval_one_step_ctrl(baseline_ctrl, x0_samp)
    if baseline_value is None:
        continue
    
    (robust_value, robust_x_traj, robust_u_traj) = eval_one_step_ctrl(robust_ctrl, x0_samp)
    if robust_value is None:
        continue

    cost_opt = torch.cat((cost_opt, torch.Tensor([[optimal_value]]).type(vf.dtype)), 0)
    cost_baseline = torch.cat((cost_baseline, torch.Tensor([[baseline_value.item()]]).type(vf.dtype)), 0)
    cost_robust = torch.cat((cost_robust, torch.Tensor([[robust_value.item()]]).type(vf.dtype)), 0)                                                               
    utils.update_progress((i + 1) / num_samples)

In [None]:
nbin = 100
bartop = cost_opt.shape[0]

sub_opt_baseline = (cost_baseline - cost_opt).squeeze()
sub_opt_robust = (cost_robust - cost_opt).squeeze()
layout = go.Layout(annotations=[dict(showarrow=False, x=torch.mean(sub_opt_baseline), y=int(.9*bartop),
                                     text="baseline: mean", xanchor="right", xshift=-4, opacity=.95, textangle=0),
                                dict(showarrow=False, x=torch.mean(sub_opt_robust), y=int(.8*bartop),
                                     text="robust (ours): mean", xanchor="right", xshift=-4, opacity=.95, textangle=0),
                               dict(showarrow=False, x=torch.max(sub_opt_baseline), y=int(.65*bartop),
                                     text="baseline: max", xanchor="right", xshift=-4, opacity=.5, textangle=0),
                               dict(showarrow=False, x=torch.max(sub_opt_robust), y=int(.85*bartop),
                                     text="robust (ours): max", xanchor="right", xshift=-4, opacity=.5, textangle=0)])
fig = go.Figure(layout=layout)
fig.update_layout(barmode='overlay')
fig.add_trace(go.Histogram(x=sub_opt_baseline, name="baseline", nbinsx=nbin, bingroup=1, marker_color='#e32249'))
fig.add_trace(go.Histogram(x=sub_opt_robust, name="robust", bingroup=1, marker_color='#0f4c75'))
fig.update_traces(opacity=0.75)
fig.add_shape(go.layout.Shape(type='line', xref='x', yref='y',
                              x0=torch.mean(sub_opt_baseline), y0=0, x1=torch.mean(sub_opt_baseline), y1=bartop,
                              line=dict(dash='dash', color='#d8c962'),opacity=.5))
fig.add_shape(go.layout.Shape(type='line', xref='x', yref='y',
                              x0=torch.mean(sub_opt_robust), y0=0, x1=torch.mean(sub_opt_robust), y1=bartop,
                              line=dict(dash='dash', color='#d8c962'),opacity=.5))
fig.add_shape(go.layout.Shape(type='line', xref='x', yref='y',
                              x0=torch.max(sub_opt_baseline), y0=0, x1=torch.max(sub_opt_baseline), y1=bartop,
                              line=dict(dash='dash', color='#d8c962'),opacity=1.))
fig.add_shape(go.layout.Shape(type='line', xref='x', yref='y',
                              x0=torch.max(sub_opt_robust), y0=0, x1=torch.max(sub_opt_robust), y1=bartop,
                              line=dict(dash='dash', color='#d8c962'),opacity=1.))
fig.show()

In [None]:
as_generator = adversarial_sample.AdversarialSampleGenerator(vf_next, x0_lo, x0_up)
# (eps_adv, x_adv) = as_generator.get_upper_bound_global(baseline_model, requires_grad=False)
# (eps_adv, x_adv) = as_generator.get_upper_bound_global(robust_model, requires_grad=False)
# (eps_adv, x_adv_, V_adv, _) = as_generator.get_squared_bound_sample(baseline_model, max_iter=15, conv_tol=1e-4, learning_rate=.1)
(eps_adv, x_adv_, V_adv, _) = as_generator.get_squared_bound_sample(robust_model, max_iter=15, conv_tol=1e-4, learning_rate=.1)
print(x_adv_)
x_adv = x_adv_[-1, :]

In [None]:
x0_samp = torch.rand(vf.sys.x_dim) * (x0_up - x0_lo) + x0_lo
# x0_samp = x_adv
(optimal_value, opt_s, opt_alpha) = V(x0_samp)
(x_traj_opt, u_traj_opt, alpha_traj_opt) = vf.sol_to_traj(x0_samp, opt_s, opt_alpha)
(robust_value, robust_x_traj, robust_u_traj) = eval_one_step_ctrl(robust_ctrl, x0_samp)
(baseline_value, baseline_x_traj, baseline_u_traj) = eval_one_step_ctrl(baseline_ctrl, x0_samp)

data = []
# data += [go.Scatter(y=robust_x_traj[0,:], name="ball y robust", line=dict(color='green')),
#          go.Scatter(y=robust_x_traj[1,:], name="paddle y robust", line=dict(color='green'))]
# data += [go.Scatter(y=baseline_x_traj[0,:], name="ball y baseline", line=dict(color='blue')),
#          go.Scatter(y=baseline_x_traj[1,:], name="paddle y baseline",line=dict(color='blue'))]
# data += [go.Scatter(y=x_traj_opt[0,:], name="ball y opt", line=dict(dash="dash", color='red')),
#          go.Scatter(y=x_traj_opt[1,:], name="paddle y opt", line=dict(dash="dash", color='red'))]

data += [go.Scatter(y=robust_u_traj[0,:], name="ball y robust", line=dict(color='green'))]
data += [go.Scatter(y=baseline_u_traj[0,:], name="ball y baseline", line=dict(color='blue'))]
data += [go.Scatter(y=u_traj_opt[0,:], name="ball y opt", line=dict(dash='dash', color='red'))]
plotly.offline.iplot(data)

print("Optimal value: " + str(optimal_value))
print("Baseline value: " + str(baseline_value.item()))
print("Robust value: " + str(robust_value.item()))