In [1]:
import json
import os
import numpy as np

In [2]:
def import_metrics_json(models_path, run_num):
    """
    Looks in {models_path}/{run_num}/metrics.json and returns the contents as a
    Python dictionary. Returns None if the path does not exist.
    """
    path = os.path.join(models_path, str(run_num), "metrics.json")
    if not os.path.exists(path):
        return None
    with open(path, "r") as f:
        return json.load(f)

In [3]:
def get_best_metric(models_path, metric_name, reduce_func, compare_func):
    """
    Given the path to a set of runs, determines the run with the best metric value,
    for the given `metric_name`. For each run, the function `reduce_func` must take
    the array of all values for that metric and return a (scalar) value to use for
    comparison. The best metric value is determined by `metric_compare_func`, which
    must take in two arguments, and return whether or not the _first_ one is better.
    Returns the number of the run, the value associated with that run, and a dict of
    all the values used for comparison.
    """
    # Get the metrics, ignoring empty or nonexistent metrics.json files
    metrics = {run_num : import_metrics_json(models_path, run_num) for run_num in os.listdir(models_path)}
    metrics = {key : val for key, val in metrics.items() if val}  # Remove empties
    
    # Get the best value
    best_run, best_val, all_vals = None, None, {}
    for run_num in metrics.keys():
        try:
            val = reduce_func(metrics[run_num][metric_name]["values"])
            all_vals[run_num] = val
            if best_val is None or compare_func(val, best_val):
                best_val, best_run = val, run_num
        except Exception:
            print("Warning: Was not able to compute values for run %s" % run_num)
            continue
    return best_run, best_val, all_vals

In [4]:
def get_best_metric_at_best_epoch(models_path, metric_name, reduce_func, compare_func):
    """
    Given the path to a set of runs, determines the run with the best metric value,
    for the given `metric_name`. For each run, the function `reduce_func` must take
    the array of all values for that metric and return a (scalar) value FOR EACH
    SUBARRAY/VALUE in the value array to use for comparison. The best metric value
    is determined by `metric_compare_func`, which must take in two arguments, and
    return whether or not the _first_ one is better.
    Returns the number of the run, the (one-indexed) number of the epch, the value
    associated with that run and epoch, and a dict of all the values used for
    comparison (mapping pair of run number and epoch number to value).
    """
    # Get the metrics, ignoring empty or nonexistent metrics.json files
    metrics = {run_num : import_metrics_json(models_path, run_num) for run_num in os.listdir(models_path)}
    metrics = {key : val for key, val in metrics.items() if val}  # Remove empties
    
    # Get the best value
    best_run, best_epoch, best_val, all_vals = None, None, None, {}
    for run_num in metrics.keys():
        try:
            # Find the best epoch within that run
            best_epoch_in_run, best_val_in_run = None, None
            for i, subarr in enumerate(metrics[run_num][metric_name]["values"]):
                val = reduce_func(subarr)
                if best_val_in_run is None or compare_func(val, best_val_in_run):
                    best_epoch_in_run, best_val_in_run = i + 1, val
            all_vals[(run_num, best_epoch_in_run)] = best_val_in_run
            
            # If the best value in the best epoch of the run is best so far, update
            if best_val is None or compare_func(best_val_in_run, best_val):
                best_run, best_epoch, best_val = run_num, best_epoch_in_run, best_val_in_run
        except Exception:
            print("Warning: Was not able to compute values for run %s" % run_num)
            continue
    return best_run, best_epoch, best_val, all_vals

In [52]:
def print_validation_profile_and_prior_losses(condition):
    models_path = "/users/amtseng/att_priors/models/trained_models/profile_models/%s/" % condition
    
    print("Best profile loss overall:")
    best_run, best_epoch, best_val, all_vals = get_best_metric_at_best_epoch(
        models_path,
        "summit_prof_nll",
        lambda values: np.mean(values),
        lambda x, y: x < y
    )
    print("\tBest run: %s" % best_run)
    print("\tBest epoch in run: %d" % best_epoch)
    print("\tAssociated value: %s" % best_val)
    
    print("Best epoch in each run:")
    for key in sorted(all_vals.keys(), key=lambda p: int(p[0])):
        print("\tRun %s, epoch %d: %6.2f" % (key[0], key[1], all_vals[key]))
        
    print("All validation profile and prior losses:")
    for key in sorted(all_vals.keys(), key=lambda p: int(p[0])):
        print(key[0])
        metrics = import_metrics_json(models_path, key[0])
        print("\t" + " ".join(["%6.2f" % i for i in np.mean(metrics["val_prof_corr_losses"]["values"], axis=1)]))
        print("\t" + " ".join(["%6.4f" % i for i in np.mean(metrics["val_pos_att_losses"]["values"], axis=1)]))

In [53]:
print_validation_profile_and_prior_losses("SPI1")

Best profile loss overall:
	Best run: 33
	Best epoch in run: 1
	Associated value: 155.00478504608748
Best epoch in each run:
	Run 1, epoch 1: 156.31
	Run 2, epoch 1: 156.22
	Run 3, epoch 1: 157.18
	Run 4, epoch 1: 156.02
	Run 5, epoch 1: 156.76
	Run 6, epoch 1: 155.66
	Run 7, epoch 1: 155.13
	Run 8, epoch 1: 157.11
	Run 9, epoch 1: 156.52
	Run 10, epoch 1: 157.89
	Run 11, epoch 1: 155.57
	Run 12, epoch 1: 156.53
	Run 13, epoch 1: 157.05
	Run 14, epoch 1: 156.63
	Run 15, epoch 1: 156.55
	Run 16, epoch 1: 155.27
	Run 17, epoch 1: 155.32
	Run 18, epoch 1: 155.17
	Run 19, epoch 1: 157.56
	Run 20, epoch 1: 158.47
	Run 21, epoch 1: 155.26
	Run 22, epoch 1: 155.56
	Run 23, epoch 1: 156.60
	Run 24, epoch 1: 157.80
	Run 25, epoch 1: 157.22
	Run 26, epoch 1: 158.76
	Run 27, epoch 1: 155.79
	Run 28, epoch 1: 155.08
	Run 29, epoch 1: 155.39
	Run 30, epoch 1: 155.72
	Run 31, epoch 1: 156.80
	Run 32, epoch 1: 155.26
	Run 33, epoch 1: 155.00
	Run 34, epoch 1: 155.59
	Run 35, epoch 1: 156.00
	Run 36, 

	 88.49  87.18  86.28  86.03  85.75  85.67  85.56  85.41  85.36  85.26
	0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
49
	 90.68  89.40  88.60  88.18  87.84  87.43  87.22  86.98  86.89  86.77
	0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
50
	 87.78  87.62  87.96  87.58  87.41  87.30  87.34  87.33  87.35  87.44
	0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000


In [54]:
print_validation_profile_and_prior_losses("SPI1_prior")

Best profile loss overall:
	Best run: 11
	Best epoch in run: 1
	Associated value: 156.0538430891092
Best epoch in each run:
	Run 1, epoch 1: 159.12
	Run 2, epoch 1: 156.44
	Run 3, epoch 1: 158.80
	Run 4, epoch 1: 161.46
	Run 5, epoch 1: 156.34
	Run 6, epoch 1: 157.15
	Run 7, epoch 1: 156.76
	Run 8, epoch 1: 156.12
	Run 9, epoch 1: 156.22
	Run 10, epoch 1: 157.13
	Run 11, epoch 1: 156.05
	Run 12, epoch 1: 157.45
	Run 13, epoch 1: 156.51
	Run 14, epoch 1: 158.77
	Run 15, epoch 1: 157.23
	Run 16, epoch 1: 156.40
	Run 17, epoch 1: 158.05
	Run 18, epoch 1: 156.75
	Run 19, epoch 1: 158.58
	Run 20, epoch 1: 157.39
	Run 21, epoch 1: 159.86
	Run 22, epoch 1: 158.75
	Run 23, epoch 1: 156.47
	Run 24, epoch 1: 157.26
All validation profile and prior losses:
1
	104.18  92.49  90.64  90.17  89.20  88.90  88.57  88.28  88.08  87.98
	0.3877 0.1215 0.1044 0.0859 0.0804 0.0744 0.0695 0.0674 0.0644 0.0605
2
	 87.50  86.96  86.70  86.57  86.42  86.33  86.61  86.44  86.55  86.42
	0.1853 0.0792 0.0669 0.062

In [43]:
print_validation_profile_and_prior_losses("SPI1_noise10")

Best profile loss overall:
	Best run: 2
	Best epoch in run: 9
	Associated value: 95.21198285420736
Best epoch in each run:
	Run 1, epoch 9:  97.19
	Run 2, epoch 9:  95.21
	Run 3, epoch 6:  95.30
	Run 4, epoch 7:  96.07
All validation profile and prior losses:
1
	103.58 101.14  98.86  98.90  98.18  99.33  97.43  98.10  97.19  98.43
	0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
2
	111.29  98.64  97.25  95.86  95.77  96.70  96.75  95.67  95.21  97.31
	0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
3
	 96.72  95.99  96.42  96.21  95.91  95.30  96.06  95.73  95.66  96.53
	0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
4
	109.80  99.07  97.11  97.20  98.15  96.77  96.07  98.21  96.32  96.81
	0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000


In [44]:
print_validation_profile_and_prior_losses("SPI1_noise20")

Best profile loss overall:
	Best run: 2
	Best epoch in run: 9
	Associated value: 95.85062360410338
Best epoch in each run:
	Run 1, epoch 4:  97.29
	Run 2, epoch 9:  95.85
	Run 3, epoch 1:  96.21
	Run 4, epoch 7:  96.60
	Run 5, epoch 8:  98.88
	Run 6, epoch 8:  97.63
	Run 7, epoch 10:  97.07
	Run 8, epoch 6:  97.58
All validation profile and prior losses:
1
	 98.21  97.64 101.21  97.29  97.87  97.91  98.62  98.45  99.08  98.61
	0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
2
	 98.74  97.61  97.37  97.39  96.80  96.96  97.25  96.94  95.85  96.54
	0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
3
	 96.21 100.25  97.65  98.13  97.95  96.92  97.87  98.96  98.20  98.28
	0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
4
	101.85 100.52  97.28  97.95  97.32  97.73  96.60  97.62  97.39  97.96
	0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
5
	100.81 104.66 101.69  98.89  99.70 100.04 101.11  98.88  99.25  

In [45]:
print_validation_profile_and_prior_losses("SPI1_prior_noise10")

Best profile loss overall:
	Best run: 1
	Best epoch in run: 9
	Associated value: 96.98021186546043
Best epoch in each run:
	Run 1, epoch 9:  96.98
	Run 2, epoch 2: 100.15
All validation profile and prior losses:
1
	113.47 101.51  98.03  99.49  98.93  98.24  98.21  98.11  96.98  97.49
	0.3955 0.1120 0.0834 0.0727 0.0683 0.0703 0.0615 0.0590 0.0571 0.0556
2
	113.38 100.15 101.71
	0.3797 0.1099 0.0865


In [46]:
print_validation_profile_and_prior_losses("SPI1_prior_noise20")

Best profile loss overall:
	Best run: 2
	Best epoch in run: 10
	Associated value: 94.8845773767542
Best epoch in each run:
	Run 1, epoch 8:  96.76
	Run 2, epoch 10:  94.88
	Run 3, epoch 10:  97.64
	Run 4, epoch 2:  98.40
All validation profile and prior losses:
1
	 99.21  97.73  97.04  97.45  97.59  99.23  97.01  96.76  98.85  97.25
	0.1374 0.0727 0.0683 0.0644 0.0558 0.0561 0.0557 0.0536 0.0554 0.0569
2
	 97.21  99.75  97.18  97.35  96.78  97.44  96.18  96.85  96.13  94.88
	0.1338 0.0769 0.0676 0.0600 0.0595 0.0590 0.0589 0.0581 0.0575 0.0572
3
	 99.72 100.82  99.12  99.20  99.94  98.24  97.81 100.14  98.31  97.64
	0.1305 0.0728 0.0618 0.0601 0.0582 0.0571 0.0578 0.0536 0.0545 0.0558
4
	100.48  98.40  98.92
	0.1256 0.0739 0.0685


In [48]:
print_validation_profile_and_prior_losses("SPI1_drop20")

Best profile loss overall:
	Best run: 2
	Best epoch in run: 9
	Associated value: 85.410839426959
Best epoch in each run:
	Run 1, epoch 10:  85.77
	Run 2, epoch 9:  85.41
	Run 3, epoch 8:  85.59
	Run 4, epoch 9:  86.81
	Run 5, epoch 10:  86.79
	Run 6, epoch 9:  85.63
	Run 7, epoch 8:  87.99
All validation profile and prior losses:
1
	 89.26  87.91  87.09  86.55  86.27  86.13  85.93  85.83  85.86  85.77
	0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
2
	 87.26  86.48  86.00  85.71  85.59  85.56  85.48  85.42  85.41  85.47
	0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
3
	 87.22  86.21  85.90  85.69  85.67  85.68  85.77  85.59  85.65  85.70
	0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
4
	 88.10  87.12  86.88  87.05  86.87  86.89  86.89  86.89  86.81  86.85
	0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
5
	 88.71  87.52  87.38  87.07  87.18  87.08  86.99  86.93  86.84  86.79
	0.0000 0.0000 0.00

In [49]:
print_validation_profile_and_prior_losses("SPI1_prior_drop20")

Best profile loss overall:
	Best run: 3
	Best epoch in run: 10
	Associated value: 86.2115753809611
Best epoch in each run:
	Run 1, epoch 7:  87.25
	Run 2, epoch 4:  86.25
	Run 3, epoch 10:  86.21
	Run 4, epoch 8:  86.46
All validation profile and prior losses:
1
	 87.84  87.85  87.51  87.42  87.56  87.64  87.25  87.34  87.29  87.80
	0.1531 0.0816 0.0684 0.0613 0.0580 0.0566 0.0559 0.0560 0.0541 0.0538
2
	 86.71  86.40  86.50  86.25  86.41  86.49  86.27  86.48  86.37  86.29
	0.1513 0.0748 0.0659 0.0623 0.0598 0.0582 0.0567 0.0569 0.0559 0.0563
3
	 91.86  88.57  87.79  87.26  87.03  86.77  86.55  86.59  86.28  86.21
	0.1903 0.0984 0.0819 0.0718 0.0654 0.0600 0.0572 0.0520 0.0509 0.0491
4
	 88.45  87.56  87.22  86.91  86.72  86.69  86.49  86.46
	0.1517 0.0913 0.0756 0.0644 0.0593 0.0555 0.0527 0.0516
