# Get the permuted model parameters after performing permutation found by ZipIt!

This notebook contains the implementation to get the permuted model parameters after performing permutation found by ZipIt!. These parameters are then used to evaluate the performances of different post-merging normalization strategies, as reported in Figure 3 of our paper.

For running the code, please first use our training scripts to train the models and move the checkpoints to the corresponding directories in the following code.

In [1]:
import torch
from copy import deepcopy
import os
from utils import get_config_from_name, prepare_experiment_config,\
     get_merging_fn
from model_merger import ModelMerge
from lmc_utils import interpolate_state_dicts

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
config_name = 'cifar10_my_vgg16'
device = 'cuda:0'
raw_config = get_config_from_name(config_name, device=device)
pairs_s = ['1_2']*2 + ['2_3']*2 + ['1_3']*1

num_saved_pairs = 0
for pairs in pairs_s:
    # change the model bases to the desired pair
    pairs = pairs.split('_')
    pairs = [int(pair) for pair in pairs]
    for i, model_idx in enumerate(pairs):
        path = raw_config['model']['bases'][i]  # ..._1.pth
        # replace the last digit with the model_idx
        path = path[:-5] + '_' + str(model_idx) + '.pt'
        # remove ./ from the path
        # if path.startswith('./'):
        #     path = path[2:]
        raw_config['model']['bases'][i] = path

    # prepare models
    model_paths = deepcopy(raw_config['model']['bases'])
    cur_config = deepcopy(raw_config)
    config = prepare_experiment_config(cur_config)
    train_loader = config['data']['train']['full']
    test_loader = config['data']['test']['full']
    base_models = [base_model for base_model in
                    config['models']['bases']]
    Grapher = config['graph']
    criterion = torch.nn.CrossEntropyLoss()

    merging_fn = 'match_tensors_zipit' # use zipit to match tensors
    graphs = [Grapher(deepcopy(base_model)).graphify() for base_model
                        in base_models]
    Merge = ModelMerge(*graphs, device=device)
    Merge.transform(
        deepcopy(config['models']['new']),
        train_loader,
        transform_fn=get_merging_fn(merging_fn),
        metric_classes=config['metric_fns'],
        stop_at=None,
        start_at=None
    )

    # get permuted model
    graphs = Merge.graphs
    base_model_merge_s = [deepcopy(graph.model) for graph in graphs]
    # remove all hooks from the model
    for model in base_model_merge_s:
        model._forward_hooks = {}
        model._backward_hooks = {}
    sd_1_permuted = base_model_merge_s[0].state_dict()
    sd_2_permuted = base_model_merge_s[1].state_dict()
    sd_1_save_name = f'checkpoints/cifar10_my_vgg16_{num_saved_pairs*2+1}_zipit.pth'
    sd_2_save_name = f'checkpoints/cifar10_my_vgg16_{(num_saved_pairs+1)*2}_zipit.pth'
    if os.path.exists(sd_1_save_name):
        raise FileExistsError(f'{sd_1_save_name} already exists')
    elif os.path.exists(sd_2_save_name):
        raise FileExistsError(f'{sd_2_save_name} already exists')
    else:
        torch.save(sd_1_permuted, sd_1_save_name)
        torch.save(sd_2_permuted, sd_2_save_name)
    num_saved_pairs += 1