In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("..")
sys.path.append(".")
import torch
import numpy as np
import copy
import plotly
import plotly.graph_objs as go
import pickle
from datetime import datetime
plotly.offline.init_notebook_mode(connected=True)

import robust_value_approx.samples_generator as samples_generator
import robust_value_approx.samples_buffer as samples_buffer
import robust_value_approx.value_approximation as value_approximation
import robust_value_approx.training_log as training_log
import robust_value_approx.controllers as controllers

import pendulum_utils
import double_pendulum_utils
import acrobot_utils



You may have already (directly or indirectly) imported `torch` which uses
`RTLD_GLOBAL`. Using `RTLD_GLOBAL` may cause symbol collisions which manifest
themselves in bugs like "free(): invalid pointer". Please consider importing
`pydrake` (and related C++-wrapped libraries like `cv2`, `open3d`, etc.)
*before* importing `torch`. For more details, see:
https://github.com/pytorch/pytorch/issues/3059#issuecomment-534676459




In [866]:
dtype = torch.float64

"""
sys_name = 'pendulum'
opt = dict(
    dtype = dtype,
    
    sys_name = sys_name,
    validation_file = '../data/validation_' + sys_name,
    init_file = '../data/init_' + sys_name,
    
    offline_horizon = 50,
    offline_dt = .1,

    include_time = True,

    online_horizon = 5,
    control_interp = "foh",

    sim_x0 = torch.tensor([np.pi+1.5, 1.], dtype=dtype),
    sim_dt = .1,
    sim_horizon = 50,
    
    x_nom = torch.tensor([np.pi, 0.], dtype=dtype),
    u_nom = torch.zeros(1, dtype=dtype),

    lqr_Q = torch.diag(torch.tensor([1., 1.], dtype=dtype)),
    lqr_R = torch.diag(torch.tensor([.1], dtype=dtype)),
    
    x0_lo = torch.tensor([np.pi-.5, -1.], dtype=dtype),
    x0_up = torch.tensor([np.pi+.5, 1.], dtype=dtype),
    
    batch_size = 20,
    
    learning_rate_value = 1e-3,
    learning_rate_policy = 1e-3,
    
    nn_width_value = 30,
    nn_depth_value = 2,
    
    nn_width_policy = 30,
    nn_depth_policy = 2,

    num_samples_validation = [10, 10],
    
    max_buffer_size = None,

    init_num_samples = [10, 10],
    init_num_trainig_step = 0,
    
    num_generations = 20,
    num_samples_per_generation = 20,
    num_train_step_per_gen = 500,
    
    adv_max_iter = 5,
    adv_conv_tol = 1e-5,
    adv_learning_rate = .1,

    # benchmark params
    bench_x0 = torch.tensor([np.pi, 0.], dtype=dtype),
    bench_x_goal = torch.tensor([np.pi, 0.], dtype=dtype),
    bench_x0_eps = torch.tensor([2., 3.], dtype=dtype),
    bench_num_breaks = [10, 10],
    
    sys_utils = pendulum_utils,
)
"""

"""
sys_name = 'double_pendulum'
opt = dict(
    dtype = dtype,
    
    sys_name = sys_name,
    validation_file = '../data/validation_' + sys_name,
    init_file = '../data/init_' + sys_name,
    
    offline_horizon = 50,
    offline_dt = .1,
    
    include_time = True,
    
    online_horizon = 5,
    control_interp = "foh",

    sim_x0 = torch.tensor([np.pi+.5, .5, 1., 1.], dtype=dtype),
    sim_dt = .1,
    sim_horizon = 50,
    
    x_nom = torch.tensor([np.pi, 0., 0., 0.], dtype=dtype),
    u_nom = torch.zeros(2, dtype=dtype),

    lqr_Q = torch.diag(torch.tensor([1., 1., 1., 1.], dtype=dtype)),
    lqr_R = torch.diag(torch.tensor([.1, .1], dtype=dtype)),
    
    x0_lo = torch.tensor([np.pi-.5, -.5, -1., -1.], dtype=dtype),
    x0_up = torch.tensor([np.pi+.5, .5, 1., 1.], dtype=dtype),
    
    batch_size = 30,
    
    learning_rate_value = 1e-3,
    learning_rate_policy = 1e-3,
    
    nn_width_value = 30,
    nn_depth_value = 2,
    
    nn_width_policy = 30,
    nn_depth_policy = 2,

    num_samples_validation = [3, 3, 3, 3],
    
    max_buffer_size = None,

    init_num_samples = [3, 3, 3, 3],
    init_num_trainig_step = 0,
    
    num_generations = 10,
    num_samples_per_generation = 20,
    num_train_step_per_gen = 500,
    
    adv_max_iter = 5,
    adv_conv_tol = 1e-5,
    adv_learning_rate = .1,

    # benchmark params
    bench_x0 = torch.tensor([np.pi, 0., 0., 0.], dtype=dtype),
    bench_x_goal = torch.tensor([np.pi, 0., 0., 0.], dtype=dtype),
    bench_x0_eps = torch.tensor([2., 2., 5., 5.], dtype=dtype),
    bench_num_breaks = [10, 10, 1, 1],
    
    sys_utils = double_pendulum_utils,
)
"""

sys_name = 'acrobot'
opt = dict(
    dtype = dtype,
    
    sys_name = sys_name,
    validation_file = '../data/validation_' + sys_name,
    init_file = '../data/init_' + sys_name,
    
    offline_horizon = 80,
    offline_dt = .1,

#     include_time = True,
    include_time = False,

    online_horizon = 5,
    control_interp = "foh",
    
    sim_x0 = torch.tensor([np.pi+.05, .05, 0., 0.], dtype=dtype),

    sim_dt = .1,
    sim_horizon = 80,
    
    x_nom = torch.tensor([np.pi, 0., 0., 0.], dtype=dtype),
    u_nom = torch.zeros(1, dtype=dtype),

    lqr_Q = torch.diag(torch.tensor([1., 1., 1., 1.], dtype=dtype)),
    lqr_R = torch.diag(torch.tensor([.1], dtype=dtype)),
    
    x0_lo = torch.tensor([np.pi-.1, -.1, 0., 0.], dtype=dtype),
    x0_up = torch.tensor([np.pi+.1, .1, 0., 0.], dtype=dtype),
    
    batch_size = 100,
    
    learning_rate_value = 1e-3,
    weight_decay_value = 0,
    
    learning_rate_policy = 1e-3,
    weight_decay_policy = 1e-3,
    
    nn_width_value = 20,
    nn_depth_value = 1,
    
    nn_width_policy = 80,
    nn_depth_policy = 2,

    num_samples_validation = 100,
    
    max_buffer_size = None,
    init_num_samples = 100,    
    init_num_trainig_step = 100,

#     num_generations = 20,
    num_generations = 5,
    num_adv_samples_per_generation = 20,
    num_rand_samples_per_generation = 5,
    num_train_step_per_gen = 200,
    
    adv_max_iter = 5,
    adv_conv_tol = 1e-5,
    adv_learning_rate = .75,

    # benchmark params
    bench_x0 = torch.tensor([np.pi, 0., 0., 0.], dtype=dtype),
    bench_x_goal = torch.tensor([np.pi, 0., 0., 0.], dtype=dtype),
    bench_x0_eps = torch.tensor([.2, .2, 0., 0.], dtype=dtype),
    bench_num_breaks = [20, 20, 1, 1],
    
    sys_utils = acrobot_utils,
)

if opt['include_time']:
    opt['final_time'] = opt['sim_dt']*opt['sim_horizon']
else:
    opt['final_time'] = None
    
if opt['include_time']:
    opt['input_dim'] = 1 + len(opt['x0_lo'])
else:
    opt['input_dim'] = len(opt['x0_lo'])

In [725]:
vf, sys = opt['sys_utils'].get_value_function(opt['offline_horizon'], opt['offline_dt'], dtype=opt['dtype'])

In [496]:
V = vf.get_value_function()
v, res = V(opt['sim_x0'])
sys.plot_result(res).show()

In [567]:
ctrl, S = controllers.get_lqr_controller(sys.dx,
                                         opt['x_nom'], opt['u_nom'],
                                         opt['lqr_Q'], opt['lqr_R'],
                                         vf.u_lo[0], vf.u_up[0])
# ctrl = controllers.get_limited_lookahead_controller(opt['sys_utils'].get_value_function(opt['offline_horizon'],
#                                                                                         opt['offline_dt'],
#                                                                                         dtype=opt['dtype'])[0])
x_traj_sim, t_traj_sim = controllers.sim_ctrl(opt['sim_x0'], vf.u_dim[0], sys.dx, ctrl,
                                              opt['sim_dt'], opt['sim_horizon'],
                                              opt['control_interp'])
controllers.plot_sim(t_traj_sim, x_traj_sim, "Full Horizon").show()

In [867]:
samples_gen_rand = samples_generator.RandomSampleGenerator(vf, opt['x0_lo'], opt['x0_up'])
samples_gen_grid = samples_generator.GridSampleGenerator(vf, opt['x0_lo'], opt['x0_up'])
samples_gen_adv = samples_generator.AdversarialSampleGenerator(vf, opt['x0_lo'], opt['x0_up'], 
                                                               max_iter=opt['adv_max_iter'],
                                                               conv_tol=opt['adv_conv_tol'],
                                                               learning_rate=opt['adv_learning_rate'])

In [868]:
if isinstance(opt['num_samples_validation'], list):
    x_validation, v_validation, u_validation = samples_gen_grid.generate_samples(opt['num_samples_validation'],
                                                                                 include_time=opt['include_time'],
                                                                                 show_progress=True)
elif isinstance(opt['num_samples_validation'], int):
    x_validation, v_validation, u_validation = samples_gen_rand.generate_samples(opt['num_samples_validation'],
                                                                                 include_time=opt['include_time'],
                                                                                 show_progress=True)

Progress: [########################################] 100.0%


In [656]:
torch.save(x_validation, opt['validation_file'] + '_x.pt')
torch.save(v_validation, opt['validation_file'] + '_v.pt')
torch.save(u_validation, opt['validation_file'] + '_u.pt')

In [None]:
x_validation = torch.load(opt['validation_file'] + '_x.pt')
v_validation = torch.load(opt['validation_file'] + '_v.pt')
u_validation = torch.load(opt['validation_file'] + '_u.pt')

In [869]:
if isinstance(opt['init_num_samples'], list):
    x_init, v_init, u_init = samples_gen_grid.generate_samples(opt['init_num_samples'],
                                                               include_time=opt['include_time'],
                                                               show_progress=True)
elif isinstance(opt['init_num_samples'], int):
    x_init, v_init, u_init = samples_gen_rand.generate_samples(opt['init_num_samples'],
                                                               include_time=opt['include_time'],
                                                               show_progress=True)

Progress: [########################################] 100.0%


In [658]:
torch.save(x_init, opt['init_file'] + '_x.pt')
torch.save(v_init, opt['init_file'] + '_v.pt')
torch.save(u_init, opt['init_file'] + '_u.pt')

In [432]:
x_init = torch.load(opt['init_file'] + '_x.pt')
v_init = torch.load(opt['init_file'] + '_v.pt')
u_init = torch.load(opt['init_file'] + '_u.pt')

In [885]:
x_mean = torch.mean(x_init, axis=0)
x_std = torch.std(x_init, axis=0)
v_mean = torch.mean(v_init, axis=0)
v_std = torch.std(v_init, axis=0)
u_mean = torch.mean(u_init, axis=0)
u_std = torch.std(u_init, axis=0)

x_i = (x_init - x_mean.unsqueeze(0)) / (x_std.unsqueeze(0) + 1e-5)
x_v = (x_validation - x_mean.unsqueeze(0)) / (x_std.unsqueeze(0)+ 1e-5)
v_i = (v_init - v_mean.unsqueeze(0)) / (v_std.unsqueeze(0) + 1e-5)
v_v = (v_validation - v_mean.unsqueeze(0)) / (v_std.unsqueeze(0) + 1e-5)
u_i = (u_init - u_mean.unsqueeze(0)) / (u_std.unsqueeze(0) + 1e-5)
u_v = (u_validation - u_mean.unsqueeze(0)) / (u_std.unsqueeze(0) + 1e-5)

In [911]:
vf_model = value_approximation.NeuralNetworkModel(vf.dtype,
                                                  opt['input_dim'],
                                                  opt['nn_width_value'],
                                                  opt['nn_depth_value'])
vf_approx = value_approximation.FunctionApproximation(vf_model,
                                                      learning_rate=opt['learning_rate_value'],
                                                      weight_decay=opt['weight_decay_value'])
vf_log = training_log.TrainingLog(prefix="value_baseline")

policy_model = value_approximation.NeuralNetworkModel(vf.dtype,
                                                      opt['input_dim'],
                                                      opt['nn_width_policy'],
                                                      opt['nn_depth_policy'],
                                                      dim_out=vf.u_dim[0],
                                                      
                                                      out_lo=vf.u_lo[0],
                                                      out_up=vf.u_up[0])


policy_approx = value_approximation.FunctionApproximation(policy_model,
                                                          learning_rate=opt['learning_rate_policy'],                                
                                                          weight_decay=opt['weight_decay_policy'])
policy_log = training_log.TrainingLog(prefix="policy_baseline")

In [912]:
samples_buff = samples_buffer.SamplesBuffer(opt['input_dim'],
                                            1,
                                            vf.u_dim[0],
                                            vf.dtype,
                                            max_size=opt['max_buffer_size'])

In [898]:
samples_buff.add_samples(x_init, v_init, u_init)

In [913]:
samples_buff.add_samples(x_i, v_i, u_i)

In [918]:
for train_step_i in range(opt['init_num_trainig_step']):
    x, v, u = samples_buff.get_random_samples(opt['batch_size'])
    
    loss = vf_approx.train_step(x, v)
    vf_log.add_train_loss(loss)
    
    loss = policy_approx.train_step(x, u)
    policy_log.add_train_loss(loss)
    
#     loss = vf_approx.validation_loss(x_validation, v_validation)
    loss = vf_approx.validation_loss(x_v, v_v)
    vf_log.add_validation_loss(loss)
    
#     loss = policy_approx.validation_loss(x_validation, u_validation)
    loss = policy_approx.validation_loss(x_v, u_v)
    policy_log.add_validation_loss(loss)

In [863]:
samples_buff_adv = copy.deepcopy(samples_buff)
vf_approx_adv = copy.deepcopy(vf_approx)
vf_log_adv = training_log.TrainingLog.get_copy(vf_log,
                                               prefix="value_adversarial",
                                               keep_writer=True)

policy_approx_adv = copy.deepcopy(policy_approx)
policy_log_adv = training_log.TrainingLog.get_copy(policy_log,
                                                   prefix="policy_adversarial",
                                                   keep_writer=True)

In [921]:
# random samples
# loss = vf_approx.validation_loss(x_validation, v_validation)
loss = vf_approx.validation_loss(x_v, v_v)
vf_log.add_validation_loss(loss)
# loss = policy_approx.validation_loss(x_validation, u_validation)
loss = policy_approx.validation_loss(x_v, u_v)
policy_log.add_validation_loss(loss)

for gen_i in range(opt['num_generations']):

    (x, v, u) = samples_gen_rand.generate_samples(opt['num_adv_samples_per_generation'],
                                                  include_time=opt['include_time'],
                                                  show_progress=True)
    
    x_ = (x - x_mean.unsqueeze(0)) / x_std.unsqueeze(0)
    v_ = (v - v_mean.unsqueeze(0)) / v_std.unsqueeze(0)
    u_ = (u - u_mean.unsqueeze(0)) / u_std.unsqueeze(0)
    
    for train_step_i in range(opt['num_train_step_per_gen']):
        x_ran, v_ran, u_ran = samples_buff.get_random_samples(opt['num_rand_samples_per_generation'])
        
        loss = vf_approx.train_step(torch.cat([x_, x_ran]), torch.cat([v_, v_ran]))
        vf_log.add_train_loss(loss)

        loss = policy_approx.train_step(torch.cat([x_, x_ran]), torch.cat([u_, u_ran]))
        policy_log.add_train_loss(loss)
        
        loss = vf_approx.validation_loss(x_v, v_v)
        vf_log.add_validation_loss(loss)
        
        loss = policy_approx.validation_loss(x_v, u_v)
        policy_log.add_validation_loss(loss)
    
# #     samples_buff.add_samples(x, v, u)
#     samples_buff.add_samples(x_, v_, u_)

#     for train_step_i in range(opt['num_train_step_per_gen']):
#         samples_indices = samples_buff.get_random_sample_indices(opt['batch_size'])
#         x, v, u = samples_buff.get_samples_from_indices(samples_indices)
    
#         loss = vf_approx.train_step(x, v)
#         vf_log.add_train_loss(loss)

#         loss = policy_approx.train_step(x, u)
#         policy_log.add_train_loss(loss)
    
# #     loss = vf_approx.validation_loss(x_validation, v_validation)
#     loss = vf_approx.validation_loss(x_v, v_v)
#     vf_log.add_validation_loss(loss)
# #     loss = policy_approx.validation_loss(x_validation, u_validation)
#     loss = policy_approx.validation_loss(x_v, u_v)
#     policy_log.add_validation_loss(loss)

Progress: [########################################] 100.0%


In [None]:
# adverserial samples
# loss = vf_approx_adv.validation_loss(x_validation, v_validation)
loss = vf_approx_adv.validation_loss(x_v, v_v)
vf_log_adv.add_validation_loss(loss)
# loss = policy_approx_adv.validation_loss(x_validation, u_validation)
loss = policy_approx_adv.validation_loss(x_v, u_v)
policy_log_adv.add_validation_loss(loss)

for gen_i in range(opt['num_generations']):

    (x, v, u) = samples_gen_adv.generate_samples(opt['num_adv_samples_per_generation'],
                                                 vf_approx_adv,
                                                 include_time=opt['include_time'],
                                                 show_progress=True)
    
    x_ = (x - x_mean.unsqueeze(0)) / x_std.unsqueeze(0)
    v_ = (v - v_mean.unsqueeze(0)) / v_std.unsqueeze(0)
    u_ = (u - u_mean.unsqueeze(0)) / u_std.unsqueeze(0)
    
    for train_step_i in range(opt['num_train_step_per_gen']):
        x_ran, v_ran, u_ran = samples_buff_adv.get_random_samples(opt['num_rand_samples_per_generation'])
        
        loss = vf_approx_adv.train_step(torch.cat([x_, x_ran]), torch.cat([v_, v_ran]))
        vf_log_adv.add_train_loss(loss)

        loss = policy_approx_adv.train_step(torch.cat([x_, x_ran]), torch.cat([u_, u_ran]))
        policy_log_adv.add_train_loss(loss)
        
        loss = vf_approx_adv.validation_loss(x_v, v_v)
        vf_log_adv.add_validation_loss(loss)
        
        loss = policy_approx_adv.validation_loss(x_v, u_v)
        policy_log_adv.add_validation_loss(loss)
    
# #     samples_buff_adv.add_samples(x, v, u)
#     samples_buff_adv.add_samples(x_, v_, u_)

#     for train_step_i in range(opt['num_train_step_per_gen']):
#         samples_indices = samples_buff_adv.get_random_sample_indices(opt['batch_size'])
#         x, v, u = samples_buff_adv.get_samples_from_indices(samples_indices)
    
#         loss = vf_approx_adv.train_step(x, v)
#         vf_log_adv.add_train_loss(loss)

#         loss = policy_approx_adv.train_step(x, u)
#         policy_log_adv.add_train_loss(loss)
    
# #     loss = vf_approx_adv.validation_loss(x_validation, v_validation)
#     loss = vf_approx_adv.validation_loss(x_v, v_v)
#     vf_log_adv.add_validation_loss(loss)
# #     loss = policy_approx_adv.validation_loss(x_validation, u_validation)
#     loss = policy_approx_adv.validation_loss(x_v, u_v)
#     policy_log_adv.add_validation_loss(loss)

In [None]:
ctrl_no_model = controllers.get_limited_lookahead_controller(opt['sys_utils'].get_value_function(opt['online_horizon'],
                                                                                opt['offline_dt'],
                                                                                dtype=opt['dtype'])[0])
x_traj_sim, t_traj_sim = controllers.sim_ctrl(opt['sim_x0'], vf.u_dim[0], sys.dx,
                                              ctrl_no_model, opt['sim_dt'], opt['sim_horizon'],
                                              integration_mode=opt['control_interp'])
controllers.plot_sim(t_traj_sim, x_traj_sim, "No lookahead").show()

In [None]:
ctrl_baseline = controllers.get_limited_lookahead_controller(opt['sys_utils'].get_value_function(opt['online_horizon'],
                                                                                        opt['offline_dt'],
                                                                                        dtype=opt['dtype'])[0],
                                                                                        vf_approx,
                                                                                        opt['final_time'])
x_traj_sim, t_traj_sim = controllers.sim_ctrl(opt['sim_x0'], vf.u_dim[0], sys.dx,
                                              ctrl_baseline, opt['sim_dt'], opt['sim_horizon'],
                                              integration_mode=opt['control_interp'])
controllers.plot_sim(t_traj_sim, x_traj_sim, "Baseline").show()

In [None]:
ctrl_adv = controllers.get_limited_lookahead_controller(opt['sys_utils'].get_value_function(opt['online_horizon'],
                                                                                        opt['offline_dt'],
                                                                                        dtype=opt['dtype'])[0],
                                                                                        vf_approx_adv,
                                                                                        opt['final_time'])
x_traj_sim, t_traj_sim = controllers.sim_ctrl(opt['sim_x0'], vf.u_dim[0], sys.dx,
                                              ctrl_adv, opt['sim_dt'], opt['sim_horizon'],
                                              integration_mode=opt['control_interp'])
controllers.plot_sim(t_traj_sim, x_traj_sim, "Adversarial").show()

In [920]:
ctrl_policy = controllers.get_learned_policy_controller(sys.dx,
                                                        policy_approx,
                                                        opt['final_time'],
                                                        opt['sim_dt'],
                                                       x_mean,
                                                       x_std,
                                                       u_mean,
                                                       u_std)

# ctrl_policy = controllers.get_learned_policy_controller(sys.dx,
#                                                         policy_approx,
#                                                         opt['final_time'],
#                                                         opt['sim_dt'])

x_traj_sim, t_traj_sim = controllers.sim_ctrl(opt['sim_x0'], vf.u_dim[0], sys.dx,
                                              ctrl_policy, opt['sim_dt'], opt['sim_horizon'],
                                              integration_mode=opt['control_interp'])

# x_traj_sim, t_traj_sim = controllers.sim_ctrl(torch.tensor([np.pi, 0., 0., 0.], dtype=vf.dtype), vf.u_dim[0], sys.dx,
#                                               ctrl_policy, opt['sim_dt'], opt['sim_horizon'],
#                                               integration_mode=opt['control_interp'])

controllers.plot_sim(t_traj_sim, x_traj_sim, "Policy baseline").show()

In [733]:
ctrl_policy_adv = controllers.get_learned_policy_controller(sys.dx,
                                                        policy_approx_adv,
                                                        opt['final_time'],
                                                        opt['sim_dt'],
                                                       x_mean,
                                                       x_std,
                                                       u_mean,
                                                       u_std)
x_traj_sim, t_traj_sim = controllers.sim_ctrl(opt['sim_x0'], vf.u_dim[0], sys.dx,
                                              ctrl_policy, opt['sim_dt'], opt['sim_horizon'],
                                              integration_mode=opt['control_interp'])
controllers.plot_sim(t_traj_sim, x_traj_sim, "Policy adversarial").show()

In [None]:
# ctrl_policy_adv = controllers.get_learned_policy_controller(policy_approx_adv,
#                                                             opt['final_time'],
#                                                             opt['sim_dt'])
# x_traj_sim, t_traj_sim = controllers.sim_ctrl(opt['sim_x0'], vf.u_dim[0], sys.dx,
#                                               ctrl_policy_adv, opt['sim_dt'], opt['sim_horizon'],
#                                               integration_mode=opt['control_interp'])
# controllers.plot_sim(t_traj_sim, x_traj_sim, "Policy adversarial").show()

In [746]:
ctrl_lqr, S = controllers.get_lqr_controller(sys.dx,
                                             opt['x_nom'], opt['u_nom'],
                                             opt['lqr_Q'], opt['lqr_R'],
                                             vf.u_lo[0], vf.u_up[0])
bench_lqr = controllers.benchmark_controller(vf.u_dim[0], sys.dx,
                                             ctrl_lqr,
                                             opt['bench_x0'],
                                             opt['bench_x0_eps'],
                                             opt['bench_num_breaks'],
                                             opt['bench_x_goal'],
                                             opt['sim_dt'],
                                             opt['sim_horizon'],
                                             integration_mode=opt['control_interp'])



In [None]:
bench_no_model = controllers.benchmark_controller(vf.u_dim[0], sys.dx,
                                                 ctrl_no_model,
                                                 opt['bench_x0'],
                                                 opt['bench_x0_eps'],
                                                 opt['bench_num_breaks'],
                                                 opt['bench_x_goal'],
                                                 opt['sim_dt'],
                                                 opt['sim_horizon'],
                                                 integration_mode=opt['control_interp'])

In [None]:
bench_baseline = controllers.benchmark_controller(vf.u_dim[0], sys.dx,
                                                 ctrl_baseline,
                                                 opt['bench_x0'],
                                                 opt['bench_x0_eps'],
                                                 opt['bench_num_breaks'],
                                                 opt['bench_x_goal'],
                                                 opt['sim_dt'],
                                                 opt['sim_horizon'],
                                                 integration_mode=opt['control_interp'])

In [None]:
bench_adv = controllers.benchmark_controller(vf.u_dim[0], sys.dx,
                                             ctrl_adv,
                                             opt['bench_x0'],
                                             opt['bench_x0_eps'],
                                             opt['bench_num_breaks'],
                                             opt['bench_x_goal'],
                                             opt['sim_dt'],
                                             opt['sim_horizon'],
                                             integration_mode=opt['control_interp'])

In [747]:
bench_policy = controllers.benchmark_controller(vf.u_dim[0], sys.dx,
                                             ctrl_policy,
                                             opt['bench_x0'],
                                             opt['bench_x0_eps'],
                                             opt['bench_num_breaks'],
                                             opt['bench_x_goal'],
                                             opt['sim_dt'],
                                             opt['sim_horizon'],
                                             integration_mode=opt['control_interp'],
                                             dim1=0,
                                             dim2=1)

In [748]:
bench_policy_adv = controllers.benchmark_controller(vf.u_dim[0], sys.dx,
                                                 ctrl_policy_adv,
                                                 opt['bench_x0'],
                                                 opt['bench_x0_eps'],
                                                 opt['bench_num_breaks'],
                                                 opt['bench_x_goal'],
                                                 opt['sim_dt'],
                                                 opt['sim_horizon'],
                                                 integration_mode=opt['control_interp'],
                                                 dim1=0,
                                                 dim2=1)

In [811]:
zmin = 0.
zmax = 500
width = 500
height = 500

x_ticks = torch.linspace(opt['bench_x0'][0] - opt['bench_x0_eps'][0], opt['bench_x0'][0] + opt['bench_x0_eps'][0], opt['bench_num_breaks'][0])
y_ticks = torch.linspace(opt['bench_x0'][1] - opt['bench_x0_eps'][1], opt['bench_x0'][1] + opt['bench_x0_eps'][1], opt['bench_num_breaks'][1])
x_label = r'$\theta_1$'
y_label = r'$\theta_2$'

fig = go.Figure()
fig.add_trace(go.Heatmap(
    x = x_ticks,
    y = y_ticks,
    z = bench_lqr.detach().numpy(),
    zmin = zmin,
    zmax = zmax,
    ))
fig.update_layout(
    title="LQR",
    width=width,
    height=height,
    xaxis=dict(
        title=x_label
    ),
    yaxis=dict(
        title=y_label
    )
)
fig.show()

fig = go.Figure()
fig.add_trace(go.Heatmap(
    x = x_ticks,
    y = y_ticks,
    z = bench_policy.detach().numpy(),
    zmin = zmin,
    zmax = zmax,
    ))
fig.update_layout(
    title="Policy",
    width=width,
    height=height,
    xaxis=dict(
        title=x_label
    ),
    yaxis=dict(
        title=y_label
    )
)
fig.show()

fig = go.Figure()
fig.add_trace(go.Heatmap(
    x = x_ticks,
    y = y_ticks,
    z = bench_policy_adv.detach().numpy(),
    zmin = zmin,
    zmax = zmax,
    ))
fig.update_layout(
    title="Policy adversarial",
    width=width,
    height=height,
    xaxis=dict(
        title=x_label
    ),
    yaxis=dict(
        title=y_label
    )
)
fig.show()

fig = go.Figure()
fig.add_trace(go.Heatmap(
    x = x_ticks,
    y = y_ticks,
    z = bench_policy_adv.detach().numpy() - bench_policy.detach().numpy(),
    zmin = -.05,
    zmax = .05,
    ))
fig.update_layout(
    title="Policy difference (negative is good)",
    width=width,
    height=height,
    xaxis=dict(
        title=x_label
    ),
    yaxis=dict(
        title=y_label
    )
)
fig.show()

fig = go.Figure()
fig.add_trace(go.Heatmap(
    x = x_ticks,
    y = y_ticks,
    z = torch.Tensor(bench_policy_adv.detach().numpy() - bench_policy.detach().numpy() < .02).type(vf.dtype),
    zmin = 0,
    zmax = 1,
    ))
fig.update_layout(
    title="Policy improvement",
    width=width,
    height=height,
    xaxis=dict(
        title=x_label
    ),
    yaxis=dict(
        title=y_label
    )
)
fig.show()

In [807]:
conv_thresh = .005
width = 500
height = 500

x_ticks = torch.linspace(opt['bench_x0'][0] - opt['bench_x0_eps'][0], opt['bench_x0'][0] + opt['bench_x0_eps'][0], opt['bench_num_breaks'][0])
y_ticks = torch.linspace(opt['bench_x0'][1] - opt['bench_x0_eps'][1], opt['bench_x0'][1] + opt['bench_x0_eps'][1], opt['bench_num_breaks'][1])
x_label = r'$\theta_1$'
y_label = r'$\theta_2$'

fig = go.Figure()
fig.add_trace(go.Heatmap(
    x = x_ticks,
    y = y_ticks,
    z = torch.Tensor(bench_lqr.detach().numpy() < conv_thresh).type(vf.dtype),
    zmin = 0,
    zmax = 1,
    colorscale=[
                [0., "#c70039"],
                [1., "#c7f0db"],
                ],
    showscale=False,
    ))
fig.update_layout(
    title="LQR",
    width=width,
    height=height,
    xaxis=dict(
        title=x_label
    ),
    yaxis=dict(
        title=y_label
    )
)
fig.show()

fig = go.Figure()
fig.add_trace(go.Heatmap(
    x = x_ticks,
    y = y_ticks,
    z = torch.Tensor(bench_policy.detach().numpy() < conv_thresh).type(vf.dtype),
    zmin = 0,
    zmax = 1,
    colorscale=[
                [0., "#c70039"],
                [1., "#c7f0db"],
                ],
    showscale=False,
    ))
fig.update_layout(
    title="Policy",
    width=width,
    height=height,
    xaxis=dict(
        title=x_label
    ),
    yaxis=dict(
        title=y_label
    )
)
fig.show()

fig = go.Figure()
fig.add_trace(go.Heatmap(
    x = x_ticks,
    y = y_ticks,
    z = torch.Tensor(bench_policy_adv.detach().numpy() < conv_thresh).type(vf.dtype),
    zmin = 0,
    zmax = 1,
    colorscale=[
                [0., "#c70039"],
                [1., "#c7f0db"],
                ],
    showscale=False,
    ))
fig.update_layout(
    title="Policy adversarial",
    width=width,
    height=height,
    xaxis=dict(
        title=x_label
    ),
    yaxis=dict(
        title=y_label
    )
)
fig.show()

In [None]:
# fig = go.Figure()
# fig.add_trace(go.Heatmap(
#     z = bench_no_model.detach().numpy(),
#     zmin = zmin,
#     zmax = zmax,
#     ))
# fig.update_layout(
#     title="No Model",
# )
# fig.show()

# fig = go.Figure()
# fig.add_trace(go.Heatmap(
#     z = bench_baseline.detach().numpy(),
#     zmin = zmin,
#     zmax = zmax,
#     ))
# fig.update_layout(
#     title="Baseline",
# )
# fig.show()

# fig = go.Figure()
# fig.add_trace(go.Heatmap(
#     z = bench_adv.detach().numpy(),
#     zmin = zmin,
#     zmax = zmax,
#     ))
# fig.update_layout(
#     title="Adversarial",
# )
# fig.show()

# fig = go.Figure()
# fig.add_trace(go.Heatmap(
#     z = torch.Tensor(bench_no_model.detach().numpy() < conv_thresh).type(vf.dtype),
#     ))
# fig.update_layout(
#     title="No Model",
# )
# fig.show()

# fig = go.Figure()
# fig.add_trace(go.Heatmap(
#     z = torch.Tensor(bench_baseline.detach().numpy() < conv_thresh).type(vf.dtype),
#     zmin = 0,
#     zmax = 1,
#     ))
# fig.update_layout(
#     title="Baseline",
# )
# fig.show()

# fig = go.Figure()
# fig.add_trace(go.Heatmap(
#     z = torch.Tensor(bench_adv.detach().numpy() < conv_thresh).type(vf.dtype),
#     zmin = 0,
#     zmax = 1,
#     ))
# fig.update_layout(
#     title="Adversarial",
# )
# fig.show()