In [1]:
import json
import os

In [2]:
def import_metrics_json(models_path, run_num):
    """
    Looks in {models_path}/{run_num}/metrics.json and returns the contents as a
    Python dictionary. Returns None if the path does not exist.
    """
    path = os.path.join(models_path, str(run_num), "metrics.json")
    if not os.path.exists(path):
        return None
    with open(path, "r") as f:
        return json.load(f)

In [14]:
def get_best_validation_auprc(models_path):
    """
    Given the path to a set of runs, determines the run and epoch with the
    best validation auPRC for the first task.
    """
    # Get the metrics, ignoring empty or nonexistent metrics.json files
    metrics = {run_num : import_metrics_json(models_path, run_num) for run_num in os.listdir(models_path)}
    metrics = {key : val for key, val in metrics.items() if val}
    
    # Get the best loss
    best_auprc, best_run, best_epoch = 0, None, None
    for run_num in metrics.keys():
        val_auprcs = metrics[run_num]["val_corr_auprc"]["values"]
        for epoch_num, auprc_list in enumerate(val_auprcs):
            auprc = auprc_list[0]  # First one, arbitrarily
            if auprc > best_auprc:
                best_auprc, best_run, best_epoch = auprc, run_num, epoch_num + 1
    print("Best auPRC: %f" % best_auprc)
    print("Epoch %d in run %s" % (best_epoch, best_run))

In [15]:
models_path = "/users/amtseng/att_priors/models/trained_models/SPI1_DREAM/"
get_best_validation_auprc(models_path)

Best auPRC: 0.504651
Epoch 3 in run 23
