This notebook provides an example comparing different implementations of the activation matching, including: 

1. The one in the "Git Re-Basin: Merging Models modulo Permutation Symmetries" paper,

2. The one in"REPAIR: REnormalizing Permuted Activations for Interpolation Repair" paper,

3. Ours.

In [5]:
import sys
sys.path.append("..")

In [14]:
import torch
from source.utils.utils import load_model
from source.utils.data_funcs import load_data
from source.utils.activation_matching import activation_matching
from source.utils.connect import interpolate_state_dicts
from source.utils.logger import Logger
from torch.utils.data import DataLoader
from source.utils.train import validate
import numpy as np
from copy import deepcopy

from time import time

In [7]:
class config:
    model = 'cifar_vgg16'
    dataset = 'cifar10'
    special_init = None # 'vgg_init' (kaiming init) or None (uniform init)
    print_freq = 100
    data_dir = '../../Linear_Mode_Connectivity/data'

In [8]:
# necessary to create logger if using the train/validate/eval_line etc. functions
Logger.setup_logging()
logger = Logger()

trainset, testset = load_data(config.data_dir, config.dataset)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True)
testloader = DataLoader(testset, batch_size=256, shuffle=False)

trainset_noaug, _ = load_data(config.data_dir, config.dataset, no_random_aug=True)
trainloader_noaug = DataLoader(trainset_noaug, batch_size=128, shuffle=True)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
criterion = torch.nn.CrossEntropyLoss()

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [9]:
def cal_barrier(end_1, end_2, mid, type='loss'):
    if type == 'loss':
        return mid - (end_1 + end_2) / 2
    elif type == 'acc':
        return (end_1 + end_2) / 2 - mid
    else:
        raise NotImplementedError

Activation code for matching VGG type of networks from the paper "REPAIR: REnormalizing Permuted Activations for Interpolation Repair"

Repo: https://github.com/KellerJordan/REPAIR

In [10]:
import scipy
from tqdm import tqdm
import torch.nn as nn

# given two networks net0, net1 which each output a feature map of shape NxCxWxH
# this will reshape both outputs to (N*W*H)xC
# and then compute a CxC correlation matrix between the outputs of the two networks
def run_corr_matrix(net0, net1, epochs=1, norm=True, loader=trainloader_noaug):
    n = epochs*len(loader)
    mean0 = mean1 = std0 = std1 = None
    with torch.no_grad():
        net0.eval()
        net1.eval()
        for _ in range(epochs):
            for i, (images, _) in enumerate(tqdm(loader)):
                img_t = images.float().cuda()
                out0 = net0(img_t)
                out0 = out0.reshape(out0.shape[0], out0.shape[1], -1).permute(0, 2, 1)
                out0 = out0.reshape(-1, out0.shape[2]).double()

                out1 = net1(img_t)
                out1 = out1.reshape(out1.shape[0], out1.shape[1], -1).permute(0, 2, 1)
                out1 = out1.reshape(-1, out1.shape[2]).double()

                mean0_b = out0.mean(dim=0)
                mean1_b = out1.mean(dim=0)
                std0_b = out0.std(dim=0)
                std1_b = out1.std(dim=0)
                outer_b = (out0.T @ out1) / out0.shape[0]

                if i == 0:
                    mean0 = torch.zeros_like(mean0_b)
                    mean1 = torch.zeros_like(mean1_b)
                    std0 = torch.zeros_like(std0_b)
                    std1 = torch.zeros_like(std1_b)
                    outer = torch.zeros_like(outer_b)
                mean0 += mean0_b / n
                mean1 += mean1_b / n
                std0 += std0_b / n
                std1 += std1_b / n
                outer += outer_b / n

    cov = outer - torch.outer(mean0, mean1)
    if norm:
        corr = cov / (torch.outer(std0, std1) + 1e-4)
        return corr
    else:
        return cov

def get_layer_perm1(corr_mtx):
    corr_mtx_a = corr_mtx.cpu().numpy()
    row_ind, col_ind = scipy.optimize.linear_sum_assignment(corr_mtx_a, maximize=True)
    assert (row_ind == np.arange(len(corr_mtx_a))).all()
    perm_map = torch.tensor(col_ind).long()
    return perm_map

# returns the channel-permutation to make layer1's activations most closely
# match layer0's.
def get_layer_perm(net0, net1):
    corr_mtx = run_corr_matrix(net0, net1)
    return get_layer_perm1(corr_mtx)

# modifies the weight matrices of a convolution and batchnorm
# layer given a permutation of the output channels
def permute_output(perm_map, conv, bn):
    pre_weights = [
        conv.weight,
    ]
    if conv.bias is not None:
        pre_weights.append(conv.bias)
    if bn is not None:
        pre_weights.extend([
            bn.weight,
            bn.bias,
            bn.running_mean,
            bn.running_var,
        ])
    for w in pre_weights:
        w.data = w[perm_map]

# modifies the weight matrix of a layer for a given permutation of the input channels
# works for both conv2d and linear
def permute_input(perm_map, layer):
    w = layer.weight
    w.data = w[:, perm_map]

def subnet(model, n_layers):
    return model.features[:n_layers]

def activation_matching_v1(model_1, model_2):
    model0 = model_1
    model1 = deepcopy(model_2)

    feats1 = model1.features

    n = len(feats1)
    for i in range(n):
        layer = feats1[i]
        if isinstance(layer, nn.Conv2d):
            # get permutation and permute output of conv and maybe bn
            if isinstance(feats1[i+1], nn.BatchNorm2d):
                assert isinstance(feats1[i+2], nn.ReLU)
                perm_map = get_layer_perm(subnet(model0, i+3), subnet(model1, i+3))
                permute_output(perm_map, feats1[i], feats1[i+1])
            else:
                assert isinstance(feats1[i+1], nn.ReLU)
                perm_map = get_layer_perm(subnet(model0, i+2), subnet(model1, i+2))
                permute_output(perm_map, feats1[i], None)
            # look for succeeding layer to permute input
            next_layer = None
            for j in range(i+1, n):
                if isinstance(feats1[j], nn.Conv2d):
                    next_layer = feats1[j]
                    break
            if next_layer is None:
                next_layer = model1.classifier[0]
            permute_input(perm_map, next_layer)
    return model1

In [11]:
sd_1 = torch.load(f'../../Linear_Mode_Connectivity/same_init_ex/{config.dataset}/{config.model}/diff_init/seed_20/model_1_160.pt', map_location=device)
sd_2 = torch.load(f'../../Linear_Mode_Connectivity/same_init_ex/{config.dataset}/{config.model}/diff_init/seed_20/model_2_160.pt', map_location=device)

model_1 = load_model(config).to(device)
model_2 = load_model(config).to(device)
model_1.load_state_dict(sd_1)
model_2.load_state_dict(sd_2)

<All keys matched successfully>

In [None]:
start = time()
sd_2_am, _ = activation_matching(config.model, model_1, model_2, trainloader_noaug, print_freq=100, device=device)
am_time = time() - start

start = time()
sd_2_git_am, _ = activation_matching(config.model, model_1, model_2, trainloader_noaug, print_freq=100, device=device, type='git')
git_am_time = time() - start

start = time()
model_2_repair_am = activation_matching_v1(model_1, model_2) 
sd_2_repair_am = model_2_repair_am.state_dict()
repair_am_time = time() - start

In [15]:
sd_am_mid = interpolate_state_dicts(sd_1, sd_2_am, 0.5)
sd_git_am_mid = interpolate_state_dicts(sd_1, sd_2_git_am, 0.5)
sd_repair_am_mid = interpolate_state_dicts(sd_1, sd_2_repair_am, 0.5)

model_am_mid = load_model(config).to(device)
model_git_am_mid = load_model(config).to(device)
model_repair_am_mid = load_model(config).to(device)

model_am_mid.load_state_dict(sd_am_mid)
model_git_am_mid.load_state_dict(sd_git_am_mid)
model_repair_am_mid.load_state_dict(sd_repair_am_mid)

<All keys matched successfully>

In [None]:
train_loss_1, train_acc_1, _, _ = validate(trainloader_noaug, model_1, criterion, device, config)
test_loss_1, test_acc_1, _, _ = validate(testloader, model_1, criterion, device, config)
train_loss_2, train_acc_2, _, _ = validate(trainloader_noaug, model_2, criterion, device, config)
test_loss_2, test_acc_2, _, _ = validate(testloader, model_2, criterion, device, config)

train_loss_am_mid, train_acc_am_mid, _, _ = validate(trainloader_noaug, model_am_mid, criterion, device, config)
test_loss_am_mid, test_acc_am_mid, _, _ = validate(testloader, model_am_mid, criterion, device, config)
train_loss_git_am_mid, train_acc_git_am_mid, _, _ = validate(trainloader_noaug, model_git_am_mid, criterion, device, config)
test_loss_git_am_mid, test_acc_git_am_mid, _, _ = validate(testloader, model_git_am_mid, criterion, device, config)
train_loss_repair_am_mid, train_acc_repair_am_mid, _, _ = validate(trainloader_noaug, model_repair_am_mid, criterion, device, config)
test_loss_repair_am_mid, test_acc_repair_am_mid, _, _ = validate(testloader, model_repair_am_mid, criterion, device, config)

end_1 = torch.tensor([train_loss_1, train_acc_1, test_loss_1, test_acc_1])
end_2 = torch.tensor([train_loss_2, train_acc_2, test_loss_2, test_acc_2])
end_am_mid = torch.tensor([train_loss_am_mid, train_acc_am_mid, test_loss_am_mid, test_acc_am_mid])
end_git_am_mid = torch.tensor([train_loss_git_am_mid, train_acc_git_am_mid, test_loss_git_am_mid, test_acc_git_am_mid])
end_repair_am_mid = torch.tensor([train_loss_repair_am_mid, train_acc_repair_am_mid, test_loss_repair_am_mid, test_acc_repair_am_mid])

In [17]:
barrier_am = [None] * 4
barrier_git_am = [None] * 4
barrier_repair_am = [None] * 4

barrier_am[0] = cal_barrier(train_loss_1, train_loss_2, train_loss_am_mid, type='loss')
barrier_am[1] = cal_barrier(train_acc_1, train_acc_2, train_acc_am_mid, type='acc')
barrier_am[2] = cal_barrier(test_loss_1, test_loss_2, test_loss_am_mid, type='loss')
barrier_am[3] = cal_barrier(test_acc_1, test_acc_2, test_acc_am_mid, type='acc')

barrier_git_am[0] = cal_barrier(train_loss_1, train_loss_2, train_loss_git_am_mid, type='loss')
barrier_git_am[1] = cal_barrier(train_acc_1, train_acc_2, train_acc_git_am_mid, type='acc')
barrier_git_am[2] = cal_barrier(test_loss_1, test_loss_2, test_loss_git_am_mid, type='loss')
barrier_git_am[3] = cal_barrier(test_acc_1, test_acc_2, test_acc_git_am_mid, type='acc')

barrier_repair_am[0] = cal_barrier(train_loss_1, train_loss_2, train_loss_repair_am_mid, type='loss')
barrier_repair_am[1] = cal_barrier(train_acc_1, train_acc_2, train_acc_repair_am_mid, type='acc')
barrier_repair_am[2] = cal_barrier(test_loss_1, test_loss_2, test_loss_repair_am_mid, type='loss')
barrier_repair_am[3] = cal_barrier(test_acc_1, test_acc_2, test_acc_repair_am_mid, type='acc')

barrier_am = torch.tensor(barrier_am)
barrier_git_am = torch.tensor(barrier_git_am)
barrier_repair_am = torch.tensor(barrier_repair_am)

In [22]:
print(f'Time (ours): {am_time:.2f}s')
print(f'Time (from Git): {git_am_time:.2f}s')
print(f'Time (from REPAIR): {repair_am_time:.2f}s')
print('Test Accuracy:')
print(f'Barrier (ours): {barrier_am[-1]:.2f}%')
print(f'Barrier (from Git): {barrier_git_am[-1]:.2f}%')  
print(f'Barrier (from REPAIR): {barrier_repair_am[-1]:.2f}%')

Time (ours): 28.48s
Time (from Git): 14.34s
Time (from REPAIR): 81.37s
Test Accuracy:
Barrier (ours): 42.22%
Barrier (from Git): 67.47%
Barrier (from REPAIR): 42.28%
