# Compare performances of different post-merging normalization methods

This notebook contains the implementation to compare different post-merging normalization methods in Section 4.3 of our paper.

For ZipIt! experiments, please first run `PFM/get_zipit_permuted_model.ipynb` to obtain the permuted models and move them to the corresponding directories in the following code.

In [1]:
import sys
sys.path.append('..')

In [None]:
from source.utils.utils import load_model
from source.utils.connect import interpolate_state_dicts, eval_line
from source.utils.weight_matching import weight_matching
from source.utils.data_funcs import load_data
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from source.utils.train import validate
import numpy as np
from copy import deepcopy
from source.utils.logger import Logger
import matplotlib.pyplot as plt

In [None]:
Logger.setup_logging()
logger = Logger()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class config:
    model = 'cifar_vgg16'
    dataset = 'cifar10'
    print_freq = 100
    path = '../data' # path to dataset
    special_init = 'vgg_init'

trainset, testset = load_data(config.path, config.dataset)

trainloader = DataLoader(trainset, batch_size=128, shuffle=True)
testloader = DataLoader(testset, batch_size=128, shuffle=False)

criterion = nn.CrossEntropyLoss()

In [4]:
matching_alg = 'zipit' # 'wm' or 'zipit'
n = 11 # number of merged models to evaluate
config.n = n

### Evaluate different normalization strategies on top of the WM-based merging and ZipIt!-based merging

Weight Matching

In [5]:
if matching_alg == 'wm':
    sd_1 = torch.load('../ex_results/same_init_ex/cifar10/cifar_vgg16/diff_init/seed_20/model_1_160.pt', map_location=device)
    sd_2 = torch.load('../ex_results/same_init_ex/cifar10/cifar_vgg16/diff_init/seed_20/model_2_160.pt', map_location=device)
    model_1 = load_model(config, 3).to(device)
    model_2 = load_model(config, 3).to(device)
    model_1.load_state_dict(sd_1)
    model_2.load_state_dict(sd_2)
    
    # 1. REPAIR/RESCALE/Bias Calibration after Merging
    stats_dict = {}
    sd_list = []

    for i in range(5):
        cur_res_list = {}
        sd_2_wm, _ = weight_matching(config.model, sd_1, sd_2, device=device)
        sd_list.append(sd_2_wm)
        model_2_wm = load_model(config, 3).to(device)
        model_2_wm.load_state_dict(sd_2_wm)

        wm_stats = eval_line(model_1, model_2_wm, testloader, criterion, device, config, n=n, repair=None, name=config.model)
        repair_stats = eval_line(model_1, model_2_wm, testloader, criterion, device, config, n=n, repair='repair', bn_loader=trainloader, name=config.model)
        rescale_stats = eval_line(model_1, model_2_wm, testloader, criterion, device, config, n=n, repair='rescale', bn_loader=trainloader , name=config.model)
        bias_cal = eval_line(model_1, model_2_wm, testloader, criterion, device, config, n=n, repair=None, name=config.model, bias_norm=True)
        cur_res_list['wm'] = wm_stats
        cur_res_list['repair'] = repair_stats
        cur_res_list['rescale'] = rescale_stats
        cur_res_list['bias_cal'] = bias_cal
        stats_dict[i] = cur_res_list

    # 2. Bias Removal before Merging
    sd_wm_mid = interpolate_state_dicts(sd_1, sd_2_wm, 0.5)
    sd_ori = sd_wm_mid
    sd_removebias = deepcopy(sd_ori)

    keys = list(sd_1.keys())
    bias_keys = []
    for k in keys:
        if 'bias' in k:
            bias_keys.append(k)
    bias_keys = np.array(bias_keys)

    forward_test_acc_s = []
    forward_test_loss_s = []

    model_removebias = deepcopy(model_1)
    for i in range(len(bias_keys)+1):
        sd_removebias = deepcopy(sd_ori)
        for k in bias_keys[:i]:
            if 'bias' in k:
                sd_removebias[k].fill_(0)
        model_removebias.load_state_dict(sd_removebias)
        loss, acc, _, _ = validate(testloader, model_removebias, criterion, device, config)
        forward_test_acc_s.append(acc)
        forward_test_loss_s.append(loss)

    backward_test_acc_s = [forward_test_acc_s[0]]
    backward_test_loss_s = [forward_test_loss_s[0]]
    for i in range(1, len(bias_keys)+1):
        sd_removebias = deepcopy(sd_ori)
        for k in bias_keys[-i:]:
            if 'bias' in k:
                sd_removebias[k].fill_(0)
        model_removebias.load_state_dict(sd_removebias)
        loss, acc, _, _ = validate(testloader, model_removebias, criterion, device, config)
        backward_test_acc_s.append(acc)
        backward_test_loss_s.append(loss)
        
    plt.plot(forward_test_acc_s, label='forward')
    plt.plot(backward_test_acc_s, label='backward')
    plt.legend()
    plt.show()
    
    # choose the best result
    backward_test_acc_s = np.array(backward_test_acc_s)
    print(backward_test_acc_s)
    remove_last_k_bias = np.argmax(backward_test_acc_s)
    print(f'remove last {remove_last_k_bias} bias')
    alpha_s = np.linspace(0.0, 1.0, n)
    # alpha_s = [0.5]  # test 
    remove_bias_stasts_dict = {}
    for i in range(5):
        cur_res_list = {'bias_removal': {'loss': [], 'acc': []}}
        for alpha in alpha_s:
            sd_merged = interpolate_state_dicts(sd_1, sd_2_wm, alpha)
            sd_removebias = deepcopy(sd_merged)
            # remove bias
            for k in bias_keys[-remove_last_k_bias:]:
                if 'bias' in k:
                    sd_removebias[k].fill_(0)
            model_removebias.load_state_dict(sd_removebias)
            loss, acc, _, _ = validate(testloader, model_removebias, criterion, device, config)
            cur_res_list['bias_removal']['loss'].append(loss)
            cur_res_list['bias_removal']['acc'].append(acc)
        remove_bias_stasts_dict[i] = cur_res_list

    for i in remove_bias_stasts_dict:
        remove_bias_stasts_dict[i]['bias_removal']['loss'] = np.array(remove_bias_stasts_dict[i]['bias_removal']['loss'])
        remove_bias_stasts_dict[i]['bias_removal']['acc'] = np.array(remove_bias_stasts_dict[i]['bias_removal']['acc'])

ZipIt!

In [None]:
if matching_alg == 'zipit':
    stats_dict = {}
    sd_list = []
    for i in range(5):
        # 1. REPAIR/RESCALE/Bias Calibration after Merging
        sd_1 = torch.load(f'../PFM/checkpoints/cifar10_my_vgg16_{2*i+1}_zipit.pth')
        sd_2 = torch.load(f'../PFM/checkpoints/cifar10_my_vgg16_{2*(i+1)}_zipit.pth')
        model_1 = load_model(config, 3).to(device)
        model_2 = load_model(config, 3).to(device)
        model_1.load_state_dict(sd_1)
        model_2.load_state_dict(sd_2)

        cur_res_list = {}
        merging_stats = eval_line(model_1, model_2, testloader, criterion, device, config, n=n, repair=None, name=config.model)
        repair_stats = eval_line(model_1, model_2, testloader, criterion, device, config, n=n, repair='repair', bn_loader=trainloader, name=config.model)
        rescale_stats = eval_line(model_1, model_2, testloader, criterion, device, config, n=n, repair='rescale', bn_loader=trainloader , name=config.model)
        bias_cal = eval_line(model_1, model_2, testloader, criterion, device, config, n=n, repair=None, name=config.model, bias_norm=True)
        cur_res_list['zipit'] = merging_stats
        cur_res_list['repair'] = repair_stats
        cur_res_list['rescale'] = rescale_stats
        cur_res_list['bias_cal'] = bias_cal
        stats_dict[i] = cur_res_list

        # 2. Bias Removal before Merging
        sd_wm_mid = interpolate_state_dicts(sd_1, sd_2, 0.5)
        sd_ori = sd_wm_mid
        sd_removebias = deepcopy(sd_ori)

        keys = list(sd_1.keys())
        bias_keys = []
        for k in keys:
            if 'bias' in k:
                bias_keys.append(k)
        bias_keys = np.array(bias_keys)

        forward_test_acc_s = []
        forward_test_loss_s = []

        model_removebias = deepcopy(model_1)
        for i in range(len(bias_keys)+1):
            sd_removebias = deepcopy(sd_ori)
            for k in bias_keys[:i]:
                if 'bias' in k:
                    sd_removebias[k].fill_(0)
            model_removebias.load_state_dict(sd_removebias)
            loss, acc, _, _ = validate(testloader, model_removebias, criterion, device, config)
            forward_test_acc_s.append(acc)
            forward_test_loss_s.append(loss)

        backward_test_acc_s = [forward_test_acc_s[0]]
        backward_test_loss_s = [forward_test_loss_s[0]]
        for i in range(1, len(bias_keys)+1):
            sd_removebias = deepcopy(sd_ori)
            for k in bias_keys[-i:]:
                if 'bias' in k:
                    sd_removebias[k].fill_(0)
            model_removebias.load_state_dict(sd_removebias)
            loss, acc, _, _ = validate(testloader, model_removebias, criterion, device, config)
            backward_test_acc_s.append(acc)
            backward_test_loss_s.append(loss)

        plt.plot(forward_test_acc_s, label='forward')
        plt.plot(backward_test_acc_s, label='backward')
        plt.legend()
        plt.show()
        # choose the best result
        backward_test_acc_s = np.array(backward_test_acc_s)
        print(backward_test_acc_s)
        remove_last_k_bias = np.argmax(backward_test_acc_s)
        print(f'remove last {remove_last_k_bias} bias')
        alpha_s = np.linspace(0.0, 1.0, n)
        remove_bias_stasts_dict = {}
        for i in range(5):
            for i in range(5):
                sd_1 = torch.load(f'../PFM/checkpoints/cifar10_my_vgg16_{2*i+1}_zipit.pth')
                sd_2 = torch.load(f'../PFM/checkpoints/cifar10_my_vgg16_{2*(i+1)}_zipit.pth')
                model_removebias = load_model(config, 3).to(device)

                cur_res_list = {'bias_removal': {'loss': [], 'acc': []}}
                for alpha in alpha_s:
                    sd_merged = interpolate_state_dicts(sd_1, sd_2, alpha)
                    sd_removebias = deepcopy(sd_merged)
                    # remove bias
                    for k in bias_keys[-remove_last_k_bias:]:
                        if 'bias' in k:
                            sd_removebias[k].fill_(0)
                    model_removebias.load_state_dict(sd_removebias)
                    loss, acc, _, _ = validate(testloader, model_removebias, criterion, device, config)
                    cur_res_list['bias_removal']['loss'].append(loss)
                    cur_res_list['bias_removal']['acc'].append(acc)
                remove_bias_stasts_dict[i] = cur_res_list

        for i in remove_bias_stasts_dict:
            remove_bias_stasts_dict[i]['bias_removal']['loss'] = np.array(remove_bias_stasts_dict[i]['bias_removal']['loss'])
            remove_bias_stasts_dict[i]['bias_removal']['acc'] = np.array(remove_bias_stasts_dict[i]['bias_removal']['acc'])

Summarize results

In [7]:
summary_dict = {}
for i in stats_dict:
    summary_dict.update({i: {}})
    for method in stats_dict[i]:
        summary_dict[i][method] = {}
        summary_dict[i][method]['loss'] = stats_dict[i][method][:, 0].numpy()
        summary_dict[i][method]['acc'] = stats_dict[i][method][:, 1].numpy()

for i in summary_dict:
    summary_dict[i].update(remove_bias_stasts_dict[i])
# calculate the mean and std across 5 experiments
mean_dict = {}
std_dict = {}
for method in summary_dict[0]:
    mean_dict[method] = {}
    std_dict[method] = {}
    for metric in summary_dict[0][method]:
        mean_dict[method][metric] = np.mean([summary_dict[i][method][metric] for i in summary_dict], axis=0)
        std_dict[method][metric] = np.std([summary_dict[i][method][metric] for i in summary_dict], axis=0)
summary_dict['mean'] = mean_dict
summary_dict['std'] = std_dict

Save

In [8]:
import os
os.makedirs('CPAL_plots', exist_ok=True)
save_name = f'CPAL_plots/renorm_{matching_alg}_cifar10_vgg16.pth'
if os.path.exists(save_name):
    raise ValueError('file exists')
else:
    torch.save(summary_dict, save_name)

### Visualize

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

summary_dict = torch.load(f'CPAL_plots/renorm_{matching_alg}_cifar10_vgg16.pth')
mean_dict = summary_dict['mean']
std_dict = summary_dict['std']

methods = ['wm', 'zipit', 'repair', 'rescale', 'bias_cal', 'bias_removal']
methods_label_map = {'wm': 'WM', 'zipit': 'ZipIt!', 'repair': f'{matching_alg} + Repair', 'rescale': f'{matching_alg} + Rescale', 'bias_cal': f'{matching_alg} + Bias Cal.', 'bias_removal': f'{matching_alg} + Bias Removal'}

plt.figure(figsize=(8, 6))
plt.rcParams.update({'font.size': 14})
x_s = np.linspace(0, 1.1, n)

for method in methods:
    if method in mean_dict:
        plt.errorbar(x_s, mean_dict[method]['acc'], yerr=std_dict[method]['acc'], label=methods_label_map[method], capsize=5)

plt.xlabel(r'Merging Coefficient ($\alpha$)', fontsize=20)
plt.ylabel('Test Accuracy (%)', fontsize=20)
plt.yticks([0, 20, 40, 60, 80, 90])
plt.ylim(0, 95)
plt.title('Compare Post-Merging Normalization Strategy', fontsize=20)
plt.legend()
plt.grid(True)
plt.show()
