In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import math
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import gc
from IPython.display import clear_output
import time
import wandb
from hypnettorch.data import FashionMNISTData, MNISTData
from hypnettorch.data.special.split_mnist import get_split_mnist_handlers
from hypnettorch.data.special.split_cifar import get_split_cifar_handlers
from hypnettorch.mnets import LeNet, ZenkeNet, ResNet
from hypnettorch.hnets import HMLP, StructuredHMLP, ChunkedHMLP

from utils.data import get_mnist_data_loaders, get_emnist_data_loaders, randomize_targets, select_from_classes, get_data_handlers
from utils.visualization import show_imgs, get_model_dot
from utils.others import measure_alloc_mem, count_parameters
from utils.timing import func_timer
from utils.metrics import get_accuracy, calc_accuracy, print_arch_summary
from utils.hypnettorch_utils import correct_param_shapes, calc_delta_theta, get_reg_loss_for_cond, get_reg_loss, \
    infer, print_stats, print_metrics, clip_grads, take_training_step, init_hnet_unconditionals, remove_hnet_uncondtionals, \
    validate_cells_training_inputs, train_cells
from utils.models import get_target_nets, get_hnets, create_tree
from single_hypernet import finish_arch_config, get_arch_config, get_config
# from main import get_config
from configs.hypercl_zenke_splitcifar100 import get_config


from IPython.display import clear_output

torch.set_printoptions(precision=3, linewidth=180)
%env "WANDB_NOTEBOOK_NAME" "main.ipynb"
wandb.login()

In [None]:
config = {
    "num_cells": 2,
    "epochs": 10,
    # "max_minibatches_per_epoch": 1200,
    "max_minibatches_per_epoch": None,
    # "phases": ["hnet->solver", "hnet->hnet->solver", "hnet->solver", "hnet->hnet->solver", "hnet->solver", "hnet->hnet->solver"],
    "phases": ["hnet->solver", "hnet->hnet->solver"],
    "n_training_iters_solver": 14,
    "n_training_iters_hnet": 6,
    "data": {
        # "name": "mnist|fmnist",
        # "name": "splitmnist",
        "in_shape": [28, 28, 1],
        "batch_size": 32,
        "data_dir": "data_tmp",
        "num_tasks": 5,
        "num_classes_per_task": 2,
        "validation_size": 0,
    },
    "solver": {
        "use": "lenet",
        "lenet": {
            "arch": "mnist_large",
            "no_weights": True,
        },
        "zenkenet": {
            "arch": "cifar",
            "no_weights": True,
            "dropout_rate": 0.15,
        },
        "resnet": {
            "n": 5,
            "k": 1,
            "use_bias": True,
            "no_weights": True,
        },
    },
    "hnet": {
        "model": {
            # "layers": [100, 100],
            "layers": [25,25],
            "dropout_rate": -1, # hnet doesn't get images -> need to be added to resnet
            "chunk_emb_size": 80,
            "chunk_size": 60_000, # 8000
            "num_cond_embs": None, # specified later
            "cond_in_size": 48,
            "cond_chunk_embs": True,
            "root_no_uncond_weights": False,
            "root_no_cond_weights": False,
            "children_no_uncond_weights": True,
            "children_no_cond_weights": False,
        },
        "lr": 1e-3,
        "reg_lr": 1e-3,
        # "reg_alpha": 5e-3, # L2 regularization of solvers' parameters
        # "reg_beta": 8e-2, # regularization against forgetting other contexts (tasks)
        "reg_alpha": 1e-2, # L2 regularization of solvers' parameters
        "reg_beta": 1e-3, # regularization against forgetting other contexts (tasks)
        "detach_d_theta": True,
        "reg_clip_grads_max_norm": None,
        "reg_clip_grads_max_value": 1.,
    },
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "wandb_logging": False,
}
# first config["data"]["num_tasks"] embs for generating target nets, last config["data"]["num_tasks"] for generating a child hnet
config["hnet"]["model"]["num_cond_embs"] = config["data"]["num_tasks"] * 2

print(f"... Running on {config['device']} ...")

In [None]:
data_handlers = get_data_handlers(config=config)
# torch.manual_seed(0)
# np.random.seed(0)

# if config["data"]["name"] == "mnist|fmnist":
#     mnist = MNISTData(config["data"]["data_dir"], use_one_hot=True, validation_size=config["data"]["validation_size"])
#     fmnist = FashionMNISTData(config["data"]["data_dir"], use_one_hot=True, validation_size=config["data"]["validation_size"])
#     data_handlers = [mnist, fmnist]
# elif config["data"]["name"] == "splitmnist":
#     data_handlers = get_split_mnist_handlers(config["data"]["data_dir"], use_one_hot=True, num_tasks=config["data"]["num_tasks"], num_classes_per_task=config["data"]["num_classes_per_task"], validation_size=config["data"]["validation_size"])
# elif config["data"]["name"] == "splitcifar":
#     data_handlers = get_split_cifar_handlers(config["data"]["data_dir"], use_one_hot=True, num_tasks=config["data"]["num_tasks"], num_classes_per_task=config["data"]["num_classes_per_task"], validation_size=config["data"]["validation_size"])

# assert config["data"]["num_tasks"] == len(data_handlers)

In [None]:
target_nets = get_target_nets(config=config)
solver_root = target_nets[0]
solver_child = target_nets[1]
# torch.manual_seed(0)
# np.random.seed(0)
# # target networks (solvers)
# if config["solver"]["use"] == "lenet":
#     solver_child = LeNet(in_shape=data_handlers[0].in_shape, num_classes=config["data"]["num_classes_per_task"], **config["solver"]["lenet"]).to(config["device"])
#     solver_root = LeNet(in_shape=data_handlers[0].in_shape, num_classes=config["data"]["num_classes_per_task"], **config["solver"]["lenet"]).to(config["device"])
# elif config["solver"]["use"] == "zenkenet":
#     solver_child = ZenkeNet(in_shape=data_handlers[0].in_shape, num_classes=config["data"]["num_classes_per_task"], **config["solver"]["zenkenet"]).to(config["device"])
#     solver_root = ZenkeNet(in_shape=data_handlers[0].in_shape, num_classes=config["data"]["num_classes_per_task"], **config["solver"]["zenkenet"]).to(config["device"])
# elif config["solver"]["use"] == "resnet":
#     solver_child = ResNet(in_shape=data_handlers[0].in_shape, num_classes=config["data"]["num_classes_per_task"], **config["solver"]["resnet"]).to(config["device"])
#     solver_root = ResNet(in_shape=data_handlers[0].in_shape, num_classes=config["data"]["num_classes_per_task"], **config["solver"]["resnet"]).to(config["device"])
# else:
#     raise ValueError(f"Unknown solver: {config['solver']['use']}")

In [None]:
hnets = get_hnets(config=config, target_nets_shapes=[-1, solver_child.param_shapes])
hnet_root = hnets[0]
hnet_child = hnets[1]
hnet_root_optim = torch.optim.Adam(hnet_root.parameters(), lr=config["hnet"]["lr"])
hnet_child_optim = torch.optim.Adam(hnet_child.parameters(), lr=config["hnet"]["lr"])

arch = [("solver_child", solver_child), ("solver_root", solver_root), ("hnet_child", hnet_child), ("hnet_root", hnet_root)]
print_arch_summary(arch)

# torch.manual_seed(0)
# np.random.seed(0)

# hnet_child = ChunkedHMLP(
#     solver_child.param_shapes,
#     layers=config["hnet"]["model"]["layers"],
#     chunk_size=config["hnet"]["chunk_size"],
#     chunk_emb_size=config["hnet"]["chunk_emb_size"],
#     cond_chunk_embs=config["hnet"]["cond_chunk_embs"],
#     cond_in_size=config["hnet"]["cond_in_size"],
#     num_cond_embs=config["data"]["num_tasks"] * 2, # num_tasks * 2 for child hypernetwork
#     no_uncond_weights=True,
#     no_cond_weights=False,
# ).to(config["device"])
# hnet_child_optim = torch.optim.Adam(hnet_child.parameters(), lr=config["hnet"]["lr"])

# hnet_root = ChunkedHMLP(
#     hnet_child.unconditional_param_shapes,
#     layers=config["hnet"]["model"]["layers"],
#     dropout_rate=config["hnet"]["model"]["dropout_rate"], # only for the root hypernetwork
#     chunk_size=config["hnet"]["chunk_size"],
#     chunk_emb_size=config["hnet"]["chunk_emb_size"],
#     cond_chunk_embs=config["hnet"]["cond_chunk_embs"],
#     cond_in_size=config["hnet"]["cond_in_size"],
#     num_cond_embs=config["data"]["num_tasks"] * 2, # num_tasks * 2 for child hypernetwork
#     no_uncond_weights=False,
#     no_cond_weights=False,
# ).to(config["device"])
# # hnet_root.apply_chunked_hyperfan_init(mnet=hnet_child)
# hnet_root_optim = torch.optim.Adam(hnet_root.parameters(), lr=config["hnet"]["lr"])

# print("\nSummary of parameters:")
# max_possible_num_of_maintained_params = 0
# num_of_maintained_params = 0
# for name, m in [("solver_child", solver_child), ("solver_root", solver_root), ("hnet_child", hnet_child), ("hnet_root", hnet_root)]:
#     print(f"- {name}:\t{sum(p.numel() for p in m.parameters())}\t({sum([np.prod(p) for p in m.param_shapes])} possible)")
#     num_of_maintained_params += sum(p.numel() for p in m.parameters())
#     max_possible_num_of_maintained_params += sum([np.prod(p) for p in m.param_shapes])
# print(f"---\nTotal available parameters:\t{max_possible_num_of_maintained_params}")
# print(f"Parameters maintained:\t\t{num_of_maintained_params}")
# print(f"-> Coefficient of compression:\t{(num_of_maintained_params / max_possible_num_of_maintained_params):.5f}")

In [None]:
if config["wandb_logging"]:
    wandb_run = wandb.init(
        project="Hypernets", entity="johnny1188", config=config, group=config["data"]["name"],
        tags=[], notes=f""
    )
    wandb.watch((hnet_root, hnet_child, solver_root, solver_child), log="all", log_freq=100)
else:
    wandb_run = None

#### Training in continual learning setting

In [None]:
# initialize configurations
hnets_cond_ids = [
    {"hnet->solver": {"hnet_root": task_i, "hnet_child": None}, "hnet->hnet->solver": {"hnet_root": task_i + len(data_handlers), "hnet_child": task_i}}
    for task_i in range(len(data_handlers))
]
datasets_for_eval = {d_i: (cond_ids["hnet->solver"]["hnet_root"], cond_ids["hnet->hnet->solver"]["hnet_root"], cond_ids["hnet->hnet->solver"]["hnet_child"], data_handlers[d_i]) for d_i, cond_ids in enumerate(hnets_cond_ids)}

#### Continual learning - segmented backprop

In [None]:
# select cond_ids for hypernets
# hnet_root_cond_id = hnets_cond_ids[d_i][phase]["hnet_root"]
# hnet_child_cond_id = hnets_cond_ids[d_i][phase]["hnet_child"]

# for i in range(10):
#     cells = [
#         {
#             "hnet": hnet_child,
#             "solver": solver_child,
#             "hnet_optim": hnet_child_optim,
#             "hnet_to_hnet_cond_id": None,
#             "hnet_to_solver_cond_id": hnets_cond_ids[d_i]["hnet->hnet->solver"]["hnet_child"],
#             "hnet_init_theta": hnet_root(cond_id=hnets_cond_ids[d_i]["hnet->hnet->solver"]["hnet_root"]),
#             "hnet_prev_params": None,
#             "hnet_theta_out_target": None,
#             "n_training_iters_solver": 200,
#             "n_training_iters_hnet": 0,
#         },
#         {
#             "hnet": hnet_root,
#             "solver": solver_root,
#             "hnet_optim": hnet_root_optim,
#             "hnet_to_hnet_cond_id": hnets_cond_ids[d_i]["hnet->hnet->solver"]["hnet_root"],
#             "hnet_to_solver_cond_id": hnets_cond_ids[d_i]["hnet->solver"]["hnet_root"],
#             "hnet_init_theta": None,
#             "hnet_prev_params": hnet_root_prev_params,
#             "hnet_theta_out_target": None,
#             "n_training_iters_solver": 0,
#             "n_training_iters_hnet": 50,
#         }
#     ]
#     train_cells(X, y, cells, config)

#     # generate theta and predict
#     y_hat, params_solver = infer(X, phase, hnet_parent_cond_id=hnet_root_cond_id, hnet_child_cond_id=hnet_child_cond_id, hnet_parent=hnet_root, hnet_child=hnet_child, solver_parent=solver_root, solver_child=solver_child)

#     # solvers' params regularization
#     loss_solver_params_reg = sum([p.norm(p=2) for p in params_solver]) / len(params_solver)
#     # task loss
#     loss_class = loss_fn(y_hat, y)
#     print(loss_class.item(), " \t" ,loss_solver_params_reg.item(), " \t", y_hat.var(dim=0))

In [None]:
torch.manual_seed(0)
np.random.seed(0)
hnet_root_prev_params = None
log_step = 0

for p_i, phase in enumerate(config["phases"][1:]):
    for d_i, data in enumerate(data_handlers[1:]):
        d_i = d_i + 1
        # save parameters before solving the task for regularization against forgetting
        hnet_root_prev_params = [p.detach().clone() for p_idx, p in enumerate(hnet_root.unconditional_params)]
        # select cond_ids for hypernets
        # hnet_root_cond_id = hnets_cond_ids[d_i][phase]["hnet_root"]
        # hnet_child_cond_id = hnets_cond_ids[d_i][phase]["hnet_child"]
        
        for epoch in range(config["epochs"]):
            for i, (batch_size, X, y) in enumerate(data.train_iterator(config["data"]["batch_size"])):
                if config["max_minibatches_per_epoch"] is not None and i > config["max_minibatches_per_epoch"]:
                    break

                X = data.input_to_torch_tensor(X, config["device"], mode="train")
                y = data.output_to_torch_tensor(y, config["device"], mode="train")

                if phase == "hnet->solver":
                    cells = [
                        {
                            "hnet": hnet_root,
                            "solver": solver_root,
                            "hnet_optim": hnet_root_optim,
                            "hnet_to_hnet_cond_id": None,
                            "hnet_to_solver_cond_id": hnets_cond_ids[d_i]["hnet->solver"]["hnet_root"],
                            "hnet_init_theta": None,
                            "hnet_prev_params": hnet_root_prev_params,
                            "hnet_theta_out_target": None,
                            # "n_training_iters_solver": config["n_training_iters_solver"],
                            "n_training_iters_solver": 1,
                            "n_training_iters_hnet": 0,
                        }
                    ]
                elif phase == "hnet->hnet->solver":
                    cells = [
                        {
                            "hnet": hnet_child,
                            "solver": solver_child,
                            "hnet_optim": hnet_child_optim,
                            "hnet_to_hnet_cond_id": None,
                            "hnet_to_solver_cond_id": hnets_cond_ids[d_i]["hnet->hnet->solver"]["hnet_child"],
                            "hnet_init_theta": hnet_root(cond_id=hnets_cond_ids[d_i]["hnet->hnet->solver"]["hnet_root"]),
                            "hnet_prev_params": None, # TODO: would those also regularize the root hypernet?
                            "hnet_theta_out_target": None,
                            "n_training_iters_solver": config["n_training_iters_solver"],
                            "n_training_iters_hnet": 0,
                        },
                        {
                            "hnet": hnet_root,
                            "solver": solver_root,
                            "hnet_optim": hnet_root_optim,
                            "hnet_to_hnet_cond_id": hnets_cond_ids[d_i]["hnet->hnet->solver"]["hnet_root"],
                            "hnet_to_solver_cond_id": hnets_cond_ids[d_i]["hnet->solver"]["hnet_root"],
                            "hnet_init_theta": None,
                            "hnet_prev_params": hnet_root_prev_params,
                            "hnet_theta_out_target": None, # will get set during the train_cells() call
                            # "n_training_iters_solver": config["n_training_iters_hnet"], # root cell shouldn't take more steps on task than on hnet
                            "n_training_iters_solver": 0, # root cell shouldn't take more steps on task than on hnet
                            "n_training_iters_hnet": config["n_training_iters_hnet"],
                        }
                    ]

                validate_cells_training_inputs(X, y, cells, config)
                stats = train_cells(X, y, cells, config, [])
                # clear_output(wait=True)
                if i % 5 == 4:
                    print_metrics(
                        datasets_for_eval, config, hnet_root, hnet_child, solver_root, solver_child,
                        prefix=f"[{p_i + 1}:{phase} | {d_i}/{len(data_handlers) - 1} | {epoch + 1}/{config['epochs']} | {i + 1}]",
                        skip_phases=[],
                        wandb_run=wandb_run,
                        additional_metrics=None
                    )
                    print(".")
                    print_stats(reversed(stats))
                    print("---")

In [None]:
print_metrics(
    datasets_for_eval, config, hnet_root, hnet_child, solver_root, solver_child,
    prefix=f"[{p_i + 1}:{phase} | {d_i}/{len(data_handlers) - 1} | {epoch + 1}/{config['epochs']} | {i + 1}]",
    skip_phases=[],
    wandb_run=wandb_run,
    additional_metrics=None
)
print("---")

#### Continual learning - full backprop

In [None]:
torch.manual_seed(0)
np.random.seed(0)
loss_fn = nn.CrossEntropyLoss(reduction="mean")
hnet_root_prev_params = None
log_step = 0
phases = config["phases"]

for p_i, phase in enumerate(phases):
    for d_i, data in enumerate(data_handlers):
        # save parameters before solving the task for regularization against forgetting
        hnet_root_prev_params = [p.detach().clone() for p_idx, p in enumerate(hnet_root.unconditional_params)]
        for epoch in range(config["epochs"]):
            for i, (batch_size, X, y) in enumerate(data.train_iterator(config["data"]["batch_size"])):
                if config["max_minibatches_per_epoch"] is not None and i > config["max_minibatches_per_epoch"]:
                    break

                X = data.input_to_torch_tensor(X, config["device"], mode="train")
                y = data.output_to_torch_tensor(y, config["device"], mode="train")

                hnet_root_optim.zero_grad()
                hnet_child_optim.zero_grad()

                # select cond_ids for hypernets
                hnet_root_cond_id = hnets_cond_ids[d_i][phase]["hnet_root"]
                hnet_child_cond_id = hnets_cond_ids[d_i][phase]["hnet_child"]
                # generate theta and predict
                y_hat, params_solver = infer(X, phase, hnet_parent_cond_id=hnet_root_cond_id, hnet_child_cond_id=hnet_child_cond_id,
                    hnet_parent=hnet_root, hnet_child=hnet_child, solver_parent=solver_root, solver_child=solver_child, config=config)
                
                # solvers' params regularization
                loss_solver_params_reg = sum([p.norm(p=2) for p in params_solver]) / len(params_solver)
                # task loss
                loss_class = loss_fn(y_hat, y)
                loss = loss_class + config["hnet"]["reg_alpha"] * loss_solver_params_reg
                loss.backward(retain_graph=True, create_graph=not config["hnet"]["detach_d_theta"])
                # gradient clipping
                clip_grads([hnet_child, hnet_root], config["hnet"]["reg_clip_grads_max_norm"], config["hnet"]["reg_clip_grads_max_value"])
                
                # regularization against forgetting other contexts
                loss_reg = config["hnet"]["reg_beta"] * get_reg_loss(hnet_root, hnet_root_prev_params, curr_cond_id=hnet_root_cond_id, lr=config["hnet"]["reg_lr"], detach_d_theta=config["hnet"]["detach_d_theta"])
                loss_reg.backward()
                clip_grads([hnet_child, hnet_root], config["hnet"]["reg_clip_grads_max_norm"], config["hnet"]["reg_clip_grads_max_value"])
                
                hnet_root_optim.step()
                hnet_child_optim.step()
                hnet_root_optim.zero_grad()
                hnet_child_optim.zero_grad()

                if i % 100 == 99:
                    acc = (y_hat.argmax(dim=-1) == y.argmax(dim=-1)).float().mean() * 100.
                    print_metrics(
                        datasets_for_eval, config=config, hnet_root=hnet_root, hnet_child=hnet_child, solver_root=solver_root, solver_child=solver_child,
                        prefix=f"[{p_i + 1}:{phase} | {d_i}/{len(data_handlers) - 1} | {epoch + 1}/{config['epochs']} | {i + 1}]",
                        skip_phases=[],
                        wandb_run=wandb_run, additional_metrics={
                            "loss_class": loss_class.item(),
                            "acc_class": acc,
                            "loss_solver_params_reg": loss_solver_params_reg.item(),
                            "loss_reg": loss_reg.item(),
                        }
                    )
                    print("---")
                    log_step += 1

#### Training in multitask setting

In [None]:
torch.manual_seed(0)
np.random.seed(0)
loss_fn = nn.CrossEntropyLoss(reduction="mean")
hnet_root_prev_phase_params = None
log_step = 0

phases = ["hnet->solver", "hnet->hnet->solver", "hnet->solver", "hnet->hnet->solver", "hnet->solver", "hnet->hnet->solver"]
for p_i, phase in enumerate(phases):
    print(f"\n\n.... Starting phase {phase} ...")
    if wandb_run is not None:
        wandb_run.log({"Phase": wandb.Table(columns=["phase", "step"], data=[[phase, log_step]])}) # log what phase the training is in
    for epoch in range(config["epochs"]):
        for i, ((_, m_X, m_y),(_, f_X, f_y)) in enumerate(zip(mnist.train_iterator(config["data"]["batch_size"]), fmnist.train_iterator(config["data"]["batch_size"]))):
            if i > config["max_minibatches_per_epoch"]:
                break

            # Mini-batch of MNIST samples
            m_X = mnist.input_to_torch_tensor(m_X, config["device"], mode="train")
            m_y = mnist.output_to_torch_tensor(m_y, config["device"], mode="train")
            # Mini-batch of FashionMNIST samples
            f_X = fmnist.input_to_torch_tensor(f_X, config["device"], mode="train")
            f_y = fmnist.output_to_torch_tensor(f_y, config["device"], mode="train")

            hnet_root_optim.zero_grad()
            hnet_child_optim.zero_grad()

            # MNIST ------------------------------------------------------------
            if phase == "hnet->solver":
                hnet_root_cond_id = 0
                hnet_child_cond_id = None
            elif phase == "hnet->hnet->solver":
                hnet_root_cond_id = 2
                hnet_child_cond_id = 0
            else:
                raise ValueError(f"Unknown phase {phase}")
            
            y_hat, params_solver = infer(m_X, phase, hnet_root_cond_id=hnet_root_cond_id, hnet_child_cond_id=hnet_child_cond_id, hnet_root=hnet_root, hnet_child=hnet_child, solver_root=solver_root, solver_child=solver_child)
            
            # solvers' params regularization + task loss
            m_loss_solver_params_reg = sum([p.norm(p=2) for p in params_solver]) / len(params_solver)
            m_loss = loss_fn(y_hat, m_y.max(dim=1)[1]) + config["hnet"]["reg_alpha"] * m_loss_solver_params_reg
            m_acc = (y_hat.argmax(dim=-1) == m_y.argmax(dim=-1)).float().mean() * 100.
            m_loss.backward()
            if config["hnet"]["reg_clip_grads_max_norm"] is not None:
                clip_grads([hnet_child, hnet_root], config["hnet"]["reg_clip_grads_max_norm"])
            
            # regularization against forgetting other contexts
            m_loss_reg = config["hnet"]["reg_beta"] * get_reg_loss(hnet_root, hnet_root_prev_phase_params, hnet_root_optim, curr_cond_id=hnet_root_cond_id, lr=config["hnet"]["reg_lr"], detach_d_theta=False)
            m_loss_reg.backward()
            if config["hnet"]["reg_clip_grads_max_norm"] is not None:
                clip_grads([hnet_child, hnet_root], config["hnet"]["reg_clip_grads_max_norm"])
            
            hnet_root_optim.step()
            hnet_child_optim.step()
            hnet_root_optim.zero_grad()
            hnet_child_optim.zero_grad()


            # FashionMNIST ------------------------------------------------------------
            if phase == "hnet->solver":
                hnet_root_cond_id = 1
                hnet_child_cond_id = None
            elif phase == "hnet->hnet->solver":
                hnet_root_cond_id = 3
                hnet_child_cond_id = 1
            else:
                raise ValueError(f"Unknown phase {phase}")
            
            y_hat, params_solver = infer(f_X, phase, hnet_root_cond_id=hnet_root_cond_id, hnet_child_cond_id=hnet_child_cond_id, hnet_root=hnet_root, hnet_child=hnet_child, solver_root=solver_root, solver_child=solver_child)
            
            # solvers' params regularization + task loss
            f_loss_solver_params_reg = sum([p.norm(p=2) for p in params_solver]) / len(params_solver)
            f_loss = loss_fn(y_hat, f_y.max(dim=1)[1]) + config["hnet"]["reg_alpha"] * f_loss_solver_params_reg
            f_acc = (y_hat.argmax(dim=-1) == f_y.argmax(dim=-1)).float().mean() * 100.
            f_loss.backward()
            if config["hnet"]["reg_clip_grads_max_norm"] is not None:
                clip_grads([hnet_child, hnet_root], config["hnet"]["reg_clip_grads_max_norm"])
            
            # regularization against forgetting other contexts
            f_loss_reg = config["hnet"]["reg_beta"] * get_reg_loss(hnet_root, hnet_root_prev_phase_params, hnet_root_optim, curr_cond_id=hnet_root_cond_id, lr=config["hnet"]["reg_lr"], detach_d_theta=False)
            f_loss_reg.backward()
            if config["hnet"]["reg_clip_grads_max_norm"] is not None:
                clip_grads([hnet_child, hnet_root], config["hnet"]["reg_clip_grads_max_norm"])
            
            hnet_root_optim.step()
            hnet_child_optim.step()
            hnet_root_optim.zero_grad()
            hnet_child_optim.zero_grad()


            if i % 100 == 99:
                print_metrics(
                    {"MNIST": (0, 2, 0, mnist), "FashionMNIST": (1, 3, 1, fmnist)},
                    hnet_root, hnet_child, solver_root, solver_child,
                    prefix=f"[{phase} | {epoch}/{config['epochs']} | {i + 1}]\nM: {m_loss_reg:.3f} F: {f_loss_reg:.3f}",
                    # skip_phases=["hnet->hnet->solver"] if p_i == 0 and phase == "hnet->solver" else [],
                    skip_phases=[],
                    wandb_run=wandb_run, additional_metrics={
                        "m_loss_class": m_loss.item() - config["hnet"]["reg_alpha"] * m_loss_solver_params_reg.item(),
                        "f_loss_class": f_loss.item() - config["hnet"]["reg_alpha"] * f_loss_solver_params_reg.item(),
                        "m_acc_class": m_acc,
                        "f_acc_class": f_acc,
                        "m_loss_solver_params_reg": m_loss_solver_params_reg.item(),
                        "f_loss_solver_params_reg": f_loss_solver_params_reg.item(),
                        "m_loss_reg": m_loss_reg.item(),
                        "f_loss_reg": f_loss_reg.item(),
                    }
                )
                print("---")
                log_step += 1
    hnet_root_prev_phase_params = [p.detach().clone() for p_idx, p in enumerate(hnet_root.unconditional_params)]