In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("..")
import torch
import plotly
import plotly.graph_objs as go
plotly.offline.init_notebook_mode(connected=True)

import robust_value_approx.train_value as train_value

# Double Integrator Example

In [None]:
import double_integrator_utils

vf = double_integrator_utils.get_value_function(N=5)

x0_lo = -1 * torch.ones(vf.sys.x_dim, dtype=vf.dtype)
x0_up = 1 * torch.ones(vf.sys.x_dim, dtype=vf.dtype)

# validation options
num_breaks_validation = [100] * vf.sys.x_dim

# file options
x_samples_file = '../data/learn_value_function_double_int_x'
v_samples_file = '../data/learn_value_function_double_int_v'
model_file = '../data/double_int'

# neural network options
nn_width = 8
nn_depth = 0

# setting up adversarial training options
train_opt = train_value.AdversarialWithBaselineTrainingOptions()
train_opt.num_iter_desired = 1001
train_opt.batch_size = 10
train_opt.max_buffer_size = 100000
train_opt.init_buffer_size = 1
train_opt.sample_refresh_rate = 100
train_opt.num_rand_extra = 0
train_opt.x_adv_max_iter = 2
train_opt.x_adv_conv_tol = 1e-5
train_opt.x_adv_lr = .25
train_opt.x_adv0_noise = 1.

num_training_run = 15

# Vertical Ball Paddle Example

In [None]:
import ball_paddle_utils

vf = ball_paddle_utils.get_value_function_vertical(N=5)

x0_lo = torch.Tensor([1.5, .15, -5., -1.]).type(vf.dtype)
x0_up = torch.Tensor([2., .15, 1., 5.]).type(vf.dtype)

# validation options
num_breaks_validation = [20, 1, 20, 20]

# data file options
x_samples_file = '../data/learn_value_function_ball_paddle_vertical_x'
v_samples_file = '../data/learn_value_function_ball_paddle_vertical_v'
model_file = '../data/vertical_ball_paddle'

# neural network options
nn_width = 16
nn_depth = 0

# setting up adversarial training options
train_opt = train_value.AdversarialWithBaselineTrainingOptions()
train_opt.num_iter_desired = 1001
train_opt.batch_size = 100
train_opt.max_buffer_size = 100000
train_opt.init_buffer_size = 100
train_opt.sample_refresh_rate = 100
train_opt.num_rand_extra = 2
train_opt.x_adv_max_iter = 3
train_opt.x_adv_conv_tol = 1e-5
train_opt.x_adv_lr = .25
train_opt.x_adv0_noise = 1.

num_training_run = 10

In [None]:
# checking the spread of trajectories for sanity check
V = vf.get_value_function()
x_traj_min = float('inf') * torch.ones(vf.sys.x_dim, vf.N).type(vf.dtype)
x_traj_max = float('-inf') * torch.ones(vf.sys.x_dim, vf.N).type(vf.dtype)
for i in range(50):
    x0 = torch.rand(vf.sys.x_dim, dtype=vf.dtype) * (x0_up - x0_lo) + x0_lo
    (x_traj, u_traj, _) = vf.sol_to_traj(x0, *V(x0)[1:])
    if x_traj is None:
        continue
    x_traj_min = torch.min(x_traj_min, x_traj)
    x_traj_max = torch.max(x_traj_max, x_traj)
for x0 in [x0_lo, x0_up]:
    (x_traj, u_traj, _) = vf.sol_to_traj(x0, *V(x0)[1:])
    if x_traj is None:
        continue
    x_traj_min = torch.min(x_traj_min, x_traj)
    x_traj_max = torch.max(x_traj_max, x_traj)
data = [go.Scatter(y=x_traj_min[0,:], marker=dict(color="#444"), line=dict(width=0), mode='lines', name=""),
        go.Scatter(y=x_traj_max[0,:], marker=dict(color="#444"), line=dict(width=0), mode='lines', name="ball y", fillcolor='rgba(200, 68, 10, 0.3)', fill='tonexty'),
        go.Scatter(y=x_traj_min[1,:], marker=dict(color="#444"), line=dict(width=0), mode='lines', name=""),
        go.Scatter(y=x_traj_max[1,:], marker=dict(color="#444"), line=dict(width=0), mode='lines', name="paddle y", fillcolor='rgba(38, 98, 68, 0.3)', fill='tonexty')]
x0 = .5 * (x0_lo + x0_up)
(x_traj, u_traj, _) = vf.sol_to_traj(x0, *V(x0)[1:])
data += [go.Scatter(y=x_traj[0,:], name="ball y"),
         go.Scatter(y=x_traj[1,:], name="paddle y")]
(x_traj, u_traj, _) = vf.sol_to_traj(x0_lo, *V(x0_lo)[1:])
data += [go.Scatter(y=x_traj[0,:], name="ball y"),
         go.Scatter(y=x_traj[1,:], name="paddle y")]
(x_traj, u_traj, _) = vf.sol_to_traj(x0_up, *V(x0_up)[1:])
data += [go.Scatter(y=x_traj[0,:], name="ball y"),
         go.Scatter(y=x_traj[1,:], name="paddle y")]
plotly.offline.iplot(data)

# Generating Validation Data

In [None]:
x_samples_validation, v_samples_validation = vf.get_value_sample_grid(x0_lo, x0_up, num_breaks_validation, update_progress=True)

In [None]:
torch.save(x_samples_validation, x_samples_file + '_validation.pt')
torch.save(v_samples_validation, v_samples_file + '_validation.pt')

# Or Loading It From File

In [None]:
x_samples_validation = torch.load(x_samples_file + '_validation.pt')
v_samples_validation = torch.load(v_samples_file + '_validation.pt')

# Adversarial Training

In [None]:
adv = train_value.AdversarialWithBaseline(vf, x0_lo, x0_up,
                                          nn_width=nn_width, nn_depth=nn_depth,
                                          x_samples_validation=x_samples_validation,
                                          v_samples_validation=v_samples_validation)
# for plotting
adv_data_buffers = []
adv_label_buffers = []
rand_data_buffers = []
rand_label_buffers = []
robust_val_losses = []
baseline_val_losses = []
robust_losses = []
baseline_losses = []

In [None]:
for i in range(num_training_run):
    # tranining
    adv.train(train_opt)
    # logging for plotting
    with torch.no_grad():
        adv_data_buffers.append(adv.adv_data_buffer.clone())
        adv_label_buffers.append(adv.adv_label_buffer.clone())
        rand_data_buffers.append(adv.rand_data_buffer.clone())
        rand_label_buffers.append(adv.rand_label_buffer.clone())
        robust_val_losses.append(torch.pow(adv.v_samples_validation.squeeze() - adv.robust_model(adv.x_samples_validation).squeeze(), 2))
        baseline_val_losses.append(torch.pow(adv.v_samples_validation.squeeze() - adv.baseline_model(adv.x_samples_validation).squeeze(), 2))
        robust_losses.append(torch.pow(adv.adv_label_buffer[:,0].squeeze() - adv.robust_model(adv.adv_data_buffer).squeeze(), 2))
        baseline_losses.append(torch.pow(adv.rand_label_buffer[:,0].squeeze() - adv.baseline_model(adv.rand_data_buffer).squeeze(), 2))

In [None]:
# plotting options
max_scale_down = .01
ix = 0
iy = 3
lim_eps = .5
fig = plotly.subplots.make_subplots(rows=1,cols=2,subplot_titles=("Adversarial Buffer", "Baseline Buffer"))
fig.update_layout(showlegend=False)
fig.update_xaxes(showgrid=False, zeroline=True, range=[x0_lo[ix]-lim_eps, x0_up[ix]+lim_eps])
fig.update_yaxes(showgrid=False, zeroline=True, range=[x0_lo[iy]-lim_eps, x0_up[iy]+lim_eps])
for i in range(len(adv_data_buffers)):
    min_robust_val_loss = torch.min(robust_val_losses[i]).item()
    max_robust_val_loss = torch.max(robust_val_losses[i]).item()
    min_baseline_val_loss = torch.min(baseline_val_losses[i]).item()
    max_baseline_val_loss = torch.max(baseline_val_losses[i]).item()
    fig.add_trace(go.Scatter(
            visible=False,
            x=adv_data_buffers[i][:,ix],
            y=adv_data_buffers[i][:,iy],
            mode='markers',
            marker=dict(
            size=7,
            color=robust_losses[i],
            colorscale='Viridis',
            cmin=min_robust_val_loss,
            cmax=max_robust_val_loss*max_scale_down,
            symbol=0,
            showscale=False)), row=1, col=1)
    fig.add_trace(go.Scatter(
            visible=False,
            x=rand_data_buffers[i][:,ix],
            y=rand_data_buffers[i][:,iy],
            mode='markers',
            marker=dict(
            size=7,
            color=baseline_losses[i],
            colorscale='Viridis',
            cmin=min_baseline_val_loss,
            cmax=max_baseline_val_loss*max_scale_down,
            showscale=False)), row=1, col=2)
fig.data[0].visible = True
fig.data[1].visible = True
steps = []
for i in range(int(len(fig.data)/2)):
    step = dict(
        method="restyle",
        args=["visible", [False] * len(fig.data)],
    )
    step["args"][1][2*i] = True
    step["args"][1][2*i+1] = True
    steps.append(step)
sliders = [dict(
    active=0,
    currentvalue={"prefix": "Step: "},
    pad={"t": 50},
    steps=steps
)]
fig.update_layout(
    sliders=sliders
)
fig.show()

In [None]:
advantage = []
with torch.no_grad():
    for i in range(len(adv_data_buffers)):
        advantage.append(torch.mean(baseline_val_losses[i]) - torch.mean(robust_val_losses[i]))
data = [go.Scatter(y=advantage, name="Advantage")]
plotly.offline.iplot(data)

# Saving Models

In [None]:
torch.save(adv.baseline_model, model_file + '_baseline_model.pt')
torch.save(adv.robust_model, model_file + '_robust_model.pt')