In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import math
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
from hypnettorch.data import FashionMNISTData, MNISTData
from hypnettorch.data.special.split_mnist import get_split_mnist_handlers
from hypnettorch.data.special.split_cifar import get_split_cifar_handlers
from hypnettorch.mnets import LeNet, ResNet
from hypnettorch.hnets import HMLP, StructuredHMLP, ChunkedHMLP

from utils.data import get_mnist_data_loaders, get_emnist_data_loaders, randomize_targets, select_from_classes
from utils.visualization import show_imgs, get_model_dot
from utils.others import measure_alloc_mem, count_parameters
from utils.timing import func_timer
from utils.metrics import get_accuracy


from IPython.display import clear_output

torch.set_printoptions(precision=3, linewidth=180)

In [None]:
config = {
    "epochs": 2,
    "max_minibatches_per_epoch": 600,
    "data": {
        "name": "mnist|fmnist",
        "batch_size": 32,
        "data_dir": "data_tmp",
        "num_tasks": 5, # only for split tasks
        "num_classes_per_task": 2, # only for split tasks
        "validation_size": 960,
    },
    "solver": {
        "use": "resnet",
        "lenet": {
            "arch": "mnist_large",
            "num_classes": 10,
            "no_weights": True,
        },
        "resnet": {
            "n": 5,
            "k": 1,
            "use_bias": True,
            "num_classes": 10,
            "no_weights": True,
        },
    },
    "hnet": {
        "lr": 1e-3,
        "reg_lr": 1e-3,
        "model": {
            "layers": [128, 128],  
        },
        "chunk_emb_size": 16,
        "chunk_size": 3500,
        "cond_in_size": 32,
        "num_cond_embs": 4,
        "cond_chunk_embs": True,
        "reg_beta": 0.03,
    },
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    # "device": "cpu",
}

print(f"... Running on {config['device']} ...")

In [None]:
torch.manual_seed(0)
np.random.seed(0)

if config["data"]["name"] == "mnist|fmnist":
    mnist = MNISTData(config["data"]["data_dir"], use_one_hot=True, validation_size=config["data"]["validation_size"])
    fmnist = FashionMNISTData(config["data"]["data_dir"], use_one_hot=True, validation_size=config["data"]["validation_size"])
elif config["data"]["name"] == "splitmnist":
    dhandlers = get_split_mnist_handlers(config["data"]["data_dir"], use_one_hot=True, num_tasks=config["data"]["num_tasks"], num_classes_per_task=config["data"]["num_classes_per_task"], validation_size=config["data"]["validation_size"])
elif config["data"]["name"] == "splitcifar":
    dhandlers = get_split_mnist_handlers(config["data"]["data_dir"], use_one_hot=True, num_tasks=config["data"]["num_tasks"], num_classes_per_task=config["data"]["num_classes_per_task"], validation_size=config["data"]["validation_size"])

In [None]:
torch.manual_seed(0)
np.random.seed(0)
# target networks (solvers)
if config["solver"]["use"] == "lenet":
    solver_child = LeNet(in_shape=mnist.in_shape, **config["solver"]["lenet"]).to(config["device"])
    solver_root = LeNet(in_shape=mnist.in_shape, **config["solver"]["lenet"]).to(config["device"])
elif config["solver"]["use"] == "resnet":
    solver_child = ResNet(in_shape=mnist.in_shape, **config["solver"]["resnet"]).to(config["device"])
    solver_root = ResNet(in_shape=mnist.in_shape, **config["solver"]["resnet"]).to(config["device"])
else:
    raise ValueError(f"Unknown solver: {config['solver']['use']}")

In [None]:
# hnet = StructuredHMLP(
#     solver.param_shapes,
#     chunk_shapes=[[[16]], [[32]], [[64]], [[16, 16, 3, 3], [16]], [[32, 32, 3, 3], [32]], [[64, 64, 3, 3], [64]], [[10, 64], [10]]],
#     num_per_chunk=[14, 12, 12, 6, 5, 5, 1],
#     chunk_emb_sizes=32,
#     hmlp_kwargs=config["hnet"]["model"],
#     assembly_fct=assembly_fct,
#     uncond_in_size=0, cond_in_size=8, num_cond_embs=2).to(config["device"]
# )
# """
# missing in chunk_shapes:
# [16, 1, 3, 3] [16]
# [32, 16, 3, 3] [32]
# [64, 32, 3, 3] [64]
# """

In [None]:
torch.manual_seed(0)
np.random.seed(0)
hnet_child = ChunkedHMLP(
    solver_child.param_shapes,
    layers=config["hnet"]["model"]["layers"],
    chunk_size=config["hnet"]["chunk_size"],
    chunk_emb_size=config["hnet"]["chunk_emb_size"],
    cond_chunk_embs=config["hnet"]["cond_chunk_embs"],
    cond_in_size=config["hnet"]["cond_in_size"],
    num_cond_embs=config["hnet"]["num_cond_embs"],
    no_uncond_weights=True,
    no_cond_weights=False,
).to(config["device"])

hnet_root = ChunkedHMLP(
    hnet_child.unconditional_param_shapes,
    layers=config["hnet"]["model"]["layers"],
    chunk_size=config["hnet"]["chunk_size"],
    chunk_emb_size=config["hnet"]["chunk_emb_size"],
    cond_chunk_embs=config["hnet"]["cond_chunk_embs"],
    cond_in_size=config["hnet"]["cond_in_size"],
    num_cond_embs=config["hnet"]["num_cond_embs"],
    no_uncond_weights=False,
    no_cond_weights=False,
).to(config["device"])
# hnet_root.apply_chunked_hyperfan_init(mnet=hnet_child)

In [None]:
# with chunking (hypernet -> hypernet -> target net)
# def assembly_fct(list_of_chunks):
#     assert len(list_of_chunks) == 107
#     params_out = [
#         list_of_chunks[0][0],
#         list_of_chunks[1][0],
#         list_of_chunks[2][0],
#         list_of_chunks[0][1],
#         list_of_chunks[2][1],
#         list_of_chunks[1][1]
#     ]
#     to_concat = []
#     for i in range(3, 52 + 3):
#         to_concat.append(list_of_chunks[i][0])
#     params_out.append(torch.cat(to_concat, dim=0))

#     to_concat = []
#     for i in range(52 + 3, 52 + 52 + 3):
#         to_concat.append(list_of_chunks[i][0])
#     params_out.append(torch.cat(to_concat, dim=0))
    
#     return params_out

# hnet_root = StructuredHMLP(
#     hnet_child.param_shapes,
#     chunk_shapes=[[[8],[100]], [[100, 8], [100, 100]], [[420, 100]], [[420]]],
#     num_per_chunk=[2, 1, 52, 52],
#     chunk_emb_sizes=32,
#     hmlp_kwargs=dict(layers=[100, 100]),
#     assembly_fct=assembly_fct,
#     uncond_in_size=0, cond_in_size=8, num_cond_embs=2).to(config["device"]
# )

In [None]:
def calc_accuracy(data, solver, solver_weights, use_data_from="validation"):
    """Compute the test accuracy for a given dataset (validation)"""
    assert use_data_from == "train" or data.num_val_samples > 0, "No validation data available."
    solver_train = solver.training
    solver.eval()
    acc = None

    with torch.no_grad():
        if use_data_from == "validation":
            num_correct = 0

            for batch_size, X, y, ids in data.val_iterator(config["data"]["batch_size"], return_ids=True):
                X = data.input_to_torch_tensor(X, config["device"], mode='inference')
                y = data.output_to_torch_tensor(y, config["device"], mode='inference')
                y_hat = solver.forward(X, weights=solver_weights)
                num_correct += int(torch.sum(y_hat.argmax(dim=1) == y.argmax(dim=1)).detach().cpu())

            acc = num_correct / data.num_val_samples * 100.
        elif use_data_from == "train":
                # Process complete test set as one batch.
                test_in = data.input_to_torch_tensor( \
                    data.get_test_inputs(), config["device"], mode='inference')
                test_out = data.input_to_torch_tensor( \
                    data.get_test_outputs(), config["device"], mode='inference')
                test_lbls = test_out.max(dim=1)[1]

                if solver_weights is not None:
                    logits = solver(test_in, weights=solver_weights)
                else:
                    logits = solver(test_in)
                pred_lbls = logits.max(dim=1)[1]

                acc = torch.sum(test_lbls == pred_lbls) / test_lbls.numel() * 100.
        else:
            raise ValueError("Unknown data source (use 'train' or 'validation').")

    solver.train(mode=solver_train)
    return acc

In [None]:
def correct_param_shapes(solver, params):
    """Correct the shapes of the parameters for the solver"""
    params_solver = []
    src_param_i = 0
    src_param_start_idx = 0

    for target_param_i, p_shape in enumerate(solver_root.param_shapes):
        curr_available_src_params = params[src_param_i].flatten()[src_param_start_idx:].numel()
        if curr_available_src_params >= math.prod(p_shape):
            params_solver.append(params[src_param_i].flatten()[src_param_start_idx:src_param_start_idx + math.prod(p_shape)].view(p_shape))
            src_param_start_idx += math.prod(p_shape)
        else:
            new_param = torch.zeros(math.prod(p_shape), device=config["device"])
            s, e = 0, 0

            while math.prod(p_shape) > e:
                curr_available_src_params = params[src_param_i].flatten().numel()
                to_add = params[src_param_i].flatten()[src_param_start_idx:min(curr_available_src_params, src_param_start_idx + (math.prod(p_shape) - e))]
                e = s + to_add.numel()
                new_param[s:e] = to_add
                s += to_add.numel()

                if e < math.prod(p_shape):
                    src_param_i += 1
                    src_param_start_idx = 0
                else:
                    src_param_start_idx += to_add.numel()

            params_solver.append(new_param.view(p_shape))
    return params_solver

In [None]:
def calc_delta_theta(optimizer, lr, clip_delta=True, detach=False):
    ret = []
    for g in optimizer.param_groups:
        for p in g["params"]:
            if p.grad is None:
                ret.append(None)
                continue
            if detach:
                ret.append(-lr * p.grad.detach().clone())
            else:
                ret.append(-lr * p.grad.clone())
    return ret

In [None]:
def get_reg_loss_for_cond(hnet, hnet_prev_params, hnet_optimizer, lr, reg_cond_id, detach_d_theta=False):
    # prepare targets (theta for child nets predicted by previous hnet)
    hnet_mode = hnet.training
    hnet.eval()
    with torch.no_grad():
        theta_child_target = hnet(cond_id=reg_cond_id, weights={"uncond_weights": hnet_prev_params} if hnet_prev_params is not None else None)
    # detaching target below is important!
    theta_child_target = torch.cat([p.detach().clone().view(-1) for p in theta_child_target])
    hnet.train(mode=hnet_mode)
    
    d_theta = calc_delta_theta(hnet_optimizer, lr, detach=detach_d_theta)
    theta_parent_for_pred = []
    for _theta, _d_theta in zip(hnet.internal_params, d_theta):
        if _d_theta is None:
            theta_parent_for_pred.append(_theta)
        else:
            theta_parent_for_pred.append(_theta + _d_theta if detach_d_theta is False else _theta + _d_theta.detach())
    theta_child_predicted = hnet(cond_id=reg_cond_id, weights=theta_parent_for_pred)
    theta_child_predicted = torch.cat([p.view(-1) for p in theta_child_predicted])

    return (theta_child_target - theta_child_predicted).pow(2).sum()

In [None]:
def get_reg_loss(hnet, hnet_prev_params, hnet_optimizer, curr_cond_id, lr=1e-3, clip_grads_max_norm=1., detach_d_theta=False):
    if clip_grads_max_norm is not None:
        torch.nn.utils.clip_grad_norm_(hnet.parameters(), clip_grads_max_norm)
    reg_loss = 0
    for c_i in range(hnet._num_cond_embs):
        if curr_cond_id is not None and c_i == curr_cond_id:
            continue
        reg_loss += get_reg_loss_for_cond(hnet, hnet_prev_params, hnet_optimizer, lr, c_i, detach_d_theta)
    return reg_loss / (hnet._num_cond_embs - (curr_cond_id is not None))

In [None]:
def infer(X, scenario, hnet_root_cond_id, hnet_child_cond_id, hnet_root, hnet_child, solver_root, solver_child):
    assert scenario != "hnet->hnet->solver" or hnet_child_cond_id is not None, f"Scenario {scenario} requires hnet_child_cond_id to be set"
    
    if scenario == "hnet->solver":
        params_solver = hnet_root.forward(cond_id=hnet_root_cond_id) # root hnet -> params root solver
        y_hat = solver_root.forward(X, weights=correct_param_shapes(solver_root, params_solver))
    elif scenario == "hnet->hnet->solver":
        # params_hnet_child = hnet_root.forward(cond_id=hnet_root_cond_id)[hnet_child._num_cond_embs:] # root hnet -> params child hnet (only the unconditional ones)
        # if len(hnet_child.conditional_param_shapes) is not hnet_child._num_cond_embs:
        #     params_hnet_child = params_hnet_child[:-(len(hnet_child.conditional_param_shapes) - hnet_child._num_cond_embs)]
        params_hnet_child = hnet_root.forward(cond_id=hnet_root_cond_id) # root hnet -> params child hnet (only the unconditional ones)
        params_solver = hnet_child.forward(cond_id=hnet_child_cond_id, weights=params_hnet_child)
        y_hat = solver_child.forward(X, weights=params_solver)
    else:
        raise ValueError(f"Unknown inference scenario {scenario}")
    return y_hat, params_solver

In [None]:
def print_metrics(datasets : dict, hnet_root, hnet_child, solver_root, solver_child, loss_fn, prefix="", skip_phases=[]):
    print(prefix)
    with torch.no_grad():
        for data_name, (hnet_root_cond_id_hnet_solver, hnet_root_cond_id_hnet_hnet_solver, hnet_child_cond_id, dataset, X, y) in datasets.items():
            print(data_name)
            if "hnet->solver" not in skip_phases:
                print("    hnet->solver")
                y_hat, params_solver_root = infer(X, "hnet->solver", hnet_root_cond_id_hnet_solver, None, hnet_root, hnet_child, solver_root, solver_child)
                loss_root = loss_fn(y_hat, y)
                print(f"        Loss: {loss_root.item():.3f} | Accuracy: {calc_accuracy(dataset, solver_root, correct_param_shapes(solver_root, params_solver_root)):.3f}")
            
            if "hnet->hnet->solver" not in skip_phases:
                print("    hnet->hnet->solver")
                y_hat, params_solver_child = infer(X, "hnet->hnet->solver", hnet_root_cond_id_hnet_hnet_solver, hnet_child_cond_id, hnet_root, hnet_child, solver_root, solver_child)
                loss_child = loss_fn(y_hat, y)
                print(f"        Loss: {loss_child.item():.3f} | Accuracy: {calc_accuracy(dataset, solver_child, params_solver_child):.3f}")

In [None]:
torch.manual_seed(0)
np.random.seed(0)
hnet_root_optim = torch.optim.Adam(hnet_root.internal_params, lr=config["hnet"]["lr"])
hnet_child_optim = torch.optim.Adam(hnet_child.internal_params, lr=config["hnet"]["lr"])
loss_fn = nn.CrossEntropyLoss(reduction="mean")
hnet_root_prev_phase_params = None

phases = ["hnet->solver", "hnet->hnet->solver", "hnet->solver", "hnet->hnet->solver"]
for p_i, phase in enumerate(phases):
    print(f"\n\n.... Starting phase {phase} ...")
    for epoch in range(config["epochs"]):
        for i, ((_, m_X, m_y),(_, f_X, f_y)) in enumerate(zip(mnist.train_iterator(config["data"]["batch_size"]), fmnist.train_iterator(config["data"]["batch_size"]))):
            if i > config["max_minibatches_per_epoch"]:
                break

            # Mini-batch of MNIST samples
            m_X = mnist.input_to_torch_tensor(m_X, config["device"], mode="train")
            m_y = mnist.output_to_torch_tensor(m_y, config["device"], mode="train")
            # Mini-batch of FashionMNIST samples
            f_X = fmnist.input_to_torch_tensor(f_X, config["device"], mode="train")
            f_y = fmnist.output_to_torch_tensor(f_y, config["device"], mode="train")

            hnet_root_optim.zero_grad()
            hnet_child_optim.zero_grad()

            # Compute MNIST loss
            if phase == "hnet->solver":
                hnet_root_cond_id = 0
                hnet_child_cond_id = None
            elif phase == "hnet->hnet->solver":
                hnet_root_cond_id = 2
                hnet_child_cond_id = 0
            else:
                raise ValueError(f"Unknown phase {phase}")
            y_hat, _ = infer(m_X, phase, hnet_root_cond_id=hnet_root_cond_id, hnet_child_cond_id=hnet_child_cond_id, hnet_root=hnet_root, hnet_child=hnet_child, solver_root=solver_root, solver_child=solver_child)
            m_loss = loss_fn(y_hat, m_y.max(dim=1)[1])
            m_loss.backward(retain_graph=True)
            # regularization against forgetting other contexts
            m_loss_reg = get_reg_loss(hnet_root, hnet_root_prev_phase_params, hnet_root_optim, curr_cond_id=hnet_root_cond_id, lr=config["hnet"]["reg_lr"], clip_grads_max_norm=1., detach_d_theta=False)
            hnet_root_optim.zero_grad()
            hnet_child_optim.zero_grad()

            # Compute FashionMNIST loss
            if phase == "hnet->solver":
                hnet_root_cond_id = 1
                hnet_child_cond_id = None
            elif phase == "hnet->hnet->solver":
                hnet_root_cond_id = 3
                hnet_child_cond_id = 1
            else:
                raise ValueError(f"Unknown phase {phase}")
            y_hat, _ = infer(f_X, phase, hnet_root_cond_id=hnet_root_cond_id, hnet_child_cond_id=hnet_child_cond_id, hnet_root=hnet_root, hnet_child=hnet_child, solver_root=solver_root, solver_child=solver_child)
            f_loss = loss_fn(y_hat, f_y.max(dim=1)[1])
            f_loss.backward(retain_graph=True)
            # regularization against forgetting other contexts
            f_loss_reg = get_reg_loss(hnet_root, hnet_root_prev_phase_params, hnet_root_optim, curr_cond_id=hnet_root_cond_id, lr=config["hnet"]["reg_lr"], clip_grads_max_norm=1., detach_d_theta=False)
            hnet_root_optim.zero_grad()
            hnet_child_optim.zero_grad()

            total_loss = m_loss + f_loss + config["hnet"]["reg_beta"] * m_loss_reg + config["hnet"]["reg_beta"] * f_loss_reg
            total_loss.backward()
            
            hnet_root_optim.step()
            hnet_child_optim.step()

            if i % 100 == 99:
                print_metrics(
                    {"MNIST": (0, 2, 0, mnist, m_X, m_y.max(dim=1)[1]), "FashionMNIST": (1, 3, 1, fmnist, f_X, f_y.max(dim=1)[1])},
                    hnet_root, hnet_child, solver_root, solver_child, loss_fn, prefix=f"[{phase} | {epoch}/{config['epochs']} | {i + 1}]\nM: {m_loss_reg:.2f} F: {f_loss_reg:.2f}",
                    skip_phases=["hnet->hnet->solver"] if p_i == 0 and phase == "hnet->solver" else [],
                )
    hnet_root_prev_phase_params = [p.detach().clone() for p_idx, p in enumerate(hnet_root.unconditional_params)]

In [None]:
print_metrics({"MNIST": (0, 2, 0, mnist, m_X, m_y.max(dim=1)[1]), "FashionMNIST": (1, 3, 1, fmnist, f_X, f_y.max(dim=1)[1])}, hnet_root, hnet_child, solver_root, solver_child, hnet_root_optimizer, loss_fn, prefix=f"[{phase} | {epoch}/{config['epochs']} | {i + 1}]")

- 98 mnist, 89 fmnist (hypernet -> target net)
- 98 mnist, 88 fmnist (hypernet -> hypernet -> target net)