In [None]:
import os
import json
import time
import copy
from datetime import datetime
from functools import partial
import matplotlib.pyplot as plt
import seaborn as sns
import lovely_tensors as lt # can be removed
import numpy as np
import dill
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from fl2o.optimizee import MLPOptee, CustomParams
from fl2o.optimizee_modules import MetaParameter
from fl2o.optimizer import GD, Adam, FGD, AFOGD, CFGD, CFGD_ClosedForm, L2O_Update
from fl2o.l2o import L2O
from fl2o.data import MNIST, CustomTask, generate_least_squares_task, H1, H2, H3
from fl2o.training import do_fit, find_best_lr, meta_train, get_optimal_lr, n_step_lookahead_lr_search_hfunc_tanh_twolayer_optee
from fl2o.utils import dict_to_str, plot_log, plotter, plot_metrics, apply_publication_plt_settings, plot_strategy

lt.monkey_patch() # can be removed

DATA_PATH = os.getenv("DATA_PATH")
CKPT_PATH = os.getenv("CKPT_PATH")
DEVICE = os.getenv("DEVICE", "cpu")

print(f"{DATA_PATH=}\n{CKPT_PATH=}\n{DEVICE=}")

In [None]:
### load previous checkpoint (and skip meta-training of a new l2O optimizer)
ckpt = torch.load(
    os.path.join(
        CKPT_PATH,
        "l2o",
        "31-05_10-21__L2O__CFGD",
        "ckpt_1000.pt"
        # "meta_training",
        # "20.pt",
    ),
    map_location=torch.device(DEVICE),
    pickle_module=dill,
)
l2o_dict = ckpt["l2o_dict"]
l2o_dict_best = ckpt["l2o_dict_best"]
l2o_dict_best["best_l2o_dict"]["opter"].device = DEVICE
l2o_dict["opter"].device = DEVICE
log = ckpt["log"]
config = ckpt["config"]
config["device"] = DEVICE
print(json.dumps(config, indent=4, default=str))

### l2o
# ckpt_2 = torch.load(
#     os.path.join(
#         CKPT_PATH,
#         "l2o",
#         "02-02_11-34__L2O__L2O_Update",
#         "ckpt.pt"
#         # "meta_training",
#         # "20.pt",
#     ),
#     map_location=torch.device(DEVICE),
#     pickle_module=dill,
# )
# l2o_dict_2 = ckpt_2["l2o_dict"]
# l2o_dict_best_2 = ckpt_2["l2o_dict_best"]
# log_2 = ckpt_2["log"]
# config_2 = ckpt_2["config"]
# print(json.dumps(config_2, indent=4, default=str))

## Quadratic Objective Function (Figure 1)
$$
\min_x f(x, y) = 10 \cdot x^2 + y^2
$$

In [None]:
"""
Quadratic objective function
    min_x f(x, y) = 10 * x^2 + y^2
"""
def task_gen(device="cpu"):
    ### Least squares problem 1/2 x^T A x + b^T x
    A = torch.tensor([[10., 0.], [0., 1.]], device=device)
    b = torch.tensor([[0., 0.]], device=device).T
    x_solution = torch.tensor([0., 0.], device=device)

    loss_fn = lambda y_hat: 0.5 * y_hat @ A @ y_hat.T + b.T @ y_hat.T

    return {
        "A": A,
        "b": b,
        "loss_fn": loss_fn,
        "x_solution": x_solution,
    }

In [None]:
config = {
    "optee": {
        "optee_cls": CustomParams,
        "optee_config": {
            "dim": (1, 2),
            "init_params": torch.tensor([[1., -10.]], device=DEVICE),
            "param_func": None,
        },
    },
    "opter": {
        "opter_cls": CFGD_ClosedForm,
        "opter_config": {
            "lr": get_optimal_lr,
            "gamma": -1.,
            "c": None,
            "version": "AT",
            "init_points": [torch.tensor([-1., -1.], device=DEVICE)],
            "device": DEVICE,
        },
    },
    "data": {
        "data_cls": CustomTask,
        "data_config": {
            "task": task_gen,
            "task_config": {
                # "verbose": False,
                "device": DEVICE,
            },
        },
    },
    "n_iters": 50,
    "l2o_dict": None,
    "additional_metrics": {
        "l2_dist(x*, x)": lambda task, optee, **kwargs: \
            torch.norm(task["x_solution"] - optee.params.detach(), p=2).item(),
    },
    "device": DEVICE,
    "seed": 0,
}

In [None]:
### runs config
runs = dict()

runs["GD"] = {
    "update_config": {
        "opter": {
            "opter_cls": GD,
            "opter_config": {
                "lr": get_optimal_lr,
                "device": DEVICE,
            },
        },
        "additional_metrics": {
            "l2_dist(x*, x)": lambda task, optee, **kwargs: \
                torch.norm(task["x_solution"] - optee.params.detach(), p=2).item(),
            "x": lambda optee, **kwargs: optee.params.detach().cpu().numpy(),
        },
    },
    "plot_config": {
        "color": "black",
        "linestyle": "dashed",
    },
}

# for gamma in [-1., 0.01, 0.05, 0.25, 0.5, 1., 10.]:
for gamma in [0.05]:
    runs[r"CFGD_ClosedForm, $\gamma$=" + str(gamma)] = {
        "update_config": {
            "opter": {
                "opter_cls": CFGD_ClosedForm,
                "opter_config": {
                    "lr": get_optimal_lr,
                    "gamma": gamma,
                    "c": None,
                    "version": "AT",
                    "init_points": [torch.tensor([-1., -1.], device=DEVICE)],
                    "device": DEVICE,
                },
            },
            "additional_metrics": {
                "l2_dist(x*, x)": lambda task, optee, **kwargs: \
                    torch.norm(task["x_solution"] - optee.params.detach(), p=2).item(),
                "x": lambda optee, **kwargs: optee.params.detach().cpu().numpy(),
            },
        },
        # "plot_config": {
        #     "linestyle": "dashed",
        # },
    }

runs["L2O-CFGD"] = {
    "update_config": {
        "opter": {
            "opter_cls": L2O,
            "opter_config": {
                "in_dim": 3, # len(in_features) + 1
                "out_dim": 3,
                "hidden_sz": 40,
                "in_features": ("grad", "iter_num_enc"),
                "base_opter_cls": CFGD_ClosedForm,
                "base_opter_config": {
                    "lr": get_optimal_lr,
                    "gamma": None,
                    "c": None,
                    "version": "NA",
                    # "init_points": [torch.randn(1, config["data"]["d"], requires_grad=False, device=DEVICE)],
                    "device": DEVICE,
                },
                "params_to_optimize": {
                    # "gamma": {
                    #     "idx": 0,
                    #     "act_fns": ("identity", "diag"),
                    # },
                    # "gamma": {
                    #     "idx": 0,
                    #     "act_fns": ("alpha_to_gamma", "diag"),
                    #     "beta": 0.,
                    # },
                    "gamma": {
                        "idx": (0, 1),
                        "act_fns": ("alpha_beta_to_gamma", "diag"),
                    },
                    "c": {
                        "idx": 2,
                        "act_fns": ("identity",),
                    },
                },
            },
        },
        "additional_metrics": {
            # "gamma": lambda opter, **kwargs: \
            #     # opter.base_opter.param_groups[0]["gamma"].item() \
            #     opter.base_opter.param_groups[0]["gamma"].detach().cpu().numpy() \
            #     if hasattr(opter, "base_opter") else opter.param_groups[0].get("gamma", None),
            # "c": lambda opter, **kwargs: \
            #     # opter.base_opter.param_groups[0]["c"].item() \
            #     opter.base_opter.param_groups[0]["c"].detach().cpu().numpy() \
            #     if hasattr(opter, "base_opter") else opter.param_groups[0].get("c", None),
            # "alpha": lambda opter, **kwargs: \
            #     # opter.base_opter.param_groups[0]["gamma"].item() \
            #     opter.base_opter.param_groups[0]["alpha"] \
            #     if hasattr(opter, "base_opter") else opter.param_groups[0].get("alpha", None),
            # "beta": lambda opter, **kwargs: \
            #     # opter.base_opter.param_groups[0]["c"].item() \
            #     opter.base_opter.param_groups[0]["beta"] \
            #     if hasattr(opter, "base_opter") else opter.param_groups[0].get("beta", None),
            # "grad": lambda opter, **kwargs: \
            #     opter.base_opter.state[0]["last_grad"].detach().cpu().numpy() \
            #     if hasattr(opter, "base_opter") else opter.state[0]["last_grad"].detach().cpu().numpy(),
            "x": lambda optee, **kwargs: optee.params.detach().cpu().numpy(),
            "cos_sim(d, x.grad)": lambda opter, **kwargs: \
                torch.cosine_similarity(
                    opter.base_opter.state[0]["last_update"].flatten(),
                    opter.base_opter.state[0]["last_grad"].flatten(),
                    dim=0
                ).item(),
            "last_lr": lambda opter, **kwargs: \
                opter.base_opter.state[0]["last_lr"].item() if type(opter.base_opter.state[0]["last_lr"]) == torch.Tensor else opter.base_opter.state[0]["last_lr"],
        },
        # "l2o_dict": l2o_dict,
        "l2o_dict": l2o_dict_best["best_l2o_dict"],
    },
    "plot_config": {
        # "color": "orange",
        # "color": c_palette[3],
        # "linewidth": 1.5,
    },
}

In [None]:
### run all
for run_name in runs.keys():
    if "log" in runs[run_name]:
        continue # already run
    run_config = copy.deepcopy(config)
    if "update_config" in runs[run_name] and runs[run_name]["update_config"] is not None:
        run_config.update(runs[run_name]["update_config"])
    print(f"{run_name}:")
    
    torch.manual_seed(config["seed"])
    np.random.seed(config["seed"])

    if "lr" in run_config["opter"]["opter_config"] \
        and run_config["opter"]["opter_config"]["lr"] == find_best_lr:
        print("  > Finding best lr...")
        run_config["opter"]["opter_config"]["lr"] = find_best_lr(
            opter_cls=run_config["opter"]["opter_cls"],
            opter_config=run_config["opter"]["opter_config"],
            optee_cls=run_config["optee"]["optee_cls"],
            optee_config=run_config["optee"]["optee_config"],
            data_cls=run_config["data"]["data_cls"],
            data_config=run_config["data"]["data_config"],
            # loss_fn=run_config["loss_fn"],
            n_iters=120,
            n_tests=1,
            consider_metric="loss",
            lrs_to_try=[0.01, 0.05, 0.1, 0.3, 0.5, 0.7, 1.0, 1.5, 2.0],
        )
        print(f"  > Best lr: {run_config['opter']['opter_config']['lr']}")

    print("  > Running...")
    runs[run_name]["log"], _, _ = do_fit(
        opter_cls=run_config["opter"]["opter_cls"],
        opter_config=run_config["opter"]["opter_config"],
        optee_cls=run_config["optee"]["optee_cls"],
        optee_config=run_config["optee"]["optee_config"],
        data_cls=run_config["data"]["data_cls"],
        data_config=run_config["data"]["data_config"],
        n_iters=run_config["n_iters"],
        l2o_dict=run_config["l2o_dict"],
        in_meta_training=False,
        additional_metrics=run_config["additional_metrics"],
    )
    runs[run_name]["config"] = run_config

In [None]:
plot_log(
    runs,
    only_metrics=["loss", "l2_dist(x*, x)", "time"],
    log_metrics=["loss", "l2_dist(x_tik*, x)", "l2_dist(x*, x)"],
    conv_win=1,
)