In [None]:
%load_ext autoreload
%autoreload 2
#%matplotlib inline

In [None]:
from classic_envs.random_integrator import DiscRandomIntegratorEnv
from lyapunov_reachability.speculation_tabular import DefaultQAgent, ExplorerQAgent, LyapunovQAgent
from lyapunov_reachability.shortest_path.dp import SimplePIAgent
from gridworld.utils import test, play, visualize

from cplex.exceptions import CplexSolverError
import matplotlib
matplotlib.rcParams['figure.dpi'] = 300
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
import matplotlib.pyplot as plt
import torch.nn as nn
import numpy as np
import pickle
import os
import seaborn as sns
sns.set()

In [None]:
env_name = 'integrator'
# DO NOT CHANGE THIS ----
n = 10
grid_points = 40
# You can change them ----
episode_length = 200
confidence = 0.8
name = '{}-{}-{}'.format(n, grid_points, env_name)

In [None]:
ans_data = np.load(os.path.join(name, 'answer', 'answer.npz'))
ans = ans_data['safety_v']
del ans_data
max_safe_set = np.sum(ans >= confidence)

### Load baseline log files

In [None]:
baseline_dir = os.path.join(name, 'tabular-initial')
baseline_step = int(5e6)

a = np.load(os.path.join(baseline_dir, '{}.npz'.format(int(baseline_step))))['reachability_q']
a = np.min(a, -1)
init_found = np.sum((a <= 1. - confidence) * (ans >= confidence))
init_notsafe = np.sum((a <= 1. - confidence) * (ans < confidence))
init_error = np.mean((a - ans) ** 2)
del a

### Set necessary parameters to load log files

In [None]:
steps = int(1e8)
improve_interval = int(1e6)
log_interval = int(5e6)
save_interval= int(5e6)

In [None]:
fig_kwargs = {'format': 'eps',
              'dpi': 300,
              'rasterized': True,
              'bbox_inches': 'tight',
              'pad_inches': 0,
              'frameon': False,
             }
# Figsize default: (6., 4.); do not change this

### Get statistics

In [None]:
ckpts = int(steps // save_interval)
xaxis = save_interval * np.array(range(0, ckpts+1))#(np.array(range(1, ckpts+1))-0.5)

In [None]:
bl_seeds = list(range(1001, 1021))
seeds = bl_seeds

bl_error = []
bl_found = []
bl_notsafe = []
bl_cover = []

for seed in seeds:
    file_prev = np.load(os.path.join(name, 'tabular-initial',
                                     '{}.npz'.format(int(baseline_step))))
    map_prev = np.sum(file_prev['reachability_q'] * file_prev['policy'], -1)
    del file_prev
    for i in range(1, ckpts+1):
        file_now = np.load(os.path.join(name, 'spec-tb-default-{}'.format(seed),
                                        '{}.npz'.format(int(save_interval * i))))
        map_now = np.sum(file_now['reachability_q'] * file_now['policy'], -1)
        bl_found.append(np.sum((map_now <= 1. - confidence) * (ans >= confidence)))
        bl_notsafe.append( np.sum((map_now <= 1. - confidence) * (ans < confidence)))
        bl_error.append(np.mean((map_now - ans) ** 2))
        bl_cover.append( np.sum((map_now  <= 1. - confidence) * (map_prev <= 1. - confidence)) / np.sum(map_prev <= 1. - confidence) )
        
        map_prev[:] = map_now[:]
        del map_now, file_now
    del map_prev

bl_error = np.array(bl_error).reshape((len(seeds), ckpts))
bl_found = np.array(bl_found).reshape((len(seeds), ckpts))
bl_notsafe = np.array(bl_notsafe).reshape((len(seeds), ckpts))
bl_cover = np.array(bl_cover).reshape((len(seeds), ckpts))

In [None]:
lyap_seeds = list(range(1001, 1021))
seeds = lyap_seeds

lyap_error = []
lyap_found = []
lyap_notsafe = []
lyap_cover = []

for seed in seeds:
    file_prev = np.load(os.path.join(name, 'tabular-initial',
                                     '{}.npz'.format(int(baseline_step))))
    map_prev = np.sum(file_prev['reachability_q'] * file_prev['policy'], -1)
    del file_prev
    for i in range(1, ckpts+1):
        file_now = np.load(os.path.join(name, 'spec-tb-lyapunov-{}'.format(seed),
                                        '{}.npz'.format(int(save_interval * i))))
        map_now = np.sum(file_now['reachability_q'] * file_now['policy'], -1)
        lyap_found.append(np.sum((map_now <= 1. - confidence) * (ans >= confidence)))
        lyap_notsafe.append(np.sum((map_now <= 1. - confidence) * (ans < confidence)))
        lyap_error.append(np.mean((map_now - ans) ** 2))
        lyap_cover.append( np.sum((map_now  <= 1. - confidence) * (map_prev <= 1. - confidence)) / np.sum(map_prev <= 1. - confidence) )
        
        map_prev[:] = map_now[:]
        del map_now, file_now
    del map_prev

lyap_error = np.array(lyap_error).reshape((len(seeds), ckpts))
lyap_found = np.array(lyap_found).reshape((len(seeds), ckpts))
lyap_notsafe = np.array(lyap_notsafe).reshape((len(seeds), ckpts))
lyap_cover = np.array(lyap_cover).reshape((len(seeds), ckpts))

In [None]:
exp_seeds = list(range(1001, 1021))
seeds = exp_seeds

exp_error = []
exp_found = []
exp_notsafe = []
exp_cover = []

for seed in seeds:
    file_prev = np.load(os.path.join(name, 'tabular-initial',
                                     '{}.npz'.format(int(baseline_step))))
    map_prev = np.sum(file_prev['reachability_q'] * file_prev['policy'], -1)
    del file_prev
    for i in range(1, ckpts+1):
        file_now = np.load(os.path.join(name, 'spec-tb-explorer-{}'.format(seed),
                                        '{}.npz'.format(int(save_interval * i))))
        map_now = np.sum(file_now['reachability_q'] * file_now['policy'], -1)
        exp_found.append(np.sum((map_now <= 1. - confidence) * (ans >= confidence)))
        exp_notsafe.append(np.sum((map_now <= 1. - confidence) * (ans < confidence)))
        exp_error.append(np.mean((map_now - ans) ** 2))
        exp_cover.append( np.sum((map_now  <= 1. - confidence) * (map_prev <= 1. - confidence)) / np.sum(map_prev <= 1. - confidence) )
        
        map_prev[:] = map_now[:]
        del map_now, file_now
    del map_prev

exp_error = np.array(exp_error).reshape((len(seeds), ckpts))
exp_found = np.array(exp_found).reshape((len(seeds), ckpts))
exp_notsafe = np.array(exp_notsafe).reshape((len(seeds), ckpts))
exp_cover = np.array(exp_cover).reshape((len(seeds), ckpts))

In [None]:
del seeds

In [None]:
fig, ax = plt.subplots(1, 1, sharex=True)

bl_mu = np.concatenate(([init_found], np.mean(bl_found, axis=0)), axis=0) / max_safe_set
bl_std = np.concatenate(([0], np.std(bl_found, axis=0)), axis=0) / max_safe_set
lyap_mu = np.concatenate(([init_found], np.mean(lyap_found, axis=0)), axis=0) / max_safe_set
lyap_std = np.concatenate(([0], np.std(lyap_found, axis=0)), axis=0) / max_safe_set
exp_mu = np.concatenate(([init_found], np.mean(exp_found, axis=0)), axis=0) / max_safe_set
exp_std = np.concatenate(([0], np.std(exp_found, axis=0)), axis=0) / max_safe_set

ax.fill_between(xaxis, bl_mu - bl_std, bl_mu + bl_std, alpha=0.25, color='teal')
ax.fill_between(xaxis, lyap_mu - lyap_std, lyap_mu + lyap_std, alpha=0.25, color='coral')
ax.fill_between(xaxis, exp_mu - exp_std, exp_mu + exp_std, alpha=0.25, color='mediumblue')
ax.plot(xaxis, bl_mu, label='No Lyapunov', color='teal')
ax.plot(xaxis, lyap_mu, label='LSS', color='coral')
ax.plot(xaxis, exp_mu, label='ESS', color='mediumblue')

ax.legend(ncol=3, loc='upper right')
# plt.xlabel('Steps (1 step=128 samples)')
# plt.ylabel('Ratio of safe states found')
plt.xlabel('Steps')
ax.ticklabel_format(style='sci', scilimits=(-3,4), axis='both')
plt.xlim(0, ckpts*save_interval)
plt.ylim(.1, .65)
# plt.ylim(-0.05, 1.05)
# ax.set_rasterized(True)
ax.set_rasterization_zorder(0)
fig.set_dpi(300)
fig.patch.set_alpha(0)
fig.tight_layout()
plt.savefig(os.path.join(name, '{}-spec-tb-[safe_set]over[max_safe_set].pdf'.format(env_name)), format='pdf')

In [None]:
fig, ax = plt.subplots(1, 1, sharex=True)

bl_mu = np.concatenate(([init_notsafe], np.mean(bl_notsafe, axis=0)), axis=0) / np.prod(ans.shape)
bl_std = np.concatenate(([0], np.std(bl_notsafe, axis=0)), axis=0) / np.prod(ans.shape)
lyap_mu = np.concatenate(([init_notsafe], np.mean(lyap_notsafe, axis=0)), axis=0) / np.prod(ans.shape)
lyap_std = np.concatenate(([0], np.std(lyap_notsafe, axis=0)), axis=0) / np.prod(ans.shape)
exp_mu = np.concatenate(([init_notsafe], np.mean(exp_notsafe, axis=0)), axis=0) / np.prod(ans.shape)
exp_std = np.concatenate(([0], np.std(exp_notsafe, axis=0)), axis=0) / np.prod(ans.shape)

ax.fill_between(xaxis, bl_mu - bl_std, bl_mu + bl_std, alpha=0.25, color='teal')
ax.fill_between(xaxis, lyap_mu - lyap_std, lyap_mu + lyap_std, alpha=0.25, color='coral')
ax.fill_between(xaxis, exp_mu - exp_std, exp_mu + exp_std, alpha=0.25, color='mediumblue')
ax.plot(xaxis, bl_mu, label='No Lyapunov', color='teal')
ax.plot(xaxis, lyap_mu, label='LSS', color='coral')
ax.plot(xaxis, exp_mu, label='ESS', color='mediumblue')

ax.legend(ncol=3, loc='upper right')
# plt.xlabel('Steps (1 step=128 samples)')
# plt.ylabel('Ratio of false-positive safe states')
plt.xlabel('Steps')
ax.ticklabel_format(style='sci', scilimits=(-3,4), axis='both')
plt.xlim(save_interval, ckpts*save_interval)
plt.ylim(-0.05, 0.20)
# ax.set_rasterized(True)
ax.set_rasterization_zorder(0)
fig.set_dpi(300)
fig.patch.set_alpha(0)
fig.tight_layout()
plt.savefig(os.path.join(name, '{}-spec-tb-[false_positive_safe_set]over[state_space].pdf'.format(env_name)), format='pdf')

In [None]:
fig, ax = plt.subplots(1, 1, sharex=True)

bl_mu = np.concatenate(([1], np.mean(bl_cover, axis=0)), axis=0)
bl_std = np.concatenate(([0], np.std(bl_cover, axis=0)), axis=0)
lyap_mu = np.concatenate(([1], np.mean(lyap_cover, axis=0)), axis=0)
lyap_std = np.concatenate(([0], np.std(lyap_cover, axis=0)), axis=0)
exp_mu = np.concatenate(([1], np.mean(exp_cover, axis=0)), axis=0)
exp_std = np.concatenate(([0], np.std(exp_cover, axis=0)), axis=0)

ax.fill_between(xaxis, bl_mu - bl_std, bl_mu + bl_std, alpha=0.25, color='teal')
ax.fill_between(xaxis, lyap_mu - lyap_std, lyap_mu + lyap_std, alpha=0.25, color='coral')
ax.fill_between(xaxis, exp_mu - exp_std, exp_mu + exp_std, alpha=0.25, color='mediumblue')
ax.plot(xaxis, bl_mu, label='No Lyapunov', color='teal')
ax.plot(xaxis, lyap_mu, label='LSS', color='coral')
ax.plot(xaxis, exp_mu, label='ESS', color='mediumblue')

ax.legend(ncol=2, loc='lower right')
# plt.xlabel('Steps (1 step=128 samples)')
# plt.ylabel('Ratio of safe states found')
plt.xlabel('Steps')
ax.ticklabel_format(style='sci', scilimits=(-3,4), axis='both')
plt.xlim(int(0. * save_interval), int(ckpts * save_interval))
plt.ylim(0.95, 1.05)
# ax.set_rasterized(True)
ax.set_rasterization_zorder(0)
fig.set_dpi(300)
fig.patch.set_alpha(0)
fig.tight_layout()plt.savefig(os.path.join(name, '{}-spec-tb-[cover_ratio].pdf'.format(env_name)), format='pdf')

In [None]:
fig, ax = plt.subplots(1, 1, sharex=True)

bl_mu = np.concatenate(([init_error], np.mean(bl_error, axis=0)), axis=0)
bl_std = np.concatenate(([0], np.std(bl_error, axis=0)), axis=0)
lyap_mu = np.concatenate(([init_error], np.mean(lyap_error, axis=0)), axis=0)
lyap_std = np.concatenate(([0], np.std(lyap_error, axis=0)), axis=0)
exp_mu = np.concatenate(([init_error], np.mean(exp_error, axis=0)), axis=0)
exp_std = np.concatenate(([0], np.std(exp_error, axis=0)), axis=0)

ax.fill_between(xaxis, bl_mu - bl_std, bl_mu + bl_std, alpha=0.25, color='teal')
ax.fill_between(xaxis, lyap_mu - lyap_std, lyap_mu + lyap_std, alpha=0.25, color='coral')
ax.fill_between(xaxis, exp_mu - exp_std, exp_mu + exp_std, alpha=0.25, color='mediumblue')
ax.plot(xaxis, bl_mu, label='No Lyapunov', color='teal')
ax.plot(xaxis, lyap_mu, label='LSS', color='coral')
ax.plot(xaxis, exp_mu, label='ESS', color='mediumblue')

ax.legend(ncol=2, loc='lower right')
# plt.xlabel('Steps (1 step=128 samples)')
# plt.ylabel('Ratio of false-positive safe states')
plt.xlabel('Steps')
ax.ticklabel_format(style='sci', scilimits=(-3,4), axis='both')
plt.xlim(save_interval, ckpts*save_interval)
#plt.ylim(-0.05, 1.05)
# ax.set_rasterized(True)
ax.set_rasterization_zorder(0)
fig.set_dpi(300)
fig.patch.set_alpha(0)
fig.tight_layout()
plt.savefig(os.path.join(name, '{}-spec-tb-[mean_square_error].pdf'.format(env_name)), format='pdf')

### Learning curve

In [None]:
init_safety = 0
with open(os.path.join(baseline_dir, 'log.pkl'), 'rb') as f:
    data = pickle.load(f)
    init_safety = data['average_safety'][-1]
    del data

In [None]:
bl_episode_safety = []
lyap_episode_safety = []
exp_episode_safety = []

ckpts = int(steps // log_interval)
xaxis = log_interval * np.array(range(0, ckpts+1))#(np.array(range(1, ckpts+1))-0.5)

seeds = list(range(1001, 1021))
for seed in seeds:
    with open(os.path.join(name, 'spec-tb-default-{}'.format(seed), 'log.pkl'), 'rb') as f:
        data = pickle.load(f)
        bl_episode_safety += data['average_safety']
        del data
bl_episode_safety = np.array(bl_episode_safety).reshape((len(seeds), ckpts))

seeds = list(range(1001, 1021))
for seed in seeds:
    with open(os.path.join(name, 'spec-tb-lyapunov-{}'.format(seed), 'log.pkl'), 'rb') as f:
        data = pickle.load(f)
        lyap_episode_safety += data['average_safety']
        del data
lyap_episode_safety = np.array(lyap_episode_safety).reshape((len(seeds), ckpts))

seeds = list(range(1001, 1021))
for seed in seeds:
    with open(os.path.join(name, 'spec-tb-explorer-{}'.format(seed), 'log.pkl'), 'rb') as f:
        data = pickle.load(f)
        exp_episode_safety += data['average_safety']
        del data    
exp_episode_safety = np.array(exp_episode_safety).reshape((len(seeds), ckpts))

In [None]:
fig, ax = plt.subplots(1, 1, sharex=True)

bl_mu = np.concatenate(([init_safety], np.mean(bl_episode_safety, axis=0)), axis=0)
bl_std = np.concatenate(([0], np.std(bl_episode_safety, axis=0)), axis=0)
lyap_mu = np.concatenate(([init_safety], np.mean(lyap_episode_safety, axis=0)), axis=0)
lyap_std = np.concatenate(([0], np.std(lyap_episode_safety, axis=0)), axis=0)
exp_mu = np.concatenate(([init_safety], np.mean(exp_episode_safety, axis=0)), axis=0)
exp_std = np.concatenate(([0], np.std(exp_episode_safety, axis=0)), axis=0)

ax.fill_between(xaxis, bl_mu - bl_std, bl_mu + bl_std, alpha=0.25, color='teal')
ax.fill_between(xaxis, lyap_mu - lyap_std, lyap_mu + lyap_std, alpha=0.25, color='coral')
ax.fill_between(xaxis, exp_mu - exp_std, exp_mu + exp_std, alpha=0.25, color='mediumblue')
ax.plot(xaxis, bl_mu, label='No Lyapunov', color='teal')
ax.plot(xaxis, lyap_mu, label='LSS', color='coral')
ax.plot(xaxis, exp_mu, label='ESS', color='mediumblue')

ax.plot(xaxis, confidence * np.ones((ckpts+1,)), 'r--')

ax.legend(ncol=2, loc='lower right')
# plt.xlabel('Steps (1 step=128 samples)')
# plt.ylabel('Average episode safety')
plt.xlabel('Steps')
ax.ticklabel_format(style='sci', scilimits=(-3,4), axis='both')
plt.xlim(0, ckpts*log_interval)
plt.ylim(.70, 1.05)
# plt.ylim(-0.05, 1.05)
# ax.set_rasterized(True)
ax.set_rasterization_zorder(0)
fig.set_dpi(300)
fig.patch.set_alpha(0)
fig.tight_layout()
plt.savefig(os.path.join(name, '{}-spec-tb-[train_episode_safety].pdf'.format(env_name)), format='pdf')

### Visualization

In [None]:
def get_reachability(name, logdir, seeds, ckpts, reshape=True, reference=None):
    reachability_list = []
    for seed in seeds:
        tmp = []
        for i in range(1, ckpts+1):
            a = np.load(os.path.join(name, '{}-{}'.format(logdir, seed),
                                     '{}.npz'.format(int(save_interval * i))))['reachability_q']
            a = np.min(a, -1)            
            tmp.append(a)
            del a
        tmp = np.array(tmp)
        reachability_list.append(tmp)
    if reference is None:
        reachability_list = np.array(reachability_list).mean(0)
    else:
        idx = np.argmax(reference[:, -1])
        reachability_list = np.array(reachability_list)[idx, ...]
    if reshape:
        try:
            reachability_list = reachability_list.reshape((ckpts, grid_points, grid_points))
        except ValueError:
            print("Reshape unavailable.")
    return reachability_list

In [None]:
ckpts = int(steps // save_interval)
xaxis = save_interval * np.array(range(1, ckpts+1))#(np.array(range(1, ckpts+1))-0.5)

In [None]:
bl_list = get_reachability(name, 'spec-tb-default', bl_seeds, ckpts, reshape=True, reference=bl_found)
lyap_list = get_reachability(name, 'spec-tb-lyapunov', lyap_seeds, ckpts, reshape=True, reference=lyap_found)
exp_list = get_reachability(name, 'spec-tb-explorer', exp_seeds, ckpts, reshape=True, reference=exp_found)

In [None]:
idx = ckpts

In [None]:
fig, ax = plt.subplots(1,1)
img = plt.imshow(ans.reshape((grid_points, grid_points)) >= confidence,
                 cmap='inferno', extent=[.5, -.5, -1., 1.,], aspect=.5)
#img = plt.imshow(ans.reshape((grid_points, grid_points)) >= confidence, cmap='plasma', extent=[.5, -.5, -1., 1.,], aspect=.5)
ax.set_xlabel('Velocity')
ax.set_xticks(np.arange(-.5, .5+1e-3, .2))
ax.set_ylabel('Position')
ax.set_yticks(np.arange(-1., 1.+2e-3, .4))
# ax.get_yaxis().set_visible(False)

plt.clim(0., 1.)
# fig.colorbar(img)
plt.grid(False)
fig.set_dpi(300)
fig.patch.set_facecolor('none')
fig.patch.set_alpha(0)
fig.tight_layout()
ax.patch.set_facecolor('none')
ax.patch.set_alpha(0)
plt.savefig(os.path.join(name, 'integrator-spec-tb-visualize-answer.pdf'), format='pdf',
            facecolor=fig.get_facecolor(), edgecolor='none', bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(1,1)
# Show False-positive and True-positive altogether.
img = plt.imshow((1.-bl_list[idx-1] >= confidence) * (ans.reshape((grid_points, grid_points)) >= confidence)
                 + (1.-bl_list[idx-1] >= confidence) * (ans.reshape((grid_points, grid_points)) < confidence) * 0.5,
                 cmap='inferno', extent=[.5, -.5, -1., 1.,], aspect=.5)
#img = plt.imshow((1.-bl_list[idx-1] >= confidence), cmap='plasma', extent=[.5, -.5, -1., 1.,], aspect=.5)

ax.set_xlabel('Velocity')
ax.set_xticks(np.arange(-.5, .5+1e-3, .2))
ax.set_ylabel('Position')
ax.set_yticks(np.arange(-1., 1.+2e-3, .4))
ax.get_yaxis().set_visible(False)

plt.clim(0., 1.)
# fig.colorbar(img)
plt.grid(False)
fig.set_dpi(300)
fig.patch.set_facecolor('none')
fig.patch.set_alpha(0)
fig.tight_layout()
ax.patch.set_facecolor('none')
ax.patch.set_alpha(0)
fig.savefig(os.path.join(name, 'integrator-spec-tb-visualize-baseline-{}.pdf'.format(save_interval * idx)), format='pdf',
            facecolor=fig.get_facecolor(), edgecolor='none', bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(1,1)

# Show False-positive and True-positive altogether.
img = plt.imshow((1.-lyap_list[idx-1] >= confidence) * (ans.reshape((grid_points, grid_points)) >= confidence)
                 + (1.-lyap_list[idx-1] >= confidence) * (ans.reshape((grid_points, grid_points)) < confidence) * 0.5,
                 cmap='inferno', extent=[.5, -.5, -1., 1.,], aspect=.5)
#img = plt.imshow((1.-lyap_list[idx-1] >= confidence), cmap='plasma', extent=[.5, -.5, -1., 1.,], aspect=.5)
ax.set_xlabel('Velocity')
ax.set_xticks(np.arange(-.5, .5+1e-3, .2))
ax.set_ylabel('Position')
ax.set_yticks(np.arange(-1., 1.+2e-3, .4))
ax.get_yaxis().set_visible(False)

plt.clim(0., 1.)
# fig.colorbar(img)
plt.grid(False)
fig.set_dpi(300)
fig.patch.set_facecolor('none')
fig.patch.set_alpha(0)
fig.tight_layout()
ax.patch.set_facecolor('none')
ax.patch.set_alpha(0)
plt.savefig(os.path.join(name, 'integrator-spec-tb-visualize-lyapunov-{}.pdf'.format(save_interval * idx)),
            format='pdf', facecolor=fig.get_facecolor(), edgecolor='none', bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(1,1)

# Show False-positive and True-positive altogether.
img = plt.imshow((1.-exp_list[idx-1] >= confidence) * (ans.reshape((grid_points, grid_points)) >= confidence)
                 + (1.-exp_list[idx-1] >= confidence) * (ans.reshape((grid_points, grid_points)) < confidence) * 0.5,
                 cmap='inferno', extent=[.5, -.5, -1., 1.,], aspect=.5)
#img = plt.imshow((1.-exp_list[idx-1] >= confidence), cmap='plasma', extent=[.5, -.5, -1., 1.,], aspect=.5)
ax.set_xlabel('Velocity')
ax.set_xticks(np.arange(-.5, .5+1e-3, .2))
ax.set_ylabel('Position')
ax.set_yticks(np.arange(-1., 1.+2e-3, .4))
ax.get_yaxis().set_visible(False)

plt.clim(0., 1.)
# fig.colorbar(img)
plt.grid(False)
fig.set_dpi(300)
fig.patch.set_facecolor('none')
fig.patch.set_alpha(0)
fig.tight_layout()
ax.patch.set_facecolor('none')
ax.patch.set_alpha(0)
plt.savefig(os.path.join(name, 'integrator-spec-tb-visualize-explorer-{}.pdf'.format(save_interval * idx)),
            format='pdf', facecolor=fig.get_facecolor(), edgecolor='none', bbox_inches='tight')

### Trial

In [None]:
env = DiscRandomIntegratorEnv(n=n, grid_points=grid_points, seed=None)

In [None]:
model = ExplorerQAgent.load(os.path.join(name, 'spec-tb-lyapunov-123'), int(1e8), env=env)

In [None]:
def tester(env, act, trials=100, initial_state=None, gamma=1.):
    count = 0
    safe_run = 0.
    done = False
    while count < trials:
        _ = env.reset()
        if initial_state is not None:
            env.set_state(initial_state)
            st = initial_state
        else:
            st = env.quantize(env.state)
        episode_rew = 0.
        episode_safety = 1.
        t = 0
        while not done: #t <= episode_length and not done:
            _, rew, done, info = env.step(act.step(st))
            st = info['state']
            episode_safety *= info['safety']
            episode_rew = gamma * episode_rew + rew
            t += 1
        count += 1
        if episode_safety > 0.:
            safe_run += 1.
    env.close()
    return 1. * safe_run / trials

In [None]:
tester(env, model, trials=10, initial_state=763)