Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import models, datasets
from torch.autograd import Variable
import shutil

import os
import numpy as np
import pandas as pd 

import matplotlib.pyplot as plt
import seaborn as sns

Hyperparameters

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

torch.cuda.manual_seed(1337)

batch_size = 100
test_batch_size = 1000

kwargs = {'num_workers': 16, 'pin_memory': True}

cuda


DataLoaders

In [3]:
train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('./data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.Pad(4),
                       transforms.RandomCrop(32),
                       transforms.RandomHorizontalFlip(),
                       transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                   ])),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('./data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                   ])),
    batch_size=test_batch_size, shuffle=True, **kwargs)


Files already downloaded and verified


Network Model

In [4]:
class sequential_model(nn.Module):
    def __init__(self, layers=None):
        super(sequential_model, self).__init__()
        if layers == None:
            layers = [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512]
        num_classes = 10
        self.feature = self.make_layers(layers)
        self.classifier = nn.Linear(layers[-1], num_classes)
    
    def make_layers(self, structure):
        layers = []
        in_channels = 3
        for v in structure:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.feature(x)
        x = nn.AvgPool2d(2)(x)
        x = x.view(x.size(0), -1)
        y = self.classifier(x)
        return y

Train Epoch method

In [5]:
def train(model, epoch, optimizer, data_loader=train_loader):
    model.train()
    for idx, (data, target) in enumerate(data_loader):
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
            
        if idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
            epoch, idx * len(data), len(data_loader.dataset),
            100. * idx / len(data_loader), loss.data.item()))

Validation Method

In [6]:
def test(model, data_loader=test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in data_loader:
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)        
        output = model(data)
        test_loss += F.cross_entropy(output, target, size_average=False).data.item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()
        test_loss /= len(data_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(
        test_loss, correct, len(data_loader.dataset),
        100. * correct / len(data_loader.dataset)))
    return correct / float(len(data_loader.dataset))   

Save Model Method

In [7]:
def save_checkpoint(state, is_best, filename='checkpoint_sr.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best_sr.pth.tar')

Train network method

In [8]:
def train_model(model, epochs=10):
    model.cuda()
    optimizer = optim.Adam(model.parameters())
    best_prec = 0.
    for i in range(0, epochs):
        train(model, i, optimizer)
        prec = test(model)
        is_best = prec > best_prec
        best_prec1 = max(prec, best_prec)
        save_checkpoint({
            'epoch': i + 1,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec,
            'optimizer': optimizer.state_dict(),
        }, is_best)
    return model

Load existing Model method

In [9]:
def load_model(checkpoint_path="checkpoint_sr.pth.tar", model_path="model_best_sr.pth.tar"):
    model = sequential_model()
    model.cuda()
    if os.path.isfile(model_path):
        print("=> loading checkpoint '{}'".format(model_path))
        checkpoint_path = torch.load(model_path)
        best_prec1 = checkpoint_path['best_prec1']
        model.load_state_dict(checkpoint_path['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
              .format(model, checkpoint_path['epoch'], best_prec1))
    else:
        print("=> no checkpoint found at")
    return model

Select weak channels

In [10]:
def selectChannels(model, percent=0.2):
    total = 0
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            total += m.weight.data.shape[0]

    bn = torch.zeros(total)
    index = 0
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            size = m.weight.data.shape[0]
            bn[index:(index+size)] = m.weight.data.abs().clone()
            index += size

    y, i = torch.sort(bn)
    thre_index = int(total * percent)
    thre = y[thre_index]

    pruned = 0
    cfg = []
    cfg_mask = []
    for k, m in enumerate(model.modules()):
        if isinstance(m, nn.BatchNorm2d):
            weight_copy = m.weight.data.clone()
            print(type(weight_copy.abs().gt(thre).float()))
            #mask is a matrix in which 1 marks the channels which are kept and 0 marks the pruned channels
            mask = weight_copy.abs().gt(thre).float().cuda()          
            #pruned is the number of all pruned channels 
            pruned = pruned + mask.shape[0] - torch.sum(mask)
            m.weight.data.mul_(mask)
            m.bias.data.mul_(mask)
            cfg.append(int(torch.sum(mask)))
            cfg_mask.append(mask.clone())
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
                format(k, mask.shape[0], int(torch.sum(mask))))
        elif isinstance(m, nn.MaxPool2d):
            cfg.append('M')
    return cfg, cfg_mask

Build new model and transfer weights from full model to build the new pruned model

In [11]:
def transfer_params(cfg, cfg_mask, model):
    newmodel = sequential_model(layers=cfg)
    newmodel.cuda() 

    layer_id_in_cfg = 0
    start_mask = torch.ones(3)
    end_mask = cfg_mask[layer_id_in_cfg]
    for [m0, m1] in zip(model.modules(), newmodel.modules()):
        if isinstance(m0, nn.BatchNorm2d):
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            m1.weight.data = m0.weight.data[idx1].clone()
            m1.bias.data = m0.bias.data[idx1].clone()
            m1.running_mean = m0.running_mean[idx1].clone()
            m1.running_var = m0.running_var[idx1].clone()
            layer_id_in_cfg += 1
            start_mask = end_mask.clone()
            if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
                end_mask = cfg_mask[layer_id_in_cfg]
        elif isinstance(m0, nn.Conv2d):
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            print('In shape: {:d} Out shape:{:d}'.format(idx0.shape[0], idx1.shape[0]))
            w = m0.weight.data[:, idx0, :, :].clone()
            w = w[idx1, :, :, :].clone()
            m1.weight.data = w.clone()
            # m1.bias.data = m0.bias.data[idx1].clone()
        elif isinstance(m0, nn.Linear):
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            m1.weight.data = m0.weight.data[:, idx0].clone()    
    return newmodel

Prune trained model

In [12]:
def prune_model(model):
    cfg, cfg_mask = selectChannels(model)
    prune_model = transfer_params(cfg, cfg_mask, model)
    torch.save({'cfg': cfg, 'state_dict': prune_model.state_dict()}, f='pruned_model.pt')

In [13]:
model = train_model(sequential_model())



  data, target = Variable(data, volatile=True), Variable(target)



Test set: Average loss: 0.1269, Accuracy: 5531/10000 (55.3%)


Test set: Average loss: 0.0870, Accuracy: 6819/10000 (68.2%)


Test set: Average loss: 0.0689, Accuracy: 7500/10000 (75.0%)


Test set: Average loss: 0.0608, Accuracy: 7953/10000 (79.5%)


Test set: Average loss: 0.0589, Accuracy: 7949/10000 (79.5%)


Test set: Average loss: 0.0531, Accuracy: 8338/10000 (83.4%)


Test set: Average loss: 0.0495, Accuracy: 8415/10000 (84.2%)


Test set: Average loss: 0.0410, Accuracy: 8527/10000 (85.3%)


Test set: Average loss: 0.0456, Accuracy: 8453/10000 (84.5%)


Test set: Average loss: 0.0376, Accuracy: 8586/10000 (85.9%)



In [14]:
prune_model(model)

<class 'torch.Tensor'>
layer index: 3 	 total channel: 64 	 remaining channel: 26
<class 'torch.Tensor'>
layer index: 6 	 total channel: 64 	 remaining channel: 56
<class 'torch.Tensor'>
layer index: 10 	 total channel: 128 	 remaining channel: 112
<class 'torch.Tensor'>
layer index: 13 	 total channel: 128 	 remaining channel: 106
<class 'torch.Tensor'>
layer index: 17 	 total channel: 256 	 remaining channel: 196
<class 'torch.Tensor'>
layer index: 20 	 total channel: 256 	 remaining channel: 187
<class 'torch.Tensor'>
layer index: 24 	 total channel: 512 	 remaining channel: 376
<class 'torch.Tensor'>
layer index: 27 	 total channel: 512 	 remaining channel: 427
<class 'torch.Tensor'>
layer index: 31 	 total channel: 512 	 remaining channel: 441
<class 'torch.Tensor'>
layer index: 34 	 total channel: 512 	 remaining channel: 428
In shape: 3 Out shape:26
In shape: 26 Out shape:56
In shape: 56 Out shape:112
In shape: 112 Out shape:106
In shape: 106 Out shape:196
In shape: 196 Out shap

In [15]:
safed = torch.load('pruned_model.pt')
structure = safed['cfg']
weights = safed['state_dict']
pruned_model = sequential_model(structure)
pruned_model.load_state_dict(weights)
pruned_model.cuda()
test(pruned_model)

  data, target = Variable(data, volatile=True), Variable(target)



Test set: Average loss: 0.2018, Accuracy: 3968/10000 (39.7%)



tensor(0.3968)

In [None]:
fine_tuned_model = train_model(pruned_model, epochs=3)
test(fine_tuned_model)