In [None]:
import json
from copy import deepcopy
import numpy as np
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
from datetime import datetime
import os.path
import copy
import seaborn as sns
import lovely_tensors as lt # can be removed

from l2o.others import w, detach_var, rsetattr, rgetattr, count_parameters, print_grads, \
    load_l2o_opter_ckpt, load_baseline_opter_ckpt, load_ckpt, dict_to_str, get_baseline_ckpt_dir
from l2o.visualization import get_model_dot
from l2o.training import do_fit, fit_normal, fit_optimizer, find_best_lr_normal
from l2o.regularization import (
    regularize_updates_translation_constraints,
    regularize_updates_scale_constraints,
    regularize_updates_rescale_constraints,
    regularize_updates_constraints,
    regularize_translation_conservation_law_breaking,
    regularize_rescale_conservation_law_breaking,
)
from l2o.analysis import (
    get_baseline_opter_param_updates,
    collect_rescale_sym_deviations,
    collect_translation_sym_deviations,
    collect_scale_sym_deviations,
)
from l2o.data import MNIST, CIFAR10
from l2o.optimizer import Optimizer
from l2o.optimizee import (
    MNISTSigmoid,
    MNISTReLU,
    MNISTNet,
    MNISTNet2Layer,
    MNISTNetBig,
    MNISTRelu,
    MNISTLeakyRelu,
    MNISTSimoidBatchNorm,
    MNISTReluBatchNorm,
    MNISTConv,
    MNISTReluBig,
    MNISTReluBig2Layer,
    MNISTMixtureOfActivations,
    MNISTNetBig2Layer,
)
from l2o.meta_module import *
from meta_test import meta_test, meta_test_baselines

lt.monkey_patch() # can be removed
sns.set(color_codes=True)
sns.set_style("white")

# Initialize L2O Optimizer

### Load pre-trained L2O

In [None]:
### load previous checkpoint (and skip meta-training of a new l2O optimizer)
opter, config, ckpt = load_ckpt(
    dir_path=os.path.join(os.environ["CKPT_PATH"], "10-05-2023_13-26-17_MNISTReluBatchNorm_Optimizer")
)
print(json.dumps(config, indent=4, default=str))

### Meta-train a new L2O

In [None]:
### create a new config
config = { # global config
    "opter_cls": Optimizer,
    "opter_config": {
        "preproc": True,
        "additional_inp_dim": 0,
        "manual_init_output_params": False,
    },
    "eval_n_tests": 5,
    "ckpt_base_dir": datetime.now().strftime('%d-%m-%Y_%H-%M-%S'),
    "ckpt_baselines_dir": "baselines",
}

config["meta_training"] = { # training the optimizer
    "opter_cls": config["opter_cls"],
    "opter_config": config["opter_config"],
    "data_cls": MNIST,
    "data_config": {
        "batch_size": 128,
        # "only_classes": [0, 1],
    },
    "optee_cls": MNISTLeakyRelu,
    "optee_config": {},
    "n_epochs": 50,
    "n_optim_runs_per_epoch": 20,
    "n_iters": 200,
    "unroll": 20,
    "n_tests": 0,
    "optee_updates_lr": 0.1,
    "opter_lr": 0.001,
    "log_unroll_losses": True,
    "opter_updates_reg_func": None,
    "opter_updates_reg_func_config": {},
    "reg_mul": 0,
    "eval_iter_freq": 10,
    "ckpt_iter_freq": 5,
    "ckpt_dir": None, # will be set later
    "load_ckpt": None,
    "start_from_epoch": 0,
    "verbose": 2,
}

config["meta_testing"] = { # testing the optimizer
    "data_cls": config["meta_training"]["data_cls"],
    "data_config": config["meta_training"]["data_config"],
    "optee_cls": config["meta_training"]["optee_cls"],
    "optee_config": config["meta_training"]["optee_config"],
    "unroll": 1,
    "n_iters": 1000,
    "optee_updates_lr": config["meta_training"]["optee_updates_lr"],
    "train_opter": False,
    "opter_optim": None,
    "ckpt_iter_freq": 5,
    "ckpt_dir": None, # will be set later
}

### additional config
config["ckpt_base_dir"] = f"{config['ckpt_base_dir']}_{config['meta_training']['optee_cls'].__name__}_{config['opter_cls'].__name__}"
config["meta_training"]["ckpt_dir"] = os.path.join(os.environ["CKPT_PATH"], config["ckpt_base_dir"], "meta_training")
config["meta_testing"]["ckpt_dir"] = os.path.join(os.environ["CKPT_PATH"], config["ckpt_base_dir"], "meta_testing")

In [None]:
### create directories
os.makedirs(os.path.join(os.environ["CKPT_PATH"], config["ckpt_base_dir"]), exist_ok=False)
os.makedirs(config["meta_training"]["ckpt_dir"], exist_ok=True)
os.makedirs(config["meta_testing"]["ckpt_dir"], exist_ok=True)

### dump config
with open(os.path.join(os.environ["CKPT_PATH"], config["ckpt_base_dir"], "config.json"), "w") as f:
    json.dump(config, f, indent=4, default=str)

print(f"Base directory: {config['ckpt_base_dir']}")

In [None]:
### meta-train a new L2O optimizer model
torch.manual_seed(0)
np.random.seed(0)

best_training_loss, metrics, opter_state_dict = fit_optimizer(
    **config["meta_training"],
)
opter = w(
    config["opter_cls"](
        **config["opter_config"] if config["opter_config"] is not None else {}
    )
)
opter.load_state_dict(opter_state_dict)
print(best_training_loss)

### save the final model
torch.save({
    "state_dict": opter_state_dict,
    "config": config,
    "loss": best_training_loss,
    "metrics": metrics,
}, os.path.join(os.environ["CKPT_PATH"], config["ckpt_base_dir"], f"l2o_optimizer.pt"))

# Evaluate

In [None]:
optees_to_test = [
    (MNISTRelu, {}),
    (MNISTConv, {}),
    (MNISTSigmoid, {"layer_sizes": [100, 100]}),
    (MNISTReLU, {"layer_sizes": [100, 100]}),
    (MNISTNet2Layer, {}),
    (MNISTNetBig, {}),
    (MNISTReluBatchNorm, {"affine": True, "track_running_stats": True}),
    (MNISTNet, {}),
    (MNISTLeakyRelu, {}),
]
baselines_to_test_against = [
    ("Adam", optim.Adam, {"lr": find_best_lr_normal}),
    ("SGD", optim.SGD, {"lr": find_best_lr_normal, "momentum": 0.9}),
]
use_existing_baselines = True # loads existing if exists, otherwise trains from scratch

### Meta-test L2O optimizer and baselines

In [None]:
results = meta_test(
    opter=opter,
    optees=optees_to_test,
    config=config,
    save_ckpts_for_all_test_runs=False,
)

In [None]:
results_baselines = meta_test_baselines(
    baseline_opters=baselines_to_test_against,
    optees=optees_to_test,
    config=config,
    use_existing_baselines=use_existing_baselines,
    save_ckpts_for_all_test_runs=False,
)

In [None]:
### load all l2o metrics from disk
l2o_metrics = {}

for metrics_file in [f_name for f_name in os.listdir(os.path.join(os.environ["CKPT_PATH"], config["ckpt_base_dir"])) if f_name.startswith("metrics_")]:
    print(f"Loading {metrics_file}")
    metrics_name = metrics_file[8:-4] # remove the "metrics_" prefix and ".npy" suffix
    l2o_metrics[metrics_name] = np.load(os.path.join(os.environ["CKPT_PATH"], config["ckpt_base_dir"], metrics_file), allow_pickle=True).item()

In [None]:
### load baseline metrics from disk
baseline_metrics = dict()

for optee_cls, optee_config in optees_to_test:
    run_nickname = f"{optee_cls.__name__}_{dict_to_str(optee_config)}_{config['meta_testing']['data_cls'].__name__}_{dict_to_str(config['meta_testing']['data_config'])}"
    baseline_metrics[run_nickname] = dict()
    
    ### load metrics for all considered baselines
    for (opter_name, baseline_opter_cls, baseline_opter_config) in baselines_to_test_against:
        baseline_opter_config_copy = deepcopy(baseline_opter_config)
        
        if "lr" in baseline_opter_config and callable(baseline_opter_config["lr"]):
            baseline_opter_config_copy["lr"] = baseline_opter_config_copy["lr"].__name__ # replace function with its name

        baseline_dir_name = f"{opter_name}_{dict_to_str(baseline_opter_config_copy)}" \
            + f"_{optee_cls.__name__}_{dict_to_str(optee_config)}" \
            + f"_{config['meta_testing']['data_cls'].__name__}_{dict_to_str(config['meta_testing']['data_config'])}"
        metrics_path = os.path.join(os.environ["CKPT_PATH"], config["ckpt_baselines_dir"], baseline_dir_name, "metrics.npy")
        
        ### load
        print(f"Loading {metrics_path}")
        baseline_metrics[run_nickname][opter_name] = np.load(metrics_path, allow_pickle=True).item()

### Plot meta-testing results

In [None]:
show_max_iters = 500
log_losses = False
fig_save_dir = os.path.join(os.environ["CKPT_PATH"], config["ckpt_base_dir"])
# fig_save_dir = None

for optee_nickname, metrics in l2o_metrics.items():
    if optee_nickname not in baseline_metrics:
        continue
    curr_baseline_metrics = baseline_metrics[optee_nickname]

    ### plot comparison
    fig = plt.figure(figsize=(26, 16))
    fig.suptitle(f"Meta-testing on {optee_nickname}", fontsize=16, fontweight="bold", y=0.92)
    
    for m_i, metric in enumerate(["train_loss", "test_loss", "train_acc", "test_acc"]):
        ax = fig.add_subplot(2, 2, m_i + 1)
        
        ### baseline optimizers
        for opter_name, opter_metrics in curr_baseline_metrics.items():
            if "test" in metric:
                x = np.arange(config["meta_training"]["eval_iter_freq"], show_max_iters + 1, config["meta_training"]["eval_iter_freq"])
                y = np.mean(opter_metrics[metric][:,:show_max_iters // 10], axis=0)
                y_min = np.min(opter_metrics[metric][:,:show_max_iters // 10], axis=0)
                y_max = np.max(opter_metrics[metric][:,:show_max_iters // 10], axis=0)
            else:
                x = range(opter_metrics[metric][:,:show_max_iters].shape[1])
                y = np.mean(opter_metrics[metric][:,:show_max_iters], axis=0)
                y_min = np.min(opter_metrics[metric][:,:show_max_iters], axis=0)
                y_max = np.max(opter_metrics[metric][:,:show_max_iters], axis=0)
            sns.lineplot(
                x=x,
                y=y,
                label=opter_name,
                linestyle="--",
                ax=ax,
            )
            ax.fill_between(
                x=x,
                y1=y_min,
                y2=y_max,
                alpha=0.1,
            )

            if log_losses:
                # set y to log scale
                if "loss" in metric:
                    ax.set_yscale("log")
        
        ### L2O optimizer
        if "test" in metric:
            x = np.arange(config["meta_training"]["eval_iter_freq"], show_max_iters + 1, config["meta_training"]["eval_iter_freq"])
            y = np.mean(metrics[metric][:,:show_max_iters // config["meta_training"]["eval_iter_freq"]], axis=0)
            y_min = np.min(metrics[metric][:,:show_max_iters // config["meta_training"]["eval_iter_freq"]], axis=0)
            y_max = np.max(metrics[metric][:,:show_max_iters // config["meta_training"]["eval_iter_freq"]], axis=0)
        else:
            x = range(metrics[metric][:,:show_max_iters].shape[1])
            y = np.mean(metrics[metric][:,:show_max_iters], axis=0)
            y_min = np.min(metrics[metric][:,:show_max_iters], axis=0)
            y_max = np.max(metrics[metric][:,:show_max_iters], axis=0)
        sns.lineplot(
            x=x,
            y=y,
            label="L2O",
            color="orange",
            linewidth=2,
            ax=ax,
        )
        ax.fill_between(
            x=x,
            y1=y_min,
            y2=y_max,
            alpha=0.2,
            color="orange",
        )
        
        ### plot settings
        ax.set_title(metric, fontsize=14, fontweight="bold")
        ax.set_xlabel("Iteration")
        ax.set_ylabel(metric)
        if "acc" in metric:
            ax.set_ylim(0.6, 1.0)
        ax.legend()

    plt.show()

    ### save the figure
    if fig_save_dir is not None:
        prefix = "log_losses" if log_losses else "losses"
        fig.savefig(os.path.join(fig_save_dir, f"{prefix}_{optee_nickname}_{show_max_iters}.png"))