In [48]:
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import yaml
from pathlib import Path
import os
import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import display

%cd /nas/ucb/arvindrajaraman/language-skills/diayn

from models import Discriminator
from captioner_ll import naive_captioner, language_captioner

/nas/ucb/arvindrajaraman/language-skills/diayn


In [55]:
# Choose run and config

run_name = 'LunarLander-v2_mlp_3_naive_1709239562'
config_name = 'lunarlander_lang_naive_s3.yml'
iteration = 1500

config = yaml.safe_load(Path(os.path.join('config', config_name)).read_text())
print(config)

# Choose embedding function
if config["embedding_type"] == "identity":
    embedding_fn = lambda x: x
elif config["embedding_type"] == "naive":
    embedding_fn = naive_captioner
elif config["embedding_type"] == "language":
    embedding_fn = language_captioner
else:
    raise ValueError(f"Invalid embedding type: {config['embedding_type']}")

# Load discriminator

discriminator = Discriminator(state_embedding_size=config["embedding_size"], skill_size=config["skill_size"], fc1_units=config["discrim_units"], fc2_units=config["discrim_units"])
discriminator = nn.DataParallel(discriminator)
discriminator.to('cuda:0')
discriminator.load_state_dict(torch.load(f'./data/{run_name}/discriminator_iter{iteration}.pth'))
discriminator.eval()

lin_vel_x = widgets.FloatSlider(min=-5., max=5., step=0.1, description='Linear Velocity (x):', continuous_update=False,
                                orientation='horizontal', readout=True, readout_format='.1f')
lin_vel_y = widgets.FloatSlider(min=-5., max=5., step=0.1, description='Linear Velocity (y):', continuous_update=False,
                                orientation='horizontal', readout=True, readout_format='.1f')
angle = widgets.FloatSlider(min=-3.14, max=3.14, step=0.1, description='Angle:', continuous_update=False,
                                orientation='horizontal', readout=True, readout_format='.1f')
ang_vel = widgets.FloatSlider(min=-5., max=5., step=0.1, description='Angular Velocity:', continuous_update=False,
                                orientation='horizontal', readout=True, readout_format='.1f')
left_leg = widgets.FloatSlider(min=0., max=1., step=1.0, description='Left Leg on Ground?', continuous_update=False,
                                orientation='horizontal', readout=True, readout_format='.1f')
right_leg = widgets.FloatSlider(min=0., max=1., step=1.0, description='Right Leg on Ground?', continuous_update=False,
                                orientation='horizontal', readout=True, readout_format='.1f')

def update(lin_vel_x, lin_vel_y, angle, ang_vel, left_leg, right_leg):
    other_features = torch.tensor([lin_vel_x, lin_vel_y, angle, ang_vel, left_leg, right_leg])
    xs = torch.arange(-1.5, 1.51, 0.01)
    ys = torch.arange(-1.5, 1.51, 0.01)

    X, Y = torch.meshgrid(xs, ys)
    points = torch.stack((X, Y), dim=2).view(-1, 2)
    points = torch.cat([points, other_features.repeat(points.size(0), 1)], dim=1)

    embeddings = []
    for i in range(points.shape[0]):
        embeddings.append(embedding_fn(points[i]))
    embeddings = torch.stack(embeddings, dim=0)
    embeddings = embeddings.to('cuda:0')

    skill_probs = discriminator.forward(embeddings).cpu().detach().numpy()
    skill_probs = skill_probs.reshape((X.shape[0], X.shape[1], 3))
    print('Red = skill 0, Green = skill 1, Blue = skill 2')
    plt.pcolormesh(X, Y, skill_probs, shading='gouraud')

interact(update, lin_vel_x=lin_vel_x, lin_vel_y=lin_vel_y, angle=angle, ang_vel=ang_vel, left_leg=left_leg, right_leg=right_leg)

{'action_size': 4, 'batch_size': 256, 'buffer_size': 10000, 'discrim_lr': 0.0001, 'discrim_momentum': 0.99, 'env_name': 'LunarLander-v2', 'episodes': 5000, 'eps_decay': 0.995, 'eps_end': 0.01, 'eps_start': 1.0, 'gamma': 0.95, 'max_steps_per_episode': 300, 'policy_lr': 0.001, 'skill_size': 3, 'state_size': 8, 'tau': 0.01, 'update_every': 1, 'exp_type': 'mlp', 'discrim_units': 512, 'policy_units': 512, 'embedding_type': 'naive', 'embedding_size': 3}


interactive(children=(FloatSlider(value=0.0, continuous_update=False, description='Linear Velocity (x):', max=â€¦

<function __main__.update(lin_vel_x, lin_vel_y, angle, ang_vel, left_leg, right_leg)>