In [None]:
import json
import os
import numpy as np
import scipy.ndimage
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager

In [None]:
# Plotting defaults
font_manager.fontManager.ttflist.extend(
    font_manager.createFontList(
        font_manager.findSystemFonts(fontpaths="/users/amtseng/modules/fonts")
    )
)
plot_params = {
    "figure.titlesize": 22,
    "axes.titlesize": 22,
    "axes.labelsize": 20,
    "axes.labelweight": "bold",
    "legend.fontsize": 18,
    "xtick.labelsize": 16,
    "ytick.labelsize": 16,
    "font.family": "Roboto",
    "font.weight": "bold"
}
plt.rcParams.update(plot_params)

### Define paths for the model and data of interest

In [None]:
model_type = "binary"

In [None]:
if model_type == "binary":
    models_base_path = "/users/amtseng/att_priors/models/trained_models/binary/"
else:
    models_base_path = "/users/amtseng/att_priors/models/trained_models/profile/"

In [None]:
condition_name = "SPI1"
noprior_models_path = os.path.join(models_base_path, "%s" % condition_name)
prior_models_path = os.path.join(models_base_path, "%s_prior" % condition_name)

### Importing saved metrics JSON files

In [None]:
def import_metrics_json(models_path, run_num):
    """
    Looks in `{models_path}/{run_num}/metrics.json` and returns the contents as a
    Python dictionary. Returns None if the path does not exist, or if the JSON is
    malformed.
    """
    path = os.path.join(models_path, str(run_num), "metrics.json")
    if not os.path.exists(path):
        print("Path does not exist: %s" % path)
        return None
    try:
        with open(path, "r") as f:
            return json.load(f)
    except json.JSONDecodeError:
        print("Malformed JSON: %s" % path)
        return None

In [None]:
def import_all_metrics_json(models_path):
    """
    Looks in `models_path` and finds all instances of
    `{models_path}/{run_num}/metrics.json`, returning a dictionary that maps
    `{run_num}` to the metrics dictionary.
    """
    all_metrics = {run_num : import_metrics_json(models_path, run_num) for run_num in os.listdir(models_path)}
    all_metrics = {key : val for key, val in all_metrics.items() if val}  # Remove empties
    return all_metrics

In [None]:
def import_config_json(models_path, run_num):
    """
    Looks in `{models_path}/{run_num}/config.json` and returns the contents as a
    Python dictionary. Returns None if the path does not exist, or if the JSON is
    malformed.
    """
    path = os.path.join(models_path, str(run_num), "config.json")
    if not os.path.exists(path):
        print("Path does not exist: %s" % path)
        return None
    try:
        with open(path, "r") as f:
            return json.load(f)
    except json.JSONDecodeError:
        print("Malformed JSON: %s" % path)
        return None

In [None]:
def import_all_config_json(models_path):
    """
    Looks in `models_path` and finds all instances of
    `{models_path}/{run_num}/config.json`, returning a dictionary that maps
    `{run_num}` to the config dictionary.
    """
    all_config = {run_num : import_config_json(models_path, run_num) for run_num in os.listdir(models_path)}
    all_config = {key : val for key, val in all_config.items() if val}  # Remove empties
    return all_config

In [None]:
def extract_metrics_values(metrics, key):
    """
    From a single metrics dictionary (i.e. the imported metrics.json for a
    single run), extracts the set of values with the given key.
    """
    return metrics[key]["values"]

In [None]:
def extract_metrics_values_at_best_run(all_metrics, key):
    """
    From a metrics dictionary of all runs (i.e. the imported metrics from
    `import_all_metrics_json`, extracts the set of values with the given key,
    but only for the run that yielded the minimal validation loss. Returns
    the run number, epoch number, and the metric values.
    """
    if model_type == "binary":
        val_key = "val_corr_losses"
    else:
        val_key = "val_prof_corr_losses"
    best_run, best_epcoh, best_val = None, None, None
    for run in all_metrics:
        metrics = all_metrics[run]
        vals = np.mean(extract_metrics_values(metrics, val_key), axis=1)
        epoch = np.argmin(vals)
        val = vals[epoch]
        if best_val is None or val < best_val:
            best_run, best_epoch, best_val = run, epoch + 1, val
    return best_run, best_epoch, extract_metrics_values(all_metrics[best_run], key)

In [None]:
def smooth_signal(signal, sigma, axis=-1):
    """
    Smooths a signal along the given axis using a Gaussian weight vector.
    Smooths to 1 sigma (unless sigma is 0, and then it does no smoothing).
    """
    if sigma == 0:
        return scipy.ndimage.gaussian_filter1d(signal, 1, axis=axis, truncate=0)
    else:
        return scipy.ndimage.gaussian_filter1d(signal, sigma, axis=axis, truncate=1)

### Plot training statistics tracked through time

In [None]:
noprior_metrics = import_all_metrics_json(noprior_models_path)
prior_metrics = import_all_metrics_json(prior_models_path)

In [None]:
# Plot training and validation correctness losses with vs without the prior
if model_type == "binary":
    train_key = "train_corr_losses"
    val_key = "val_corr_losses"
else:
    train_key = "train_prof_corr_losses"
    val_key = "val_prof_corr_losses"

noprior_train_corr_losses = {key : extract_metrics_values(m, train_key) for key, m in noprior_metrics.items()}
prior_train_corr_losses = {key : extract_metrics_values(m, train_key) for key, m in prior_metrics.items()}
noprior_val_corr_losses = {key : extract_metrics_values(m, val_key) for key, m in noprior_metrics.items()}
prior_val_corr_losses = {key : extract_metrics_values(m, val_key) for key, m in prior_metrics.items()}

fig, ax = plt.subplots(figsize=(12, 12))
for key, corr_losses in noprior_train_corr_losses.items():
    noprior_train_line, = ax.plot(np.nanmean(corr_losses, axis=1), color="forestgreen", linestyle=":", alpha=0.7)
for key, corr_losses in prior_train_corr_losses.items():
    prior_train_line, = ax.plot(np.nanmean(corr_losses, axis=1), color="purple", linestyle=":", alpha=0.7)
for key, corr_losses in noprior_val_corr_losses.items():
    noprior_val_line, = ax.plot(np.nanmean(corr_losses, axis=1), color="coral", alpha=0.7)
for key, corr_losses in prior_val_corr_losses.items():
    prior_val_line, = ax.plot(np.nanmean(corr_losses, axis=1), color="royalblue", alpha=0.7)
ax.legend(
    [noprior_train_line, noprior_val_line, prior_train_line, prior_val_line],
    [
        "Training loss without prior", "Validation loss without prior",
        "Training loss with Fourier prior", "Validation loss with Fourier prior"
    ]
)
if model_type == "binary":
    title = "Correctness loss without/with Fourier priors"
else:
    title = "Histogram of validation profile NLL loss without/with Fourier priors"
title += "\n%s, %d/%d %s models" % (condition_name, len(noprior_metrics), len(prior_metrics), model_type)
ax.set_title(title)
ax.set_xlabel("Epoch number")
xticks = np.arange(0, np.max(ax.get_xticks())).astype(int)
ax.set_xticks(xticks)
ax.set_xticklabels(xticks + 1)
if model_type == "binary":
    ax.set_ylabel("Cross-entropy loss")
else:
    ax.set_ylabel("Profile NLL loss")

In [None]:
# Plot training and validation correctness losses with vs without the prior
if model_type == "binary":
    train_key = "train_corr_losses"
    val_key = "val_corr_losses"
else:
    train_key = "train_prof_corr_losses"
    val_key = "val_prof_corr_losses"

noprior_best_run, noprior_best_epoch, noprior_train_corr_losses = extract_metrics_values_at_best_run(noprior_metrics, train_key)
prior_best_run, prior_best_epoch, prior_train_corr_losses = extract_metrics_values_at_best_run(prior_metrics, train_key)
_, _, noprior_val_corr_losses = extract_metrics_values_at_best_run(noprior_metrics, val_key)
_, _, prior_val_corr_losses = extract_metrics_values_at_best_run(prior_metrics, val_key)

print("Best run/epoch without prior: run %s, epoch %d" % (noprior_best_run, noprior_best_epoch))
print("Best run/epoch with priors: run %s, epoch %d" % (prior_best_run, prior_best_epoch))

fig, ax = plt.subplots(figsize=(12, 12))
noprior_train_line, = ax.plot(np.nanmean(noprior_train_corr_losses, axis=1), color="forestgreen", linestyle=":")
prior_train_line, = ax.plot(np.nanmean(prior_train_corr_losses, axis=1), color="purple", linestyle=":")
noprior_val_line, = ax.plot(np.nanmean(noprior_val_corr_losses, axis=1), color="coral")
prior_val_line, = ax.plot(np.nanmean(prior_val_corr_losses, axis=1), color="royalblue")
plt.legend(
    [noprior_train_line, noprior_val_line, prior_train_line, prior_val_line],
    [
        "Training loss without prior", "Validation loss without prior",
        "Training loss with Fourier prior", "Validation loss with Fourier prior"
    ]
)
if model_type == "binary":
    title = "Correctness loss without/with Fourier priors of best run"
else:
    title = "Histogram of validation profile NLL loss without/with Fourier priors"
title += "\nComparison of best-performing %s %s models" % (condition_name, model_type)
ax.set_title(title)
ax.set_xlabel("Epoch number")
xticks = np.arange(0, np.max(ax.get_xticks())).astype(int)
ax.set_xticks(xticks)
ax.set_xticklabels(xticks + 1)
if model_type == "binary":
    ax.set_ylabel("Cross-entropy loss")
else:
    ax.set_ylabel("Profile NLL loss")

In [None]:
print("Best validation loss without prior: %f" % np.min(np.nanmean(noprior_val_corr_losses, axis=1)))
print("Best validation loss with prior: %f" % np.min(np.nanmean(prior_val_corr_losses, axis=1)))