In [1]:
import os
import torch
import random

from copy import deepcopy
from tqdm.auto import tqdm
import numpy as np

from utils import *
from model_merger import ModelMerge

def validate(model, testloader, criterion, device):
    model.eval()
    correct = 0
    total = 0
    loss_sum = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss_sum += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Accuracy of the network on the 10000 test images: {} '.format(100 * correct / total))
    return loss_sum / total, correct / total

config_name = 'cifar10_my_vgg16_bn'

device = 'cuda' if torch.cuda.is_available() else 'cpu'
raw_config = get_config_from_name(config_name, device=device)  # 返回config文件中的字典，添加了一个device键值对
cur_config = deepcopy(raw_config)
config = prepare_experiment_config(cur_config)

train_loader = config['data']['train']['full']
test_loader = config['data']['test']['full']

base_models = [reset_bn_stats(base_model, train_loader) for base_model in config['models']['bases']]
Grapher = config['graph']

criterion = torch.nn.CrossEntropyLoss()

Files already downloaded and verified
Files already downloaded and verified


Preparing Models:   0%|          | 0/2 [00:00<?, ?it/s]



Resetting batch norm:   0%|          | 0/391 [00:00<?, ?it/s]

Resetting batch norm:   0%|          | 0/391 [00:00<?, ?it/s]

In [2]:
from graphs.base_graph import NodeType

graph = Grapher(deepcopy(base_models[0])).graphify().G
prefix_nodes = []
for node in graph.nodes:
    info = graph.nodes[node]
    if info['type'] == NodeType.PREFIX:
        prefix_nodes.append(node)
prefix_nodes

[4, 8, 13, 17, 22, 26, 30, 35, 39, 43, 48, 52, 56]

In [3]:
prefix_nodes = [None] + prefix_nodes
prefix_nodes

[None, 4, 8, 13, 17, 22, 26, 30, 35, 39, 43, 48, 52, 56]

In [4]:
merging_fn_s = ['match_tensors_permute', 'match_tensors_zipit',
                'match_tensors_identity']
res_dict = {merging_fn: {'merger': [],
                            'merger_reset': []} for merging_fn in merging_fn_s}

In [5]:
merging_fn = 'match_tensors_permute'

In [None]:
for merging_fn in merging_fn_s:
    print(f"merging_fn: {merging_fn}")
    graphs = [Grapher(deepcopy(base_model)).graphify() for base_model in base_models]
    stop_at = None
    prefix_nodes = [None]+prefix_nodes
    for start_at in prefix_nodes:
        print(f"start_at: {start_at}")
        Merge = ModelMerge(*graphs, device=device)
        Merge.transform(
            deepcopy(config['models']['new']), 
            train_loader, 
            transform_fn=get_merging_fn(merging_fn),
            metric_classes=config['metric_fns'],
            stop_at=stop_at,
            start_at=start_at
        )
        merger_acc = validate(Merge, test_loader, criterion, device)[1]
        print(f"merger_acc: {merger_acc}")

        # reset
        reset_bn_stats(Merge, train_loader)

        merger_reset_acc = validate(Merge, test_loader, criterion, device)[1]
        print(f"merger_reset_acc: {merger_reset_acc}")

        res_dict[merging_fn]['merger'].append(merger_acc)
        res_dict[merging_fn]['merger_reset'].append(merger_reset_acc)
        
        
    for start_at in prefix_nodes:
        graphs = [Grapher(deepcopy(base_model)).graphify() for base_model in base_models]
        # start_at = prefix_nodes[idx] # 0
        stop_at = None

        Merge = ModelMerge(*graphs, device=device)
        Merge.transform(
            deepcopy(config['models']['new']), 
            train_loader, 
            transform_fn=get_merging_fn(merging_fn),
            metric_classes=config['metric_fns'],
            stop_at=stop_at,
            start_at=start_at
        )

        merger_acc = validate(Merge, test_loader, criterion, device)[1]
        res_dict[merging_fn]['merger'].append(merger_acc)

        reset_bn_stats(Merge, train_loader)

        merger_acc = validate(Merge, test_loader, criterion, device)[1]
        res_dict[merging_fn]['merger_reset'].append(merger_acc)

merging_fn: match_tensors_permute


Forward Pass to Compute Merge Metrics:   0%|          | 0/391 [00:00<?, ?it/s]

Computing transformations:   0%|          | 0/13 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 17.67 


Resetting batch norm:   0%|          | 0/391 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 89.61 


Forward Pass to Compute Merge Metrics:   0%|          | 0/391 [00:00<?, ?it/s]

Computing transformations:   0%|          | 0/13 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 17.38 


Resetting batch norm:   0%|          | 0/391 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 89.57 


Forward Pass to Compute Merge Metrics:   0%|          | 0/391 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# def tmp_f(start_at):
#     graphs = [Grapher(deepcopy(base_model)).graphify() for base_model in base_models]
#     # start_at = prefix_nodes[idx] # 0
#     stop_at = None

#     Merge = ModelMerge(*graphs, device=device)
#     Merge.transform(
#         deepcopy(config['models']['new']), 
#         train_loader, 
#         transform_fn=get_merging_fn(merging_fn),
#         metric_classes=config['metric_fns'],
#         stop_at=stop_at,
#         start_at=start_at
#     )

#     merger_acc = validate(Merge, test_loader, criterion, device)[1]
#     res_dict[merging_fn]['merger'].append(merger_acc)

#     reset_bn_stats(Merge, train_loader)

#     merger_acc = validate(Merge, test_loader, criterion, device)[1]
#     res_dict[merging_fn]['merger_reset'].append(merger_acc)


# # for idx in range(len(prefix_nodes)):
# #     tmp_f(idx)

# # for start_at in prefix_nodes:
# #     tmp_f(start_at)

# for start_at in prefix_nodes:
#     graphs = [Grapher(deepcopy(base_model)).graphify() for base_model in base_models]
#     # start_at = prefix_nodes[idx] # 0
#     stop_at = None

#     Merge = ModelMerge(*graphs, device=device)
#     Merge.transform(
#         deepcopy(config['models']['new']), 
#         train_loader, 
#         transform_fn=get_merging_fn(merging_fn),
#         metric_classes=config['metric_fns'],
#         stop_at=stop_at,
#         start_at=start_at
#     )

#     merger_acc = validate(Merge, test_loader, criterion, device)[1]
#     res_dict[merging_fn]['merger'].append(merger_acc)

#     reset_bn_stats(Merge, train_loader)

#     merger_acc = validate(Merge, test_loader, criterion, device)[1]
#     res_dict[merging_fn]['merger_reset'].append(merger_acc)

Forward Pass to Compute Merge Metrics:   0%|          | 0/391 [00:00<?, ?it/s]

Computing transformations:   0%|          | 0/13 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 18.53 


Resetting batch norm:   0%|          | 0/391 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 89.54 


Forward Pass to Compute Merge Metrics:   0%|          | 0/391 [00:00<?, ?it/s]

Computing transformations:   0%|          | 0/13 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 11.23 


Resetting batch norm:   0%|          | 0/391 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 89.91 


Forward Pass to Compute Merge Metrics:   0%|          | 0/391 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
len(prefix_nodes)

14