In [6]:
import json
import os
import numpy as np

In [7]:
def import_metrics_json(models_path, run_num):
    """
    Looks in {models_path}/{run_num}/metrics.json and returns the contents as a
    Python dictionary. Returns None if the path does not exist.
    """
    path = os.path.join(models_path, str(run_num), "metrics.json")
    if not os.path.exists(path):
        return None
    with open(path, "r") as f:
        return json.load(f)

In [12]:
def get_best_metric(models_path, metric_extract_func, metric_compare_func):
    """
    Given the path to a set of runs, determines the run with the best metric value,
    where the metric value is fetched by `metric_extract_func`. This function must
    take the imported metrics JSON and return the (scalar) value to use for
    comparison. The best metric value is determiend by `metric_compare_func`, which
    must take in two arguments, and return whether or not the _first_ one is better.
    Returns the number of the run, the value associated with that run, and a dict of
    all the values used for comparison.
    """
    # Get the metrics, ignoring empty or nonexistent metrics.json files
    metrics = {run_num : import_metrics_json(models_path, run_num) for run_num in os.listdir(models_path)}
    metrics = {key : val for key, val in metrics.items() if val}  # Remove empties
    
    # Get the best value
    best_run, best_val, all_vals = None, None, {}
    for run_num in metrics.keys():
        try:
            val = metric_extract_func(metrics[run_num])
        except Exception:
            print("Warning: Was not able to extract metric for run %s" % run_num)
            continue
        all_vals[run_num] = val
        if best_val is None or metric_compare_func(val, best_val):
            best_val, best_run = val, run_num
    return best_run, best_val, all_vals

In [16]:
models_path = "/users/amtseng/att_priors/models/trained_models/profile_models/SPI1_prior_overfit/"
best_run, best_val, all_vals = get_best_metric(
    models_path,
    lambda metrics: np.min(metrics["summit_prof_nll"]["values"]),
    lambda x, y: x < y
)
print("Best run: %s" % best_run)
print("Associated value: %s" % best_val)
for key in sorted(all_vals.keys(), key=lambda x: int(x)):
    print(key, all_vals[key])

Best run: 15
Associated value: 104.38627925135756
1 105.23240835466534
2 108.44316901735286
3 151.87730817694498
4 151.86776815111844
5 151.80312478272788
6 104.48911513265023
7 151.77995798217728
8 152.06106237013918
9 104.7685573637444
10 151.82967111766132
11 107.64243772520153
12 104.65566430495839
13 151.92846948490546
14 152.0134945726129
15 104.38627925135756
16 151.8923343126129
17 151.8077179371474
18 104.67592518949661
19 104.88231685955147
20 104.92297000566589
21 104.71059927507306
22 151.8921340204271
23 151.7692143255662
24 106.04085368238636
25 151.8235596838277


In [19]:
m = import_metrics_json("/users/amtseng/att_priors/models/trained_models/profile_models/SPI1_prior_overfit/", 1)

In [20]:
m.keys()

dict_keys(['summit_count_mse', 'summit_count_pearson', 'summit_count_spearman', 'summit_prof_auprc_bin1', 'summit_prof_auprc_bin10', 'summit_prof_auprc_bin4', 'summit_prof_jsd', 'summit_prof_mse_bin1', 'summit_prof_mse_bin10', 'summit_prof_mse_bin4', 'summit_prof_nll', 'summit_prof_pearson_bin1', 'summit_prof_pearson_bin10', 'summit_prof_pearson_bin4', 'summit_prof_spearman_bin1', 'summit_prof_spearman_bin10', 'summit_prof_spearman_bin4', 'train_att_losses', 'train_batch_losses', 'train_corr_losses', 'train_epoch_loss', 'val_att_losses', 'val_batch_losses', 'val_corr_losses', 'val_epoch_loss'])

In [23]:
m["val_batch_losses"]

{'steps': [0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19],
 'timestamps': ['2019-11-21T01:25:47.719453',
  '2019-11-21T01:50:57.239655',
  '2019-11-21T02:16:10.823540',
  '2019-11-21T02:41:27.474379',
  '2019-11-21T03:06:15.592786',
  '2019-11-21T03:25:23.139656',
  '2019-11-21T03:46:01.214803',
  '2019-11-21T04:07:15.245930',
  '2019-11-21T04:27:45.186197',
  '2019-11-21T04:52:54.816325',
  '2019-11-21T05:19:14.704548',
  '2019-11-21T05:40:02.020635',
  '2019-11-21T06:02:50.087246',
  '2019-11-21T06:28:26.630961',
  '2019-11-21T06:54:12.689903',
  '2019-11-21T07:20:11.081298',
  '2019-11-21T07:45:57.931090',
  '2019-11-21T08:12:02.843759',
  '2019-11-21T08:38:07.432688',
  '2019-11-21T09:04:27.523268'],
 'values': [[93.19227600097656,
   86.19403839111328,
   92.30490112304688,
   84.74042510986328,
   89.09931182861328,
   80.48370361328125,
   83.98051452636719,
   95.2027359008789,
   101.05570983886719,
   86.29988861083