In [58]:

import os
import torch
import random
from lmc_utils import BatchScale1d, BatchScale2d
from copy import deepcopy
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt
from utils import get_config_from_name, get_device, prepare_experiment_config, get_merging_fn
from model_merger import ModelMerge
from torch import nn

def validate(model, testloader, criterion, device, half=False, num_iters=None, print_freq=None):
    model.eval()
    correct = 0
    total = 0
    loss_sum = 0

    with torch.no_grad():
        it = 0
        for data in tqdm(testloader):
            if num_iters is not None and it >= num_iters:
                break
            images, labels = data
            images = images.to(device).float()
            labels = labels.to(device).long()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss_sum += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            it += 1
            if print_freq is not None and it % print_freq == 0:
                print('Accuracy so far: {}%'.format(100 * correct / total))
            
            # it += 1
            # if it % 10:
            #     print('Accuracy so far: {}'.format(100 * correct / total))
        
    print('Accuracy of the network on the 10000 test images: {}%'.format(100 * correct / total))
    return loss_sum / total, correct / total


def imagenet_reset_bn_stats(model, loader, reset=True, num_iters=None):
    """Reset batch norm stats if nn.BatchNorm2d present in the model."""
    device = get_device(model)
    has_bn = False
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) in (nn.BatchNorm2d, BatchScale2d, BatchScale1d, nn.BatchNorm1d):
            if reset:
                m.momentum = None # use simple average
                m.reset_running_stats()
            has_bn = True

    if not has_bn:
        return model

    # run a single train epoch with augmentations to recalc stats
    model.train()
    iter = 0
    with torch.no_grad():
        for images, _ in tqdm(loader, desc='Resetting batch norm'):
            if images.shape[0] == 1:
                break
            if num_iters is not None and iter >= num_iters:
                break
            if iter == len(loader): # hack for fractional loader
                break
            images = images.to(device).float()
            _ = model(images)
            iter += 1
    model.eval()
    return model

In [None]:
config_name = 'imagenet_vgg16'

device = 'cuda:1'
raw_config = get_config_from_name(config_name, device=device)  # 返回config文件中的字典，添加了一个device键值对
cur_config = deepcopy(raw_config)
config = prepare_experiment_config(cur_config)

train_loader = config['data']['train']['full']
test_loader = config['data']['test']['full']
test_loader.num_workers = 0


base_models = [base_model for base_model in config['models']['bases']]
Grapher = config['graph']
merging_fn = 'match_tensors_permute'
# merging_fn = 'match_tensors_zipit'

criterion = torch.nn.CrossEntropyLoss()

Preparing Models




In [3]:
# get prefix nodes
from graphs.base_graph import NodeType

graph = Grapher(deepcopy(base_models[0])).graphify().G
prefix_nodes = []
for node in graph.nodes:
    info = graph.nodes[node]
    if info['type'] == NodeType.PREFIX:
        prefix_nodes.append(node)
prefix_nodes

# prefix_nodes = [3, 6, 10, 13, 17, 20, 23, 27, 30, 33, 37, 40, 43]
# 21 42 63
# 3 7 11

[3, 6, 10, 13, 17, 20, 23, 27, 30, 33, 37, 40, 43, 47, 50]

In [4]:
graphs = [Grapher(deepcopy(base_model)).graphify() for base_model in base_models]

merging_fn = 'match_tensors_permute' # 'match_tensors_permute'
start_at = 50  # [3, 6, 10, 13, 17, 20, 23, 27, 30, 33, 37, 40, 43, 47, 50]
stop_at = None

Merge = ModelMerge(*graphs, device=device)
prepared_metrics = torch.load('pfm_results/imagenet_vgg16_metrics.pth', weights_only=True)
for key in prepared_metrics:
    for key2 in prepared_metrics[key]:
        prepared_metrics[key][key2] = prepared_metrics[key][key2].to(device)
Merge.metrics = prepared_metrics

Merge.transform(
    deepcopy(config['models']['new']), 
    train_loader, 
    transform_fn=get_merging_fn(merging_fn),
    metric_classes=config['metric_fns'],
    stop_at=stop_at,
    start_at=start_at,
    prepared_metrics=prepared_metrics,
    # a=0.3,
    # b=0.8
)


imagenet_reset_bn_stats(Merge, train_loader)

merged_model_backup = deepcopy(Merge.merged_model)

Computing transformations: 


In [5]:
validate(Merge, test_loader, criterion, device)

100%|██████████| 1563/1563 [02:04<00:00, 12.56it/s]

Accuracy of the network on the 10000 test images: 71.314%





(0.04456307440429926, 0.71314)

### get permuted model

In [6]:
conv_linear_module_names = []
for name, module in base_models[0].named_modules():
    if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
        conv_linear_module_names.append(name)
conv_linear_module_names = conv_linear_module_names[:-1] # remove the last linear layer

# permute the model
total_nodes = list(Merge.merges.keys())
base_model_merge_s = [deepcopy(base_model) for base_model in base_models]

for model_idx, base_model_merge in enumerate(base_model_merge_s):
    for i, (last_node, node, layer_name) in enumerate(zip([None]+total_nodes[:-1], total_nodes, conv_linear_module_names[-len(total_nodes):])):
        # print(layer_name)
        if 'features' in layer_name:
            idx = int(layer_name.split('.')[1])
            layer = base_model_merge.features[idx]
        elif 'classifier' in layer_name:
            idx = int(layer_name.split('.')[1])
            layer = base_model_merge.classifier[idx]
        merge_matrix = Merge.merges[node][model_idx]
        # print(merge_matrix.shape)
        if i > 0:
            unmerge_matrix = Merge.unmerges[last_node][model_idx]
            # print(unmerge_matrix.shape)
            # hack for avgpool
            if layer_name == 'classifier.0':
                group_size = 7 * 7
                num_groups = layer.weight.shape[1] // (group_size)
                weight_matrix_grouped = layer.weight.data.view(-1, num_groups, group_size)
                permuted_groups = torch.einsum('ij,bjk->bik', unmerge_matrix, weight_matrix_grouped)
                layer.weight.data = permuted_groups.reshape(-1, num_groups * group_size)
            elif 'features' in layer_name:
                layer.weight.data = torch.einsum('OIHW,IU->OUHW', layer.weight.data, unmerge_matrix)
            else:
                layer.weight.data = layer.weight.data @ unmerge_matrix
    
        if 'features' in layer_name:
            layer.weight.data = torch.einsum('UO,OIHW->UIHW', merge_matrix, layer.weight.data)
        elif 'classifier' in layer_name:
            layer.weight.data = merge_matrix @ layer.weight.data
        if hasattr(layer, 'bias') and layer.bias is not None:
            layer.bias.data = merge_matrix @ layer.bias.data
    unmerge_matrix = Merge.unmerges[node][model_idx]
    
    base_model_merge.classifier[4].weight.data = base_model_merge.classifier[4].weight.data @ unmerge_matrix

In [7]:
# # baseline to check correctness
# from lmc_utils import interpolate_state_dicts

# sd_mid = interpolate_state_dicts(base_model_merge_s[1].state_dict(), base_model_merge_s[0].state_dict(), 0.5)

# for k in sd_mid.keys():
#     print(k,":",  (sd_mid[k] - Merge.merged_model.state_dict()[k]).abs().max().item())

### bias correction

In [8]:
from lmc_utils import interpolate_state_dicts

Merge.merged_model = deepcopy(merged_model_backup)
merged_sd = Merge.merged_model.state_dict()
merged_sd = interpolate_state_dicts(merged_sd, merged_sd, 0.5, True)
# merged_model = 
Merge.merged_model.load_state_dict(deepcopy(merged_sd))
# Merge.merged_model = merged_model.to(device)
validate(Merge, test_loader, criterion, device)

  0%|          | 0/1563 [00:00<?, ?it/s]

100%|██████████| 1563/1563 [02:04<00:00, 12.57it/s]

Accuracy of the network on the 10000 test images: 71.278%





(0.04460170504137874, 0.71278)

In [9]:
train_loader.batch_size = 32

### remove bias

In [10]:
# sd_ori = merged_model_backup.state_dict()
# sd_removebias = deepcopy(sd_ori)

# keys = list(sd_ori.keys())
# bias_keys = []
# for k in keys:
#     if 'bias' in k:
#         bias_keys.append(k)
# bias_keys = np.array(bias_keys)


# # forward_test_acc_s = []
# # forward_test_loss_s = []

# Merge.merged_model = deepcopy(merged_model_backup)
# model_removebias = Merge.merged_model
# # for i in range(len(bias_keys)+1):
# #     sd_removebias = deepcopy(sd_ori)
# #     for k in bias_keys[:i]:
# #         if 'bias' in k:
# #             sd_removebias[k].fill_(0)
# #     model_removebias.load_state_dict(sd_removebias)
# #     loss, acc = validate(Merge, test_loader, criterion, device, num_iters=30, print_freq=10)
# #     forward_test_acc_s.append(acc)
# #     forward_test_loss_s.append(loss)

# backward_test_acc_s = []
# backward_test_loss_s = []
# for i in range(1, len(bias_keys)+1):
#     sd_removebias = deepcopy(sd_ori)
#     for k in bias_keys[-i:]:
#         if 'bias' in k:
#             sd_removebias[k].fill_(0)
#     model_removebias.load_state_dict(sd_removebias)
#     loss, acc = validate(Merge, test_loader, criterion, device, num_iters=30, print_freq=10)
#     backward_test_acc_s.append(acc)
#     backward_test_loss_s.append(loss)

In [11]:
# import matplotlib.pyplot as plt

# plt.plot(backward_test_acc_s, label='backward')
# plt.legend()

### repair/rescale

In [15]:
from lmc_utils import repair, ResetLayer, RescaleLayer, TrackLayer

def make_repaired_imagenet_vgg16(net, device=None):
    net1 = deepcopy(net).to(device)
    for i, layer in enumerate(net1.features):
        if isinstance(layer, (nn.Conv2d)):
            net1.features[i] = ResetLayer(layer)
    for i, layer in enumerate(net1.classifier):
        if i < 4 and isinstance(layer, nn.Linear):
            net1.classifier[i] = ResetLayer(layer)
    return net1.eval().to(device)

def make_rescaled_imagenet_vgg16(net, device=None):
    net1 = deepcopy(net).to(device)
    for i, layer in enumerate(net1.features):
        if isinstance(layer, (nn.Conv2d)):
            net1.features[i] = RescaleLayer(layer)
    for i, layer in enumerate(net1.classifier):
        if i < 4 and isinstance(layer, nn.Linear):
            net1.classifier[i] = RescaleLayer(layer)
    return net1.eval().to(device)


def make_tracked_imagenet_vgg16(net, device=None):
    net1 = deepcopy(net)
    for i, layer in enumerate(net1.features):
        if isinstance(layer, (nn.Conv2d)):
            net1.features[i] = TrackLayer(layer)
    for i, layer in enumerate(net1.classifier):
        if i < 4 and isinstance(layer, nn.Linear):
            net1.classifier[i] = TrackLayer(layer)
    return net1.eval().to(device)

# model_tracked_s = [make_tracked_net(model, device, name) for model in model_tracked_s]
# for model in model_tracked_s:
#     reset_bn_stats(model, loader)

In [16]:
model_tracked_s = [make_tracked_imagenet_vgg16(model, device) for model in base_model_merge_s]
for model in model_tracked_s:
    imagenet_reset_bn_stats(model, train_loader, num_iters=2)
num_models = len(model_tracked_s)
alpha_s = [1/num_models] * num_models

Resetting batch norm:   1%|          | 2/313 [00:01<04:37,  1.12it/s]
Resetting batch norm:   1%|          | 2/313 [00:00<00:35,  8.78it/s]


In [17]:
means = [[], []]
stds = [[], []]
goal_means = []
goal_stds = []
for layers in zip(*[model_tracked.modules() for model_tracked in model_tracked_s]):
    
    if not isinstance(layers[0], TrackLayer):
        continue
    # get neuronal statistics of original networks
    mu_s = [layer.get_stats()[0] for layer in layers[:-1]]
    std_s = [layer.get_stats()[1] for layer in layers[:-1]]
    
    goal_mean = sum([alpha * mu for alpha, mu in zip(alpha_s, mu_s)])
    goal_std = sum([alpha * std for alpha, std in zip(alpha_s, std_s)])
    
    means[0].append(mu_s)
    stds[0].append(std_s)
    goal_means.append(goal_mean)
    goal_stds.append(goal_std)

In [None]:
# activation_statistics = {'means': means, 'stds': stds, 'goal_means': goal_means, 'goal_stds': goal_stds}
# activation_statistics_save_path = 'pfm_results/imagenet_vgg16_activation_statistics.pth'
# torch.save(activation_statistics, activation_statistics_save_path)

In [None]:
# activation_statistics_save_path = 'pfm_results/imagenet_vgg16_activation_statistics.pth'
# activation_statistics = torch.load(activation_statistics_save_path)

# goal_means = activation_statistics['goal_means']
# goal_stds = activation_statistics['goal_stds']

In [18]:
Merge.merged_model = deepcopy(merged_model_backup)
model_repaired = Merge.merged_model
variant = 'rescale'

if variant == 'repair':
    model_repaired = make_repaired_imagenet_vgg16(model_repaired, device)
elif variant == 'rescale':
    model_repaired = make_rescaled_imagenet_vgg16(model_repaired, device)

i = 0
for layer in model_repaired.modules():
    if isinstance(layer, (ResetLayer)):
        layer.set_stats(goal_means[i], goal_stds[i])
        i += 1
    elif isinstance(layer, (RescaleLayer)):
        layer.set_stats(goal_stds[i])
        i += 1

Merge.merged_model = model_repaired
imagenet_reset_bn_stats(Merge, train_loader)
validate(Merge, test_loader, criterion, device, num_iters=30, print_freq=10)

Resetting batch norm: 100%|██████████| 313/313 [00:24<00:00, 12.59it/s]
  1%|          | 11/1563 [00:01<02:15, 11.42it/s]

Accuracy so far: 83.75%


  1%|▏         | 21/1563 [00:01<02:03, 12.49it/s]

Accuracy so far: 84.53125%


  2%|▏         | 30/1563 [00:02<02:15, 11.28it/s]

Accuracy so far: 86.97916666666667%
Accuracy of the network on the 10000 test images: 86.97916666666667%





(0.03266051756218076, 0.8697916666666666)

In [None]:
Merge.merged_model = deepcopy(merged_model_backup)
model_repaired = Merge.merged_model
variant = 'repair'

if variant == 'repair':
    model_repaired = make_repaired_imagenet_vgg16(model_repaired, device)
elif variant == 'rescale':
    model_repaired = make_rescaled_imagenet_vgg16(model_repaired, device)

i = 0
for layer in model_repaired.modules():
    if isinstance(layer, (ResetLayer)):
        layer.set_stats(goal_means[i], goal_stds[i])
        i += 1
    elif isinstance(layer, (RescaleLayer)):
        layer.set_stats(goal_stds[i])
        i += 1

Merge.merged_model = model_repaired
imagenet_reset_bn_stats(Merge, train_loader)
validate(Merge, test_loader, criterion, device, num_iters=30, print_freq=10)

Resetting batch norm:   0%|          | 0/313 [00:00<?, ?it/s]


ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 4096])

In [59]:
imagenet_reset_bn_stats(Merge, train_loader)

Resetting batch norm:   0%|          | 0/313 [00:00<?, ?it/s]

Resetting batch norm:   0%|          | 0/313 [00:00<?, ?it/s]


ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 4096])

In [43]:
data, target = next(iter(test_loader))

data = data.to(device)[:2].float()

In [45]:
data[0]

tensor([[[-0.3540, -0.3027, -0.2000,  ...,  0.6904,  0.3823,  0.8105],
         [-0.5254, -0.6450, -0.0116,  ...,  0.8447,  0.2281,  0.6392],
         [-0.6450, -0.3711, -0.0629,  ...,  0.7935,  0.6221,  0.0569],
         ...,
         [-0.2683, -0.2512, -0.2341,  ...,  0.6904, -0.1486, -0.4226],
         [-0.3711, -0.3882, -0.3711,  ..., -0.4397, -0.0287,  0.0056],
         [-0.3369, -0.1829, -0.2856,  ..., -0.3711,  0.1255,  0.0912]],

        [[-0.5127,  0.2578, -0.3025,  ...,  1.0635,  0.7480,  1.2031],
         [-0.5649, -0.3550, -0.1975,  ...,  1.2207,  0.5376,  0.9756],
         [-0.5825, -0.3726, -0.0574,  ...,  1.0977,  1.0107,  0.4502],
         ...,
         [-0.1274, -0.1099, -0.0574,  ...,  0.8003,  0.0126, -0.0924],
         [-0.1975, -0.2849, -0.2150,  ..., -0.1274,  0.2578,  0.5728],
         [ 0.2402,  0.0126, -0.0049,  ...,  0.0126,  0.3977,  0.1527]],

        [[-0.6367,  0.3916, -0.3926,  ...,  0.8271,  0.6528,  0.8447],
         [-0.7412, -0.4451, -0.4102,  ...,  1

In [35]:
out = model_repaired.features(data)
out = model_repaired.avgpool(out)
out = torch.flatten(out, 1)
out = model_repaired.classifier[0](out)
out = model_repaired.classifier[1](out)
out = model_repaired.classifier[2](out)
out = model_repaired.classifier[3](out)
out = model_repaired.classifier[4](out)

In [68]:

model_repaired.classifier

Sequential(
  (0): ResetLayer(
    (layer): Linear(in_features=25088, out_features=4096, bias=True)
    (bn): BatchNorm1d(4096, eps=1e-05, momentum=None, affine=True, track_running_stats=True)
  )
  (1): ReLU(inplace=True)
  (2): ResetLayer(
    (layer): Linear(in_features=4096, out_features=4096, bias=True)
    (bn): BatchNorm1d(4096, eps=1e-05, momentum=None, affine=True, track_running_stats=True)
  )
  (3): ReLU(inplace=True)
  (4): Linear(in_features=4096, out_features=1000, bias=True)
)

In [72]:
from collections import defaultdict
from model_merger import MergedModelStop

def new_forward(self, x, cat_dim=None, start_idx=None):
    """ Evaluate the combined model. """
    if self.start_at is not None:
        start_val = defaultdict(lambda: 0)
        total = 0

        for idx, model in enumerate(self.start_at_models):
            if start_idx is not None and idx != start_idx:
                continue

            try:
                model(x)
            except MergedModelStop as e:
                for k, v in e.x.items():
                    start_val[k] = start_val[k] + v
                total += 1
        
        self.start_at_ptr.clear()
        for k, v in start_val.items():
            self.start_at_ptr[k] = v / total # / len(self.graphs)
        x = x[0, None].detach()
    
    try:
        print("shape of x", x.shape)
        out = self.merged_model(x)
        # print("shape of out", out.shape)
        # print(out)
        return out
    except MergedModelStop as e:
        self.stop_at_ptr[0] = e.x[0]

        dummy_x = x[0, None].detach()
        out = []
        for idx, model in enumerate(self.head_models):
            out.append(model(dummy_x))

        self.stop_at_ptr[0] = None
        
        if cat_dim is not None:
            out = torch.cat(out, dim=cat_dim)
        
        return out

Merge.forward = new_forward.__get__(Merge, Merge.__class__) 

In [77]:
Merge.merged_model.classifier[0]

ResetLayer(
  (layer): Linear(in_features=25088, out_features=4096, bias=True)
  (bn): BatchNorm1d(4096, eps=1e-05, momentum=None, affine=True, track_running_stats=True)
)

In [73]:
Merge(data)

shape of x torch.Size([1, 3, 224, 224])


ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 4096])

In [None]:
Merge

VGG(
  (features): Sequential(
    (0): ResetLayer(
      (layer): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): BatchNorm2d(64, eps=1e-05, momentum=None, affine=True, track_running_stats=True)
    )
    (1): ReLU(inplace=True)
    (2): ResetLayer(
      (layer): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): BatchNorm2d(64, eps=1e-05, momentum=None, affine=True, track_running_stats=True)
    )
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): ResetLayer(
      (layer): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): BatchNorm2d(128, eps=1e-05, momentum=None, affine=True, track_running_stats=True)
    )
    (6): ReLU(inplace=True)
    (7): ResetLayer(
      (layer): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): BatchNorm2d(128, eps=1e-05, momentum=None, affine=True, track_running_stats=True)

In [52]:
images, _ = next(iter(train_loader))
images = images.to(device).float()

In [56]:
model_repaired(images)

tensor([[-0.3517, -1.4258, -1.2641,  ..., -0.3146,  0.9565, -0.9646],
        [-1.8688,  0.8482, -0.8746,  ...,  0.8126,  2.3886,  1.3361],
        [-1.9659, -1.9149,  0.0967,  ..., -1.1985, -0.1174,  0.5273],
        ...,
        [ 0.6457, -0.9827, -0.0270,  ...,  2.4435,  2.0778,  0.8790],
        [-2.5298, -2.6307, -0.5199,  ..., -1.7217,  0.8399,  2.5990],
        [ 1.0522, -0.1755,  0.3616,  ...,  0.6595,  3.9731,  0.8954]],
       device='cuda:1', grad_fn=<AddmmBackward0>)

In [60]:
# resetting stats to baseline first as below is necessary for stability
for m in Merge.modules():
    if type(m) in (nn.BatchNorm2d, BatchScale2d, BatchScale1d, nn.BatchNorm1d):

            m.momentum = None # use simple average
            m.reset_running_stats()

# run a single train epoch with augmentations to recalc stats
Merge.train()

iter = 0
with torch.no_grad():
    for images, _ in tqdm(train_loader, desc='Resetting batch norm'):
        if images.shape[0] == 1:
            break
        if  iter >= 10:
            break
        if iter == len(train_loader): # hack for fractional loader
            break
        images = images.to(device).float()
        _ = Merge(images)
        iter += 1
model.eval()

Resetting batch norm:   0%|          | 0/313 [00:00<?, ?it/s]


ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 4096])

In [None]:
out = 

tensor([[ 0.8551,  0.7926,  0.2835,  ...,  0.8156, -1.7074, -3.7925],
        [-1.7701, -2.0483, -3.1649,  ..., -2.7772,  0.9288,  0.7644]],
       device='cuda:1', grad_fn=<NativeBatchNormBackward0>)

In [None]:
# torch.cuda.is_available()

In [None]:
# prepared_metrics = torch.load('pfm_results/imagenet_vgg16_metrics.pth', weights_only=True)
# Merge.metrics = prepared_metrics
