In [None]:
import json
import time
from copy import deepcopy
import numpy as np
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
from datetime import datetime
import os
import os.path
import copy
import seaborn as sns
import lovely_tensors as lt # can be removed

from l2o.others import w, detach_var, rsetattr, rgetattr, count_parameters, print_grads, \
    load_l2o_opter_ckpt, load_baseline_opter_ckpt, load_ckpt, dict_to_str, get_baseline_ckpt_dir
from l2o.training import fit_normal, find_best_lr_normal
from l2o.regularization import (
    regularize_updates_translation_constraints,
    regularize_updates_scale_constraints,
    regularize_updates_rescale_constraints,
    regularize_updates_constraints,
    regularize_translation_conservation_law_breaking,
    regularize_rescale_conservation_law_breaking,
)
from l2o.analysis import (
    get_baseline_opter_param_updates,
    collect_rescale_sym_deviations,
    collect_translation_sym_deviations,
    collect_scale_sym_deviations,
    calc_sai,
)
from l2o.data import MNIST, CIFAR10
from l2o.optimizer import Optimizer
from l2o.optimizee import (
    MNISTSigmoid,
    MNISTReLU,
    MNISTNet,
    MNISTNet2Layer,
    MNISTNetBig,
    MNISTRelu,
    MNISTLeakyRelu,
    MNISTSimoidBatchNorm,
    MNISTReluBatchNorm,
    MNISTConv,
    MNISTReluBig,
    MNISTReluBig2Layer,
    MNISTMixtureOfActivations,
    MNISTNetBig2Layer,
)
from l2o.meta_module import *
# from meta_test import meta_test, meta_test_baselines

lt.monkey_patch() # can be removed
sns.set(color_codes=True)
sns.set_style("white")

## Utils

In [None]:
def do_fit(
    opter,
    opter_optim,
    data_cls,
    optee_cls,
    unroll,
    n_iters,
    optee_updates_lr,
    data_config=None,
    train_opter=True,
    log_unroll_losses=False,
    opter_updates_reg_func=None,
    opter_updates_reg_func_config=None,
    reg_mul=1.0,
    optee_config=None,
    eval_iter_freq=10,
    ckpt_iter_freq=None,
    ckpt_prefix="",
    ckpt_dir="",
):
    if train_opter:
        opter.train()
        opter_optim.zero_grad()
    else:
        opter.eval()
        unroll = 1

    train_data = data_cls(
        training=True, **data_config if data_config is not None else {}
    )
    test_data = data_cls(
        training=False, **data_config if data_config is not None else {}
    )
    optee = w(optee_cls(**optee_config if optee_config is not None else {}))
    optee_n_params = sum(
        [int(np.prod(p.size())) for _, p in optee.all_named_parameters() if p.requires_grad]
    )

    ### save initial optee's parameters for regularization
    optee_init_params = dict()
    for name, p in optee.all_named_parameters():
        optee_init_params[name] = p.data.detach().clone()

    ### initialize hidden and cell states
    hidden_states = [
        w(Variable(torch.zeros(optee_n_params, opter.hidden_sz))) for _ in range(2)
    ]
    cell_states = [
        w(Variable(torch.zeros(optee_n_params, opter.hidden_sz))) for _ in range(2)
    ]

    metrics = {m: [] for m in ["train_loss", "train_acc", "test_loss", "test_acc"]}
    unroll_losses = None
    reg_losses = None
    updates_for_ckpt = dict()
    prev_params = dict()

    ### run optee's training loop
    for iteration in range(1, n_iters + 1):
        ### train optee
        optee.train()
        train_loss, train_acc = optee(train_data, return_acc=True)  # a single minibatch
        train_loss.backward(retain_graph=train_opter)

        ### track for training opter
        unroll_losses = (
            train_loss if unroll_losses is None else unroll_losses + train_loss
        )

        ### optimizer: gradients -> updates
        result_params = dict()
        updates_for_reg = dict()
        hidden_states2 = [
            w(Variable(torch.zeros(optee_n_params, opter.hidden_sz))) for _ in range(2)
        ]
        cell_states2 = [
            w(Variable(torch.zeros(optee_n_params, opter.hidden_sz))) for _ in range(2)
        ]
        offset = 0
        for name, p in optee.all_named_parameters():
            if p.requires_grad == False:  # batchnorm stats
                result_params[name] = p
                continue

            # We do this so the gradients are disconnected from the graph but we still get
            # gradients from the rest
            cur_sz = int(np.prod(p.size()))
            gradients = detach_var(p.grad.view(cur_sz, 1))
            
            ### prepare additional input for optimizer
            if iteration == 1:
                additional_inp = w(torch.zeros(cur_sz, 1))
            else:
                additional_inp = w(calc_sai(
                    vec_t0=prev_params[name].detach().view(-1),
                    vec_t1=p.detach().view(-1),
                    time_delta=1,
                    normalize=True
                ).expand(cur_sz, 1))
            
            updates, new_hidden, new_cell = opter(
                optee_grads=gradients,
                hidden=[h[offset : offset + cur_sz] for h in hidden_states],
                cell=[c[offset : offset + cur_sz] for c in cell_states],
                additional_inp=additional_inp,
            )

            ### track updates for checkpointing
            if ckpt_iter_freq and (iteration % ckpt_iter_freq == 0 or iteration == 1):
                updates_for_ckpt[name] = updates.view(*p.size()).detach()

            ### track updates for regularization
            if train_opter and opter_updates_reg_func is not None:
                updates_for_reg[name] = updates.view(*p.size())

            ### update hidden and cell states
            for i in range(len(new_hidden)):
                hidden_states2[i][offset : offset + cur_sz] = new_hidden[i]
                cell_states2[i][offset : offset + cur_sz] = new_cell[i]

            ### update optee's params
            result_params[name] = p + optee_updates_lr * updates.view(*p.size())
            result_params[name].retain_grad()
            offset += cur_sz

        ### add regularization loss for opter
        if train_opter and opter_updates_reg_func is not None:
            if "conservation_law_breaking" in opter_updates_reg_func.__name__:  # TODO: quick hack
                reg_loss = torch.abs(
                    reg_mul
                    * opter_updates_reg_func(
                        optee=optee,
                        params_t0=optee_init_params,
                        **opter_updates_reg_func_config
                        if opter_updates_reg_func_config is not None
                        else {},
                    )
                )
            else:
                reg_loss = torch.abs(
                    reg_mul
                    * opter_updates_reg_func(
                        updates=updates_for_reg,
                        optee=optee,
                        lr=optee_updates_lr,
                        **opter_updates_reg_func_config
                        if opter_updates_reg_func_config is not None
                        else {},
                    )
                )
            reg_losses = reg_loss if reg_losses is None else reg_losses + reg_loss
            updates_for_reg = dict()
            # add to metrics
            if "train_reg_loss" not in metrics:
                metrics["train_reg_loss"] = []
            metrics["train_reg_loss"].append(reg_loss.item())

        ### track metrics
        metrics["train_loss"].append(train_loss.item())
        metrics["train_acc"].append(train_acc.item())

        ### eval
        if eval_iter_freq is not None and (iteration % eval_iter_freq == 0 or iteration == 1):
            optee.eval()
            test_loss, test_acc = optee(test_data, return_acc=True)
            metrics["test_loss"].append(test_loss.item())
            metrics["test_acc"].append(test_acc.item())
            optee.train()

        ### checkpoint
        if ckpt_iter_freq and (iteration % ckpt_iter_freq == 0 or iteration == 1):
            ckpt = {
                "optimizee": optee.state_dict(),
                "optimizee_grads": {k: v.grad for k, v in optee.all_named_parameters()},
                "optimizee_updates": updates_for_ckpt,
                "optimizer": opter.state_dict(),
                "hidden_states": hidden_states,
                "cell_states": cell_states,
                "metrics": metrics,
            }
            if not ckpt_dir.startswith(os.environ["CKPT_PATH"]):
                print(f"[WARNING] ckpt_dir {ckpt_dir} does not start with CKPT_PATH, prepending CKPT_PATH to it")
                ckpt_dir = os.path.join(os.environ["CKPT_PATH"], ckpt_dir)
            torch.save(ckpt, os.path.join(ckpt_dir, f"{ckpt_prefix}{iteration}.pt"))
            updates_for_ckpt = dict()

        ### save current optee params
        prev_params = deepcopy({n: p.detach() for n, p in optee.all_named_parameters()})

        ### update - continue unrolling or step w/ opter
        if iteration % unroll == 0:
            ### step w/ the optimizer
            if train_opter:
                opter_optim.zero_grad()
                if log_unroll_losses:
                    unroll_losses = torch.log(unroll_losses)
                total_loss = (
                    unroll_losses + reg_losses
                    if reg_losses is not None
                    else unroll_losses
                )
                total_loss.backward()
                opter_optim.step()

            ### reinitialize - start next unroll
            optee = w(optee_cls(**optee_config if optee_config is not None else {}))
            optee.load_state_dict(result_params)
            optee.zero_grad()
            hidden_states = [detach_var(v) for v in hidden_states2]
            cell_states = [detach_var(v) for v in cell_states2]
            unroll_losses = None
            reg_losses = None
        else:
            ### update the optimizee and optimizer's states
            for name, p in optee.all_named_parameters():
                if p.requires_grad:  # leave the batchnorm stats
                    rsetattr(optee, name, result_params[name])
            hidden_states = hidden_states2
            cell_states = cell_states2

    return metrics

In [None]:
def fit_optimizer(
    data_cls,
    optee_cls,
    data_config=None,
    optee_config=None,
    opter_cls=Optimizer,
    opter_config=None,
    unroll=20,
    n_epochs=20,
    n_optim_runs_per_epoch=20,
    n_iters=100,
    n_tests=100,
    opter_lr=0.01,
    log_unroll_losses=False,
    opter_updates_reg_func=None,
    opter_updates_reg_func_config=None,
    reg_mul=1.0,
    optee_updates_lr=1.0,
    eval_iter_freq=10,
    ckpt_iter_freq=None,
    ckpt_epoch_freq=None,
    ckpt_prefix="",
    ckpt_dir="",
    load_ckpt=None,
    start_from_epoch=0,
    verbose=1,
):
    if ckpt_iter_freq is not None:
        os.makedirs(ckpt_dir, exist_ok=True)

    opter = w(opter_cls(**opter_config if opter_config is not None else {}))
    meta_opt = optim.Adam(opter.parameters(), lr=opter_lr)

    best_opter = None
    best_loss = np.inf
    all_metrics = list()
    
    ### load checkpoint
    if load_ckpt is not None:
        print(f"... loading checkpoint from {load_ckpt} ...")
        ckpt = torch.load(load_ckpt)
        best_opter = ckpt["best_opter"]
        best_loss = ckpt["best_loss"]
        all_metrics = ckpt["metrics"]
        opter.load_state_dict(ckpt["opter"])
        meta_opt = optim.Adam(opter.parameters(), lr=opter_lr)
        meta_opt.load_state_dict(ckpt["meta_opter"])
        meta_opt.zero_grad()
        opter.train()

    ### meta-training epochs
    for epoch_i in range(start_from_epoch, n_epochs):
        start_time = time.time()
        if verbose > 0:
            print(f"[{epoch_i + 1}/{n_epochs}]")
        all_metrics.append({k: dict() for k in ["meta_training", "meta_testing"]})

        ### meta-train
        for run_i in range(n_optim_runs_per_epoch):
            curr_ckpt_iter_freq = None
            if epoch_i % ckpt_epoch_freq == 0 and run_i == n_optim_runs_per_epoch - 1:
                curr_ckpt_iter_freq = ckpt_iter_freq
            optim_run_metrics = do_fit(
                opter=opter,
                opter_optim=meta_opt,
                data_cls=data_cls,
                data_config=data_config,
                optee_cls=optee_cls,
                optee_config=optee_config,
                unroll=unroll,
                n_iters=n_iters,
                optee_updates_lr=optee_updates_lr,
                train_opter=True,
                log_unroll_losses=log_unroll_losses,
                opter_updates_reg_func=opter_updates_reg_func,
                opter_updates_reg_func_config=opter_updates_reg_func_config,
                reg_mul=reg_mul,
                eval_iter_freq=eval_iter_freq,
                ckpt_iter_freq=curr_ckpt_iter_freq,
                ckpt_prefix=f"{ckpt_prefix}{epoch_i}e_",
                ckpt_dir=ckpt_dir,
            )
            log_msg = f"  [{run_i + 1}/{n_optim_runs_per_epoch}]"
            for k, v in optim_run_metrics.items():
                if k not in all_metrics[-1]["meta_training"]:
                    all_metrics[-1]["meta_training"][k] = np.array(v)
                else:
                    all_metrics[-1]["meta_training"][k] += np.array(v)
                if "acc" in k:
                    log_msg += f"  {k}_last={v[-1]:.3f}"
                else:
                    log_msg += f"  {k}_sum={np.sum(v):.3f}  {k}_last={v[-1]:.3f}"
            if verbose > 1:
                print(log_msg)

        ### average metrics and log
        for k, v in all_metrics[-1]["meta_training"].items():
            v = v / n_optim_runs_per_epoch
            all_metrics[-1]["meta_training"][k] = {"sum": np.sum(v), "last": v[-1]}

        if verbose > 0:
            print(
                f"[{epoch_i + 1}/{n_epochs}] Meta-training metrics:"
                f"\n{json.dumps(all_metrics[-1]['meta_training'], indent=4, sort_keys=False)}"
            )

        ### meta-test
        if n_tests > 0:
            for _ in range(n_tests):
                optim_run_metrics = do_fit(
                    opter=opter,
                    opter_optim=meta_opt,
                    data_cls=data_cls,
                    data_config=data_config,
                    optee_cls=optee_cls,
                    optee_config=optee_config,
                    unroll=unroll,
                    n_iters=n_iters,
                    optee_updates_lr=optee_updates_lr,
                    train_opter=False,
                    eval_iter_freq=eval_iter_freq,
                    ckpt_iter_freq=None,
                )
                for k, v in optim_run_metrics.items():
                    if k not in all_metrics[-1]["meta_testing"]:
                        all_metrics[-1]["meta_testing"][k] = np.array(v)
                    else:
                        all_metrics[-1]["meta_testing"][k] += np.array(v)

            ### average metrics and log
            for k, v in all_metrics[-1]["meta_testing"].items():
                v = v / n_tests
                all_metrics[-1]["meta_testing"][k] = {"sum": np.sum(v), "last": v[-1]}

            if verbose > 0:
                print(
                    f"[{epoch_i + 1}/{n_epochs}] Meta-testing metrics:"
                    f"\n{json.dumps(all_metrics[-1]['meta_testing'], indent=4, sort_keys=False)}"
                )

            if all_metrics[-1]["meta_testing"]["train_loss"]["sum"] < best_loss:
                if verbose > 0:
                    print(
                        f"[{epoch_i + 1}/{n_epochs}] New best loss"
                        f"\n\t previous:\t {best_loss}"
                        f"\n\t current:\t {all_metrics[-1]['meta_testing']['train_loss']['sum']} (at last iter: {all_metrics[-1]['meta_testing']['train_loss']['last']})"
                    )
                best_loss = all_metrics[-1]["meta_testing"]["train_loss"]["sum"]
                best_opter = copy.deepcopy(opter.state_dict())
        else:
            ### no meta-testing, so just save the best model based on meta-training
            if all_metrics[-1]["meta_training"]["train_loss"]["sum"] < best_loss:
                if verbose > 0:
                    print(
                        f"[{epoch_i + 1}/{n_epochs}] New best loss"
                        f"\n\t previous:\t {best_loss}"
                        f"\n\t current:\t {all_metrics[-1]['meta_training']['train_loss']['sum']:.3f} (at last iter: {all_metrics[-1]['meta_training']['train_loss']['last']:.3f})"
                    )
                best_loss = all_metrics[-1]["meta_training"]["train_loss"]["sum"]
                best_opter = copy.deepcopy(opter.state_dict())

        ### save ckpt
        if epoch_i % ckpt_epoch_freq == 0:
            ckpt = {
                "best_opter": best_opter if best_opter else None,
                "best_loss": best_loss,
                "opter": opter.state_dict(),
                "meta_opter": meta_opt.state_dict(),
                "metrics": all_metrics,
            }
            torch.save(ckpt, os.path.join(ckpt_dir, f"{epoch_i}.pt"))

        end_time = time.time()
        if verbose > 0:
            print(f"[{epoch_i + 1}/{n_epochs}] Epoch took {end_time - start_time:.2f}s")

    return best_loss, all_metrics, best_opter

In [None]:
def meta_test(opter, optees, config, save_ckpts_for_all_test_runs=False, seed=0):
    """
    Parameters
    ----------
    opter : l2o.optimizer.Optimizer
        The L2O optimizer model to meta-test.
    optees : list of tuples
        List of tuples of the form (optimizee_cls, optimizee_config) to meta-test on.
    config : dict
        The config dictionary for the meta-testing run.
    save_ckpts_for_all_test_runs : bool, optional
        Whether to save checkpoints for all test runs, by default False.
    seed : int, optional
        The random seed to use for the meta-testing run, by default 0.

    Returns
    -------
    results : dict
        Dictionary of results.
    """
    ### meta-test the L2O optimizer model on various optimizees
    results = dict()
    for optee_cls, optee_config in optees:
        meta_task_name = (
            f"{optee_cls.__name__}_{dict_to_str(optee_config)}"
            + f"_{config['meta_testing']['data_cls'].__name__}_{dict_to_str(config['meta_testing']['data_config'])}"
        )
        print(f"Meta-testing on {meta_task_name}")

        ### set config
        run_config = deepcopy(config)
        run_config["meta_testing"]["optee_cls"] = optee_cls
        run_config["meta_testing"]["optee_config"] = optee_config
        if save_ckpts_for_all_test_runs is False:
            ckpt_prefix = ""

        ### run the meta-testing run with l2o optimizer
        torch.manual_seed(seed)
        np.random.seed(seed)
        metrics = []
        for eval_test_i in range(run_config["eval_n_tests"]):
            ckpt_prefix = f"run{eval_test_i}_" # add prefix to checkpoints
            if not save_ckpts_for_all_test_runs and eval_test_i > 0:
                run_config["meta_testing"]["ckpt_iter_freq"] = None # only save checkpoints for the first run
            metrics.append(
                do_fit(
                    opter=opter,
                    **run_config["meta_testing"],
                    ckpt_prefix=ckpt_prefix,
                )
            )

        ### aggregate metrics by metric name and save
        metrics = {k: np.array([m[k] for m in metrics]) for k in metrics[0].keys()}
        metrics_path = os.path.join(
            os.environ["CKPT_PATH"], run_config["ckpt_base_dir"], f"metrics_{meta_task_name}.npy"
        )
        np.save(metrics_path, metrics)
        print(f"  Metrics saved to {metrics_path}")

        ### save the config for this run
        config_path = os.path.join(
            os.environ["CKPT_PATH"], run_config["ckpt_base_dir"], f"config_{meta_task_name}.json"
        )
        with open(config_path, "w") as f:
            json.dump(run_config, f, indent=4, default=str)
        
        ### save also the whole run information as .pt (config and metrics)
        run_info = {
            "config": run_config,
            "metrics": metrics,
        }
        run_info_path = os.path.join(
            os.environ["CKPT_PATH"], run_config["ckpt_base_dir"], f"run_{meta_task_name}.pt"
        )
        torch.save(run_info, run_info_path)
        
        results[meta_task_name] = run_info

    return results

In [None]:
def meta_test_baselines(baseline_opters, optees, config, use_existing_baselines=True, save_ckpts_for_all_test_runs=False):
    """
    Parameters
    ----------
    baseline_opters : list of tuples
        List of tuples of the form (optimizer_name, optimizer_cls, optimizer_config).
    optees : list of tuples
        List of tuples of the form (optimizee_cls, optimizee_config).
    config : dict
        The config dictionary for the training run.
    use_existing_baselines : bool, optional
        Whether to use existing baselines if they exist, by default True.
    save_ckpts_for_all_test_runs : bool, optional
        Whether to save checkpoints for all test runs, by default False.
        
    Returns
    -------
    results : dict
        Dictionary of results.
    """
    results = dict()
    for optee_cls, optee_config in optees:
        ### train optees with baseline optimizers (or load previous)
        for opter_name, baseline_opter_cls, opter_config in baseline_opters:
            print(f"Training {optee_cls.__name__}_{dict_to_str(optee_config)} with {opter_name} optimizer")
            torch.manual_seed(0)
            np.random.seed(0)

            ### set config
            run_config = deepcopy(config)
            run_config["meta_testing"]["optee_cls"] = optee_cls
            run_config["meta_testing"]["optee_config"] = optee_config

            ### prepare checkpointing
            baseline_file_nickname = get_baseline_ckpt_dir(
                opter_cls=baseline_opter_cls,
                opter_config=opter_config,
                optee_cls=run_config["meta_testing"]["optee_cls"],
                optee_config=run_config["meta_testing"]["optee_config"],
                data_cls=run_config["meta_testing"]["data_cls"],
                data_config=run_config["meta_testing"]["data_config"],
            )
            baseline_opter_dir = os.path.join(
                os.environ["CKPT_PATH"], run_config["ckpt_baselines_dir"], baseline_file_nickname
            )
            os.makedirs(baseline_opter_dir, exist_ok=True)
            metrics_path = os.path.join(baseline_opter_dir, "metrics.npy")

            ### load previous if exists
            if use_existing_baselines and os.path.exists(metrics_path) and os.path.isdir(os.path.join(baseline_opter_dir, "ckpt")):
                print(
                    f"  Existing metrics and checkpoints for {opter_name} exist, skipping..."
                    f"\n  metrics_file: {metrics_path}"
                )
                continue

            ### not reusing existing, train from scratch
            opter_config_for_run = deepcopy(opter_config)

            ### find best lr
            if "lr" in opter_config_for_run and callable(opter_config_for_run["lr"]):
                print(f"  Finding best lr for {opter_name} optimizer")
                best_lr = opter_config_for_run["lr"](
                    data_cls=run_config["meta_testing"]["data_cls"],
                    optee_cls=run_config["meta_testing"]["optee_cls"],
                    opter_cls=baseline_opter_cls,
                    n_tests=3,
                    n_iters=run_config["meta_testing"]["n_iters"] // 2,
                    optee_config=run_config["meta_testing"]["optee_config"],
                    opter_config=opter_config_for_run,
                    consider_metric="train_loss",
                )
                opter_config_for_run["lr"] = best_lr
                print(f"  Best lr for {opter_name} optimizer: {best_lr}")

            ### dump config
            run_config["meta_testing"]["baseline_opter_config"] = opter_config_for_run
            with open(os.path.join(baseline_opter_dir, "config.json"), "w") as f:
                json.dump(run_config, f, indent=4, default=str)
            torch.save(run_config, os.path.join(baseline_opter_dir, "config.pt"))

            ### train
            baseline_ckpt_dir = os.path.join(baseline_opter_dir, "ckpt")
            baseline_metrics = fit_normal(
                data_cls=run_config["meta_testing"]["data_cls"],
                data_config=run_config["meta_testing"]["data_config"],
                optee_cls=run_config["meta_testing"]["optee_cls"],
                optee_config=run_config["meta_testing"]["optee_config"],
                opter_cls=baseline_opter_cls,
                opter_config=opter_config_for_run,
                n_iters=run_config["meta_testing"]["n_iters"],
                n_tests=run_config["eval_n_tests"],
                ckpt_iter_freq=run_config["meta_testing"]["ckpt_iter_freq"],
                ckpt_dir=baseline_ckpt_dir,
                save_ckpts_for_all_test_runs=save_ckpts_for_all_test_runs,
            )

            ### save metrics to disk
            np.save(metrics_path, baseline_metrics)
            print(f"  Metrics of {opter_name} saved to {metrics_path}")

            ### save all info to disk as .pt (config and metrics)
            run_info = {
                "config": run_config,
                "metrics": baseline_metrics,
            }
            run_info_path = os.path.join(baseline_opter_dir, "run.pt")
            torch.save(run_info, run_info_path)

            results[f"{opter_name}_{optee_cls.__name__}_{dict_to_str(optee_config)}"] = run_info

    return results

# Initialize L2O Optimizer

### Load pre-trained L2O

In [None]:
### load previous checkpoint (and skip meta-training of a new l2O optimizer)
opter, config, ckpt = load_ckpt(
    dir_path=os.path.join(os.environ["CKPT_PATH"], "10-05-2023_13-26-17_MNISTReluBatchNorm_Optimizer")
)
print(json.dumps(config, indent=4, default=str))

### Meta-train a new L2O

In [None]:
### create a new config
config = { # global config
    "opter_cls": Optimizer,
    "opter_config": {
        "preproc": True,
        "additional_inp_dim": 1,
        "manual_init_output_params": False,
    },
    "eval_n_tests": 5,
    "ckpt_base_dir": datetime.now().strftime('%d-%m-%Y_%H-%M-%S'),
    "ckpt_baselines_dir": "baselines",
}

config["meta_training"] = { # training the optimizer
    "opter_cls": config["opter_cls"],
    "opter_config": config["opter_config"],
    "data_cls": MNIST,
    "data_config": {
        "batch_size": 128,
        # "only_classes": [0, 1],
    },
    "optee_cls": MNISTNet,
    "optee_config": {},
    "n_epochs": 50,
    "n_optim_runs_per_epoch": 20,
    "n_iters": 200,
    "unroll": 20,
    "n_tests": 0,
    "optee_updates_lr": 0.1,
    "opter_lr": 0.001,
    "log_unroll_losses": True,
    "opter_updates_reg_func": None,
    "opter_updates_reg_func_config": {},
    "reg_mul": 0,
    "eval_iter_freq": 10,
    "ckpt_iter_freq": 5,
    "ckpt_epoch_freq": 5,
    "ckpt_dir": None, # will be set later
    "load_ckpt": None,
    "start_from_epoch": 0,
    "verbose": 2,
}

config["meta_testing"] = { # testing the optimizer
    "data_cls": config["meta_training"]["data_cls"],
    "data_config": config["meta_training"]["data_config"],
    "optee_cls": config["meta_training"]["optee_cls"],
    "optee_config": config["meta_training"]["optee_config"],
    "unroll": 1,
    "n_iters": 1000,
    "optee_updates_lr": config["meta_training"]["optee_updates_lr"],
    "train_opter": False,
    "opter_optim": None,
    "ckpt_iter_freq": 5,
    "ckpt_dir": None, # will be set later
}

### additional config
config["ckpt_base_dir"] = f"{config['ckpt_base_dir']}_{config['meta_training']['optee_cls'].__name__}_{config['opter_cls'].__name__}"
config["meta_training"]["ckpt_dir"] = os.path.join(os.environ["CKPT_PATH"], config["ckpt_base_dir"], "meta_training")
config["meta_testing"]["ckpt_dir"] = os.path.join(os.environ["CKPT_PATH"], config["ckpt_base_dir"], "meta_testing")

In [None]:
### create directories
os.makedirs(os.path.join(os.environ["CKPT_PATH"], config["ckpt_base_dir"]), exist_ok=False)
os.makedirs(config["meta_training"]["ckpt_dir"], exist_ok=True)
os.makedirs(config["meta_testing"]["ckpt_dir"], exist_ok=True)

### dump config
with open(os.path.join(os.environ["CKPT_PATH"], config["ckpt_base_dir"], "config.json"), "w") as f:
    json.dump(config, f, indent=4, default=str)

print(f"Base directory: {config['ckpt_base_dir']}")

In [None]:
### meta-train a new L2O optimizer model
torch.manual_seed(0)
np.random.seed(0)

best_training_loss, metrics, opter_state_dict = fit_optimizer(
    **config["meta_training"],
)
opter = w(
    config["opter_cls"](
        **config["opter_config"] if config["opter_config"] is not None else {}
    )
)
opter.load_state_dict(opter_state_dict)
print(best_training_loss)

### save the final model
torch.save({
    "state_dict": opter_state_dict,
    "config": config,
    "loss": best_training_loss,
    "metrics": metrics,
}, os.path.join(os.environ["CKPT_PATH"], config["ckpt_base_dir"], f"l2o_optimizer.pt"))

# Evaluate

In [None]:
optees_to_test = [
    (MNISTRelu, {}),
    (MNISTConv, {}),
    (MNISTSigmoid, {"layer_sizes": [100, 100]}),
    (MNISTReLU, {"layer_sizes": [100, 100]}),
    (MNISTNet2Layer, {}),
    (MNISTNetBig, {}),
    (MNISTReluBatchNorm, {"affine": True, "track_running_stats": True}),
    (MNISTLeakyRelu, {}),
    (MNISTNet, {}),
]
baselines_to_test_against = [
    ("Adam", optim.Adam, {"lr": find_best_lr_normal}),
    ("SGD", optim.SGD, {"lr": find_best_lr_normal, "momentum": 0.9}),
]
use_existing_baselines = True # loads existing if exists, otherwise trains from scratch

### Meta-test L2O optimizer and baselines

In [None]:
results = meta_test(
    opter=opter,
    optees=optees_to_test,
    config=config,
    save_ckpts_for_all_test_runs=False,
)

In [None]:
results_baselines = meta_test_baselines(
    baseline_opters=baselines_to_test_against,
    optees=optees_to_test,
    config=config,
    use_existing_baselines=use_existing_baselines,
    save_ckpts_for_all_test_runs=False,
)

In [None]:
### load all l2o metrics from disk
l2o_metrics = {}

for metrics_file in [f_name for f_name in os.listdir(os.path.join(os.environ["CKPT_PATH"], config["ckpt_base_dir"])) if f_name.startswith("metrics_")]:
    print(f"Loading {metrics_file}")
    metrics_name = metrics_file[8:-4] # remove the "metrics_" prefix and ".npy" suffix
    l2o_metrics[metrics_name] = np.load(os.path.join(os.environ["CKPT_PATH"], config["ckpt_base_dir"], metrics_file), allow_pickle=True).item()

In [None]:
### load baseline metrics from disk
baseline_metrics = dict()

for optee_cls, optee_config in optees_to_test:
    run_nickname = f"{optee_cls.__name__}_{dict_to_str(optee_config)}_{config['meta_testing']['data_cls'].__name__}_{dict_to_str(config['meta_testing']['data_config'])}"
    baseline_metrics[run_nickname] = dict()
    
    ### load metrics for all considered baselines
    for (opter_name, baseline_opter_cls, baseline_opter_config) in baselines_to_test_against:
        baseline_opter_config_copy = deepcopy(baseline_opter_config)
        
        if "lr" in baseline_opter_config and callable(baseline_opter_config["lr"]):
            baseline_opter_config_copy["lr"] = baseline_opter_config_copy["lr"].__name__ # replace function with its name

        baseline_dir_name = f"{opter_name}_{dict_to_str(baseline_opter_config_copy)}" \
            + f"_{optee_cls.__name__}_{dict_to_str(optee_config)}" \
            + f"_{config['meta_testing']['data_cls'].__name__}_{dict_to_str(config['meta_testing']['data_config'])}"
        metrics_path = os.path.join(os.environ["CKPT_PATH"], config["ckpt_baselines_dir"], baseline_dir_name, "metrics.npy")
        
        ### load
        print(f"Loading {metrics_path}")
        baseline_metrics[run_nickname][opter_name] = np.load(metrics_path, allow_pickle=True).item()

### Plot meta-testing results

In [None]:
show_max_iters = 500
log_losses = True
fig_save_dir = os.path.join(os.environ["CKPT_PATH"], config["ckpt_base_dir"])
# fig_save_dir = None

for optee_nickname, metrics in l2o_metrics.items():
    if optee_nickname not in baseline_metrics:
        continue
    curr_baseline_metrics = baseline_metrics[optee_nickname]

    ### plot comparison
    fig = plt.figure(figsize=(26, 16))
    fig.suptitle(f"Meta-testing on {optee_nickname}", fontsize=16, fontweight="bold", y=0.92)
    
    for m_i, metric in enumerate(["train_loss", "test_loss", "train_acc", "test_acc"]):
        ax = fig.add_subplot(2, 2, m_i + 1)
        
        ### baseline optimizers
        for opter_name, opter_metrics in curr_baseline_metrics.items():
            if "test" in metric:
                x = np.arange(config["meta_training"]["eval_iter_freq"], show_max_iters + 1, config["meta_training"]["eval_iter_freq"])
                y = np.mean(opter_metrics[metric][:,:show_max_iters // 10], axis=0)
                y_min = np.min(opter_metrics[metric][:,:show_max_iters // 10], axis=0)
                y_max = np.max(opter_metrics[metric][:,:show_max_iters // 10], axis=0)
            else:
                x = range(opter_metrics[metric][:,:show_max_iters].shape[1])
                y = np.mean(opter_metrics[metric][:,:show_max_iters], axis=0)
                y_min = np.min(opter_metrics[metric][:,:show_max_iters], axis=0)
                y_max = np.max(opter_metrics[metric][:,:show_max_iters], axis=0)
            sns.lineplot(
                x=x,
                y=y,
                label=opter_name,
                linestyle="--",
                ax=ax,
            )
            ax.fill_between(
                x=x,
                y1=y_min,
                y2=y_max,
                alpha=0.1,
            )

            if log_losses:
                # set y to log scale
                if "loss" in metric:
                    ax.set_yscale("log")
        
        ### L2O optimizer
        if "test" in metric:
            x = np.arange(config["meta_training"]["eval_iter_freq"], show_max_iters + 1, config["meta_training"]["eval_iter_freq"])
            y = np.mean(metrics[metric][:,:show_max_iters // config["meta_training"]["eval_iter_freq"]], axis=0)
            y_min = np.min(metrics[metric][:,:show_max_iters // config["meta_training"]["eval_iter_freq"]], axis=0)
            y_max = np.max(metrics[metric][:,:show_max_iters // config["meta_training"]["eval_iter_freq"]], axis=0)
        else:
            x = range(metrics[metric][:,:show_max_iters].shape[1])
            y = np.mean(metrics[metric][:,:show_max_iters], axis=0)
            y_min = np.min(metrics[metric][:,:show_max_iters], axis=0)
            y_max = np.max(metrics[metric][:,:show_max_iters], axis=0)
        sns.lineplot(
            x=x,
            y=y,
            label="L2O",
            color="orange",
            linewidth=2,
            ax=ax,
        )
        ax.fill_between(
            x=x,
            y1=y_min,
            y2=y_max,
            alpha=0.2,
            color="orange",
        )
        
        ### plot settings
        ax.set_title(metric, fontsize=14, fontweight="bold")
        ax.set_xlabel("Iteration")
        ax.set_ylabel(metric)
        if "acc" in metric:
            ax.set_ylim(0.6, 1.0)
        ax.legend()

    plt.show()

    ### save the figure
    if fig_save_dir is not None:
        prefix = "log_losses" if log_losses else "losses"
        fig.savefig(os.path.join(fig_save_dir, f"{prefix}_{optee_nickname}_{show_max_iters}.png"))