In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.nn.utils import prune
import numpy as np

In [2]:
from utils.load_data import load_data
from utils.load_model import load_model
from utils.test_model import test_model

In [3]:
dataloaders, dataset_size, class_names = load_data('../data')

In [4]:
model = load_model('../models/resnet18_01')

<All keys matched successfully>


In [5]:
test_model(model, dataloaders['test'])

Accuracy of the model on 540 test images is 90.37


In [16]:
conv_idx = []
for i, m in enumerate(model.modules()):
    if isinstance(m, torch.nn.Conv2d):
        conv_idx.append(i+1)
conv_idx

[2, 8, 11, 14, 17, 21, 24, 27, 30, 33, 37, 40, 43, 46, 49, 53, 56, 59, 62, 65]

In [17]:
skip = [2, 8, 14, 16, 26, 28, 30]
layers_to_prune = [x for x in conv_idx if x not in skip]
cfg = []
cfg_mask = []
prune_prob = [0.6, 0.4, 0.2, 0.1, 0.1]
layer_id = 1

In [18]:
for m in model.modules():
    if isinstance(m, nn.Conv2d):
        if m.kernel_size == (1,1):
            continue
        out_channels = m.weight.data.shape[0]
        if layer_id in skip:
            cfg_mask.append(torch.ones(out_channels))
            cfg.append(out_channels)
            layer_id += 1
            continue
        if layer_id % 2 == 0:
            if layer_id <= 6:
                stage = 0
            elif layer_id <= 14:
                stage = 1
            elif layer_id <= 26:
                stage = 2
            else:
                stage = 3
            prune_prob_stage = prune_prob[stage]
            weight_copy = m.weight.data.abs().clone().cpu().numpy()
            L1_norm = np.sum(weight_copy, axis=(1,2,3))
            num_keep = int(out_channels * (1 - prune_prob_stage))
            arg_max = np.argsort(L1_norm)
            arg_max_rev = arg_max[::-1][:num_keep]
            mask = torch.zeros(out_channels)
            mask[arg_max_rev.tolist()] = 1
            cfg_mask.append(mask)
            cfg.append(num_keep)
            layer_id += 1
            continue
    layer_id += 1


In [19]:
newmodel = load_model('../models/resnet18_01')

<All keys matched successfully>


In [20]:
start_mask = torch.ones(3)
layer_id_in_cfg = 0
conv_count = 0
batchnorm_layers = [x+1 for x in layers_to_prune]
len(cfg_mask)

9

In [21]:
for [m0, m1] in zip(model.modules(), newmodel.modules()):
    if isinstance(m0, nn.Conv2d):
        if m0.kernel_size == (1,1):
            # Cases for down-sampling convolution.
            m1.weight.data = m0.weight.data.clone()
            continue
        if conv_count == 1:
            m1.weight.data = m0.weight.data.clone()
            conv_count += 1
            continue
        if conv_count % 2 == 0:
            mask = cfg_mask[layer_id_in_cfg]
            idx = np.squeeze(np.argwhere(np.asarray(mask.cpu().numpy())))
            if idx.size == 1:
                idx = np.resize(idx, (1,))
            w = m0.weight.data[idx.tolist(), :, :, :].clone()
            m1.weight.data = w.clone()
            layer_id_in_cfg += 1
            conv_count += 1
            continue
        if conv_count % 2 == 1:
            mask = cfg_mask[layer_id_in_cfg-1]
            idx = np.squeeze(np.argwhere(np.asarray(mask.cpu().numpy())))
            if idx.size == 1:
                idx = np.resize(idx, (1,))
            w = m0.weight.data[:, idx.tolist(), :, :].clone()
            m1.weight.data = w.clone()
            conv_count += 1
            continue
    elif isinstance(m0, nn.BatchNorm2d):
        assert isinstance(m1, nn.BatchNorm2d), "There should not be bn layer here."
        if conv_count % 2 == 1:
            mask = cfg_mask[layer_id_in_cfg-1]
            idx = np.squeeze(np.argwhere(np.asarray(mask.cpu().numpy())))
            if idx.size == 1:
                idx = np.resize(idx, (1,))
            m1.weight.data = m0.weight.data[idx.tolist()].clone()
            m1.bias.data = m0.bias.data[idx.tolist()].clone()
            m1.running_mean = m0.running_mean[idx.tolist()].clone()
            m1.running_var = m0.running_var[idx.tolist()].clone()
            continue
        m1.weight.data = m0.weight.data.clone()
        m1.bias.data = m0.bias.data.clone()
        m1.running_mean = m0.running_mean.clone()
        m1.running_var = m0.running_var.clone()
    elif isinstance(m0, nn.Linear):
        m1.weight.data = m0.weight.data.clone()
        m1.bias.data = m0.bias.data.clone()

In [22]:
layer_id_in_cfg

9

In [23]:
from torchsummary import summary

In [24]:
summary(model, (3,224,244))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 122]           9,408
       BatchNorm2d-2         [-1, 64, 112, 122]             128
              ReLU-3         [-1, 64, 112, 122]               0
         MaxPool2d-4           [-1, 64, 56, 61]               0
            Conv2d-5           [-1, 64, 56, 61]          36,864
       BatchNorm2d-6           [-1, 64, 56, 61]             128
              ReLU-7           [-1, 64, 56, 61]               0
            Conv2d-8           [-1, 64, 56, 61]          36,864
       BatchNorm2d-9           [-1, 64, 56, 61]             128
             ReLU-10           [-1, 64, 56, 61]               0
       BasicBlock-11           [-1, 64, 56, 61]               0
           Conv2d-12           [-1, 64, 56, 61]          36,864
      BatchNorm2d-13           [-1, 64, 56, 61]             128
             ReLU-14           [-1, 64,

In [13]:
for i, m0 in enumerate(model.modules()):
    if isinstance(m0, nn.Conv2d):
        if m0.kernel_size == (1,1):
            print(i, m0.weight.data.shape[0])

26 128
42 256
58 512


In [25]:
test_model(newmodel, dataloaders['test'])

RuntimeError: running_mean should contain 128 elements not 102