In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
import gzip
import tensorflow as tf
import glob
import math
import skimage
from matplotlib.patches import Rectangle

In [None]:
tf.enable_eager_execution()

## Specify the Experiment Directory

In [None]:
common_dir = '/home/justinvyu/ray_results'
universe = 'gym'
domain = 'Point2D'
task = 'Maze-v0'

In [None]:
base_path = os.path.join(common_dir, universe, domain, task)
exps = sorted(list(glob.iglob(os.path.join(base_path, '*'))))
for i, exp in enumerate(exps):
    print(f'{i} \t {exp.replace(base_path, "")}')
    
exp_choice = int(input('\n Which experiment do you want to analyze? (ENTER A NUMBER) \t'))

exp_path = exps[exp_choice]
print('\n')
seeds = sorted(list(glob.iglob(os.path.join(exp_path, '*'))))
seeds = [seed for seed in seeds if os.path.isdir(seed)]
for i, seed in enumerate(seeds):
    print(f'{i} \t {seed.replace(exp_path, "")}')
    
# TODO: Extend to analyzing all seeds
seed_choice = int(input('\n Which seed do you want to analyze? (ENTER A NUMBER) \t'))

seed_path = seeds[seed_choice]

print('PATH:\n', seed_path)

## Specify the Checkpoint

In [None]:
checkpoint_to_analyze = 60

In [None]:
with open(os.path.join(seed_path, f'checkpoint_{checkpoint_to_analyze}/checkpoint.pkl'), 'rb') as f:
    checkpoint = pickle.load(f)

## Load Reward Classifier

In [None]:
reward_clf = checkpoint['reward_classifier']

In [None]:
train_env = checkpoint['training_environment']

In [None]:
target_pos = train_env.unwrapped._get_obs()['state_desired_goal']

In [None]:
n_bins = train_env.n_bins + 1

for i in range(n_bins):
    for j in range(n_bins):
        obs = np.eye(n_bins)[np.array([i, j])].flatten()

In [None]:
n_samples = 50
obs_space = train_env.unwrapped.observation_space['state_observation']
xs = np.linspace(obs_space.low[0], obs_space.high[0], n_samples)
ys = np.linspace(obs_space.low[1], obs_space.high[1], n_samples)

xys = np.meshgrid(xs, ys)

In [None]:
grid_vals = np.array(xys).transpose(1, 2, 0).reshape((n_samples * n_samples, 2))
grid_vals = np.array([np.eye(n_bins)[train_env.unwrapped._discretize_observation(grid_val)].flatten() for grid_val in grid_vals])

In [None]:
grid_vals.shape

In [None]:
rewards = reward_clf.predict(grid_vals)

In [None]:
reward_clf.summary()

In [None]:
plt.figure(figsize=(8, 8))
from matplotlib.patches import Rectangle
plt.gca().invert_yaxis()

plt.contourf(xys[0], xys[1], rewards.reshape(xys[0].shape), levels=20)
plt.colorbar(fraction=0.046, pad=0.04)

if task == 'BoxWall-v1':
    currentAxis = plt.gca()
    currentAxis.add_patch(Rectangle((-2, -2), 4, 4,
                          alpha=1, fill=None, linewidth=4))

plt.scatter(*target_pos, marker='*', s=250, color='white')
plt.title(f'VICE Reward for {domain + task} Task @ Checkpoint #{checkpoint_to_analyze}\n'
          + f'Target Pos: {target_pos}')
plt.show()

## Plot All Checkpoints at Once

In [None]:
def plot_vice_reward(clf, train_env, n_samples=50):
    obs_space = train_env.observation_space['state_observation']
    xs = np.linspace(obs_space.low[0], obs_space.high[0], n_samples)
    ys = np.linspace(obs_space.low[1], obs_space.high[1], n_samples)

    xys = np.meshgrid(xs, ys)
    grid_vals = np.array(xys).transpose(1, 2, 0).reshape((n_samples * n_samples, 2))
    
    rewards = clf.predict(grid_vals)
    plt.gca().invert_yaxis()
    plt.contourf(xys[0], xys[1], rewards.reshape(xys[0].shape), levels=300)
    plt.colorbar(fraction=0.046, pad=0.04)
    target_pos = train_env.unwrapped._get_obs()['state_desired_goal']

    if task == 'BoxWall-v1':
        currentAxis = plt.gca()
        currentAxis.add_patch(Rectangle((-2, -2), 4, 4,
                              alpha=1, fill=None, linewidth=4))

    plt.scatter(*target_pos, marker='*', s=250, color='white')

In [None]:
def plot_grid(imgs, labels=None):
    n_images = len(imgs)
    n_columns = np.sqrt(n_images)
    n_rows = np.ceil(n_images / n_columns) + 1
    plt.figure(figsize=(5 * n_columns, 5 * n_rows))
    for i, img in enumerate(imgs):
        plt.subplot(n_rows, n_columns, i+1)
        plt.axis('off')
        plt.imshow(img)
        if labels is not None:
            plt.title(labels[i], fontsize=20)
    plt.show()

In [None]:
checkpoint_paths = list(glob.iglob(os.path.join(seed_path, 'checkpoint_*')))
# Sort by the checkpoint number at the end
checkpoint_paths = sorted(checkpoint_paths, key=lambda s: int(s.split("_")[-1]))

In [None]:
n_plots = len(checkpoint_paths)
n_columns = int(np.sqrt(n_plots) + 1)
n_rows = np.ceil(n_plots / n_columns)
plt.figure(figsize=(5 * n_columns, 5 * n_rows))

for i, path in enumerate(checkpoint_paths):
    with open(os.path.join(path, 'checkpoint.pkl'), 'rb') as f:
        checkpoint = pickle.load(f)
    reward_clf = checkpoint['reward_classifier']
    train_env = checkpoint['training_environment']
    plt.subplot(n_rows, n_columns, i+1, aspect=1)
    plot_vice_reward(reward_clf, train_env)
    plt.title(int(path.split("_")[-1]), fontsize=20)
    
plt.show()

## Plot Visitations

In [None]:
from softlearning.replay_pools.utils import get_replay_pool_from_variant

replay_pool = None
train_env = None

for i, path in enumerate(checkpoint_paths):
    if replay_pool is None:
        with open(os.path.join(path, 'checkpoint.pkl'), 'rb') as f:
            checkpoint = pickle.load(f)
        variant = checkpoint['variant']
        train_env = checkpoint['training_environment']
        replay_pool = get_replay_pool_from_variant(variant, train_env)
        
    replay_pool_path = os.path.join(path, 'replay_pool.pkl')
    replay_pool.load_experience(replay_pool_path)

In [None]:
non_zero_rows = replay_pool.data[('observations', 'state_observation')].any(axis=-1)

In [None]:
visitations = replay_pool.data[('observations', 'state_observation')][non_zero_rows]
plt.figure(figsize=(8, 8))
plt.xlim(-4, 4)
plt.ylim(-4, 4)
plt.gca().invert_yaxis()
plt.imshow(train_env.render('rgb_array'),
           extent=(-4, 4, -4, 4),
           origin='lower',
           alpha=0.25,
           zorder=3,
           interpolation='nearest')

plt.scatter(visitations[:, 0], visitations[:, 1], alpha=0.1)

In [None]:
plt.imshow(train_env.render('rgb_array'))

## Plot Goal Examples

In [None]:
with open('/home/justinvyu/dev/vice/goal_classifier/pointmass_nowalls/bottom_middle/positives.pkl', 'rb') as f:
    data = pickle.load(f)

In [None]:
plt.figure(figsize=(8, 8))

plt.scatter(data['state_observation'][:,0], data['state_observation'][:,1], s=5)
plt.xlim([-4, 4])
plt.ylim([-4, 4])
plt.gca().invert_yaxis()

## Plot Ground Truth Rewards

In [None]:
feed_dict = {
    'state_achieved_goal': grid_vals,
    'state_desired_goal': np.full(grid_vals.shape, fill_value=2)
}
train_env.unwrapped.reward_type = 'sparse'
gtr = train_env.unwrapped.compute_rewards(None, feed_dict)
plt.figure(figsize=(8, 8))

from matplotlib.patches import Rectangle

plt.gca().invert_yaxis()

plt.contourf(xys[0], xys[1], gtr.reshape(xys[0].shape))
plt.colorbar(fraction=0.046, pad=0.04)

if task == 'BoxWall-v1':
    currentAxis = plt.gca()
    currentAxis.add_patch(Rectangle((-2, -2), 4, 4,
                          alpha=1, fill=None, linewidth=4))

plt.title(f'Ground Truth Reward for {domain + task} Task @ Checkpoint #{checkpoint_to_analyze}')

plt.scatter(*target_pos, marker='*', s=250, color='white')
plt.show()

## Qs Visualization

In [None]:
checkpoint_to_analyze = 100
checkpoint_dir = os.path.join(seed_path, f'checkpoint_{checkpoint_to_analyze}')

with open(os.path.join(checkpoint_dir, 'checkpoint.pkl'), 'rb') as f:
    checkpoint = pickle.load(f)

In [None]:
variant = checkpoint['variant']
env = checkpoint['training_environment']
target_pos = env.unwrapped._get_obs()['state_desired_goal']

In [None]:
from softlearning.value_functions.utils import get_Q_function_from_variant

In [None]:
Qs = get_Q_function_from_variant(variant, env)

In [None]:
for i, Q in enumerate(Qs):
    weights_path = os.path.join(checkpoint_dir, f'Qs_{i}')
    Q.load_weights(weights_path)

In [None]:
n_action_samples = 20
sample_actions = np.vstack([env.action_space.sample() for _ in range(n_action_samples)])

In [None]:
n_samples = 50

obs_space = env.observation_space['state_observation']
xs = np.linspace(obs_space.low[0], obs_space.high[0], n_samples)
ys = np.linspace(obs_space.low[1], obs_space.high[1], n_samples)

xys = np.meshgrid(xs, ys)
grid_vals = np.array(xys).transpose(1, 2, 0).reshape((n_samples * n_samples, 2))

In [None]:
value_estimates = []
for pos in grid_vals:
    value_estimates.append(
        np.min([Q.predict([sample_actions,
                        np.repeat(pos, n_action_samples).reshape((n_action_samples, -1))])
             for Q in Qs])
    )

In [None]:
plt.figure(figsize=(8, 8))
from matplotlib.patches import Rectangle
plt.gca().invert_yaxis()

plt.contourf(xys[0], xys[1], np.array(value_estimates).reshape(xys[0].shape))
plt.colorbar(fraction=0.046, pad=0.04)

if task == 'BoxWall-v1':
    currentAxis = plt.gca()
    currentAxis.add_patch(Rectangle((-2, -2), 4, 4,
                          alpha=1, fill=None, linewidth=4))

plt.scatter(*target_pos, marker='*', s=250, color='white')
plt.title(f'Value function estimates for {domain + task} Task @ Checkpoint #{checkpoint_to_analyze}\n'
          + f'Target Pos: {target_pos}')
plt.show()

In [None]:
Qs[0].summary()

In [None]:
np.repeat(grid_vals[0], n_action_samples).reshape((n_action_samples, -1))