In [8]:
import json
import os

In [9]:
def import_metrics_json(models_path, run_num):
    """
    Looks in {models_path}/{run_num}/metrics.json and returns the contents as a
    Python dictionary. Returns None if the path does not exist.
    """
    path = os.path.join(models_path, str(run_num), "metrics.json")
    if not os.path.exists(path):
        return None
    with open(path, "r") as f:
        return json.load(f)

In [10]:
def get_best_metric(models_path, metric_extract_func, metric_compare_func):
    """
    Given the path to a set of runs, determines the run with the best metric value,
    where the metric value is fetched by `metric_extract_func`. This function must
    take the imported metrics JSON and return the (scalar) value to use for
    comparison. The best metric value is determiend by `metric_compare_func`, which
    must take in two arguments, and return whether or not the _first_ one is better.
    Returns the number of the run, the value associated with that run, and a list of
    all the values used for comparison.
    """
    # Get the metrics, ignoring empty or nonexistent metrics.json files
    metrics = {run_num : import_metrics_json(models_path, run_num) for run_num in os.listdir(models_path)}
    metrics = {key : val for key, val in metrics.items() if val}  # Remove empties
    
    # Get the best value
    best_run, best_val, all_vals = None, None, []
    for run_num in metrics.keys():
        try:
            val = metric_extract_func(metrics[run_num])
        except Exception:
            print("Warning: Was not able to extract metric for run %s" % run_num)
            continue
        all_vals.append(val)
        if best_val is None or metric_compare_func(val, best_val):
            best_val, best_run = val, run_num
    return best_run, best_val, all_vals

In [11]:
models_path = "/users/amtseng/att_priors/models/trained_profile_models/SPI1/"
best_run, best_val, all_vals = get_best_metric(
    models_path,
    lambda metrics: metrics["summit_prof_nll"]["values"][0][0],  # First task, arbitrarily
    lambda x, y: x < y
)
print("Best run: %s" % best_run)
print("Associated value: %s" % best_val)

Best run: 2
Associated value: 171.80329483709107
