In [0]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.nn.utils.prune as prune

from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

import numpy as np
import os
import seaborn as sns

In [0]:
# Import dataset (MNIST).

# Transforms which will be applied to the data.
transform = transforms.Compose([transforms.ToTensor(),
                                ])

# Split the train dataset into a train + valid datasets.
# Must set the values of the samples in each split (here, 50000, 10000).
dataset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)
train_set, valid_set = torch.utils.data.random_split(dataset, [50000, 10000])

# Load the test dataset.
test_set = datasets.MNIST(root='../data', train=False, transform=transform)

In [0]:
# Transformations will not be applied until you call a DataLoader on it.
train_loader = torch.utils.data.DataLoader(train_set, batch_size = 32, shuffle=True, num_workers=0, drop_last=True)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=32, shuffle=False, num_workers=0, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size = 32, shuffle=False, num_workers=0, drop_last=True)

In [0]:
# Create a neural net.
class LeNetTrash(nn.Module):
    def __init__(self):
        super(LeNetTrash, self).__init__()
        self.fc1 = nn.Linear(784, 300)
        self.fc2 = nn.Linear(300, 100)
        self.fc3 = nn.Linear(100, 10)
    
    def forward(self, input):
        x = input.flatten(start_dim=1, end_dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [54]:
# Instantiate the model.
net = LeNetTrash()
print(net)

LeNetTrash(
  (fc1): Linear(in_features=784, out_features=300, bias=True)
  (fc2): Linear(in_features=300, out_features=100, bias=True)
  (fc3): Linear(in_features=100, out_features=10, bias=True)
)


In [55]:
# Initialise the model weights.
def weight_init(m):
    if(isinstance(m, nn.Linear)):
        init.xavier_normal_(m.weight.data)
        init.normal_(m.bias.data)
    # Other branches if other types of layers are present.

# Apply the weight initialization function recursively for each layer of the net.
net.apply(weight_init)

LeNetTrash(
  (fc1): Linear(in_features=784, out_features=300, bias=True)
  (fc2): Linear(in_features=300, out_features=100, bias=True)
  (fc3): Linear(in_features=100, out_features=10, bias=True)
)

In [0]:
# Code for training one epoch (one pass through the dataset).
def train_one_epoch(model, train_loader, optimizer, criterion):
    model.train() #Sets nn.Module in train mode (has effects only on some models)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    for batch_idx, (imgs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        imgs, targets = imgs.to(device), targets.to(device)

        output = model(imgs)
        
        train_loss = criterion(output, targets)
        train_loss.backward()

        # In original, gradient of the pruned nodes were made 0.
        optimizer.step()
    
    # train_loss is a tensor with one value;
    # tensor.item() returns the value held by a tensor with one value;
    return train_loss.item()

In [0]:
def calculate_accuracy_and_loss(model, loader, criterion):
    # Put the model in evaluation mode.
    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    nr_batches = len(loader)
    total_loss = 0
    accuracy = 0

    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (imgs, targets) in enumerate(loader):
            imgs, targets = imgs.to(device), targets.to(device)

            output = model(imgs)

            total_loss += (1/nr_batches) * criterion(output, targets)

            _, predicted = torch.max(output.data, 1)
            total += targets.shape[0]
            correct += (predicted == targets).sum().item()

    accuracy = correct/total * 100
    return accuracy, total_loss

In [0]:
def save_checkpoint(PATH, epoch, model, optimizer):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, PATH)

# model, optimizer must be instantiated objects of the same respective types
# as the model/optimizer that were saved.
def load_checkpoint(PATH, model, optimizer):
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    return model, optimizer, epoch

def remove_checkpoint(PATH):
    if(os.path.exists(PATH)):
        os.remove(PATH)

In [59]:
def get_best_model_PATH(experiment_folder, best_model_filename, prune_iter):
    return experiment_folder + '/' + best_model_filename + '_prune_iter_' + str(prune_iter) + '.tar'

print(get_best_model_PATH('experiments/FC_test_run', 'best_model', 0))

experiments/FC_test_run/best_model_prune_iter_0.tar


In [0]:
# Define a pruning strategy for the layers of the model.
# prune_methods must contain a mapping from modules names to pruning methods.
# prune_args must contain a mapping from modules names to the arguments of the
# respective pruning methods.
def prune_model(model, prune_method, prune_args):
    for name, module in model.named_modules():
        prune_method[name](module, **prune_args[name])

In [61]:
def restore_orig_weights(model):
    for name, module in model.named_modules():
        
        #print(list(module.named_parameters()))
        #print(dict(module.named_buffers()).keys())
        if(name == ''):
            continue
        print('\n\n\n')
        print(name + str(type(module)))
        print(module.weight)
        #for parameter_name, parameter in module.named_parameters():
        #    print(parameter_name + ' ' + str(type(parameter)))
restore_orig_weights(net)





fc1<class 'torch.nn.modules.linear.Linear'>
Parameter containing:
tensor([[ 0.0222,  0.0469,  0.0494,  ...,  0.0173, -0.0053,  0.0220],
        [ 0.0992,  0.0046, -0.0264,  ..., -0.0449, -0.0249,  0.0866],
        [-0.0044,  0.0049, -0.0662,  ..., -0.0534, -0.0521,  0.0599],
        ...,
        [ 0.0701, -0.0238, -0.0836,  ...,  0.0071,  0.0211, -0.0638],
        [-0.0789,  0.0138, -0.0051,  ...,  0.0111, -0.0370,  0.0138],
        [ 0.0094,  0.0133, -0.0426,  ...,  0.0109, -0.0360,  0.0143]],
       requires_grad=True)




fc2<class 'torch.nn.modules.linear.Linear'>
Parameter containing:
tensor([[ 4.5328e-02, -5.3506e-02,  5.0311e-03,  ..., -1.9942e-02,
         -2.8242e-02,  8.6791e-02],
        [ 2.4255e-02,  1.2922e-02,  6.4602e-02,  ..., -1.9510e-02,
         -8.2361e-02, -1.2523e-01],
        [-8.3200e-02,  9.2748e-02,  1.2278e-01,  ..., -1.8978e-02,
         -1.0114e-01, -6.0964e-03],
        ...,
        [ 1.8643e-01, -9.5648e-03, -5.3288e-02,  ..., -1.3616e-02,
          

In [14]:
print(net.state_dict().keys())

odict_keys(['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])


In [15]:
list(net.parameters())

[Parameter containing:
 tensor([[-0.0009,  0.0425,  0.0261,  ...,  0.0066,  0.0240, -0.0359],
         [ 0.0507, -0.0170, -0.0142,  ...,  0.0015, -0.0572, -0.0091],
         [-0.0427, -0.0477,  0.0192,  ..., -0.0263,  0.0187,  0.0216],
         ...,
         [-0.0842, -0.0107,  0.0155,  ...,  0.0940,  0.1008, -0.0216],
         [ 0.0334, -0.0214,  0.0110,  ..., -0.0394,  0.0147,  0.0060],
         [ 0.0098,  0.0278,  0.0219,  ..., -0.0508, -0.0205,  0.0089]],
        requires_grad=True), Parameter containing:
 tensor([ 0.1604,  1.0814,  1.2004, -0.9492, -1.2613, -0.6145, -0.2033,  1.0035,
         -1.1604,  0.3370, -0.0480,  0.7957,  0.5526,  1.7292,  0.6536, -0.1919,
          1.3983,  1.0776,  0.4900,  0.4933, -1.1510, -1.2824, -0.2843,  1.6931,
          1.3025, -0.1345, -0.4162,  0.6173, -0.3863,  1.1685,  0.0172, -1.1021,
          1.1570, -1.2995,  1.4132,  0.3960,  0.1702, -0.3318, -0.4430, -1.6753,
          0.5717, -1.2890, -1.5643, -1.8697,  0.8611, -0.0628,  0.3078,  0.0175,

In [16]:
print(list(net.named_modules()))

[('', LeNetTrash(
  (fc1): Linear(in_features=784, out_features=300, bias=True)
  (fc2): Linear(in_features=300, out_features=100, bias=True)
  (fc3): Linear(in_features=100, out_features=10, bias=True)
)), ('fc1', Linear(in_features=784, out_features=300, bias=True)), ('fc2', Linear(in_features=300, out_features=100, bias=True)), ('fc3', Linear(in_features=100, out_features=10, bias=True))]


In [62]:
optimizer = torch.optim.Adam(net.parameters(), weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

train_one_epoch(net, train_loader, optimizer, criterion)

0.0746472030878067

In [63]:
prune.l1_unstructured(net.fc1, 'weight', 0.1)
prune.l1_unstructured(net.fc2, 'weight', 0.1)
prune.l1_unstructured(net.fc3, 'weight', 0.1)

Linear(in_features=100, out_features=10, bias=True)

In [19]:
print(net.state_dict().keys())

odict_keys(['fc1.bias', 'fc1.weight_orig', 'fc1.weight_mask', 'fc2.bias', 'fc2.weight_orig', 'fc2.weight_mask', 'fc3.bias', 'fc3.weight_orig', 'fc3.weight_mask'])


In [66]:
for parameter in net.fc1.named_parameters():
    print(parameter[0])

for parameter in net.fc1.parameters():
    print(parameter.shape)

print(net.fc1.weight.shape)

bias
weight_orig
torch.Size([300])
torch.Size([300, 784])
torch.Size([300, 784])


In [0]:
print(net.fc1.weight_orig)

Parameter containing:
tensor([[ 1.2672e-39,  1.0505e-37,  4.4345e-39,  ...,  1.4049e-39,
          1.1837e-37, -4.2345e-39],
        [ 3.0337e-37,  1.5924e-37, -5.6929e-39,  ..., -3.0686e-39,
         -1.0979e-37, -4.8788e-39],
        [-5.9204e-39,  2.1624e-39,  6.6162e-39,  ..., -1.0876e-39,
         -1.8000e-39,  1.2207e-37],
        ...,
        [ 4.0793e-39, -3.6856e-40,  4.0646e-38,  ..., -3.0056e-39,
         -1.2159e-37,  4.5388e-39],
        [ 1.3105e-38, -3.9764e-39,  4.1318e-39,  ..., -4.8299e-39,
         -1.6280e-37,  2.9855e-40],
        [ 9.4934e-40,  1.7330e-39, -2.0795e-38,  ..., -2.8807e-39,
         -3.9555e-39,  2.9408e-39]], requires_grad=True)


In [0]:
weight_orig_clone1 = net.fc1.weight_orig.detach().clone()
weight_clone1 = net.fc1.weight.detach().clone()

In [0]:
print(weight_orig_clone1)

tensor([[ 1.2672e-39,  1.0505e-37,  4.4345e-39,  ...,  1.4049e-39,
          1.1837e-37, -4.2345e-39],
        [ 3.0337e-37,  1.5924e-37, -5.6929e-39,  ..., -3.0686e-39,
         -1.0979e-37, -4.8788e-39],
        [-5.9204e-39,  2.1624e-39,  6.6162e-39,  ..., -1.0876e-39,
         -1.8000e-39,  1.2207e-37],
        ...,
        [ 4.0793e-39, -3.6856e-40,  4.0646e-38,  ..., -3.0056e-39,
         -1.2159e-37,  4.5388e-39],
        [ 1.3105e-38, -3.9764e-39,  4.1318e-39,  ..., -4.8299e-39,
         -1.6280e-37,  2.9855e-40],
        [ 9.4934e-40,  1.7330e-39, -2.0795e-38,  ..., -2.8807e-39,
         -3.9555e-39,  2.9408e-39]])


In [0]:
print(weight_clone1)

tensor([[ 0.0000e+00,  1.0505e-37,  4.4345e-39,  ...,  0.0000e+00,
          1.1837e-37, -4.2345e-39],
        [ 0.0000e+00,  0.0000e+00, -5.6929e-39,  ..., -0.0000e+00,
         -1.0979e-37, -4.8788e-39],
        [-5.9204e-39,  0.0000e+00,  6.6162e-39,  ..., -0.0000e+00,
         -0.0000e+00,  0.0000e+00],
        ...,
        [ 4.0793e-39, -0.0000e+00,  4.0646e-38,  ..., -0.0000e+00,
         -0.0000e+00,  0.0000e+00],
        [ 0.0000e+00, -3.9764e-39,  4.1318e-39,  ..., -4.8299e-39,
         -1.6280e-37,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
         -3.9555e-39,  0.0000e+00]])


In [0]:
train_one_epoch(net, train_loader, optimizer, criterion)

0.015482593327760696

In [0]:
print(weight_orig_clone1)

tensor([[ 1.2672e-39,  1.0505e-37,  4.4345e-39,  ...,  1.4049e-39,
          1.1837e-37, -4.2345e-39],
        [ 3.0337e-37,  1.5924e-37, -5.6929e-39,  ..., -3.0686e-39,
         -1.0979e-37, -4.8788e-39],
        [-5.9204e-39,  2.1624e-39,  6.6162e-39,  ..., -1.0876e-39,
         -1.8000e-39,  1.2207e-37],
        ...,
        [ 4.0793e-39, -3.6856e-40,  4.0646e-38,  ..., -3.0056e-39,
         -1.2159e-37,  4.5388e-39],
        [ 1.3105e-38, -3.9764e-39,  4.1318e-39,  ..., -4.8299e-39,
         -1.6280e-37,  2.9855e-40],
        [ 9.4934e-40,  1.7330e-39, -2.0795e-38,  ..., -2.8807e-39,
         -3.9555e-39,  2.9408e-39]])


In [0]:
print(net.fc1.weight_orig)

Parameter containing:
tensor([[ 1.2672e-39, -1.7521e-37,  4.4345e-39,  ...,  1.4049e-39,
          1.1837e-37, -4.2345e-39],
        [ 3.0337e-37,  1.9106e-38, -5.6929e-39,  ..., -3.0686e-39,
          1.7047e-37, -4.8788e-39],
        [-5.9204e-39,  2.1624e-39,  6.6162e-39,  ..., -1.0876e-39,
         -1.8000e-39, -1.8060e-38],
        ...,
        [ 4.0793e-39, -3.6856e-40,  4.0646e-38,  ..., -3.0056e-39,
          1.8539e-38,  4.5388e-39],
        [ 1.3105e-38, -3.9764e-39,  4.1318e-39,  ..., -4.8299e-39,
         -1.6280e-37,  2.9855e-40],
        [ 9.4934e-40,  1.7330e-39,  1.1934e-37,  ..., -2.8807e-39,
         -3.9555e-39,  2.9408e-39]], requires_grad=True)


In [0]:
print(weight_orig_clone1.eq(net.fc1.weight_orig))
print(weight_clone1.eq(net.fc1.weight))

tensor([[ True, False,  True,  ...,  True,  True,  True],
        [ True, False,  True,  ...,  True, False,  True],
        [ True,  True,  True,  ...,  True,  True, False],
        ...,
        [ True,  True,  True,  ...,  True, False,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True, False,  ...,  True,  True,  True]])
tensor([[ True, False,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True, False,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        ...,
        [ True,  True, False,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True, False,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])


In [0]:
optimizer.state_dict().keys()

dict_keys(['state', 'param_groups'])

In [0]:
# Initialise loss function.
criterion = nn.CrossEntropyLoss()

# Testing training a few epochs, saving the best model then loading it.
pruning_iterations = 5
epochs = 20

# Frequency for getting train/validation losses + (saving the best model).
test_freq = 2

# Experiments results + logs will be saved in the "experiments" folder.
base_dir = "pruning_experiments"
experiment_name = "FC_test_run"
best_model_filename = "best_model"

# Initialise a SummaryWriter in order to use tensorboard + log relevant results.
writer = SummaryWriter(experiment_PATH)

for pruning_iter in range(pruning_iterations):

    experiment_PATH = base_dir + '/' + experiment_name + '/pruning_' + str(pruning_iter)
    os.makedirs(experiment_PATH, exist_ok=True)

    print('\n\nStarting pruning iteration: ' + str(pruning_iter) + '\n')

    if(pruning_iter != 0):


    # Initialize optimizer:
    optimizer = torch.optim.Adam(net.parameters(), weight_decay=1e-4)

    best_accuracy = 0
    for epoch in range(epochs):
        train_one_epoch(net, train_loader, optimizer, criterion)

        if epoch % test_freq == 0:
            train_acc, train_loss = calculate_accuracy_and_loss(net, train_loader, criterion)
            valid_acc, valid_loss = calculate_accuracy_and_loss(net, valid_loader, criterion)

            writer.add_scalar('Accuracy/train', train_acc, epoch)
            writer.add_scalar('Loss/train', train_loss, epoch)
            writer.add_scalar('Accuracy/valid', valid_acc, epoch)
            writer.add_scalar('Loss/valid', valid_loss, epoch)

            if(valid_acc > best_accuracy):
                best_accuracy = valid_acc
                PATH = get_best_model_PATH(experiment_PATH, best_model_filename, 0)
                remove_checkpoint(PATH)
                save_checkpoint(PATH, epoch, net, optimizer)

            print('Epoch: ' + str(epoch) +  ', Train loss: {:.4f}, Train Acc: {:.2f}, Valid loss: {:.4f}, Valid Acc: {:.2f}'.format(train_loss, train_acc, valid_loss, valid_acc))
    





  0%|          | 0/10 [00:00<?, ?it/s][A[A[A[A



Epoch: 0, Train loss: 0.0379, Train Acc: 98.72, Valid loss: 0.0379, Valid Acc: 98.72:   0%|          | 0/10 [00:54<?, ?it/s][A[A[A[A



Epoch: 0, Train loss: 0.0379, Train Acc: 98.72, Valid loss: 0.0379, Valid Acc: 98.72:  10%|█         | 1/10 [00:54<08:12, 54.68s/it][A[A[A[A



Epoch: 0, Train loss: 0.0379, Train Acc: 98.72, Valid loss: 0.0379, Valid Acc: 98.72:  20%|██        | 2/10 [01:17<06:00, 45.08s/it][A[A[A[A



Epoch: 2, Train loss: 0.0342, Train Acc: 98.86, Valid loss: 0.0342, Valid Acc: 98.86:  20%|██        | 2/10 [01:55<06:00, 45.08s/it][A[A[A[A



Epoch: 2, Train loss: 0.0342, Train Acc: 98.86, Valid loss: 0.0342, Valid Acc: 98.86:  30%|███       | 3/10 [01:55<04:59, 42.85s/it][A[A[A[A



Epoch: 2, Train loss: 0.0342, Train Acc: 98.86, Valid loss: 0.0342, Valid Acc: 98.86:  40%|████      | 4/10 [02:14<03:34, 35.83s/it][A[A[A[A



Epoch: 4, Train loss: 0.0271, Train Acc: 99.13, Valid loss: 0.

In [0]:
print(experiment_PATH)

experiments/FC_test_run


In [0]:
# Start a tensorboard session.
%load_ext tensorboard
%tensorboard --logdir experiments/FC_test_run

In [0]:
for prune_iteration in range(nr_prune_iterations):
    if(prune_iteration != 0):
        # Need to save original parameters of the model.
        # Prune parameters.
        # Re-initialise parameters.
        # Re-initialize optimizer (add only the parameters that learn).
        # Something to do with the device?
        pass
    
    for train_iteration in range(train_iterations):
        # If train_iteration... record the accuracy/loss.
        # If it's the best accuracy, save the model + optimiser parameters.
        # Train the model.
        # Need a train function (for one iteration?) + test function for evaluating...
    # Make figures, save them.

# TODO: use pruningmethod for weight magnitudes then for gradients.
# Apparently, you need to take into account the device you are working with
# (CPU vs GPU) when doing this shit.

SyntaxError: ignored

In [0]:
nr_prune_iterations = 10
epochs = 100

best_loss = np.inf
train_losses = np.zeros((epochs, ))
test_losses = np.zeros((epochs, ))
for epoch in range(epochs):
    

0.407292902469635

In [0]:
def report_test_loss(model, test_loader, optimizer, criterion):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    

In [0]:
net.state_dict

<bound method Module.state_dict of LeNetTrash(
  (fc1): Linear(in_features=784, out_features=300, bias=True)
  (fc2): Linear(in_features=300, out_features=100, bias=True)
  (fc3): Linear(in_features=100, out_features=10, bias=True)
)>

In [0]:
print(net.fc1.weight.data.numpy())

[[ 7.3078976e-40 -3.5359847e-39 -5.1571147e-40 ... -1.5360235e-39
   5.1942575e-39 -2.8576764e-39]
 [-4.2151913e-39 -5.1506715e-39 -2.7939061e-39 ... -2.4624598e-39
  -3.9547263e-39  3.9824174e-39]
 [ 1.9701430e-39 -4.2787303e-39 -4.1677834e-38 ...  1.7528899e-39
   7.7572408e-39  5.2655611e-39]
 ...
 [ 6.2717579e-39  2.4628928e-39 -7.6535122e-38 ... -2.7414684e-38
   1.0152351e-39 -1.1631659e-34]
 [ 2.4143823e-38  3.6544967e-39 -8.4877349e-39 ...  7.0279462e-40
  -5.5321960e-39 -2.6077940e-39]
 [ 2.4695559e-39  5.0258018e-39 -4.5261940e-42 ...  6.2589556e-40
  -5.3935642e-39  2.2270416e-40]]


In [0]:
optimizer.

<bound method Optimizer.state_dict of Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0.0001
)>

In [0]:
# -*- coding: utf-8 -*-
"""
Pruning Tutorial
=====================================
**Author**: `Michela Paganini <https://github.com/mickypaganini>`_

State-of-the-art deep learning techniques rely on over-parametrized models 
that are hard to deploy. On the contrary, biological neural networks are 
known to use efficient sparse connectivity. Identifying optimal  
techniques to compress models by reducing the number of parameters in them is 
important in order to reduce memory, battery, and hardware consumption without 
sacrificing accuracy, deploy lightweight models on device, and guarantee 
privacy with private on-device computation. On the research front, pruning is 
used to investigate the differences in learning dynamics between 
over-parametrized and under-parametrized networks, to study the role of lucky 
sparse subnetworks and initializations
("`lottery tickets <https://arxiv.org/abs/1803.03635>`_") as a destructive 
neural architecture search technique, and more.

In this tutorial, you will learn how to use ``torch.nn.utils.prune`` to 
sparsify your neural networks, and how to extend it to implement your 
own custom pruning technique.

Requirements
------------
``"torch>=1.4.0a0+8e8a5e0"``

"""
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

######################################################################
# Create a model
# --------------
#
# In this tutorial, we use the `LeNet 
# <http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf>`_ architecture from 
# LeCun et al., 1998.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

In [0]:
######################################################################
# Inspect a Module
# ----------------
# 
# Let's inspect the (unpruned) ``conv1`` layer in our LeNet model. It will contain two 
# parameters ``weight`` and ``bias``, and no buffers, for now.
module = model.conv1
print(list(module.named_parameters()))

[('weight', Parameter containing:
tensor([[[[ 0.2138,  0.1507,  0.1297],
          [-0.0887, -0.3002, -0.2179],
          [ 0.1616, -0.2809, -0.2802]]],


        [[[ 0.1787, -0.2232,  0.0195],
          [ 0.2123, -0.2549,  0.2789],
          [ 0.2916, -0.1297,  0.2314]]],


        [[[-0.1468, -0.1988, -0.0314],
          [ 0.2852,  0.0238,  0.1755],
          [-0.0203,  0.2721, -0.1099]]],


        [[[-0.1998,  0.1009, -0.0759],
          [-0.0180,  0.0602, -0.2492],
          [ 0.0734, -0.2812, -0.0847]]],


        [[[-0.0311,  0.3279, -0.1438],
          [-0.2013, -0.1219, -0.0993],
          [ 0.1979,  0.0930, -0.0901]]],


        [[[ 0.0004, -0.2402, -0.1435],
          [ 0.2260,  0.0478,  0.1106],
          [ 0.0376,  0.0913, -0.1595]]]], requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.2093, -0.2762,  0.2606, -0.2037,  0.1396, -0.3161],
       requires_grad=True))]


In [0]:
######################################################################
print(list(module.named_buffers()))

[]


In [0]:
######################################################################
# Pruning a Module
# ----------------
# 
# To prune a module (in this example, the ``conv1`` layer of our LeNet 
# architecture), first select a pruning technique among those available in 
# ``torch.nn.utils.prune`` (or
# `implement <#extending-torch-nn-utils-pruning-with-custom-pruning-functions>`_
# your own by subclassing 
# ``BasePruningMethod``). Then, specify the module and the name of the parameter to 
# prune within that module. Finally, using the adequate keyword arguments 
# required by the selected pruning technique, specify the pruning parameters.
#
# In this example, we will prune at random 30% of the connections in 
# the parameter named ``weight`` in the ``conv1`` layer.
# The module is passed as the first argument to the function; ``name`` 
# identifies the parameter within that module using its string identifier; and 
# ``amount`` indicates either the percentage of connections to prune (if it 
# is a float between 0. and 1.), or the absolute number of connections to 
# prune (if it is a non-negative integer).
prune.random_unstructured(module, name="weight", amount=0.3) 

######################################################################
# Pruning acts by removing ``weight`` from the parameters and replacing it with 
# a new parameter called ``weight_orig`` (i.e. appending ``"_orig"`` to the 
# initial parameter ``name``). ``weight_orig`` stores the unpruned version of 
# the tensor. The ``bias`` was not pruned, so it will remain intact.
print(list(module.named_parameters()))

######################################################################
# The pruning mask generated by the pruning technique selected above is saved 
# as a module buffer named ``weight_mask`` (i.e. appending ``"_mask"`` to the 
# initial parameter ``name``).
print(list(module.named_buffers()))

[('bias', Parameter containing:
tensor([ 0.2093, -0.2762,  0.2606, -0.2037,  0.1396, -0.3161],
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.2138,  0.1507,  0.1297],
          [-0.0887, -0.3002, -0.2179],
          [ 0.1616, -0.2809, -0.2802]]],


        [[[ 0.1787, -0.2232,  0.0195],
          [ 0.2123, -0.2549,  0.2789],
          [ 0.2916, -0.1297,  0.2314]]],


        [[[-0.1468, -0.1988, -0.0314],
          [ 0.2852,  0.0238,  0.1755],
          [-0.0203,  0.2721, -0.1099]]],


        [[[-0.1998,  0.1009, -0.0759],
          [-0.0180,  0.0602, -0.2492],
          [ 0.0734, -0.2812, -0.0847]]],


        [[[-0.0311,  0.3279, -0.1438],
          [-0.2013, -0.1219, -0.0993],
          [ 0.1979,  0.0930, -0.0901]]],


        [[[ 0.0004, -0.2402, -0.1435],
          [ 0.2260,  0.0478,  0.1106],
          [ 0.0376,  0.0913, -0.1595]]]], requires_grad=True))]
[('weight_mask', tensor([[[[0., 0., 1.],
          [1., 0., 1.],
          [1., 1., 1.]]],


In [0]:
######################################################################
# For the forward pass to work without modification, the ``weight`` attribute 
# needs to exist. The pruning techniques implemented in 
# ``torch.nn.utils.prune`` compute the pruned version of the weight (by 
# combining the mask with the original parameter) and store them in the 
# attribute ``weight``. Note, this is no longer a parameter of the ``module``,
# it is now simply an attribute.
print(module.weight)

tensor([[[[ 0.0000,  0.0000,  0.1297],
          [-0.0887, -0.0000, -0.2179],
          [ 0.1616, -0.2809, -0.2802]]],


        [[[ 0.1787, -0.0000,  0.0000],
          [ 0.0000, -0.2549,  0.2789],
          [ 0.2916, -0.0000,  0.2314]]],


        [[[-0.1468, -0.0000, -0.0314],
          [ 0.2852,  0.0238,  0.1755],
          [-0.0203,  0.2721, -0.1099]]],


        [[[-0.1998,  0.1009, -0.0759],
          [-0.0180,  0.0602, -0.2492],
          [ 0.0734, -0.0000, -0.0847]]],


        [[[-0.0311,  0.3279, -0.0000],
          [-0.2013, -0.0000, -0.0993],
          [ 0.1979,  0.0930, -0.0901]]],


        [[[ 0.0000, -0.2402, -0.0000],
          [ 0.0000,  0.0478,  0.0000],
          [ 0.0376,  0.0913, -0.0000]]]], grad_fn=<MulBackward0>)


In [0]:
######################################################################
# Finally, pruning is applied prior to each forward pass using PyTorch's
# ``forward_pre_hooks``. Specifically, when the ``module`` is pruned, as we 
# have done here, it will acquire a ``forward_pre_hook`` for each parameter 
# associated with it that gets pruned. In this case, since we have so far 
# only pruned the original parameter named ``weight``, only one hook will be
# present.
print(module._forward_pre_hooks)

OrderedDict([(28, <torch.nn.utils.prune.RandomUnstructured object at 0x7f8cda9824a8>)])


In [0]:
######################################################################
# For completeness, we can now prune the ``bias`` too, to see how the 
# parameters, buffers, hooks, and attributes of the ``module`` change.
# Just for the sake of trying out another pruning technique, here we prune the 
# 3 smallest entries in the bias by L1 norm, as implemented in the 
# ``l1_unstructured`` pruning function.
prune.l1_unstructured(module, name="bias", amount=3)

######################################################################
# We now expect the named parameters to include both ``weight_orig`` (from 
# before) and ``bias_orig``. The buffers will include ``weight_mask`` and 
# ``bias_mask``. The pruned versions of the two tensors will exist as 
# module attributes, and the module will now have two ``forward_pre_hooks``.
print(list(module.named_parameters()))

######################################################################
print(list(module.named_buffers()))

######################################################################
print(module.bias)

######################################################################
print(module._forward_pre_hooks)

######################################################################
# Iterative Pruning
# -----------------
# 
# The same parameter in a module can be pruned multiple times, with the 
# effect of the various pruning calls being equal to the combination of the
# various masks applied in series.
# The combination of a new mask with the old mask is handled by the 
# ``PruningContainer``'s ``compute_mask`` method.
#
# Say, for example, that we now want to further prune ``module.weight``, this
# time using structured pruning along the 0th axis of the tensor (the 0th axis 
# corresponds to the output channels of the convolutional layer and has 
# dimensionality 6 for ``conv1``), based on the channels' L2 norm. This can be 
# achieved using the ``ln_structured`` function, with ``n=2`` and ``dim=0``.
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# As we can verify, this will zero out all the connections corresponding to 
# 50% (3 out of 6) of the channels, while preserving the action of the 
# previous mask.
print(module.weight)

############################################################################
# The corresponding hook will now be of type 
# ``torch.nn.utils.prune.PruningContainer``, and will store the history of 
# pruning applied to the ``weight`` parameter.
for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container 

######################################################################
# Serializing a pruned model
# --------------------------
# All relevant tensors, including the mask buffers and the original parameters
# used to compute the pruned tensors are stored in the model's ``state_dict`` 
# and can therefore be easily serialized and saved, if needed.
print(model.state_dict().keys())

[('weight_orig', Parameter containing:
tensor([[[[ 0.2138,  0.1507,  0.1297],
          [-0.0887, -0.3002, -0.2179],
          [ 0.1616, -0.2809, -0.2802]]],


        [[[ 0.1787, -0.2232,  0.0195],
          [ 0.2123, -0.2549,  0.2789],
          [ 0.2916, -0.1297,  0.2314]]],


        [[[-0.1468, -0.1988, -0.0314],
          [ 0.2852,  0.0238,  0.1755],
          [-0.0203,  0.2721, -0.1099]]],


        [[[-0.1998,  0.1009, -0.0759],
          [-0.0180,  0.0602, -0.2492],
          [ 0.0734, -0.2812, -0.0847]]],


        [[[-0.0311,  0.3279, -0.1438],
          [-0.2013, -0.1219, -0.0993],
          [ 0.1979,  0.0930, -0.0901]]],


        [[[ 0.0004, -0.2402, -0.1435],
          [ 0.2260,  0.0478,  0.1106],
          [ 0.0376,  0.0913, -0.1595]]]], requires_grad=True)), ('bias_orig', Parameter containing:
tensor([ 0.2093, -0.2762,  0.2606, -0.2037,  0.1396, -0.3161],
       requires_grad=True))]
[('weight_mask', tensor([[[[0., 0., 1.],
          [1., 0., 1.],
          [1., 1., 1.

In [0]:
######################################################################
# Remove pruning re-parametrization
# ---------------------------------
#
# To make the pruning permanent, remove the re-parametrization in terms
# of ``weight_orig`` and ``weight_mask``, and remove the ``forward_pre_hook``,
# we can use the ``remove`` functionality from ``torch.nn.utils.prune``.
# Note that this doesn't undo the pruning, as if it never happened. It simply 
# makes it permanent, instead, by reassigning the parameter ``weight`` to the 
# model parameters, in its pruned version.

######################################################################
# Prior to removing the re-parametrization:
print(list(module.named_parameters()))
######################################################################
print(list(module.named_buffers()))
######################################################################
print(module.weight)

######################################################################
# After removing the re-parametrization:
prune.remove(module, 'weight')
print(list(module.named_parameters()))
######################################################################
print(list(module.named_buffers()))

[('weight_orig', Parameter containing:
tensor([[[[ 0.2138,  0.1507,  0.1297],
          [-0.0887, -0.3002, -0.2179],
          [ 0.1616, -0.2809, -0.2802]]],


        [[[ 0.1787, -0.2232,  0.0195],
          [ 0.2123, -0.2549,  0.2789],
          [ 0.2916, -0.1297,  0.2314]]],


        [[[-0.1468, -0.1988, -0.0314],
          [ 0.2852,  0.0238,  0.1755],
          [-0.0203,  0.2721, -0.1099]]],


        [[[-0.1998,  0.1009, -0.0759],
          [-0.0180,  0.0602, -0.2492],
          [ 0.0734, -0.2812, -0.0847]]],


        [[[-0.0311,  0.3279, -0.1438],
          [-0.2013, -0.1219, -0.0993],
          [ 0.1979,  0.0930, -0.0901]]],


        [[[ 0.0004, -0.2402, -0.1435],
          [ 0.2260,  0.0478,  0.1106],
          [ 0.0376,  0.0913, -0.1595]]]], requires_grad=True)), ('bias_orig', Parameter containing:
tensor([ 0.2093, -0.2762,  0.2606, -0.2037,  0.1396, -0.3161],
       requires_grad=True))]
[('weight_mask', tensor([[[[0., 0., 1.],
          [1., 0., 1.],
          [1., 1., 1.

In [0]:
######################################################################
# Pruning multiple parameters in a model 
# --------------------------------------
#
# By specifying the desired pruning technique and parameters, we can easily 
# prune multiple tensors in a network, perhaps according to their type, as we 
# will see in this example.

new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers 
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers 
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

######################################################################
# Global pruning
# --------------
#
# So far, we only looked at what is usually referred to as "local" pruning,
# i.e. the practice of pruning tensors in a model one by one, by 
# comparing the statistics (weight magnitude, activation, gradient, etc.) of 
# each entry exclusively to the other entries in that tensor. However, a 
# common and perhaps more powerful technique is to prune the model all at 
# once, by removing (for example) the lowest 20% of connections across the 
# whole model, instead of removing the lowest 20% of connections in each 
# layer. This is likely to result in different pruning percentages per layer.
# Let's see how to do that using ``global_unstructured`` from 
# ``torch.nn.utils.prune``.

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

######################################################################
# Now we can check the sparsity induced in every pruned parameter, which will 
# not be equal to 20% in each layer. However, the global sparsity will be 
# (approximately) 20%.
print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])
Sparsity in conv1.weight: 0.00%
Sparsity in conv2.weight: 8.80%
Sparsity in fc1.weight: 22.05%
Sparsity in fc2.weight: 12.34%
Sparsity in fc3.weight: 7.74%
Global sparsity: 20.00%


In [0]:
######################################################################
# Extending ``torch.nn.utils.prune`` with custom pruning functions
# ------------------------------------------------------------------
# To implement your own pruning function, you can extend the
# ``nn.utils.prune`` module by subclassing the ``BasePruningMethod``
# base class, the same way all other pruning methods do. The base class
# implements the following methods for you: ``__call__``, ``apply_mask``,
# ``apply``, ``prune``, and ``remove``. Beyond some special cases, you shouldn't
# have to reimplement these methods for your new pruning technique.
# You will, however, have to implement ``__init__`` (the constructor),
# and ``compute_mask`` (the instructions on how to compute the mask
# for the given tensor according to the logic of your pruning
# technique). In addition, you will have to specify which type of
# pruning this technique implements (supported options are ``global``,
# ``structured``, and ``unstructured``). This is needed to determine
# how to combine masks in the case in which pruning is applied
# iteratively. In other words, when pruning a pre-pruned parameter,
# the current prunining techique is expected to act on the unpruned
# portion of the parameter. Specifying the ``PRUNING_TYPE`` will
# enable the ``PruningContainer`` (which handles the iterative
# application of pruning masks) to correctly identify the slice of the
# parameter to prune.
#
# Let's assume, for example, that you want to implement a pruning
# technique that prunes every other entry in a tensor (or -- if the
# tensor has previously been pruned -- in the remaining unpruned
# portion of the tensor). This will be of ``PRUNING_TYPE='unstructured'``
# because it acts on individual connections in a layer and not on entire
# units/channels (``'structured'``), or across different parameters
# (``'global'``).

class FooBarPruningMethod(prune.BasePruningMethod):
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0 
        return mask

######################################################################
# Now, to apply this to a parameter in an ``nn.Module``, you should
# also provide a simple function that instantiates the method and
# applies it.
def foobar_unstructured(module, name):
    """Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module) 
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the 
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the 
    original (unpruned) parameter is stored in a new parameter named 
    `name+'_orig'`.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module
    
    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
    """
    FooBarPruningMethod.apply(module, name)
    return module

######################################################################
# Let's try it out!
model = LeNet()
foobar_unstructured(model.fc3, name='bias')

print(model.fc3.bias_mask)


[('weight', Parameter containing:
tensor([[[[ 0.0960,  0.1990, -0.3284],
          [-0.0901, -0.0498,  0.2432],
          [ 0.0512,  0.1423, -0.1570]]],


        [[[-0.1582,  0.2889,  0.3102],
          [ 0.1464, -0.2813, -0.0036],
          [ 0.2368, -0.2903,  0.1698]]],


        [[[ 0.2098,  0.0130, -0.1306],
          [-0.2480, -0.2332, -0.1625],
          [-0.1637, -0.3274, -0.2308]]],


        [[[-0.1492,  0.1220, -0.0337],
          [-0.1939,  0.2457, -0.1599],
          [-0.2925,  0.1766,  0.1953]]],


        [[[-0.0476,  0.2885, -0.0508],
          [ 0.2852, -0.2461,  0.2614],
          [-0.2236, -0.1489,  0.1698]]],


        [[[-0.2876, -0.3044,  0.0094],
          [-0.2870,  0.0434, -0.2159],
          [-0.3075, -0.0028, -0.2630]]]], requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.0297,  0.3308,  0.3307,  0.2654, -0.1183, -0.1151],
       requires_grad=True))]
[]
[('bias', Parameter containing:
tensor([ 0.0297,  0.3308,  0.3307,  0.2654, -0.1183, -0.1151]