In [None]:
import copy
import glob
import json
import os

##### These should be the only things you need to modify in this code block #####
runs_directory = "/scratch/sam/il-representations-gcp-volume/cluster-data/cluster-2020-09-28-hypot-1-for-real-hopefully-take-2/"
exp_name = "2"
exp_dir = os.path.join(runs_directory, 'chain_runs', exp_name)
#################################################################################

assert os.path.isdir(exp_dir)

# Collect Trial information

Some config keys are not useful for understanding the experiments -> Filter these when printing. 
All the strings here are some keys in the config that I (Cynthia) think might not be useful. You can add or delete config keys below, or just use these as default values.

In [None]:
##### These should be the only things you need to modify in this code block #####
repl_extra_config_key = ["demo_timesteps", "device", "n_envs", "ppo_finetune", "ppo_timesteps", 
                         "seed", "torch_num_threads", "unit_test_max_train_steps", "use_random_rollouts",
                         "benchmark", "algo_params/device", "algo_params/loss_calculator_kwargs", 
                         "algo_params/optimizer", "algo_params/preprocess_extra_context", 
                         "algo_params/preprocess_target", "algo_params/save_interval", "algo_params/scheduler",
                         "algo_params/scheduler_kwargs", "algo_params/shuffle_batches", 
                         "algo_params/target_pair_constructor_kwargs", "algo_params/unit_test_max_train_steps"]
il_train_extra_config_key = ["benchmark", "device_name", "final_pol_name", "gail", "seed", "torch_num_threads", 
                             "encoder_kwargs"]
il_test_extra_config_key = ["benchmark", "device_name", "run_id", "seed", "torch_num_threads", "write_video",
                            "video_file_name", "policy_path"]
#################################################################################

Within each run of the experiment, let's say `chain_runs/1/`, there can be many subfolders each representing one run of a given config. I call one end-to-end run like this "trial". So one complete trial can be linked with dirs like `repl/1`, `il_train/2`, and `il_test/2`, where `il_test/2`'s trained policy comes from `il_train/2`, and `il_train/2`'s trained encoder comes from `repl/1`.

In [None]:
# Define Trial object
class Trial:
    """
        The Trial object contains its run info.
        Each trial should have its repl_dir, il_train_dir, and il_test_dir. Depending on 
        the status of the trial, it might not have its il_train and il_test (yet).
        In this case, self.il_test_dir and self.il_train_dir will be None.
    """
    def __init__(self, last_run_dir):
        self.il_test_dir = None
        self.il_train_dir = None
        self.repl_dir = None
        self.set_dirs(last_run_dir)
        self.return_mean = -999
        self.get_return()
        
    def set_dirs(self, last_run_dir):
        if 'il_test' in last_run_dir:
            self.il_test_dir = last_run_dir
            test_config = self.get_config('il_test')
            last_run_dir = '/'.join(self.il_test_dir.split('/')[:-2] + 
                                    ['il_train', test_config['policy_path'].split('/')[-2]])
        if 'il_train' in last_run_dir:
            self.il_train_dir = last_run_dir
            train_config = self.get_config('il_train')
            if train_config['encoder_path'] == None:  # This trial doesn't take repl training
                last_run_dir = ""
            else:
                last_run_dir = '/'.join(self.il_train_dir.split('/')[:-2] + 
                                        ['repl', train_config['encoder_path'].split('/')[-4]])
        if 'repl' in last_run_dir:
            self.repl_dir = last_run_dir
    
    def get_config(self, mode, key_to_remove=[]):
        # Get the trial config specified by mode, with optional key_to_remove to hide unuseful info.
        assert mode in ['repl', 'il_train', 'il_test']
        if mode == 'repl' and self.repl_dir is None:
            return
        with open(f'{self.__dict__[mode + "_dir"]}/config.json') as json_file:
                config = json.load(json_file)
                
        def remove_dict_entry(dic, key_to_remove):
            dic_copy = copy.deepcopy(dic)
            for key in key_to_remove:
                if '/' in key:
                    del dic_copy[key.split('/')[0]][key.split('/')[1]]
                else:
                    del dic_copy[key]
            return dic_copy
        config = remove_dict_entry(config, key_to_remove)
        return config
    
    def get_return(self):
        if self.il_test_dir:
            result_file = f'{self.il_test_dir}/eval.json'
            if os.path.isfile(result_file):
                with open(result_file) as json_file:
                        result = json.load(json_file)
                self.return_mean = result['return_mean']
            else:
                print(f'WARNING - {result_file} does not exist.')

    def __str__(self):
        return '\n'.join([f'{key}: {value}' for key, value in self.__dict__.items()])

In [None]:
# Collect trial objects from specified dirs
def get_trial_objects(root_dir, exp_type, trial_dirs):
    """
    Return a list of trial objects by inspecting subdirs of run_dir, 
    which typically is either repl, il_train, or il_test.
    """
    trial_list = []
    for trial_dir in trial_dirs:
        if trial_dir in ['_sources', 'progress']:
            continue
        dir_abspath = os.path.join(root_dir, exp_type, trial_dir)
        trial_list.append(Trial(dir_abspath))
    return trial_list

In [None]:
# A tool to pretty print nested dictionaries
def pretty_print(d, indent=0):
   for key, value in d.items():
      print('\t' * indent + str(key))
      if isinstance(value, dict):
        for key_, value_ in value.items():
             print('\t' * (indent+1) + str(key_) + ': ' + str(value_))
      else:
        print('\t' * (indent+1) + str(value))

In [None]:
# Identify trial type and get trial objects accordingly
trials = {
    'full_exp': [],
    'il_train_only': [],
    'repl_only': []
}

il_test_dirs = os.listdir(os.path.join(exp_dir, 'il_test'))
trials['full_exp'] = get_trial_objects(exp_dir, 'il_test', il_test_dirs)

il_train_dirs = os.listdir(os.path.join(exp_dir, 'il_train'))
recorded_train_dirs = [t.il_train_dir.split('/')[-1] for t in trials['full_exp']]
il_train_only_dirs = [d for d in il_train_dirs if d not in recorded_train_dirs]
trials['il_train_only'] = get_trial_objects(exp_dir, 'il_train', il_train_only_dirs)

if os.path.isdir(os.path.join(exp_dir, 'repl')):
    repl_dirs = os.listdir(os.path.join(exp_dir, 'repl'))
    recorded_repl_dirs = [t.repl_dir.split('/')[-1] for t in trials['full_exp'] + trials['il_train_only']]
    repl_only_dirs = [d for d in repl_dirs if d not in recorded_repl_dirs]
    trials['repl_only'] = get_trial_objects(exp_dir, 'repl', repl_only_dirs)

print('\nExperiment info: \n')
print('\n'.join([f'{trial_type}: {len(trial_list)} runs' for trial_type, trial_list in trials.items()]))

# Print information of trials with the highest test score

Note that if a trial's final return_mean = -999, then it means the trial's `il_test/eval.json` file could not be found.

Top-`n` trials will be printed if `n` is smaller than the number of trials you have. Otherwise it will just print out all the trials' scores.

In [None]:
##### These should be the only things you need to modify in this code block #####
sorted_trials = sorted(trials['full_exp'], key=lambda x: x.return_mean, reverse=True)
n = 10
#################################################################################

selected_trials = sorted_trials[:n] if n < len(sorted_trials) else sorted_trials
for count, trial in enumerate(selected_trials):
    print(f"{'='*40} Trial {count}/{len(sorted_trials)} {'='*40}")
    print(f"Final return_mean: {trial.return_mean}")
    if trial.repl_dir:
        print(f"repl_config: ")
        pretty_print(trial.get_config('repl', repl_extra_config_key))
    print(f"il_train_config: ")
    pretty_print(trial.get_config('il_train', il_train_extra_config_key))
    print(f"il_test_config: ")
    pretty_print(trial.get_config('il_test', il_test_extra_config_key))

# Prepare files for analyzing with Viskit

If you want to see a trial object's information, you can do print(trial).

In [None]:
def prepare_files(mode):
    """
    Create a folder named "progress" under your specified dir. The structure will be like:
    il_train
        └── progress
            └── 1
                ├── params.json   (same as config.json)
                └── progress.csv


    then the Viskit analysis can be called by: python viskit/frontend.py {dest_dir}/progress/
    """
    assert mode in ['il_train', 'repl']
    
    ##### This should be the only things you need to modify in this code block #####
    dest_dir = f"/home/cynthiachen/il-representations/runs/chain_runs/{exp_name}/{mode}"
    ################################################################################
    
    # Create progress and its subdirs 
    dir_name = os.path.join(exp_dir, mode)
    trial_dirs = os.listdir(dir_name)
    
    for trial_dir in trial_dirs:
        os.system(f'mkdir -p {dest_dir}/progress/{trial_dir}')
        os.system(f'cp {dir_name}/{trial_dir}/config.json {dest_dir}/progress/{trial_dir}/params.json')
        os.system(f'cp {dir_name}/{trial_dir}/progress.csv {dest_dir}/progress/{trial_dir}/')
    
    # Copy and rename files
    os.system(f'cp {dir_name}/*/config.json ')

prepare_files('il_train')

# Filter for runs whose performance exceeds vanilla IL baseline

In [None]:
##### These should be the only things you need to modify in this code block #####
glob_query = "chain_runs/[0-9]*/il_test/[0-9]*/run.json"  # Make this a regex that will capture all the runs
baseline_filename = glob.glob(os.path.join(runs_directory, glob_query))[0]  # Make this the baseline
# baseline_filename = "/scratch/sam/il-representations-gcp-volume/cluster-data/cluster-2020-09-28-hypot-1-for-real-hopefully-take-2/chain_runs/7/il_test/5/run.json"

In [None]:
# Glob the test result files
glob_arg = os.path.join(runs_directory, glob_query)
evaluation_files = glob.glob(glob_arg)

In [None]:
# Helper function to parse mean return from .json file
def get_evaluation_result(filename):
    with open(filename, 'r') as f:
        results = json.load(f)
        try:
            return results['result']['return_mean']['value']
        except (TypeError, KeyError):
            return float("-inf")

In [None]:
# Parse the test result files for mean return
evaluation_results = []
for filename in evaluation_files:
    evaluation_result = get_evaluation_result(filename)
    evaluation_results.append(evaluation_result)

In [None]:
# Filter for test result files that exceed IL baseline return
baseline_return = get_evaluation_result(baseline_filename)
runs_with_improvement = []
for run_file, run_return in zip(evaluation_files, evaluation_results):
    if run_return > baseline_return:
        run_name = os.path.dirname(run_file)
        runs_with_improvement.append(run_name)

In [None]:
print(evaluation_files)  # These are all the runs that were being considered
print('\n')
print(evaluation_results)  # These are the 'return_mean' 'value's corresponding to those runs
print('\n')
print(runs_with_improvement)  # These are the runs with performance greater than the baseline

# Interpret encoders

In [None]:
# Make sure your cwd is the il-representations directory
if os.getcwd().split('/')[-1] == 'analysis':
    os.chdir("..")
print('Check cwd', os.getcwd())

from il_representations.scripts.interpret import (prepare_network, process_data, save_img, saliency_, deep_lift_, 
integrated_gradient_, layer_conductance_, layer_gradcam_, layer_act_, choose_layer, interp_ex)
from il_representations.envs.config import benchmark_ingredient
import il_representations.envs.auto as auto_env

import sacred
import numpy
from sacred.observers import FileStorageObserver
from sacred import Experiment
from stable_baselines3.common.utils import get_device
from captum.attr import LayerActivation, LayerGradientXActivation

In [None]:
render_interp_ex = Experiment('render_interp', ingredients=[benchmark_ingredient, interp_ex], interactive=True)
interp_ex.observers.append(FileStorageObserver('runs/interpret_runs'))

##### These should be the only things you need to modify in this code block #####
policy_paths = [os.path.join(t.il_train_dir, 'policy_final.pt') for t in trials['full_exp']]
configs = [t.get_config('repl', key_to_remove=repl_extra_config_key) for t in trials['full_exp']]

@interp_ex.config
def config():
    encoder_paths = policy_paths
    for path in encoder_paths:
        assert os.path.isfile(path), f'Please double check if {path} exists.'
    
    # Data settings
    # The benchmark is set by detecting il_representations/envs/config's bench_defaults.benchmark_name
    imgs = [8]  # index of the image to be inspected (int)
    assert all(isinstance(im, int) for im in imgs), 'imgs list should contain integers only'

    verbose = False
    #################################################################################

In [None]:
render_interp_ex = Experiment('render_interp', ingredients=[benchmark_ingredient, interp_ex], interactive=True)

@render_interp_ex.main
def run():
    venv = auto_env.load_vec_env()
    networks = prepare_network(venv)
    images, labels = process_data()
    return networks, images, labels

r = render_interp_ex.run()
networks = r.result[0]
images = r.result[1]
labels = r.result[2]
verbose = True
log_dir = None

In [None]:
def saliency():
    for img, label in zip(images, labels):
        for config, network in zip(configs, networks):
            print('='*50)
            pretty_print(config)
            original_img = img[0].permute(1, 2, 0).detach().numpy()
            saliency_(network, img, label, original_img, log_dir, False)

saliency()

In [None]:
def deep_lift():
    for img, label in zip(images, labels):
        for config, network in zip(configs, networks):
            print('='*50)
            original_img = img[0].permute(1, 2, 0).detach().numpy()
            deep_lift_(network, img, label, original_img, log_dir, False)

deep_lift()