In [1]:
import torchvision
from torchvision.models import ResNet50_Weights
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

In [3]:
sd = model.state_dict()
torch.save(sd, './checkpoints/imagenet_resnet50_2.pth')

_____

In [4]:
import os
import torch
import random
from lmc_utils import BatchScale1d, BatchScale2d
from copy import deepcopy
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt
from utils import get_config_from_name, get_device, prepare_experiment_config, get_merging_fn
from model_merger import ModelMerge
from torch import nn

def validate(model, testloader, criterion, device, half=False, num_iters=None, print_freq=None):
    model.eval()
    correct = 0
    total = 0
    loss_sum = 0

    with torch.no_grad():
        it = 0
        for data in testloader:
            if num_iters is not None and it >= num_iters:
                break
            images, labels = data
            images = images.to(device).float()
            labels = labels.to(device).long()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss_sum += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            it += 1
            if print_freq is not None and it % print_freq == 0:
                print('Accuracy so far: {}%'.format(100 * correct / total))
            
            # it += 1
            # if it % 10:
            #     print('Accuracy so far: {}'.format(100 * correct / total))
        
    print('Accuracy of the network on the 10000 test images: {}%'.format(100 * correct / total))
    return loss_sum / total, correct / total


def imagenet_reset_bn_stats(model, loader, reset=True, num_iters=None):
    """Reset batch norm stats if nn.BatchNorm2d present in the model."""
    device = get_device(model)
    has_bn = False
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) in (nn.BatchNorm2d, BatchScale2d, BatchScale1d, nn.BatchNorm1d):
            if reset:
                m.momentum = None # use simple average
                m.reset_running_stats()
            has_bn = True

    if not has_bn:
        return model

    # run a single train epoch with augmentations to recalc stats
    model.train()
    iter = 0
    with torch.no_grad():
        for images, _ in loader:
            if images.shape[0] == 1:
                break
            if num_iters is not None and iter >= num_iters:
                break
            if iter == len(loader): # hack for fractional loader
                break
            images = images.to(device).float()
            _ = model(images)
            iter += 1
    model.eval()
    return model

In [5]:
config_name = 'imagenet_resnet50'

device = 'cuda:1'
raw_config = get_config_from_name(config_name, device=device)  # 返回config文件中的字典，添加了一个device键值对
cur_config = deepcopy(raw_config)
config = prepare_experiment_config(cur_config)

train_loader = config['data']['train']['full']
test_loader = config['data']['test']['full']
test_loader.num_workers = 0


base_models = [base_model for base_model in config['models']['bases']]
Grapher = config['graph']
merging_fn = 'match_tensors_permute'
# merging_fn = 'match_tensors_zipit'

criterion = torch.nn.CrossEntropyLoss()

Preparing Models: 100%|██████████| 2/2 [00:01<00:00,  1.96it/s]


In [6]:
from graphs.base_graph import NodeType

graph = Grapher(deepcopy(base_models[0])).graphify().G
prefix_nodes = []
for node in graph.nodes:
    info = graph.nodes[node]
    if info['type'] == NodeType.PREFIX:
        prefix_nodes.append(node)
len(prefix_nodes)

37

In [7]:
graphs = [Grapher(deepcopy(base_model)).graphify() for base_model in base_models]

merging_fn = 'match_tensors_permute' # 'match_tensors_permute'
start_at = None  # [3, 6, 10, 13, 17, 20, 23, 27, 30, 33, 37, 40, 43, 47, 50]
stop_at = None

Merge = ModelMerge(*graphs, device=device)
prepared_metrics = None
# prepared_metrics = torch.load('pfm_results/imagenet_vgg16_metrics.pth', weights_only=True)
# for key in prepared_metrics:
#     for key2 in prepared_metrics[key]:
#         prepared_metrics[key][key2] = prepared_metrics[key][key2].to(device)
# Merge.metrics = prepared_metrics

Merge.transform(
    deepcopy(config['models']['new']), 
    train_loader, 
    transform_fn=get_merging_fn(merging_fn),
    metric_classes=config['metric_fns'],
    stop_at=stop_at,
    start_at=start_at,
    prepared_metrics=prepared_metrics,
    # a=0.3,
    # b=0.8
)


# imagenet_reset_bn_stats(Merge, train_loader)

# merged_model_backup = deepcopy(Merge.merged_model)

Forward Pass to Compute Merge Metrics: 
Computing transformations: 


In [13]:
# validate ensemble of base models

total = 0
correct = 0

for data in test_loader:
    images, labels = data
    images = images.to(device).float()
    labels = labels.to(device).long()
    outputs = [base_model(images) for base_model in base_models]
    outputs = torch.stack(outputs, dim=0).mean(dim=0)
    _, predicted = torch.max(outputs, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

print('Accuracy of the ensemble on the 10000 test images: {}%'.format(100 * correct / total))

Accuracy of the ensemble on the 10000 test images: 79.774%
