In [2]:
# import os
import torch
# import random
from copy import deepcopy
# from tqdm.auto import tqdm
# import numpy as np
import os
from graphs.base_graph import NodeType
from utils import get_config_from_name, prepare_experiment_config,\
     get_merging_fn
from lmc_utils import reset_bn_stats
from model_merger import ModelMerge, MergeHandler
from lmc_utils import interpolate_state_dicts, repair

In [3]:
config_name = 'cifar10_my_vgg16'

device = 'cuda:0'
raw_config = get_config_from_name(config_name, device=device)

# change the model bases to the desired pair
pairs = '2_3'.split('_')
pairs = [int(pair) for pair in pairs]
for i, model_idx in enumerate(pairs):
    path = raw_config['model']['bases'][i]  # ..._1.pth
    # replace the last digit with the model_idx
    path = path[:-5] + '_' + str(model_idx) + '.pt'
    # remove ./ from the path
    # if path.startswith('./'):
    #     path = path[2:]
    raw_config['model']['bases'][i] = path

# prepare models
model_paths = deepcopy(raw_config['model']['bases'])
cur_config = deepcopy(raw_config)
config = prepare_experiment_config(cur_config)
train_loader = config['data']['train']['full']
test_loader = config['data']['test']['full']
base_models = [base_model for base_model in
                config['models']['bases']]
Grapher = config['graph']
criterion = torch.nn.CrossEntropyLoss()


merging_fn = 'match_tensors_permute'
graphs = [Grapher(deepcopy(base_model)).graphify() for base_model
                    in base_models]
Merge = ModelMerge(*graphs, device=device)
Merge.transform(
    deepcopy(config['models']['new']),
    train_loader,
    transform_fn=get_merging_fn(merging_fn),
    metric_classes=config['metric_fns'],
    stop_at=None,
    start_at=None
)

# get permuted model
graphs = Merge.graphs
base_model_merge_s = [deepcopy(graph.model) for graph in graphs]
# remove all hooks from the model
for model in base_model_merge_s:
    model._forward_hooks = {}
    model._backward_hooks = {}
sd_1_permuted = base_model_merge_s[0].state_dict()
sd_2_permuted = base_model_merge_s[1].state_dict()

sd_1_save_name = f'checkpoints/cifar10_my_vgg16_1_permute.pth'
sd_2_save_name = f'checkpoints/cifar10_my_vgg16_2_permute.pth'
if os.path.exists(sd_1_save_name):
    raise FileExistsError(f'{sd_1_save_name} already exists')
elif os.path.exists(sd_2_save_name):
    raise FileExistsError(f'{sd_2_save_name} already exists')
else:
    torch.save(sd_1_permuted, sd_1_save_name)
    torch.save(sd_2_permuted, sd_2_save_name)

Files already downloaded and verified
Files already downloaded and verified
Preparing Models




Forward Pass to Compute Merge Metrics: 
Computing transformations: 


In [4]:
def validate(model, testloader, criterion, device):
    model.eval()
    correct = 0
    total = 0
    loss_sum = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data 
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss_sum += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Accuracy of the network on the 10000 test images: \
        {}'.format(100 * correct / total))
    return loss_sum / total, correct / total

validate(Merge, test_loader, criterion, device)

Accuracy of the network on the 10000 test images:         44.75


(0.013312624871730804, 0.4475)

In [5]:
sd_merge = Merge.merged_model.state_dict()

merged_model_test = deepcopy(base_models[0])

sd_merged = interpolate_state_dicts(sd_merge, sd_merge, 0.5, True)
merged_model_test.load_state_dict(sd_merged)

validate(merged_model_test, test_loader, criterion, device)

Accuracy of the network on the 10000 test images:         81.63


(0.016766719698905946, 0.8163)