In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
from cplex.exceptions import CplexSolverError
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.dpi'] = 216
import numpy as np
import cv2
from cv2 import resize
from PIL import Image
from gym.wrappers import Monitor
import glob
import pickle
import os
import seaborn as sns
sns.set()

# Reacher

* Load experiment setups.

In [None]:
env_name = 'reacher'
episode_length = 300
confidence = 0.8
grid_points = 21

name = '{}-{}'.format(int(episode_length), env_name)

In [None]:
steps = int(2e6)
log_interval = int(1e4)
save_interval= int(1e5)

In [None]:
PI = np.pi
ans = []

for param_2 in range(grid_points):
    pos_2 = 2 * PI * param_2 / (grid_points - 1) - PI
    for param_1 in range(grid_points):
        pos_1 = 2 * PI * param_1 / (grid_points - 1) - PI
        
        safety = (np.abs(0.1 * np.sin(pos_1) + 0.11 * np.sin(pos_1 + pos_2)) <= 0.1) * 1.
        ans.append(safety)

ans = np.array(ans)
max_safe_set = np.sum(ans >= confidence)

In [None]:
baseline_dir = os.path.join(name, 'ddpg-initial')
baseline_step = int(0)

bl_map = np.load(os.path.join(baseline_dir, '{}-reachability-map.npz'.format(int(baseline_step))))['arr_0']

In [None]:
init_found = np.sum((bl_map <= 1. - confidence) * (ans >= confidence))
init_notsafe = np.sum((bl_map <= 1. - confidence) * (ans < confidence))
init_error = np.mean((bl_map - ans) ** 2)

* List of seeds, figure-related arguments.

In [None]:
bl_seeds = list(range(8001, 8011))
lyap_seeds = list(range(8001, 8011))
exp_seeds = list(range(8201, 8210)) + [8211]

In [None]:
fig_kwargs = {'format': 'eps',
              'dpi': 216,
              'rasterized': True,
              'bbox_inches': 'tight',
              'pad_inches': 0,
              'frameon': False,
             }
# Figsize default: (6., 4.); do not change this

## Compute necessary statistics.

In [None]:
ckpts = int(steps // save_interval)
xaxis = np.array(range(0, ckpts+1)) * save_interval

In [None]:
ans = ans.reshape((ans.size,))

In [None]:
def get_stats(seeds, dir_name):

    error = []
    found = []
    notsafe = []
    cover = []

    for seed in seeds:
        map_prev = np.load(os.path.join(baseline_dir, '{}-reachability-map.npz'.format(int(baseline_step))))['arr_0']
        for i in range(1, ckpts+1):
            map_now = np.load(os.path.join(name, '{}-{}'.format(dir_name, seed),
                                           '{}-reachability-map.npz'.format(int(save_interval * i))))['arr_0']
            found.append(np.sum((map_now <= 1. - confidence) * (ans >= confidence)))
            notsafe.append( np.sum((map_now <= 1. - confidence) * (ans < confidence)))
            error.append(np.mean((map_now - ans) ** 2))
            cover.append( np.sum((map_now  <= 1. - confidence) * (map_prev <= 1. - confidence)) / np.sum(map_prev <= 1. - confidence) )

            map_prev[:] = map_now[:]
            del map_now
        del map_prev

    error = np.array(error).reshape((len(seeds), ckpts))
    found = np.array(found).reshape((len(seeds), ckpts))
    notsafe = np.array(notsafe).reshape((len(seeds), ckpts))
    cover = np.array(cover).reshape((len(seeds), ckpts))
    
    return error, found, notsafe, cover

In [None]:
b1_error, b1_found, b1_notsafe, b1_cover = get_stats(bl_seeds, 'spec-def-ddpg')# double Q, double replay

In [None]:
l1_error, l1_found, l1_notsafe, l1_cover = get_stats(lyap_seeds, 'spec-lyap-ddpg')# double Q

In [None]:
e1_error, e1_found, e1_notsafe, e1_cover = get_stats(exp_seeds, 'spec-exp-ddpg')# double Q, double replay, explorer only

# Option 1: Create a state space color map

## Get the color maps.

In [None]:
ans = np.array(ans).reshape((grid_points, grid_points))

In [None]:
def get_reachability(name, logdir, seeds, ckpts, reshape=True, reference=None):
    reachability_list = []
    for seed in seeds:
        tmp = []
        for i in range(1, ckpts+1):
            a = np.load(os.path.join(name, '{}-{}'.format(logdir, seed),
                                     '{}-reachability-map.npz'.format(int(save_interval * i))))['arr_0']
            tmp.append(a)
            del a
        tmp = np.array(tmp)
        reachability_list.append(tmp)
    if reference is None:
        reachability_list = np.array(reachability_list).mean(0)
    else:
        idx = np.argmax(reference[:, -1])
        reachability_list = np.array(reachability_list)[idx, ...]
    if reshape:
        try:
            reachability_list = reachability_list.reshape((ckpts, grid_points, grid_points))
        except ValueError:
            print("Reshape unavailable.")
    return reachability_list

In [None]:
ckpts = int(steps // save_interval)
xaxis = save_interval * np.array(range(1, ckpts+1))#(np.array(range(1, ckpts+1))-0.5)

In [None]:
bl_list = get_reachability(name, 'spec-def-ddpg', bl_seeds, ckpts, reshape=True, reference=b1_found)
lyap_list = get_reachability(name, 'spec-lyap-ddpg', lyap_seeds, ckpts, reshape=True, reference=l1_found)
exp_list = get_reachability(name, 'spec-exp-ddpg', exp_seeds, ckpts, reshape=True, reference=e1_found)

## Create the sequence of images (run just once)

In [None]:
for idx in range(ckpts):
    fig, axes = plt.subplots(1, 3, sharey=True, figsize=(16,5))
    
    # Show False-positive and True-positive altogether.
    im0 = axes[0].imshow((1.-bl_list[idx] >= confidence) * (ans.reshape((grid_points, grid_points)) >= confidence)
                         + (1.-bl_list[idx] >= confidence) * (ans.reshape((grid_points, grid_points)) < confidence) * 0.5,
                         cmap='inferno', extent=[-180., +180., +180., -180.,], aspect=1.)
    im1 = axes[1].imshow((1.-lyap_list[idx-1] >= confidence) * (ans.reshape((grid_points, grid_points)) >= confidence)
                         + (1.-lyap_list[idx-1] >= confidence) * (ans.reshape((grid_points, grid_points)) < confidence) * 0.5,
                         cmap='inferno', extent=[-180., +180., +180., -180.,], aspect=1.)
    im2 = axes[2].imshow((1.-exp_list[idx-1] >= confidence) * (ans >= confidence)
                         + (1.-exp_list[idx-1] >= confidence) * (ans < confidence) * 0.5,
                         cmap='inferno', extent=[-180., +180., +180., -180.,], aspect=1.)

    axes[0].set_title('Baseline')
    axes[1].set_title('LSS')
    axes[2].set_title('ESS (ours)')
    
    im0.set_clim(0., 1.)
    im1.set_clim(0., 1.)
    im2.set_clim(0., 1.)

    for ax in reversed(axes):
        ax.set_xlabel('Angle 1 (degree)')# center
        ax.set_xticks(np.arange(-180., 180.+1, 60.))
        ax.set_ylabel('Angle 2 (degree)')# arm tip
        ax.set_yticks(np.arange(-180., 180.+1, 60.))
#        ax.get_yaxis().set_visible(False)
        ax.patch.set_facecolor('none')
        ax.patch.set_alpha(0)
        ax.grid(False)
    
    fig.set_dpi(216)
    fig.patch.set_facecolor('none')
    fig.patch.set_alpha(0)
    fig.tight_layout()
    fig.savefig(os.path.join(name, 'visualize-frame-{}.png'.format(idx)),
                format='png', facecolor=fig.get_facecolor(), edgecolor='none')

## Create the video

In [None]:
img_array = []

# Sources:
# https://theailearner.com/2018/10/15/creating-video-from-images-using-opencv-python/
# https://stackoverflow.com/questions/30509573/writing-an-mp4-video-using-python-opencv

for idx in range(ckpts):
    fn = os.path.join(name, 'visualize-frame-{}.png'.format(idx))
    img = cv2.imread(fn)
    height, width, layers = img.shape
    size = (width, height)
    img_array.append(img)

#out = cv2.VideoWriter(os.path.join(name, 'visualize.mp4'), cv2.VideoWriter_fourcc(*'DIVX'), 15, size)
out = cv2.VideoWriter(os.path.join(name, 'visualize.mp4'), cv2.VideoWriter_fourcc(*'MP4V'), 4, size)

for i in range(len(img_array)):
    out.write(img_array[i])
out.release()

## Create the GIF image

In [None]:
del img_array
# Source: https://stackoverflow.com/questions/753190/programmatically-generate-video-or-animated-gif-in-python

# filepaths
fp_in = os.path.join(name, 'visualize-frame-*.png')
fp_out = os.path.join(name, 'visualize.gif')

# https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#gif
img, *imgs = [Image.open(f) for f in sorted(glob.glob(fp_in))]
img.save(fp=fp_out, format='gif', append_images=imgs,
         save_all=True, duration=200, loop=0)

# Option 2: Create a Mujoco trial

In [None]:
import torch.nn as nn
from lyapunov_reachability.speculation_ddpg import DefaultDDPG, ExplorerDDPG, DefaultBCQ, ExplorerBCQ, LyapunovDDPG, LyapExpDDPG
from lyapunov_reachability.common.networks import Mlp, Cnn

In [None]:
from complex_envs.reacher import FixedReacherEnv

episode_length = 300
confidence = 0.8
batch_size = 256
gamma = 1. - 1e-3
strict_done = True
n = 11
grid_points = 21
replay_size = int(4e5)

env = FixedReacherEnv()
name = '{}-reacher'.format(int(episode_length))

ckpts = int(steps // save_interval)
xaxis = save_interval * np.array(range(1, ckpts+1))#(np.array(range(1, ckpts+1))-0.5)

In [None]:
steps = xaxis[4]

In [None]:
del bl_act, lyap_act, exp_act

In [None]:
bl_act = DefaultDDPG.load(os.path.join(name, 'spec-def-ddpg-{}'.format(
    bl_seeds[np.argmax(b1_found[:, -1])])), steps, env=env)

In [None]:
lyap_act = LyapunovDDPG.load(os.path.join(name, 'spec-lyap-ddpg-{}'.format(
    lyap_seeds[np.argmax(l1_found[:, -1])])), steps, env=env)

In [None]:
exp_act = ExplorerDDPG.load(os.path.join(name, 'spec-exp-ddpg-{}'.format(8204)), steps, env=env)
#    exp_seeds[np.argmax(e1_found[:, -1])])), steps, env=env)

In [None]:
# env = Monitor(env, './video', force=True)
frame_size = (180, 180)

In [None]:
def run_once(model):
    obs, done = env.reset(), False

    episode_rew = 0
    episode_safety = 1.
    frames = []
    t = 0

    while not done:
        if t > episode_length or episode_safety == 0.:
            break

    #     env.render()
        frame = env.render(mode='rgb_array')
        frames.append(resize(frame, dsize=frame_size,))

        # Do step
        obs, rew, done, info = env.step(model.step(obs))
        episode_safety = episode_safety * info['safety']
        episode_rew = gamma * episode_rew + rew
        t += 1

    print("Total runtime: %.4f" % t)
    print("Total reward: %.4f" % episode_rew)
    print("Total safety: %.4f" % episode_safety)
    env.close()
    return frames

In [None]:
bl_frames = run_once(bl_act)
lyap_frames = run_once(lyap_act)
exp_frames = run_once(exp_act)

## Create the video \& GIF image

In [None]:
img_array = []

# Sources:
# https://theailearner.com/2018/10/15/creating-video-from-images-using-opencv-python/
# https://stackoverflow.com/questions/30509573/writing-an-mp4-video-using-python-opencv

img_array = []

for t in range(episode_length+1):
    img = np.ones((240, 620, 3), dtype=np.int) * 255
    
    if len(bl_frames) <= t:
        bl_fr = bl_frames[-1]
    else:
        bl_fr = bl_frames[t]

    if len(lyap_frames) <= t:
        lyap_fr = lyap_frames[-1]
    else:
        lyap_fr = lyap_frames[t]

    if len(exp_frames) <= t:
        exp_fr = exp_frames[-1]
    else:
        exp_fr = exp_frames[t]

    img[30:210, 15:195, :] = bl_fr
    img[30:210, 210:390, :] = lyap_fr
    img[30:210, 405:585, :] = exp_fr
            
    height, width, layers = img.shape
    size = (width, height)
    
    cv2.imwrite(os.path.join(name, "trial-frame-{}-{}.jpg".format(steps, t)), img)
    img_array.append(np.uint8(img))

#out = cv2.VideoWriter(os.path.join(name, 'visualize.mp4'), cv2.VideoWriter_fourcc(*'DIVX'), 15, size)
out = cv2.VideoWriter(os.path.join(name, 'trial-{}.mp4'.format(steps)), cv2.VideoWriter_fourcc(*'MP4V'), 15, size)

for i in range(len(img_array)):
    out.write(img_array[i])
out.release()

## Create the GIF image

In [None]:
del img_array
# Source: https://stackoverflow.com/questions/753190/programmatically-generate-video-or-animated-gif-in-python

# filepaths
fp_in = os.path.join(name, 'trial-frame-{}-*.jpg'.format(steps))
fp_out = os.path.join(name, 'trial-{}.gif'.format(steps))

# https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#gif
img, *imgs = [Image.open(f) for f in sorted(glob.glob(fp_in))]
img.save(fp=fp_out, format='gif', append_images=imgs,
         save_all=True, duration=200, loop=0)