In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
import gzip
import tensorflow as tf
import glob
import math
import skimage
from matplotlib.patches import Rectangle

## Specify the Experiment Directory

In [None]:
common_dir = '/home/justinvyu/ray_results'
universe = 'gym'
domain = 'Point2D'
task = 'Maze-v0'

In [None]:
base_path = os.path.join(common_dir, universe, domain, task)
exps = sorted(list(glob.iglob(os.path.join(base_path, '*'))))
for i, exp in enumerate(exps):
    print(f'{i} \t {exp.replace(base_path, "")}')
    
exp_choice = int(input('\n Which experiment do you want to analyze? (ENTER A NUMBER) \t'))

exp_path = exps[exp_choice]
print('\n')
seeds = sorted(list(glob.iglob(os.path.join(exp_path, '*'))))
seeds = [seed for seed in seeds if os.path.isdir(seed)]
for i, seed in enumerate(seeds):
    print(f'{i} \t {seed.replace(exp_path, "")}')
    
# TODO: Extend to analyzing all seeds
seed_choice = int(input('\n Which seed do you want to analyze? (ENTER A NUMBER) \t'))

seed_path = seeds[seed_choice]

print('PATH:\n', seed_path)

## Specify the Checkpoint

In [None]:
checkpoint_to_analyze = 200

In [None]:
checkpoint_dir = os.path.join(seed_path, f'checkpoint_{checkpoint_to_analyze}')
with open(os.path.join(checkpoint_dir, 'checkpoint.pkl'), 'rb') as f:
    checkpoint = pickle.load(f)

## Load Distance Function

In [None]:
checkpoint.keys()

In [None]:
distance_fn = checkpoint['distance_estimator']

In [None]:
train_env = checkpoint['training_environment']

In [None]:
target_pos = train_env.unwrapped._get_obs()['state_desired_goal']

In [None]:
n_samples = 50
obs_space = train_env.observation_space['state_observation']
xs = np.linspace(obs_space.low[0], obs_space.high[0], n_samples)
ys = np.linspace(obs_space.low[1], obs_space.high[1], n_samples)

xys = np.meshgrid(xs, ys)

In [None]:
grid_vals = np.array(xys).transpose(1, 2, 0).reshape((n_samples * n_samples, 2))

In [None]:
goal_vals = np.repeat(target_pos[None], n_samples * n_samples, axis=0)

In [None]:
dists = distance_fn.predict([grid_vals, goal_vals])

In [None]:
plt.figure(figsize=(8, 8))
from matplotlib.patches import Rectangle

plt.imshow(train_env.render('rgb_array', width=32, height=32),
           extent=(-4, 4, -4, 4), origin='lower', alpha=0.25, zorder=3)

plt.gca().invert_yaxis()
plt.contourf(xys[0], xys[1], dists.reshape(xys[0].shape), levels=20, zorder=1)
plt.colorbar(fraction=0.046, pad=0.04)

if task == 'BoxWall-v1':
    currentAxis = plt.gca()
    currentAxis.add_patch(Rectangle((-2, -2), 4, 4,
                          alpha=1, fill=None, linewidth=4))
    
plt.scatter(*target_pos, marker='*', s=250, color='white', zorder=2)

plt.title(f'd(s, g) for {domain + task} Task @ Checkpoint #{checkpoint_to_analyze}\n'
          + f'Target Pos: {target_pos}')
plt.show()

## Plot Evolution of Distance Function over Many Checkpoints

In [None]:
def plot_distance_to_goal(distance_fn, train_env, n_samples=50):
    obs_space = train_env.observation_space['state_observation']
    target_pos = train_env.unwrapped._get_obs()['state_desired_goal']
    
    goal_vals = np.repeat(target_pos[None], n_samples * n_samples, axis=0)
    
    xs = np.linspace(obs_space.low[0], obs_space.high[0], n_samples)
    ys = np.linspace(obs_space.low[1], obs_space.high[1], n_samples)

    xys = np.meshgrid(xs, ys)
    grid_vals = np.array(xys).transpose(1, 2, 0).reshape((n_samples * n_samples, 2))
    dists = distance_fn.predict([grid_vals, goal_vals])
    
#     plt.figure(figsize=(8, 8))
    from matplotlib.patches import Rectangle

    plt.imshow(train_env.render('rgb_array', width=32, height=32),
               extent=(-4, 4, -4, 4), origin='lower', alpha=0.25, zorder=3)

    plt.gca().invert_yaxis()
    plt.contourf(xys[0], xys[1], dists.reshape(xys[0].shape), levels=20, zorder=1)
    plt.colorbar(fraction=0.046, pad=0.04)

    plt.scatter(*target_pos, marker='*', s=250, color='white', zorder=2)

In [None]:
def plot_grid(imgs, labels=None):
    n_images = len(imgs)
    n_columns = np.sqrt(n_images)
    n_rows = np.ceil(n_images / n_columns) + 1
    plt.figure(figsize=(5 * n_columns, 5 * n_rows))
    for i, img in enumerate(imgs):
        plt.subplot(n_rows, n_columns, i+1)
        plt.axis('off')
        plt.imshow(img)
        if labels is not None:
            plt.title(labels[i], fontsize=20)
    plt.show()

In [None]:
checkpoint_paths = list(glob.iglob(os.path.join(seed_path, 'checkpoint_*')))
# Sort by the checkpoint number at the end
checkpoint_paths = sorted(checkpoint_paths, key=lambda s: int(s.split("_")[-1]))

In [None]:
n_plots = len(checkpoint_paths)
n_columns = int(np.sqrt(n_plots) + 1)
n_rows = np.ceil(n_plots / n_columns)
plt.figure(figsize=(5 * n_columns, 5 * n_rows))

imgs = []
for i, path in enumerate(checkpoint_paths):
    with open(os.path.join(path, 'checkpoint.pkl'), 'rb') as f:
        checkpoint = pickle.load(f)
    distance_fn = checkpoint['distance_estimator']
    train_env = checkpoint['training_environment']
    plt.subplot(n_rows, n_columns, i+1, aspect=1)
    plot_distance_to_goal(distance_fn, train_env)
    plt.title(int(path.split("_")[-1]), fontsize=20)
    
plt.show()

## Generate GIF of Distance over Time

In [None]:
checkpoint_paths = list(glob.iglob(os.path.join(seed_path, 'checkpoint_*')))
# Sort by the checkpoint number at the end
checkpoint_paths = sorted(checkpoint_paths, key=lambda s: int(s.split("_")[-1]))

imgs = []
for i, path in enumerate(checkpoint_paths):
    fig = plt.figure(figsize=(8, 8))
    with open(os.path.join(path, 'checkpoint.pkl'), 'rb') as f:
        checkpoint = pickle.load(f)
    distance_fn = checkpoint['distance_estimator']
    train_env = checkpoint['training_environment']
    plot_distance_to_goal(distance_fn, train_env)
    plt.title(int(path.split("_")[-1]), fontsize=20)
    fig.canvas.draw()
    data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    imgs.append(data)

In [None]:
import imageio
imageio.mimsave('./test.gif', imgs, duration=1.0)

## Plot Goal Examples

In [None]:
with open('/home/justinvyu/dev/vice/goal_classifier/pointmass_nowalls/bottom_middle/positives.pkl', 'rb') as f:
    data = pickle.load(f)

In [None]:
plt.figure(figsize=(8, 8))

plt.scatter(data['state_observation'][:,0], data['state_observation'][:,1], s=5)
plt.xlim([-4, 4])
plt.ylim([-4, 4])
plt.gca().invert_yaxis()

## Plot Ground Truth Rewards

In [None]:
feed_dict = {
    'state_achieved_goal': grid_vals,
    'state_desired_goal': np.full(grid_vals.shape, fill_value=2)
}
train_env.unwrapped.reward_type = 'sparse'
gtr = train_env.unwrapped.compute_rewards(None, feed_dict)
plt.figure(figsize=(8, 8))

from matplotlib.patches import Rectangle

plt.gca().invert_yaxis()

plt.contourf(xys[0], xys[1], gtr.reshape(xys[0].shape))
plt.colorbar(fraction=0.046, pad=0.04)

if task == 'BoxWall-v1':
    currentAxis = plt.gca()
    currentAxis.add_patch(Rectangle((-2, -2), 4, 4,
                          alpha=1, fill=None, linewidth=4))

plt.title(f'Ground Truth Reward for {domain + task} Task @ Checkpoint #{checkpoint_to_analyze}')

plt.scatter(*target_pos, marker='*', s=250, color='white')
plt.show()

## Qs Visualization

In [None]:
checkpoint_to_analyze = 100
checkpoint_dir = os.path.join(seed_path, f'checkpoint_{checkpoint_to_analyze}')

with open(os.path.join(checkpoint_dir, 'checkpoint.pkl'), 'rb') as f:
    checkpoint = pickle.load(f)

In [None]:
variant = checkpoint['variant']
env = checkpoint['training_environment']
target_pos = env.unwrapped._get_obs()['state_desired_goal']

In [None]:
from softlearning.value_functions.utils import get_Q_function_from_variant

In [None]:
Qs = get_Q_function_from_variant(variant, env)

In [None]:
for i, Q in enumerate(Qs):
    weights_path = os.path.join(checkpoint_dir, f'Qs_{i}')
    Q.load_weights(weights_path)

In [None]:
n_action_samples = 20
sample_actions = np.vstack([env.action_space.sample() for _ in range(n_action_samples)])

In [None]:
n_samples = 50

obs_space = env.observation_space['state_observation']
xs = np.linspace(obs_space.low[0], obs_space.high[0], n_samples)
ys = np.linspace(obs_space.low[1], obs_space.high[1], n_samples)

xys = np.meshgrid(xs, ys)
grid_vals = np.array(xys).transpose(1, 2, 0).reshape((n_samples * n_samples, 2))

In [None]:
value_estimates = []
for pos in grid_vals:
    value_estimates.append(
        np.min([Q.predict([sample_actions,
                        np.repeat(pos, n_action_samples).reshape((n_action_samples, -1))])
             for Q in Qs])
    )

In [None]:
plt.figure(figsize=(8, 8))
from matplotlib.patches import Rectangle
plt.gca().invert_yaxis()

plt.contourf(xys[0], xys[1], np.array(value_estimates).reshape(xys[0].shape))
plt.colorbar(fraction=0.046, pad=0.04)

if task == 'BoxWall-v1':
    currentAxis = plt.gca()
    currentAxis.add_patch(Rectangle((-2, -2), 4, 4,
                          alpha=1, fill=None, linewidth=4))

plt.scatter(*target_pos, marker='*', s=250, color='white')
plt.title(f'Value function estimates for {domain + task} Task @ Checkpoint #{checkpoint_to_analyze}\n'
          + f'Target Pos: {target_pos}')
plt.show()

In [None]:
np.repeat(grid_vals[0], n_action_samples).reshape((n_action_samples, -1))

## Embedding

In [None]:
embedding_fn = checkpoint['distance_estimator']
train_env = checkpoint['training_environment']
target_pos = train_env.unwrapped._get_obs()['state_desired_goal']

n_samples = 50
obs_space = train_env.observation_space['state_observation']
xs = np.linspace(obs_space.low[0], obs_space.high[0], n_samples)
ys = np.linspace(obs_space.low[1], obs_space.high[1], n_samples)

xys = np.meshgrid(xs, ys)

grid_vals = np.array(xys).transpose(1, 2, 0).reshape((n_samples * n_samples, 2))
goal_vals = np.repeat(target_pos[None], n_samples * n_samples, axis=0)
dists = np.linalg.norm(embedding_fn.predict(goal_vals) - embedding_fn.predict(grid_vals), axis=-1)

In [None]:
plt.figure(figsize=(8, 8))
from matplotlib.patches import Rectangle

plt.imshow(train_env.render('rgb_array', width=256, height=256),
           extent=(-4, 4, -4, 4), origin='lower', alpha=0.25, zorder=3)

plt.gca().invert_yaxis()
plt.contourf(xys[0], xys[1], dists.reshape(xys[0].shape), levels=50, zorder=1)
plt.colorbar(fraction=0.046, pad=0.04)
    
plt.scatter(*target_pos, marker='*', s=250, color='white', zorder=2)

plt.title(f'|phi(g) - phi(s)| for {domain + task} Task @ Checkpoint #{checkpoint_to_analyze}\n'
          + f'Target Pos: {target_pos}')
plt.show()

In [None]:
embedded_goal = embedding_fn.predict(target_pos[None])

In [None]:
radii = np.arange(0.5, 4, 0.5)
pts_by_radius = []
for r in radii:
    embedded_pts = []
    for theta in np.arange(0, 2 * np.pi + np.pi / 30, np.pi / 30):
        dx = r * np.cos(theta)
        dy = r * np.sin(theta)
        pt = target_pos + np.array([dx, dy])
        x, y = pt
        if -4 <= x and x <= 4 and -4 <= y and y <= 4:
            embedded_pt = embedding_fn.predict(pt[None])
            embedded_pts.append(embedded_pt)
    pts_by_radius.append(np.vstack(embedded_pts))

In [None]:
border_pts = []
border_range = np.arange(-4 + 0.1, 4 - 0.1, 0.1).reshape(-1, 1)
border_pts.append(np.hstack(
    (np.ones(border_range.shape) * (4 - 0.1), border_range)
))
border_pts.append(np.hstack(
    (np.ones(border_range.shape) * (-4 + 0.1), border_range)
))
border_pts.append(np.hstack(
    (border_range, np.ones(border_range.shape) * (-4 + 0.1))
))
border_pts.append(np.hstack(
    (border_range, np.ones(border_range.shape) * (4 - 0.1))
))
border_pts = np.vstack(border_pts)

In [None]:
embedded_border = embedding_fn.predict(border_pts)

In [None]:
plt.figure(figsize=(8, 8))
plt.xlim(-5, 5)
plt.ylim(-10, 0)
plt.scatter(embedded_goal[0][0], embedded_goal[0][1])
for pts in pts_by_radius:
    plt.plot(pts[:, 0], pts[:, 1])
plt.legend(radii)

plt.plot(embedded_trajectory[:, 0], embedded_trajectory[:, 1], 'black')
# plt.plot(embedded_border)
# plt.quiver(embedded_trajectory[:-1, 0],
#            embedded_trajectory[:-1, 1],
#            embedded_actions[:, 0],
#            embedded_actions[:, 1],
#            color='black',
#            alpha=0.5,
#            linewidth=2,
#            headwidth=4)

In [None]:
plt.figure(figsize=(8, 8))

plt.imshow(train_env.render('rgb_array', width=256, height=256),
           extent=(-4, 4, -4, 4), origin='lower', alpha=0.25, zorder=3)

# plt.contourf(xys[0], xys[1], dists.reshape(xys[0].shape), levels=50, zorder=1)
# plt.colorbar(fraction=0.046, pad=0.04)
    
plt.scatter(*target_pos, marker='*', s=250, color='white', zorder=2)

trajectory = sample_trajectory['observations']['state_observation']
next_obs = sample_trajectory['next_observations']['state_observation']
actions = next_obs - trajectory

plt.plot(trajectory[:, 0], trajectory[:, 1])

plt.gca().invert_yaxis()

plt.title(f'|phi(g) - phi(s)| for {domain + task} Task @ Checkpoint #{checkpoint_to_analyze}\n'
          + f'Target Pos: {target_pos}')
plt.show()

In [None]:
checkpoint.keys()

In [None]:
from softlearning.replay_pools.utils import get_replay_pool_from_variant

variant = checkpoint['variant']
train_env = checkpoint['training_environment']
replay_pool = get_replay_pool_from_variant(variant, train_env)

replay_pool_path = os.path.join(checkpoint_dir, 'replay_pool.pkl')
replay_pool.load_experience(replay_pool_path)

In [None]:
sample_trajectory = replay_pool.last_n_batch(100)

In [None]:
embedded_trajectory = embedding_fn.predict(sample_trajectory['observations']['state_observation'])
embedded_actions = embedded_trajectory[1:, :] - embedded_trajectory[:-1, :]