### Pruning ResNet-56 for CIFAR-10

#### In this notebook we...
- Load ResNet-56 model we previously trained on CIFAR-10
- Prune 20% of the weights using l1 unstructured (can be easily changed to random unstructured)
- Test the pruned model
- Save the pruned model to "./models/resnet_56_sparse.th"


#### Important Notes
- Cells must be run in order (especially for the actual pruning as rerunning code blocks will result in further pruning)
- 

In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torchsummary import summary
import torch.nn.utils.prune as prune

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from resnet import resnet56

In [19]:
criterion = nn.CrossEntropyLoss()

def test():
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0) 
            correct += predicted.eq(targets).sum().item()
    
    acc = 100.*correct/total
    print(acc)

In [20]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [21]:
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [22]:
PATH = './models/resnet_56_dense.th'
model = resnet56()
model = model.to(device)
trained_model = torch.load(PATH)
if device == 'cuda':
    model = torch.nn.DataParallel(model)
    cudnn.benchmark = True
model.load_state_dict(trained_model['state_dict'])

<All keys matched successfully>

In [23]:
test()

92.33


In [13]:
# Pruning 20% of weights using l1 unstructured
for name, module in model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
        prune.remove(module, 'weight') 
    # prune 20% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.2)
        prune.remove(module, 'weight')

In [14]:
# Sanity check for sparsity:
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        print(float(torch.sum(module.weight == 0)) / float(module.weight.nelement()))
    elif isinstance(module, torch.nn.Linear):
        print(float(torch.sum(module.weight == 0)) / float(module.weight.nelement()))

0.19907407407407407
0.20008680555555555
0.20008680555555555
0.20008680555555555
0.20008680555555555
0.20008680555555555
0.20008680555555555
0.20008680555555555
0.20008680555555555
0.20008680555555555
0.20008680555555555
0.20008680555555555
0.20008680555555555
0.20008680555555555
0.20008680555555555
0.20008680555555555
0.20008680555555555
0.20008680555555555
0.20008680555555555
0.20008680555555555
0.1999782986111111
0.1999782986111111
0.1999782986111111
0.1999782986111111
0.1999782986111111
0.1999782986111111
0.1999782986111111
0.1999782986111111
0.1999782986111111
0.1999782986111111
0.1999782986111111
0.1999782986111111
0.1999782986111111
0.1999782986111111
0.1999782986111111
0.1999782986111111
0.1999782986111111
0.1999782986111111
0.2000054253472222
0.2000054253472222
0.2000054253472222
0.2000054253472222
0.2000054253472222
0.2000054253472222
0.2000054253472222
0.2000054253472222
0.2000054253472222
0.2000054253472222
0.2000054253472222
0.2000054253472222
0.2000054253472222
0.200005425

In [16]:
test()

92.37


In [17]:
torch.save(model.state_dict(), './models/resnet_56_sparse.th')