In [4]:
import torch
from Models import modelpool
from Preprocess import datapool
from torch import nn
from spiking_layer_ours import SPIKE_layer
import numpy as np
from copy import deepcopy
from utils_my import add_dimension
from utils_my import replace_maxpool2d_by_avgpool2d, replace_layer_by_tdlayer
from slerp import slerp
from collections import OrderedDict
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

def isActivation(name):
    if 'relu' in name.lower():
        return True
    return False

def replace_activation_by_spike(model, thresholds, thresholds1, n_steps, counter=0):
    thresholds_new = deepcopy(thresholds)
    thresholds_new1 = deepcopy(thresholds1)
    
    for name, module in model._modules.items():
        if hasattr(module,"_modules"):
            model._modules[name], counter, thresholds_new = replace_activation_by_spike(module, thresholds_new, thresholds_new1, n_steps, counter)
        if isActivation(module.__class__.__name__.lower()):
            thresholds_new[counter, n_steps:] = thresholds_new1[counter, 1] / n_steps  # thresholds_out_sum/n_steps# thresholds1[counter,1] / n_steps
            thresholds_new[counter, :n_steps] = thresholds_new1[counter, 0] / n_steps  # thresholds_inner_sum/n_steps#thresholds1[counter,0] / n_steps
            model._modules[name] = SPIKE_layer(thresholds_new[counter, n_steps:], thresholds_new[counter, 0:n_steps])
            counter += 1
    return model, counter, thresholds_new


def interpolate_state_dicts(state_dict_1, state_dict_2, weight,
                            bias_norm=False, use_slerp=False):
    if use_slerp:
        model_state = deepcopy(state_dict_1)
        for p_name in model_state:
            if "batches" not in p_name:
                model_state[p_name] = slerp(weight, state_dict_1[p_name], state_dict_2[p_name])
        return model_state
    elif bias_norm:
        model_state = deepcopy(state_dict_1)
        height = 0
        for p_name in model_state:
            if "batches" not in p_name:
                model_state[p_name].zero_()
                if "weight" in p_name:
                    model_state[p_name].add_(1.0 - weight, state_dict_1[p_name])
                    model_state[p_name].add_(weight, state_dict_2[p_name])
                    height += 1
                if "bias" in p_name:
                    model_state[p_name].add_((1.0 - weight)**height, state_dict_1[p_name])
                    model_state[p_name].add_(weight**height, state_dict_2[p_name])
                if "res_scale" in p_name:
                    model_state[p_name].add_(1.0 - weight, state_dict_1[p_name])
                    model_state[p_name].add_(weight, state_dict_2[p_name])
        return model_state
    else:
        return {key: (1 - weight) * state_dict_1[key] + 
                weight * state_dict_2[key] for key in state_dict_1.keys()}


def interpolate_multi_state_dicts(sd_s, weight_s, use_slerp=False):
    if use_slerp:
        sd_interpolated = deepcopy(sd_s[0])
        weight_interpolated = weight_s[0]
        for i in range(1, len(sd_s)):
            sd_next = sd_s[i]
            weight_next = weight_s[i]
            t = weight_next / (weight_interpolated + weight_next)
            sd_interpolated = interpolate_state_dicts(sd_interpolated, sd_next, t, use_slerp=use_slerp)
            weight_interpolated += weight_next
        return sd_interpolated
    else:
        sd_interpolated = deepcopy(sd_s[0])
        for key in sd_s[0].keys():
            sd_interpolated[key] = weight_s[0] * sd_s[0][key]
            for i in range(1, len(sd_s)):
                sd_interpolated[key] += weight_s[i] * sd_s[i][key]
        return sd_interpolated
        
        
def validate_snn(model, loader, n_steps, device, verbose=1, use_double=False):
    device_old = get_device(model)
    model.to(device)
    model.eval()
    total = 0
    correct = 0
    for data, target in loader:
        data = add_dimension(data, n_steps)
        data, target = data.to(device), target.to(device)
        if use_double:
            data = data.double()
        output = model(data, L=0, t=n_steps)
        output = torch.mean(output, dim=1)
        total += target.size(0)
        correct += (output.argmax(1) == target).sum().item()
    if verbose != 0:
        print('Accuracy of the network on the test images: %f' % (100 * correct / total))
    acc = 100 * correct / total
    model.to(device_old)
    return acc

def validate_snn_ensemble(models, loader, n_steps, device, verbose=1):
    device_old = get_device(models[0])
    for model in models:
        model.to(device)
        model.eval()
    total = 0
    correct = 0
    for data, target in loader:
        data = add_dimension(data, n_steps)
        data, target = data.to(device), target.to(device)
        output = torch.zeros(data.size(0), 10).to(device)
        for model in models:
            output += torch.mean(model(data, L=0, t=n_steps), dim=1)
        total += target.size(0)
        correct += (output.argmax(1) == target).sum().item()
    if verbose != 0:
        print('Accuracy of the network on the test images: %f' % (100 * correct / total))
    acc = 100 * correct / total
    for model in models:
        model.to(device_old)
    return acc
    
def validate_ann(model, loader, device, verbose=1, use_double=False):
    device_old = get_device(model)
    model.to(device)
    model.eval()
    total = 0
    correct = 0
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        if use_double:
            data = data.double()
        output = model(data)
        total += target.size(0)
        correct += (output.argmax(1) == target).sum().item()
    if verbose != 0:
        print('Accuracy of the network on the test images: %f' % (100 * correct / total))
    acc = 100 * correct / total
    model.to(device_old)
    return acc

def ann_to_snn(model, thresholds, thresholds1, n_steps):
    model, counter, thresholds_new = replace_activation_by_spike(model, thresholds, thresholds1, n_steps)
    model = replace_maxpool2d_by_avgpool2d(model)
    model = replace_layer_by_tdlayer(model)
    return model, thresholds_new


def bn_calibration_init(m):
    """ calculating post-statistics of batch normalization """
    if getattr(m, 'track_running_stats', False):
        # reset all values for post-statistics
        m.reset_running_stats()
        # set bn in training mode to update post-statistics
        m.training = True
        # use cumulative moving average
        m.momentum = None

def reset_bn_stats(model, device, bn_loader, n_steps, layerwise=False, ann=False, use_double=False):
    # Reset batch norm statistics
    model.to(device)
    # get batchnorm layer names
    bn_layer_names = [name for name, layer in model.named_modules() if isinstance(layer, (nn.BatchNorm2d))]
    for m in model.modules():
        bn_calibration_init(m)
    model.train()
    with torch.no_grad():
        L_range = range(1, 14) if layerwise else [0]
        for L in L_range:
            for data, _ in bn_loader:
                if not ann:
                    data = add_dimension(data, n_steps)
                data = data.to(device)
                if use_double:
                    data = data.double()
                model(data, L=L, t=n_steps)
            # set current bn layer in eval mode
            if layerwise:
                layer = dict(model.named_modules())[bn_layer_names[L-1]]
                layer.eval()
    model.eval()
    model.to('cpu')

def get_device(item):
    if isinstance(item, nn.Module):
        return next(item.parameters()).device
    elif isinstance(item, OrderedDict):
        return next(iter(item.values())).device

In [5]:
class args:
    model = 'vgg16'
    dataset = 'cifar10'
    batch_size = 128
    t = 1
device = torch.device('cuda:1')
criterion = nn.CrossEntropyLoss()
n_steps = args.t
train_loader, test_loader = datapool(args.dataset, args.batch_size, 0, shuffle=True)

Files already downloaded and verified


Files already downloaded and verified


In [6]:
# v_idx_s = list(range(2, 16)) + list(range(22, 37))
v_idx_s = list(range(2, 13))
# v_idx_s = [2]
num_models = len(v_idx_s) + 1

model_s = [modelpool(args.model, args.dataset) for _ in range(num_models)]
ckpt_model_idx = 3
sd = torch.load(f'saved_models/cifar10_vgg16_{ckpt_model_idx}.pth', map_location='cpu', weights_only=True)
for model in model_s:
    model.load_state_dict(sd)

num_relu = str(model_s[0]).count('ReLU')
# prefix = 'cifar10_vgg16_3_lr_1e3/'
prefix = ''
thresholds = torch.zeros(num_relu, 2*n_steps)
thresholds1 = torch.Tensor(np.load(f'{prefix}cifar10_vgg16_{ckpt_model_idx}_threshold_all_noaug{n_steps}.npy'))

In [7]:
model_s = [ann_to_snn(model, thresholds, thresholds1, n_steps) for model in model_s]
threshold_new = model_s[0][1]
model_s = [pair[0] for pair in model_s]

snn_init = modelpool(args.model, args.dataset)
ann_sd = torch.load(f'saved_models/cifar10_vgg16_{ckpt_model_idx}.pth', map_location='cpu', weights_only=True)
snn_init.load_state_dict(ann_sd)
snn_init = ann_to_snn(snn_init, thresholds, thresholds1, n_steps)[0]

ann_model = modelpool(args.model, args.dataset)
ann_sd = torch.load(f'saved_models/cifar10_vgg16_{ckpt_model_idx}.pth', map_location='cpu', weights_only=True)
ann_model.load_state_dict(ann_sd)

<All keys matched successfully>

In [8]:
# model_s = [ann_to_snn(model, thresholds, thresholds1, n_steps) for model in model_s]
sd_s = [torch.load(f'{prefix}cifar10_vgg16_{ckpt_model_idx}_updated_snn1_1.pth', weights_only=True, map_location='cpu')] + \
        [torch.load(f'{prefix}cifar10_vgg16_{ckpt_model_idx}_updated_snn1_1_v{v}.pth', weights_only=True, map_location='cpu') for v in v_idx_s]
for model, sd in zip(model_s, sd_s):
    model.load_state_dict(sd)

In [None]:
acc_s = [validate_snn(model, test_loader, n_steps, device, use_double=False) for model in model_s]
acc_ensemble = validate_snn_ensemble(model_s, test_loader, n_steps, device)
print(acc_s, acc_ensemble)

Accuracy of the network on the test images: 92.370000
Accuracy of the network on the test images: 92.430000
Accuracy of the network on the test images: 92.370000
Accuracy of the network on the test images: 92.450000
Accuracy of the network on the test images: 92.300000
Accuracy of the network on the test images: 92.420000
Accuracy of the network on the test images: 92.720000
Accuracy of the network on the test images: 92.570000
Accuracy of the network on the test images: 92.370000
Accuracy of the network on the test images: 92.570000
Accuracy of the network on the test images: 92.380000
Accuracy of the network on the test images: 92.690000
Accuracy of the network on the test images: 94.320000
[92.37, 92.43, 92.37, 92.45, 92.3, 92.42, 92.72, 92.57, 92.37, 92.57, 92.38, 92.69] 94.32


#### Merging

##### Test PFM

In [9]:
def PFM_outputs(model_merge_s, model_merged, data, L, n_steps):
    outputs_s = [model(data, L=L, t=n_steps) for model in model_merge_s]
    outputs_avg = torch.zeros_like(outputs_s[0])
    for outputs in outputs_s:
        outputs_avg += outputs
    outputs_avg /= len(outputs_s)
    outputs = model_merged(outputs_avg, L=0, t=n_steps, prev_L=L).mean(1)
    return outputs

In [10]:
model_merge_s = [deepcopy(model_s[i]) for i in [-3, 1, 3]]
num_merge_models = len(model_merge_s)
model_merge_mid = deepcopy(model_s[0])
sd_merge_mid = interpolate_multi_state_dicts([model.state_dict() for model in model_merge_s], [1/num_merge_models]*num_merge_models)
model_merge_mid.load_state_dict(sd_merge_mid)

for model_merge in model_merge_s:
    model_merge.to(device)
    model_merge.eval()
model_merge_mid.to(device)
model_merge_mid.eval()

validate_snn(model_merge_mid, test_loader, n_steps, device)
validate_snn_ensemble(model_merge_s, test_loader, n_steps, device)

Accuracy of the network on the test images: 92.380000
Accuracy of the network on the test images: 93.970000


93.97

In [54]:
cur_test_loader = test_loader
for L in range(1, 17):
    total = 0
    correct = 0
    for data, target in cur_test_loader:
        data = add_dimension(data, n_steps)
        data, target = data.to(device), target.to(device)
        outputs = PFM_outputs(model_merge_s, model_merge_mid, data, L, n_steps)
        total += target.size(0)
        correct += (outputs.argmax(1) == target).sum().item()

    print('Accuracy of the network on the test images: %f' % (100 * correct / total), f", L={L}")
    acc = 100 * correct / total

Accuracy of the network on the test images: 92.590000 , L=1
Accuracy of the network on the test images: 92.990000 , L=2
Accuracy of the network on the test images: 92.850000 , L=3
Accuracy of the network on the test images: 93.030000 , L=4
Accuracy of the network on the test images: 93.220000 , L=5
Accuracy of the network on the test images: 93.590000 , L=6
Accuracy of the network on the test images: 93.750000 , L=7
Accuracy of the network on the test images: 94.200000 , L=8
Accuracy of the network on the test images: 93.930000 , L=9
Accuracy of the network on the test images: 94.090000 , L=10
Accuracy of the network on the test images: 93.980000 , L=11
Accuracy of the network on the test images: 93.990000 , L=12
Accuracy of the network on the test images: 94.030000 , L=13
Accuracy of the network on the test images: 94.030000 , L=14
Accuracy of the network on the test images: 93.960000 , L=15
Accuracy of the network on the test images: 93.970000 , L=16


In [55]:
L = 3
total = 0
correct = 0
for data, target in test_loader:
    data = add_dimension(data, n_steps)
    data, target = data.to(device), target.to(device)
    outputs = PFM_outputs(model_merge_s, model_merge_mid, data, L, n_steps)
    total += target.size(0)
    correct += (outputs.argmax(1) == target).sum().item()
print('Accuracy of the network on the test images: %f' % (100 * correct / total), f", L={L}")
acc = 100 * correct / total

Accuracy of the network on the test images: 92.850000 , L=3


##### Try to improve merging on L = 1

In [58]:
print(acc_s)

[92.37, 92.43, 92.37, 92.45, 92.3, 92.42, 92.72, 92.57, 92.37, 92.57, 92.38, 92.69]


In [141]:
model_merge_s = [deepcopy(model_s[i]) for i in [0, 1]]
num_merge_models = len(model_merge_s)
model_merge_mid = deepcopy(model_s[0])
sd_merge_mid = interpolate_multi_state_dicts([model.state_dict() for model in model_merge_s], [1/num_merge_models]*num_merge_models)
model_merge_mid.load_state_dict(sd_merge_mid)

for model_merge in model_merge_s:
    model_merge.to(device)
    model_merge.eval()
model_merge_mid.to(device)
model_merge_mid.eval()

validate_snn(model_merge_mid, test_loader, n_steps, device)
validate_snn_ensemble(model_merge_s, test_loader, n_steps, device)

Accuracy of the network on the test images: 92.340000
Accuracy of the network on the test images: 93.380000


93.38

one batch

In [134]:
data, _ = next(iter(test_loader))
data = add_dimension(data, n_steps)
data = data.to(device)

pre activation difference

In [142]:
L = 1
outputs_s = [model.layer1[:2](data).mean(1) for model in model_merge_s]
outputs_avg = torch.zeros_like(outputs_s[0])
with torch.no_grad():
    for outputs in outputs_s:
        outputs_avg += outputs
    outputs_avg /= len(outputs_s)
    outputs_mid = model_merge_mid.layer1[:2](data).mean(1)

In [148]:
(outputs_mid[0] - outputs_avg[0]).abs().max()

tensor(0.0757, device='cuda:1')

In [159]:
threshold_new

tensor([[0.2019, 0.1009],
        [0.1956, 0.0978],
        [0.1320, 0.0660],
        [0.1984, 0.0992],
        [0.1295, 0.0647],
        [0.1269, 0.0634],
        [0.1491, 0.0745],
        [0.0798, 0.0399],
        [0.0590, 0.0295],
        [0.0704, 0.0352],
        [0.0393, 0.0197],
        [0.0481, 0.0241],
        [0.2805, 0.1402],
        [0.0458, 0.0229],
        [0.0836, 0.0418]])

In [151]:
diff = (outputs_mid[0] - outputs_avg[0]).abs().flatten()
# sort by descending order
diff, indices = diff.sort(descending=True)
diff

tensor([0.0757, 0.0724, 0.0698,  ..., 0.0000, 0.0000, 0.0000], device='cuda:1')

In [154]:
diff.shape

torch.Size([65536])

NameError: name 'threshold' is not defined

In [157]:
diff[1000]

tensor(0.0144, device='cuda:1')

In [160]:
L = 1
outputs_s = [model(data, L=L, t=n_steps).mean(1) for model in model_merge_s]
outputs_avg = torch.zeros_like(outputs_s[0])
with torch.no_grad():
    for outputs in outputs_s:
        outputs_avg += outputs
    outputs_avg /= len(outputs_s)
    outputs_mid = model_merge_mid(data, L=L, t=n_steps).mean(1)
nonzero_mask = (outputs_avg != 0) | (outputs_mid != 0)

In [59]:
diff = (outputs_avg[nonzero_mask] - outputs_mid[nonzero_mask]).flatten()

In [None]:
diff = diff[diff!=0]
# plot the cdf of the difference
plt.hist(diff.cpu().numpy(), bins=100, cumulative=True, density=True, histtype='step')

In [162]:
abs_diff_s = []
max_diff_s = []
scale_avg_s = []
scale_mid_s = []
mse_loss = nn.MSELoss()

for L in range(1, 17):
    # L = 1
    total = 0
    nonzero_avg = 0
    nonzero_mid = 0
    abs_diff = 0
    max_diff = 0
    scale_avg = 0
    scale_mid = 0
    with torch.no_grad():
        for data, target in train_loader:
            data = add_dimension(data, n_steps)
            data, target = data.to(device), target.to(device)
            outputs_avg = None
            for model_merge in model_merge_s:
                outputs = model_merge(data, L=L, t=n_steps).mean(1)
                if outputs_avg is None:
                    outputs_avg = outputs
                else:
                    outputs_avg += outputs
            outputs_avg /= len(outputs_s)
            outputs_mid = model_merge_mid(data, L=L, t=n_steps).mean(1)
            # post_act_loss += criterion(outputs_mid, outputs_avg) * target.size(0)
            neq_num = (outputs_mid != outputs_avg).sum()
            abs_diff += (outputs_mid - outputs_avg).abs().sum() / neq_num
            batch_max_diff = (outputs_mid - outputs_avg).abs().max()
            # calculate average scale of nonzero outputs
            scale_avg += outputs_avg[outputs_avg != 0].abs().sum()
            scale_mid += outputs_mid[outputs_mid != 0].abs().sum()
            max_diff = max(max_diff, batch_max_diff)
            total += target.size(0)
            nonzero_avg += (outputs_avg != 0).sum()
            nonzero_mid += (outputs_mid != 0).sum()
        abs_diff = abs_diff / total
        scale_avg = scale_avg / nonzero_avg
        scale_mid = scale_mid / nonzero_mid
        abs_diff_s.append(abs_diff)
        max_diff_s.append(max_diff)
        scale_avg_s.append(scale_avg)
        scale_mid_s.append(scale_mid)
        print(f"abs_diff: {abs_diff}, max_diff: {max_diff}, scale_avg ({nonzero_avg}): {scale_avg}, scale_mid ({nonzero_mid}): {scale_mid}, L={L}")

abs_diff: 0.0007899306365288794, max_diff: 0.20188574492931366, scale_avg (754814115): 0.1837773323059082, scale_mid (670974811): 0.20188568532466888, L=1
abs_diff: 0.0008815779001452029, max_diff: 0.19557975232601166, scale_avg (592921508): 0.1721639782190323, scale_mid (523387731): 0.19557981193065643, L=2
abs_diff: 0.0005983338924124837, max_diff: 0.13197259604930878, scale_avg (443334791): 0.1090385913848877, scale_mid (359660086): 0.1319725662469864, L=3
abs_diff: 0.0009561455808579922, max_diff: 0.19841402769088745, scale_avg (162789967): 0.15083763003349304, scale_mid (123600573): 0.19841410219669342, L=4
abs_diff: 0.0006224726093932986, max_diff: 0.12946902215480804, scale_avg (128215390): 0.09683099389076233, scale_mid (94190887): 0.12946906685829163, L=5
abs_diff: 0.0006244082469493151, max_diff: 0.12687665224075317, scale_avg (95757875): 0.09000459313392639, scale_mid (67899902): 0.12687665224075317, L=6
abs_diff: 0.0007428361568599939, max_diff: 0.14906498789787292, scale_a

Distill the pre-activation in the first layer

In [86]:
train_model_merge_mid = deepcopy(model_merge_mid)
train_model_merge_mid.to(device)
validate_snn(train_model_merge_mid, test_loader, n_steps, device)

Accuracy of the network on the test images: 92.340000


92.34

In [87]:
optimizer = torch.optim.SGD(train_model_merge_mid.parameters(), lr=4e-5, weight_decay=5e-4)
# only train the first two layers
train_model_merge_mid.requires_grad_(False)
train_model_merge_mid.layer1[0].requires_grad_(True)
train_model_merge_mid.layer1[1].requires_grad_(True)
epoch_loss_s = []

In [91]:
for _ in range(10):

    epoch_loss = 0
    total = 0
    for data, _ in train_loader:
        data = add_dimension(data, n_steps)
        data = data.to(device)
        with torch.no_grad():
            outputs_s = [model.layer1[:2](data).mean(1) for model in model_merge_s]
            outputs_avg = torch.zeros_like(outputs_s[0])
            for outputs in outputs_s:
                outputs_avg += outputs
            outputs_avg /= len(outputs_s)
            
        outputs_mid = train_model_merge_mid.layer1[:2](data).mean(1)
        loss = (outputs_mid - outputs_avg).abs().mean()
        # print(loss.item())
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * data.size(0)
        total += data.size(0)
    print(epoch_loss / total)
    epoch_loss_s.append(epoch_loss / total)
validate_snn(train_model_merge_mid, test_loader, n_steps, device)

0.0007875407226569951
0.0007164565998129547
0.0007435208694636822
0.0007602160020358861
0.0006976782626286149
0.0007211794837377965
0.0007060918065719306
0.0007507258039340377
0.0006767478171177208
0.000726337816286832
Accuracy of the network on the test images: 92.260000


92.26

activation matching

In [106]:
model_merge_1 = deepcopy(model_merge_s[0])
model_merge_1.to(device)
model_merge_2 = deepcopy(model_merge_s[1])
model_merge_2.to(device)

train_model_merge_mid = deepcopy(model_merge_mid)
train_model_merge_mid.to(device)
validate_snn(train_model_merge_mid, test_loader, n_steps, device)
validate_snn(model_merge_1, test_loader, n_steps, device)
validate_snn(model_merge_2, test_loader, n_steps, device)

Accuracy of the network on the test images: 92.340000
Accuracy of the network on the test images: 92.370000
Accuracy of the network on the test images: 92.430000


92.43

In [None]:
from einops import rearrange
from scipy.optimize import linear_sum_assignment

class OnlineMean:
    def __init__(self, num_features, device=None):
        self.sum = torch.zeros(num_features).to(device)
        self.count = 0

    def update(self, batch):
        # batch shape: (batch_size, channels, width, height)
        # or
        # batch shape: (batch_size, num_features)
        if len(batch.shape) == 4:
            self.sum += torch.sum(batch, dim=(0, 2, 3))
        elif len(batch.shape) == 2:
            self.sum += torch.sum(batch, dim=0)
        else:
            raise ValueError("batch shape must be (batch_size, channels, "
                             "width, height) or (batch_size, num_features)")
        self.count += batch.shape[0]

    def mean(self):
        return self.sum / self.count


class OnlineCovariance:
    def __init__(self, a_mean, b_mean, count, device=None):
        assert a_mean.shape == b_mean.shape
        d = a_mean.shape[0]
        # ensure brodcast calculation
        self.a_mean = torch.zeros_like(a_mean).double()
        self.b_mean = torch.zeros_like(b_mean).double()
        self.cov = torch.zeros((d, d)).to(device).double()
        self.std_a = torch.zeros_like(a_mean).double()
        self.std_b = torch.zeros_like(b_mean).double()
        self.count = count

    def update(self, a_batch, b_batch):
        assert a_batch.shape == b_batch.shape
        a_batch = a_batch.double()
        b_batch = b_batch.double()
        self.a_mean += a_batch.mean(dim=1) / self.count
        self.b_mean += b_batch.mean(dim=1) / self.count
        self.cov += torch.matmul(a_batch, b_batch.T) / a_batch.shape[1] / self.count
        self.std_a += a_batch.std(dim=1) / self.count
        self.std_b += b_batch.std(dim=1) / self.count

    def pearson_correlation(self):
        eps = 1e-4

        self.cov = self.cov - torch.outer(self.a_mean, self.b_mean)
        return torch.nan_to_num(self.cov / (torch.outer(self.std_a, self.std_b) + eps))

In [None]:
oldL_s = []
newL_s = []
for L in range(1, 17):
    model_merge_1.eval()
    model_merge_2.eval()
    dummy_input = train_loader.dataset[0][0].unsqueeze(0).to(device)
    with torch.no_grad():
        dummy_input = add_dimension(dummy_input, n_steps)
        dummy_out = model_merge_1(dummy_input, L=L, t=n_steps).mean(1)
    num_channels = dummy_out.shape[1]
    means_merge_1 = OnlineMean(num_channels, device)
    means_merge_2 = OnlineMean(num_channels, device)
    # run one epoch
    num_batches = len(train_loader)
    covs = OnlineCovariance(means_merge_1.mean(), means_merge_2.mean(), num_batches, device)
    for data, _ in train_loader:
        with torch.no_grad():
            data = data.to(device)
            data = add_dimension(data, n_steps)
            # forward
            acts_1 = model_merge_1(data, L=L, t=n_steps).mean(1)
            acts_2 = model_merge_2(data, L=L, t=n_steps).mean(1)
            # flatten activations: 'b c w h -> c (b h w)'
            # or 'b c -> c b'
            if len(acts_1.shape) == 4:
                acts_1 = rearrange(acts_1, 'b c w h -> c (b h w)')
                acts_2 = rearrange(acts_2, 'b c w h -> c (b h w)')
            elif len(acts_1.shape) == 2:
                acts_1 = rearrange(acts_1, 'b c -> c b')
                acts_2 = rearrange(acts_2, 'b c -> c b')
            covs.update(acts_1, acts_2)

    # matching
    perm_values = []
    # calculate permutation
    correlation = covs.pearson_correlation()
    ri, ci = linear_sum_assignment(correlation.cpu().detach().numpy(),
                                    maximize=True)
    ci = torch.from_numpy(ci).to(device)
    perm_values.append(ci)
    oldL = torch.einsum('ij,ij->i', correlation,
                        torch.eye(len(ci), device=device, dtype=torch.double)).sum()
    newL = torch.einsum('ij,ij->i', correlation,
                        torch.eye(len(ci), device=device, dtype=torch.double)[ci, :]).sum()
    print(f"L = {L}")
    print(f"oldL: {oldL}")
    print(f"newL: {newL}")
    print(f"newL - oldL: {newL - oldL}")
    oldL_s.append(oldL)
    newL_s.append(newL)

In [127]:
perm_values = []

for L in range(14, 16):
    model_merge_1.eval()
    model_merge_2.eval()
    dummy_input = train_loader.dataset[0][0].unsqueeze(0).to(device)
    with torch.no_grad():
        dummy_input = add_dimension(dummy_input, n_steps)
        dummy_out = model_merge_1(dummy_input, L=L, t=n_steps).mean(1)
    num_channels = dummy_out.shape[1]
    means_merge_1 = OnlineMean(num_channels, device)
    means_merge_2 = OnlineMean(num_channels, device)
    # run one epoch
    num_batches = len(train_loader)
    covs = OnlineCovariance(means_merge_1.mean(), means_merge_2.mean(), num_batches, device)
    for data, _ in train_loader:
        with torch.no_grad():
            data = data.to(device)
            data = add_dimension(data, n_steps)
            # forward
            acts_1 = model_merge_1(data, L=L, t=n_steps).mean(1)
            acts_2 = model_merge_2(data, L=L, t=n_steps).mean(1)
            # flatten activations: 'b c w h -> c (b h w)'
            # or 'b c -> c b'
            if len(acts_1.shape) == 4:
                acts_1 = rearrange(acts_1, 'b c w h -> c (b h w)')
                acts_2 = rearrange(acts_2, 'b c w h -> c (b h w)')
            elif len(acts_1.shape) == 2:
                acts_1 = rearrange(acts_1, 'b c -> c b')
                acts_2 = rearrange(acts_2, 'b c -> c b')
            covs.update(acts_1, acts_2)

    # matching
    # calculate permutation
    correlation = covs.pearson_correlation()
    ri, ci = linear_sum_assignment(correlation.cpu().detach().numpy(),
                                    maximize=True)
    ci = torch.from_numpy(ci).to(device)
    perm_values.append(ci)
    oldL = torch.einsum('ij,ij->i', correlation,
                        torch.eye(len(ci), device=device, dtype=torch.double)).sum()
    newL = torch.einsum('ij,ij->i', correlation,
                        torch.eye(len(ci), device=device, dtype=torch.double)[ci, :]).sum()
    print(f"L = {L}")
    print(f"oldL: {oldL}")
    print(f"newL: {newL}")
    print(f"newL - oldL: {newL - oldL}")

L = 14
oldL: 2312.218413115928
newL: 2327.8608753852195
newL - oldL: 15.642462269291627
L = 15
oldL: 2082.692316341654
newL: 2087.361372362972
newL - oldL: 4.669056021317829


In [128]:
model_merge_2_am = deepcopy(model_merge_2)

ci = perm_values[0]
model_merge_2_am.classifier[1].layer.module.weight.data = model_merge_2_am.classifier[1].layer.module.weight.data[ci]
model_merge_2_am.classifier[1].layer.module.bias.data = model_merge_2_am.classifier[1].layer.module.bias.data[ci]
model_merge_2_am.classifier[4].layer.module.weight.data = torch.index_select(model_merge_2_am.classifier[4].layer.module.weight.data, 1, ci)

ci = perm_values[0]
model_merge_2_am.classifier[4].layer.module.weight.data = model_merge_2_am.classifier[4].layer.module.weight.data[ci]
model_merge_2_am.classifier[4].layer.module.bias.data = model_merge_2_am.classifier[4].layer.module.bias.data[ci]
model_merge_2_am.classifier[7].layer.module.weight.data = torch.index_select(model_merge_2_am.classifier[7].layer.module.weight.data, 1, ci)

sd_merge_mid_am = interpolate_state_dicts(model_merge_1.state_dict(), model_merge_2_am.state_dict(), 0.5)
model_merge_mid.load_state_dict(sd_merge_mid_am)

<All keys matched successfully>

In [132]:
validate_snn(model_merge_mid, test_loader, n_steps, device)
reset_bn_stats(model_merge_mid, device, train_loader, n_steps, layerwise=False, ann=False, use_double=False)
validate_snn(model_merge_mid, test_loader, n_steps, device)

Accuracy of the network on the test images: 92.240000
Accuracy of the network on the test images: 91.990000


91.99