## Setup

In [None]:
import json
import re
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import autograd
from torch.autograd import Variable
import matplotlib.pyplot as plt
import math
import random
from datetime import datetime
import os.path
import copy
import seaborn as sns
import lovely_tensors as lt # can be removed

from l2o.others import w, detach_var, rsetattr, rgetattr, count_parameters, print_grads, \
    load_l2o_opter_ckpt, load_baseline_opter_ckpt, load_ckpt, get_baseline_ckpt_dir, dict_to_str
from l2o.visualization import get_model_dot
from l2o.training import do_fit, fit_normal, fit_optimizer, find_best_lr_normal
from l2o.regularization import (
    regularize_updates_translation_constraints,
    regularize_updates_scale_constraints,
    regularize_updates_rescale_constraints,
    regularize_updates_constraints,
    regularize_translation_conservation_law_breaking,
    regularize_rescale_conservation_law_breaking,
)
from l2o.analysis import (
    get_rescale_sym_constraint_deviation,
    get_translation_sym_constraint_deviations,
    get_scale_sym_constraint_deviation,
    get_baseline_opter_param_updates,
    collect_rescale_sym_deviations,
    collect_translation_sym_deviations,
    collect_scale_sym_deviations,
    collect_conservation_law_deviations,
    calc_sai,
)
from l2o.tail_index_utils import (
    alpha_estimator,
)
from l2o.data import MNIST, CIFAR10
from l2o.optimizer import Optimizer
from l2o.optimizee import (
    MNISTSigmoid,
    MNISTReLU,
    MNISTNet,
    MNISTNet2Layer,
    MNISTNetBig,
    MNISTRelu,
    MNISTLeakyRelu,
    MNISTSimoidBatchNorm,
    MNISTReluBatchNorm,
    MNISTConv,
    MNISTReluBig,
    MNISTReluBig2Layer,
    MNISTMixtureOfActivations,
    MNISTNetBig2Layer,
)
from l2o.meta_module import *
from meta_test import meta_test, meta_test_baselines

lt.monkey_patch() # can be removed
sns.set(color_codes=True)
sns.set_style("white")

In [None]:
### publication figure settings:
plt.rc("font", family="serif")

plt.rc("legend", fontsize=12)
plt.rc("xtick", labelsize=12)
plt.rc("ytick", labelsize=12)
plt.rc("axes", labelsize=13)
plt.rc("axes", titlesize=13)
plt.rc("axes", linewidth=0.5)
plt.rc("axes", labelpad=10)

plt.rc("lines", linewidth=1.)

plt.rc("figure", dpi=300)
plt.rc("figure", figsize=(6, 4))

plt.rc("savefig", dpi=300)
plt.rc("savefig", format="pdf")
plt.rc("savefig", bbox="tight")
plt.rc("savefig", pad_inches=0.1)

In [None]:
### config
l2os = {
    ### MNISTReluBatchNorm ###
    # r"L2O, $\beta$=0": "26-04-2023_01-22-31_MNISTReluBatchNorm_Optimizer", # DEPRECATED
    # r"L2O, $\beta$=0": "12-05-2023_00-38-54_MNISTReluBatchNorm_Optimizer", # lr: 1e-3
    # r"L2O, $\beta$=0.01": "03-05-2023_22-42-56_MNISTReluBatchNorm_Optimizer", # scale
    # r"L2O, $\beta$=0.1": "02-05-2023_16-51-19_MNISTReluBatchNorm_Optimizer", # scale
    # r"L2O, $\beta$=1.": "03-05-2023_22-42-36_MNISTReluBatchNorm_Optimizer", # scale
    
    # "L2O - unroll 10": "10-05-2023_13-26-17_MNISTReluBatchNorm_Optimizer", # already collected
    # "L2O - unroll 20": "12-05-2023_00-38-54_MNISTReluBatchNorm_Optimizer", # lr: 1e-3, already collected
    # "L2O - unroll 30": "11-05-2023_01-27-43_MNISTReluBatchNorm_Optimizer", # already collected

    # "L2O - randomized restarts": "20-06-2023_00-06-50_MNISTReluBatchNorm_Optimizer",
    # "L2O - baseline": "12-05-2023_00-38-54_MNISTReluBatchNorm_Optimizer",

    # "L2O - w/ SAI input": "21-06-2023_12-34-25_MNISTReluBatchNorm_Optimizer",
    # "L2O - w/ SAI input & random restarts": "22-06-2023_11-52-12_MNISTReluBatchNorm_Optimizer",
    ##########################


    ### MNISTLeakyRelu ###
    # r"L2O, $\beta$=0": "19-02-2023_18-20-21_MNISTLeakyRelu_Optimizer", # DEPRECATED
    # r"L2O, $\beta$=0": "12-05-2023_00-35-50_MNISTLeakyRelu_Optimizer", # lr: 1e-3
    # r"L2O, $\beta$=0.01": "30-04-2023_13-08-09_MNISTLeakyRelu_Optimizer", # rescale
    # r"L2O, $\beta$=0.1": "06-05-2023_01-54-04_MNISTLeakyRelu_Optimizer", # rescale
    # r"L2O, $\beta$=0.5": "05-05-2023_01-42-29_MNISTLeakyRelu_Optimizer", # rescale
    
    # "L2O - unroll 10": "09-05-2023_20-50-27_MNISTLeakyRelu_Optimizer", # already collected
    # "L2O - unroll 20": "12-05-2023_00-35-50_MNISTLeakyRelu_Optimizer", # lr: 1e-3, already collected
    # "L2O - unroll 30": "10-05-2023_11-11-39_MNISTLeakyRelu_Optimizer", # already collected
    
    # "L2O - randomized restarts": "19-06-2023_20-42-55_MNISTLeakyRelu_Optimizer",
    # "L2O - baseline": "12-05-2023_00-35-50_MNISTLeakyRelu_Optimizer",

    # "L2O - w/ SAI input": "21-06-2023_12-26-53_MNISTLeakyRelu_Optimizer",
    # "L2O - w/ log10(SAI) input": "21-06-2023_12-28-08_MNISTLeakyRelu_Optimizer",
    # "L2O - w/ SAI input & random restarts": "22-06-2023_11-51-28_MNISTLeakyRelu_Optimizer",
    # "L2O - w/ time input & random restarts": "23-07-2023_14-07-30_MNISTLeakyRelu_Optimizer",

    # "L2O, shared state": "23-07-2023_23-32-08_MNISTLeakyRelu_Optimizer", # shared state (256)
    ######################


    ### MNISTNet ###
    # r"L2O, $\beta$=0": "05-03-2023_01-33-57_MNISTNet_Optimizer", # DEPRECATED
    # r"L2O, $\beta$=0": "07-05-2023_20-52-18_MNISTNet_Optimizer", # lr: 1e-3
    # r"L2O, $\beta$=0.01": "29-04-2023_01-35-00_MNISTNet_Optimizer", # translation
    # r"L2O, $\beta$=0.1": "28-04-2023_13-45-29_MNISTNet_Optimizer", # translation
    # r"L2O, $\beta$=0.5": "30-04-2023_12-41-11_MNISTNet_Optimizer", # translation
    
    # "L2O - unroll 10": "09-05-2023_20-41-10_MNISTNet_Optimizer", # already collected
    # "L2O - unroll 20": "07-05-2023_20-52-18_MNISTNet_Optimizer", # lr: 1e-3, already collected
    # "L2O - unroll 30": "09-05-2023_20-42-08_MNISTNet_Optimizer", # already collected
    
    # "L2O - randomized restarts": "18-06-2023_13-29-45_MNISTNet_Optimizer",
    # "L2O - baseline": "05-03-2023_01-33-57_MNISTNet_Optimizer",

    # "L2O - w/ SAI input": "20-06-2023_21-57-17_MNISTNet_Optimizer",
    # "L2O - w/ log10(SAI) input": "20-06-2023_22-00-58_MNISTNet_Optimizer",
    # "L2O - w/ SAI input (shared)": "21-06-2023_16-34-47_MNISTNet_Optimizer",
    # "L2O - w/ SAI input (init ones)": "21-06-2023_18-44-45_MNISTNet_Optimizer",
    # "L2O - w/ SAI input & random restarts": "22-06-2023_11-50-14_MNISTNet_Optimizer",
    # "L2O - w/ time input & random restarts": "23-07-2023_23-37-16_MNISTNet_Optimizer",
    # r"L2O - hidden_sz=30 & w/ sym. reg. enc & $\beta$=0.1": "11-08-2023_21-13-36_MNISTNet_Optimizer", # translation

    # "L2O, shared state": "23-07-2023_14-37-19_MNISTNet_Optimizer", # shared state (152)
    #################


    ### Meta-training for generalization ###
    # "L2O - multi-task": "23-06-2023_21-49-22_MNISTMixtureOfActivationsFeatureDim_Optimizer",
    # "L2O - multi-task": "06-06-2023_15-48-07_MNISTMixtureOfActivations_Optimizer",
    # "L2O - fine-tuning": "06-06-2023_16-16-33_MNISTRelu_Optimizer",
    # "L2O - baseline": "07-05-2023_20-52-18_MNISTNet_Optimizer",
    ########################################
}

baselines = {
    ### MNISTReluBatchNorm
    # "Adam": "Adam_{lr=find_best_lr_normal}_MNISTReluBatchNorm_{affine=True_track_running_stats=True}_MNIST_{batch_size=128}",
    # "SGD": "SGD_{lr=find_best_lr_normal_momentum=0.9}_MNISTReluBatchNorm_{affine=True_track_running_stats=True}_MNIST_{batch_size=128}",

    ### MNISTLeakyRelu
    # "Adam": "Adam_{lr=find_best_lr_normal}_MNISTLeakyRelu_{}_MNIST_{batch_size=128}",
    # "SGD": "SGD_{lr=find_best_lr_normal_momentum=0.9}_MNISTLeakyRelu_{}_MNIST_{batch_size=128}",

    ### MNISTNet
    # "Adam": "Adam_{lr=find_best_lr_normal}_MNISTNet_{}_MNIST_{batch_size=128}",
    # "SGD": "SGD_{lr=find_best_lr_normal_momentum=0.9}_MNISTNet_{}_MNIST_{batch_size=128}",
}

In [None]:
### load l2os from disk
for l2o_name, l2o_dir in l2os.items():
    ### load final l2o checkpoint
    ckpt = torch.load(os.path.join(os.environ["CKPT_PATH"], l2o_dir, "l2o_optimizer.pt"), map_location="cpu")

    ### load all metrics
    l2o_metrics = {}
    for metrics_file in [f_name for f_name in os.listdir(os.path.join(os.environ["CKPT_PATH"], ckpt["config"]["ckpt_base_dir"])) if f_name.startswith("metrics_")]:
        metrics_name = metrics_file[8:-4] # remove the "metrics_" prefix and ".npy" suffix
        l2o_metrics[metrics_name] = np.load(os.path.join(os.environ["CKPT_PATH"], ckpt["config"]["ckpt_base_dir"], metrics_file), allow_pickle=True).item()
    l2os[l2o_name] = {
        "ckpt": ckpt,
        "config": ckpt["config"],
        "metrics": l2o_metrics,
    }

In [None]:
### load baselines from disk
def load_baselines(baselines_dict):
    baselines_root_dir = os.path.join(os.environ["CKPT_PATH"], "baselines")
    for baseline_name, baseline_dir in baselines_dict.items():
        ### load config
        config = torch.load(os.path.join(baselines_root_dir, baseline_dir, "config.pt"), map_location="cpu")
        if "baseline_opter_cls" not in config:
            if "sgd" in baseline_name.lower():
                config["meta_testing"]["baseline_opter_cls"] = optim.SGD
            elif "adam" in baseline_name.lower():
                config["meta_testing"]["baseline_opter_cls"] = optim.Adam
            else:
                raise NotImplementedError

        ### load metrics
        metrics = np.load(os.path.join(baselines_root_dir, baseline_dir, "metrics.npy"), allow_pickle=True).item()
        baselines_dict[baseline_name] = {
            "baseline_dir": os.path.join(baselines_root_dir, baseline_dir),
            "config": config,
            # "baseline_config": baseline_config,
            "metrics": metrics,
        }
    return baselines_dict

baselines = load_baselines(baselines)

## Plot performance

In [None]:
def plot_performance(
    plot_baselines,
    plot_l2os,
    run_nickname,
    show_max_iters,
    metric,
    log_loss=False,
    save_fig_to_path=None,
    with_err_bars=False,
    conv_window=None,
):
    ### plot comparison
    fig = plt.figure()
    ax = fig.add_subplot(111)

    ### baseline optimizers
    for baseline_name, baseline_dict in plot_baselines.items():
        opter_metrics = baseline_dict["metrics"]
        config = baseline_dict["config"]
        if "test" in metric:
            x = np.arange(config["meta_training"]["eval_iter_freq"], show_max_iters + 1, config["meta_training"]["eval_iter_freq"])
            y = np.mean(opter_metrics[metric][:,:show_max_iters // 10], axis=0)
        else:
            x = range(opter_metrics[metric][:,:show_max_iters].shape[1])
            y = np.mean(opter_metrics[metric][:,:show_max_iters], axis=0)
        if conv_window and conv_window > 1:
            y_removed_start = y[:conv_window - 1]
            y = np.convolve(y, np.ones(conv_window), "valid") / conv_window
            y = np.concatenate([y_removed_start, y])
        sns.lineplot(
            x=x,
            y=y,
            label=baseline_name,
            linestyle="--",
            ax=ax,
        )
        
        if with_err_bars:
            if "test" in metric:
                # x = np.arange(config["meta_training"]["eval_iter_freq"], show_max_iters + 1, config["meta_training"]["eval_iter_freq"])
                y_std = np.std(opter_metrics[metric][:,:show_max_iters // 10], axis=0)
            else:
                # x = range(opter_metrics[metric][:,:show_max_iters].shape[1])
                y_std = np.std(opter_metrics[metric][:,:show_max_iters], axis=0)
            if conv_window and conv_window > 1:
                y_removed_start = y_std[:conv_window - 1]
                y_std = np.convolve(y_std, np.ones(conv_window), "valid") / conv_window
                y_std = np.concatenate([y_removed_start, y_std])
            ax.fill_between(
                x,
                y - y_std,
                y + y_std,
                alpha=0.2,
            )
            

    ### L2O optimizers
    for l2o_name, l2o_dict in plot_l2os.items():
        metrics = l2o_dict["metrics"][run_nickname]
        config = l2o_dict["config"]

        if "test" in metric:
            x = np.arange(config["meta_training"]["eval_iter_freq"], show_max_iters + 1, config["meta_training"]["eval_iter_freq"])
            y = np.mean(metrics[metric][:,:show_max_iters // config["meta_training"]["eval_iter_freq"]], axis=0)
        else:
            x = range(metrics[metric][:,:show_max_iters].shape[1])
            y = np.mean(metrics[metric][:,:show_max_iters], axis=0)
        if conv_window and conv_window > 1:
            y_removed_start = y[:conv_window - 1]
            y = np.convolve(y, np.ones(conv_window), "valid") / conv_window
            y = np.concatenate([y_removed_start, y])
        sns.lineplot(
            x=x,
            y=y,
            label=fr"{l2o_name}",
            # label=fr"{l2o_name}, $\beta$={config['meta_training']['reg_mul']}",
            # linewidth=1.,
            ax=ax,
        )

        if with_err_bars:
            if "test" in metric:
                y_std = np.std(metrics[metric][:,:show_max_iters // config["meta_training"]["eval_iter_freq"]], axis=0)
            else:
                y_std = np.std(metrics[metric][:,:show_max_iters], axis=0)
            if conv_window and conv_window > 1:
                y_removed_start = y_std[:conv_window - 1]
                y_std = np.convolve(y_std, np.ones(conv_window), "valid") / conv_window
                y_std = np.concatenate([y_removed_start, y_std])

            ax.fill_between(
                x,
                y - y_std,
                y + y_std,
                alpha=0.2,
            )

    ### plot settings
    ax.set_xlabel("Iteration")
    if metric == "train_loss":
        metric_as_label = "Train Loss"
    elif metric == "test_loss":
        metric_as_label = "Test Loss"
    elif metric == "train_acc":
        metric_as_label = "Train Accuracy"
    elif metric == "test_acc":
        metric_as_label = "Test Accuracy"
    else:
        metric_as_label = metric
    ax.set_ylabel(metric_as_label)

    # set y to log scale
    if log_loss and "loss" in metric:
        ax.set_yscale("log")

    if "acc" in metric:
        ax.set_ylim(0.6, 1.0)
    elif log_loss is not True:
        ax.set_ylim(0.0, None)

    # ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.32), ncol=2)
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.25), ncol=3)
    legend = ax.get_legend()
    for legend_handle in legend.legendHandles:
        legend_handle.set_linewidth(3.0)

    # x-ticks
    x_ticks = ax.get_xticks()
    x_ticks = np.linspace(0, show_max_iters, 3)
    ax.set_xticks(x_ticks)

    ### y-ticks
    if log_loss is not True:
        y_max = ax.get_ylim()[1]
        y_max = np.ceil(y_max / 0.5) * 0.5 # round
        y_ticks = np.linspace(0, y_max, 3)
        ax.set_yticks(y_ticks)
    else:
        # y_ticks = ax.get_yticks()
        # y_ticks = np.linspace(y_ticks[1], y_ticks[-2], 2)
        # ax.set_yticks(y_ticks)
        ax.set_yticks([1e0, 1e-1])
        # ax.set_ylim(0.2e-1, None)
        # ax.set_ylim(1.6e-2, 1e1)
    
    ### add zoom region inset axes (zoom on the first 100 iterations)
    # axins = ax.inset_axes([0.55, 0.56, 0.35, 0.44])
    # for l_i, line in enumerate(ax.lines):
    #     xy_data = line.get_xydata()
    #     axins.plot(xy_data[:,0], xy_data[:,1], linewidth=1., color=line.get_color(), linestyle=line.get_linestyle())

    # axins.set_xlim(0, 80)
    # axins.set_ylim(0.35, 2.5)
    # # axins.set_yscale("log")
    # axins.set_xticks([0, 40, 80])
    # # axins.set_yticks([1.])
    # axins.set_yticklabels([1.0], fontsize=7)
    # axins.set_xticklabels([0, 40, 80], fontsize=7.5, position=(0., .08))
    # ax.indicate_inset_zoom(axins)

    plt.show()

    ### save the figure
    if save_fig_to_path is not None:
        fig.savefig(save_fig_to_path, bbox_inches="tight")

In [None]:
### config
show_max_iters = 500
log_loss = True
metric = "train_loss" # ["train_loss", "test_loss", "train_acc", "test_acc"]
conv_window = 5
with_err_bars = True

### run specification
optee_cls = MNISTLeakyRelu
optee_config = {}
# optee_config = {"affine": True, "track_running_stats": True}
# optee_config = {"layer_sizes": [100,100]}
data_cls = MNIST
data_config = {"batch_size": 128}
optee_nickname = f"{optee_cls.__name__}_{dict_to_str(optee_config)}"
run_nickname = f"{optee_nickname}_{data_cls.__name__}_{dict_to_str(data_config)}"

### load corresponding baselines
plot_baselines = {
    "Adam": "Adam_{lr=find_best_lr_normal}_" + run_nickname,
    "SGD": "SGD_{lr=find_best_lr_normal_momentum=0.9}_" + run_nickname,
}
plot_baselines = load_baselines(plot_baselines)

### where to save the figure
fig_dir = "../results/publication/meta_training_for_generalization/MNISTNet_meta_training"
fig_name = f"{metric}_comparison_{optee_nickname}_{show_max_iters}.pdf"
if log_loss is True:
    fig_name = f"log_{fig_name}"
save_fig_to_path = os.path.join(fig_dir, fig_name)
save_fig_to_path = None # don't save

print(f"Final destination: {save_fig_to_path if save_fig_to_path is not None else 'None'}")

In [None]:
plot_performance(
    plot_l2os=l2os,
    plot_baselines=plot_baselines,
    run_nickname=run_nickname,
    show_max_iters=show_max_iters,
    metric=metric,
    log_loss=log_loss,
    save_fig_to_path=save_fig_to_path,
    with_err_bars=with_err_bars,
    conv_window=conv_window,
)

## Plot performance comparison between L2Os
- TODO

In [None]:
### settings
ckpt_root_dir = "./ckpt"
l2o_opter_ckpt_dirs_all = {
    "MNISTNet": [
        "05-03-2023_01-33-57_MNISTNet_Optimizer", # reg_mul=0
        "29-04-2023_01-35-00_MNISTNet_Optimizer", # reg_mul=0.01
        "28-04-2023_13-45-29_MNISTNet_Optimizer", # reg_mul=0.1
    ],
    "MNISTLeakyRelu_normalized_params": [
        "19-02-2023_18-20-21_MNISTLeakyRelu_Optimizer", # reg_mul=0
        "29-04-2023_01-38-19_MNISTLeakyRelu_Optimizer", # reg_mul=0.01
        "28-04-2023_02-11-33_MNISTLeakyRelu_Optimizer", # reg_mul=0.1
        "29-04-2023_23-46-08_MNISTLeakyRelu_Optimizer", # reg_mul=0.5
    ],
    "MNISTLeakyRelu": [
        "19-02-2023_18-20-21_MNISTLeakyRelu_Optimizer", # reg_mul=0
        # "...", # reg_mul=0.01
        "27-04-2023_16-39-10_MNISTLeakyRelu_Optimizer", # reg_mul=0.1
    ],
    "MNISTReluBatchNorm": [
        "26-04-2023_01-22-31_MNISTReluBatchNorm_Optimizer", # reg_mul=0
        "26-04-2023_01-31-59_MNISTReluBatchNorm_Optimizer", # reg_mul=0.01
        "27-04-2023_01-37-23_MNISTReluBatchNorm_Optimizer", # reg_mul=0.1
    ],
    "MNISTReluBatchNorm_normalized_params": [
        "26-04-2023_01-22-31_MNISTReluBatchNorm_Optimizer", # reg_mul=0
        "29-04-2023_23-51-43_MNISTReluBatchNorm_Optimizer" # reg_mul=0.01
        "28-04-2023_11-48-30_MNISTReluBatchNorm_Optimizer", # reg_mul=0.1
    ]
}
for k in l2o_opter_ckpt_dirs_all:
    l2o_opter_ckpt_dirs_all[k] = [os.path.join(ckpt_root_dir, d) for d in l2o_opter_ckpt_dirs_all[k]]

baselines_dir = "./ckpt/baselines"
baselines_to_test_against = [
    ("Adam", optim.Adam, {"lr": find_best_lr_normal}),
    ("SGD", optim.SGD, {"lr": find_best_lr_normal, "momentum": 0.9}),
]

data_cls = MNIST
data_config = {
    "batch_size": 128,
}

opter_key = "MNISTNet"
optee_cls = MNISTReluBig2Layer
optee_config = {}
# optee_config = {"affine": True, "track_running_stats": True}
# optee_config = {"layer_sizes": [100,100]}

show_max_iters = 200
log_loss = True
metric = "train_loss" # ["train_loss", "test_loss", "train_acc", "test_acc"]

l2o_opter_ckpt_dirs = l2o_opter_ckpt_dirs_all[opter_key]
optee_nickname = f"{optee_cls.__name__}_{optee_config}"
run_nickname = f"{optee_nickname}_{data_cls.__name__}_{data_config}"

fig_name = f"{metric}_comparison_{optee_nickname}_meta_trained_on_{opter_key}_{show_max_iters}.pdf"
if log_loss is True:
    fig_name = f"log_{fig_name}"
fig_dir = "../results/publication/reg_comparison"
os.makedirs(fig_dir, exist_ok=True)
fig_dir = None

print(f"Final destination: {os.path.join(fig_dir, fig_name) if fig_dir is not None else 'None'}")

In [None]:
### load baseline metrics from disk
baseline_metrics = dict()

### load metrics for all considered baselines
for (opter_name, baseline_opter_cls, baseline_opter_config) in baselines_to_test_against:
    baseline_opter_config_copy = deepcopy(baseline_opter_config)
    
    if "lr" in baseline_opter_config and callable(baseline_opter_config["lr"]):
        baseline_opter_config_copy["lr"] = baseline_opter_config_copy["lr"].__name__ # replace function with its name

    baseline_dir_name = f"{opter_name}_{baseline_opter_config_copy}" \
        + f"_{optee_cls.__name__}_{optee_config}" \
        + f"_{data_cls.__name__}_{data_config}"
    metrics_path = os.path.join(baselines_dir, baseline_dir_name, "metrics.npy")
    
    ### load
    print(f"Loading {metrics_path}")
    baseline_metrics[opter_name] = np.load(metrics_path, allow_pickle=True).item()

In [None]:
### load all l2o opters from disk (results of meta-testing + configs)
l2o_opters = []

for l2o_opter_ckpt_dir in l2o_opter_ckpt_dirs:
    ### load previous checkpoint (and skip meta-training of a new l2O optimizer)
    print(f"Loading {l2o_opter_ckpt_dir}")
    _, config, _ = load_ckpt(dir_path=l2o_opter_ckpt_dir)
    assert l2o_opter_ckpt_dir == config["ckpt_base_dir"]
    l2o_opter_dict = {
        "l2o_opter_ckpt_dir": l2o_opter_ckpt_dir,
        "config": config,
        "metrics": dict(),
    }

    for metrics_file in [f_name for f_name in os.listdir(config["ckpt_base_dir"]) if f_name.startswith("metrics_")]:
        metrics_name = metrics_file[8:-4] # remove the "metrics_" prefix and ".npy" suffix
        l2o_opter_dict["metrics"][metrics_name] = np.load(os.path.join(config["ckpt_base_dir"], metrics_file), allow_pickle=True).item()
    
    l2o_opters.append(l2o_opter_dict)

In [None]:
### plot comparison
fig = plt.figure()
ax = fig.add_subplot(111)

### baseline optimizers
for opter_name, opter_metrics in baseline_metrics.items():
    if "test" in metric:
        x = np.arange(config["meta_training"]["eval_iter_freq"], show_max_iters + 1, config["meta_training"]["eval_iter_freq"])
        y = np.mean(opter_metrics[metric][:,:show_max_iters // 10], axis=0)
        y_min = np.min(opter_metrics[metric][:,:show_max_iters // 10], axis=0)
        y_max = np.max(opter_metrics[metric][:,:show_max_iters // 10], axis=0)
    else:
        x = range(opter_metrics[metric][:,:show_max_iters].shape[1])
        y = np.mean(opter_metrics[metric][:,:show_max_iters], axis=0)
        y_min = np.min(opter_metrics[metric][:,:show_max_iters], axis=0)
        y_max = np.max(opter_metrics[metric][:,:show_max_iters], axis=0)
    sns.lineplot(
        x=x,
        y=y,
        label=opter_name,
        linestyle="--",
        ax=ax,
    )

### L2O optimizers
for l2o_opter_dict in l2o_opters:
    metrics = l2o_opter_dict["metrics"][run_nickname]
    config = l2o_opter_dict["config"]

    if "test" in metric:
        x = np.arange(config["meta_training"]["eval_iter_freq"], show_max_iters + 1, config["meta_training"]["eval_iter_freq"])
        y = np.mean(metrics[metric][:,:show_max_iters // config["meta_training"]["eval_iter_freq"]], axis=0)
        y_min = np.min(metrics[metric][:,:show_max_iters // config["meta_training"]["eval_iter_freq"]], axis=0)
        y_max = np.max(metrics[metric][:,:show_max_iters // config["meta_training"]["eval_iter_freq"]], axis=0)
    else:
        x = range(metrics[metric][:,:show_max_iters].shape[1])
        y = np.mean(metrics[metric][:,:show_max_iters], axis=0)
        y_min = np.min(metrics[metric][:,:show_max_iters], axis=0)
        y_max = np.max(metrics[metric][:,:show_max_iters], axis=0)

    reg_func_name = config['meta_training']['opter_updates_reg_func'].__name__.replace("regularize_updates_", "") if config['meta_training']['opter_updates_reg_func'] is not None else "None"
    sns.lineplot(
        x=x,
        y=y,
        label=fr"L2O, $\alpha$={config['meta_training']['reg_mul']}",
        linewidth=1,
        ax=ax,
    )
    

### plot settings
ax.set_xlabel("Iteration")
if metric == "train_loss":
    metric_as_label = "Train Loss"
elif metric == "test_loss":
    metric_as_label = "Test Loss"
elif metric == "train_acc":
    metric_as_label = "Train Accuracy"
elif metric == "test_acc":
    metric_as_label = "Test Accuracy"
else:
    metric_as_label = metric
ax.set_ylabel(metric_as_label)

# set y to log scale
if log_loss and "loss" in metric:
    ax.set_yscale("log")

if "acc" in metric:
    ax.set_ylim(0.6, 1.0)
elif log_loss is not True:
    ax.set_ylim(0.0, None)

ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.25), ncol=3)
# legend = ax.get_legend()
# for legend_handle in legend.legendHandles:
#     legend_handle.set_linewidth(3.0)


# x-ticks
x_ticks = ax.get_xticks()
x_ticks = np.linspace(0, 200, 3)
ax.set_xticks(x_ticks)

### y-ticks
if log_loss is not True:
    y_max = ax.get_ylim()[1]
    y_max = np.ceil(y_max / 0.5) * 0.5 # round
    y_ticks = np.linspace(0, y_max, 3)
    ax.set_yticks(y_ticks)

plt.show()

### save the figure
if fig_dir is not None:
    fig.savefig(os.path.join(fig_dir, fig_name), bbox_inches="tight")

## Params, Grads, Updates
Plot the norm and mean abs value of parameters, gradients, and parameter updates across training

In [None]:
### config
ckpt_iter_freq = 10
max_iters = 500

In [None]:
### baseline optimizers
for baseline_name in baselines.keys():
    baselines[baseline_name]["params_grads"] = {k: {} for k in ["param_norms", "param_abs_means", "grad_norms", "grad_abs_means", "updates_norms", "updates_abs_means"]}
    opter_metrics = baselines[baseline_name]["metrics"]
    config = baselines[baseline_name]["config"]

    assert ckpt_iter_freq % config["meta_testing"]["ckpt_iter_freq"] == 0

    ### collect for all test runs
    for test_run_i in range(config["eval_n_tests"]):
        test_run_params_grads = {k: {} for k in ["param_norms", "param_abs_means", "grad_norms", "grad_abs_means", "updates_norms", "updates_abs_means"]}

        for iter_i in range(ckpt_iter_freq, max_iters + 1, ckpt_iter_freq):
            ckpt_dir = os.path.join(os.environ["CKPT_PATH"], "baselines", baselines[baseline_name]["baseline_dir"], "ckpt")
            optee, opter, optee_grads, loss_history = load_baseline_opter_ckpt(
                path=os.path.join(ckpt_dir, f"run{test_run_i}_{iter_i}.pt"),
                optee_cls=config["meta_testing"]["optee_cls"],
                opter_cls=config["meta_testing"]["baseline_opter_cls"],
                optee_config=config["meta_testing"]["optee_config"],
                opter_config=config["meta_testing"]["baseline_opter_config"]
            )
            optee_updates = get_baseline_opter_param_updates(optee, opter)
        
            for n, p in optee.all_named_parameters():
                if not p.requires_grad:
                    continue

                for k in test_run_params_grads:
                    if n not in test_run_params_grads[k]:
                        test_run_params_grads[k][n] = []

                test_run_params_grads["param_norms"][n].append(p.norm().item())
                test_run_params_grads["param_abs_means"][n].append(p.abs().mean().item())
                test_run_params_grads["grad_norms"][n].append(p.grad.norm().item())
                test_run_params_grads["grad_abs_means"][n].append(p.grad.abs().mean().item())
                test_run_params_grads["updates_norms"][n].append(optee_updates[n].norm().item())
                test_run_params_grads["updates_abs_means"][n].append(optee_updates[n].abs().mean().item())

        ### add to all test runs
        for k in test_run_params_grads:
            for n in test_run_params_grads[k]:
                if n not in baselines[baseline_name]["params_grads"][k]:
                    baselines[baseline_name]["params_grads"][k][n] = []
                baselines[baseline_name]["params_grads"][k][n].append(test_run_params_grads[k][n])

    ### convert to np arrays
    for k in baselines[baseline_name]["params_grads"]:
        for n in baselines[baseline_name]["params_grads"][k]:
            baselines[baseline_name]["params_grads"][k][n] = np.array(baselines[baseline_name]["params_grads"][k][n])

In [None]:
### l2o optimizers
for l2o_name in l2os:
    config = l2os[l2o_name]["config"]
    assert ckpt_iter_freq % config["meta_testing"]["ckpt_iter_freq"] == 0
    if "params_grads" in l2os[l2o_name] and len(l2os[l2o_name]["params_grads"]["param_norms"]) > 0:
        print(f"Skipping {l2o_name}")
        continue
    l2os[l2o_name]["params_grads"] = {k: {} for k in ["param_norms", "param_abs_means", "grad_norms", "grad_abs_means", "updates_norms", "updates_abs_means"]}
    
    ### collect for all test runs
    for test_run_i in range(config["eval_n_tests"]):
        test_run_params_grads = {k: {} for k in ["param_norms", "param_abs_means", "grad_norms", "grad_abs_means", "updates_norms", "updates_abs_means"]}

        for iter_i in range(ckpt_iter_freq, max_iters + 1, ckpt_iter_freq):
            ### load L2O optimizer
            optee, opter, optee_grads, optee_updates, loss_history = load_l2o_opter_ckpt(
                path=os.path.join(os.environ["CKPT_PATH"], config["meta_testing"]["ckpt_dir"], f"run{test_run_i}_{iter_i}.pt"),
                optee_cls=config["meta_testing"]["optee_cls"],
                opter_cls=config["opter_cls"],
                optee_config=config["meta_testing"]["optee_config"],
                opter_config=config["opter_config"],
            )
            # scale updates by optee update lr
            optee_updates = {n: p * config["meta_testing"]["optee_updates_lr"] for n, p in optee_updates.items()}

            for n, p in optee.all_named_parameters():
                if not p.requires_grad:
                    continue

                for k in test_run_params_grads:
                    if n not in test_run_params_grads[k]:
                        test_run_params_grads[k][n] = []
                test_run_params_grads["param_norms"][n].append(p.norm().item())
                test_run_params_grads["param_abs_means"][n].append(p.abs().mean().item())
                test_run_params_grads["grad_norms"][n].append(p.grad.norm().item())
                test_run_params_grads["grad_abs_means"][n].append(p.grad.abs().mean().item())
                test_run_params_grads["updates_norms"][n].append(optee_updates[n].norm().item())
                test_run_params_grads["updates_abs_means"][n].append(optee_updates[n].abs().mean().item())
        
        ### add to all test runs
        for k in test_run_params_grads:
            for n in test_run_params_grads[k]:
                if n not in l2os[l2o_name]["params_grads"][k]:
                    l2os[l2o_name]["params_grads"][k][n] = []
                l2os[l2o_name]["params_grads"][k][n].append(test_run_params_grads[k][n])
        
    ### convert to np arrays
    for k in l2os[l2o_name]["params_grads"]:
        for n in l2os[l2o_name]["params_grads"][k]:
            l2os[l2o_name]["params_grads"][k][n] = np.array(l2os[l2o_name]["params_grads"][k][n])

In [None]:
### config for plotting
params_grads_metrics = ["param_norms", "param_abs_means", "grad_norms", "grad_abs_means", "updates_norms", "updates_abs_means"]
save_to_dir = "../results/sym_breaking_regularization/MNISTReluBatchNorm_meta_training"

In [None]:
def smooth(y, box_pts=10):
    box = np.ones(box_pts) / box_pts
    y_smooth = np.convolve(y, box, mode="same")
    return y_smooth

for k in params_grads_metrics:
    fig = plt.figure(figsize=(22, 20))
    x_ticks = np.arange(ckpt_iter_freq, max_iters + 1, ckpt_iter_freq)

    for i, n in enumerate(baselines["SGD"]["params_grads"][k]):
        ax = fig.add_subplot(3, 2, i + 1)

        for baseline_name, baseline_dict in baselines.items():
            if n not in baseline_dict["params_grads"][k]:
                continue
            # sns.lineplot(x=x_ticks, y=smooth(baseline_dict["params_grads"][k][n][:max_iters]), alpha=0.8, linewidth=1.5, linestyle="--", ax=ax, label=f"{baseline_name}")
            y_mean = np.mean(baseline_dict["params_grads"][k][n][:, :max_iters // ckpt_iter_freq], axis=0)
            sns.lineplot(x=x_ticks, y=y_mean, alpha=0.8, linewidth=1.5, linestyle="--", ax=ax, label=f"{baseline_name}")

            ### add error bars
            y_std = np.std(baseline_dict["params_grads"][k][n][:, :max_iters // ckpt_iter_freq], axis=0)
            ax.fill_between(x=x_ticks, y1=y_mean - y_std, y2=y_mean + y_std, alpha=0.2)

        for l2o_name, l2o_dict in l2os.items():
            if n not in l2o_dict["params_grads"][k]:
                continue
            # sns.lineplot(x=x_ticks, y=smooth(l2o_dict["params_grads"][k][n][:max_iters]), alpha=0.8, linewidth=1.5, ax=ax, label=f"{l2o_name}")
            y_mean = np.mean(l2o_dict["params_grads"][k][n][:, :max_iters // ckpt_iter_freq], axis=0)
            sns.lineplot(x=x_ticks, y=y_mean, alpha=0.8, linewidth=1.5, ax=ax, label=f"{l2o_name}")

            ### add error bars
            y_std = np.std(l2o_dict["params_grads"][k][n][:, :max_iters // ckpt_iter_freq], axis=0)
            ax.fill_between(x=x_ticks, y1=y_mean - y_std, y2=y_mean + y_std, alpha=0.2)

        ax.set_xlabel("Iteration")
        ax.set_title(k.replace("_", " ").title() + ": " + n, fontsize=13, fontweight="bold")
        ax.legend()

    plt.show()

    if save_to_dir is not None:
        ### save the figure
        fig.tight_layout()
        fig.savefig(os.path.join(save_to_dir, f"{k}.png"), bbox_inches="tight")

## Conservation Law Breaking

In [None]:
phase = "meta_testing"
max_iters = 500
reg_func = regularize_translation_conservation_law_breaking

### collect deviations
for l2o_name in l2os:
    config = l2os[l2o_name]["config"]
    l2os[l2o_name][reg_func.__name__] = np.array(collect_conservation_law_deviations(
        func=reg_func,
        opter_cls=Optimizer,
        opter_config=config["opter_config"],
        optee_cls=config["meta_testing"]["optee_cls"],
        optee_config=config["meta_testing"]["optee_config"],
        ckpt_iter_freq=config["meta_testing"]["ckpt_iter_freq"],
        n_iters=config["meta_testing"]["n_iters"],
        ckpt_path_prefix=os.path.join(os.environ["CKPT_PATH"], config["meta_testing"]["ckpt_dir"], ""),
        is_l2o=True,
        max_iters=max_iters,
    ))

### Baseline optimizers
for baseline_name in baselines:
    config = baselines[baseline_name]["config"]
    ### collect deviations
    baselines[baseline_name][reg_func.__name__] = np.array(collect_conservation_law_deviations(
        func=reg_func,
        opter_cls=config["meta_testing"]["baseline_opter_config"],
        opter_config=config["meta_testing"]["baseline_opter_config"],
        optee_cls=config["meta_testing"]["optee_cls"],
        optee_config=baselines[baseline_name]["config"]["meta_testing"]["optee_config"],
        ckpt_iter_freq=baselines[baseline_name]["config"]["meta_testing"]["ckpt_iter_freq"],
        n_iters=baselines[baseline_name]["config"]["meta_testing"]["n_iters"],
        ckpt_path_prefix=os.path.join(os.environ["CKPT_PATH"], baselines[baseline_name]["baseline_dir"], "ckpt/"),
        is_l2o=False,
        max_iters=max_iters,
    ))

## Breaking Geometric Constraints on Gradients

In [None]:
phase = "meta_testing"
max_iters = None
collect_func = collect_translation_sym_deviations

### L2O optimizers
for l2o_name in l2os:
    config = l2os[l2o_name]["config"]
    l2os[l2o_name][collect_func.__name__ + "_grads"] = []
    l2os[l2o_name][collect_func.__name__ + "_updates"] = []
    for test_run_i in range(config["eval_n_tests"]):
        grad_deviations, param_update_deviations = collect_func(
            ckpt_iter_freq=config["meta_testing"]["ckpt_iter_freq"],
            n_iters=config["meta_testing"]["n_iters"],
            optee_cls=config["meta_testing"]["optee_cls"],
            opter_cls=config["opter_cls"],
            optee_config=config["meta_testing"]["optee_config"],
            opter_config=config["opter_config"],
            phase="meta_testing",
            ckpt_path_prefix=os.path.join(os.environ["CKPT_PATH"], config["meta_testing"]["ckpt_dir"], f"run{test_run_i}_"),
            max_iters=max_iters,
        )
        if np.ndim(grad_deviations) == 2:
            ### sum the deviations (weight and bias)
            grad_deviations = grad_deviations.sum(-1)
            param_update_deviations = param_update_deviations.sum(-1)
        l2os[l2o_name][collect_func.__name__ + "_grads"].append(grad_deviations)
        l2os[l2o_name][collect_func.__name__ + "_updates"].append(param_update_deviations)

### Baseline optimizers
for baseline_name in baselines:
    if "baseline_opter_cls" in baselines[baseline_name]["config"]:
        baseline_opter_cls = baselines[baseline_name]["config"]["baseline_opter_cls"]
    elif "sgd" in baseline_name.lower():
        baseline_opter_cls = optim.SGD
    elif "adam" in baseline_name.lower():
        baseline_opter_cls = optim.Adam
    else:
        raise NotImplementedError
    config = baselines[baseline_name]["config"]
    baseline_opter_config = baselines[baseline_name]["config"]["meta_testing"]["baseline_opter_config"]
    
    ### collect deviations
    baselines[baseline_name][collect_func.__name__ + "_grads"] = []
    baselines[baseline_name][collect_func.__name__ + "_updates"] = []
    for test_run_i in range(config["eval_n_tests"]):
        grad_deviations, param_update_deviations = collect_func(
            ckpt_iter_freq=config["meta_testing"]["ckpt_iter_freq"],
            n_iters=config["meta_testing"]["n_iters"],
            optee_cls=config["meta_testing"]["optee_cls"],
            optee_config=config["meta_testing"]["optee_config"],
            opter_cls=baseline_opter_cls,
            opter_config=baseline_opter_config,
            phase="meta_testing",
            ckpt_path_prefix=os.path.join(baselines[baseline_name]["baseline_dir"], f"ckpt/run{test_run_i}_"),
            max_iters=max_iters,
        )
        if np.ndim(grad_deviations) == 2:
            ### sum the deviations (weight and bias)
            grad_deviations = grad_deviations.sum(-1)
            param_update_deviations =  param_update_deviations.sum(-1)
        baselines[baseline_name][collect_func.__name__ + "_grads"].append(grad_deviations)
        baselines[baseline_name][collect_func.__name__ + "_updates"].append(param_update_deviations)

## Plotting - Deviations

In [None]:
def plot_deviations(dict_key_to_plot, log=False, abs_values=False, max_iters=None, save_fig_to_path=None):
    ### plot comparison
    fig = plt.figure()
    ax = fig.add_subplot(111)
    x_ticks = np.arange(
        0,
        min(max_iters + 1, config["meta_testing"]["n_iters"] + 1) if max_iters != None else config["meta_testing"]["n_iters"] + 1,
        config["meta_testing"]["ckpt_iter_freq"]
    )

    ### L2O optimizers
    for l2o_name, l2o_dict in l2os.items():
        to_plot = np.array(
            l2o_dict[dict_key_to_plot] if not abs_values else np.abs(l2o_dict[dict_key_to_plot])
        )[:, :len(x_ticks)]
        sns.lineplot(
            x=x_ticks,
            y=to_plot.mean(0),
            label=l2o_name,
            ax=ax,
        )
        ax.fill_between(
            x=x_ticks,
            y1=to_plot.mean(0) - to_plot.std(0),
            y2=to_plot.mean(0) + to_plot.std(0),
            alpha=0.2,
        )

    ### baseline optimizers
    for baseline_name, baseline_dict in baselines.items():
        to_plot = np.array(
            baseline_dict[dict_key_to_plot] if not abs_values else np.abs(baseline_dict[dict_key_to_plot])
        )[:, :len(x_ticks)]
        sns.lineplot(
            x=x_ticks,
            y=to_plot.mean(0),
            label=baseline_name,
            linestyle="--",
            linewidth=1.5,
            ax=ax,
        )
        ax.fill_between(
            x=x_ticks,
            y1=to_plot.mean(0) - to_plot.std(0),
            y2=to_plot.mean(0) + to_plot.std(0),
            alpha=0.2,
        )

    ### plot settings
    ax.set_xlabel("Iteration")
    ax.set_ylabel("Deviation" if not abs_values else "Absolute Deviation")

    # set y to log scale
    if log:
        ax.set_yscale("log")

    # ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.25), ncol=2)
    ax.legend(loc="upper right")
    legend = ax.get_legend()
    for legend_handle in legend.legendHandles:
        legend_handle.set_linewidth(3.0)

    # x-ticks
    x_ticks = ax.get_xticks()
    x_ticks = np.linspace(0, max_iters, 3)
    ax.set_xticks(x_ticks)
    
    # y-ticks
    # y_max = ax.get_ylim()[1]
    # y_max = np.ceil(y_max / 100) * 100 # round
    # y_ticks = np.linspace(0, y_max, 4)
    # ax.set_yticks(y_ticks)
    # y_ticks = ax.get_yticks()
    # ax.set_yticks(y_ticks[::3])
    # ax.set_ylim(0, 90)
    # ax.set_yticks([0, 45, 90])

    plt.show()

    ### save the figure
    if save_fig_to_path is not None:
        fig.savefig(save_fig_to_path, bbox_inches="tight")

In [None]:
key_to_plot = collect_func.__name__ + "_updates"
# key_to_plot = reg_func.__name__
save_fig_to_path = os.path.join(
    "..",
    "results",
    "publication",
    "constraint_deviations",
    "translation_sym_meta_testing_MNISTNet_{}_200.pdf"
)
save_fig_to_path = None

plot_deviations(
    key_to_plot,
    log=False,
    abs_values=True,
    max_iters=200,
    save_fig_to_path=save_fig_to_path,
)

## Spectral Analysis

In [None]:
### collect for L2Os
for opter_name in l2os:
    print(f"Running {opter_name}...")
    config = l2os[opter_name]["config"]
    ckpts_dir = os.path.join(os.environ["CKPT_PATH"], config["meta_testing"]["ckpt_dir"])

    run_history_all_test_runs = []
    # for test_run_i in range(config["eval_n_tests"]):
    for test_run_i in range(1):
        print(f"  Test run {test_run_i}...")
        # ckpt_prefix = f"run{test_run_i}_"
        ckpt_prefix = f""
        
        history = {k: [] for k in ["hidden_rnn_1", "hidden_rnn_2", "cell_rnn_1", "cell_rnn_2", "updates"]}
        for iter_i in range(
            config["meta_testing"]["ckpt_iter_freq"],
            config["meta_testing"]["n_iters"] + 1,
            config["meta_testing"]["ckpt_iter_freq"],
        ):
            ### load checkpoint
            ckpt_path = os.path.join(ckpts_dir, f"{ckpt_prefix}{iter_i}.pt")
            ckpt = torch.load(ckpt_path, map_location="cpu")
            
            hidden_states = ckpt["hidden_states"] # (2, num_params, hidden_size)
            cell_states = ckpt["cell_states"] # (2, num_params, hidden_size)
            optee_updates = ckpt["optimizee_updates"]

            history["hidden_rnn_1"].append(hidden_states[0].mean(dim=0).tolist())
            history["hidden_rnn_2"].append(hidden_states[1].mean(dim=0).tolist())

            history["cell_rnn_1"].append(cell_states[0].mean(dim=0).tolist())
            history["cell_rnn_2"].append(cell_states[1].mean(dim=0).tolist())
        
        for k in history:
            history[k] = np.array(history[k])

In [None]:
from scipy import signal

to_plot = "hidden_rnn_1"
nperseg = 20
noverlap = 19
nfft = 20

fig = plt.figure(figsize=(15, 15))

for channel_idx in range(history[to_plot].shape[-1]):
    ax = fig.add_subplot(5, 4, channel_idx + 1)
    f, t, Sxx = signal.spectrogram(
        x=history[to_plot][:, channel_idx],
        fs=1.0,
        window="hann",
        nperseg=nperseg,
        noverlap=noverlap,
        nfft=nfft,
    )
    ax.pcolormesh(t, f, Sxx, shading="gouraud")
    ax.set_ylabel("Frequency [Hz]")
    ax.set_xlabel("Iteration")
    ax.set_title(f"Channel {channel_idx}")
    
fig.tight_layout()
plt.show()

## Heavy-tail gradient/update noise
- From the [paper](http://proceedings.mlr.press/v97/simsekli19a/simsekli19a.pdf) *A Tail-Index Analysis of Stochastic Gradient Noise in Deep Neural Networks. U. Simsekli, L. Sagun, M. Gurbuzbalaban.In Proceedings of the 36th International Conference on Machine Learning, (ICML) 2019.*
- [GitHub repository](https://github.com/umutsimsekli/sgd_tail_index)

In [None]:
def eval_metrics(optee, data_loader, opter=None, hidden_states=None, cell_states=None, optee_updates_lr=None):
    ### calculate gradient (and update) noise over the the given data loader
    optee.eval()
    tmp_optee_optim = optim.SGD(optee.parameters(), lr=0.0) # just for zeroing out the gradients
    grads = []
    param_updates = []
    losses = []
    accs = []
    n_minibatches = 0
    for i, (x, y) in enumerate(data_loader):
        n_minibatches += 1
        tmp_optee_optim.zero_grad()
        
        x, y = x.view(-1, 784).cuda(), y.cuda()
        loss, acc = optee(inp=x, out=y, return_acc=True)
        loss.backward()

        ### collect gradients
        grads.append(torch.cat([p.grad.detach().view(-1) for n, p in optee.all_named_parameters() if p.requires_grad]).cpu())

        ### collect updates
        if opter is not None:
            if isinstance(opter, Optimizer):
                ### L2O
                curr_l2o_updates = []
                offset = 0
                for name, p in optee.all_named_parameters():
                    if p.requires_grad == False: # batchnorm stats
                        continue

                    cur_sz = int(np.prod(p.size()))
                    gradients = p.grad.detach().view(cur_sz, 1)
                    updates, _, _ = opter(
                        optee_grads=gradients,
                        hidden=[h[offset : offset + cur_sz] for h in hidden_states],
                        cell=[c[offset : offset + cur_sz] for c in cell_states],
                        additional_inp=None,
                    )
                    offset += cur_sz
                    curr_l2o_updates.append(optee_updates_lr * updates.detach().view(-1))
                param_updates.append(torch.cat(curr_l2o_updates).cpu())
            else:
                ### baseline optimizer
                param_updates.append(
                    torch.cat([
                        p.detach().view(-1).cpu() for p in get_baseline_opter_param_updates(optee, opter, verbose=False).values()
                    ], dim=0)
                )

        ### track history
        losses.append(loss.item())
        accs.append(acc.item())

    optee_total_params = len(grads[0])

    ### gradients
    grads = torch.stack(grads, dim=0) # (n_minibatches, optee_total_params)
    mean_grad = grads.mean(dim=0) # (optee_total_params,)
    grads_noise_norm = (grads - mean_grad).norm(dim=1) # (n_minibatches,)
    # get the tail index alpha
    N = optee_total_params * n_minibatches
    for i in range(1, 1 + int(np.sqrt(N))):
        if N % i == 0:
            m = i
    grads_alpha = alpha_estimator(m, (grads - mean_grad).view(-1, 1))

    ### l2o updates
    updates_noise_norm, updates_alpha = None, None
    if opter is not None:
        param_updates = torch.stack(param_updates, dim=0) # (n_minibatches, optee_total_params)
        mean_l2o_updates = param_updates.mean(dim=0) # (optee_total_params,)
        updates_noise_norm = (param_updates - mean_l2o_updates).norm(dim=1) # (n_minibatches,)
        # get the tail index alpha
        N = optee_total_params * n_minibatches
        for i in range(1, 1 + int(np.sqrt(N))):
            if N % i == 0:
                m = i
        updates_alpha = alpha_estimator(m, (param_updates - mean_l2o_updates).view(-1, 1))

    return (
        losses,
        accs,
        grads_noise_norm,
        grads_alpha,
        updates_noise_norm,
        updates_alpha,
    )

In [None]:
def collect_grad_update_noise_for_l2o(ckpt_path, run_history, config, train_data, test_data):
    ckpt = torch.load(ckpt_path)
    optee = config["meta_testing"]["optee_cls"](**config["meta_testing"]["optee_config"]).cuda()
    optee.load_state_dict(ckpt["optimizee"])

    ### init l2o optimizer to collect noise in its updates
    opter = config["opter_cls"](**config["opter_config"]).cuda()
    opter.load_state_dict(ckpt["optimizer"])

    ### collect history
    for phase, data_loader in (("train", train_data.loader), ("test", test_data.loader)):
        losses, accs, grads_noise_norm, grads_alpha, updates_noise_norm, updates_alpha = eval_metrics(
            optee=optee,
            data_loader=data_loader,
            opter=opter,
            hidden_states=ckpt["hidden_states"],
            cell_states=ckpt["cell_states"],
            optee_updates_lr=config["meta_testing"]["optee_updates_lr"],
        )
        run_history[phase]["loss"].append(np.mean(losses))
        run_history[phase]["acc"].append(np.mean(accs))
        run_history[phase]["grads_noise_norm"].append(grads_noise_norm)
        run_history[phase]["grads_alpha"].append(grads_alpha.item())
        run_history[phase]["updates_noise_norm"].append(updates_noise_norm)
        run_history[phase]["updates_alpha"].append(updates_alpha.item())
    
    return run_history

In [None]:
load_existing = True
max_n_tests = 3
ckpt_iter_freq = 10
max_iters = 1000

### collect for L2Os
for opter_name in l2os:
    print(f"Running {opter_name}...")
    config = l2os[opter_name]["config"]
    config["meta_testing"]["n_iters"] = 1000
    ckpts_dir = os.path.join(os.environ["CKPT_PATH"], config["meta_testing"]["ckpt_dir"])
    save_run_history_path = os.path.join(
        os.environ["CKPT_PATH"],
        config["ckpt_base_dir"],
        f"grads_updates_noise_heavy_tail_alpha_estimates" +
            f"_{config['meta_testing']['optee_cls'].__name__}_{dict_to_str(config['meta_testing']['optee_config'])}" +
            f"_{config['meta_testing']['data_cls'].__name__}_{dict_to_str(config['meta_testing']['data_config'])}" +
            f"_{config['eval_n_tests']}_tests.pt"
    )

    if load_existing and os.path.exists(save_run_history_path):
        print(f"  Loading existing run history from {save_run_history_path}")
        run_history = torch.load(save_run_history_path)
        l2os[opter_name]["run_history"] = run_history
        continue

    run_history_all_test_runs = []
    for test_run_i in range(config["eval_n_tests"]):
        if max_n_tests and test_run_i >= max_n_tests:
            break
        print(f"  Test run {test_run_i}...")
        ckpt_prefix = f"run{test_run_i}_"
    
        ### collect
        run_history = {k: {k: [] for k in ("loss", "acc", "grads_noise_norm", "grads_alpha", "updates_noise_norm", "updates_alpha")}
            for k in ("train", "test")}
        train_data = MNIST(training=True, batch_size=128)
        test_data = MNIST(training=False, batch_size=128)

        ### load checkpoints within this test run
        for iter_i in [1, *range(
            ckpt_iter_freq,
            min(max_iters, config["meta_testing"]["n_iters"]) + 1,
            ckpt_iter_freq,
        )]:
        # for iter_i in [5, *range(
        #     ckpt_iter_freq,
        #     min(max_iters, config["meta_testing"]["n_iters"]) + 1,
        #     ckpt_iter_freq,
        # )]:
            print(f"  [{iter_i}/{min(max_iters, config['meta_testing']['n_iters'])}]")

            ### load checkpoint
            ckpt_path = os.path.join(ckpts_dir, f"{ckpt_prefix}{iter_i}.pt")
            run_history = collect_grad_update_noise_for_l2o(
                ckpt_path=ckpt_path,
                run_history=run_history,
                config=config,
                train_data=train_data,
                test_data=test_data
            )

        ### save current test run
        run_history_all_test_runs.append(run_history)

    ### save
    l2os[opter_name]["run_history"] = run_history_all_test_runs
    torch.save(run_history_all_test_runs, save_run_history_path)

In [None]:
def collect_grad_update_noise_for_normal(ckpt_path, run_history, config, train_data, test_data):
    ckpt = torch.load(ckpt_path)
    
    ### init optee
    optee = config["meta_testing"]["optee_cls"](**config["meta_testing"]["optee_config"]).cuda()
    optee.load_state_dict(ckpt["optimizee"])

    ### init opter
    opter = config["meta_testing"]["baseline_opter_cls"](optee.parameters(), **config["meta_testing"]["baseline_opter_config"])
    opter.load_state_dict(ckpt["optimizer"])

    ### collect
    for phase, data_loader in (("train", train_data.loader), ("test", test_data.loader)):
        losses, accs, grads_noise_norm, grads_alpha, updates_noise_norm, updates_alpha = eval_metrics(
            optee=optee,
            data_loader=data_loader,
            opter=opter,
            hidden_states=None,
            cell_states=None
        )
        run_history[phase]["loss"].append(np.mean(losses))
        run_history[phase]["acc"].append(np.mean(accs))
        run_history[phase]["grads_noise_norm"].append(grads_noise_norm)
        run_history[phase]["grads_alpha"].append(grads_alpha.item())
        run_history[phase]["updates_noise_norm"].append(updates_noise_norm)
        run_history[phase]["updates_alpha"].append(updates_alpha.item())

    return run_history

In [None]:
load_existing = True
max_n_tests = 3
ckpt_iter_freq = 10
max_iters = 1000

### collect for baselines
for baseline_name in baselines:
    print(f"Running {baseline_name}...")
    config = baselines[baseline_name]["config"]
    save_run_history_path = os.path.join(
        baselines[baseline_name]["baseline_dir"],
        f"grad_noise_run_history_{config['eval_n_tests']}_tests.pt"
    )

    if load_existing and os.path.exists(save_run_history_path):
        print(f"  Loading existing run history from {save_run_history_path}")
        run_history = torch.load(save_run_history_path)
        baselines[baseline_name]["run_history"] = run_history
        continue
    
    run_history_all_test_runs = []
    for test_run_i in range(config["eval_n_tests"]):
        if max_n_tests and test_run_i >= max_n_tests:
            break
        print(f"  Test run {test_run_i}...")
        ckpt_prefix = f"run{test_run_i}_"
        ### collect
        run_history = {k: {k: [] for k in ("loss", "acc", "grads_noise_norm", "grads_alpha", "updates_noise_norm", "updates_alpha")}
            for k in ("train", "test")}
        train_data = MNIST(training=True, batch_size=128)
        test_data = MNIST(training=False, batch_size=128)

        ### load checkpoints within this test run
        for iter_i in [1, *range(
            ckpt_iter_freq,
            min(max_iters, config["meta_testing"]["n_iters"]) + 1,
            ckpt_iter_freq,
        )]:
            print(f"  [{iter_i}/{min(max_iters, config['meta_testing']['n_iters'])}]")
            
            ### load checkpoint
            ckpt_path = os.path.join(baselines[baseline_name]["baseline_dir"], "ckpt", f"{ckpt_prefix}{iter_i}.pt")
            run_history = collect_grad_update_noise_for_normal(
                ckpt_path=ckpt_path,
                run_history=run_history,
                config=config,
                train_data=train_data,
                test_data=test_data
            )
        
        ### save current test run
        run_history_all_test_runs.append(run_history)

    ### save
    baselines[baseline_name]["run_history"] = run_history_all_test_runs
    torch.save(run_history_all_test_runs, save_run_history_path)

In [None]:
### preprocess
### turn l2os[opter_name]["run_history"][<number-of-test>]["train"][<metric>] into l2os[opter_name]["run_history"]["train"][<metric>][<number-of-test>]
for opter_name in l2os:
    new_train_run_history = dict()
    new_test_run_history = dict()
    for test_run_i in range(min(max_n_tests, l2os[opter_name]["config"]["eval_n_tests"])):
        for metric_name in l2os[opter_name]["run_history"][test_run_i]["train"]:
            if test_run_i == 0:
                new_train_run_history[metric_name] = [l2os[opter_name]["run_history"][test_run_i]["train"][metric_name]]
                new_test_run_history[metric_name] = [l2os[opter_name]["run_history"][test_run_i]["test"][metric_name]]
            else:
                new_train_run_history[metric_name].append(l2os[opter_name]["run_history"][test_run_i]["train"][metric_name])
                new_test_run_history[metric_name].append(l2os[opter_name]["run_history"][test_run_i]["test"][metric_name])
    l2os[opter_name]["run_history"] = dict()
    l2os[opter_name]["run_history"]["train"] = new_train_run_history
    l2os[opter_name]["run_history"]["test"] = new_test_run_history

for baseline_name in baselines:
    new_train_run_history = dict()
    new_test_run_history = dict()
    for test_run_i in range(min(max_n_tests, baselines[baseline_name]["config"]["eval_n_tests"])):
        for metric_name in baselines[baseline_name]["run_history"][test_run_i]["train"]:
            if test_run_i == 0:
                new_train_run_history[metric_name] = [baselines[baseline_name]["run_history"][test_run_i]["train"][metric_name]]
                new_test_run_history[metric_name] = [baselines[baseline_name]["run_history"][test_run_i]["test"][metric_name]]
            else:
                new_train_run_history[metric_name].append(baselines[baseline_name]["run_history"][test_run_i]["train"][metric_name])
                new_test_run_history[metric_name].append(baselines[baseline_name]["run_history"][test_run_i]["test"][metric_name])
    baselines[baseline_name]["run_history"] = dict()
    baselines[baseline_name]["run_history"]["train"] = new_train_run_history
    baselines[baseline_name]["run_history"]["test"] = new_test_run_history

In [None]:
### config for plotting
show_max_iters = max_iters
# show_max_iters = 200
plot_l2o_grad_alpha = True
plot_l2o_update_alpha = True
plot_baseline_grad_alpha = True
plot_baseline_update_alpha = True
optee_name = "MNISTNet"
phase = "train"
save_fig_to_path = os.path.join(
    "../results/heavy_tail_grad_update_noise/publication",
    f"update_noise_alpha_estimates_{optee_name}_{phase}_{show_max_iters}.pdf"
)
save_fig_to_path = None
print(f"Saving figure to {save_fig_to_path}")

In [None]:
### plot
x_ticks = range(
    ckpt_iter_freq,
    show_max_iters + ckpt_iter_freq + 1 \
        if show_max_iters is None or config["meta_testing"]["n_iters"] <= show_max_iters \
        else show_max_iters + 1,
    ckpt_iter_freq
)

### plot alpha estimates
fig = plt.figure()
ax = fig.add_subplot(111)

for baseline_name in baselines:
    if plot_baseline_update_alpha:
        ### alpha estimates for param updates
        updates_alpha = np.array(baselines[baseline_name]["run_history"][phase]["updates_alpha"])[:, :len(x_ticks)]
        sns.lineplot(x=x_ticks, y=updates_alpha.mean(0), label=f"{baseline_name} - updates", linestyle="--", ax=ax)
        ax.fill_between(
            x=x_ticks,
            y1=updates_alpha.mean(0) - updates_alpha.std(0),
            y2=updates_alpha.mean(0) + updates_alpha.std(0),
            alpha=0.2
        )

    if plot_baseline_grad_alpha:
        ### alpha estimates for gradients
        grads_alpha = np.array(baselines[baseline_name]["run_history"][phase]["grads_alpha"])[:, :len(x_ticks)] # (n_tests, n_iters)
        sns.lineplot(x=x_ticks, y=grads_alpha.mean(0), label=f"{baseline_name} - gradients", linestyle="--", ax=ax)
        ax.fill_between(
            x=x_ticks,
            y1=grads_alpha.mean(0) - grads_alpha.std(0),
            y2=grads_alpha.mean(0) + grads_alpha.std(0),
            alpha=0.2
        )

for opter_name in l2os:
    if plot_l2o_update_alpha:
        ### alpha estimates for param updates
        updates_alpha = np.array(l2os[opter_name]["run_history"][phase]["updates_alpha"])[:, :len(x_ticks)] # (n_tests, n_iters)
        sns.lineplot(x=x_ticks, y=updates_alpha.mean(0), label=f"{opter_name} - updates", ax=ax)
        # sns.lineplot(x=x_ticks, y=updates_alpha.mean(0), label=f"L2O - updates", ax=ax)
        ax.fill_between(
            x=x_ticks,
            y1=updates_alpha.mean(0) - updates_alpha.std(0),
            y2=updates_alpha.mean(0) + updates_alpha.std(0),
            alpha=0.2
        )
    
    if plot_l2o_grad_alpha:
        ### alpha estimates for gradients
        grads_alpha = np.array(l2os[opter_name]["run_history"][phase]["grads_alpha"])[:, :len(x_ticks)] # (n_tests, n_iters)
        sns.lineplot(x=x_ticks, y=grads_alpha.mean(0), label=f"{opter_name} - gradients", ax=ax)
        # sns.lineplot(x=x_ticks, y=grads_alpha.mean(0), label=f"L2O - gradients", ax=ax)
        ax.fill_between(
            x=x_ticks,
            y1=grads_alpha.mean(0) - grads_alpha.std(0),
            y2=grads_alpha.mean(0) + grads_alpha.std(0),
            alpha=0.2
        )

# ax.set_title(f"Alpha estimates ({phase})")
ax.set_xlabel("Iteration")
ax.set_ylabel("Alpha estimate")


### manual
ax.set_ylim(0.4, 1.06)
# yticks = ax.get_yticks()
# ax.set_yticks(yticks[::3])
# ax.set_yticks([0.5, 0.65, 0.8, 1])
# ax.set_xticks(np.arange(0, 1001 if show_max_iters is None else show_max_iters + 1, 500))

### automatic (3 x-ticks)
x_ticks = ax.get_xticks()
ax.set_xticks([0, show_max_iters // 2, show_max_iters])

# legend
ax.legend(loc="upper center", bbox_to_anchor=(0.5, 1.25), ncol=3)
legend = ax.get_legend()
for legend_handle in legend.legendHandles:
    legend_handle.set_linewidth(3.0)
for legend_text in legend.get_texts():
    legend_text.set_fontsize(9)

plt.show()

# save fig
if save_fig_to_path:
    fig.savefig(save_fig_to_path, bbox_inches="tight")

### Validating results from the paper

In [None]:
class FullyConnected(nn.Module):
    def __init__(self, input_dim=28*28 , width=50, depth=3, num_classes=10):
        super(FullyConnected, self).__init__()
        self.input_dim = input_dim 
        self.width = width
        self.depth = depth
        self.num_classes = num_classes
        
        layers = self.get_layers()

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, self.width, bias=False),
            nn.ReLU(inplace=True),
            *layers,
            nn.Linear(self.width, self.num_classes, bias=False),
        )

    def get_layers(self):
        layers = []
        for i in range(self.depth - 2):
            layers.append(nn.Linear(self.width, self.width, bias=False))
            layers.append(nn.ReLU())
        return layers

    def forward(self, x):
        x = x.view(x.size(0), self.input_dim)
        x = self.fc(x)
        return x

In [None]:
def get_data():
    from torchvision import datasets, transforms
    data_class = 'MNIST'
    num_classes = 10
    stats = {
        'mean': [0.1307], 
        'std': [0.3081]
        }

    # input transformation w/o preprocessing for now

    trans = [
        transforms.ToTensor(),
        lambda t: t.type(torch.get_default_dtype()),
        transforms.Normalize(**stats)
        ]
        
    # get tr and te data with the same normalization
    tr_data = getattr(datasets, data_class)(
        root=os.environ['DATA_PATH'], 
        train=True, 
        download=False,
        transform=transforms.Compose(trans)
        )

    te_data = getattr(datasets, data_class)(
        root=os.environ['DATA_PATH'], 
        train=False, 
        download=False,
        transform=transforms.Compose(trans)
        )

    # get tr_loader for train/eval and te_loader for eval
    train_loader = torch.utils.data.DataLoader(
        dataset=tr_data,
        batch_size=100, 
        shuffle=False,
        )

    train_loader_eval = torch.utils.data.DataLoader(
        dataset=tr_data,
        batch_size=100, 
        shuffle=False,
        )

    test_loader_eval = torch.utils.data.DataLoader(
        dataset=te_data,
        batch_size=100, 
        shuffle=False,
        )

    return train_loader, test_loader_eval, train_loader_eval, num_classes

In [None]:
### run separately
run_history = {k: [] for k in ("train_loss", "train_acc", "noise_norm", "alpha")}
# train_data = MNIST(training=True, batch_size=100)
# test_data = MNIST(training=False, batch_size=100)
train_loader, test_loader_eval, train_loader_eval, num_classes = get_data()

# optee = MNISTRelu().cuda()
optee = FullyConnected(width=20, depth=1).cuda()
optee_optim = optim.SGD(optee.parameters(), lr=0.1)
loss_fn = nn.CrossEntropyLoss()

def eval():
    optee.eval()
    grads = []
    losses = []
    accs = []
    n_minibatches = 0
    # for data in (train_data,):
    for i, (x, y) in enumerate(train_loader_eval):
        n_minibatches += 1
        
        x, y = x.view(-1, 784).cuda(), y.cuda()
        # loss, acc = optee(x, out=y, return_acc=True)
        y_hat = optee(x)
        loss = loss_fn(y_hat, y)
        loss.backward()

        ### collect gradients
        # grads.append(torch.cat([p.grad.detach().view(-1) for n, p in optee.all_named_parameters() if p.requires_grad]).cpu())
        grads.append(torch.cat([p.grad.detach().view(-1) for n, p in optee.named_parameters() if p.requires_grad]).cpu())

        ### track history
        losses.append(loss.item())
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        accs.append(acc.item())
        
        optee_optim.zero_grad()

    optee_total_params = len(grads[0])
    grads = torch.stack(grads, dim=0) # (n_minibatches, optee_total_params)
    mean_grad = grads.mean(dim=0) # (optee_total_params,)
    noise_norm = (grads - mean_grad).norm(dim=1) # (n_minibatches,)

    ### get the tail index alpha
    N = optee_total_params * n_minibatches
    for i in range(1, 1 + int(np.sqrt(N))):
        if N % i == 0:
            m = i
    alpha = alpha_estimator(m, (grads - mean_grad).view(-1, 1))

    ### collect history
    run_history["train_loss"].append(np.mean(losses))
    run_history["train_acc"].append(np.mean(accs))
    run_history["noise_norm"].append(noise_norm)
    run_history["alpha"].append(alpha.item())

def cyclic_loader(loader):
    while True:
        for x in loader:
            yield x

cyclic_train_loader = cyclic_loader(train_loader)

for i, (x, y) in enumerate(cyclic_train_loader):
    if i % 100 == 0:
        eval()
        print(i, run_history["train_loss"][-1], run_history["train_acc"][-1], run_history["alpha"][-1])
    
    optee.train()
    x, y = x.view(-1, 784).cuda(), y.cuda()
    # loss, acc = optee(inp=x, out=y, return_acc=True)
    y_hat = optee(x)
    loss = loss_fn(y_hat, y)
    loss.backward()
    optee_optim.step()
    optee_optim.zero_grad()



In [None]:
fig = plt.figure()
plt.plot(run_history["train_loss"])
plt.show()

fig = plt.figure()
plt.plot(run_history["train_acc"])
plt.show()

fig = plt.figure()
plt.plot(run_history["alpha"])
plt.show()

## Heavy-tail distribution of parameters
- From the [paper](http://proceedings.mlr.press/v139/gurbuzbalaban21a/gurbuzbalaban21a.pdf) *The heavy-tail phenomenon in SGD by Gurbuzbalaban, M., Simsekli, U., & Zhu, L.. In International Conference on Machine Learning (ICML), 2021.*

In [None]:
def get_sampling_ms(N):
    ms = [2]
    select_ms_closest_to_curr_idx = 0
    select_ms_closest_to = [5, 10, 20, 50, 100, 500, 1000]
    for i in range(3, 1 + int(np.sqrt(N))):
        if N % i == 0 \
            and i > ms[-1] \
            and ms[-1] < select_ms_closest_to[select_ms_closest_to_curr_idx] \
            and i >= select_ms_closest_to[select_ms_closest_to_curr_idx]:
            ms.append(i)
            select_ms_closest_to_curr_idx += 1
            if select_ms_closest_to_curr_idx >= len(select_ms_closest_to):
                break
    return ms

In [None]:
def collect_parameter_heavy_tail_alpha_estimates(
    config,
    ckpts_dir,
    iters_window,
    ckpt_iter_freq,
    iter_print_freq=50,
):
    ### collect
    parameter_history_metrics = {k: [] for k in ("loss", "acc", "params_alpha_total")}
    for k in ("params_noise_norm_per_iter", "params_alpha_per_iter"):
        parameter_history_metrics[k] = dict()
    
    _tmp_optee = config["meta_testing"]["optee_cls"](**config["meta_testing"]["optee_config"])
    params = {iter_i: {n: [] for n, p in _tmp_optee.named_parameters() if p.requires_grad}
        for iter_i in range(iters_window[0], iters_window[1], ckpt_iter_freq)}
    del _tmp_optee

    for test_run_i in range(config["eval_n_tests"]):
        print(f"  Test run {test_run_i}...")
        for iter_i in range(
            iters_window[0],
            iters_window[1],
            ckpt_iter_freq
        ):
            if iter_i % iter_print_freq == 0:
                print(f"  [COLLECTING-PARAMS-{test_run_i}][{iter_i}/{iters_window[1]}]")

            ### load checkpoint
            ckpt_path = os.path.join(ckpts_dir, f"run{test_run_i}_{iter_i}.pt")
            ckpt = torch.load(ckpt_path, map_location="cpu")

            ### load optee and collect parameters
            for param_name in ckpt["optimizee"]:
                if "mean" not in param_name and "var" not in param_name: # skip batch norm params
                    params[iter_i][param_name].append(ckpt["optimizee"][param_name].detach().view(-1).cpu())
    
    ### post-process - calculate noise norm and alpha estimates per iteration
    print("Post-processing (per iteration)...")
    for iter_i in range(
        iters_window[0],
        iters_window[1],
        ckpt_iter_freq
    ):
        if iter_i % iter_print_freq == 0:
            print(f"  [POSTPROCESSING-PER-ITER][{iter_i}/{iters_window[1]}]")
        for n in params[iter_i].keys():
            # get alpha estimate
            params[iter_i][n] = torch.stack(params[iter_i][n]) # (eval_n_tests, param_dim)
            mean_param = params[iter_i][n].mean(dim=0) # (eval_n_tests, param_dim) -> (param_dim,)
            N = params[iter_i][n].shape[0] * params[iter_i][n].shape[1]
            for i in range(1, 1 + int(np.sqrt(N))):
                if N % i == 0:
                    m = i
            alpha = alpha_estimator(m, (params[iter_i][n] - mean_param).view(-1, 1)) # (n_tests,)

            noise_norm = (params[iter_i][n] - mean_param).norm(dim=1) # (n_tests,)
            parameter_history_metrics["params_noise_norm_per_iter"][iter_i] = torch.mean(noise_norm).item()
            parameter_history_metrics["params_alpha_per_iter"][iter_i] = alpha.item()

    print("Post-processing (total)...")
    for n in params[iters_window[0]].keys():
        tmp_params = torch.cat([params[iter_i][n] for iter_i in range(
            iters_window[0],
            iters_window[1],
            ckpt_iter_freq
        )], dim=0) # (eval_n_tests * n_iters, param_dim)
        tmp_params = tmp_params.view(-1, 1)
        tmp_params = tmp_params - torch.mean(tmp_params) # center
        alpha = np.median([alpha_estimator(m, tmp_params) for m in get_sampling_ms(tmp_params.shape[0])])
        parameter_history_metrics["params_alpha_total"] = alpha
    
    return parameter_history_metrics

In [None]:
### config
load_existing = True
iters_window = (1, 1000)
ckpt_iter_freq = 10 # config["meta_testing"]["ckpt_iter_freq"]
iter_print_freq = 100

In [None]:
### collect for baselines
for baseline_name in baselines:
    print(f"Collecting for {baseline_name}...")
    config = baselines[baseline_name]["config"]

    ckpts_dir = os.path.join(baselines[baseline_name]["baseline_dir"], "ckpt")
    save_parameter_history_metrics_path = os.path.join(
        baselines[baseline_name]["baseline_dir"],
        f"parameter_heavy_tail_alpha_estimates" +
            f"_{iters_window[0]}-{iters_window[1]}" +
            f"_{config['meta_testing']['optee_cls'].__name__}_{dict_to_str(config['meta_testing']['optee_config'])}" +
            f"_{config['meta_testing']['data_cls'].__name__}_{dict_to_str(config['meta_testing']['data_config'])}" +
            f".pt"
    )

    if load_existing and os.path.exists(save_parameter_history_metrics_path):
        print(f"  Loading existing run history from {save_parameter_history_metrics_path}")
        baselines[baseline_name]["parameter_history_metrics"] = torch.load(save_parameter_history_metrics_path)
    else:
        print(f"  Existing run history at path {save_parameter_history_metrics_path} not found, collecting it...")
        baselines[baseline_name]["parameter_history_metrics"] = collect_parameter_heavy_tail_alpha_estimates(
            config=config,
            ckpts_dir=ckpts_dir,
            iters_window=iters_window,
            ckpt_iter_freq=ckpt_iter_freq,
            iter_print_freq=iter_print_freq,
        )
        torch.save(baselines[baseline_name]["parameter_history_metrics"], save_parameter_history_metrics_path)

In [None]:
### collect for l2os
for opter_name in l2os:
    print(f"Collecting for {opter_name}...")
    config = l2os[opter_name]["config"]

    # ckpts_dir = os.path.join(os.environ["CKPT_PATH"], config["meta_testing"]["ckpt_dir"])
    ckpts_dir = os.path.join(os.environ["CKPT_PATH"], config["meta_testing"]["ckpt_dir"] + "_long")
    save_parameter_history_metrics_path = os.path.join(
        os.environ["CKPT_PATH"],
        config["ckpt_base_dir"],
        f"parameter_heavy_tail_alpha_estimates" +
            f"_{iters_window[0]}-{iters_window[1]}" +
            f"_{config['meta_testing']['optee_cls'].__name__}_{dict_to_str(config['meta_testing']['optee_config'])}" +
            f"_{config['meta_testing']['data_cls'].__name__}_{dict_to_str(config['meta_testing']['data_config'])}" +
            f".pt"
    )

    if load_existing and os.path.exists(save_parameter_history_metrics_path):
        print(f"  Loading existing run history from {save_parameter_history_metrics_path}")
        l2os[opter_name]["parameter_history_metrics"] = torch.load(save_parameter_history_metrics_path)
    else:
        print(f"  Existing run history not found at path {save_parameter_history_metrics_path}, collecting it...")
        l2os[opter_name]["parameter_history_metrics"] = collect_parameter_heavy_tail_alpha_estimates(
            config=config,
            ckpts_dir=ckpts_dir,
            iters_window=iters_window,
            ckpt_iter_freq=ckpt_iter_freq,
            iter_print_freq=iter_print_freq,
        )
        torch.save(l2os[opter_name]["parameter_history_metrics"], save_parameter_history_metrics_path)

In [None]:
print(l2os["baseline_no_reg"]["parameter_history_metrics"]["params_alpha_total"])
plt.plot(list(l2os["baseline_no_reg"]["parameter_history_metrics"]["params_alpha_per_iter"].values())[0:])

In [None]:
print(baselines["SGD"]["parameter_history_metrics"]["params_alpha_total"])
plt.plot(list(baselines["SGD"]["parameter_history_metrics"]["params_alpha_per_iter"].values())[0:])
plt.ylim(0, 20)

#### Using the implementation from the authors
- [GitHub repository](https://github.com/umutsimsekli/sgd_ht/tree/main)

In [None]:
nets = []

### collect for baselines
for baseline_name in baselines:
    print(f"Collecting for {baseline_name}...")
    config = baselines[baseline_name]["config"]

    ckpts_dir = os.path.join(baselines[baseline_name]["baseline_dir"], "ckpt")

    for iter_i in range(
        1500,
        2000,
        5
    ):
        if iter_i % 100 == 0:
            print(f"  iter_i={iter_i}")
        ### load checkpoint
        ckpt_path = os.path.join(ckpts_dir, f"run0_{iter_i}.pt")
        ckpt = torch.load(ckpt_path, map_location="cpu")
        optee = config["meta_testing"]["optee_cls"](**config["meta_testing"]["optee_config"])
        optee.load_state_dict(ckpt["optimizee"])
        nets.append(optee.cpu())
    
### collect for l2os
for opter_name in l2os:
    print(f"Collecting for {opter_name}...")
    config = l2os[opter_name]["config"]

    ckpts_dir = os.path.join(os.environ["CKPT_PATH"], config["meta_testing"]["ckpt_dir"])

    for iter_i in range(
        1500,
        2000,
        5
    ):
        if iter_i % 100 == 0:
            print(f"  iter_i={iter_i}")
        ### load checkpoint
        ckpt_path = os.path.join(ckpts_dir, f"run0_{iter_i}.pt")
        ckpt = torch.load(ckpt_path, map_location="cpu")
        optee = config["meta_testing"]["optee_cls"](**config["meta_testing"]["optee_config"])
        optee.load_state_dict(ckpt["optimizee"])
        nets.append(optee.cpu())

In [None]:
depth = 2
num_nets = len(nets)
alphas_mc = np.zeros(depth) - 1

# Corollary 2.4 in Mohammadi 2014 - for 1d
def alpha_estimator_one(m, X):
    N = len(X)
    n = int(N/m) # must be an integer
    
    X = X[0:n*m]
    
    Y = np.sum(X.reshape(n, m),1)
    eps = np.spacing(1)

    Y_log_norm =  np.log(np.abs(Y) + eps).mean()
    X_log_norm =  np.log(np.abs(X) + eps).mean()
    diff = (Y_log_norm - X_log_norm) / math.log(m)
    return 1 / diff

In [None]:
### collect
weights = []
for i in range(depth):
    weights.append([])

# record the layers in different arrays
for i in range(num_nets):
    tmp_net = nets[i]
    ix = 0
    for n, p in tmp_net.all_named_parameters():
        if "bias" in n or "batch_norm" in n:
            continue
        layer = p.detach().numpy()
        layer = layer.reshape(-1,1)
        weights[ix].append(layer)
        ix += 1

for i in range(depth):
    weights[i] = np.concatenate(weights[i], axis = 1)  

for i in range(depth):
    tmp_weights = np.mean(weights[i], axis=1)
    tmp_weights = tmp_weights.reshape(-1,1)
    tmp_weights = tmp_weights - np.mean(tmp_weights)
    # tmp_alphas = [alpha_estimator_one(mm, tmp_weights) for mm in (2, 5, 10, 20, 50, 100, 500, 1000)]
    tmp_alphas = [alpha_estimator_one(mm, tmp_weights) for mm in (2, 5, 10, 20)]
    # tmp_alphas = [alpha_estimator(mm, torch.tensor(tmp_weights).view(-1, 1)) for mm in (2, 5, 10, 20)]
    alphas_mc[i] = np.median(tmp_alphas)

print(alphas_mc)

## Covariance of gradients/updates

In [None]:
# ckpt_iter_freq = 10
ckpt_iter_freq = 20
max_iters = 1000
max_test_runs = 3
trace_estimate_data_frac = 0.2 # what fraction of the whole dataset to use for trace estimation
eigen_estimate_data_frac = 0.2 # what fraction of mini batches of data_mini_batched to use for max eigenvalue estimation

data_samples = MNIST(training=True, batch_size=1, preload=True)
data_mini_batched = MNIST(training=True, batch_size=128, preload=True)
data_full = MNIST(training=True, batch_size=len(data_samples.loader))

In [None]:
def power_method(A, max_iters):
    x = torch.randn(A.shape[0], 1).to(A.device)
    for _ in range(max_iters):
        x = A @ x
        x /= torch.norm(x)
    return x.T @ A @ x

In [None]:
def get_cov_trace(optee, data_full, data_samples, trace_estimate_data_frac, opter=None, hidden_states=None, cell_states=None, optee_updates_lr=None):
    ### init trace of cov(grad) and cov(updates)
    tr_grads, tr_updates = 0, 0
    # optee_n_params = sum(
    #     [int(np.prod(p.size())) for _, p in optee.all_named_parameters() if p.requires_grad]
    # )

    ### get true gradient g
    loss = optee(data=data_full)
    loss.backward()
    g = []
    update_full = []
    offset = 0

    ### get updates for baseline opter
    if opter is not None and not isinstance(opter, Optimizer):
        update_full = torch.cat([
            p.detach().view(-1, 1).to("cpu") for p in get_baseline_opter_param_updates(optee, opter, verbose=False).values()
        ], dim=0)
    
    for name, p in optee.all_named_parameters():
        if not p.requires_grad or p.grad is None:
            continue
        param_grad = p.grad.view(-1, 1).detach()
        g.append(param_grad)

        ### get updates for l2o
        if opter is not None and isinstance(opter, Optimizer):
            curr_sz = int(np.prod(p.size()))
            with torch.no_grad():
                param_update, _, _ = opter(
                    optee_grads=param_grad,
                    hidden=[h[offset : offset + curr_sz] for h in hidden_states],
                    cell=[c[offset : offset + curr_sz] for c in cell_states],
                    additional_inp=None,
                )
                update_full.append(optee_updates_lr * param_update.view(-1, 1).detach())
            offset += curr_sz
        
        p.grad = None # clear grad
    
    g = torch.cat(g, dim=0).to("cpu")
    if opter is not None and isinstance(opter, Optimizer):
        update_full = torch.cat(update_full, dim=0).to("cpu")

    ### get gradients for each sample
    num_samples = int(len(data_samples.batches) * trace_estimate_data_frac)
    for inp, out in data_samples.batches[:num_samples]:
        inp = w(Variable(inp.view(inp.size()[0], 28 * 28)))
        out = w(Variable(out))
        loss = optee(inp=inp, out=out)
        loss.backward()
        g_i = []
        update_i = []
        offset = 0

        ### get updates for baseline opter
        if opter is not None and not isinstance(opter, Optimizer):
            update_i = torch.cat([
                p.view(-1, 1).detach().to("cpu") for p in get_baseline_opter_param_updates(optee, opter, verbose=False).values()
            ], dim=0)

        for name, p in optee.all_named_parameters():
            if not p.requires_grad or p.grad is None:
                continue
            param_grad = p.grad.view(-1, 1).detach()
            g_i.append(param_grad)

            ### get updates for l2o
            if opter is not None and isinstance(opter, Optimizer):
                curr_sz = int(np.prod(p.size()))
                with torch.no_grad():
                    param_update, _, _ = opter(
                        optee_grads=param_grad,
                        hidden=[h[offset : offset + curr_sz] for h in hidden_states],
                        cell=[c[offset : offset + curr_sz] for c in cell_states],
                        additional_inp=None,
                    )
                    update_i.append(optee_updates_lr * param_update.view(-1, 1).detach())
                offset += curr_sz

            p.grad = None # clear grad

        g_i = torch.cat(g_i, dim=0).to("cpu")
        tr_grads += torch.norm(g - g_i) ** 2
        if opter is not None and isinstance(opter, Optimizer):
            update_i = torch.cat(update_i, dim=0).to("cpu")
        if opter is not None:
            tr_updates += torch.norm(update_full - update_i) ** 2

    ### record trace
    tr_grads /= num_samples
    if opter is not None:
        tr_updates /= num_samples
        return tr_grads.item(), tr_updates.item()
    return tr_grads.item()

In [None]:
def get_cov_max_eigenval(optee, data_mini_batched, eigen_estimate_data_frac, opter=None, hidden_states=None, cell_states=None, optee_updates_lr=None):
    ### get max eigenvalue
    L = int(len(data_mini_batched.batches) * eigen_estimate_data_frac)
    optee_n_params = sum(
        [int(np.prod(p.size())) for _, p in optee.all_named_parameters() if p.requires_grad]
    )

    grads = []
    g = torch.zeros(optee_n_params) # estimate of the true gradient
    updates_full = []
    update_full = torch.zeros(optee_n_params) # estimate of the true update
    for (inp, out) in data_mini_batched.batches[:L]:
        inp = w(Variable(inp.view(inp.size()[0], 28 * 28)))
        out = w(Variable(out))
        loss = optee(inp=inp, out=out)
        loss.backward()
        g_i = []
        update_i = []
        offset = 0

        ### get updates for baseline opter
        if opter is not None and not isinstance(opter, Optimizer):
            update_i = torch.cat([
                p.view(-1, 1).detach().to("cpu") for p in get_baseline_opter_param_updates(optee, opter, verbose=False).values()
            ], dim=0)

        for name, p in optee.all_named_parameters():
            if not p.requires_grad or p.grad is None:
                continue
            param_grad = p.grad.view(-1, 1).detach()
            g_i.append(param_grad)

            ### get updates for l2o
            if opter is not None and isinstance(opter, Optimizer):
                curr_sz = int(np.prod(p.size()))
                with torch.no_grad():
                    param_update, _, _ = opter(
                        optee_grads=param_grad,
                        hidden=[h[offset : offset + curr_sz] for h in hidden_states],
                        cell=[c[offset : offset + curr_sz] for c in cell_states],
                        additional_inp=None,
                    )
                    update_i.append(optee_updates_lr * param_update.view(-1, 1).detach())
                offset += curr_sz

            p.grad = None # clear grad
        g_i = torch.cat(g_i, dim=0).to("cpu")
        grads.append(g_i)
        g += g_i.view(-1)
        
        if opter is not None:
            if isinstance(opter, Optimizer):
                update_i = torch.cat(update_i, dim=0).to("cpu")
            updates_full.append(update_i)
            update_full += update_i.view(-1)

    ### get max eigenvalue of the cov(grads)
    g /= L
    grads = torch.cat(grads, dim=1).T # (L, n_params)
    grads = grads - g # (L, n_params)
    gram_mat = (grads @ grads.T) / L # (L, L)
    max_eigen_grad = power_method(gram_mat, 100)

    ### get max eigenvalue of the cov(updates)
    if opter is not None:
        update_full /= L
        updates_full = torch.cat(updates_full, dim=1).T # (L, n_params)
        updates_full = updates_full - update_full # (L, n_params)
        gram_mat = (updates_full @ updates_full.T) / L # (L, L)
        max_eigen_update = power_method(gram_mat, 100)
        return max_eigen_grad.item(), max_eigen_update.item()

    return max_eigen_grad.item()

In [None]:
### collect for L2Os
for opter_name in l2os:
    print(f"Collecting for {opter_name}...")
    config = l2os[opter_name]["config"]
    ckpts_dir = os.path.join(os.environ["CKPT_PATH"], config["meta_testing"]["ckpt_dir"])
    save_path = os.path.join(os.environ["CKPT_PATH"], config["ckpt_base_dir"], "cov_analysis.pt")

    l2os[opter_name]["cov_grad_tr"] = []
    l2os[opter_name]["cov_grad_max_eigen"] = []
    l2os[opter_name]["cov_update_tr"] = []
    l2os[opter_name]["cov_update_max_eigen"] = []
    
    ### load already collected
    if os.path.exists(save_path):
        print(f"  Loading previously collected results...")
        cov_analysis = torch.load(save_path)
        for k in ["cov_grad_tr", "cov_grad_max_eigen", "cov_update_tr", "cov_update_max_eigen"]:
            l2os[opter_name][k] = cov_analysis[k]
        continue

    for test_run_i in range(config["eval_n_tests"]):
        if max_test_runs and test_run_i >= max_test_runs:
            break
        print(f"  Test run {test_run_i}...")

        l2os[opter_name]["cov_grad_tr"].append([])
        l2os[opter_name]["cov_grad_max_eigen"].append([])
        l2os[opter_name]["cov_update_tr"].append([])
        l2os[opter_name]["cov_update_max_eigen"].append([])

        for iter_i in [1, *range(ckpt_iter_freq, max_iters + 1, ckpt_iter_freq)]:
            if iter_i % 100 == 0:
                print(f"    iter_i={iter_i}")

            ### load ckpt
            ckpt_path = os.path.join(ckpts_dir, f"run{test_run_i}_{iter_i}.pt")
            ckpt = torch.load(ckpt_path)

            ### init optee
            optee_ckpt = ckpt["optimizee"]
            optee = w(config["meta_testing"]["optee_cls"](**config["meta_testing"]["optee_config"]))
            optee.load_state_dict(optee_ckpt)
            optee.eval()
            # if hasattr(optee, "batch_norm"):
            #     optee.batch_norm.eval()

            ### init opter
            opter_ckpt = ckpt["optimizer"]
            opter = w(config["opter_cls"](**config["opter_config"]))
            opter.load_state_dict(opter_ckpt)
            
            ### get trace of cov(grad) and cov(updates)
            tr_grads, tr_updates = get_cov_trace(
                opter=opter,
                hidden_states=ckpt["hidden_states"],
                cell_states=ckpt["cell_states"],
                optee=optee,
                data_full=data_full,
                data_samples=data_samples,
                trace_estimate_data_frac=trace_estimate_data_frac,
                optee_updates_lr=config["meta_testing"]["optee_updates_lr"],
            )
            l2os[opter_name]["cov_grad_tr"][-1].append(tr_grads)
            l2os[opter_name]["cov_update_tr"][-1].append(tr_updates)

            ### get max eigenvalue of cov(grad) and cov(updates)
            max_eigenval_grads, max_eigenval_updates = get_cov_max_eigenval(
                opter=opter,
                hidden_states=ckpt["hidden_states"],
                cell_states=ckpt["cell_states"],
                optee=optee,
                data_mini_batched=data_mini_batched,
                eigen_estimate_data_frac=eigen_estimate_data_frac,
                optee_updates_lr=config["meta_testing"]["optee_updates_lr"],
            )
            l2os[opter_name]["cov_grad_max_eigen"][-1].append(max_eigenval_grads)
            l2os[opter_name]["cov_update_max_eigen"][-1].append(max_eigenval_updates)
        
    ### to np arrays
    l2os[opter_name]["cov_grad_tr"] = np.array(l2os[opter_name]["cov_grad_tr"])
    l2os[opter_name]["cov_grad_max_eigen"] = np.array(l2os[opter_name]["cov_grad_max_eigen"])
    l2os[opter_name]["cov_update_tr"] = np.array(l2os[opter_name]["cov_update_tr"])
    l2os[opter_name]["cov_update_max_eigen"] = np.array(l2os[opter_name]["cov_update_max_eigen"])

    ### save
    torch.save({
        "cov_grad_tr": l2os[opter_name]["cov_grad_tr"],
        "cov_grad_max_eigen": l2os[opter_name]["cov_grad_max_eigen"],
        "cov_update_tr": l2os[opter_name]["cov_update_tr"],
        "cov_update_max_eigen": l2os[opter_name]["cov_update_max_eigen"],
    }, save_path)

In [None]:
### collect for baselines
for baseline_name in baselines:
    print(f"Collecting for {baseline_name}...")
    config = baselines[baseline_name]["config"]
    ckpts_dir = os.path.join(baselines[baseline_name]["baseline_dir"], "ckpt")
    save_path = os.path.join(baselines[baseline_name]["baseline_dir"], "cov_analysis.pt")

    baselines[baseline_name]["cov_grad_tr"] = []
    baselines[baseline_name]["cov_grad_max_eigen"] = []
    baselines[baseline_name]["cov_update_tr"] = []
    baselines[baseline_name]["cov_update_max_eigen"] = []
    
    ### load already collected
    if os.path.exists(save_path):
        print(f"  Loading previously collected results...")
        cov_analysis = torch.load(save_path)
        for k in ["cov_grad_tr", "cov_grad_max_eigen", "cov_update_tr", "cov_update_max_eigen"]:
            baselines[baseline_name][k] = cov_analysis[k]
        continue

    for test_run_i in range(config["eval_n_tests"]):
        if max_test_runs and test_run_i >= max_test_runs:
            break
        print(f"  Test run {test_run_i}...")

        baselines[baseline_name]["cov_grad_tr"].append([])
        baselines[baseline_name]["cov_grad_max_eigen"].append([])
        baselines[baseline_name]["cov_update_tr"].append([])
        baselines[baseline_name]["cov_update_max_eigen"].append([])

        for iter_i in [1, *range(ckpt_iter_freq, max_iters + 1, ckpt_iter_freq)]:
            if iter_i % 50 == 0:
                print(f"    iter_i={iter_i}")

            ### load ckpt
            ckpt_path = os.path.join(ckpts_dir, f"run{test_run_i}_{iter_i}.pt")
            ckpt = torch.load(ckpt_path)

            ### init optee
            optee_ckpt = ckpt["optimizee"]
            optee = w(config["meta_testing"]["optee_cls"](**config["meta_testing"]["optee_config"]))
            optee.load_state_dict(optee_ckpt)
            optee_grads = ckpt["optimizee_grads"]
            for k, v in optee.named_parameters():
                v.grad = optee_grads[k]
            if hasattr(optee, "batch_norm"):
                optee.batch_norm.eval()
            
            ### init opter
            opter_ckpt = ckpt["optimizer"]
            opter = config["meta_testing"]["baseline_opter_cls"](optee.parameters(), **config["meta_testing"]["baseline_opter_config"])
            opter.load_state_dict(opter_ckpt)

            ### get trace of cov(grad) and cov(updates)
            tr_grads, tr_updates = get_cov_trace(
                opter=opter,
                hidden_states=None,
                cell_states=None,
                optee=optee,
                data_full=data_full,
                data_samples=data_samples,
                trace_estimate_data_frac=trace_estimate_data_frac,
            )
            baselines[baseline_name]["cov_grad_tr"][-1].append(tr_grads)
            baselines[baseline_name]["cov_update_tr"][-1].append(tr_updates)

            ### get max eigenvalue of cov(grad) and cov(updates)
            max_eigenval_grads, max_eigenval_updates = get_cov_max_eigenval(
                opter=opter,
                hidden_states=None,
                cell_states=None,
                optee=optee,
                data_mini_batched=data_mini_batched,
                # data_mini_batched=data_samples,
                eigen_estimate_data_frac=eigen_estimate_data_frac,
            )
            baselines[baseline_name]["cov_grad_max_eigen"][-1].append(max_eigenval_grads)
            baselines[baseline_name]["cov_update_max_eigen"][-1].append(max_eigenval_updates)

    ### to np arrays
    baselines[baseline_name]["cov_grad_tr"] = np.array(baselines[baseline_name]["cov_grad_tr"])
    baselines[baseline_name]["cov_grad_max_eigen"] = np.array(baselines[baseline_name]["cov_grad_max_eigen"])
    baselines[baseline_name]["cov_update_tr"] = np.array(baselines[baseline_name]["cov_update_tr"])
    baselines[baseline_name]["cov_update_max_eigen"] = np.array(baselines[baseline_name]["cov_update_max_eigen"])

    ### save
    torch.save({
        "cov_grad_tr": baselines[baseline_name]["cov_grad_tr"],
        "cov_grad_max_eigen": baselines[baseline_name]["cov_grad_max_eigen"],
        "cov_update_tr": baselines[baseline_name]["cov_update_tr"],
        "cov_update_max_eigen": baselines[baseline_name]["cov_update_max_eigen"],
    }, save_path)

In [None]:
### plot config
to_plot = "cov_update_tr"
plot_iters = max_iters
# plot_iters = 200
log_scale = True
save_to_dir = "../results/sym_breaking_regularization/MNISTLeakyRelu_meta_training"
save_to_dir = None

In [None]:
### plot
fig = plt.figure()
ax = fig.add_subplot(111)
x_ticks = [1, *range(ckpt_iter_freq, plot_iters + 1, ckpt_iter_freq)]

### plot baselines
for baseline_name in baselines:
    y_mean = np.mean(baselines[baseline_name][to_plot][:,:(plot_iters // ckpt_iter_freq) + 1], axis=0)
    sns.lineplot(x=x_ticks, y=y_mean, label=baseline_name, ax=ax, linestyle="--")

    ### error bars
    y_std = np.std(baselines[baseline_name][to_plot][:,:(plot_iters // ckpt_iter_freq) + 1], axis=0)
    ax.fill_between(
        x=x_ticks,
        y1=y_mean - y_std,
        y2=y_mean + y_std,
        alpha=0.2
    )

### plot l2os
for opter_name in l2os:
    y_mean = np.mean(l2os[opter_name][to_plot][:,:(plot_iters // ckpt_iter_freq) + 1], axis=0)
    sns.lineplot(x=x_ticks, y=y_mean, label=opter_name, ax=ax)

    ### error bars
    y_std = np.std(l2os[opter_name][to_plot][:,:(plot_iters // ckpt_iter_freq) + 1], axis=0)
    ax.fill_between(
        x=x_ticks,
        y1=y_mean - y_std,
        y2=y_mean + y_std,
        alpha=0.2
    )

ax.set_xlabel("Iteration")
ax.set_ylabel(to_plot)
ax.legend(bbox_to_anchor=(0.5, 1.08), loc="center", ncol=3)
legend = ax.get_legend()
for legend_handle in legend.legendHandles:
    legend_handle.set_linewidth(3.0)

if log_scale:
    ax.set_yscale("log")

if save_to_dir is not None:
    fig_name = f"{to_plot}"
    fig_name += f"_iters_{plot_iters}"
    fig_name += f"_log" if log_scale else ""
    fig.savefig(os.path.join(save_to_dir, f"{fig_name}.png"))

## Stiffness

In [None]:
### config for meta-testing
phase = "meta_testing"
max_iters = 1000
max_test_runs = 3
ckpt_iter_freq = 1

### config for meta-training
# phase = "meta_training"
# max_iters = 200
# max_test_runs = 1
# epoch = 30
# ckpt_iter_freq = 5

In [None]:
### collect for L2O
for opter_name in l2os:
    print(f"Collecting for {opter_name}...")
    config = l2os[opter_name]["config"]
    if phase == "meta_training":
        ckpts_dir = os.path.join(os.environ["CKPT_PATH"], config["meta_training"]["ckpt_dir"])
        save_path = os.path.join(os.environ["CKPT_PATH"], config["ckpt_base_dir"], f"stiffness_analysis_meta_training_{epoch}e.pt")
    else:
        ckpts_dir = os.path.join(os.environ["CKPT_PATH"], config["meta_testing"]["ckpt_dir"])
        save_path = os.path.join(os.environ["CKPT_PATH"], config["ckpt_base_dir"], "stiffness_analysis.pt")

    l2os[opter_name]["param_sai"] = dict()
    l2os[opter_name]["grad_sai"] = dict()

    ### load already collected
    if os.path.exists(save_path):
        print(f"  Loading previously collected results...")
        stiffness_analysis = torch.load(save_path)
        for k in ["param_sai", "grad_sai"]:
            l2os[opter_name][k] = stiffness_analysis[k]
        continue

    for test_run_i in range(config["eval_n_tests"]):
        if max_test_runs and test_run_i >= max_test_runs:
            break
        print(f"  Test run {test_run_i}...")

        for k in l2os[opter_name]["param_sai"].keys():
            l2os[opter_name]["param_sai"][k].append([])
        for k in l2os[opter_name]["grad_sai"].keys():
            l2os[opter_name]["grad_sai"][k].append([])

        for iters in zip(
            [1, *range(ckpt_iter_freq, max_iters, ckpt_iter_freq)],
            range(ckpt_iter_freq, max_iters + ckpt_iter_freq, ckpt_iter_freq)
        ):
            ### load checkpoints
            if phase == "meta_training":
                ckpt_1_path = os.path.join(ckpts_dir, f"{epoch}e_{iters[0]}.pt")
                ckpt_2_path = os.path.join(ckpts_dir, f"{epoch}e_{iters[1]}.pt")
                if not os.path.isfile(ckpt_1_path):
                    ### missing as zero sai
                    if "params_combined" not in l2os[opter_name]["param_sai"]:
                        l2os[opter_name]["param_sai"]["params_combined"] = [[]]
                    if "params_combined" not in l2os[opter_name]["grad_sai"]:
                        l2os[opter_name]["grad_sai"]["params_combined"] = [[]]
                    for n, k in enumerate(l2os[opter_name]["param_sai"].keys()):
                        l2os[opter_name]["param_sai"][k][-1].append(0)
                    for n, k in enumerate(l2os[opter_name]["grad_sai"].keys()):
                        l2os[opter_name]["grad_sai"][k][-1].append(0)
                    continue
            else:
                ckpt_1_path = os.path.join(ckpts_dir, f"run{test_run_i}_{iters[0]}.pt")
                ckpt_2_path = os.path.join(ckpts_dir, f"run{test_run_i}_{iters[1]}.pt")

            ckpt_1 = torch.load(ckpt_1_path, map_location="cpu")
            ckpt_2 = torch.load(ckpt_2_path, map_location="cpu")
            assert ckpt_1["optimizee"].keys() == ckpt_2["optimizee"].keys()
            assert ckpt_1["optimizee_grads"].keys() == ckpt_2["optimizee_grads"].keys()
            
            ### calculate sai
            for param_name in ckpt_1["optimizee"].keys():
                if ckpt_1["optimizee_grads"][param_name] is None:
                    continue
                if param_name not in l2os[opter_name]["param_sai"]:
                    l2os[opter_name]["param_sai"][param_name] = [[]]
                    if "params_combined" in l2os[opter_name]["param_sai"] and len(l2os[opter_name]["param_sai"]["params_combined"]) > 0:
                        # add zeros for missing iterations
                        l2os[opter_name]["param_sai"][param_name][-1].extend(
                            [0] * (len(l2os[opter_name]["param_sai"]["params_combined"][0]) - len(l2os[opter_name]["param_sai"][param_name][-1]))
                        )
                if param_name not in l2os[opter_name]["grad_sai"]:
                    l2os[opter_name]["grad_sai"][param_name] = [[]]
                    if "params_combined" in l2os[opter_name]["grad_sai"] and len(l2os[opter_name]["grad_sai"]["params_combined"]) > 0:
                        # add zeros for missing iterations
                        l2os[opter_name]["grad_sai"][param_name][-1].extend(
                            [0] * (len(l2os[opter_name]["grad_sai"]["params_combined"][0]) - len(l2os[opter_name]["grad_sai"][param_name][-1]))
                        )

                ### param sai
                param_sai = calc_sai(
                    vec_t0=ckpt_1["optimizee"][param_name].view(-1),
                    vec_t1=ckpt_2["optimizee"][param_name].view(-1),
                    time_delta=iters[1] - iters[0],
                    normalize=True
                )
                l2os[opter_name]["param_sai"][param_name][-1].append(param_sai.item())

                ### grad sai
                grad_sai = calc_sai(
                    vec_t0=ckpt_1["optimizee_grads"][param_name].view(-1),
                    vec_t1=ckpt_2["optimizee_grads"][param_name].view(-1),
                    time_delta=iters[1] - iters[0],
                    normalize=True
                )
                l2os[opter_name]["grad_sai"][param_name][-1].append(grad_sai.item())

            ### params combined
            ### param sai
            if "params_combined" not in l2os[opter_name]["param_sai"]:
                l2os[opter_name]["param_sai"]["params_combined"] = [[]]
            params_combined_ckpt_1 = torch.cat([ckpt_1["optimizee"][param_name].view(-1) for param_name in ckpt_1["optimizee"].keys() if ckpt_1["optimizee_grads"][param_name] is not None])
            params_combined_ckpt_2 = torch.cat([ckpt_2["optimizee"][param_name].view(-1) for param_name in ckpt_2["optimizee"].keys() if ckpt_2["optimizee_grads"][param_name] is not None])
            params_combined_sai = calc_sai(
                vec_t0=params_combined_ckpt_1,
                vec_t1=params_combined_ckpt_2,
                time_delta=iters[1] - iters[0],
                normalize=True
            )
            l2os[opter_name]["param_sai"]["params_combined"][-1].append(params_combined_sai.item())

            ### grad sai
            if "params_combined" not in l2os[opter_name]["grad_sai"]:
                l2os[opter_name]["grad_sai"]["params_combined"] = [[]]
            grads_combined_ckpt_1 = torch.cat([ckpt_1["optimizee_grads"][param_name].view(-1) for param_name in ckpt_1["optimizee_grads"].keys() if ckpt_1["optimizee_grads"][param_name] is not None])
            grads_combined_ckpt_2 = torch.cat([ckpt_2["optimizee_grads"][param_name].view(-1) for param_name in ckpt_2["optimizee_grads"].keys() if ckpt_2["optimizee_grads"][param_name] is not None])
            grads_combined_sai = calc_sai(
                vec_t0=grads_combined_ckpt_1,
                vec_t1=grads_combined_ckpt_2,
                time_delta=iters[1] - iters[0],
                normalize=True
            )
            l2os[opter_name]["grad_sai"]["params_combined"][-1].append(grads_combined_sai.item())

    ### to np arrays
    for param_name in l2os[opter_name]["param_sai"].keys():
        l2os[opter_name]["param_sai"][param_name] = np.array(l2os[opter_name]["param_sai"][param_name])
    for param_name in l2os[opter_name]["grad_sai"].keys():
        l2os[opter_name]["grad_sai"][param_name] = np.array(l2os[opter_name]["grad_sai"][param_name])

    ### save
    torch.save({
        "param_sai": l2os[opter_name]["param_sai"],
        "grad_sai": l2os[opter_name]["grad_sai"]
    }, save_path)
    print(f"  Saved to {save_path}")

In [None]:
### collect for baselines
for baseline_name in baselines:
    print(f"Collecting for {baseline_name}...")
    config = baselines[baseline_name]["config"]
    ckpts_dir = os.path.join(baselines[baseline_name]["baseline_dir"], "ckpt")
    save_path = os.path.join(baselines[baseline_name]["baseline_dir"], "stiffness_analysis.pt")

    baselines[baseline_name]["param_sai"] = dict()
    baselines[baseline_name]["grad_sai"] = dict()

    ### load already collected
    if os.path.exists(save_path):
        print(f"  Loading previously collected results...")
        stiffness_analysis = torch.load(save_path)
        for k in ["param_sai", "grad_sai"]:
            baselines[baseline_name][k] = stiffness_analysis[k]
        continue

    for test_run_i in range(config["eval_n_tests"]):
        if max_test_runs and test_run_i >= max_test_runs:
            break
        print(f"  Test run {test_run_i}...")
        
        for k in baselines[baseline_name]["param_sai"].keys():
            baselines[baseline_name]["param_sai"][k].append([])
        for k in baselines[baseline_name]["grad_sai"].keys():
            baselines[baseline_name]["grad_sai"][k].append([])

        for iters in zip(range(1, max_iters), range(2, max_iters + 1)):
            ### load checkpoints
            ckpt_1_path = os.path.join(ckpts_dir, f"run{test_run_i}_{iters[0]}.pt")
            ckpt_2_path = os.path.join(ckpts_dir, f"run{test_run_i}_{iters[1]}.pt")

            ckpt_1 = torch.load(ckpt_1_path, map_location="cpu")
            ckpt_2 = torch.load(ckpt_2_path, map_location="cpu")
            assert ckpt_1["optimizee"].keys() == ckpt_2["optimizee"].keys()
            assert ckpt_1["optimizee_grads"].keys() == ckpt_2["optimizee_grads"].keys()
            
            ### calculate sai
            for param_name in ckpt_1["optimizee"].keys():
                if ckpt_1["optimizee_grads"][param_name] is None or ckpt_2["optimizee_grads"][param_name] is None:
                    continue
                if param_name not in baselines[baseline_name]["param_sai"]:
                    baselines[baseline_name]["param_sai"][param_name] = [[]]
                if param_name not in baselines[baseline_name]["grad_sai"]:
                    baselines[baseline_name]["grad_sai"][param_name] = [[]]
                
                ### param sai
                param_sai = calc_sai(
                    vec_t0=ckpt_1["optimizee"][param_name].view(-1),
                    vec_t1=ckpt_2["optimizee"][param_name].view(-1),
                    time_delta=iters[1] - iters[0],
                    normalize=True
                )
                baselines[baseline_name]["param_sai"][param_name][-1].append(param_sai.item())

                ### grad sai
                grad_sai = calc_sai(
                    vec_t0=ckpt_1["optimizee_grads"][param_name].view(-1),
                    vec_t1=ckpt_2["optimizee_grads"][param_name].view(-1),
                    time_delta=iters[1] - iters[0],
                    normalize=True
                )
                baselines[baseline_name]["grad_sai"][param_name][-1].append(grad_sai.item())
            
            ### params combined
            ### param sai
            if "params_combined" not in baselines[baseline_name]["param_sai"]:
                baselines[baseline_name]["param_sai"]["params_combined"] = [[]]
            params_combined_ckpt_1 = torch.cat([ckpt_1["optimizee"][param_name].view(-1) 
                                                for param_name in ckpt_1["optimizee"].keys() if ckpt_1["optimizee_grads"][param_name] is not None], dim=0)
            params_combined_ckpt_2 = torch.cat([ckpt_2["optimizee"][param_name].view(-1)
                                                for param_name in ckpt_2["optimizee"].keys() if ckpt_2["optimizee_grads"][param_name] is not None], dim=0)
            params_combined_sai = calc_sai(
                vec_t0=params_combined_ckpt_1,
                vec_t1=params_combined_ckpt_2,
                time_delta=iters[1] - iters[0],
                normalize=True
            )
            baselines[baseline_name]["param_sai"]["params_combined"][-1].append(params_combined_sai.item())

            ### grad sai
            if "params_combined" not in baselines[baseline_name]["grad_sai"]:
                baselines[baseline_name]["grad_sai"]["params_combined"] = [[]]
            grads_combined_ckpt_1 = torch.cat([ckpt_1["optimizee_grads"][param_name].view(-1)
                                               for param_name in ckpt_1["optimizee_grads"].keys() if ckpt_1["optimizee_grads"][param_name] is not None], dim=0)
            grads_combined_ckpt_2 = torch.cat([ckpt_2["optimizee_grads"][param_name].view(-1)
                                               for param_name in ckpt_2["optimizee_grads"].keys() if ckpt_2["optimizee_grads"][param_name] is not None], dim=0)
            params_combined_grads_sai = calc_sai(
                vec_t0=grads_combined_ckpt_1,
                vec_t1=grads_combined_ckpt_2,
                time_delta=iters[1] - iters[0],
                normalize=True
            )
            baselines[baseline_name]["grad_sai"]["params_combined"][-1].append(params_combined_grads_sai.item())

    ### to np arrays
    for param_name in baselines[baseline_name]["param_sai"].keys():
        baselines[baseline_name]["param_sai"][param_name] = np.array(baselines[baseline_name]["param_sai"][param_name])
    for param_name in baselines[baseline_name]["grad_sai"].keys():
        baselines[baseline_name]["grad_sai"][param_name] = np.array(baselines[baseline_name]["grad_sai"][param_name])
    
    ### save
    torch.save({
        "param_sai": baselines[baseline_name]["param_sai"],
        "grad_sai": baselines[baseline_name]["grad_sai"]
    }, save_path)
    print(f"  Saved to {save_path}")

In [None]:
to_plot = "grad_sai"
log_scale = True
iters_window = (0, max_iters - 1)
# iters_window = (0, 200)
x_ticks = range(iters_window[0], iters_window[1], ckpt_iter_freq)
save_fig = True

fig = plt.figure(figsize=(22, 18))
ax_i = 1
for param_name in l2os[opter_name][to_plot].keys():
    if param_name == "params_combined":
        ax = fig.add_subplot(3, 2, (ax_i, ax_i + 1))
        ax_i += 2
    else:
        ax = fig.add_subplot(3, 2, ax_i)
        ax_i += 1
    ax.set_title(param_name, fontsize=18)

    ### plot baselines
    for baseline_name in baselines:
        y_mean = np.mean(baselines[baseline_name][to_plot][param_name][:, iters_window[0] // ckpt_iter_freq:(iters_window[1] // ckpt_iter_freq) + 1], axis=0)
        sns.lineplot(
            x=x_ticks,
            y=y_mean,
            label=baseline_name,
            linestyle="--",
            linewidth=2,
            ax=ax,
        )

        ### error bars
        y_std = np.std(baselines[baseline_name][to_plot][param_name][:, iters_window[0] // ckpt_iter_freq:(iters_window[1] // ckpt_iter_freq) + 1], axis=0)
        ax.fill_between(
            x=x_ticks,
            y1=y_mean - y_std,
            y2=y_mean + y_std,
            alpha=0.2
        )
    
    ### plot L2O
    for opter_name in l2os:
        y_mean = np.mean(l2os[opter_name][to_plot][param_name][:, iters_window[0] // ckpt_iter_freq:(iters_window[1] // ckpt_iter_freq) + 1], axis=0)
        sns.lineplot(
            x=x_ticks,
            y=y_mean,
            label=opter_name,
            linewidth=2.5,
            ax=ax
        )

        ### error bars
        y_std = np.std(l2os[opter_name][to_plot][param_name][:, iters_window[0] // ckpt_iter_freq:(iters_window[1] // ckpt_iter_freq) + 1], axis=0)
        ax.fill_between(
            x=x_ticks,
            y1=y_mean - y_std,
            y2=y_mean + y_std,
            alpha=0.2
        )

    ax.set_xlabel("Iteration", fontsize=14)
    y_label = to_plot.split("_")
    y_label = " ".join((y_label[0].title(), y_label[1].upper()))
    if log_scale:
        y_label += " (log scale)"
        ax.set_yscale("log")
    ax.set_ylabel(y_label, fontsize=14)
    ax.legend(fontsize=14)

fig.tight_layout()
if save_fig:
    fig_name = f"{to_plot}_iters_{iters_window[0]}_{iters_window[1]}"
    if phase == "meta_training":
        fig_name += f"_{epoch}e"
    if log_scale:
        fig_name += "_log"
    fig.savefig(f"{fig_name}.png")

## Parameter updates histogram

In [None]:
def plot_hist(updates, grads, iterations, plot_abs=True, log_scale=True, save_to_dir=None):
    assert updates.ndim == 2, "updates must be a 2D array [iterations, updates]"
    assert len(updates) == len(grads)
    assert len(updates) == len(iterations)
    assert updates.shape[1] == grads.shape[1]
    
    fig = plt.figure(figsize=(12, 3 * len(updates)))
    n_axes = len(updates)
    axes = [fig.add_subplot(n_axes, 1, i + 1) for i in range(n_axes)]
    min_val, max_val = np.inf, -np.inf

    for i, iter in zip(range(n_axes), iterations):
        ax = axes[i]
        ax.set_title(f"Iteration {iter}")
        ax.set_xlabel("Value" if not plot_abs else "Absolute Value")
        ax.set_ylabel("Count")

        if not plot_abs:
            ax.plot([0, 0], [0, 1e4], color="black", linewidth=0.5, linestyle="--", alpha=0.5)

        sns.histplot(
            updates[i] if not plot_abs else torch.abs(updates[i]),
            ax=ax,
            # binwidth=0.0075,
            # bins=50,s
            label="update",
            legend=False,
            element="step"
        )
        sns.histplot(
            grads[i] if not plot_abs else torch.abs(grads[i]),
            ax=ax,
            # binwidth=0.0075,
            # bins=50,
            color="red",
            alpha=0.35,
            label="gradient",
            legend=False,
            element="step"
        )

        ax.legend()
        if log_scale:
            ax.set_yscale("log")

        # get min and max values for all axes
        min_val = min(min_val, ax.get_xlim()[0])
        max_val = max(max_val, ax.get_xlim()[1])

    # set all axes to have the same x limits
    for i in range(n_axes):
        ax = axes[i]
        ax.set_xlim(min_val, max_val)

    fig.tight_layout()
    if save_to_dir is not None:
        fig_name = f"update_hist_all"
        if plot_abs:
            fig_name += "_abs"
        if log_scale:
            fig_name += "_log"
        fig.savefig(os.path.join(save_to_dir, f"{fig_name}.png"))

In [None]:
### plot histogram of updates - config
log_scale = True
plot_abs = True
optee_updates_lr = 0.1
l2o_to_plot = list(l2os.keys())[0]
config = l2os[l2o_to_plot]["config"]
print(f"Plotting {l2o_to_plot} updates")
save_to_dir = "../results/sym_breaking_regularization/MNISTNet_meta_training"
# save_to_dir = None

### collect for meta-testing
ckpts_dir = os.path.join(os.environ["CKPT_PATH"], config["meta_testing"]["ckpt_dir"])
show_iters = [1, 5, 10, 20, 50, 100, 200, 500, 1000]
ckpt_prefix = "run0_"

### collect for meta-training
# ckpts_dir = os.path.join(os.environ["CKPT_PATH"], config["meta_training"]["ckpt_dir"])
# show_iters = [1, 5, 10, 20, 50, 100, 200]
# ckpt_prefix = "40e_"

In [None]:
### collect and plot histogram of updates
all_updates, all_grads, iterations = [], [], []
for iteration in show_iters:
    ckpt_path = os.path.join(ckpts_dir, f"{ckpt_prefix}{iteration}.pt")

    ckpt = torch.load(ckpt_path, map_location="cpu")
    updates = optee_updates_lr * torch.cat(
        [ckpt["optimizee_updates"][param_name].view(-1) for param_name in ckpt["optimizee_updates"].keys() if "running" not in param_name]
    )
    grads = torch.cat(
        [ckpt["optimizee_grads"][param_name].view(-1) for param_name in ckpt["optimizee_grads"].keys() if "running" not in param_name]
    )

    all_updates.append(updates.view(1, -1))
    all_grads.append(grads.view(1, -1))
    iterations.append(iteration)

all_updates = torch.cat(all_updates, dim=0)
all_grads = torch.cat(all_grads, dim=0)
plot_hist(all_updates, all_grads, iterations, plot_abs=plot_abs, log_scale=log_scale, save_to_dir=save_to_dir)