# Setup

In [None]:
# Imports
import collections
import os
import os.path as osp

import joblib
import numpy as np
import gym
import pandas as pd
import matplotlib.pyplot as plt

import pirl
from pirl.experiments import config, experiments, plots as myplots

In [None]:
# Config
experiment = 'few-jungle-20180508_170607-0f47c1243e034ef403d65d8db95c256c3ed75f3b'
experiment_dir = osp.join('data', experiment)
figs_dir = osp.join('../population-irl-paper', 'figs')

# Value difference

In [None]:
def plot_value(experiment_dir, algo_pattern='(.*)', env_pattern='(.*)', algos=['.*'], dps=2):
    fname = osp.join(experiment_dir, 'results.pkl')
    data = pd.read_pickle(fname)
    
    value = myplots.extract_value(data)
    value.columns = value.columns.str.extract(algo_pattern, expand=False)
    envs = value.index.levels[0].str.extract(env_pattern, expand=False)
    value.index = value.index.set_levels(envs, level=0)
    
    matches = []
    mask = pd.Series(False, index=value.columns)
    for p in algos:
        m = value.columns.str.match(p)
        matches += list(value.columns[m & (~mask)])
        mask |= m
    value = value.loc[:, matches]
    
    value.columns = value.columns.str.split('_').str.join(' ')  # so lines wrap
    value = value.round(dps)
    return value

def plot_ci(df):
    mean = df.loc[(slice(None), slice(None), slice(None), slice(None), 'mean'), :]
    se = df.loc[(slice(None), slice(None), slice(None), slice(None), 'se'), :]
    mean.index = mean.index.droplevel('type')
    se.index = se.index.droplevel('type')
    return mean.applymap(lambda x: '{:.3f} +/- '.format(x)) + se.applymap(lambda x: '{:.3f}'.format(1.96 * x))

In [None]:
algo_pattern = '(.*)'
env_pattern = '.*-(.*)-v0'
df = plot_value(experiment_dir, algo_pattern, env_pattern)
df

# Policy rollout

In [None]:
def expert_cached_value(rl, env_name, pol_discount=0.99, eval_discount=1.00, seed=1234, episodes=100):
    '''Rollout a cached expert policy for episodes.
       WARNING: This will be slow or just break if policy is not in cache!'''
    gen_policy, _sample, compute_value = config.RL_ALGORITHMS[rl]
    policy, value = experiments._train_policy(rl, pol_discount, env_name, seed, None)
    vmean, vse = value
    print('Cached value: {:.3f} +/- {:.3f}'.format(vmean, 1.96 * vse))
    
    env = gym.make(env_name)
    rmean, rse = compute_value(env, policy, eval_discount, num_episodes=episodes, seed=seed)
    print('Rollout value: {:.3f} +/- {:.3f}'.format(rmean, 1.96 * rse))
    return (vmean, vse), (rmean, rse)

def _policy_value(results_dir, rl, env_name, pol_discount, eval_discount, episodes, seed):
    _gen_policy, _sample, compute_value = config.RL_ALGORITHMS[rl]
    fname = osp.join(results_dir, 'policy.pkl')
    print('Loading policy from ', fname)
    policy = joblib.load(fname)
    env = gym.make(env_name)
    
    mean, se = compute_value(env, policy, eval_discount, num_episodes=episodes, seed=seed)
    print('Rollout value: {:.3f} +/- {:.3f}'.format(mean, 1.96 * se))
    return mean, se

def expert_value(experiment_dir, rl, env_name, pol_discount=0.99, eval_discount=1.00, episodes=100, seed=1234):
    results_dir = osp.join(experiment_dir, 'expert', env_name, rl)
    return _policy_value(results_dir, rl, env_name, pol_discount, eval_discount, episodes, seed)

def irl_eval_value(experiment_dir, irl_name, num_traj, rl, env_name, pol_discount=0.99, eval_discount=1.00, episodes=100, seed=1234):
    results_dir = osp.join(experiment_dir, 'eval', env_name, 
                           '{}:{}:{}'.format(irl_name, num_traj, num_traj), rl)
    return _policy_value(results_dir, rl, env_name, pol_discount, eval_discount, episodes, seed)
    
def irl_value(experiment_dir, irl_name, env_name, num_traj, eval_discount=1.00, episodes=100):
    _irl_algo, _reward_wrapper, compute_value = experiments.make_irl_algo(irl_name)
    irl_dir = osp.join(experiment_dir, 'irl', irl_name)
    if not os.path.exists(irl_dir):
        raise FileNotFoundError("No result directory {}".format(irl_dir))
    
    pop_fname = osp.join(irl_dir, str(num_traj), 'policies.pkl')
    sin_fname = osp.join(irl_dir, env_name, str(num_traj), 'policy.pkl')
    if os.path.exists(pop_fname):
        policies = joblib.load(pop_fname)
        print(policies.keys())
        policy = policies[env_name]
    elif os.path.exists(sin_fname):
        policy = joblib.load(sin_fname)
    else:
        raise FileNotFoundError("Neither {} or {} exists".format(pop_fname, sin_fname))
    
    env = gym.make(env_name)
    mean, se = compute_value(env, policy, discount=eval_discount, num_episodes=episodes)
    print('Rollout value: {} +/- {}'.format(mean, 1.96 * se))

In [None]:
expert_cached_value('ppo_cts', 'Reacher-v2', episodes=100)

In [None]:
expert_value(experiment_dir, 'ppo_cts', 'Reacher-v2', episodes=500, seed=1234)

In [None]:
irl_value(experiment_dir, 'airl', 'Reacher-v2', 1000, episodes=100)

# Visualizing rewards (gridworld only)

In [None]:
def show_heatmaps(irl_algo, kind='inline', out_dir=None, shape=(9,9), **kwargs):
    data = pd.read_pickle(osp.join(experiment_dir, 'results.pkl'))
    rewards = data['rewards'][irl_algo]
    if kind in ['inline', 'pdf']:
        figs = myplots.gridworld_heatmap(rewards, shape)
        if out_dir is None:
            for fig in figs:
                display(fig[1])
        else:
            myplots.save_figs(figs, out_dir)
    elif kind == 'movie':
        myplots.gridworld_heatmap_movie(out_dir, rewards, shape)
    else:
        assert False

In [None]:
irl_algos = ['mce', 'mcec', 'mcep_reg1e0', 'mcep_reg1e-1', 'mcep_reg1e-2']
for irl in irl_algos:
    show_heatmaps(irl, kind='pdf', out_dir='figs/few-jungle/' + irl, shape=(9,9))
    #show_heatmaps(irl, kind='movie', out_dir='figs/jungle/movies/' + irl)
    #show_heatmaps(irl, kind='movie', out_dir='figs/jungle/' + irl)

# Jungle experiments

In [None]:
jungle_types = collections.OrderedDict([
    ('A', 'Soda'), 
    ('B', 'Water'), 
    ('A+B', 'Liquid')
])
jungle_envs = collections.OrderedDict([
    (k, 'pirl/GridWorld-Jungle-9x9-{}-v0'.format(v))
    for k, v in jungle_types.items()
])
default_algos = collections.OrderedDict([
    ('mce', 'Single'),
    ('mcec', 'Concat'),
    ('mcep reg1e-1', 'Multi-task'),
    ('value iteration', 'Oracle'),
])

## Ground truth

In [None]:
with plt.style.context([myplots.style('default'), myplots.style('twocol')]):
    fig = myplots.gridworld_cartoon((9,9))
    fig.savefig(osp.join(figs_dir, 'jungle', 'gt.pdf'))
    plt.close(fig)

## Reoptimized policy value

In [None]:
values = plot_value(experiment_dir, algo_pattern, env_pattern)
values.columns

## Extract data

In [None]:
algo_pattern = '(.*)'
env_pattern = '.*-(.*)-v0'

values = plot_value(experiment_dir, algo_pattern, env_pattern)
values = values.rename(index={v: k for k, v in jungle_types.items()}, level=0)
values = values.xs((1000, 'value_iteration'), level=('n', 'eval'))
values

## Figures

In [None]:
baseline_comparison = values.rename(columns=default_algos).loc[:, default_algos.values()]
with plt.style.context([myplots.style('default'), myplots.style('twocol')]):
    fig = myplots.value_bar_chart_by_env(baseline_comparison, envs=jungle_types.keys(), relative='Oracle')
    fig.savefig(osp.join(figs_dir, 'jungle', 'baseline_comparison.pdf'), bbox_inches='tight')

In [None]:
reg_algorithms = collections.OrderedDict([
    ('mcep reg0', '$\lambda = 0$'),
    ('mcep reg1e-2', '$\lambda = 10^{-2}$'),
    ('mcep reg1e-1', '$\lambda = 10^{-1}$'),
    ('mcep reg1e0', '$\lambda = 10^0$'),
    ('value iteration', 'Oracle'),
])
reg_comparison = values.rename(columns=reg_algorithms).loc[:, reg_algorithms.values()]
with plt.style.context([myplots.style('default'), myplots.style('twocol')]):
    fig = myplots.value_bar_chart_by_env(reg_comparison, envs=jungle_types.keys(), relative='Oracle')
    fig.savefig(osp.join(figs_dir, 'jungle', 'reg_comparison.pdf'), bbox_inches='tight')

## Tables

In [None]:
table_values = values.rename(columns=default_algos).loc[:, default_algos.values()]
print(myplots.value_latex_table(table_values, envs=['A', 'B', 'A+B'], relative='Oracle'))

# Ad-hoc experiments

In [None]:
fname = osp.join(experiment_dir, 'results.pkl')
data = pd.read_pickle(fname)

In [None]:
env_name = 'pirl/GridWorld-Jungle-9x9-Soda-v0'
algo = 'mce'
fig, axs = plt.subplots(1, 2, figsize=(13, 6))
# zeroshot = data['rewards'][algo][1000][0][env_name]
# oneshot = data['rewards'][algo][1000][1][env_name]
# myplots._gridworld_heatmap(oneshot - zeroshot, (9,9), fmt='.2f', ax=axs[0])
fiveshot = data['rewards'][algo][1000][5][env_name]
hundredshot = data['rewards'][algo][1000][100][env_name]
myplots._gridworld_heatmap(hundredshot - zeroshot, (9,9), fmt='.2f', ax=axs[1])

env = gym.make(env_name)
gt = env.unwrapped.reward
env.close()
r = data['rewards'][algo][1000]
reward_delta = {k: np.linalg.norm(v[env_name] - np.mean(v[env_name]) - gt + np.mean(gt))
                for k, v in r.items()}
reward_delta