In [18]:
import torch
import numpy as np
import os
from pathlib import Path
import yaml
import ipywidgets as widgets
from ipywidgets import interact, Output, GridspecLayout
from IPython.display import display
import mediapy as media
import ast
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (20,3)

In [19]:
run_name = 'LunarLander-v2_mlp_3_naive_1709239562'
config = 'lunarlander_lang_naive_s3.yml'

path = os.path.join('../config', config)
config = yaml.safe_load(Path(path).read_text())
print(config)

{'action_size': 4, 'batch_size': 256, 'buffer_size': 10000, 'discrim_lr': 0.0001, 'discrim_momentum': 0.99, 'env_name': 'LunarLander-v2', 'episodes': 5000, 'eps_decay': 0.995, 'eps_end': 0.01, 'eps_start': 1.0, 'gamma': 0.95, 'max_steps_per_episode': 300, 'policy_lr': 0.001, 'skill_size': 3, 'state_size': 8, 'tau': 0.01, 'update_every': 1, 'exp_type': 'mlp', 'discrim_units': 512, 'policy_units': 512, 'embedding_type': 'naive', 'embedding_size': 3}


In [20]:
# Get all folders that contain all substrings from a set of substrings
def get_folders(path, substrings):
    folders = []
    for folder in os.listdir(path):
        if all(substring in folder for substring in substrings):
            folders.append(folder)
    folders.sort()
    return folders

def create_discrim_graphs(folders):
    path_prefix = os.path.join('../data/', run_name, 'rollouts')
    fig, axs = plt.subplots(1, len(folders))
    for i, folder in enumerate(folders):
        full_path = os.path.join(path_prefix, folder, 'discriminator_probs.txt')
        with open(full_path, 'r') as f:
            discriminator_probs = f.read()
            discriminator_probs = ast.literal_eval(discriminator_probs)
            axs[i].plot(discriminator_probs)
            axs[i].set_ylim(-0.05, 1.05)
    plt.show()

## Visualize all rollouts for a specific timestep $t$, skill $z$

In [21]:
cell1_t = widgets.IntSlider(min=0, max=4500, step=500, description='Timestep:', continuous_update=False,
                            orientation='horizontal', readout=True, readout_format='d')
cell1_z = widgets.IntSlider(min=0, max=config["skill_size"] - 1, step=1, description='Skill:', continuous_update=False,
                            orientation='horizontal', readout=True, readout_format='d')

def display_cell1_videos(t, z):
    # Collect all relevant videos
    print('Generating videos...')
    path_prefix = os.path.join('../data/', run_name, 'rollouts')
    folders = get_folders(path_prefix, [f'iter{t}', f'skill{z}'])
    videos = {}
    for folder in folders:
        rollout_path = os.path.join(path_prefix, folder)
        # Find first video in path
        for filename in os.listdir(rollout_path):
            if filename.endswith('.mp4'):
                video_path = os.path.join(rollout_path, filename)
                video = media.read_video(video_path)
                video_title = 'r={}'.format(folder.split('_')[2].split('rollout')[-1])
                videos[video_title] = video
                break
    
    # Display videos
    media.show_videos(videos, fps=20, height=200)
    create_discrim_graphs(folders)
    print('Generated')

interact(display_cell1_videos, t=cell1_t, z=cell1_z)

interactive(children=(IntSlider(value=0, continuous_update=False, description='Timestep:', max=4500, step=500)…

<function __main__.display_cell1_videos(t, z)>

## Visualize all skills for a specific timestep $t$ (diversity of skills)

In [22]:
cell1_t = widgets.IntSlider(min=0, max=4500, step=500, description='Timestep:', continuous_update=False,
                            orientation='horizontal', readout=True, readout_format='d')
cell1_r = widgets.IntSlider(min=0, max=4, step=1, description='Rollout:', continuous_update=False,
                            orientation='horizontal', readout=True, readout_format='d')

def display_cell2_videos(t, r):
    # Collect all relevant videos
    print('Generating videos...')
    path_prefix = os.path.join('../data/', run_name, 'rollouts')
    folders = get_folders(path_prefix, [f'iter{t}', f'rollout{r}'])
    videos = {}
    for folder in folders:
        rollout_path = os.path.join(path_prefix, folder)
        # Find first video in path
        for filename in os.listdir(rollout_path):
            if filename.endswith('.mp4'):
                video_path = os.path.join(rollout_path, filename)
                video = media.read_video(video_path)
                video_title = 'z={}'.format(folder.split('_')[1].split('skill')[-1])
                videos[video_title] = video
                break
    
    # Display videos
    media.show_videos(videos, fps=20, height=200)
    create_discrim_graphs(folders)
    print('Generated')

interact(display_cell2_videos, t=cell1_t, r=cell1_r)

interactive(children=(IntSlider(value=0, continuous_update=False, description='Timestep:', max=4500, step=500)…

<function __main__.display_cell2_videos(t, r)>

## Visualize all timesteps for a specific skill $z$ (progression of skill over time)

In [23]:
cell1_z = widgets.IntSlider(min=0, max=config["skill_size"] - 1, step=1, description='Skill:', continuous_update=False,
                            orientation='horizontal', readout=True, readout_format='d')
cell1_r = widgets.IntSlider(min=0, max=4, step=1, description='Rollout:', continuous_update=False,
                            orientation='horizontal', readout=True, readout_format='d')

def display_cell3_videos(z, r):
    # Collect all relevant videos
    print('Generating videos...')
    path_prefix = os.path.join('../data/', run_name, 'rollouts')
    folders = get_folders(path_prefix, [f'rollout{r}', f'skill{z}'])
    folders.sort(key=lambda x: int(x.split('_')[0].split('iter')[-1]))
    videos = {}
    for folder in folders:
        rollout_path = os.path.join(path_prefix, folder)
        # Find first video in path
        for filename in os.listdir(rollout_path):
            if filename.endswith('.mp4'):
                video_path = os.path.join(rollout_path, filename)
                video = media.read_video(video_path)
                video_title = 't={}'.format(folder.split('_')[0].split('iter')[-1])
                videos[video_title] = video
                break
    
    # Display videos
    media.show_videos(videos, fps=20, height=200)
    create_discrim_graphs(folders)
    print('Generated')

interact(display_cell3_videos, z=cell1_z, r=cell1_r)

interactive(children=(IntSlider(value=0, continuous_update=False, description='Skill:', max=2), IntSlider(valu…

<function __main__.display_cell3_videos(z, r)>