# Setup code

In [None]:
%matplotlib inline
import collections
import copy
import html
import os
import re
import shutil
import json
import glob
from functools import partial
from numpy import trapz
import csv
import pandas as pd

from datetime import datetime

from matplotlib import pyplot as plt 
import seaborn as sns 
from math import sqrt 
import pandas as pd 
import math

from IPython.core.display import display, HTML
import numpy as np 
import pandas as pd

sns.set()

# FIXME(sam): make cluster subpath selectable with a dropdown. Also, automatically infer
# gfs_mount from the hostname (have sensible default for svm/perceptron)

# this identifies data for a particular cluster on the GFS volume
# cluster_subpath = "with-resnet-128-2020-01-27/"
# cluster_subpath = "cluster-data/cluster-2021-01-29-set3-try4/"
# on svm I think gfs_mount is /scratch/sam/repl-vol/ or something like that
# gfs_mount = "/scratch/cynthiachen/"
# gfs_mount = "/scratch/sam/il-representations-gcp-volume/cluster-data/cluster-2021-09-23-sam-new-vis-gail-v2/"  # Google Filestore mount point (local)
# gfs_mount = "/scratch/sam/ilr-gail-procgen-runs-2021-09-29/"
gfs_mount = "/scratch/sam/il-representations-gcp-volume/cluster-data/cluster-2021-09-29-sam-new-vis-repl-ablations/"  # Google Filestore mount point (local)
# gfs_mount = "/home/sam/repos/il-representations/runs/"
# gfs_mount = "/scratch/sam/il-rep-quals-runs-2021-04-17/"
cluster_subpath = "./"
runs_directory = os.path.join(gfs_mount, cluster_subpath)
path_translations = {
    # when loading things like 'encoder_path' and 'policy_path' from configs,
    # replace the thing on the left with the thing on the right
    "/data/il-representations/": gfs_mount,
    "/root/il-rep/runs/": os.path.join(gfs_mount, cluster_subpath),
    # "/home/sam/repos/il-representations/cloud/runs/": os.path.join(gfs_mount, cluster_subpath),
    "/home/sam/repos/il-representations/runs/": gfs_mount,
}

exp_dir = os.path.join(runs_directory, 'chain_runs')
assert os.path.exists(exp_dir)

# Preparing a table of results

In [None]:
def get_parent_relpath(sample_parent_file, local_root_dir):
    """Get root-relative path to a 'parent' directory, such as the directory
    containing a saved encoder or policy. This is somewhat tricky because
    we need to replace paths that might be different on svm/perceptron or on
    a laptop compared to what they were on GCP. e.g. inside the Ray docker
    container, '/root/il-rep/runs' maps to 'cluster-data/' in the GFS volume.
    The `path_translations` variable handles all the necessary changes."""
    for prefix, replacement in path_translations.items():
        if sample_parent_file.startswith(prefix):
            sample_parent_file = replacement + sample_parent_file[len(prefix):]

    full_path = os.path.abspath(sample_parent_file)
    full_dir = os.path.dirname(full_path)
    rel_dir = os.path.relpath(full_dir, local_root_dir)

    return rel_dir

class SubexperimentRun:
    """A SubexperimentRun associates all the information associated
    with a run of a particular Sacred sub-experiment. That means a
    run (single execution) of the 'repl', 'il_train', or 'il_test'
    experiments."""
    def __init__(self, subexp_dir, experiment_dir_root):
        # Subexperiment dir is used as a unique identifier.
        # We strip out the leading 'experiment_dir_root' to shorten identifiers.
        subexp_dir = os.path.abspath(subexp_dir)
        experiment_dir_root = os.path.abspath(experiment_dir_root)
        self.ident = os.path.relpath(subexp_dir, experiment_dir_root)
        self.subexp_dir = subexp_dir
        self.experiment_dir_root = experiment_dir_root

        # usually paths are like, e.g., 'chain_runs/repl/42' or
        # 'chain_runs/il_train/13'; if we take the second last component,
        # we should get the mode
        self.mode = os.path.split(os.path.split(subexp_dir)[0])[1]
        assert self.mode in {'repl', 'il_train', 'il_test'}, (mode, subexp_dir)

        # Load experiment config
        config_path = os.path.join(subexp_dir, 'config.json')
        with open(config_path, 'r') as fp:
            self.config = json.load(fp)

        # Store a path to relevant progress.csv file (only for il_train/repl)
        progress_path = os.path.join(subexp_dir, "progress.csv")
        if os.path.exists(progress_path):
            self.progress_path = progress_path
        else:
            self.progress_path = None

        # Store a path to relevant eval.json file
        eval_json_path = os.path.join(subexp_dir, "eval.json")
        if os.path.exists(eval_json_path):
            self.eval_json_path = eval_json_path
        else:
            self.eval_json_path = None

        # Infer the .ident attribute for the parent experiment
        # (if it exists)
        if self.mode == 'il_train' and self.config.get('encoder_path') is not None:
            encoder_relpath = get_parent_relpath(
                self.config['encoder_path'], experiment_dir_root)
            # The relpath is going to be something like
            # "chain_runs/10/repl/5/checkpoints/representation_encoder".
            # We heuristically remove the last two parts.
            # (this definitely breaks on Windows…)
            encoder_relpath = '/'.join(encoder_relpath.split('/')[:-2])
            self.parent_ident = encoder_relpath
        elif self.mode == 'il_test':
            policy_relpath = get_parent_relpath(
                self.config['policy_path'], experiment_dir_root)
            if policy_relpath.endswith('/snapshots'):
                policy_relpath = '/'.join(policy_relpath.split('/')[:-1])
            self.parent_ident = policy_relpath
        else:
            # "repl" runs and "il_train" runs without an encoder_path
            # have no parents
            assert self.mode == 'repl' \
              or (self.mode == 'il_train' and self.config.get('encoder_path') is None), \
               (self.mode, self.config.get('encoder_path'))
            self.parent_ident = None
            
        # HACK: adding a use_repl key so that we can see whether il_train runs used repL
        if self.mode == 'il_train':
            self.config['use_repl'] = self.parent_ident is not None

    def get_merged_config(self, index):
        """Get a 'merged' config dictionary for this subexperiment and
        all of its parents. The dict will have a format like this:
        
        {"benchmark": {…}, "il_train": {…}, "il_test": {…}, "repl": {…}}
        
        Note that some keys might not be present (e.g. if this is a `repl` run,
        it will not have the `il_train` key; if this is an `il_train` run with
        no parent, then the `repl` key will be absent)."""
        config = {self.mode: dict(self.config)}
        extract_keys = ('env_cfg', 'venv_opts', 'env_data')
        for extract_key in extract_keys:
            if extract_key in config[self.mode]:
                # move 'benchmark' key to the top because that ingredient name is
                # shared between il_train, and il_test experiments
                config[extract_key] = config[self.mode][extract_key]
                del config[self.mode][extract_key]
        parent = self.get_parent(index)
        if parent is not None:
            # TODO: merge this properly, erroring on incompatible duplicate
            # keys. I think Cody has code for this.
            config.update(parent.get_merged_config(index))
        return config
    
    def get_parent(self, index):
        if self.parent_ident is None:
            return None
        return index.get_subexp(self.parent_ident)

    def __hash__(self):
        return hash(self.ident)
    
    def __eq__(self, other):
        if not isinstance(other, SubexperimentRun):
            return NotImplemented
        return self.ident == other.ident

class SubexperimentIndex:
    """An index of subexperiments. For now this just supports
    looking up experiments by identifier. Later it might support
    lookup by attributes."""
    def __init__(self):
        self.subexp_by_ident = {}
        
    def add_subexp(self, subexp):
        if subexp.ident in self.subexp_by_ident:
            raise ValueError("duplicate subexperiment:", subexp)
        self.subexp_by_ident[subexp.ident] = subexp
        
    def get_subexp(self, ident):
        return self.subexp_by_ident[ident]
    
    def search(self, **attrs):
        """Find a subexperiment with attributes matching the values
        given in 'attrs'."""
        results = []
        for subexp in self.subexp_by_ident.values():
            for k, v in attrs.items():
                if getattr(subexp, k) != v:
                    break
            else:
                results.append(subexp)
        return results

def get_experiment_directories(root_dir, skip_skopt=True):
    """Look for directories that end in a sequence of numbers, and contain a
    grid_search subdirectory."""
    expt_pat = re.compile(r'^.*/(il_test|il_train|repl)/\d+$')
    ignore_pat = re.compile(r'^.*/(grid_search|_sources)$')  # ignore the grid_search subdir
    expt_dirs = set()
    for root, dirs, files in os.walk(root_dir, followlinks=True, topdown=True):
        if ignore_pat.match(root):
            del dirs[:]
            continue
            
        # check whether tihs is a skopt dir
        if skip_skopt and 'grid_search' in dirs:
            gs_files = os.listdir(os.path.join(root, 'grid_search'))
            if any(s.startswith('search-alg-') for s in gs_files):
                # this is a skopt dir, skip it
                print("skipping skopt directory in", root)
                del dirs[:]
                continue

        found_match = False
        for d in dirs:
            d_path = os.path.abspath(os.path.join(root, d))
            m = expt_pat.match(d_path)
            if m is None:
                continue  # no match
            expt_dirs.add(d_path)
            found_match = True

        if found_match:
            del dirs[:]  # don't recurse
    return sorted(expt_dirs)

# Find all experiment directories (i.e. directories containing a grid_search
# subdir)
def load_all_subexperiments(root_dir, skip_skopt=True):
    """Find all experiment run subdirectories, and create SubexperimentIndex objects for them."""
    print("Searching for experiment directories (might take a minute or two)")
    all_expt_directories = get_experiment_directories(root_dir, skip_skopt=skip_skopt)
    print("Loading experiments (might take another minute or two)")
    index = SubexperimentIndex()
    for expt_dir in all_expt_directories:
        subexp = SubexperimentRun(expt_dir, root_dir)
        index.add_subexp(subexp)
    return index

In [None]:
subexp_index = load_all_subexperiments(runs_directory, skip_skopt=True)
print('Discovered', len(subexp_index.subexp_by_ident), 'subexperiments')

test_expts = subexp_index.search(mode='il_test')
test_expts[1].get_merged_config(subexp_index)

## Print a table of il_test results

Shows a separate set of il_test results for each benchmark setting, and also puts that data into a dataframe

In [None]:
def flatten_dict(d):
    """Flatten a nested dict into a single-level dict with
    'keys/separated/like/this'."""
    out_dict = {}
    if isinstance(d, dict):
        key_iter = d.items()
    else:
        assert isinstance(d, list), type(d)
        # we flatten lists into dicts of the form {0: <first elem>, 1: <second elem>, …}
        key_iter = ((str(idx), v) for idx, v in enumerate(d))
    for key, value in key_iter:
        if isinstance(value, (dict, list)):
            value = flatten_dict(value)
            for subkey, subvalue in value.items():
                out_dict[f'{key}/{subkey}'] = subvalue
        else:
            out_dict[key] = value
    return out_dict

def combine_dicts_multiset(dicts):
    """Combine a series of dicts into a key-multiset mapping, where the
    multiset measures how many times each observed value occurs for each
    key."""
    count_dict = {}
    for d in dicts:
        for k, v in d.items():
            if k not in count_dict:
                count_dict[k] = collections.Counter()
            count_dict[k][v] += 1
    return count_dict

def remove_inapplicable_keys(flat_dict):
    """Remove keys that do not make a difference from a flattened config dicts.
    Totally heuristic, so might have to add more options to this later on."""
    remove_keys = set()
    
    # remove inapplicable benchmark keys
    for benchmark_name in ['magical', 'dm_control']:
        if flat_dict.get('env_cfg/benchmark_name') != benchmark_name:
            for key in flat_dict:
                # this will remove, e.g., dm_control keys from magical experiments
                if key.startswith('env_cfg/' + benchmark_name) or key.startswith('env_data/' + benchmark_name):
                    remove_keys.add(key)
                    
    # remove repl keys from things that don't use repL
    if flat_dict.get('il_train/use_repl') is False:
        for key in flat_dict:
            if key.startswith('repl/'):
                remove_keys.add(key)
                    
    return {k: v for k, v in flat_dict.items() if k not in remove_keys}

def simplify_config_dicts(hierarchical_dicts,
                          base_thresh=0.75,
                          remove_seeds=True,
                          prohibited_base_keys=('env_cfg/task_name', 'env_cfg/benchmark_name', 'il_test/exp_ident'),
                          force_remove_keys=('il_test/policy_path', 'il_train/encoder_path')):
    """Simplify flattened config dicts so that:
    
    0. They are totally flat.
    1. They only contain keys for which values actually differ between
       different dicts, and
    2. If the value of some key is the same for at least a fraction
       `base_thresh` of dicts, then that key is moved into a _base config_.
       Returned dicts will only contain that key if they have a different
       value from the base config one.
    3. Optionally, remove all seed values from dicts.

    This makes it more clear which values are actually changing."""
    # first flatten all dicts
    dicts = [dict(flatten_dict(d)) for d in hierarchical_dicts]
    
    # remove seeds, if required
    if remove_seeds:
        for d in dicts:
            for key in list(d.keys()):
                if key.split('/')[-1] == 'seed':
                    del d[key]
                    
    # make sure that every dict has every key
    all_keys = set()
    for d in dicts:
        all_keys |= d.keys()
    for d in dicts:
        for new_key in all_keys - d.keys():
            d[new_key] = None
        
    # remove inapplicable keys
    dicts = [remove_inapplicable_keys(d) for d in dicts]

    # now figure out which keys we wish to remove or move to the base config
    base_config = {}
    remove_keys = set()
    base_thresh_abs = len(dicts) * base_thresh
    count_dict = combine_dicts_multiset(dicts)
    for key, counter in count_dict.items():
        if len(counter) == 1 or key in force_remove_keys:
            if key not in prohibited_base_keys:
                # if all dicts have the same value for this key, we will
                # remove it from output dicts
                remove_keys.add(key)
        elif key not in prohibited_base_keys:
            # if most dicts have the same value for this key, then
            # we add it to the base config
            (max_count_item, max_count), = counter.most_common(1)
            if max_count > base_thresh_abs:
                base_config[key] = max_count_item

    # remove keys that we are ignoring, or for which the corresponding value
    # already exists in the base config
    new_dicts = []
    for old_dict in dicts:
        new_dict = {}
        for key, value in old_dict.items():
            if key in remove_keys \
              or (key in base_config and base_config[key] == value):
                continue  # skip this key
            new_dict[key] = value
        new_dicts.append(new_dict)

    return base_config, new_dicts

Note: the below code can take 2-3m

In [None]:
def escape(obj):
    return html.escape(str(obj))

main_expts = subexp_index.search(mode='il_test')
all_configs = [subexp.get_merged_config(subexp_index) for subexp in main_expts]
base_config, flat_configs = simplify_config_dicts(all_configs)
flat_config_tups = [tuple(sorted(d.items())) for d in flat_configs]
subexp_by_benchmark = {}
for flat_cfg, subexp in zip(flat_config_tups, main_expts):
    bench_key = tuple((k, v) for k, v in flat_cfg if k.startswith('env_') or k.startswith('venv_'))
    subexp_by_benchmark.setdefault(bench_key, []).append((flat_cfg, subexp))

display(HTML('<p><strong>Base config</strong></p>'))
display(HTML('<p>Unless specified otherwise, all config dicts include these keys:</p>'))
print(base_config)

raw_return_data = {}
    
for idx, (bench_key, cfgs_subexps) in enumerate(subexp_by_benchmark.items(), start=1):
    # print out benchmark details
    display(HTML(f'<p><strong>Results for benchmark config &#35;{idx}</strong></p>'))
    display(HTML(f'<p>Config:</p>'))
    rows = [f'<tr><th>{escape(key)}</th><td>{escape(value)}</td></tr>' for key, value in bench_key]
    display(HTML(f'<table>{"".join(rows)}</table>'))
    display(HTML(f'<p>Runs:</p>'))
    
    # cluster subexperiments by config
    by_cfg = {}
    for tup_cfg, subexp in cfgs_subexps:
        tup_cfg = tuple(k for k in tup_cfg if k not in bench_key)
        by_cfg.setdefault(tup_cfg, []).append(subexp)
    
    # load all eval.json files and figure out what columns we need
    stats_dicts = {}
    columns = set()
    for _, subexp in cfgs_subexps:
        if subexp.eval_json_path:
            with open(subexp.eval_json_path, 'r') as fp:
                eval_dict = json.load(fp)
            # is this a magical run?
            is_magical = 'full_data' in eval_dict.keys()
            # is this a procgen run?
            is_procgen = 'train_level' in eval_dict.keys()
            assert not (is_magical and is_procgen)
            if is_magical:
                stats_dict = {
                    '-'.join(env_dict['test_env'].split('-')[:2]): env_dict['mean_score']
                    for env_dict in eval_dict['full_data']
                }
                stats_dict['Average on all envs'] = eval_dict['return_mean']
            if is_procgen:
                stats_dict = {
                    level_type: eval_dict[level_type]['return_mean']
                    for level_type in ['train_level', 'test_level']
                }
            else:
                stats_dict = {'return_mean': eval_dict['return_mean']}
            stats_dicts[subexp] = stats_dict
            columns |= stats_dict.keys()
        else:
            stats_dicts[subexp] = {}
    columns = sorted(columns)
    
    # now produce a table with one row per config
    table_parts = ['<table>']                                         # begin table
    table_parts.append('<tr>')                                        # begin header row
    table_parts.append('<th style="border-collapse: collapse;">Config</th>')
    table_parts.extend(f'<th style="border-collapse: collapse;">{html.escape(col_name)}</th>' for col_name in columns)
    table_parts.append('</tr>')                                       # end header row

    row_indexes = []
    rows = []
    
    for cfg, subexps in sorted(by_cfg.items(), key=lambda cfg_se: dict(cfg_se[0])['il_test/exp_ident']):
        table_parts.append('<tr>')                                    # begin row

        # cell containing config
        if True:  # remove to show full config
            d = dict(cfg)
            exp_ident = d['il_test/exp_ident']
            # if 'mtest_st' in exp_ident:
            #     exp_ident = exp_ident + '_' + d['repl/algo/py/type'].split('.')[-1].lower()
            # bench_name = d['env_cfg/benchmark_name']
            # task_name = d['env_cfg/task_name']
            # desc_str = f'{exp_ident} ({bench_name}/{task_name})'
            table_parts.append(f'<td style="border-collapse: collapse;">{escape(exp_ident)}</td>')
        else:
            kv_cfg = ', '.join(f'{key}={value!r}' for key, value in cfg)
            table_parts.append(f'<td style="max-width: 600px; border-collapse: collapse;">{escape(kv_cfg)}</td>')
        row_indexes.append(exp_ident)  
        # cells containing data
        row_values = []
        for column in columns:
            column_values = [stats_dicts[subexp][column] for subexp in subexps
                             if column in stats_dicts[subexp]]
            if not column_values:
                table_parts.append('<td style="border-collapse: collapse;">-</td>')
                row_values.append({'mean': np.nan, 'std': np.nan, 'n': np.nan})
            else:
                mean = np.mean(column_values)
                std = np.std(column_values)
                n = len(column_values)
                table_parts.append(f'<td style="border-collapse: collapse;">{mean:.3g}±{std:.1g} ({n})</td>')
                row_values.append({'mean': mean, 'std': std, 'n': n})
        rows.append(row_values)
        # cells containing values

        table_parts.append('</tr>')                                   # end row
    raw_return_data[dict(bench_key)['env_cfg/task_name']] = {'row_indexes': row_indexes, 'rows': rows, 'columns': columns}
    table_parts.append('</table>')                                    # end table
    display(HTML(''.join(table_parts)))

In [None]:
def make_monster_frame(main_expts, all_configs):
    """Make a huge dataframe containing results on all environments."""
    df_records = collections.defaultdict(list)
    for expt, cfg in zip(main_expts, all_configs):
        if not expt.eval_json_path:
            continue
        with open(expt.eval_json_path, 'r') as fp:
            eval_dict = json.load(fp)
        train_env = cfg['env_cfg']['task_name']
        exp_ident = cfg['il_train']['exp_ident']
        # is this a magical run?
        is_magical = 'full_data' in eval_dict.keys()
        is_procgen = not is_magical and ('train_level' in eval_dict.keys())
        if is_magical:
            test_envs_returns = []
            for env_dict in eval_dict['full_data']:
                short_test_env = '-' + env_dict['test_env'].split('-')[1]
                test_envs_returns.append((short_test_env, env_dict['mean_score']))
            test_envs_returns.append(('Average', eval_dict['return_mean']))
        elif is_procgen:
            test_envs_returns = []
            for env_name in ['train_level', 'test_level']:
                env_dict = eval_dict[env_name]
                test_envs_returns.append((env_name, env_dict['return_mean']))
        else:
            test_envs_returns = [(train_env, eval_dict['return_mean'])]
        for test_env, ret in test_envs_returns:
            # here are all the columns we have in our data frame
            df_records['eval_json_path'].append(expt.eval_json_path)
            df_records['exp_ident'].append(exp_ident)
            df_records['train_env'].append(train_env)
            df_records['test_env'].append(test_env)
            df_records['return'].append(ret)
            df_records['is_magical'].append(is_magical)

    # big dataframe containing all results
    monster_frame = pd.DataFrame.from_dict(df_records)
    return monster_frame

monster_frame = make_monster_frame(main_expts, all_configs)

In [None]:
# pivot the results so we can insert them here: https://docs.google.com/spreadsheets/d/1bIRKNlLHOeZSsQTkP1oluL1KfJACb-WeTID7Ve-3R_U/edit#gid=0
def return_mean(x): return np.mean(x)
def return_std(x): return np.std(x)
def count(x): return len(x)
return_pivot = monster_frame.pivot_table(index=['exp_ident', 'train_env', 'test_env'], values='return', aggfunc=[return_mean, return_std, count])
return_pivot.columns = [tup[0] for tup in return_pivot.columns.values]
return_pivot.to_csv('./results.csv')

In [None]:
class UnrecognisedExpIdent(Exception):
    """Raised when factor_exp_ident fails because it cannot recognise exp_ident."""

def factor_exp_ident(row):
    """Parse an exp_ident into a dictionary of components."""
    exp_ident = row['exp_ident']
    match = re.match(
        r'^(?P<il_algo>bc|gail)(_icml)?_'
        r'(?P<repl_algo>dynamics|vae|inv_dyn|identity_cpc|control|tcpc_8step)'
        r'(_(?P<repl_data>.*))?$',
        exp_ident)
    if match is None:
        raise UnrecognisedExpIdent(f"Could not parse exp_ident='{exp_ident}'")
    human_names = {
        'il_algo': {
            'bc': 'BC',
            'gail': 'GAIL',
        },
        'repl_algo': {
            'dynamics': 'dynamics',
            'vae': 'VAE',
            'inv_dyn': 'inv. dyn.',
            'identity_cpc': 'CPC',
            'tcpc_8step': 'TCPC-8',
            'control': 'control',
        },
        'repl_data': {
            'rand_demos_magical_mt': 'MT rollouts + MT demos',
            'rand_demos': 'rollouts + demos',
            'rand_demos_test': 'test demos + rollouts',
            None: 'n/a'
        },
    }
    ret_dict = {}
    for group_name, group_val in match.groupdict().items():
        human_readable_value = human_names.get(group_name, {}).get(group_val, group_val)
        ret_dict[group_name] = human_readable_value
    return ret_dict

def human_readable_name(row):
    """After succesfully applying factor_exp_ident, you can apply this function
    to get a human-readable name for each exp_ident"""
    if row['repl_data'] == 'n/a' and row['repl_algo'] == 'control':
        return row['il_algo'] + ' control'
    return f"{row['il_algo']} + {row['repl_algo']} w/ {row['repl_data']}"
    
allowed_test_variants = ['Average']
allowed_mask = monster_frame['test_env'].isin(allowed_test_variants)
allowed_monster = monster_frame[allowed_mask].sort_values(by='exp_ident')
grouped_monster = allowed_monster.groupby(by=['train_env', 'test_env'])
for (train_env, test_env), frame in grouped_monster:
    try:
        # this will only work for my quals runs from 2021-04-17
        new_cols = frame.apply(factor_exp_ident, result_type='expand', axis=1)
        frame = pd.concat([frame, new_cols], axis=1)
        frame['human_name'] = frame.apply(human_readable_name, axis=1)
    except UnrecognisedExpIdent:
        raise
    sns.set_context("talk")
    # using a facetgrid
    g = sns.catplot(data=frame, row='il_algo', col='repl_algo', orient='v',
                    x='repl_data', hue='repl_data', y='return', kind='box',
                    sharex=False, sharey=True, margin_titles=True, height=5, dodge=False, aspect=0.7)
    g.set_xticklabels(rotation=30)
    plt.ylim([0, 1])
    # g.legend.hide()
    plt.suptitle(f"{test_env} return for IL+repL approaches (trained on {train_env})")
    g.fig.tight_layout()
    plt.show()

In [None]:
raw_return_dataframes = dict()                                                                                                                                                                                                                                                                                                                                                            

for k in raw_return_data.keys():
    print(f"Processing task {k}")
    raw_return_dataframes[k] = dict()
    for stat in ['mean', 'std', 'n']:
        columns = raw_return_data[k]['columns']
        index = raw_return_data[k]['row_indexes']
        rows = []
        for raw_row in raw_return_data[k]['rows']: 
            rows.append([el[stat] for el in raw_row])
        try: 
            raw_return_dataframes[k][stat] = pd.DataFrame(rows, index=index, columns=columns)
        except: 
            import pdb; pdb.set_trace()

## Print + DF-ize Raw AUC Data

In [None]:
# How many splits?
num_split = 6

# Exclude first n values in the loss list?
start_count = 3

def calculate_auc(y, dx=1):
    return trapz(y, dx=dx)
    
train_expts = subexp_index.search(mode='il_train')
all_configs = [subexp.get_merged_config(subexp_index) for subexp in train_expts]
base_config, flat_configs = simplify_config_dicts(all_configs)
flat_config_tups = [tuple(sorted(d.items())) for d in flat_configs]
subexp_by_benchmark = {}
for flat_cfg, subexp in zip(flat_config_tups, train_expts):
    bench_key = tuple((k, v) for k, v in flat_cfg if k.startswith('env_') or k.startswith('venv_'))
    subexp_by_benchmark.setdefault(bench_key, []).append((flat_cfg, subexp))

display(HTML('<p><strong>Base config</strong></p>'))
display(HTML('<p>Unless specified otherwise, all config dicts include these keys:</p>'))
print(base_config)
    
raw_auc_data = {}
raw_loss_data = {}

for idx, (bench_key, cfgs_subexps) in enumerate(subexp_by_benchmark.items(), start=1):
    # print out benchmark details
    task_name = dict(bench_key)['env_cfg/task_name']
    
    # cluster subexperiments by config
    by_cfg = {}
    for tup_cfg, subexp in cfgs_subexps:
        tup_cfg = tuple(k for k in tup_cfg if k not in bench_key)
        by_cfg.setdefault(tup_cfg, []).append(subexp)

    # load all progress files and figure out what columns we need
    stats_dicts = {}
    columns = set()
    raw_loss_data[task_name] = None
    for cfg, subexp in cfgs_subexps:
        d = dict(cfg)
        exp_ident = d['il_train/exp_ident']
        benchmark_name = d.get('env_cfg/benchmark_name', base_config.get('env_cfg/benchmark_name'))
        full_length = 400 if benchmark_name == 'dm_control' else 40
        
        loss_df_cols = ['exp_ident', 'batches', 'seed', 'loss']
        
        if subexp.progress_path:
            try:
                df = pd.read_csv(subexp.progress_path)
            except:
                print(f'Read csv reported error for exp {exp_ident}, skipping...')
                continue
            if len(df['loss']) != full_length:
                print(f'Experiment {exp_ident} only has len(loss) {len(df["loss"])}, skipping... ')
                continue
                
            step_length = len(df['loss']) // num_split
            stats_dict = {}
            if len(df['loss']) < 10: 
                stats_dicts[subexp] = {}
                continue 
                
            for step in range(step_length, len(df['loss']), step_length):
                label = f"step {step:02d}"
                stats_dict[label] = calculate_auc(df['loss'][start_count:step])
                
            stats_dict[f"step {len(df['loss'])}"] = calculate_auc(df['loss'][start_count:len(df['loss'])])
            stats_dicts[subexp] = stats_dict
            columns |= stats_dict.keys()
            
            # Add losses to raw_loss_data
            sub_df_list = []
            for count, loss in enumerate(df['loss']):
                sub_df_list.append([exp_ident, count+1, subexp.config['seed'], loss])
            sub_df = pd.DataFrame(sub_df_list, columns=loss_df_cols)

            if raw_loss_data[task_name] is None:
                raw_loss_data[task_name] = sub_df
            else:
                raw_loss_data[task_name] = raw_loss_data[task_name].append(sub_df, ignore_index=True)

        else:
            stats_dicts[subexp] = {}
    columns = sorted(columns)
    
    raw_auc_data[task_name] = dict()
    raw_auc_data[task_name]['columns'] = columns
    
    row_indexes = []
    rows = []
    for cfg, subexps in sorted(by_cfg.items(), key=lambda cfg_se: dict(cfg_se[0])['il_train/exp_ident']):
        table_parts.append('<tr>')                                    # begin row

        # cell containing config
        if True:  # remove to show full config
            d = dict(cfg)
            exp_ident = d['il_train/exp_ident']
            # bench_name = d['env_cfg/benchmark_name']
            # task_name = d['env_cfg/task_name']
            # desc_str = f'{exp_ident} ({bench_name}/{task_name})'
            table_parts.append(f'<td style="border-collapse: collapse;">{escape(exp_ident)}</td>')
        else:
            kv_cfg = ', '.join(f'{key}={value!r}' for key, value in cfg)
            table_parts.append(f'<td style="max-width: 600px; border-collapse: collapse;">{escape(kv_cfg)}</td>')
            
        # cells containing data
        row_indexes.append(exp_ident)
        row = []
        for column in columns:
            column_values = [stats_dicts[subexp][column] for subexp in subexps
                             if subexp in stats_dicts.keys() and column in stats_dicts[subexp]]
            if not column_values:
                table_parts.append('<td style="border-collapse: collapse;">-</td>')
                row.append({'mean': np.nan, 'std': np.nan, 'n': np.nan})
            else:
                mean = np.mean(column_values)
                std = np.std(column_values)
                n = len(column_values)
                table_parts.append(f'<td style="border-collapse: collapse;">{mean:.3g}±{std:.1g} ({n})</td>')
                row.append({'mean': mean, 'std': std, 'n': n})
        rows.append(row)
        # cells containing values

        table_parts.append('</tr>')                                   # end row
    raw_auc_data[task_name]['row_indexes'] = row_indexes
    raw_auc_data[task_name]['rows'] = rows
    table_parts.append('</table>')                                    # end table

In [None]:
raw_auc_dataframes = dict()

for k in raw_auc_data.keys(): 
    print(k)
    raw_auc_dataframes[k] = dict()
    for stat in ['mean', 'std', 'n']:
        columns = raw_auc_data[k]['columns']
        index = raw_auc_data[k]['row_indexes']
        rows = []
        for raw_row in raw_auc_data[k]['rows']: 
            rows.append([el[stat] for el in raw_row])

        raw_auc_dataframes[k][stat] = pd.DataFrame(rows, index=index, columns=columns)

## Interface Check! 

In [None]:
# for task in raw_auc_dataframes.keys(): 
#     for stat in raw_auc_dataframes[task].keys(): 
#         raw_auc_dataframes[task][stat].to_csv(f"cached_dfs/raw_auc_dataframes_{task}_{stat}.csv")

# for task in raw_return_dataframes.keys(): 
#     for stat in raw_return_dataframes[task].keys(): 
#         raw_return_dataframes[task][stat].to_csv(f"cached_dfs/raw_return_dataframes_{task}_{stat}.csv")

# print(f"These dataframes were last saved as a failsafe cache at {datetime.now()} PST")

In [None]:
## RUN TO RELOAD DATA FROM DISK ## 
raw_return_dataframes = {}
raw_auc_dataframes = {}
for task in ['MatchRegions-Demo-v0', 'MoveToCorner-Demo-v0', 'MoveToRegion-Demo-v0', 'finger-spin', 'cheetah-run']: 
    raw_return_dataframes[task] = {}
    raw_auc_dataframes[task] = {}
    for stat in ['mean', 'std', 'n']: 
        raw_return_dataframes[task][stat] = pd.read_csv(f"cached_dfs/raw_return_dataframes_{task}_{stat}.csv", 
                                                            index_col=0)
        raw_auc_dataframes[task][stat] = pd.read_csv(f"cached_dfs/raw_auc_dataframes_{task}_{stat}.csv", 
                                                         index_col=0)
        


In [None]:
raw_auc_dataframes

At this point, the raw data is stored in dataframes: raw_auc_dataframes and raw_return_dataframes

## Make Plots

In [None]:
# def clean_index_val(index, check_algo=False, check_data=False): 
#     new_index = index 
#     for lookup_key, lookup_val in merge_lookups.items(): 
#         new_index = new_index.replace(lookup_key, lookup_val)
#     return new_index 

def order_idx(indexes, algo_order):
    for idx in indexes:
        if idx in algo_order:
            continue
        else:
            print(f"{idx} not arranged in algo_order, append to the list.")
            algo_order.append(idx)
    return algo_order

In [None]:
# def control_not_key(col, control_key): 
#     if 'control' in col: 
#         if col == control_key: 
#             return True 
#         else: 
#             return False 
#     else: 
#         return True 
    
# def should_keep_exp_ident(ident): 
#     if 'froco' in ident:
#         return False 
#     if 'no_ortho' in ident and 'actual' not in ident: 
#         return False 
#     return True 

In [None]:
def check_algo_or_data(index, lookup_dict, check_algo=False, check_data=False, verbose=True, blacklist=()):
    new_index = index 
    assert sum([check_algo, check_data]) % 2 != 0, 'One of check_algo and check_data must be True'
    inner_lookup_key = 'algo_lookups' if check_algo else 'data_lookups'
    inner_lookup_dict = lookup_dict[inner_lookup_key]
    check_type = 'data' if check_data else 'algo'
    for term in blacklist: 
        if re.search(term, index): 
            if verbose: 
                print(f"Run {index} skipped due to blacklist term {term}")
            return None 
    for lookup_key, lookup_val in inner_lookup_dict.items(): 
        if re.search(lookup_key, new_index):
            new_index = lookup_val
            return new_index
    if verbose:
        print(f"Didn't find any {check_type} entries for {index}")
    return None

In [None]:
def create_pivoted_dfs(base_df_dict, ret_col_lookup, blacklist_terms, 
                      whitelist_lookups, control_idx, algo_order, verbose=True): 
    data_dataframes = dict()
    for task in base_df_dict.keys():
        data_dataframes[task] = dict()
        for stat in base_df_dict[task].keys():
            dataset_list = set([check_algo_or_data(ind, whitelist_lookups, check_data=True, 
                                                   verbose=verbose, blacklist=blacklist_terms) 
                         for ind in base_df_dict[task][stat].index]) - set([None])
            ret_col = ret_col_lookup[task]
            algo_list = set([check_algo_or_data(algo, whitelist_lookups, check_algo=True, 
                                                verbose=verbose, blacklist=blacklist_terms) 
                         for algo in base_df_dict[task][stat][ret_col].index]) - set([None])
            table = pd.DataFrame(np.zeros((len(algo_list), len(dataset_list))), 
                                 index=algo_list,
                                 columns=dataset_list)
            table[table == 0] = np.nan 
            
            for exp_ident in base_df_dict[task][stat][ret_col].index:
                data = check_algo_or_data(exp_ident, whitelist_lookups, check_data=True, 
                                          verbose=verbose, blacklist=blacklist_terms)
                algo = check_algo_or_data(exp_ident, whitelist_lookups, check_algo=True, 
                                          verbose=verbose, blacklist=blacklist_terms)
                if (data and algo):
                    if verbose: 
                        print(f"Using information from {exp_ident}, is that OK?")
                    if not math.isnan(float(base_df_dict[task][stat][ret_col][exp_ident])):
                        try: 
                            table[data][algo] = float(base_df_dict[task][stat][ret_col][exp_ident])
                        except Exception as e: 
                            print(e)
                            import pdb; pdb.set_trace()
                elif algo == control_idx: 
                    if not math.isnan(float(base_df_dict[task][stat][ret_col][exp_ident])):
                        for dataset in dataset_list: 
                            table[dataset][algo] = float(base_df_dict[task][stat][ret_col][exp_ident])

            # Move control to top of table
            idx = order_idx(table.index, algo_order)
            table = table.reindex(idx)
            data_dataframes[task][stat] = table
    return data_dataframes

In [None]:
def create_pooled_dfs(base_df_dict, control_idx): 
    pooled_dataframes = dict()

    for task in base_df_dict.keys():

        #create a copy of data_dataframes that I can modify 
        pooled_dataframes[task] = dict()
        for stat in base_df_dict[task].keys(): 
            pooled_dataframes[task][stat] = base_df_dict[task][stat].copy(deep=True)


        pooled_control_mean = pooled_dataframes[task]['mean'].loc[control_idx].mean()
        pooled_control_std = pooled_dataframes[task]['std'].loc[control_idx].mean()
        pooled_control_n = pooled_dataframes[task]['n'].loc[control_idx].sum()
        pooled_dataframes[task]['mean'].loc[control_idx] = pooled_control_mean
        pooled_dataframes[task]['std'].loc[control_idx] = pooled_control_std
        pooled_dataframes[task]['n'].loc[control_idx] = pooled_control_n
    return pooled_dataframes

In [None]:
def filter_loss_dfs(base_df, blacklist_terms, whitelist_lookups, control_idx, verbose=False):

    dataset_list = set([check_algo_or_data(row['exp_ident'], whitelist_lookups, check_data=True, 
                                            verbose=verbose, blacklist=blacklist_terms) 
                         for index, row in base_df.iterrows()]) - set([None])
    dataset_list = [ds.replace('\n','') for ds in dataset_list]
    return_data = {data_source: [] for data_source in dataset_list}
    for index, row in base_df.iterrows():
        # Data source of this exp
        data_source = check_algo_or_data(row['exp_ident'], whitelist_lookups, check_data=True, 
                                         verbose=verbose, blacklist=blacklist_terms)
        new_exp_ident = check_algo_or_data(row['exp_ident'], whitelist_lookups, check_algo=True, 
                                         verbose=verbose, blacklist=blacklist_terms)
        if data_source and new_exp_ident:
            data_source = data_source.replace('\n','')
            return_data[data_source].append([new_exp_ident] + [row[key] for key in row.keys() if key != 'exp_ident'])
        elif new_exp_ident == control_idx: 
            for dataset in dataset_list: 
                return_data[dataset].append([new_exp_ident] + [row[key] for key in row.keys() if key != 'exp_ident'])
    
    # Make dataframes
    for data_source, value in return_data.items():
        return_data[data_source] = pd.DataFrame(value, columns=base_df.columns)
    return return_data

In [None]:
from pathlib import Path

def save_eps(sns_plot, env, plot_type):
    save_dir = f"./plots/{cluster_subpath.split('/')[1]}"
    save_path = os.path.join(save_dir, f"{env}-{plot_type}.eps")
    Path(save_dir).mkdir(parents=True, exist_ok=True)
    fig = sns_plot.get_figure()
    fig.savefig(save_path, bbox_inches="tight", format='eps', dpi=1200)
    print(f"Plot saved at {os.path.abspath(save_path)}")

#### Plotting Code

In [None]:
from scipy.stats import ttest_ind_from_stats

In [None]:
def get_p_val_df(mean_df, std_df, n_df, control_idx, alternative):
    test_means = mean_df[mean_df.index != control_idx]
    control_means = mean_df.loc[control_idx]

    test_std = std_df[std_df.index != control_idx]
    control_std = std_df.loc[control_idx]

    test_ns = n_df[n_df.index != control_idx]
    control_ns = n_df.loc[control_idx]
    
    try: 
        t_stats, p_vals = ttest_ind_from_stats(test_means, test_std, test_ns, control_means, control_std, control_ns, 
                         equal_var=False, alternative=alternative)
    except: 
        import pdb; pdb.set_trace()
    
    return p_vals

In [None]:
def task_return_heatmap(task, df_dict, control_key, fontsize=30, 
                        show_ylabel=True, font_scale=2, narrow=False, min_max_vals=None): 
    sns.set(font="Times New Roman", font_scale=font_scale)
    
    
    mean_df = df_dict[task]['mean']
    str_format = "{:.2f}" if mean_df.loc[:].min()[0] < 1 else "{:.0f}"

    
    std_df = df_dict[task]['std']
    
    
    n_df = df_dict[task]['n']
    
    p_vals = get_p_val_df(mean_df, std_df, n_df, control_key, alternative="greater")
    p_val_df = pd.DataFrame(p_vals, index=[el for el in mean_df.index if el != control_key],
                            columns=mean_df.columns)
    normed_df = mean_df - mean_df.loc[control_key]

        
    std_df = std_df.applymap(str_format.format)
    mean_df = mean_df.applymap(str_format.format)
    
    text_df = pd.DataFrame()
    for col in mean_df.columns:
        text_df[col] = mean_df[col].astype(str) + ' (' + std_df[col].astype(str) + ')'
    
    for col in text_df.columns: 
        for ind in text_df.index: 
            if ind == control_key: 
                continue 
            elif p_val_df[col][ind] < 0.05: 
                text_df[col][ind] = text_df[col][ind] + '**'
    
    if narrow:
        figsize = (4, 7)
    else: 
        figsize = (13, 7)
    if show_ylabel:
        figsize = (figsize[0]+3, figsize[1])
    plt.figure(figsize=figsize)
    
    plt.title(f"Mean Return: {task}", fontsize=30)
    
    if min_max_vals is not None: 
        vmin, vmax = min_max_vals
    else: 
        vmin, vmax = None, None 
        

    ax = sns.heatmap(normed_df, annot=text_df, annot_kws={"fontsize":fontsize}, cbar=False,
                     cmap=sns.diverging_palette(250, 10, s=60, l=45, as_cmap=True), 
                     center=0, fmt='.20', yticklabels=show_ylabel, vmin=vmin, vmax=vmax)

    return ax

In [None]:
def task_auc_heatmap(task, df_dict, control_key, fontsize=30,
                     show_ylabel=True, font_scale=2, show_sig=False, narrow=False, min_max_vals=None): 
    sns.set(font="Times New Roman", font_scale=font_scale)
    
    str_format = "{:.2f}"
    subset_df = df_dict[task]['mean']
    std_df = df_dict[task]['std']
    n_df = df_dict[task]['n']
    
    p_vals = get_p_val_df(subset_df, std_df, n_df, control_key, alternative="less")
    p_val_df = pd.DataFrame(p_vals, 
                            index=[el for el in subset_df.index if el != control_key],
                            columns=subset_df.columns)
    std_sub_df = std_df.applymap(str_format.format)
        
    normed_df = subset_df - subset_df.loc[control_key]
    
    subset_df = subset_df.applymap(str_format.format)
    text_df = pd.DataFrame()
    for col in subset_df.columns:
        text_df[col] = subset_df[col].astype(str) + ' (' + std_sub_df[col].astype(str) + ')'
    
    if show_sig: 
        for col in text_df.columns: 
            for ind in text_df.index: 
                if ind == control_key: 
                    continue 
                elif p_val_df[col][ind] < 0.05: 
                    text_df[col][ind] = text_df[col][ind] + '**'
        
    if narrow:
        figsize = (4, 7)
    else: 
        figsize = (13, 7)
    if show_ylabel:
        figsize = (figsize[0]+3, figsize[1])

    plt.figure(figsize=figsize)
    plt.title(f"AUC: {task}", fontsize=30)
    
    if min_max_vals is not None: 
        vmin, vmax = min_max_vals
    else: 
        vmin, vmax = None, None 
    
    
    ax = sns.heatmap(normed_df, annot=text_df, annot_kws={"fontsize":fontsize}, cbar=False,
                     cmap=sns.diverging_palette(10, 250, s=60, l=45, as_cmap=True), 
                     center=0, fmt='.20', yticklabels=show_ylabel, vmin=vmin, vmax=vmax)

    return ax

In [None]:
def plot_loss_curves(env, data_source, task_df, algo_order, show_ylabel=False, show_legend=False,
                    ylim=None):
    sns.set(font="Times New Roman", font_scale=1.5)
#     sns.set(font="Verdana")  # Use default font?
    plt.figure(figsize=(6,5))
    plt.title(f"{data_source}")
    kw_legend = "brief" if show_legend else False
    ax = sns.lineplot(data=task_df, x="batches", y="loss", hue="exp_ident", hue_order=algo_order, 
                     legend=kw_legend)
    
    ylabel = f"{env} Loss" if show_ylabel else None
    ax.set(xlabel="Epoch", ylabel=ylabel)
    if ylim:
        ax.set(ylim=ylim)
    if show_legend:
        legend = ax.legend()
        legend.texts[0].set_text("Algorithm")
        plt.setp(ax.get_legend().get_texts(), fontsize='12') # for legend text

    return ax

## Baseline Plots

### Return Plots

In [None]:
control_idx = 'Control'

envs = ['cheetah-run', 'finger-spin', 'MatchRegions-Demo-v0', 'MoveToRegion-Demo-v0']

dmc_blacklist_terms = ['froco', 'ablation', 'newbcaugs']
dmc_algo_lookups = {
    '^icml_inv_dyn_': "Inverse Dynamics", 
    "^icml_ac_tcpc_": "Action Conditioned TCPC", 
    "^icml_vae_": "VAE", 
    "^icml_dynamics_": "Dynamics Model",
    "^control_ortho_init_": "Control",
    "^icml_identity_cpc_": "Temporal CPC (TCPC)", #due to error
}


magical_blacklist_terms = ['froco', 'ablation']
magical_algo_lookups = {
    '^icml_inv_dyn_': "Inverse Dynamics", 
    "^icml_ac_tcpc_": "Action Conditioned TCPC", 
    "^icml_vae_": "VAE", 
    "^icml_dynamics_": "Dynamics Model",
    "^control_ortho_init": "Control",
    "^icml_identity_cpc_": "Temporal CPC (TCPC)", 
}

data_lookups = {
    "cfg_data_repl_random": "Random Rollouts", 
    "cfg_data_repl_demos_random": "Demos & \nRandom Rollouts", 
    "cfg_data_repl_demos_magical_mt": "Multitask Demos", 
    "cfg_data_repl_rand_demos_magical_mt": "Multitask Demos & \nRandom Rollouts",
}

magical_whitelist_lookups = {'data_lookups': data_lookups, 'algo_lookups': magical_algo_lookups}
dmc_whitelist_lookups = {'data_lookups': data_lookups, 'algo_lookups': dmc_algo_lookups}


## ***CHANGE THESE NAMES FOR NEW PLOTS *** 
baseline_magical_configs = dict(control_idx=control_idx, 
                                blacklist_terms=magical_blacklist_terms, 
                                whitelist_lookups=magical_whitelist_lookups) 
baseline_dmc_configs = dict(control_idx=control_idx, 
                            blacklist_terms=dmc_blacklist_terms, 
                            whitelist_lookups=dmc_whitelist_lookups) 

baseline_plot_config_lookup = {'MoveToRegion-Demo-v0': baseline_magical_configs, 
                             'MatchRegions-Demo-v0': baseline_magical_configs, 
                             'cheetah-run': baseline_dmc_configs, 
                             'finger-spin': baseline_dmc_configs
                               }

In [None]:
## Configs for baseline plots 

ret_col_lookup = {
    'MatchRegions-Demo-v0': 'Average on all envs', 
    'MoveToRegion-Demo-v0': 'Average on all envs', 
    'MoveToCorner-Demo-v0': 'Average on all envs', 
    'finger-spin': 'return_mean', 
    'cheetah-run': 'return_mean'
}

# Set plot properties according to where we want to position it in paper
right_envs = ['MatchRegions-Demo-v0', 'cheetah-run']
left_envs = ['MoveToRegion-Demo-v0', 'finger-spin']

algo_order = ["Control", "Temporal CPC (TCPC)", "Action Conditioned TCPC", "VAE", "Dynamics Model",
              "Inverse Dynamics"]

for env in envs:
    print(f"Creating plot for {env}")
    plot_config = baseline_plot_config_lookup[env]
    pivoted_dfs = create_pivoted_dfs(raw_return_dataframes, ret_col_lookup, 
                                     plot_config['blacklist_terms'], plot_config['whitelist_lookups'], 
                                     plot_config['control_idx'], algo_order, verbose=False)
    pooled_dfs = create_pooled_dfs(pivoted_dfs, control_idx=plot_config['control_idx'])
    kwargs = {}
    if env in left_envs:
        kwargs = {'show_ylabel': False}
    env_plot = task_return_heatmap(env, pooled_dfs, control_key=control_idx, min_max_vals=(-.1, .1), **kwargs)
    display(env_plot)
    save_eps(env_plot, env, 'return')

# # to be paranoid
del pivoted_dfs 
del pooled_dfs

###  AUCs

In [None]:
ret_col_lookup = {
    'MatchRegions-Demo-v0': 'step 40', 
    'MoveToRegion-Demo-v0': 'step 40', 
    'MoveToCorner-Demo-v0': 'step 40', 
    'finger-spin': 'step 400', 
    'cheetah-run': 'step 400'
}

for env in envs:
    plot_config = baseline_plot_config_lookup[env]
    pivoted_dfs = create_pivoted_dfs(raw_auc_dataframes, ret_col_lookup, 
                                     plot_config['blacklist_terms'], plot_config['whitelist_lookups'], 
                                     plot_config['control_idx'], algo_order, verbose=False)
    pooled_dfs = create_pooled_dfs(pivoted_dfs, control_idx=plot_config['control_idx'])
    kwargs = {}
    if env in left_envs:
        kwargs = {'show_ylabel': False}
    env_plot = task_auc_heatmap(env, pooled_dfs, control_key=control_idx, **kwargs)
    display(env_plot)
    save_eps(env_plot, env, 'AUC')

# to be paranoid
del pivoted_dfs 
del pooled_dfs

### Loss curves

In [None]:
env_ylim = {
    'MatchRegions-Demo-v0': (0.05, 0.2),
    'MoveToRegion-Demo-v0': (0.05, 0.2),
    'cheetah-run': (1, 4),
    'finger-spin': (-0.7, 0)
}

for env in envs:
    plot_config = baseline_plot_config_lookup[env]
    filtered_dfs = filter_loss_dfs(raw_loss_data[env], 
                                   plot_config['blacklist_terms'], 
                                   plot_config['whitelist_lookups'], 
                                   plot_config['control_idx'],
                                   verbose=False)
    print(f"Start plotting curves for env {env}...")
    for data_source, df in filtered_dfs.items():
        print(set(df['exp_ident']))
        kwargs = {}
        if data_source == 'Demos & Random Rollouts':
            kwargs['show_ylabel'] = True
            if env == 'MatchRegions-Demo-v0':
                kwargs['show_legend'] = True
        if env in env_ylim.keys():
            kwargs['ylim'] = env_ylim[env]
        env_plot = plot_loss_curves(env, data_source, df, algo_order, **kwargs)
        display(env_plot)
        filename = f"{env}-{data_source}".replace(" ", "").replace("&", "-")
        save_eps(env_plot, filename, 'loss')

## FROCO plots

In [None]:
control_idx = 'Control'

envs = ['cheetah-run', 'finger-spin', 'MatchRegions-Demo-v0', 'MoveToRegion-Demo-v0']


froco_blacklist_terms = ['ablation', 'newbcaugs']
froco_algo_lookups = {
    '^froco_icml_inv_dyn_': "Inverse Dynamics", 
    "^froco_icml_ac_tcpc_": "Action Conditioned TCPC", 
    "^froco_icml_vae_": "VAE", 
    "^froco_icml_dynamics_": "Dynamics Model",
    "^froco_control_ortho_init": "Control",
    "^froco_icml_identity_cpc_": "Temporal CPC (TCPC)", # Due to error 
}

data_lookups = {
    "cfg_data_repl_random": "Random Rollouts", 
    "cfg_data_repl_demos_random": "Demos & \nRandom Rollouts", 
    "cfg_data_repl_demos_magical_mt": "Multitask Demos", 
    "cfg_data_repl_rand_demos_magical_mt": "Multitask Demos & \nRandom Rollouts",
}

froco_whitelist_lookups = {'data_lookups': data_lookups, 'algo_lookups': froco_algo_lookups}


## ***CHANGE THESE NAMES FOR NEW PLOTS *** 
froco_configs = dict(control_idx=control_idx, 
                                blacklist_terms=froco_blacklist_terms, 
                                whitelist_lookups=froco_whitelist_lookups) 

froco_plot_config_lookup = {'MoveToRegion-Demo-v0': froco_configs, 
                             'MatchRegions-Demo-v0': froco_configs, 
                             'cheetah-run': froco_configs, 
                             'finger-spin': froco_configs}

In [None]:
## Configs for baseline plots 

ret_col_lookup = {
    'MatchRegions-Demo-v0': 'Average on all envs', 
    'MoveToRegion-Demo-v0': 'Average on all envs', 
    'MoveToCorner-Demo-v0': 'Average on all envs', 
    'finger-spin': 'return_mean', 
    'cheetah-run': 'return_mean'
}

# Set plot properties according to where we want to position it in paper
right_envs = ['MatchRegions-Demo-v0', 'cheetah-run']
left_envs = ['MoveToRegion-Demo-v0', 'finger-spin']

algo_order = ["Control", "Temporal CPC (TCPC)", "Action Conditioned TCPC ", "VAE ", "Dynamics Model ",
              "Inverse Dynamics "]



for env in envs:
    print(f"Creating FROCO plot for {env}")
    plot_config = froco_plot_config_lookup[env]
    pivoted_dfs = create_pivoted_dfs(raw_return_dataframes, ret_col_lookup, 
                                     plot_config['blacklist_terms'], plot_config['whitelist_lookups'], 
                                     plot_config['control_idx'], algo_order, verbose=False)
    pooled_dfs = create_pooled_dfs(pivoted_dfs, control_idx=plot_config['control_idx'])
    kwargs = {}

    if env in left_envs:
        kwargs = {'show_ylabel': False}
    
    if env == 'MoveToRegion-Demo-v0': 
        kwargs['min_max_vals'] = (-.2, .2)
    env_plot = task_return_heatmap(env, pooled_dfs, control_key=control_idx, **kwargs)
    display(env_plot)
    save_eps(env_plot, env, 'froco_return')

# to be paranoid
del pivoted_dfs 
del pooled_dfs

### Froco AUC

In [None]:
## Configs for baseline plots 

ret_col_lookup = {
    'MatchRegions-Demo-v0': 'step 40', 
    'MoveToRegion-Demo-v0': 'step 40', 
    'MoveToCorner-Demo-v0': 'step 40', 
    'finger-spin': 'step 400', 
    'cheetah-run': 'step 400'
}
# Set plot properties according to where we want to position it in paper
right_envs = ['MatchRegions-Demo-v0', 'cheetah-run']
left_envs = ['MoveToRegion-Demo-v0', 'finger-spin']

algo_order = ["Control", "Temporal CPC (TCPC)", "Action Conditioned TCPC ", "VAE ", "Dynamics Model ",
              "Inverse Dynamics "]



for env in envs:
    print(f"Creating FROCO plot for {env}")
    plot_config = froco_plot_config_lookup[env]
    pivoted_dfs = create_pivoted_dfs(raw_auc_dataframes, ret_col_lookup, 
                                     plot_config['blacklist_terms'], plot_config['whitelist_lookups'], 
                                     plot_config['control_idx'], algo_order, verbose=False)
    pooled_dfs = create_pooled_dfs(pivoted_dfs, control_idx=plot_config['control_idx'])
    kwargs = {}

    if env in left_envs:
        kwargs = {'show_ylabel': False}
    env_plot = task_auc_heatmap(env, pooled_dfs, control_key=control_idx, **kwargs)
    display(env_plot)
    save_eps(env_plot, env, 'froco_auc')

# to be paranoid
del pivoted_dfs 
del pooled_dfs

### Loss curves

In [None]:
env_ylim = {
    'MatchRegions-Demo-v0': (0, 4),
    'MoveToRegion-Demo-v0': (0, 2),
    'cheetah-run': (3.5, 7),
    'finger-spin': (0, 2.5)
}

for env in envs:
    plot_config = froco_plot_config_lookup[env]
    filtered_dfs = filter_loss_dfs(raw_loss_data[env], 
                                   plot_config['blacklist_terms'], 
                                   plot_config['whitelist_lookups'], 
                                   plot_config['control_idx'],
                                   verbose=False)
    print(f"Start plotting curves for env {env}...")
    for data_source, df in filtered_dfs.items():
        print(set(df['exp_ident']))
        kwargs = {}
        if data_source == 'Demos & Random Rollouts':
            kwargs['show_ylabel'] = True
            if env == 'MatchRegions-Demo-v0':
                kwargs['show_legend'] = True
        if env in env_ylim.keys():
            kwargs['ylim'] = env_ylim[env]
        env_plot = plot_loss_curves(env, data_source, df, algo_order, **kwargs)
        display(env_plot)
        filename = f"{env}-{data_source}".replace(" ", "").replace("&", "-")
        save_eps(env_plot, filename, 'froco_loss')

In [None]:
# To sync plots: '/home/cody/il-representations/analysis/plots/cluster-2021-01-29-set3-try4/`

# Ablations Plots

In [None]:
control_idx = 'Temporal CPC'

envs = ['cheetah-run', 'finger-spin', 'MatchRegions-Demo-v0', 'MoveToRegion-Demo-v0']


ablation_blacklist_terms = ['froco', 'newbcaugs']
ablation_algo_lookups = {
    '^ablation_icml_tcpc_no_augs_': "Temporal CPC - No Augmentations",
    '^ablation_icml_tcpc_momentum_': "Momentum Temporal CPC", 
    "^icml_ac_tcpc_": "Action Conditioned TCPC", 
    "^ablation_icml_tceb_": "Temporal CEB", 
    "^ablation_icml_four_tcpc_": "Temporal CPC (t=4)",
    "^icml_identity_cpc": "Temporal CPC",
    "^ablation_icml_identity_cpc_": "Identity CPC (Not temporal)", # TODO fix overlapping issue
}

ablation_data_lookups = {
    "cfg_data_repl_random": "Random Rollouts", 
}

ablation_whitelist_lookups = {'data_lookups': ablation_data_lookups, 'algo_lookups': ablation_algo_lookups}


## ***CHANGE THESE NAMES FOR NEW PLOTS *** 
ablation_configs = dict(control_idx=control_idx, 
                                blacklist_terms=ablation_blacklist_terms, 
                                whitelist_lookups=ablation_whitelist_lookups) 

ablation_plot_config_lookup = {'MoveToRegion-Demo-v0': ablation_configs, 
                             'MatchRegions-Demo-v0': ablation_configs, 
                             'cheetah-run': ablation_configs, 
                             'finger-spin': ablation_configs}

### Ablation Return

In [None]:
## Configs for baseline plots 

ret_col_lookup = {
    'MatchRegions-Demo-v0': 'Average on all envs', 
    'MoveToRegion-Demo-v0': 'Average on all envs', 
    'MoveToCorner-Demo-v0': 'Average on all envs', 
    'finger-spin': 'return_mean', 
    'cheetah-run': 'return_mean'
}

# Set plot properties according to where we want to position it in paper
 # NOTE these are inverted left/right atm 
right_envs = ['cheetah-run']
left_envs = ['MoveToRegion-Demo-v0']
center_envs = ['finger-spin', 'MatchRegions-Demo-v0']

algo_order = ["Temporal CPC", "Identity CPC (Not temporal)", "Temporal CPC (t=4)", "Temporal CEB", "Action Conditioned TCPC",
              "Momentum Temporal CPC", "Temporal CPC - No Augmentations"]



for env in envs:
    print(f"Creating ablations plot for {env}")
    plot_config = ablation_plot_config_lookup[env]
    pivoted_dfs = create_pivoted_dfs(raw_return_dataframes, ret_col_lookup, 
                                     plot_config['blacklist_terms'], plot_config['whitelist_lookups'], 
                                     plot_config['control_idx'], algo_order, verbose=False)
    pooled_dfs = create_pooled_dfs(pivoted_dfs, control_idx=plot_config['control_idx'])
    kwargs = {}

    if env in left_envs:
        kwargs = {'show_ylabel': False}
    if env in center_envs: 
        kwargs = {'show_ylabel': False}
    env_plot = task_return_heatmap(env, pooled_dfs, control_key=control_idx, narrow=True, **kwargs)
    display(env_plot)
    save_eps(env_plot, env, 'ablation_return')

# to be paranoid
del pivoted_dfs 
del pooled_dfs

### Ablation AUC

In [None]:

ret_col_lookup = {
    'MatchRegions-Demo-v0': 'step 40', 
    'MoveToRegion-Demo-v0': 'step 40', 
    'MoveToCorner-Demo-v0': 'step 40', 
    'finger-spin': 'step 400', 
    'cheetah-run': 'step 400'
}

 # NOTE these are inverted left/right atm 
right_envs = ['cheetah-run']
left_envs = ['MoveToRegion-Demo-v0']
center_envs = ['finger-spin', 'MatchRegions-Demo-v0']

algo_order = ["Temporal CPC", "Identity CPC (Not temporal)", "Temporal CPC (t=4)", "Temporal CEB", "Action Conditioned TCPC",
              "Momentum Temporal CPC", "Temporal CPC - No Augmentations"]



for env in envs:
    plot_config = ablation_plot_config_lookup[env]
    pivoted_dfs = create_pivoted_dfs(raw_auc_dataframes, ret_col_lookup, 
                                     plot_config['blacklist_terms'], plot_config['whitelist_lookups'], 
                                     plot_config['control_idx'], algo_order, verbose=False)
    pooled_dfs = create_pooled_dfs(pivoted_dfs, control_idx=plot_config['control_idx'])
    kwargs = {}

    if env in left_envs:
        kwargs = {'show_ylabel': False}
    if env in center_envs: 
        kwargs = {'show_ylabel': False}
    env_plot = task_auc_heatmap(env, pooled_dfs, control_key=control_idx, show_sig=True, narrow=True, **kwargs)
    display(env_plot)
    save_eps(env_plot, env, 'ablation_auc')

# to be paranoid
del pivoted_dfs 
del pooled_dfs

## Hail Mary


In [None]:
# Goal: For Magical, 

In [None]:
magical_envs = ['MatchRegions-Demo-v0', 'MoveToRegion-Demo-v0', 'MoveToCorner-Demo-v0']
columns_to_not_avg = ['Average on all envs', 'MatchRegions-TestColour', 
                      'MoveToRegion-TestColour', 'MoveToCorner-TestColour']

def create_non_colour_average_column(base_df_dict):
    new_df_dict = dict()
    for task in base_df_dict.keys():
        if task not in magical_envs: 
            new_df_dict[task] = base_df_dict[task]
        else:
            new_df_dict[task] = dict()
            old_mean_df = base_df_dict[task]['mean']
            new_df_dict[task]['mean'] = old_mean_df.copy(deep=True)
            new_df_dict[task]['mean']['Manual Average'] = old_mean_df[[col for col in old_mean_df.columns 
                                                                    if col not in columns_to_not_avg]].mean(axis=1)
            old_std_df = base_df_dict[task]['std']
            new_df_dict[task]['std'] = old_std_df.copy(deep=True)
            new_df_dict[task]['std']['Manual Average'] = old_std_df[[col for col in old_std_df.columns 
                                                                    if col not in columns_to_not_avg]].mean(axis=1)  
            
            new_df_dict[task]['n'] = base_df_dict[task]['n'].copy(deep=True)
            new_df_dict[task]['n']['Manual Average'] = new_df_dict[task]['n']['Average on all envs']
            
    return new_df_dict
            
        
    

In [None]:
control_idx = 'Control'

envs = ['cheetah-run', 'finger-spin', 'MatchRegions-Demo-v0', 'MoveToRegion-Demo-v0']


hail_mary_blacklist_terms = ['froco', 'ablation', 'newbcaugs']
hail_mary_algo_lookups = {
    '^icml_inv_dyn_': "Inverse Dynamics ", 
    "^icml_ac_tcpc_": "Action Conditioned TCPC ", 
    "^icml_vae_": "VAE ", 
    "^icml_dynamics_": "Dynamics Model ",
    "^control_ortho_init_": "Control",
    "^icml_identity_cpc_": "Temporal CPC (TCPC)", #due to error
}

data_lookups = {
    "cfg_data_repl_random": "Random Rollouts", 
    "cfg_data_repl_demos_random": "Demos & \nRandom Rollouts", 
    "cfg_data_repl_demos_magical_mt": "Multitask Demos", 
    "cfg_data_repl_rand_demos_magical_mt": "Multitask Demos & \nRandom Rollouts",
}

hail_mary_whitelist_lookups = {'data_lookups': data_lookups, 'algo_lookups': hail_mary_algo_lookups}


## ***CHANGE THESE NAMES FOR NEW PLOTS *** 
hail_mary_configs = dict(control_idx=control_idx, 
                         blacklist_terms=hail_mary_blacklist_terms, 
                         whitelist_lookups=hail_mary_whitelist_lookups) 

hail_mary_plot_config_lookup = {'MoveToRegion-Demo-v0': hail_mary_configs, 
                             'MatchRegions-Demo-v0': hail_mary_configs, 
                             'cheetah-run': hail_mary_configs, 
                             'finger-spin': hail_mary_configs
                               }

### Hail Mary Return

In [None]:
## Configs for baseline plots 

new_ret_col_lookup = {
    'MatchRegions-Demo-v0': 'Manual Average', 
    'MoveToRegion-Demo-v0': 'Manual Average', 
    'MoveToCorner-Demo-v0': 'Manual Average', 
    'finger-spin': 'return_mean', 
    'cheetah-run': 'return_mean'
}

range_lookup = {
    'MatchRegions-Demo-v0': (-.1, .1), 
    'MoveToRegion-Demo-v0': (-.1, .1), 
    'MoveToCorner-Demo-v0': (-.1, .1), 
    'finger-spin': (-100, 100), 
    'cheetah-run': (-100, 100)
}

# Set plot properties according to where we want to position it in paper
right_envs = ['MatchRegions-Demo-v0', 'cheetah-run']
left_envs = ['MoveToRegion-Demo-v0', 'finger-spin']

algo_order = ["Control", "Temporal CPC (TCPC)", "Action Conditioned TCPC ", "VAE ", "Dynamics Model ",
              "Inverse Dynamics "]



for env in envs:
    print(f"Creating plot for {env}")
    # CHANGE ME FOR EACH PLOT
    plot_config = hail_mary_plot_config_lookup[env]
    # CHANGE ME IF SWITCHING BETWEEN AUC AND RETURN 
    manually_averaged_dfs = create_non_colour_average_column(raw_return_dataframes)
    pivoted_dfs = create_pivoted_dfs(manually_averaged_dfs, new_ret_col_lookup, 
                                     plot_config['blacklist_terms'], plot_config['whitelist_lookups'], 
                                     plot_config['control_idx'], algo_order, verbose=False)
    pooled_dfs = create_pooled_dfs(pivoted_dfs, control_idx=plot_config['control_idx'])
    kwargs = {}
    if env in left_envs:
        kwargs = {'show_ylabel': False}
    min_max_vals = range_lookup[env]
    env_plot = task_return_heatmap(env, pooled_dfs, control_key=control_idx,min_max_vals=min_max_vals , **kwargs)
    display(env_plot)
    # CHANGE ME TO SAVE PROPERLY
    save_eps(env_plot, env, 'hail_mary_return')

# # to be paranoid
del pivoted_dfs 
del pooled_dfs

### Hail Mary AUC

In [None]:
## Configs for baseline plots 

new_ret_col_lookup = {
    'MatchRegions-Demo-v0': 'step 40', 
    'MoveToRegion-Demo-v0': 'step 40', 
    'MoveToCorner-Demo-v0': 'step 40', 
    'finger-spin': 'step 400', 
    'cheetah-run':'step 400'
}


# Set plot properties according to where we want to position it in paper
right_envs = ['MatchRegions-Demo-v0', 'cheetah-run']
left_envs = ['MoveToRegion-Demo-v0', 'finger-spin']

algo_order = ["Control", "Temporal CPC (TCPC)", "Action Conditioned TCPC ", "VAE ", "Dynamics Model ",
              "Inverse Dynamics "]



for env in envs:
    print(f"Creating plot for {env}")
    # CHANGE ME FOR EACH PLOT
    plot_config = hail_mary_plot_config_lookup[env]
    # CHANGE ME IF SWITCHING BETWEEN AUC AND RETURN 
    pivoted_dfs = create_pivoted_dfs(raw_auc_dataframes, new_ret_col_lookup, 
                                     plot_config['blacklist_terms'], plot_config['whitelist_lookups'], 
                                     plot_config['control_idx'], algo_order, verbose=False)
    pooled_dfs = create_pooled_dfs(pivoted_dfs, control_idx=plot_config['control_idx'])
    kwargs = {}
    if env in left_envs:
        kwargs = {'show_ylabel': False}
    env_plot = task_auc_heatmap(env, pooled_dfs, control_key=control_idx, **kwargs)
    display(env_plot)
    # CHANGE ME TO SAVE PROPERLY
    save_eps(env_plot, env, 'hail_mary_auc')

# # to be paranoid
del pivoted_dfs 
del pooled_dfs

# Older Code, Possibly Deprecated

In [None]:
# import csv
# def prepare_files(index, mode, exp_index, out_dir):
#     """
#     Create a folder named `out_dir`. This really just copies over files from il_train or il_test, as appropriate.
#     For instance, if il_train looks like this:
    
#     il_train
#     │   ├── 1
#     │   │   ├── ...
#     │   │   ├── config.json
#     │   │   └── progress.csv
#     │   └── _sources
#     …
    
#     Then the ouptut will look like this:
#     ├── progress
#     │   └── 1
#     │       ├── params.json   (same as config.json)
#     │       └── progress.csv
#     …

#     After you run this, you can execute viskit with: python viskit/frontend.py path/to/out_dir/
#     """
#     experiments = index.search(mode=mode)
#     # compute merged configs (nested/hierarchical dicts), and
#     # also throw out experiments with no progress.csv
#     hierarchical_dicts = []
#     new_experiments = []
#     for experiment in experiments:
#         if not experiment.progress_path:
#             print("Skipping experiment", experiment.ident, "because it has no progress.csv")
#             continue
#         merged_config = experiment.get_merged_config(exp_index)
#         hierarchical_dicts.append(merged_config)
#         new_experiments.append(experiment)
#     experiments = new_experiments

#     # first flatten all dicts
#     dicts = [dict(flatten_dict(d)) for d in hierarchical_dicts]
    
#     # make sure that every dict has every key
#     all_keys = set()
#     for d in dicts:
#         all_keys |= d.keys()
#     for d in dicts:
#         for new_key in all_keys - d.keys():
#             d[new_key] = None
    
#     # now generate outputs for experiments
#     for flat_config, experiment in zip(dicts, experiments):
#         exp_out_dir = os.path.join(out_dir, experiment.ident.replace('/', '-'))
#         os.makedirs(exp_out_dir, exist_ok=True)

#         params_json_path = os.path.join(exp_out_dir, 'params.json')
#         with open(params_json_path, 'w') as fp:
#             json.dump(flat_config, fp)

#         progress_out_path = os.path.join(exp_out_dir, 'progress.csv')
#         shutil.copyfile(experiment.progress_path, progress_out_path)
        
#         if mode == 'il_test':
#             eval_json_path = experiment.eval_json_path
#             with open(eval_json_path, 'r') as fp:
#                 eval_dict = json.load(fp)
                
#             result_keys, result_vals = [], []
#             for key, value in eval_dict.items():
# #                 print(eval_dict.keys())
#                 if key == 'return_mean' and value < 0.1:
#                     print(eval_dict['policy_path'])
#                 if isinstance(value, str) or isinstance(value, int) or isinstance(value, float):
#                     result_keys.append(key)
#                     result_vals.append(value)
            
#             with open(progress_out_path, mode='w') as result_file:
#                 result_writer = csv.writer(result_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
#                 result_writer.writerow(result_keys)
#                 result_writer.writerow(result_vals)

                
# prepare_files(subexp_index, 'repl', subexp_index, 'viskit-repl')
# prepare_files(subexp_index, 'il_train', subexp_index, 'viskit-il-train')
# prepare_files(subexp_index, 'il_test', subexp_index, 'viskit-il-test')

# Plot

In [None]:
# # import matplotlib
# # matplotlib.use("tkagg")
# import matplotlib.pyplot as plt
# import pandas as pd
# import seaborn as sns
# from pathlib import Path

# # include_ident_keywords = ['control', 'tcpc']
# # exclude_ident_keywords = ['mt', 'rand_only']

# include_ident_keywords = []
# exclude_ident_keywords = []

# def get_data(mode, data_type, include_ident_kw=None, exclude_ident_kw=None):
#     assert mode in ['repl', 'il_train', 'il_test']
#     assert data_type in ['loss', 'return']
#     expts = subexp_index.search(mode=mode)
#     all_configs = [subexp.get_merged_config(subexp_index) for subexp in expts]
#     base_config, flat_configs = simplify_config_dicts(all_configs)
#     flat_config_tups = [tuple(sorted(d.items())) for d in flat_configs]
#     subexp_by_benchmark = {}
#     for flat_cfg, subexp in zip(flat_config_tups, expts):
#         bench_key = tuple((k, v) for k, v in flat_cfg if k.startswith('env_') or k.startswith('venv_'))
#         subexp_by_benchmark.setdefault(bench_key, []).append((flat_cfg, subexp))
    
#     """
#         ret_dict has structure {'env_1': {'exp_ident_1': required_data, 'exp_ident_2': required_data, ...}, ...}
#         required_data can be either a list (i.e. loss over time) or a number (int or float, like return_mean)
#     """
#     ret_dict = {}
#     for idx, (bench_key, cfgs_subexps) in enumerate(subexp_by_benchmark.items(), start=1):
#         # cluster subexperiments by config
#         by_cfg = {}
#         for tup_cfg, subexp in cfgs_subexps:
#             tup_cfg = tuple(k for k in tup_cfg if k not in bench_key)
#             by_cfg.setdefault(tup_cfg, []).append(subexp)

#         task_name = bench_key[1][1]
            
#         ret_dict[task_name] = {}
#         for cfg, subexp in cfgs_subexps:
#             d = dict(cfg)
#             exp_ident = d['il_train/exp_ident']
        
#             if subexp.progress_path:
#                 try:
#                     df = pd.read_csv(subexp.progress_path)
#                 except:
#                     print(f'Read csv reported error for exp {exp_ident}, skipping...')
#                     continue
#                 full_length = 400 if d['env_cfg/benchmark_name'] == 'dm_control' else 40
#                 if len(df['loss']) != full_length:
#                     print(f'Experiment {exp_ident} only has len(loss) {len(df["loss"])}, skipping... ')
#                     continue
                
#                 ret_dict[task_name][exp_ident] = []
#                 if data_type == 'loss':
#                     ret_dict[task_name][exp_ident].append(df['loss'])
                    
#     print(ret_dict.keys())
#     return ret_dict


In [None]:
# def plot_curves(data_dict):
#     sns.set(rc={'figure.figsize':(7, 6)})
#     for task_key, exp_results in data_dict.items():
#         df = None
#         col_name = []
#         plt.figure()
#         for exp_ident, value, in exp_results.items():
#             col_name = [f"seed_{x}" for x in range(len(value))]
#             col_name += ['step', 'exp_ident']
#             value.append([s for s in range(1, len(value[0])+1)])
#             value.append([exp_ident for s in range(1, len(value[0])+1)])
#             value = np.array(value).transpose(1, 0)
#             sub_df = pd.DataFrame(data=value, columns=col_name)
#             df = pd.concat([df, sub_df])
#         df = pd.melt(df, id_vars=['step', 'exp_ident'])
#         df['step'] = pd.to_numeric(df['step'])
#         df['value'] = pd.to_numeric(df['value'])
#         print(df.dtypes)
        
#         ax = sns.lineplot(x='step', y='value', hue='exp_ident', data=df)
#         plt.setp(ax.get_legend().get_texts(), fontsize='12')
#         plt.legend(bbox_to_anchor=(1.01, 1),borderaxespad=0)
#         ax.set_title(bench_key)
    
# plot_curves(get_data('il_train', 
#                      'loss', 
#                      include_ident_kw=include_ident_keywords,
#                      exclude_ident_kw=exclude_ident_keywords))

In [None]:
# Interpret encoders

# Interpret encoders

Save the encoder interpretation videos. Each sub_exp might take one or two minutes to save.

In [None]:
# import glob
# import subprocess
# from pathlib import Path

# train_expts = subexp_index.search(mode='il_test')
# all_configs = [subexp.get_merged_config(subexp_index) for subexp in train_expts]
# base_config, flat_configs = simplify_config_dicts(all_configs)
# flat_config_tups = [tuple(sorted(d.items())) for d in flat_configs]
# subexp_by_benchmark = {}
# for flat_cfg, subexp in zip(flat_config_tups, train_expts):
#     bench_key = tuple((k, v) for k, v in flat_cfg if k.startswith('env_') or k.startswith('venv_'))
#     subexp_by_benchmark.setdefault(bench_key, []).append((flat_cfg, subexp))

# # Create a folder to save videos
# Path(f"./runs/{cluster_subpath.split('/')[1]}").mkdir(parents=True, exist_ok=True)
# interp_algo = 'saliency'
                    
# for idx, (bench_key, cfgs_subexps) in enumerate(subexp_by_benchmark.items(), start=1):
#     # cluster subexperiments by config
#     by_cfg = {}
#     for tup_cfg, subexp in cfgs_subexps:
#         tup_cfg = tuple(k for k in tup_cfg if k not in bench_key)
#         by_cfg.setdefault(tup_cfg, []).append(subexp)

#     for tup_cfg, subexp in by_cfg.items():
#         exp = subexp[0]
#         encoder_path = exp.config['encoder_path']
#         if encoder_path:
#             for prefix, replacement in path_translations.items():
#                 if encoder_path.startswith(prefix):
#                     encoder_path = replacement + encoder_path[len(prefix):]
#             command = "python ../src/il_representations/scripts/interpret.py with "
#             command += f"log_dir=runs/{cluster_subpath.split('/')[1]} "
#             command += f"env_cfg.benchmark_name={exp.config['env_cfg']['benchmark_name']} "
#             command += f"env_cfg.task_name={exp.config['env_cfg']['task_name']} "
#             command += f"save_video=True "
#             command += f"chosen_algo={interp_algo} "
#             command += f"encoder_path={encoder_path} "
#             command += f"filename={exp.config['env_cfg']['task_name']}_{exp.config['exp_ident']} "

#             print(f"Generating videos for exp {exp.config['exp_ident']} on {exp.config['env_cfg']['task_name']}...")
#             process = subprocess.Popen(command.split(), stdout=subprocess.PIPE)
#             output, error = process.communicate()