In [3]:
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import yaml
from pathlib import Path
import os
import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import display
from ml_collections import ConfigDict

%cd /nas/ucb/arvindrajaraman/language-skills/diayn

from models import Discriminator
from lunarlander.captioner import naive_captioner

/nas/ucb/arvindrajaraman/language-skills/diayn


In [7]:
# Choose run and config

run_name = 'LunarLander-v2_mlp_3_naive_1710264325'
config_name = 'lunarlander_s3.yml'
iteration = 220000

config = yaml.safe_load(Path(os.path.join('config', config_name)).read_text())
config = ConfigDict(config)

if config.embedding_type == 'identity':
    embedding_fn = lambda x: x
elif config.embedding_type == 'naive':
    embedding_fn = naive_captioner
else:
    raise ValueError(f"Invalid embedding type: {config['embedding_type']}")

# Load discrim

discrim = Discriminator(embedding_size=config["embedding_size"], skill_size=config["skill_size"], fc1_units=config["discrim_units"], fc2_units=config["discrim_units"])
discrim.to('cuda:0')
discrim.load_state_dict(torch.load(f'./data/{run_name}/discrim_iter{iteration}.pth'))
discrim.eval()

lin_vel_x = widgets.FloatSlider(min=-5., max=5., step=0.1, description='Linear Velocity (x):', continuous_update=False,
                                orientation='horizontal', readout=True, readout_format='.1f')
lin_vel_y = widgets.FloatSlider(min=-5., max=5., step=0.1, description='Linear Velocity (y):', continuous_update=False,
                                orientation='horizontal', readout=True, readout_format='.1f')
angle = widgets.FloatSlider(min=-3.14, max=3.14, step=0.1, description='Angle:', continuous_update=False,
                                orientation='horizontal', readout=True, readout_format='.1f')
ang_vel = widgets.FloatSlider(min=-5., max=5., step=0.1, description='Angular Velocity:', continuous_update=False,
                                orientation='horizontal', readout=True, readout_format='.1f')
left_leg = widgets.FloatSlider(min=0., max=1., step=1.0, description='Left Leg on Ground?', continuous_update=False,
                                orientation='horizontal', readout=True, readout_format='.1f')
right_leg = widgets.FloatSlider(min=0., max=1., step=1.0, description='Right Leg on Ground?', continuous_update=False,
                                orientation='horizontal', readout=True, readout_format='.1f')

def update(lin_vel_x, lin_vel_y, angle, ang_vel, left_leg, right_leg):
    other_features = torch.tensor([lin_vel_x, lin_vel_y, angle, ang_vel, left_leg, right_leg])
    xs = torch.arange(-1.0, 1.005, 0.005)
    ys = torch.arange(0.0, 1.005, 0.005)

    X, Y = torch.meshgrid(xs, ys)
    points = torch.stack((X, Y), dim=2).view(-1, 2)
    points = torch.cat([points, other_features.repeat(points.size(0), 1)], dim=1)

    embeddings = torch.from_numpy(embedding_fn(points)).to('cuda:0').float()

    skill_probs = discrim.forward(embeddings).cpu().detach().numpy()
    skill_probs = skill_probs.reshape((X.shape[0], X.shape[1], 3))
    plt.pcolormesh(X, Y, skill_probs, shading='gouraud')

    # Add legend with red, green, blue
    plt.legend(handles=[mpatches.Patch(color='#ff0000', label='Skill 0'),
                       mpatches.Patch(color='#00ff00', label='Skill 1'),
                       mpatches.Patch(color='#0000ff', label='Skill 2')])

    plt.show()

interact(update, lin_vel_x=lin_vel_x, lin_vel_y=lin_vel_y, angle=angle, ang_vel=ang_vel, left_leg=left_leg, right_leg=right_leg)

interactive(children=(FloatSlider(value=0.0, continuous_update=False, description='Linear Velocity (x):', max=â€¦

<function __main__.update(lin_vel_x, lin_vel_y, angle, ang_vel, left_leg, right_leg)>