In [1]:
import os
import torch
import random
from lmc_utils import BatchScale1d, BatchScale2d, interpolate_state_dicts
from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
from model_merger import ModelMerge
from graphs.base_graph import NodeType
from torch import nn
from utils import get_config_from_name, get_device, prepare_experiment_config, get_merging_fn

from lmc_utils import ResetLayer, RescaleLayer, TrackLayer

def make_repaired_imagenet_vgg16(net, device=None):
    net1 = deepcopy(net).to(device)
    for i, layer in enumerate(net1.features):
        if isinstance(layer, (nn.Conv2d)):
            net1.features[i] = ResetLayer(layer)
    for i, layer in enumerate(net1.classifier):
        if i < 4 and isinstance(layer, nn.Linear):
            net1.classifier[i] = ResetLayer(layer)
    return net1.eval().to(device)

def make_rescaled_imagenet_vgg16(net, device=None):
    net1 = deepcopy(net).to(device)
    for i, layer in enumerate(net1.features):
        if isinstance(layer, (nn.Conv2d)):
            net1.features[i] = RescaleLayer(layer)
    for i, layer in enumerate(net1.classifier):
        if i < 4 and isinstance(layer, nn.Linear):
            net1.classifier[i] = RescaleLayer(layer)
    return net1.eval().to(device)


def make_tracked_imagenet_vgg16(net, device=None):
    net1 = deepcopy(net)
    for i, layer in enumerate(net1.features):
        if isinstance(layer, (nn.Conv2d)):
            net1.features[i] = TrackLayer(layer)
    for i, layer in enumerate(net1.classifier):
        if i < 4 and isinstance(layer, nn.Linear):
            net1.classifier[i] = TrackLayer(layer)
    return net1.eval().to(device)


def validate(model, testloader, criterion, device, half=False, num_iters=None, print_freq=None):
    model.eval()
    correct = 0
    total = 0
    loss_sum = 0

    with torch.no_grad():
        it = 0
        for data in testloader:
            if num_iters is not None and it >= num_iters:
                break
            images, labels = data
            images = images.to(device).float()
            labels = labels.to(device).long()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss_sum += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            it += 1
            if print_freq is not None and it % print_freq == 0:
                print('Accuracy so far: {}%'.format(100 * correct / total))
        
    print('Accuracy of the network on the 10000 test images: {}%'.format(100 * correct / total))
    return loss_sum / total, correct / total


def imagenet_reset_bn_stats(model, loader, reset=True, num_iters=None):
    """Reset batch norm stats if nn.BatchNorm2d present in the model."""
    device = get_device(model)
    has_bn = False
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) in (nn.BatchNorm2d, BatchScale2d, BatchScale1d, nn.BatchNorm1d):
            if reset:
                m.momentum = None # use simple average
                m.reset_running_stats()
            has_bn = True

    if not has_bn:
        return model

    # run a single train epoch with augmentations to recalc stats
    model.train()
    iter = 0
    with torch.no_grad():
        print('Resetting batch norm stats')
        for images, _ in loader:
            if num_iters is not None and iter >= num_iters:
                break
            if iter == len(loader): # hack for fractional loader
                break
            images = images.to(device).float()
            _ = model(images)
            iter += 1
    model.eval()
    return model


  from .autonotebook import tqdm as notebook_tqdm


In [7]:
config_name = 'imagenet_vgg16'

device = 'cpu'
raw_config = get_config_from_name(config_name, device=device)

# change the model bases to the desired pair
pairs = '1_2'.split('_')
pairs = [int(pair) for pair in pairs]
for i, model_idx in enumerate(pairs):
    path = raw_config['model']['bases'][i]  # ..._1.pth
    # replace the last digit with the model_idx
    path = path[:-5] + str(model_idx) + '.pth'
    # remove ./ from the path
    # if path.startswith('./'):
    #     path = path[2:]
    print(path)
    raw_config['model']['bases'][i] = path

model_paths = deepcopy(raw_config['model']['bases'])

cur_config = deepcopy(raw_config)
config = prepare_experiment_config(cur_config)

train_loader = config['data']['train']['full']
test_loader = config['data']['test']['full']
train_loader.batch_size = 32
test_loader.batch_size = 32
test_loader.num_workers = 0
print(f"Training samples: {train_loader.batch_size * len(train_loader)}")
print(f"Testing samples: {test_loader.batch_size * len(test_loader)}")

./checkpoints/imagenet_vgg16_1.pth
./checkpoints/imagenet_vgg16_2.pth
Preparing Models




Training samples: 10016
Testing samples: 50016


In [8]:
base_models = [base_model for base_model in config['models']['bases']]
Grapher = config['graph']
criterion = torch.nn.CrossEntropyLoss()

In [9]:
graphs = [Grapher(deepcopy(base_model)).graphify() for base_model
                    in base_models]

In [None]:
metrics_save_path = '/home/xingyu/Repos/my_ZipIt/pfm_results/imagenet/imagenet_vgg16_1_2_metrics.pth'