In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import math
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import gc
from IPython.display import clear_output
import time
import wandb
from hypnettorch.data import FashionMNISTData, MNISTData
from hypnettorch.data.special.split_mnist import get_split_mnist_handlers
from hypnettorch.data.special.split_cifar import get_split_cifar_handlers
from hypnettorch.mnets import LeNet, ZenkeNet, ResNet
from hypnettorch.hnets import HMLP, StructuredHMLP, ChunkedHMLP

from utils.data import get_mnist_data_loaders, get_emnist_data_loaders, randomize_targets, select_from_classes
from utils.visualization import show_imgs, get_model_dot
from utils.others import measure_alloc_mem, count_parameters
from utils.timing import func_timer
from utils.metrics import get_accuracy, calc_accuracy
from utils.hypnettorch_utils import correct_param_shapes, calc_delta_theta, get_reg_loss_for_cond, get_reg_loss, \
    infer, print_stats, print_metrics, clip_grads, take_training_step, init_hnet_unconditionals, remove_hnet_uncondtionals, \
    validate_cells_training_inputs, train_cells


from IPython.display import clear_output

torch.set_printoptions(precision=3, linewidth=180)
%env "WANDB_NOTEBOOK_NAME" "main.ipynb"
wandb.login()

In [None]:
config = {
    "epochs": 20,
    # "max_minibatches_per_epoch": 1200,
    "max_minibatches_per_epoch": None,
    # "phases": ["hnet->solver", "hnet->hnet->solver", "hnet->solver", "hnet->hnet->solver", "hnet->solver", "hnet->hnet->solver"],
    "phases": ["hnet->hnet->solver"],
    "n_training_iters_solver": 25,
    "n_training_iters_hnet": 5,
    "data": {
        # "name": "mnist|fmnist",
        "name": "splitcifar",
        "batch_size": 32,
        "data_dir": "data_tmp",
        "num_tasks": 5,
        "num_classes_per_task": 2,
        "validation_size": 0,
    },
    "solver": {
        "use": "zenkenet",
        "lenet": {
            "arch": "mnist_large",
            "no_weights": True,
        },
        "zenkenet": {
            "arch": "cifar",
            "no_weights": True,
            "dropout_rate": 0.15,
        },
        "resnet": {
            "n": 5,
            "k": 1,
            "use_bias": True,
            "no_weights": True,
        },
    },
    "hnet": {
        "lr": 1e-3,
        "reg_lr": 1e-3,
        "model": {
            # "layers": [100, 100],
            "layers": [20,20],
            "dropout_rate": -1, # hmlp doesn't get images -> need to be added to resnet
        },
        "chunk_emb_size": 80,
        # "chunk_size": 8000,
        "chunk_size": 60_000,
        "cond_in_size": 48,
        "cond_chunk_embs": True,
        # "reg_alpha": 5e-3, # L2 regularization of solvers' parameters
        # "reg_beta": 8e-2, # regularization against forgetting other contexts (tasks)
        "reg_alpha": 1e-4, # L2 regularization of solvers' parameters
        "reg_beta": 1e-4, # regularization against forgetting other contexts (tasks)
        "detach_d_theta": True,
        "reg_clip_grads_max_norm": None,
        "reg_clip_grads_max_value": 1.,
    },
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "wandb_logging": False,
}

print(f"... Running on {config['device']} ...")

In [None]:
torch.manual_seed(0)
np.random.seed(0)

if config["data"]["name"] == "mnist|fmnist":
    mnist = MNISTData(config["data"]["data_dir"], use_one_hot=True, validation_size=config["data"]["validation_size"])
    fmnist = FashionMNISTData(config["data"]["data_dir"], use_one_hot=True, validation_size=config["data"]["validation_size"])
    data_handlers = [mnist, fmnist]
elif config["data"]["name"] == "splitmnist":
    data_handlers = get_split_mnist_handlers(config["data"]["data_dir"], use_one_hot=True, num_tasks=config["data"]["num_tasks"], num_classes_per_task=config["data"]["num_classes_per_task"], validation_size=config["data"]["validation_size"])
elif config["data"]["name"] == "splitcifar":
    data_handlers = get_split_cifar_handlers(config["data"]["data_dir"], use_one_hot=True, num_tasks=config["data"]["num_tasks"], num_classes_per_task=config["data"]["num_classes_per_task"], validation_size=config["data"]["validation_size"])

assert config["data"]["num_tasks"] == len(data_handlers)

In [None]:
torch.manual_seed(0)
np.random.seed(0)
# target networks (solvers)
if config["solver"]["use"] == "lenet":
    solver_child = LeNet(in_shape=data_handlers[0].in_shape, num_classes=config["data"]["num_classes_per_task"], **config["solver"]["lenet"]).to(config["device"])
    solver_root = LeNet(in_shape=data_handlers[0].in_shape, num_classes=config["data"]["num_classes_per_task"], **config["solver"]["lenet"]).to(config["device"])
elif config["solver"]["use"] == "zenkenet":
    solver_child = ZenkeNet(in_shape=data_handlers[0].in_shape, num_classes=config["data"]["num_classes_per_task"], **config["solver"]["zenkenet"]).to(config["device"])
    solver_root = ZenkeNet(in_shape=data_handlers[0].in_shape, num_classes=config["data"]["num_classes_per_task"], **config["solver"]["zenkenet"]).to(config["device"])
elif config["solver"]["use"] == "resnet":
    solver_child = ResNet(in_shape=data_handlers[0].in_shape, num_classes=config["data"]["num_classes_per_task"], **config["solver"]["resnet"]).to(config["device"])
    solver_root = ResNet(in_shape=data_handlers[0].in_shape, num_classes=config["data"]["num_classes_per_task"], **config["solver"]["resnet"]).to(config["device"])
else:
    raise ValueError(f"Unknown solver: {config['solver']['use']}")

In [None]:
torch.manual_seed(0)
np.random.seed(0)

hnet_child = ChunkedHMLP(
    solver_child.param_shapes,
    layers=config["hnet"]["model"]["layers"],
    chunk_size=config["hnet"]["chunk_size"],
    chunk_emb_size=config["hnet"]["chunk_emb_size"],
    cond_chunk_embs=config["hnet"]["cond_chunk_embs"],
    cond_in_size=config["hnet"]["cond_in_size"],
    num_cond_embs=config["data"]["num_tasks"] * 2, # num_tasks * 2 for child hypernetwork
    no_uncond_weights=True,
    no_cond_weights=False,
).to(config["device"])
hnet_child_optim = torch.optim.Adam(hnet_child.parameters(), lr=config["hnet"]["lr"])

hnet_root = ChunkedHMLP(
    hnet_child.unconditional_param_shapes,
    layers=config["hnet"]["model"]["layers"],
    dropout_rate=config["hnet"]["model"]["dropout_rate"], # only for the root hypernetwork
    chunk_size=config["hnet"]["chunk_size"],
    chunk_emb_size=config["hnet"]["chunk_emb_size"],
    cond_chunk_embs=config["hnet"]["cond_chunk_embs"],
    cond_in_size=config["hnet"]["cond_in_size"],
    num_cond_embs=config["data"]["num_tasks"] * 2, # num_tasks * 2 for child hypernetwork
    no_uncond_weights=False,
    no_cond_weights=False,
).to(config["device"])
# hnet_root.apply_chunked_hyperfan_init(mnet=hnet_child)
hnet_root_optim = torch.optim.Adam(hnet_root.parameters(), lr=config["hnet"]["lr"])

print("\nSummary of parameters:")
max_possible_num_of_maintained_params = 0
num_of_maintained_params = 0
for name, m in [("solver_child", solver_child), ("solver_root", solver_root), ("hnet_child", hnet_child), ("hnet_root", hnet_root)]:
    print(f"- {name}:\t{sum(p.numel() for p in m.parameters())}\t({sum([np.prod(p) for p in m.param_shapes])} possible)")
    num_of_maintained_params += sum(p.numel() for p in m.parameters())
    max_possible_num_of_maintained_params += sum([np.prod(p) for p in m.param_shapes])
print(f"---\nTotal available parameters:\t{max_possible_num_of_maintained_params}")
print(f"Parameters maintained:\t\t{num_of_maintained_params}")
print(f"-> Coefficient of compression:\t{(num_of_maintained_params / max_possible_num_of_maintained_params):.5f}")

In [None]:
def correct_param_shapes(solver, params):
    """Correct the shapes of the parameters for the solver"""
    params_solver = []
    src_param_i = 0
    src_param_start_idx = 0

    for target_param_i, p_shape in enumerate(solver_root.param_shapes):
        curr_available_src_params = params[src_param_i].flatten()[src_param_start_idx:].numel()
        if curr_available_src_params >= math.prod(p_shape):
            params_solver.append(params[src_param_i].flatten()[src_param_start_idx:src_param_start_idx + math.prod(p_shape)].view(p_shape))
            src_param_start_idx += math.prod(p_shape)
        else:
            new_param = torch.zeros(math.prod(p_shape), device=config["device"])
            s = 0

            while math.prod(p_shape) > s:
                curr_available_src_params = params[src_param_i].flatten().numel()
                to_add = params[src_param_i].flatten()[src_param_start_idx:min(curr_available_src_params, src_param_start_idx + (math.prod(p_shape) - s))]
                new_param[s:s + to_add.numel()] = to_add
                s += to_add.numel()

                if s < math.prod(p_shape):
                    src_param_i += 1
                    src_param_start_idx = 0
                else:
                    src_param_start_idx += to_add.numel()

            params_solver.append(new_param.view(p_shape))
    return params_solver

In [None]:
def calc_delta_theta(hnet, lr, clip_delta=True, detach=False):
    ret = []
    for p in hnet.internal_params:
        if p.grad is None:
            ret.append(None)
            continue
        if detach:
            ret.append(-lr * p.grad.detach().clone())
        else:
            ret.append(-lr * p.grad.clone())
    return ret

In [None]:
def get_reg_loss_for_cond(hnet, hnet_prev_params, lr, reg_cond_id, detach_d_theta=False):
    # prepare targets (theta for child nets predicted by previous hnet)
    hnet_mode = hnet.training
    hnet.eval()
    with torch.no_grad():
        theta_child_target = hnet(cond_id=reg_cond_id, weights={"uncond_weights": hnet_prev_params} if hnet_prev_params is not None else None)
    # detaching target below is important!
    theta_child_target = torch.cat([p.detach().clone().view(-1) for p in theta_child_target])
    hnet.train(mode=hnet_mode)
    
    d_theta = calc_delta_theta(hnet, lr, detach=detach_d_theta)
    theta_parent_for_pred = []
    for _theta, _d_theta in zip(hnet.internal_params, d_theta):
        if _d_theta is None:
            theta_parent_for_pred.append(_theta)
        else:
            theta_parent_for_pred.append(_theta + _d_theta if detach_d_theta is False else _theta + _d_theta.detach())
    theta_child_predicted = hnet(cond_id=reg_cond_id, weights=theta_parent_for_pred)
    theta_child_predicted = torch.cat([p.view(-1) for p in theta_child_predicted])

    return (theta_child_target - theta_child_predicted).pow(2).sum()

In [None]:
def get_reg_loss(hnet, hnet_prev_params, curr_cond_id, lr=1e-3, clip_grads_max_norm=1., detach_d_theta=False):
    reg_loss = 0
    for c_i in range(hnet._num_cond_embs):
        if curr_cond_id is not None and c_i == curr_cond_id:
            continue
        reg_loss += get_reg_loss_for_cond(hnet, hnet_prev_params, lr, c_i, detach_d_theta)
    return reg_loss / (hnet._num_cond_embs - (curr_cond_id is not None))

In [None]:
def infer(X, scenario, hnet_parent_cond_id, hnet_child_cond_id, hnet_parent, hnet_child, solver_parent, solver_child):
    assert scenario != "hnet->hnet->solver" or hnet_child_cond_id is not None, f"Scenario {scenario} requires hnet_child_cond_id to be set"
    
    if scenario == "hnet->solver":
        params_solver = hnet_parent.forward(cond_id=hnet_parent_cond_id) # parent hnet -> theta parent solver
        y_hat = solver_parent.forward(X, weights=correct_param_shapes(solver_parent, params_solver))
    elif scenario == "hnet->hnet->solver":
        params_hnet_child = hnet_parent.forward(cond_id=hnet_parent_cond_id) # parent hnet -> theta child hnet (only the unconditional ones) -> solver child
        params_solver = hnet_child.forward(cond_id=hnet_child_cond_id, weights=params_hnet_child)
        y_hat = solver_child.forward(X, weights=params_solver)
    else:
        raise ValueError(f"Unknown inference scenario {scenario}")
    return y_hat, params_solver

In [None]:
def print_metrics(datasets : dict, hnet_root, hnet_child, solver_root, solver_child, prefix="", skip_phases=[], wandb_run=None, additional_metrics=None):
    # set the models to eval mode and return them to their original mode after
    ms_modes = []
    for m in [hnet_root, hnet_child, solver_root, solver_child]:
        ms_modes.append([m, m.training])
        m.eval()
    wandb_metrics = {}
    
    print(prefix)
    with torch.no_grad():
        for data_name, (hnet_root_cond_id_hnet_solver, hnet_root_cond_id_hnet_hnet_solver, hnet_child_cond_id, dataset) in datasets.items():
            print(data_name)

            # prepare a test batch for calculating loss & getting solver params
            X = dataset.input_to_torch_tensor(dataset.get_test_inputs(), config["device"], mode="inference")
            y = dataset.output_to_torch_tensor(dataset.get_test_outputs(), config["device"], mode="inference")

            hnet_solver_loss, hnet_solver_acc, hnet_hnet_solver_loss, hnet_hnet_solver_acc = np.nan, np.nan, np.nan, np.nan
            if "hnet->solver" not in skip_phases:
                print("    hnet->solver")
                y_hat, params_solver = infer(X, "hnet->solver", hnet_parent_cond_id=hnet_root_cond_id_hnet_solver, hnet_child_cond_id=None, hnet_parent=hnet_root, hnet_child=hnet_child, solver_parent=solver_root, solver_child=solver_child)
                hnet_solver_loss = F.cross_entropy(y_hat, y).item()
                hnet_solver_acc = (y_hat.argmax(dim=-1) == y.argmax(dim=-1)).float().mean() * 100.
                print(f"        Loss: {hnet_solver_loss:.3f} | Accuracy: {hnet_solver_acc:.3f}")
            
            if "hnet->hnet->solver" not in skip_phases:
                print("    hnet->hnet->solver")
                y_hat, params_solver = infer(X, "hnet->hnet->solver", hnet_parent_cond_id=hnet_root_cond_id_hnet_hnet_solver, hnet_child_cond_id=hnet_child_cond_id, hnet_parent=hnet_root, hnet_child=hnet_child, solver_parent=solver_root, solver_child=solver_child)
                hnet_hnet_solver_loss = F.cross_entropy(y_hat, y).item()
                hnet_hnet_solver_acc = (y_hat.argmax(dim=-1) == y.argmax(dim=-1)).float().mean() * 100.
                print(f"        Loss: {hnet_hnet_solver_loss:.3f} | Accuracy: {hnet_hnet_solver_acc:.3f}")
            
            wandb_metrics[str(data_name)] = {
                "h->s loss": hnet_solver_loss,
                "h->s acc": hnet_solver_acc,
                "h->h->s loss": hnet_hnet_solver_loss,
                "h->h->s acc": hnet_hnet_solver_acc,
            }
    
    if additional_metrics:
        wandb_metrics.update(additional_metrics)
        for n, v in additional_metrics.items():
            print(f"{n}: {v:.3f}")

    if wandb_run is not None:
        wandb_run.log(wandb_metrics)
    
    for m, mode in ms_modes:
        m.train(mode=mode)

In [None]:
def print_stats(stats):
    for c_i, lh in enumerate(stats):
        print(f"{c_i if c_i != 0 else f'{c_i} (root)'}:")
        print('\n'.join([f'{k:>30}\t{f"{v.item():.4f}" if v.numel() == 1 else v.tolist()}' for k,v in dict(sorted(lh.items())).items()]))

In [None]:
def clip_grads(models, reg_clip_grads_max_norm, reg_clip_grads_max_value):
    if reg_clip_grads_max_norm is not None and reg_clip_grads_max_value is not None:
        print("Warning: both reg_clip_grads_max_norm and reg_clip_grads_max_value are set. Using reg_clip_grads_max_norm.")
    for m in models:
        if reg_clip_grads_max_norm is not None:
            torch.nn.utils.clip_grad_norm_(m.parameters(), reg_clip_grads_max_norm)
        elif reg_clip_grads_max_value is not None:
            torch.nn.utils.clip_grad_value_(m.parameters(), reg_clip_grads_max_value)

In [None]:
if config["wandb_logging"]:
    wandb_run = wandb.init(
        project="Hypernets", entity="johnny1188", config=config, group=config["data"]["name"],
        tags=[], notes=f""
    )
    wandb.watch((hnet_root, hnet_child, solver_root, solver_child), log="all", log_freq=100)
else:
    wandb_run = None

#### Training in continual learning setting

In [None]:
# initialize configurations
hnets_cond_ids = [
    {"hnet->solver": {"hnet_root": task_i, "hnet_child": None}, "hnet->hnet->solver": {"hnet_root": task_i + len(data_handlers), "hnet_child": task_i}}
    for task_i in range(len(data_handlers))
]
datasets_for_eval = {d_i: (cond_ids["hnet->solver"]["hnet_root"], cond_ids["hnet->hnet->solver"]["hnet_root"], cond_ids["hnet->hnet->solver"]["hnet_child"], data_handlers[d_i]) for d_i, cond_ids in enumerate(hnets_cond_ids)}

In [None]:
def take_training_step(X, y, parent, child, phase, hnet_parent_prev_params, config, loss_fn=F.cross_entropy):
    """
    parent and child structure: tuple (hnet, solver, hnet_optimizer, hnet_cond_id)
        child can be None if phase is "hnet->solver"
    """
    hnet_parent, solver_parent, hnet_parent_optim, hnet_parent_cond_id = parent
    hnet_child, solver_child, hnet_child_optim, hnet_child_cond_id = child
    for m in (hnet_parent, solver_parent, hnet_child, solver_child):
        if m is not None:
            m.train(mode=True)
    
    hnet_parent_optim.zero_grad()
    if hnet_child_optim is not None:
        hnet_child_optim.zero_grad()
    hnet_parent_optim.zero_grad()
    if hnet_child_optim is not None:
        hnet_child_optim.zero_grad()

    # generate theta and predict
    y_hat, params_solver = infer(X, phase, hnet_parent_cond_id=hnet_parent_cond_id, hnet_child_cond_id=hnet_child_cond_id,
        hnet_parent=hnet_parent, hnet_child=hnet_child, solver_parent=solver_parent, solver_child=solver_child)
    
    # task loss
    loss_class = loss_fn(y_hat, y)
    loss = loss_class
    # solvers' params regularization
    loss_solver_params_reg = torch.tensor(0., device=config["device"])
    if config["hnet"]["reg_alpha"] is not None and config["hnet"]["reg_alpha"] > 0.:
        loss_solver_params_reg = config["hnet"]["reg_alpha"] * sum([p.norm(p=2) for p in params_solver]) / len(params_solver)
    loss += loss_solver_params_reg
    perform_forgetting_reg = config["hnet"]["reg_beta"] is not None and config["hnet"]["reg_beta"] > 0.
    loss.backward(retain_graph=perform_forgetting_reg, create_graph=not config["hnet"]["detach_d_theta"])
    # gradient clipping
    clip_grads([m for m in (hnet_parent, hnet_child) if m is not None], config["hnet"]["reg_clip_grads_max_norm"], config["hnet"]["reg_clip_grads_max_value"])
    
    # regularization against forgetting other contexts
    loss_reg = torch.tensor(0., device=config["device"])
    if config["hnet"]["reg_beta"] is not None and config["hnet"]["reg_beta"] > 0.:
        loss_reg = config["hnet"]["reg_beta"] * get_reg_loss(hnet_parent, hnet_parent_prev_params, curr_cond_id=hnet_parent_cond_id, lr=config["hnet"]["reg_lr"], detach_d_theta=config["hnet"]["detach_d_theta"])
        loss_reg.backward()
        # gradient clipping
        clip_grads([m for m in (hnet_parent, hnet_child) if m is not None], config["hnet"]["reg_clip_grads_max_norm"], config["hnet"]["reg_clip_grads_max_value"])
    
    hnet_parent_optim.step()
    if hnet_child_optim is not None:
        hnet_child_optim.step()
    hnet_parent_optim.zero_grad()
    if hnet_child_optim is not None:
        hnet_child_optim.zero_grad()

    return loss_class.detach().clone(), loss_solver_params_reg.detach().clone(), loss_reg.detach().clone(), y_hat.var(dim=0).detach().clone()

In [None]:
def init_hnet_unconditionals(hnet, uncond_theta):
    assert [s for s in hnet.unconditional_param_shapes] == [list(p.shape) for p in uncond_theta], f"uncond_theta shapes don't match hnet.unconditional_param_shapes"
    
    params_before = [p.clone() for p in hnet.internal_params]
    params_final = [None] * len(hnet.param_shapes)
    # add conditional params
    for p_idx, p in zip(hnet.conditional_param_shapes_ref, hnet.conditional_params):
        params_final[p_idx] = p
    # add unconditional params
    hnet._unconditional_params_ref = hnet.unconditional_param_shapes_ref
    for p_idx, p in zip(hnet.unconditional_param_shapes_ref, uncond_theta):
        params_final[p_idx] = nn.Parameter(p.detach().clone(), requires_grad=True)
    # set internal params
    hnet._hnet._internal_params = nn.ParameterList(params_final)
    hnet._internal_params = nn.ParameterList(hnet._hnet._internal_params)

    return params_before

In [None]:
def remove_hnet_uncondtionals(hnet, prev_params=None):
    # store the unconditional parameters
    unconditionals = []
    for p_idx in hnet.unconditional_param_shapes_ref:
        unconditionals.append(hnet.internal_params[p_idx].detach().clone())
    
    # restore previous state of parameters
    hnet._unconditional_params_ref = None
    
    hnet._hnet._internal_params = nn.ParameterList([
        p for p_idx, p in enumerate(hnet.internal_params) if p_idx not in hnet.unconditional_param_shapes_ref
    ])
    hnet._internal_params = nn.ParameterList(hnet._hnet._internal_params) # TODO: chunked_mlp_hnet line 201
    
    # append additional conditional chunk embeddings
    if hnet._cemb_shape is not None and prev_params is not None:
        for c_i in range(hnet._num_cond_embs):
            param_to_add = prev_params[-hnet._num_cond_embs + c_i]
            if type(param_to_add) == nn.Parameter:
                param_to_add = param_to_add.detach().clone()
            elif type(param_to_add) == torch.Tensor:
                param_to_add = nn.Parameter(param_to_add.detach().clone(), requires_grad=True)
            else:
                raise ValueError(f"prev_params includes a value of type {type(param_to_add)}")
            hnet._internal_params.append(param_to_add)
    return unconditionals

In [None]:
def validate_cells_training_inputs(X, y, cells, config):
    assert X.shape[0] == y.shape[0], f"X and y have different number of samples"
    assert y.shape[1] == config["data"]["num_classes_per_task"], f"y has incorrect number of features"
    assert cells[0]["hnet_to_hnet_cond_id"] is None and cells[0]["hnet_theta_out_target"] is None and cells[0]["n_training_iters_hnet"] in (None, 0), \
        f"The last cell should have no child cells (list of cells sorted from the furthest from the root to the closest to the root)"
    assert cells[-1]["hnet_init_theta"] is None, f"The root cell (last in the cells list) should have no initial theta - its parameters are being learned"
    
    for c_i, c in enumerate(cells):
        assert set(("hnet", "solver", "hnet_optim", "hnet_to_hnet_cond_id", "hnet_to_solver_cond_id", "hnet_init_theta", "hnet_prev_params",
            "hnet_theta_out_target", "n_training_iters_solver", "n_training_iters_hnet")).issubset(set(c.keys())), f"Cell {c_i} is missing some of the required keys"
        if c_i + 1 < len(cells) - 1:
            assert sum([np.prod(p) for p in c["hnet"].unconditional_param_shapes]) == cells[c_i + 1]["hnet"].num_outputs, \
                f"Number of outputs of the {c_i + 1}-th cell's hnet should be equal to the number of unconditional parameters of the {c_i}-th cell's hnet"
    return None

In [None]:
def train_cells(X, y, cells, config, stats):
    """
    cells: list of dictionaries with the following keys (and corresponding values):
        {
            "hnet", "solver", "hnet_optim", "hnet_to_hnet_cond_id", "hnet_to_solver_cond_id", "hnet_init_theta", "hnet_prev_params",
            "hnet_theta_out_target", "n_training_iters_solver", "n_training_iters_hnet"
        }
        List of cells sorted from the furthest from the root to the closest to the root.
    """
    if len(cells) == 0:
        return stats

    # pop the first cell
    hnet, solver, hnet_optim, hnet_to_hnet_cond_id, hnet_to_solver_cond_id, hnet_init_theta, hnet_prev_params, hnet_theta_out_target, \
        n_training_iters_solver, n_training_iters_hnet = cells.pop(0).values()

    # initialize statistics - logging purposes
    c_stats = {l:torch.tensor(0.) for l in ("loss_hnet_hnet", "loss_hnet_solver_class", "loss_hnet_solver_theta_reg", "loss_hnet_forgetting_reg", "y_hat_var")}
    
    # train the hnet -> solver on the given X, y => create theta target for parent hnet
    if n_training_iters_solver is not None and n_training_iters_solver > 0:
        if hnet_init_theta is not None: # is None for the root hnet
            init_hnet_unconditionals(hnet, hnet_init_theta)
            # init optimizer of the initialized  unconditional parameters
            hnet_optim = torch.optim.Adam([*hnet.unconditional_params, *hnet.conditional_params], lr=config["hnet"]["lr"])
        for iter_i in range(n_training_iters_solver):
            curr_cell = (hnet, solver, hnet_optim, hnet_to_solver_cond_id)
            c_stats["loss_hnet_solver_class"], c_stats["loss_hnet_solver_theta_reg"], c_stats["loss_hnet_forgetting_reg"], c_stats["y_hat_var"] = take_training_step(
                X, y, parent=curr_cell, child=(None, None, None, None), phase="hnet->solver",
                hnet_parent_prev_params=hnet_prev_params, config=config, loss_fn=F.cross_entropy
            )
        # set the trained theta as the target for parent hnet
        if hnet_init_theta is not None: # is None for the root hnet
            cells[0]["hnet_theta_out_target"] = remove_hnet_uncondtionals(hnet)
    
    # train the hnet -> hnet on the given target theta
    if n_training_iters_hnet is not None and n_training_iters_hnet > 0:
        if hnet_init_theta is not None: # is None for the root hnet
            assert len(cells) == 0, "hnet_init_theta is not None for a non-root hnet"
            init_hnet_unconditionals(hnet, hnet_init_theta)
            # init optimizer of the initialized unconditional parameters
            hnet_optim = torch.optim.Adam([*hnet.unconditional_params, *hnet.conditional_params], lr=config["hnet"]["lr"])
        perform_forgetting_reg = config["hnet"]["reg_beta"] is not None and config["hnet"]["reg_beta"] > 0.
        
        for iter_i in range(n_training_iters_hnet):
            theta_target = torch.cat([p.detach().clone().view(-1) for p in hnet_theta_out_target])
            
            theta_hat = hnet(cond_id=hnet_to_hnet_cond_id)
            theta_hat = torch.cat([p.view(-1) for p in theta_hat])
            
            loss_hnet_hnet = torch.sqrt(F.mse_loss(theta_hat, theta_target))
            # loss_hnet_hnet = (theta_hat - theta_target).pow(2).sum()
            loss_hnet_hnet.backward(retain_graph=perform_forgetting_reg, create_graph=not config["hnet"]["detach_d_theta"])
            # gradient clipping
            clip_grads([hnet], config["hnet"]["reg_clip_grads_max_norm"], config["hnet"]["reg_clip_grads_max_value"])

            # regularization against forgetting other contexts
            loss_reg = torch.tensor(0., device=config["device"])
            if perform_forgetting_reg:
                loss_reg = config["hnet"]["reg_beta"] * get_reg_loss(hnet, hnet_prev_params, curr_cond_id=hnet_to_hnet_cond_id, lr=config["hnet"]["reg_lr"], detach_d_theta=config["hnet"]["detach_d_theta"])
                loss_reg.backward()
                # gradient clipping
                clip_grads([hnet], config["hnet"]["reg_clip_grads_max_norm"], config["hnet"]["reg_clip_grads_max_value"])

            hnet_optim.step()
            hnet_optim.zero_grad()
            c_stats["loss_hnet_hnet"] = loss_hnet_hnet.detach().clone()
        # set the trained theta as the target for parent hnet
        if hnet_init_theta is not None: # is None for the root hnet
            cells[0]["hnet_theta_out_target"] = remove_hnet_uncondtionals(hnet)


    # one step deeper (onto the parents of the current cell)
    stats.append(c_stats)
    return train_cells(X, y, cells, config, stats)

In [None]:
# select cond_ids for hypernets
# hnet_root_cond_id = hnets_cond_ids[d_i][phase]["hnet_root"]
# hnet_child_cond_id = hnets_cond_ids[d_i][phase]["hnet_child"]

# for i in range(10):
#     cells = [
#         {
#             "hnet": hnet_child,
#             "solver": solver_child,
#             "hnet_optim": hnet_child_optim,
#             "hnet_to_hnet_cond_id": None,
#             "hnet_to_solver_cond_id": hnets_cond_ids[d_i]["hnet->hnet->solver"]["hnet_child"],
#             "hnet_init_theta": hnet_root(cond_id=hnets_cond_ids[d_i]["hnet->hnet->solver"]["hnet_root"]),
#             "hnet_prev_params": None,
#             "hnet_theta_out_target": None,
#             "n_training_iters_solver": 200,
#             "n_training_iters_hnet": 0,
#         },
#         {
#             "hnet": hnet_root,
#             "solver": solver_root,
#             "hnet_optim": hnet_root_optim,
#             "hnet_to_hnet_cond_id": hnets_cond_ids[d_i]["hnet->hnet->solver"]["hnet_root"],
#             "hnet_to_solver_cond_id": hnets_cond_ids[d_i]["hnet->solver"]["hnet_root"],
#             "hnet_init_theta": None,
#             "hnet_prev_params": hnet_root_prev_params,
#             "hnet_theta_out_target": None,
#             "n_training_iters_solver": 0,
#             "n_training_iters_hnet": 50,
#         }
#     ]
#     train_cells(X, y, cells, config)

#     # generate theta and predict
#     y_hat, params_solver = infer(X, phase, hnet_parent_cond_id=hnet_root_cond_id, hnet_child_cond_id=hnet_child_cond_id, hnet_parent=hnet_root, hnet_child=hnet_child, solver_parent=solver_root, solver_child=solver_child)

#     # solvers' params regularization
#     loss_solver_params_reg = sum([p.norm(p=2) for p in params_solver]) / len(params_solver)
#     # task loss
#     loss_class = loss_fn(y_hat, y)
#     print(loss_class.item(), " \t" ,loss_solver_params_reg.item(), " \t", y_hat.var(dim=0))

#### Continual learning - segmented backprop

In [None]:
torch.manual_seed(0)
np.random.seed(0)
loss_fn = nn.CrossEntropyLoss(reduction="mean")
hnet_root_prev_params = None
log_step = 0
phases = config["phases"]

for p_i, phase in enumerate(phases):
    for d_i, data in enumerate(data_handlers):
        # save parameters before solving the task for regularization against forgetting
        hnet_root_prev_params = [p.detach().clone() for p_idx, p in enumerate(hnet_root.unconditional_params)]
        # select cond_ids for hypernets
        hnet_root_cond_id = hnets_cond_ids[d_i][phase]["hnet_root"]
        hnet_child_cond_id = hnets_cond_ids[d_i][phase]["hnet_child"]
        
        for epoch in range(config["epochs"]):
            for i, (batch_size, X, y) in enumerate(data.train_iterator(config["data"]["batch_size"])):
                if config["max_minibatches_per_epoch"] is not None and i > config["max_minibatches_per_epoch"]:
                    break

                X = data.input_to_torch_tensor(X, config["device"], mode="train")
                y = data.output_to_torch_tensor(y, config["device"], mode="train")

                if phase == "hnet->solver":
                    cells = [
                        {
                            "hnet": hnet_root,
                            "solver": solver_root,
                            "hnet_optim": hnet_root_optim,
                            "hnet_to_hnet_cond_id": None,
                            "hnet_to_solver_cond_id": hnets_cond_ids[d_i]["hnet->solver"]["hnet_root"],
                            "hnet_init_theta": None,
                            "hnet_prev_params": hnet_root_prev_params,
                            "hnet_theta_out_target": None,
                            "n_training_iters_solver": config["n_training_iters_solver"],
                            "n_training_iters_hnet": config["n_training_iters_hnet"],
                        }
                    ]
                elif phase == "hnet->hnet->solver":
                    cells = [
                        {
                            "hnet": hnet_child,
                            "solver": solver_child,
                            "hnet_optim": hnet_child_optim,
                            "hnet_to_hnet_cond_id": None,
                            "hnet_to_solver_cond_id": hnets_cond_ids[d_i]["hnet->hnet->solver"]["hnet_child"],
                            "hnet_init_theta": hnet_root(cond_id=hnets_cond_ids[d_i]["hnet->hnet->solver"]["hnet_root"]),
                            "hnet_prev_params": None, # TODO: would those also regularize the root hypernet?
                            "hnet_theta_out_target": None,
                            "n_training_iters_solver": config["n_training_iters_solver"],
                            "n_training_iters_hnet": 0,
                        },
                        {
                            "hnet": hnet_root,
                            "solver": solver_root,
                            "hnet_optim": hnet_root_optim,
                            "hnet_to_hnet_cond_id": hnets_cond_ids[d_i]["hnet->hnet->solver"]["hnet_root"],
                            "hnet_to_solver_cond_id": hnets_cond_ids[d_i]["hnet->solver"]["hnet_root"],
                            "hnet_init_theta": None,
                            "hnet_prev_params": hnet_root_prev_params,
                            "hnet_theta_out_target": None, # will get set during the train_cells() call
                            # "n_training_iters_solver": config["n_training_iters_hnet"], # root cell shouldn't take more steps on task than on hnet
                            "n_training_iters_solver": 0, # root cell shouldn't take more steps on task than on hnet
                            "n_training_iters_hnet": config["n_training_iters_hnet"],
                        }
                    ]

                validate_cells_training_inputs(X, y, cells, config)
                stats = train_cells(X, y, cells, config, [])
                # clear_output(wait=True)
                if i % 3 == 2:
                    print_stats(reversed(stats))
                    print_metrics(
                        datasets_for_eval, hnet_root, hnet_child, solver_root, solver_child,
                        prefix=f"[{p_i + 1}:{phase} | {d_i}/{len(data_handlers) - 1} | {epoch + 1}/{config['epochs']} | {i + 1}]",
                        skip_phases=[],
                        wandb_run=wandb_run,
                        additional_metrics=None
                    )
                    print("---")
    #             break
    #         break
    #     break
    # break

In [None]:
print_metrics(
    datasets_for_eval, hnet_root, hnet_child, solver_root, solver_child,
    prefix=f"[{p_i + 1}:{phase} | {d_i}/{len(data_handlers) - 1} | {epoch + 1}/{config['epochs']} | {i + 1}]",
    skip_phases=[],
    wandb_run=wandb_run,
    additional_metrics=None
)
print("---")

#### Continual learning - full backprop

In [None]:
torch.manual_seed(0)
np.random.seed(0)
loss_fn = nn.CrossEntropyLoss(reduction="mean")
hnet_root_prev_params = None
log_step = 0
phases = config["phases"]

for p_i, phase in enumerate(phases):
    for d_i, data in enumerate(data_handlers):
        # save parameters before solving the task for regularization against forgetting
        hnet_root_prev_params = [p.detach().clone() for p_idx, p in enumerate(hnet_root.unconditional_params)]
        for epoch in range(config["epochs"]):
            for i, (batch_size, X, y) in enumerate(data.train_iterator(config["data"]["batch_size"])):
                if config["max_minibatches_per_epoch"] is not None and i > config["max_minibatches_per_epoch"]:
                    break

                X = data.input_to_torch_tensor(X, config["device"], mode="train")
                y = data.output_to_torch_tensor(y, config["device"], mode="train")

                hnet_root_optim.zero_grad()
                hnet_child_optim.zero_grad()

                # select cond_ids for hypernets
                hnet_root_cond_id = hnets_cond_ids[d_i][phase]["hnet_root"]
                hnet_child_cond_id = hnets_cond_ids[d_i][phase]["hnet_child"]
                # generate theta and predict
                y_hat, params_solver = infer(X, phase, hnet_parent_cond_id=hnet_root_cond_id, hnet_child_cond_id=hnet_child_cond_id, hnet_parent=hnet_root, hnet_child=hnet_child, solver_parent=solver_root, solver_child=solver_child)
                
                # solvers' params regularization
                loss_solver_params_reg = sum([p.norm(p=2) for p in params_solver]) / len(params_solver)
                # task loss
                loss_class = loss_fn(y_hat, y)
                loss = loss_class + config["hnet"]["reg_alpha"] * loss_solver_params_reg
                loss.backward(retain_graph=True, create_graph=not config["hnet"]["detach_d_theta"])
                # gradient clipping
                clip_grads([hnet_child, hnet_root], config["hnet"]["reg_clip_grads_max_norm"], config["hnet"]["reg_clip_grads_max_value"])
                
                # regularization against forgetting other contexts
                loss_reg = config["hnet"]["reg_beta"] * get_reg_loss(hnet_root, hnet_root_prev_params, hnet_root_optim, curr_cond_id=hnet_root_cond_id, lr=config["hnet"]["reg_lr"], detach_d_theta=config["hnet"]["detach_d_theta"])
                loss_reg.backward()
                clip_grads([hnet_child, hnet_root], config["hnet"]["reg_clip_grads_max_norm"], config["hnet"]["reg_clip_grads_max_value"])
                
                hnet_root_optim.step()
                hnet_child_optim.step()
                hnet_root_optim.zero_grad()
                hnet_child_optim.zero_grad()

                if i % 100 == 99:
                    acc = (y_hat.argmax(dim=-1) == y.argmax(dim=-1)).float().mean() * 100.
                    print_metrics(
                        datasets_for_eval, hnet_root, hnet_child, solver_root, solver_child,
                        prefix=f"[{p_i + 1}:{phase} | {d_i}/{len(data_handlers) - 1} | {epoch + 1}/{config['epochs']} | {i + 1}]",
                        skip_phases=[],
                        wandb_run=wandb_run, additional_metrics={
                            "loss_class": loss_class.item(),
                            "acc_class": acc,
                            "loss_solver_params_reg": loss_solver_params_reg.item(),
                            "loss_reg": loss_reg.item(),
                        }
                    )
                    print("---")
                    log_step += 1

#### Training in multitask setting

In [None]:
torch.manual_seed(0)
np.random.seed(0)
loss_fn = nn.CrossEntropyLoss(reduction="mean")
hnet_root_prev_phase_params = None
log_step = 0

phases = ["hnet->solver", "hnet->hnet->solver", "hnet->solver", "hnet->hnet->solver", "hnet->solver", "hnet->hnet->solver"]
for p_i, phase in enumerate(phases):
    print(f"\n\n.... Starting phase {phase} ...")
    if wandb_run is not None:
        wandb_run.log({"Phase": wandb.Table(columns=["phase", "step"], data=[[phase, log_step]])}) # log what phase the training is in
    for epoch in range(config["epochs"]):
        for i, ((_, m_X, m_y),(_, f_X, f_y)) in enumerate(zip(mnist.train_iterator(config["data"]["batch_size"]), fmnist.train_iterator(config["data"]["batch_size"]))):
            if i > config["max_minibatches_per_epoch"]:
                break

            # Mini-batch of MNIST samples
            m_X = mnist.input_to_torch_tensor(m_X, config["device"], mode="train")
            m_y = mnist.output_to_torch_tensor(m_y, config["device"], mode="train")
            # Mini-batch of FashionMNIST samples
            f_X = fmnist.input_to_torch_tensor(f_X, config["device"], mode="train")
            f_y = fmnist.output_to_torch_tensor(f_y, config["device"], mode="train")

            hnet_root_optim.zero_grad()
            hnet_child_optim.zero_grad()

            # MNIST ------------------------------------------------------------
            if phase == "hnet->solver":
                hnet_root_cond_id = 0
                hnet_child_cond_id = None
            elif phase == "hnet->hnet->solver":
                hnet_root_cond_id = 2
                hnet_child_cond_id = 0
            else:
                raise ValueError(f"Unknown phase {phase}")
            
            y_hat, params_solver = infer(m_X, phase, hnet_root_cond_id=hnet_root_cond_id, hnet_child_cond_id=hnet_child_cond_id, hnet_root=hnet_root, hnet_child=hnet_child, solver_root=solver_root, solver_child=solver_child)
            
            # solvers' params regularization + task loss
            m_loss_solver_params_reg = sum([p.norm(p=2) for p in params_solver]) / len(params_solver)
            m_loss = loss_fn(y_hat, m_y.max(dim=1)[1]) + config["hnet"]["reg_alpha"] * m_loss_solver_params_reg
            m_acc = (y_hat.argmax(dim=-1) == m_y.argmax(dim=-1)).float().mean() * 100.
            m_loss.backward()
            if config["hnet"]["reg_clip_grads_max_norm"] is not None:
                clip_grads([hnet_child, hnet_root], config["hnet"]["reg_clip_grads_max_norm"])
            
            # regularization against forgetting other contexts
            m_loss_reg = config["hnet"]["reg_beta"] * get_reg_loss(hnet_root, hnet_root_prev_phase_params, hnet_root_optim, curr_cond_id=hnet_root_cond_id, lr=config["hnet"]["reg_lr"], detach_d_theta=False)
            m_loss_reg.backward()
            if config["hnet"]["reg_clip_grads_max_norm"] is not None:
                clip_grads([hnet_child, hnet_root], config["hnet"]["reg_clip_grads_max_norm"])
            
            hnet_root_optim.step()
            hnet_child_optim.step()
            hnet_root_optim.zero_grad()
            hnet_child_optim.zero_grad()


            # FashionMNIST ------------------------------------------------------------
            if phase == "hnet->solver":
                hnet_root_cond_id = 1
                hnet_child_cond_id = None
            elif phase == "hnet->hnet->solver":
                hnet_root_cond_id = 3
                hnet_child_cond_id = 1
            else:
                raise ValueError(f"Unknown phase {phase}")
            
            y_hat, params_solver = infer(f_X, phase, hnet_root_cond_id=hnet_root_cond_id, hnet_child_cond_id=hnet_child_cond_id, hnet_root=hnet_root, hnet_child=hnet_child, solver_root=solver_root, solver_child=solver_child)
            
            # solvers' params regularization + task loss
            f_loss_solver_params_reg = sum([p.norm(p=2) for p in params_solver]) / len(params_solver)
            f_loss = loss_fn(y_hat, f_y.max(dim=1)[1]) + config["hnet"]["reg_alpha"] * f_loss_solver_params_reg
            f_acc = (y_hat.argmax(dim=-1) == f_y.argmax(dim=-1)).float().mean() * 100.
            f_loss.backward()
            if config["hnet"]["reg_clip_grads_max_norm"] is not None:
                clip_grads([hnet_child, hnet_root], config["hnet"]["reg_clip_grads_max_norm"])
            
            # regularization against forgetting other contexts
            f_loss_reg = config["hnet"]["reg_beta"] * get_reg_loss(hnet_root, hnet_root_prev_phase_params, hnet_root_optim, curr_cond_id=hnet_root_cond_id, lr=config["hnet"]["reg_lr"], detach_d_theta=False)
            f_loss_reg.backward()
            if config["hnet"]["reg_clip_grads_max_norm"] is not None:
                clip_grads([hnet_child, hnet_root], config["hnet"]["reg_clip_grads_max_norm"])
            
            hnet_root_optim.step()
            hnet_child_optim.step()
            hnet_root_optim.zero_grad()
            hnet_child_optim.zero_grad()


            if i % 100 == 99:
                print_metrics(
                    {"MNIST": (0, 2, 0, mnist), "FashionMNIST": (1, 3, 1, fmnist)},
                    hnet_root, hnet_child, solver_root, solver_child,
                    prefix=f"[{phase} | {epoch}/{config['epochs']} | {i + 1}]\nM: {m_loss_reg:.3f} F: {f_loss_reg:.3f}",
                    # skip_phases=["hnet->hnet->solver"] if p_i == 0 and phase == "hnet->solver" else [],
                    skip_phases=[],
                    wandb_run=wandb_run, additional_metrics={
                        "m_loss_class": m_loss.item() - config["hnet"]["reg_alpha"] * m_loss_solver_params_reg.item(),
                        "f_loss_class": f_loss.item() - config["hnet"]["reg_alpha"] * f_loss_solver_params_reg.item(),
                        "m_acc_class": m_acc,
                        "f_acc_class": f_acc,
                        "m_loss_solver_params_reg": m_loss_solver_params_reg.item(),
                        "f_loss_solver_params_reg": f_loss_solver_params_reg.item(),
                        "m_loss_reg": m_loss_reg.item(),
                        "f_loss_reg": f_loss_reg.item(),
                    }
                )
                print("---")
                log_step += 1
    hnet_root_prev_phase_params = [p.detach().clone() for p_idx, p in enumerate(hnet_root.unconditional_params)]