In [None]:
import os
import json

runs_directory = "/scratch/sam/il-representations-gcp-volume/cluster-data/cluster-2020-09-27T03:04Z"
exp_name = "10"
exp_dir = os.path.join(runs_directory, 'chain_runs', exp_name)
assert os.path.isdir(exp_dir)

# Collect Trial information

In [None]:
# Define Trial object
class Trial:
    """
        The Trial object contains its directory info.
        Each trial should have its repl_dir, il_train_dir, and il_test_dir. Depending on 
        the status of the trial, it might not have its il_train and il_test started (yet).
        In this case, self.il_test_dir and self.il_train_dir will be None.
    """
    def __init__(self, last_run_dir):
        self.il_test_dir = None
        self.il_train_dir = None
        self.repl_dir = None
        self.set_dirs(last_run_dir)
        
    def set_dirs(self, last_run_dir):
        if 'il_test' in last_run_dir:
            self.il_test_dir = last_run_dir
            with open(f'{self.il_test_dir}/config.json') as json_file:
                test_config = json.load(json_file)
            last_run_dir = '/'.join(self.il_test_dir.split('/')[:-2] + 
                                    ['il_train', test_config['policy_path'].split('/')[-2]])
        if 'il_train' in last_run_dir:
            self.il_train_dir = last_run_dir
            with open(f'{self.il_train_dir}/config.json') as json_file:
                train_config = json.load(json_file)
            last_run_dir = '/'.join(self.il_train_dir.split('/')[:-2] + 
                                    ['repl', train_config['encoder_path'].split('/')[-4]])
        if 'repl' in last_run_dir:
            self.repl_dir = last_run_dir

    def __str__(self):
        return '\n'.join([f'{key}: {value}' for key, value in self.__dict__.items()])

In [None]:
# Collect trial objects from specified dirs
def get_trial_objects(root_dir, exp_type, trial_dirs):
    """
    Return a list of trial objects by inspecting subdirs of run_dir, 
    which typically is either repl, il_train, or il_test.
    """
    trial_list = []
    for trial_dir in trial_dirs:
        if trial_dir == '_sources':
            continue
        dir_abspath = os.path.join(root_dir, exp_type, trial_dir)
        trial_list.append(Trial(dir_abspath))
    return trial_list

In [None]:
# Identify trial type and get trial objects accordingly
trials = {
    'full_exp': [],
    'il_train_only': [],
    'repl_only': []
}

il_test_dirs = os.listdir(os.path.join(exp_dir, 'il_test'))
trials['full_exp'] = get_trial_objects(exp_dir, 'il_test', il_test_dirs)

il_train_dirs = os.listdir(os.path.join(exp_dir, 'il_train'))
recorded_train_dirs = [t.il_train_dir.split('/')[-1] for t in trials['full_exp']]
il_train_only_dirs = [d for d in il_train_dirs if d not in recorded_train_dirs]
trials['il_train_only'] = get_trial_objects(exp_dir, 'il_train', il_train_only_dirs)

repl_dirs = os.listdir(os.path.join(exp_dir, 'repl'))
recorded_repl_dirs = [t.repl_dir.split('/')[-1] for t in trials['full_exp'] + trials['il_train_only']]
repl_only_dirs = [d for d in repl_dirs if d not in recorded_repl_dirs]
trials['repl_only'] = get_trial_objects(exp_dir, 'repl', repl_only_dirs)

print('Experiment info: \n')
print('\n'.join([f'{trial_type}: {len(trial_list)} runs' for trial_type, trial_list in trials.items()]))

# Filter for runs whose performance exceeds vanilla IL baseline

In [None]:
import glob

In [None]:
# Glob the test result files
glob_query = "chain_runs/[0-9]*/il_test/[0-9]*/run.json"
glob_arg = os.path.join(runs_directory, glob_query)
evaluation_files = glob.glob(glob_arg)

In [None]:
# Helper function to parse mean return from .json file
def get_evaluation_result(filename):
    with open(filename, 'r') as f:
        results = json.load(f)
        if results['result'] is not None:
            return results['result']['return_mean']['value']
        else:
            return float("-inf")

In [None]:
# Parse the test result files for mean return
evaluation_results = []
for filename in evaluation_files:
    evaluation_result = get_evaluation_result(filename)
    evaluation_results.append(evaluation_result)
print(evaluation_results)

In [None]:
# Filter for test result files that exceed IL baseline return
baseline_filename = evaluation_files[0]  #TODO put the filename of the baseline IL run here
baseline_return = get_evaluation_result(baseline_filename)
runs_with_improvement = []
for run_file, run_return in zip(evaluation_files, evaluation_results):
    if run_return > baseline_return:
        run_name = os.path.dirname(run_file)
        runs_with_improvement.append(run_name)

In [None]:
print(evaluation_files)
print('\n')
print(evaluation_results)
print('\n')
print(runs_with_improvement)