In [1]:
import torch
import pickle
from models import mobilenetv2, resnet56
from torchvision.models import resnet50
from embedl.plumbing.torch.metrics.target import Target
from embedl.torch.pruning.methods import UniformPruning
from embedl.torch.viewer import view_model
from embedl.torch.metrics.performances import Flops  
from embedl.torch.metrics.measure_performance import measure_flops
import torchvision.datasets as datasets
from embedl.torch.pruning.methods import plot_pruning_profile 
import torchvision.transforms as transforms
import torch.nn as nn
from embedl.torch.metrics.performances import Flops
from embedl.torch.pruning.methods import (
    PruningMethod,
    ChannelPruningTactic,
)
from embedl.plumbing.torch.metrics.scorers import ChannelPruningScorer, PruningBalancer
from embedl.torch.metrics.importance_scores import WeightMagnitude
from embedl.plumbing.torch.pruning.method import apply_pruning_steps

normalize = transforms.Normalize(
    mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261]
)

val_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(
        root="/home/jonna/data",
        train=False,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                normalize,
            ]
        ),
    ),
    batch_size=128,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
)

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def validate(val_loader, model, criterion):
    """
    Run evaluation
    """

    prec1 = 0
    count = 0
    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.cuda()
            input_var = input.cuda()
            target_var = target.cuda()

            # compute output
            output = model(input_var)
            loss = criterion(output, target_var)

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            prec1 += accuracy(output.data, target)[0] * target.size(0)
            # print(accuracy(output.data, target)[0])
            count += target.size(0)

    print(f" * Prec@1 {prec1/count:.3f}")
    return

  from .autonotebook import tqdm as notebook_tqdm


In [20]:
state_dict = torch.load(
    "/home/jonna/hyperparameter_sensitivity_pruning/experiments/cifar10/mobilenetv2/base_model/extended_grid_3/results/lr_10**-2.20_wd_10**-2.60/checkpoint_final.th"
)["state_dict"]
state_dict = {key[7:]: weights for key, weights in state_dict.items()}

mobilenet_top1 = mobilenetv2()
mobilenet_top1.load_state_dict(state_dict)
mobilenet_top1.cuda()

state_dict = torch.load(
    "/home/jonna/hyperparameter_sensitivity_pruning/experiments/cifar10/mobilenetv2/base_model/extended_grid_4/results/lr_10**-2.40_wd_10**-2.20/checkpoint_final.th"
)["state_dict"]
state_dict = {key[7:]: weights for key, weights in state_dict.items()}

mobilenet_top2 = mobilenetv2()
mobilenet_top2.load_state_dict(state_dict)
mobilenet_top2.cuda()

state_dict = torch.load(
    "/home/jonna/hyperparameter_sensitivity_pruning/experiments/cifar10/mobilenetv2/base_model/extended_grid_4/results/lr_10**-2.20_wd_10**-2.20/checkpoint_final.th"
)["state_dict"]
state_dict = {key[7:]: weights for key, weights in state_dict.items()}

mobilenet_top3 = mobilenetv2()
mobilenet_top3.load_state_dict(state_dict)
mobilenet_top3.cuda()

print("Validate before combining")

validate(
    val_loader, torch.nn.DataParallel(mobilenet_top1), nn.CrossEntropyLoss().cuda()
)
validate(
    val_loader, torch.nn.DataParallel(mobilenet_top2), nn.CrossEntropyLoss().cuda()
)
validate(
    val_loader, torch.nn.DataParallel(mobilenet_top3), nn.CrossEntropyLoss().cuda()
)

Validate before combining
 * Prec@1 95.320
 * Prec@1 95.200
 * Prec@1 94.690


In [21]:
sd1 = mobilenet_top1.state_dict()
sd2 = mobilenet_top2.state_dict()
sd3 = mobilenet_top3.state_dict()


# Average all parameters
for key in sd1:
    sd2[key] = (sd1[key] + sd2[key] + sd3[key]) / 3.


# Recreate model and load averaged state_dict (or use modelA/B)
model = mobilenetv2()
model.load_state_dict(sd2)
validate(
    val_loader, torch.nn.DataParallel(model), nn.CrossEntropyLoss().cuda()
)

 * Prec@1 10.000
