In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from classic_envs.random_integrator import DiscRandomIntegratorEnv
from lyapunov_reachability.speculation_tabular import DefaultQAgent, ExplorerQAgent, LyapunovQAgent
from lyapunov_reachability.shortest_path.dp import SimplePIAgent
from gridworld.utils import test, play, visualize

from cplex.exceptions import CplexSolverError
import matplotlib.pyplot as plt
import torch.nn as nn
import numpy as np
import pickle
import os

# Integrator

In [None]:
# DO NOT CHANGE THIS ----
n = 10
grid_points = 40
# You can change them ----
episode_length = 1000
confidence = 0.8
strict_done = True
safe_init = True

env = DiscRandomIntegratorEnv(n=n, grid_points=grid_points, seed=None)
name = '{}-{}-integrator'.format(n, grid_points)
configure_light = env.speculate_light()

In [None]:
steps = int(1e8)
improve_interval = int(1e6)
log_interval = int(5e6)
save_interval= int(5e6)
auxiliary_args = {'learning_rate': .1, 'gamma': 1.-1e-4, 'epsilon': 0.5, 'epsilon_decay': 2e-8, 'epsilon_last': 0.1}

In [None]:
seeds = list(range(1001, 1021))

In [None]:
# Without baseline,
#baseline_dir = None
#baseline_steps = None

# ----
baseline_dir = os.path.join(name, 'tabular-initial')
baseline_step = int(5e6)

In [None]:
def_args = {'learning_rate': .1, 'gamma': 1.-1e-6, 'epsilon': 0.5, 'epsilon_decay': 1e-6, 'epsilon_last': 0.1}
def_q = DefaultQAgent(env, confidence, *configure_light, seed=1234,
                      strict_done=True, safe_init=False, save_dir=baseline_dir)
def_q.reachability_q *= 1. - np.random.random(def_q.reachability_q.shape)* 1e-1
def_stats = def_q.run(baseline_step, episode_length, improve_interval=int(1e6),
                      log_interval=int(1e6), save_interval=int(5e6), **def_args)
with open(os.path.join(baseline_dir, 'log.pkl'), 'wb') as f:
    pickle.dump(def_stats, f, pickle.HIGHEST_PROTOCOL)
del def_q
del def_stats

log	:: steps=1000000, episode_safety=0.21, episode_runtime=102.13<br>
log	:: steps=2000000, episode_safety=0.72, episode_runtime=173.26<br>
log	:: steps=3000000, episode_safety=0.98, episode_runtime=146.23<br>
log	:: steps=4000000, episode_safety=0.89, episode_runtime=126.14<br>
log	:: steps=5000000, episode_safety=0.85, episode_runtime=113.51<br>
chart	:: safe_set_size=0.05125

### Check out the "answer" first.

In [None]:
configure = env.speculate(grid_points=grid_points, confidence=confidence, episode_length=episode_length*10)
simple_pi = SimplePIAgent(env, *configure,)
_ = simple_pi.run(10, print_freq=1, verbose=False, name=os.path.join(name, 'answer'),)
safety_v = simple_pi.safety_v

In [None]:
# configure = env.speculate(grid_points=grid_points, confidence=confidence, episode_length=episode_length*10)
# simple_pi = SimplePIAgent(env, *configure,)
# _ = simple_pi.run(5, print_freq=1, verbose=False, name=os.path.join(name, 'answer'),)
# safety_v = simple_pi.safety_v

simple_pi = np.load(os.path.join(name, 'answer', 'answer.npz'))
safety_v = simple_pi['safety_v']

In [None]:
fig, ax = plt.subplots(1,1)
img = plt.imshow(
    safety_v.reshape((grid_points, grid_points)), cmap='plasma', extent=[.5, -.5, -1., 1.,], aspect=.5)
ax.set_xlabel('Velocity')
ax.set_ylabel('Position')
plt.clim(0., 1.)
fig.colorbar(img)
plt.savefig(os.path.join(name, 'answer.eps'), format='eps', dpi=300)

## Get maximal safe set without exploration.

In [None]:
for seed in seeds:
    def_q = DefaultQAgent(env, confidence, *configure_light, strict_done=strict_done, safe_init=safe_init,
                          baseline_dir=baseline_dir, baseline_step=baseline_step,
                          save_dir=os.path.join(name, 'spec-tb-default-{}'.format(seed)))
    def_stats = def_q.run(steps, episode_length, improve_interval=improve_interval,
                          log_interval=log_interval, save_interval=save_interval, **auxiliary_args)
    with open(os.path.join(name, 'spec-tb-default-{}'.format(seed), 'log.pkl'), 'wb') as f:
        pickle.dump(def_stats, f, pickle.HIGHEST_PROTOCOL)
    del def_q
    del def_stats

## Get maximal safe set with Lyapunov (still no exploration)

In [None]:
for seed in seeds:
    lyap_q = LyapunovQAgent(env, confidence, *configure_light, strict_done=strict_done, safe_init=safe_init,
                            baseline_dir=baseline_dir, baseline_step=baseline_step,
                            save_dir=os.path.join(name, 'spec-tb-lyapunov-{}'.format(seed)))
    lyap_stats = lyap_q.run(steps, episode_length, improve_interval=improve_interval,
                            log_interval=log_interval, save_interval=save_interval, **auxiliary_args)
    with open(os.path.join(name, 'spec-tb-lyapunov-{}'.format(seed), 'log.pkl'), 'wb') as f:
        pickle.dump(lyap_stats, f, pickle.HIGHEST_PROTOCOL)
    del lyap_q
    del lyap_stats

## Get maximal safe set with MSE.

In [None]:
for seed in seeds:
    exp_q = ExplorerQAgent(env, confidence, *configure_light, strict_done=strict_done, safe_init=safe_init,
                           baseline_dir=baseline_dir, baseline_step=baseline_step,
                           save_dir=os.path.join(name, 'spec-tb-explorer-{}'.format(seed)))
    exp_stats = exp_q.run(steps, episode_length, improve_interval=improve_interval,
                          log_interval=log_interval, save_interval=save_interval, **auxiliary_args)
    with open(os.path.join(name, 'spec-tb-explorer-{}'.format(seed), 'log.pkl'), 'wb') as f:
        pickle.dump(exp_stats, f, pickle.HIGHEST_PROTOCOL)
    del exp_q
    del exp_stats