In [None]:
import os
import json
import time
import copy
from datetime import datetime
from functools import partial
import matplotlib.pyplot as plt
import seaborn as sns
import lovely_tensors as lt # can be removed
import numpy as np
import dill
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from fl2o.optimizee import MLPOptee, CustomParams
from fl2o.optimizee_modules import MetaParameter
from fl2o.optimizer import GD, Adam, FGD, AFOGD, CFGD, CFGD_ClosedForm, L2O_Update
from fl2o.l2o import L2O
from fl2o.data import MNIST, CustomTask, generate_least_squares_task
from fl2o.training import do_fit, find_best_lr, meta_train, get_optimal_lr
from fl2o.utils import plot_log, plotter

lt.monkey_patch() # can be removed

DATA_PATH = os.getenv("DATA_PATH")
CKPT_PATH = os.getenv("CKPT_PATH")
DEVICE = os.getenv("DEVICE", "cpu")

print(f"{DATA_PATH=}\n{CKPT_PATH=}\n{DEVICE=}")

In [None]:
### load previous checkpoint (and skip meta-training of a new l2O optimizer)
ckpt = torch.load(
    os.path.join(
        CKPT_PATH,
        "l2o",
        "29-01_00-06__L2O__CFGD_ClosedForm",
        "ckpt.pth"
        # "meta_training",
        # "20.pt",
    ),
    map_location=torch.device(DEVICE),
    pickle_module=dill,
)
l2o_dict = ckpt["l2o_dict"]
l2o_dict_best = ckpt["l2o_dict_best"]
log = ckpt["log"]
config = ckpt["config"]
print(json.dumps(config, indent=4, default=str))

## Least Squares

$$
\begin{aligned}
    \min_{x} f(x) = \frac{1}{2} ||W^T x - y||_2^2 \\
    \text{where } W \in \mathbb{R}^{d \times m}, y \in \mathbb{R}^m
\end{aligned}
$$

In [None]:
config = {
    "time": datetime.now().strftime("%d-%m_%H-%M"),
}

### data (task)
config["data"] = {
    "d": 100,
    "m": 100,
    "data_cls": CustomTask,
}
config["data"]["data_config"] = {
    "task": generate_least_squares_task,
    "task_config": {
        "d": config["data"]["d"],
        "m": config["data"]["m"],
        "verbose": False,
        "device": DEVICE,
    },
}


### optimizee
config["optee"] = {
    "optee_cls": CustomParams,
    "optee_config": {
        "dim": (1, config["data"]["d"]),
        "init_params": "randn",
    },
}

### optimizer L2O-CFGD
# config["opter"] = {
#     "opter_cls": L2O,
#     "opter_config": {
#         "in_dim": 3, # len(in_features) + 1
#         "out_dim": 3,
#         "hidden_sz": 40,
#         "in_features": ("grad", "iter_num_enc"),
#         "base_opter_cls": CFGD_ClosedForm,
#         "base_opter_config": {
#             "lr": get_optimal_lr,
#             "gamma": None,
#             "c": None,
#             "version": "NA",
#             "device": DEVICE,
#         },
#         "params_to_optimize": {
#             # "gamma": {
#             #     "idx": 0,
#             #     "act_fns": ("identity", "diag"),
#             # },
#             # "gamma": {
#             #     "idx": 0,
#             #     "act_fns": ("alpha_to_gamma", "diag"),
#             #     "beta": 0.,
#             # },
#             "gamma": {
#                 "idx": (0, 1),
#                 "act_fns": ("alpha_beta_to_gamma", "diag"),
#             },
#             "c": {
#                 "idx": 2,
#                 "act_fns": ("identity",),
#             },
#         },
#     },
# }

### optimizer L2O
config["opter"] = {
    "opter_cls": L2O,
    "opter_config": {
        "in_dim": 3, # len(in_features) + 1
        "out_dim": 1,
        "hidden_sz": 40,
        "in_features": ("grad", "iter_num_enc"),
        "base_opter_cls": L2O_Update,
        "base_opter_config": {
            "lr": get_optimal_lr,
            "device": DEVICE,
        },
        "params_to_optimize": {
            "update": {
                "idx": 0,
                "act_fns": ("identity",),
            },
        },
    },
}

### meta-training config
config["meta_training_config"] = {
    "meta_opter_cls": optim.Adam,
    "meta_opter_config": {
        "lr": 1e-3,
    },
    "n_runs": 2000,
    "unroll": 20,
    "loggers": [
        # {
        #     "every_nth_run": 20,
        #     "logger_fn": partial(plotter, to_plot="gamma"),
        # },
        # {
        #     "every_nth_run": 20,
        #     "logger_fn": partial(plotter, to_plot="c"),
        # }
    ],
}

### other
config.update({
    "n_iters": 800,
    "l2o_dict": None,
    "additional_metrics": {
        # "gamma": lambda opter, **kwargs: \
        #     # opter.base_opter.param_groups[0]["gamma"].item() \
        #     opter.base_opter.param_groups[0]["gamma"].mean().item() \
        #     if hasattr(opter, "base_opter") else opter.param_groups[0].get("gamma", None),
        # "c": lambda opter, **kwargs: \
        #     # opter.base_opter.param_groups[0]["c"].item() \
        #     opter.base_opter.param_groups[0]["c"].mean().item() \
        #     if hasattr(opter, "base_opter") else opter.param_groups[0].get("c", None),
        # "l2_dist(x_tik*, x)": lambda task, optee, **kwargs: \
        #     torch.norm(task["x_tik_solution"](gamma=1., c=1) - optee.params.detach(), p=2).item(),
        # "l2_dist(x*, x)": lambda task, optee, **kwargs: \
        #     torch.norm(task["x_solution"] - optee.params.detach(), p=2).item(),
    },
    "ckpt_config": {
        "ckpt_every_nth_run": 50,
        "ckpt_dir": os.path.join(
            CKPT_PATH,
            "l2o",
            config["time"] + "__"\
                + config["opter"]["opter_cls"].__name__ + "__"\
                + config["opter"]["opter_config"]["base_opter_cls"].__name__,
        ),
    },
    "device": DEVICE,
    "seed": 0,
})
config["ckpt_config"]["ckpt_dir_meta_training"] = os.path.join(
    config["ckpt_config"]["ckpt_dir"],
    "meta_training",
)
config["ckpt_config"]["ckpt_dir_meta_testing"] = os.path.join(
    config["ckpt_config"]["ckpt_dir"],
    "meta_testing",
)

### make dirs
os.makedirs(config["ckpt_config"]["ckpt_dir"], exist_ok=True)
os.makedirs(config["ckpt_config"]["ckpt_dir_meta_training"], exist_ok=True)
os.makedirs(config["ckpt_config"]["ckpt_dir_meta_testing"], exist_ok=True)

### save config
with open(os.path.join(config["ckpt_config"]["ckpt_dir"], "config.json"), "w") as f:
    json.dump(config, f, indent=4, default=str)

print(f"Path to checkpoints: {config['ckpt_config']['ckpt_dir']}")

In [None]:
### meta train
torch.manual_seed(config["seed"])
np.random.seed(config["seed"])
l2o_dict, l2o_dict_best, log = meta_train(
    config=config,
    ### keep meta-training
    l2o_dict=l2o_dict,
    l2o_dict_best=l2o_dict_best,
    log=log,
)

### save final checkpoint
torch.save({
    "l2o_dict": l2o_dict,
    "l2o_dict_best": l2o_dict_best,
    "log": log,
    "config": config,
}, os.path.join(config["ckpt_config"]["ckpt_dir"], "ckpt.pth"), pickle_module=dill)

plt.plot(log["loss_sum"])

In [None]:
### meta-testing config
test_d, test_m = 300, 300
n_test_runs = 3
test_run_iters = 5000
test_runs_seed = 0

runs = dict()

update_config_base = dict()
update_config_base["n_iters"] = test_run_iters
update_config_base["data"] = {
    "d": test_d,
    "m": test_m,
    "data_cls": CustomTask,
    "data_config": {
        "task": generate_least_squares_task,
        "task_config": {
            "d": test_d,
            "m": test_m,
            "verbose": False,
            "device": DEVICE,
        },
    },
}
update_config_base["optee"] = {
    "optee_cls": CustomParams,
    "optee_config": {
        "dim": (1, test_d),
        "init_params": "randn",
    },
}
update_config_base["additional_metrics"] = {
    # "l2_dist(x_tik*, x)": lambda task, optee, **kwargs: \
    #     torch.norm(task["x_tik_solution"](gamma=gamma, c=1) - optee.params.detach(), p=2).item(),
    # "l2_dist(x*, x)": lambda task, optee, **kwargs: \
    #     torch.norm(task["x_solution"] - optee.params.detach(), p=2).item(),
    # "x": lambda y_hat, y, optee: optee.params.detach().cpu().numpy(),
    "cos_sim(d, x.grad)": lambda opter, **kwargs: \
        torch.cosine_similarity(
            opter.state[0]["last_update"].flatten() if "last_update" in opter.state[0] else opter.param_groups[0]["last_update"].flatten(),
            opter.state[0]["last_grad"].flatten() if "last_grad" in opter.state[0] else opter.param_groups[0]["last_grad"].flatten(),
            dim=0
        ).item(),
    "last_lr": lambda opter, **kwargs: \
        opter.state[0]["last_lr"] if "last_lr" in opter.state[0] else opter.param_groups[0]["last_lr"],
}


runs["GD"] = {
    "update_config": {
        **update_config_base,
        "opter": {
            "opter_cls": GD,
            "opter_config": {
                "lr": get_optimal_lr,
                "device": DEVICE,
            },
        },
    },
    "plot_config": {
        "color": "black",
        "linestyle": "dashed",
    },
}

# runs["Adam"] = {
#     "update_config": {
#         **update_config_base,
#         "opter": {
#             "opter_cls": Adam,
#             "opter_config": {
#                 "lr": get_optimal_lr,
#                 "device": DEVICE,
#             },
#         },
#     },
#     "plot_config": {
#         "color": "gray",
#         "linestyle": "dashed",
#     },
# }

for gamma in [0.1]:
    runs[r"NA-CFGD, $\gamma$=" + str(gamma)] = {
        "update_config": {
            **update_config_base,
            "opter": {
                "opter_cls": CFGD_ClosedForm,
                "opter_config": {
                    "lr": get_optimal_lr,
                    "gamma": gamma,
                    "c": 1.,
                    "version": "NA",
                    "init_points": None,
                    "device": DEVICE,
                },
            },
        },
        "plot_config": {
            "linestyle": "dashed",
        },
    }

for gamma in [0.1]:
    runs[r"AT-CFGD, $\gamma$=" + str(gamma)] = {
        "update_config": {
            **update_config_base,
            "opter": {
                "opter_cls": CFGD_ClosedForm,
                "opter_config": {
                    "lr": get_optimal_lr,
                    "gamma": gamma,
                    "c": 1.,
                    "version": "AT",
                    "init_points": [-1. * torch.ones(1, test_d, device=DEVICE)],
                    "device": DEVICE,
                },
            },
        },
        "plot_config": {
            "linestyle": "dashed",
        },
    }

runs["L2O-CFGD"] = {
    "update_config": {
        "n_iters": test_run_iters,
        "data": update_config_base["data"],
        "optee": update_config_base["optee"],
        "additional_metrics": {
            # "gamma": lambda opter, **kwargs: \
            #     # opter.base_opter.param_groups[0]["gamma"].item() \
            #     opter.base_opter.param_groups[0]["gamma"].detach().cpu().numpy() \
            #     if hasattr(opter, "base_opter") else opter.param_groups[0].get("gamma", None),
            "c": lambda opter, **kwargs: \
                # opter.base_opter.param_groups[0]["c"].item() \
                opter.base_opter.param_groups[0]["c"].detach().cpu().numpy() \
                if hasattr(opter, "base_opter") else opter.param_groups[0].get("c", None),
            "alpha": lambda opter, **kwargs: \
                # opter.base_opter.param_groups[0]["gamma"].item() \
                opter.base_opter.param_groups[0]["alpha"] \
                if hasattr(opter, "base_opter") else opter.param_groups[0].get("alpha", None),
            "beta": lambda opter, **kwargs: \
                # opter.base_opter.param_groups[0]["c"].item() \
                opter.base_opter.param_groups[0]["beta"] \
                if hasattr(opter, "base_opter") else opter.param_groups[0].get("beta", None),
            "grad": lambda opter, **kwargs: \
                opter.base_opter.state[0]["last_grad"].detach().cpu().numpy() \
                if hasattr(opter, "base_opter") else opter.state[0]["last_grad"].detach().cpu().numpy(),
            "cos_sim(d, x.grad)": lambda opter, **kwargs: \
                torch.cosine_similarity(
                    opter.base_opter.state[0]["last_update"].flatten(),
                    opter.base_opter.state[0]["last_grad"].flatten(),
                    dim=0
                ).item(),
            "last_lr": lambda opter, **kwargs: \
                opter.base_opter.state[0]["last_lr"].item() if type(opter.base_opter.state[0]["last_lr"]) == torch.Tensor else opter.base_opter.state[0]["last_lr"],
        },
        # "l2o_dict": l2o_dict,
        "l2o_dict": l2o_dict_best["best_l2o_dict"],
    },
    "plot_config": {
        "color": "orange",
        "linewidth": "3",
    },
}

# runs["CFGD"] = {
#     "update_config": {
#         **update_config_base,
#         "opter": {
#             "opter_cls": CFGD,
#             "opter_config": {
#                 "lr": get_optimal_lr,
#                 # "lr": 0.1,
#                 "alpha": 0.01,
#                 # "beta": 30.23,
#                 "beta": 0.8,
#                 "c": 1.,
#                 "s": 3,
#                 "version": "AT",
#                 "init_points": [-1. * torch.ones(1, test_d, device=DEVICE)],
#                 "device": DEVICE,
#             },
#         },
#     },
#     "plot_config": {
#         "color": "orange",
#         "linewidth": "3",
#     },
# }

In [None]:
### run all
for run_name in runs.keys():
    if "log" in runs[run_name]:
        continue # already run
    run_config = copy.deepcopy(config)
    if "update_config" in runs[run_name] and runs[run_name]["update_config"] is not None:
        run_config.update(runs[run_name]["update_config"])
    print(f"{run_name}:")
    
    torch.manual_seed(test_runs_seed)
    np.random.seed(test_runs_seed)

    if "lr" in run_config["opter"]["opter_config"] \
        and run_config["opter"]["opter_config"]["lr"] == find_best_lr:
        print("  > Finding best lr...")
        run_config["opter"]["opter_config"]["lr"] = find_best_lr(
            opter_cls=run_config["opter"]["opter_cls"],
            opter_config=run_config["opter"]["opter_config"],
            optee_cls=run_config["optee"]["optee_cls"],
            optee_config=run_config["optee"]["optee_config"],
            data_cls=run_config["data"]["data_cls"],
            data_config=run_config["data"]["data_config"],
            # loss_fn=run_config["loss_fn"],
            n_iters=120,
            n_tests=1,
            consider_metric="loss",
            lrs_to_try=[0.01, 0.05, 0.1, 0.3, 0.5, 0.7, 1.0, 1.5, 2.0],
        )
        print(f"  > Best lr: {run_config['opter']['opter_config']['lr']}")

    print("  > Running...")
    runs[run_name]["log"] = dict()
    for i in range(n_test_runs):
        print(f"    > Run {i+1}/{n_test_runs}...")

        ### check if L2O has been meta-trained
        assert not run_config["opter"]["opter_cls"] == L2O or run_config["l2o_dict"] is not None

        curr_log = do_fit(
            opter_cls=run_config["opter"]["opter_cls"],
            opter_config=run_config["opter"]["opter_config"],
            optee_cls=run_config["optee"]["optee_cls"],
            optee_config=run_config["optee"]["optee_config"],
            data_cls=run_config["data"]["data_cls"],
            data_config=run_config["data"]["data_config"],
            n_iters=run_config["n_iters"],
            l2o_dict=run_config["l2o_dict"],
            in_meta_training=False,
            additional_metrics=run_config["additional_metrics"],
        )[0]

        for metric_name in curr_log.keys():
            if metric_name not in runs[run_name]["log"]:
                runs[run_name]["log"][metric_name] = []
            runs[run_name]["log"][metric_name].append(curr_log[metric_name])

    runs[run_name]["config"] = run_config

In [None]:
plot_log(
    runs,
    only_metrics=["loss", "cos_sim(d, x.grad)", "last_lr"],
    log_metrics=["loss", "l2_dist(x_tik*, x)", "l2_dist(x*, x)"],
    conv_win=1,
    min_max_y_config={
        "last_lr": (0, 100),
    },
    # save_to=os.path.join(
    #     config["ckpt_config"]["ckpt_dir"],
    #     f"loss_cos_sim_l2o_best_dict_{test_d}d_{test_m}m_{n_test_runs}runs_{test_run_iters}iters.png"
    # ),
)

#### Analyze strategy

In [None]:
# gammas = np.stack(runs["L2O + CFGD_ClosedForm"]["log"]["gamma"])
# gammas = gammas.T.diagonal().transpose(1, 0, 2)  # (n_test_runs, n_iters, D)

alphas = np.stack(runs["L2O + CFGD_ClosedForm"]["log"]["alpha"])
# alphas = alphas.T.diagonal().transpose(1, 0, 2)  # (n_test_runs, n_iters, D)

betas = np.stack(runs["L2O + CFGD_ClosedForm"]["log"]["beta"])
# betas = betas.T.diagonal().transpose(1, 0, 2)  # (n_test_runs, n_iters, D)

cs = np.stack(runs["L2O + CFGD_ClosedForm"]["log"]["c"])  # (n_test_runs, n_iters, D)

grads = np.stack(runs["L2O + CFGD_ClosedForm"]["log"]["grad"])  # (n_test_runs, n_iters, D)

In [None]:
### how do alphas, beta and grads correlate with each other? -> bubble plot
test_run_idx = 0
iters_to_show = [0, 1, 2, 5, 20, 100, 500]

fig = plt.figure(figsize=(14, 18), facecolor="white")
# fig.suptitle("L2O + CFGD_ClosedForm")
ax_idx = 1

for iter_idx in iters_to_show:
    ax = fig.add_subplot(len(iters_to_show), 3, ax_idx)
    
    ### alphas
    sns.scatterplot(
        x=grads[test_run_idx, iter_idx].squeeze(),
        y=alphas[test_run_idx, iter_idx].squeeze(),
        ax=ax,
    )
    ax.set_xlabel(r"$\partial_x f(x)$")
    ax.set_ylabel(r"$\alpha$")
    ax.set_title(fr"$\alpha$ (iter {iter_idx})")

    ### betas
    ax = fig.add_subplot(len(iters_to_show), 3, ax_idx + 1)
    sns.scatterplot(
        x=grads[test_run_idx, iter_idx].squeeze(),
        y=betas[test_run_idx, iter_idx].squeeze(),
        ax=ax,
    )
    ax.set_xlabel(r"$\partial_x f(x)$")
    ax.set_ylabel(r"$\beta$")
    ax.set_title(fr"$\beta$ (iter {iter_idx})")

    ### cs
    ax = fig.add_subplot(len(iters_to_show), 3, ax_idx + 2)
    sns.scatterplot(
        x=grads[test_run_idx, iter_idx].squeeze(),
        y=cs[test_run_idx, iter_idx].squeeze(),
        ax=ax,
    )
    ax.set_xlabel(r"$\partial_x f(x)$")
    ax.set_ylabel(r"$c$")
    ax.set_title(fr"$c$ (iter {iter_idx})")

    ax_idx += 3

fig.tight_layout(h_pad=1.5)
save_to = os.path.join(
    config["ckpt_config"]["ckpt_dir"],
    f"strategy_grad_alpha_beta_c_{test_d}d_{test_m}m_{n_test_runs}runs_{test_run_iters}iters.png"
)
fig.savefig(save_to)

plt.show()

In [None]:
test_run_idx = 0
to_plot_name = "c"
plot_plot_label = r"$c$"
log_plot = False

fig = plt.figure(figsize=(15, 6), facecolor="w")
fig.suptitle(rf"L2O-CFGD: {plot_plot_label}", fontsize=16)

ax = fig.add_subplot(121)
if to_plot_name == "alpha":
    to_plot = alphas
elif to_plot_name == "beta":
    to_plot = betas
elif to_plot_name == "gamma":
    to_plot = gammas
elif to_plot_name == "c":
    to_plot = cs
elif to_plot_name == "grad":
    to_plot = grads
else:
    raise ValueError(f"Unknown to_plot: {to_plot_name}")

plt.plot(to_plot[test_run_idx].squeeze(), alpha=0.1, color="grey")
plt.plot(to_plot[test_run_idx].squeeze().mean(-1), color="orange", linewidth=3)
ax.set_xlabel("Iteration", fontsize=13)
ax.set_ylabel(plot_plot_label, fontsize=13)

### share y-axis with left plot
ax = fig.add_subplot(122, sharey=ax)
plt.plot(to_plot[test_run_idx].squeeze()[:,:3])
ax.set_xlabel("Iteration", fontsize=13)
ax.set_ylabel(plot_plot_label, fontsize=13)
ax.legend(["Parameter #1", "Parameter #2", "Parameter #3"])

if log_plot:
    ax.set_yscale("log")

plt.tight_layout(h_pad=2.5)
file_name = f"l2o_cfgd_{to_plot_name}_{test_d}d_{test_m}m_{n_test_runs}runs_{test_run_iters}iters.png"
save_to = os.path.join(
    config["ckpt_config"]["ckpt_dir"],
    f"strategy_l2o_best_dict_{to_plot_name}_{test_d}d_{test_m}m_{n_test_runs}runs_{test_run_iters}iters.png"
)
fig.savefig(save_to)

plt.show()

In [None]:
test_run_idx = 0

fig = plt.figure(figsize=(15, 6))
fig.suptitle(r"L2O-CFGD: $\gamma$", fontsize=16)

ax = fig.add_subplot(121)
plt.plot(gammas[test_run_idx], alpha=0.1, color="grey")
plt.plot(gammas[test_run_idx].mean(-1), color="orange", linewidth=3)
ax.set_xlabel("Iteration", fontsize=13)
ax.set_ylabel(r"$\gamma$", fontsize=13)

### share y-axis with left plot
ax = fig.add_subplot(122, sharey=ax)
plt.plot(gammas[test_run_idx,:,:3])
ax.set_xlabel("Iteration", fontsize=13)
ax.set_ylabel(r"$\gamma$", fontsize=13)
ax.legend(["Parameter #1", "Parameter #2", "Parameter #3"])

plt.show()

In [None]:
test_run_idx = 0

fig = plt.figure(figsize=(15, 6))
fig.suptitle(r"L2O-CFGD: $c$", fontsize=16)

ax = fig.add_subplot(121)
plt.plot(cs[test_run_idx], alpha=0.1, color="grey")
plt.plot(cs[test_run_idx].mean(-1), color="orange", linewidth=3)
ax.set_xlabel("Iteration", fontsize=13)
ax.set_ylabel(r"$c$", fontsize=13)

### share y-axis with left plot
ax = fig.add_subplot(122, sharey=ax)
plt.plot(cs[test_run_idx,:,:3])
ax.set_xlabel("Iteration", fontsize=13)
ax.set_ylabel(r"$c$", fontsize=13)
ax.legend(["Parameter #1", "Parameter #2", "Parameter #3"])

plt.show()

## Quadratic Objective Function (Figure 1)

In [None]:
"""
Quadratic objective function
    min_x f(x, y) = 10 * x^2 + y^2
"""
def task_gen(device="cpu"):
    ### Least squares problem 1/2 x^T A x + b^T x
    A = torch.tensor([[10., 0.], [0., 1.]], device=device)
    b = torch.tensor([[0., 0.]], device=device).T
    x_solution = torch.tensor([0., 0.], device=device)

    loss_fn = lambda y_hat: 0.5 * y_hat @ A @ y_hat.T + b.T @ y_hat.T

    return {
        "A": A,
        "b": b,
        "loss_fn": loss_fn,
        "x_solution": x_solution,
    }

In [None]:
config = {
    "optee": {
        "optee_cls": CustomParams,
        "optee_config": {
            "dim": (1, 2),
            "init_params": torch.tensor([[1., -10.]], device=DEVICE),
            "param_func": None,
        },
    },
    "opter": {
        "opter_cls": CFGD_ClosedForm,
        "opter_config": {
            "lr": get_optimal_lr,
            "gamma": -1.,
            "c": None,
            "version": "AT",
            "init_points": [torch.tensor([-1., -1.], device=DEVICE)],
            "device": DEVICE,
        },
    },
    "data": {
        "data_cls": CustomTask,
        "data_config": {
            "task": partial(task_gen, verbose=False, device=DEVICE),
            "device": DEVICE,
        },
    },
    "n_iters": 50,
    "l2o_dict": None,
    "additional_metrics": {
        "l2_dist(x*, x)": lambda task, optee, **kwargs: \
            torch.norm(task["x_solution"] - optee.params.detach(), p=2).item(),
    },
    "device": DEVICE,
    "seed": 0,
}

In [None]:
### runs config
runs = dict()

runs["GD"] = {
    "update_config": {
        "opter": {
            "opter_cls": GD,
            "opter_config": {
                "lr": get_optimal_lr,
                "device": DEVICE,
            },
        },
        "additional_metrics": {
            "l2_dist(x*, x)": lambda task, optee, **kwargs: \
                torch.norm(task["x_solution"] - optee.params.detach(), p=2).item(),
            "x": lambda optee, **kwargs: optee.params.detach().cpu().numpy(),
        },
    },
    "plot_config": {
        "color": "black",
        "linestyle": "dashed",
    },
}

for gamma in [-1., 0.01, 0.05, 0.25, 0.5, 1., 10.]:
    runs[r"CFGD_ClosedForm, $\gamma$=" + str(gamma)] = {
        "update_config": {
            "opter": {
                "opter_cls": CFGD_ClosedForm,
                "opter_config": {
                    "lr": get_optimal_lr,
                    "gamma": gamma,
                    "c": None,
                    "version": "AT",
                    "init_points": [torch.tensor([-1., -1.], device=DEVICE)],
                    "device": DEVICE,
                },
            },
            "additional_metrics": {
                "l2_dist(x*, x)": lambda task, optee, **kwargs: \
                    torch.norm(task["x_solution"] - optee.params.detach(), p=2).item(),
                "x": lambda optee, **kwargs: optee.params.detach().cpu().numpy(),
            },
        },
        # "plot_config": {
        #     "linestyle": "dashed",
        # },
    }

In [None]:
### run all
for run_name in runs.keys():
    if "log" in runs[run_name]:
        continue # already run
    run_config = copy.deepcopy(config)
    if "update_config" in runs[run_name] and runs[run_name]["update_config"] is not None:
        run_config.update(runs[run_name]["update_config"])
    print(f"{run_name}:")
    
    torch.manual_seed(config["seed"])
    np.random.seed(config["seed"])

    if "lr" in run_config["opter"]["opter_config"] \
        and run_config["opter"]["opter_config"]["lr"] == find_best_lr:
        print("  > Finding best lr...")
        run_config["opter"]["opter_config"]["lr"] = find_best_lr(
            opter_cls=run_config["opter"]["opter_cls"],
            opter_config=run_config["opter"]["opter_config"],
            optee_cls=run_config["optee"]["optee_cls"],
            optee_config=run_config["optee"]["optee_config"],
            data_cls=run_config["data"]["data_cls"],
            data_config=run_config["data"]["data_config"],
            # loss_fn=run_config["loss_fn"],
            n_iters=120,
            n_tests=1,
            consider_metric="loss",
            lrs_to_try=[0.01, 0.05, 0.1, 0.3, 0.5, 0.7, 1.0, 1.5, 2.0],
        )
        print(f"  > Best lr: {run_config['opter']['opter_config']['lr']}")

    print("  > Running...")
    runs[run_name]["log"], _, _ = do_fit(
        opter_cls=run_config["opter"]["opter_cls"],
        opter_config=run_config["opter"]["opter_config"],
        optee_cls=run_config["optee"]["optee_cls"],
        optee_config=run_config["optee"]["optee_config"],
        data_cls=run_config["data"]["data_cls"],
        data_config=run_config["data"]["data_config"],
        n_iters=run_config["n_iters"],
        l2o_dict=run_config["l2o_dict"],
        in_meta_training=False,
        additional_metrics=run_config["additional_metrics"],
    )
    runs[run_name]["config"] = run_config

In [None]:
plot_log(
    runs,
    only_metrics=["loss", "l2_dist(x*, x)", "time"],
    log_metrics=["loss", "l2_dist(x_tik*, x)", "l2_dist(x*, x)"],
    conv_win=1,
)

In [None]:
def objective_function(x, y):
    return 10 * x**2 + y**2

def plot_optimizer_steps(runs, num_steps=10, only_opters=None, starting_points=None):
    plt.figure(figsize=(15, 8))

    # Plot the contour of the objective function
    x_range = np.linspace(-1.5, 1.5, 100)
    y_range = np.linspace(-10.5, 2, 100)
    X, Y = np.meshgrid(x_range, y_range)
    Z = objective_function(X, Y)
    contour = plt.contour(Y, X, Z, levels=20, cmap='viridis')
    plt.colorbar(contour, label='Objective Function Value')

    for optimizer, data in runs.items():
        if only_opters is not None and optimizer not in only_opters:
            continue


        x_values = [x.flatten() for x in data["log"]["x"][:num_steps]]
        if starting_points is not None:
            x_values = starting_points + x_values
        x_values = np.stack(x_values)

        plt.plot(x_values[:, 1], x_values[:, 0], label=optimizer, marker='o', markersize=8, linewidth=3)

    plt.title('Optimizer Steps')
    plt.xlabel('y')
    plt.ylabel('x')
    plt.legend()
    plt.show()

# Example usage:
# Replace 'runs' with your actual dictionary
# For demonstration, I'm using a simplified structure
# runs = {
#     "GD": {"log": {"x": [torch.randn(1, 2) for _ in range(20)]}},
#     "Adam": {"log": {"x": [torch.randn(1, 2) for _ in range(20)]}},
#     # Add more optimizers as needed
# }

plot_optimizer_steps(
    runs,
    only_opters=[
        "GD",
        "CFGD_ClosedForm, $\\gamma$=-1.0",
        "CFGD_ClosedForm, $\\gamma$=0.5",
        "CFGD_ClosedForm, $\\gamma$=10.0"
    ],
    num_steps=5,
    starting_points=[[1., -10.]],
)

## MNIST

In [None]:
config = {
    "time": datetime.now().strftime("%d-%m_%H-%M"),
}

### data (task)
config["data"] = {
    "data_cls": MNIST,
}
config["data"]["data_config"] = {
    "device": DEVICE,
    "preload": True,
}

### optimizee
config["optee"] = {
    "optee_cls": MLPOptee,
    "optee_config": {
        "layer_sizes": [20],
        "act_fn": nn.ReLU(),
    },
}

### optimizer
config["opter"] = {
    "opter_cls": L2O,
    "opter_config": {
        "in_dim": 3, # len(in_features) + 1
        "out_dim": 3,
        "hidden_sz": 40,
        "in_features": ("grad", "iter_num_enc"),
        "base_opter_cls": CFGD,
        "base_opter_config": {
            "lr": 0.05,
            "alpha": None,
            "beta": None,
            "c": None,
            "s": 1,
            "version": "NA",
            "init_points": None,
            "device": DEVICE,
        },
        "params_to_optimize": {
            "alpha": {
                "idx": 0,
                "act_fns": ("sigmoid",),
            },
            "beta": {
                "idx": 1,
                "act_fns": ("identity",),
            },
            "c": {
                "idx": 2,
                "act_fns": ("identity",),
            },
        },
    },
}
config["meta_training_config"] = {
    "meta_opter_cls": optim.Adam,
    "meta_opter_config": {
        "lr": 3e-4,
    },
    "n_runs": 500,
    "unroll": 30,
    "loggers": [
        # {
        #     "every_nth_run": 20,
        #     "logger_fn": partial(plotter, to_plot="gamma"),
        # },
        # {
        #     "every_nth_run": 20,
        #     "logger_fn": partial(plotter, to_plot="c"),
        # }
    ],
}

### other
config.update({
    "n_iters": 200,
    "l2o_dict": None,
    "additional_metrics": {
        # "gamma": lambda opter, **kwargs: \
        #     # opter.base_opter.param_groups[0]["gamma"].item() \
        #     opter.base_opter.param_groups[0]["gamma"].mean().item() \
        #     if hasattr(opter, "base_opter") else opter.param_groups[0].get("gamma", None),
        # "c": lambda opter, **kwargs: \
        #     # opter.base_opter.param_groups[0]["c"].item() \
        #     opter.base_opter.param_groups[0]["c"].mean().item() \
        #     if hasattr(opter, "base_opter") else opter.param_groups[0].get("c", None),
        # "l2_dist(x_tik*, x)": lambda task, optee, **kwargs: \
        #     torch.norm(task["x_tik_solution"](gamma=1., c=1) - optee.params.detach(), p=2).item(),
        # "l2_dist(x*, x)": lambda task, optee, **kwargs: \
        #     torch.norm(task["x_solution"] - optee.params.detach(), p=2).item(),
    },
    "ckpt_config": {
        "ckpt_every_nth_run": 20,
        "ckpt_dir": os.path.join(
            CKPT_PATH,
            "l2o",
            config["time"] + "__"\
                + config["opter"]["opter_cls"].__name__ + "__"\
                + config["opter"]["opter_config"]["base_opter_cls"].__name__,
        ),
    },
    "device": DEVICE,
    "seed": 0,
})
config["ckpt_config"]["ckpt_dir_meta_training"] = os.path.join(
    config["ckpt_config"]["ckpt_dir"],
    "meta_training",
)
config["ckpt_config"]["ckpt_dir_meta_testing"] = os.path.join(
    config["ckpt_config"]["ckpt_dir"],
    "meta_testing",
)

### make dirs
os.makedirs(config["ckpt_config"]["ckpt_dir"], exist_ok=True)
os.makedirs(config["ckpt_config"]["ckpt_dir_meta_training"], exist_ok=True)
os.makedirs(config["ckpt_config"]["ckpt_dir_meta_testing"], exist_ok=True)

### save config
with open(os.path.join(config["ckpt_config"]["ckpt_dir"], "config.json"), "w") as f:
    json.dump(config, f, indent=4, default=str)

print(f"Path to checkpoints: {config['ckpt_config']['ckpt_dir']}")

In [None]:
### meta train
torch.manual_seed(config["seed"])
np.random.seed(config["seed"])
l2o_dict, l2o_dict_best, log = meta_train(
    config=config,    
    ### keep meta-training
    # l2o_dict=l2o_dict,
    # l2o_dict_best=l2o_dict_best,
    # log=log,
)

### save checkpoint
torch.save({
    "l2o_dict": l2o_dict,
    "l2o_dict_best": l2o_dict_best,
    "log": log,
    "config": config,
}, os.path.join(config["ckpt_config"]["ckpt_dir"], "ckpt.pth"), pickle_module=dill)

plt.plot(log["loss_sum"])

In [None]:
### meta-testing config
n_test_runs = 1
test_run_iters = 500
test_runs_seed = config["seed"]

runs = dict()

update_config_base = dict()
update_config_base["n_iters"] = test_run_iters
update_config_base["optee"] = {
    "optee_cls": MLPOptee,
    "optee_config": {
        "layer_sizes": [60],
        "act_fn": nn.ReLU(),
    },
}
update_config_base["additional_metrics"] = {
    # "l2_dist(x_tik*, x)": lambda task, optee, **kwargs: \
    #     torch.norm(task["x_tik_solution"](gamma=gamma, c=1) - optee.params.detach(), p=2).item(),
    # "l2_dist(x*, x)": lambda task, optee, **kwargs: \
    #     torch.norm(task["x_solution"] - optee.params.detach(), p=2).item(),
    # "x": lambda y_hat, y, optee: optee.params.detach().cpu().numpy(),
    "cos_sim(d, x.grad)": lambda opter, **kwargs: \
        torch.cosine_similarity(
            opter.state[0]["last_update"].flatten() if "last_update" in opter.state[0] else opter.param_groups[0]["last_update"].flatten(),
            opter.state[0]["last_grad"].flatten() if "last_grad" in opter.state[0] else opter.param_groups[0]["last_grad"].flatten(),
            dim=0
        ).item(),
    # "last_lr": lambda opter, **kwargs: \
    #     opter.state[0]["last_lr"] if "last_lr" in opter.state[0] else opter.param_groups[0]["last_lr"],
}

_tmp_optee = update_config_base["optee"]["optee_cls"](**update_config_base["optee"]["optee_config"])

runs["GD"] = {
    "update_config": {
        **update_config_base,
        "opter": {
            "opter_cls": GD,
            "opter_config": {
                "lr": 0.3,
                "device": DEVICE,
            },
        },
    },
    "plot_config": {
        "color": "black",
        "linestyle": "dashed",
    },
}

# for gamma in [0.1]:
#     runs[r"AT-CFGD_ClosedForm, $\gamma$=" + str(gamma)] = {
#         "update_config": {
#             "opter": {
#                 "opter_cls": CFGD_ClosedForm,
#                 "opter_config": {
#                     "lr": 0.3,
#                     "gamma": gamma,
#                     "c": 1.,
#                     "version": "AT",
#                     "init_points": [-1. * torch.ones(1, test_d, device=DEVICE)],
#                     "device": DEVICE,
#                 },
#             },
#             "additional_metrics": {
#                 # "l2_dist(x_tik*, x)": lambda task, optee, **kwargs: \
#                 #     torch.norm(task["x_tik_solution"](gamma=gamma, c=1) - optee.params.detach(), p=2).item(),
#                 # "l2_dist(x*, x)": lambda task, optee, **kwargs: \
#                 #     torch.norm(task["x_solution"] - optee.params.detach(), p=2).item(),
#                 # "x": lambda y_hat, y, optee: optee.params.detach().cpu().numpy(),
#                 "cos_sim(d, x.grad)": lambda opter, **kwargs: \
#                     torch.cosine_similarity(
#                         opter.state[0]["last_update"].flatten(),
#                         opter.state[0]["last_grad"].flatten(),
#                         dim=0
#                     ).item(),
#                 # "last_lr": lambda opter, **kwargs: \
#                 #     opter.state[0]["last_lr"].item() if type(opter.state[0]["last_lr"]) == torch.Tensor else opter.state[0]["last_lr"],
#             },
#             "n_iters": test_run_iters,
#         },
#         "plot_config": {
#             "linestyle": "dashed",
#         },
#     }

for alpha in [0.02]:
    runs[r"NA-CFGD, $\alpha$=" + str(alpha)] = {
        "update_config": {
            **update_config_base,
            "opter": {
                "opter_cls": CFGD,
                "opter_config": {
                    "lr": 0.1,
                    "alpha": alpha,
                    "beta": 0.,
                    "c": 0,
                    "s": 1,
                    "version": "NA",
                    "init_points": None,
                    "device": DEVICE,
                },
            },
        },
        "plot_config": {
            "linestyle": "dashed",
        },
    }

for alpha in [0.02]:
    runs[r"AT-CFGD, $\alpha$=" + str(alpha)] = {
        "update_config": {
            **update_config_base,
            "opter": {
                "opter_cls": CFGD,
                "opter_config": {
                    "lr": 0.1,
                    "alpha": alpha,
                    "beta": 0.,
                    "c": 0,
                    "s": 1,
                    "version": "AT",
                    "init_points": [
                        [-1.*torch.ones_like(p, requires_grad=False, device=DEVICE)] for _, p in _tmp_optee.all_named_parameters()
                    ],
                    "device": DEVICE,
                },
            },
        },
        "plot_config": {
            "linestyle": "dashed",
        },
    }

runs["L2O-CFGD"] = {
    "update_config": {
        "n_iters": update_config_base["n_iters"],
        "optee": update_config_base["optee"],
        "additional_metrics": {
            # "gamma": lambda opter, **kwargs: \
            #     # opter.base_opter.param_groups[0]["gamma"].item() \
            #     opter.base_opter.param_groups[0]["gamma"].detach().cpu().numpy() \
            #     if hasattr(opter, "base_opter") else opter.param_groups[0].get("gamma", None),
            "c": lambda opter, **kwargs: \
                # opter.base_opter.param_groups[0]["c"].item() \
                opter.base_opter.param_groups[0]["c"].detach().cpu().numpy() \
                if hasattr(opter, "base_opter") else opter.param_groups[0].get("c", None),
            "alpha": lambda opter, **kwargs: \
                # opter.base_opter.param_groups[0]["gamma"].item() \
                opter.base_opter.param_groups[0]["alpha"].detach().cpu().numpy() \
                if hasattr(opter, "base_opter") else opter.param_groups[0].get("alpha", None),
            "beta": lambda opter, **kwargs: \
                # opter.base_opter.param_groups[0]["c"].item() \
                opter.base_opter.param_groups[0]["beta"].detach().cpu().numpy() \
                if hasattr(opter, "base_opter") else opter.param_groups[0].get("beta", None),
            "grad": lambda opter, **kwargs: \
                opter.base_opter.state[0]["last_grad"].detach().cpu().numpy() \
                if hasattr(opter, "base_opter") else opter.state[0]["last_grad"].detach().cpu().numpy(),
            "cos_sim(d, x.grad)": lambda opter, **kwargs: \
                torch.cosine_similarity(
                    opter.base_opter.state[0]["last_update"].flatten(),
                    opter.base_opter.state[0]["last_grad"].flatten(),
                    dim=0
                ).item(),
            "last_lr": lambda opter, **kwargs: \
                opter.base_opter.state[0]["last_lr"].item() if type(opter.base_opter.state[0]["last_lr"]) == torch.Tensor else opter.base_opter.state[0]["last_lr"],
        },
        "l2o_dict": l2o_dict,
        # "l2o_dict": l2o_dict_best["best_l2o_dict"],
    },
    "plot_config": {
        "color": "orange",
        "linewidth": "3",
    },
}


# runs[r"NA-CFGD"] = {
#     "update_config": {
#         "opter": {
#             "opter_cls": CFGD,
#             "opter_config": {
#                 "lr": 0.1,
#                 "alpha": 0.02,
#                 "beta": 0.,
#                 "c": 0,
#                 "s": 1,
#                 "device": DEVICE,
#             },
#         },
#         "additional_metrics": {
#             # "l2_dist(x_tik*, x)": lambda task, optee, **kwargs: \
#             #     torch.norm(task["x_tik_solution"](gamma=gamma, c=1) - optee.params.detach(), p=2).item(),
#             # "l2_dist(x*, x)": lambda task, optee, **kwargs: \
#             #     torch.norm(task["x_solution"] - optee.params.detach(), p=2).item(),
#             # "x": lambda y_hat, y, optee: optee.params.detach().cpu().numpy(),
#             "cos_sim(d, x.grad)": lambda opter, **kwargs: \
#                 torch.cosine_similarity(
#                     opter.state[0]["last_update"].flatten(),
#                     opter.state[0]["last_grad"].flatten(),
#                     dim=0
#                 ).item(),
#             # "last_lr": lambda opter, **kwargs: \
#             #     opter.state[0]["last_lr"].item() if type(opter.state[0]["last_lr"]) == torch.Tensor else opter.state[0]["last_lr"],
#         },
#         "n_iters": test_run_iters,
#     },
#     "plot_config": {
#         "linestyle": "dashed",
#     },
# }

In [None]:
### run all
for run_name in runs.keys():
    if "log" in runs[run_name]:
        continue # already run
    run_config = copy.deepcopy(config)
    if "update_config" in runs[run_name] and runs[run_name]["update_config"] is not None:
        run_config.update(runs[run_name]["update_config"])
    print(f"{run_name}:")
    
    torch.manual_seed(test_runs_seed)
    np.random.seed(test_runs_seed)

    if "lr" in run_config["opter"]["opter_config"] \
        and run_config["opter"]["opter_config"]["lr"] == find_best_lr:
        print("  > Finding best lr...")
        run_config["opter"]["opter_config"]["lr"] = find_best_lr(
            opter_cls=run_config["opter"]["opter_cls"],
            opter_config=run_config["opter"]["opter_config"],
            optee_cls=run_config["optee"]["optee_cls"],
            optee_config=run_config["optee"]["optee_config"],
            data_cls=run_config["data"]["data_cls"],
            data_config=run_config["data"]["data_config"],
            # loss_fn=run_config["loss_fn"],
            n_iters=120,
            n_tests=1,
            consider_metric="loss",
            lrs_to_try=[0.01, 0.05, 0.1, 0.3, 0.5, 0.7, 1.0, 1.5, 2.0],
        )
        print(f"  > Best lr: {run_config['opter']['opter_config']['lr']}")

    print("  > Running...")
    runs[run_name]["log"] = dict()
    for i in range(n_test_runs):
        print(f"    > Run {i+1}/{n_test_runs}...")

        ### check if L2O has been meta-trained
        assert not run_config["opter"]["opter_cls"] == L2O or run_config["l2o_dict"] is not None

        curr_log = do_fit(
            opter_cls=run_config["opter"]["opter_cls"],
            opter_config=run_config["opter"]["opter_config"],
            optee_cls=run_config["optee"]["optee_cls"],
            optee_config=run_config["optee"]["optee_config"],
            data_cls=run_config["data"]["data_cls"],
            data_config=run_config["data"]["data_config"],
            n_iters=run_config["n_iters"],
            l2o_dict=run_config["l2o_dict"],
            in_meta_training=False,
            additional_metrics=run_config["additional_metrics"],
        )[0]

        for metric_name in curr_log.keys():
            if metric_name not in runs[run_name]["log"]:
                runs[run_name]["log"][metric_name] = []
            runs[run_name]["log"][metric_name].append(curr_log[metric_name])

    runs[run_name]["config"] = run_config

In [None]:
conv_win = 8
plot_log(
    runs,
    only_metrics=["loss", "cos_sim(d, x.grad)"],
    # only_metrics=["loss"],
    # only_metrics=["cos_sim(d, x.grad)"],
    log_metrics=["loss", "l2_dist(x_tik*, x)", "l2_dist(x*, x)"],
    conv_win=conv_win,
    min_max_y_config={
        "last_lr": (0, 100),
    },
    # save_to=os.path.join(
    #     config["ckpt_config"]["ckpt_dir"],
    #     f"loss_cos_sim_l2o_best_dict_{conv_win}conv_{n_test_runs}runs_{test_run_iters}iters.png"
    # ),
)

#### Analyze strategy

In [None]:
# gammas = np.stack(runs["L2O + CFGD_ClosedForm"]["log"]["gamma"])
# gammas = gammas.T.diagonal().transpose(1, 0, 2)  # (n_test_runs, n_iters, D)

alphas = np.stack(runs["L2O-CFGD"]["log"]["alpha"])
# alphas = alphas.T.diagonal().transpose(1, 0, 2)  # (n_test_runs, n_iters, D)

betas = np.stack(runs["L2O-CFGD"]["log"]["beta"])
# betas = betas.T.diagonal().transpose(1, 0, 2)  # (n_test_runs, n_iters, D)

cs = np.stack(runs["L2O-CFGD"]["log"]["c"])  # (n_test_runs, n_iters, D)

# grads = np.stack(runs["L2O + CFGD_ClosedForm"]["log"]["grad"])  # (n_test_runs, n_iters, D)

In [None]:
test_run_idx = 0
to_plot_name = "beta"
plot_plot_label = r"$\beta$"
log_plot = False

fig = plt.figure(figsize=(15, 6), facecolor="w")
fig.suptitle(rf"L2O-CFGD: {plot_plot_label}", fontsize=16)

ax = fig.add_subplot(121)
if to_plot_name == "alpha":
    to_plot = alphas
elif to_plot_name == "beta":
    to_plot = betas
elif to_plot_name == "gamma":
    to_plot = gammas
elif to_plot_name == "c":
    to_plot = cs
elif to_plot_name == "grad":
    to_plot = grads
else:
    raise ValueError(f"Unknown to_plot: {to_plot_name}")

plt.plot(to_plot[test_run_idx].squeeze().reshape(1000, -1), alpha=0.1, color="grey")
plt.plot(to_plot[test_run_idx].squeeze().reshape(1000, -1).mean(-1), color="orange", linewidth=3)
ax.set_xlabel("Iteration", fontsize=13)
ax.set_ylabel(plot_plot_label, fontsize=13)

### share y-axis with left plot
ax = fig.add_subplot(122, sharey=ax)
plt.plot(to_plot[test_run_idx].squeeze().reshape(1000, -1)[:,:3])
ax.set_xlabel("Iteration", fontsize=13)
ax.set_ylabel(plot_plot_label, fontsize=13)
ax.legend(["Parameter #1", "Parameter #2", "Parameter #3"])

if log_plot:
    ax.set_yscale("log")

plt.tight_layout(h_pad=2.5)
save_to = os.path.join(
    config["ckpt_config"]["ckpt_dir"],
    f"strategy_{to_plot_name}_{n_test_runs}runs_{test_run_iters}iters.png"
)
fig.savefig(save_to)

plt.show()