## Imports

In [None]:
import copy
import itertools
import logging
import math
from functools import partial
from pathlib import Path
from typing import Dict
import json

import hydra
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import pytorch_lightning
import seaborn as sns
import torch  # noqa
import wandb
from hydra.utils import instantiate
from matplotlib import tri
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
from scipy.stats import qmc
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler
from tqdm import tqdm

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import seed_index_everything
from nn_core.model_logging import NNLogger

import ccmm  # noqa
from ccmm.matching.utils import (
    apply_permutation_to_statedict,
    get_all_symbols_combinations,
    load_permutations,
    perm_indices_to_perm_matrix,
    plot_permutation_history_animation,
    restore_original_weights,
)
from ccmm.utils.utils import (
    fuse_batch_norm_into_conv,
    get_interpolated_loss_acc_curves,
    l2_norm_models,
    linear_interpolate,
    load_model_from_info,
    map_model_seed_to_symbol,
    normalize_unit_norm,
    project_onto,
    save_factored_permutations,
    vector_to_state_dict,
)

In [None]:
plt.rcParams.update(
    {
        "text.usetex": True,
        "font.family": "serif",
    }
)
sns.set_context("talk")

cmap_name = "coolwarm_r"

from ccmm.utils.plot import Palette

palette = Palette(f"{PROJECT_ROOT}/misc/palette2.json")
palette

In [None]:
logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)
pylogger = logging.getLogger(__name__)

## Configuration

In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="matching")

In [None]:
cfg = compose(config_name="matching", overrides=["model=mlp", "dataset=emnist"])

In [None]:
core_cfg = cfg  # NOQA
cfg = cfg.matching

seed_index_everything(cfg)

## Hyperparameters

In [None]:
num_test_samples = 5000
num_train_samples = 10000

## Load dataset

In [None]:
transform = instantiate(core_cfg.dataset.test.transform)

train_dataset = instantiate(core_cfg.dataset.train, transform=transform)
test_dataset = instantiate(core_cfg.dataset.test, transform=transform)

train_subset = Subset(train_dataset, list(range(num_train_samples)))
train_loader = DataLoader(train_subset, batch_size=5000, num_workers=cfg.num_workers)

test_subset = Subset(test_dataset, list(range(num_test_samples)))

test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)

## Load models

In [None]:
from ccmm.utils.utils import load_model_from_artifact

run = wandb.init(
    project=core_cfg.core.project_name, entity="theshadow2030-sapienza-universit-di-roma", job_type="matching"
)

# {a: 1, b: 2, c: 3, ..}
symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}

artifact_path = (
    lambda seed: f"{core_cfg.core.entity}/{core_cfg.core.project_name}/{core_cfg.dataset.name}_{core_cfg.model.model_identifier}_{seed}:latest"
)

# {a: model_a, b: model_b, c: model_c, ..}
models: Dict[str, LightningModule] = {
    map_model_seed_to_symbol(seed): load_model_from_artifact(run, artifact_path(seed)) for seed in cfg.model_seeds
}
model_orig_weights = {symbol: copy.deepcopy(model.model.state_dict()) for symbol, model in models.items()}

num_models = len(models)

pylogger.info(f"Using {num_models} models with architecture {core_cfg.model.model_identifier}")

## Match models

In [None]:
# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...
from ccmm.matching.matcher import GitRebasinMatcher
from ccmm.matching.utils import get_inverse_permutations

symbols = set(symbols_to_seed.keys())
sorted_symbols = sorted(symbols, reverse=False)
fixed_symbol, permutee_symbol = "a", "b"
fixed_model, permutee_model = models[fixed_symbol].cpu(), models[permutee_symbol].cpu()

In [None]:
from ccmm.matching.permutation_spec import MLPPermutationSpecBuilder, CNNPermutationSpecBuilder

permutation_spec_builder = MLPPermutationSpecBuilder(num_hidden_layers=4)
# permutation_spec_builder = CNNPermutationSpecBuilder()
permutation_spec = permutation_spec_builder.create_permutation_spec()

ref_model = list(models.values())[0]
assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())

### Func maps Weight matching

In [None]:
restore_original_weights(models, model_orig_weights)

In [None]:
from ccmm.matching.func_maps import FM_to_p2p, graph_zoomout_refine

w_descrs = 100
w_laps = 10
w_dcomms = 0.5

num_neighbors = 20

InitFM_mode = "identity"

num_zoomout_iters = 70
step = 1
opt_descriptor_type = "weights"  # "weights", "features", "features_denoised", "eigenneurons", "spectral"
mode = "connectivity"  # connectivity, distance

In [None]:
# dicts for permutations and permuted params, D[a][b] refers to the permutation/params to map b -> a
func_permutations = {symb: {other_symb: None for other_symb in symbols.difference(symb)} for symb in symbols}

In [None]:
from typing import Tuple
from ccmm.matching.func_maps import compute_eigenvectors, fit_func_map
from ccmm.matching.permutation_spec import PermutationSpec
from ccmm.matching.utils import get_permuted_param, perm_cols, perm_matrix_to_perm_indices, perm_rows
from ccmm.matching.weight_matching import (
    LayerIterationOrder,
    compute_weights_similarity,
    get_layer_iteration_order,
    solve_linear_assignment_problem,
)


def func_weight_matching(
    ps: PermutationSpec,
    fixed,
    permutee,
    max_iter=100,
    init_perm=None,
    layer_iteration_order: LayerIterationOrder = LayerIterationOrder.RANDOM,
    verbose=False,
    method="func",
):
    """
    Find a permutation of params_b to make them match params_a.

    :param ps: PermutationSpec
    :param target: the parameters to match
    :param to_permute: the parameters to permute
    """

    if not verbose:
        pylogger.setLevel(logging.WARNING)

    params_a, params_b = fixed, permutee

    perm_sizes = {}

    for p, params_and_axes in ps.perm_to_layers_and_axes.items():

        # p is the permutation matrix name, e.g. P_0, P_1, ..
        # params_and_axes is a list of tuples, each tuple contains the name of the parameter and the axis on which the permutation matrix acts

        # it is enough to select a single parameter and axis, since all the parameters permuted by the same matrix have the same shape
        ref_tuple = params_and_axes[0]
        ref_param_name = ref_tuple[0]
        ref_axis = ref_tuple[1]

        perm_sizes[p] = params_a[ref_param_name].shape[ref_axis]

    # initialize with identity permutation if none given
    all_perm_indices = {p: torch.arange(n) for p, n in perm_sizes.items()} if init_perm is None else init_perm
    # e.g. P0, P1, ..
    perm_names = list(all_perm_indices.keys())

    num_layers = len(perm_names)

    for iteration in tqdm(range(max_iter), desc="Weight matching"):
        progress = False

        perm_order = get_layer_iteration_order(layer_iteration_order, num_layers)

        for p_ix in perm_order:
            print(f"Permuting {perm_names[p_ix]}")

            p = perm_names[p_ix]
            num_neurons = perm_sizes[p]

            # all the params that are permuted by this permutation matrix, together with the axis on which it acts
            # e.g. ('layer_0.weight', 0), ('layer_0.bias', 0), ('layer_1.weight', 0)..
            params_and_axes: List[Tuple[str, int]] = ps.perm_to_layers_and_axes[p]
            # sort by axis, so that we can permute the columns of the weight matrices first
            params_and_axes = sorted(params_and_axes, key=lambda x: x[1])
            # filter out bias
            params_and_axes = [x for x in params_and_axes if "bias" not in x[0]]

            # TODO: check if this is true in more complex architectures (probably not)
            assert len(params_and_axes) == 2, f"Expected 2 params, got {len(params_and_axes)}: {params_and_axes}"

            # AXES 0, CURRENT LAYER
            curr_params_name = params_and_axes[0][0]
            assert params_and_axes[0][1] == 0

            # (num_neurons, neuron_dim)
            w_a = copy.deepcopy(params_a[curr_params_name])
            w_b = copy.deepcopy(params_b[curr_params_name])
            assert w_a.shape == w_b.shape

            perms_to_apply = ps.layer_and_axes_to_perm[curr_params_name]

            col_perm_to_apply = perms_to_apply[1]

            if col_perm_to_apply is not None:
                # apply the tranpose of the previous permutation to the columns of the current layer
                perm_matrix = perm_indices_to_perm_matrix(all_perm_indices[col_perm_to_apply])
                print(
                    f"Permuting the columns of {curr_params_name} with {col_perm_to_apply}, shape: {perm_matrix.shape}"
                )

                w_b = perm_cols(w_b, perm_matrix.T)

            # AXES 1, NEXT LAYER
            next_params_name = params_and_axes[1][0]
            w_a_next = copy.deepcopy(params_a[next_params_name])
            w_b_next = copy.deepcopy(params_b[next_params_name])

            assert w_a_next.shape == w_b_next.shape

            perms_to_apply = ps.layer_and_axes_to_perm[next_params_name]
            perm_row_next_layer = perms_to_apply[0]

            if perm_row_next_layer is not None:
                # permute the rows of the next layer by its permutation matrix
                w_b_next = perm_rows(w_b_next, perm_indices_to_perm_matrix(all_perm_indices[perm_row_next_layer]))

            w_a_next = torch.moveaxis(w_a_next, 1, 0).reshape((num_neurons, -1))
            w_b_next = torch.moveaxis(w_b_next, 1, 0).reshape((num_neurons, -1))

            w_a = torch.cat((w_a, w_a_next), dim=1)
            w_b = torch.cat((w_b, w_b_next), dim=1)

            print(f"w_a shape: {w_a.shape}, w_b shape: {w_b.shape}")

            sim_matrix = w_a @ w_b.T

            if method == "func":

                w_b_evecs, w_a_evecs, w_b_evals, w_a_evals = compute_eigenvectors(
                    w_b.numpy(), w_a.numpy(), radius=None, num_neighbors=num_neighbors, mode=mode, normalize_lap=True
                )

                k1, k2 = int(0.5 * num_neurons), int(0.5 * num_neurons)
                print(f"k1: {k1}, k2: {k2}")

                FM_opt, FM_loss = fit_func_map(
                    w_b.numpy(),
                    w_a.numpy(),
                    w_b_evecs,
                    w_a_evecs,
                    w_b_evals,
                    w_a_evals,
                    k1,
                    k2,
                    InitFM_mode,
                    w_descrs,
                    w_lap=w_laps,
                    w_dcomm=w_dcomms,
                    method="optimize",
                )

                FM_opt_zo = graph_zoomout_refine(FM_opt, w_b_evecs, w_a_evecs, num_iters=num_zoomout_iters, step=step)

                perm = FM_to_p2p(FM_opt_zo, w_b_evecs, w_a_evecs, n_jobs=1)

            elif method == "lap":
                perm = solve_linear_assignment_problem(sim_matrix, return_matrix=True)

            # plt.imshow(perm, cmap="coolwarm")
            # plt.show()

            old_similarity = compute_weights_similarity(sim_matrix, all_perm_indices[p])

            all_perm_indices[p] = perm_matrix_to_perm_indices(perm)

            new_similarity = compute_weights_similarity(sim_matrix, all_perm_indices[p])

            pylogger.info(f"Iteration {iteration}, Permutation {p}: {new_similarity - old_similarity}")

            progress = progress or new_similarity > old_similarity + 1e-12
            # loss_decrease = loss[p] - FM_loss
            # progress = progress or loss_decrease > 1e-6
            # pylogger.info(f"Loss decrease for {p}: {loss_decrease}")

            # loss[p] = FM_loss

        if not progress:
            break

    return all_perm_indices

In [None]:
restore_original_weights(models, model_orig_weights)

perm_indices = func_weight_matching(
    permutation_spec,
    fixed_model.model.cpu().state_dict(),
    permutee_model.model.cpu().state_dict(),
    max_iter=100,
    verbose=True,
    method="func",
)

func_permutations[fixed_symbol][permutee_symbol] = perm_indices

func_permutations[permutee_symbol][fixed_symbol] = get_inverse_permutations(perm_indices)

In [None]:
from scripts.evaluate_matched_models import evaluate_pair_of_models

restore_original_weights(models, model_orig_weights)

updated_params = {fixed_symbol: {permutee_symbol: None}}

pylogger.info(f"Permuting model {permutee_symbol} into {fixed_symbol}.")

# perms[a, b] maps b -> a
updated_params[fixed_symbol][permutee_symbol] = apply_permutation_to_statedict(
    permutation_spec, func_permutations[fixed_symbol][permutee_symbol], models[permutee_symbol].model.state_dict()
)
restore_original_weights(models, model_orig_weights)

lambdas = [0.0, 0.5, 1.0]

func_results = evaluate_pair_of_models(
    models,
    fixed_symbol,
    permutee_symbol,
    updated_params,
    train_loader,
    test_loader,
    lambdas,
    core_cfg,
)

### Naive

In [None]:
restore_original_weights(models, model_orig_weights)

In [None]:
from ccmm.matching.matcher import DummyMatcher

matcher = DummyMatcher(name="naive", permutation_spec=permutation_spec)

In [None]:
# dicts for permutations and permuted params, D[a][b] refers to the permutation/params to map b -> a
naive_permutations = {symb: {other_symb: None for other_symb in symbols.difference(symb)} for symb in symbols}

naive_permutations[fixed_symbol][permutee_symbol], perm_history = matcher(
    fixed=fixed_model.model, permutee=permutee_model.model
)

naive_permutations[permutee_symbol][fixed_symbol] = get_inverse_permutations(
    naive_permutations[fixed_symbol][permutee_symbol]
)

In [None]:
from scripts.evaluate_matched_models import evaluate_pair_of_models

restore_original_weights(models, model_orig_weights)

updated_params = {fixed_symbol: {permutee_symbol: None}}

pylogger.info(f"Permuting model {permutee_symbol} into {fixed_symbol}.")

# perms[a, b] maps b -> a
updated_params[fixed_symbol][permutee_symbol] = apply_permutation_to_statedict(
    permutation_spec, naive_permutations[fixed_symbol][permutee_symbol], models[permutee_symbol].model.state_dict()
)
restore_original_weights(models, model_orig_weights)

lambdas = [0.0, 0.5, 1]  # np.linspace(0, 1, num=4)

In [None]:
naive_results = evaluate_pair_of_models(
    models,
    fixed_symbol,
    permutee_symbol,
    updated_params,
    train_loader,
    test_loader,
    lambdas,
    core_cfg,
)

### Git Re-Basin

In [None]:
restore_original_weights(models, model_orig_weights)

In [None]:
# dicts for permutations and permuted params, D[a][b] refers to the permutation/params to map b -> a
gitrebasin_permutations = {symb: {other_symb: None for other_symb in symbols.difference(symb)} for symb in symbols}

matcher = GitRebasinMatcher(name="git_rebasin", permutation_spec=permutation_spec)
gitrebasin_permutations[fixed_symbol][permutee_symbol], perm_history = matcher(
    fixed=fixed_model.model.cpu(), permutee=permutee_model.model.cpu()
)

gitrebasin_permutations[permutee_symbol][fixed_symbol] = get_inverse_permutations(
    gitrebasin_permutations[fixed_symbol][permutee_symbol]
)

In [None]:
# for perm in gitrebasin_permutations[fixed_symbol][permutee_symbol].values():

#     plt.imshow(perm_indices_to_perm_matrix(perm), cmap="coolwarm")
#     plt.show()

In [None]:
from scripts.evaluate_matched_models import evaluate_pair_of_models

restore_original_weights(models, model_orig_weights)

updated_params = {fixed_symbol: {permutee_symbol: None}}

pylogger.info(f"Permuting model {permutee_symbol} into {fixed_symbol}.")

# perms[a, b] maps b -> a
updated_params[fixed_symbol][permutee_symbol] = apply_permutation_to_statedict(
    permutation_spec, gitrebasin_permutations[fixed_symbol][permutee_symbol], models[permutee_symbol].model.state_dict()
)
restore_original_weights(models, model_orig_weights)

lambdas = [0.0, 0.5, 1.0]

gitrebasin_results = evaluate_pair_of_models(
    models,
    fixed_symbol,
    permutee_symbol,
    updated_params,
    train_loader,
    test_loader,
    lambdas,
    core_cfg,
)

## QAP

In [None]:
restore_original_weights(models, model_orig_weights)

## Evaluation

In [None]:
results = {"git_rebasin": gitrebasin_results, "naive": naive_results, "func": func_results}
# plot train and test acc

for i, (method, method_results) in enumerate(results.items()):

    test_acc = method_results["test_acc"]
    train_acc = method_results["train_acc"]

    plt.plot(lambdas, train_acc, label=f"{method}", linestyle="solid", color=palette.get_colors(3)[i])
    plt.plot(lambdas, test_acc, label=f"{method}", linestyle="dashed", color=palette.get_colors(3)[i])

plt.legend()
plt.show()