In [1]:
%matplotlib inline


# Pruning Tutorial
**Author**: [Michela Paganini](https://github.com/mickypaganini)

State-of-the-art deep learning techniques rely on over-parametrized models 
that are hard to deploy. On the contrary, biological neural networks are 
known to use efficient sparse connectivity. Identifying optimal  
techniques to compress models by reducing the number of parameters in them is 
important in order to reduce memory, battery, and hardware consumption without 
sacrificing accuracy. This in turn allows you to deploy lightweight models on device, and guarantee 
privacy with private on-device computation. On the research front, pruning is 
used to investigate the differences in learning dynamics between 
over-parametrized and under-parametrized networks, to study the role of lucky 
sparse subnetworks and initializations
("[lottery tickets](https://arxiv.org/abs/1803.03635)") as a destructive 
neural architecture search technique, and more.

In this tutorial, you will learn how to use ``torch.nn.utils.prune`` to 
sparsify your neural networks, and how to extend it to implement your 
own custom pruning technique.

## Requirements
``"torch>=1.4.0a0+8e8a5e0"``


In [1]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [2]:
import sys
import logging
import pathlib
import random
import shutil
import time
import functools
import numpy as np
import argparse

import torch
import torchvision
from tensorboardX import SummaryWriter
from torch.nn import functional as F
from torch.utils.data import DataLoader
# from dataset_1 import SliceData,KneeData
from models import DCTeacherNet,DCStudentNet,DCTeacherNetSFTN
from torchsummary import summary
import torchvision
from torch import nn
from torch.autograd import Variable
from torch import optim
from tqdm import tqdm

## Create a model

In this tutorial, we use the [LeNet](http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf) architecture from 
LeCun et al., 1998.



In [3]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# class LeNet(nn.Module):
#     def __init__(self):
#         super(LeNet, self).__init__()
#         # 1 input image channel, 6 output channels, 3x3 square conv kernel
#         self.conv1 = nn.Conv2d(1, 6, 3)
#         self.conv2 = nn.Conv2d(6, 16, 3)
#         self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
#         self.fc2 = nn.Linear(120, 84)
#         self.fc3 = nn.Linear(84, 10)

#     def forward(self, x):
#         x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
#         x = F.max_pool2d(F.relu(self.conv2(x)), 2)
#         x = x.view(-1, int(x.nelement() / x.shape[0]))
#         x = F.relu(self.fc1(x))
#         x = F.relu(self.fc2(x))
#         x = self.fc3(x)
#         return x

# model = LeNet().to(device=device)

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
teacher = DCTeacherNet().to(device)

In [5]:
student = DCTeacherNet()

parameters_to_prune = (
    (student.tcascade1.conv1, 'weight'),
    (student.tcascade1.conv2, 'weight'),
    (student.tcascade1.conv3, 'weight'),
    (student.tcascade1.conv4, 'weight'),
    (student.tcascade1.conv5, 'weight'),
    (student.tcascade2.conv1, 'weight'),
    (student.tcascade2.conv2, 'weight'),
    (student.tcascade2.conv3, 'weight'),
    (student.tcascade2.conv4, 'weight'),
    (student.tcascade2.conv5, 'weight'),
    (student.tcascade3.conv1, 'weight'),
    (student.tcascade3.conv2, 'weight'),
    (student.tcascade3.conv3, 'weight'),
    (student.tcascade3.conv4, 'weight'),
    (student.tcascade3.conv5, 'weight'),
    (student.tcascade4.conv1, 'weight'),
    (student.tcascade4.conv2, 'weight'),
    (student.tcascade4.conv3, 'weight'),
    (student.tcascade4.conv4, 'weight'),
    (student.tcascade4.conv5, 'weight'),
    (student.tcascade5.conv1, 'weight'),
    (student.tcascade5.conv2, 'weight'),
    (student.tcascade5.conv3, 'weight'),
    (student.tcascade5.conv4, 'weight'),
    (student.tcascade5.conv5, 'weight'),
)
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

In [6]:
print(
    "Sparsity in tcascade1.conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(student.tcascade1.conv1.weight == 0))
        / float(student.tcascade1.conv1.weight.nelement())
    )
)
print(
    "Sparsity in tcascade1.conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(student.tcascade1.conv2.weight == 0))
        / float(student.tcascade1.conv2.weight.nelement())
    )
)
print(
    "Sparsity in tcascade1.conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(student.tcascade1.conv3.weight == 0))
        / float(student.tcascade1.conv3.weight.nelement())
    )
)
print(
    "Sparsity in tcascade1.conv4.weight: {:.2f}%".format(
        100. * float(torch.sum(student.tcascade1.conv4.weight == 0))
        / float(student.tcascade1.conv4.weight.nelement())
    )
)
print(
    "Sparsity in tcascade1.conv5.weight: {:.2f}%".format(
        100. * float(torch.sum(student.tcascade1.conv5.weight == 0))
        / float(student.tcascade1.conv5.weight.nelement())
    )
)
print(
    "Sparsity in tcascade2.conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(student.tcascade2.conv1.weight == 0))
        / float(student.tcascade2.conv1.weight.nelement())
    )
)
print(
    "Sparsity in tcascade2.conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(student.tcascade2.conv2.weight == 0))
        / float(student.tcascade2.conv2.weight.nelement())
    )
)
print(
    "Sparsity in tcascade3.conv5.weight: {:.2f}%".format(
        100. * float(torch.sum(student.tcascade3.conv5.weight == 0))
        / float(student.tcascade3.conv5.weight.nelement())
    )
)
print(
    "Sparsity in tcascade5.conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(student.tcascade5.conv3.weight == 0))
        / float(student.tcascade5.conv3.weight.nelement())
    )
)
print(
    "Sparsity in tcascade4.conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(student.tcascade4.conv1.weight == 0))
        / float(student.tcascade4.conv1.weight.nelement())
    )
)
print(
    "Sparsity in tcascade5.conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(student.tcascade5.conv1.weight == 0))
        / float(student.tcascade5.conv1.weight.nelement())
    )
)


print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(student.tcascade1.conv1.weight == 0)
            + torch.sum(student.tcascade1.conv2.weight == 0)
            + torch.sum(student.tcascade1.conv3.weight == 0)
            + torch.sum(student.tcascade1.conv4.weight == 0)
            + torch.sum(student.tcascade1.conv5.weight == 0)
            +torch.sum(student.tcascade2.conv1.weight == 0)
            + torch.sum(student.tcascade2.conv2.weight == 0)
            + torch.sum(student.tcascade2.conv3.weight == 0)
            + torch.sum(student.tcascade2.conv4.weight == 0)
            + torch.sum(student.tcascade2.conv5.weight == 0)
            +torch.sum(student.tcascade3.conv1.weight == 0)
            + torch.sum(student.tcascade3.conv2.weight == 0)
            + torch.sum(student.tcascade3.conv3.weight == 0)
            + torch.sum(student.tcascade3.conv4.weight == 0)
            + torch.sum(student.tcascade3.conv5.weight == 0)
            +torch.sum(student.tcascade4.conv1.weight == 0)
            + torch.sum(student.tcascade4.conv2.weight == 0)
            + torch.sum(student.tcascade4.conv3.weight == 0)
            + torch.sum(student.tcascade4.conv4.weight == 0)
            + torch.sum(student.tcascade4.conv5.weight == 0)
            +torch.sum(student.tcascade5.conv1.weight == 0)
            + torch.sum(student.tcascade5.conv2.weight == 0)
            + torch.sum(student.tcascade5.conv3.weight == 0)
            + torch.sum(student.tcascade5.conv4.weight == 0)
            + torch.sum(student.tcascade5.conv5.weight == 0)
        )
        / float(
            student.tcascade1.conv1.weight.nelement()
            + student.tcascade1.conv2.weight.nelement()
            + student.tcascade1.conv3.weight.nelement()
            + student.tcascade1.conv4.weight.nelement()
            + student.tcascade1.conv5.weight.nelement()
            
            + student.tcascade2.conv1.weight.nelement()
            + student.tcascade2.conv2.weight.nelement()
            + student.tcascade2.conv3.weight.nelement()
            + student.tcascade2.conv4.weight.nelement()
            + student.tcascade2.conv5.weight.nelement()
            
            + student.tcascade3.conv1.weight.nelement()
            + student.tcascade3.conv2.weight.nelement()
            + student.tcascade3.conv3.weight.nelement()
            + student.tcascade3.conv4.weight.nelement()
            + student.tcascade3.conv5.weight.nelement()
            
            + student.tcascade4.conv1.weight.nelement()
            + student.tcascade4.conv2.weight.nelement()
            + student.tcascade4.conv3.weight.nelement()
            + student.tcascade4.conv4.weight.nelement()
            + student.tcascade4.conv5.weight.nelement()
            
            + student.tcascade5.conv1.weight.nelement()
            + student.tcascade5.conv2.weight.nelement()
            + student.tcascade5.conv3.weight.nelement()
            + student.tcascade5.conv4.weight.nelement()
            + student.tcascade5.conv5.weight.nelement()            
        )
    )
)

Sparsity in tcascade1.conv1.weight: 5.90%
Sparsity in tcascade1.conv2.weight: 20.71%
Sparsity in tcascade1.conv3.weight: 19.51%
Sparsity in tcascade1.conv4.weight: 20.21%
Sparsity in tcascade1.conv5.weight: 18.40%
Sparsity in tcascade2.conv1.weight: 4.86%
Sparsity in tcascade2.conv2.weight: 20.78%
Sparsity in tcascade3.conv5.weight: 22.92%
Sparsity in tcascade5.conv3.weight: 20.04%
Sparsity in tcascade4.conv1.weight: 2.43%
Sparsity in tcascade5.conv1.weight: 4.17%
Global sparsity: 20.00%


In [7]:
print(summary(teacher, [(1, 240, 240),(1,240,240,2),(1,240,240)]))
print(summary(student, [(1, 240, 240),(1,240,240,2),(1,240,240)]))

Layer (type:depth-idx)                   Output Shape              Param #
├─TeacherNet: 1-1                        [-1, 32, 240, 240]        --
|    └─Conv2d: 2-1                       [-1, 32, 240, 240]        320
|    └─Conv2d: 2-2                       [-1, 32, 240, 240]        9,248
|    └─Conv2d: 2-3                       [-1, 32, 240, 240]        9,248
|    └─Conv2d: 2-4                       [-1, 32, 240, 240]        9,248
|    └─Conv2d: 2-5                       [-1, 1, 240, 240]         289
├─DataConsistencyLayer: 1-2              [-1, 1, 240, 240]         --
├─TeacherNet: 1-3                        [-1, 32, 240, 240]        --
|    └─Conv2d: 2-6                       [-1, 32, 240, 240]        320
|    └─Conv2d: 2-7                       [-1, 32, 240, 240]        9,248
|    └─Conv2d: 2-8                       [-1, 32, 240, 240]        9,248
|    └─Conv2d: 2-9                       [-1, 32, 240, 240]        9,248
|    └─Conv2d: 2-10                      [-1, 1, 240, 240]      

In [8]:
# for module, thing in parameters_to_prune:
#     print(module)
#     prune.remove(module,'weight')

In [9]:
print(teacher.tcascade1.conv5.weight)
print(teacher.tcascade1.conv5.weight.nelement(), 'are the total elements in student')
print(torch.sum(teacher.tcascade1.conv5.weight == 0), 'are total zeros')

Parameter containing:
tensor([[[[ 3.8560e-02,  3.9496e-02, -3.3408e-02],
          [ 3.3913e-02, -1.7869e-02, -3.6605e-03],
          [ 4.0564e-02, -1.7470e-02, -7.2544e-03]],

         [[ 4.8953e-02, -5.6545e-02,  1.1524e-02],
          [ 5.5961e-03,  3.5293e-02, -1.7527e-02],
          [ 3.7410e-02, -4.6230e-02,  5.1977e-02]],

         [[ 1.3630e-02, -1.1948e-02, -1.6120e-02],
          [-4.7507e-02, -2.2926e-02,  4.3085e-02],
          [-2.8819e-02,  3.1354e-03,  4.8159e-02]],

         [[ 5.1011e-02, -3.9950e-02, -3.1601e-02],
          [-3.5990e-02, -2.4796e-03, -2.8009e-02],
          [-1.9417e-02, -5.3305e-03, -2.9557e-02]],

         [[ 3.1049e-02,  2.4259e-02,  5.2541e-02],
          [ 1.9686e-02,  4.5637e-02, -1.6989e-02],
          [-5.6354e-02, -3.3767e-02, -3.7016e-02]],

         [[ 2.6803e-02, -9.3838e-03, -1.8184e-02],
          [-5.8764e-02,  7.1971e-03,  4.6151e-02],
          [-3.8687e-02, -6.3523e-03,  1.0612e-02]],

         [[ 3.4743e-02, -3.0508e-02,  5.6733e-02

In [10]:
print(student.tcascade1.conv5.weight)

tensor([[[[ 0.0415, -0.0398,  0.0474],
          [-0.0141,  0.0520,  0.0570],
          [-0.0419,  0.0528,  0.0561]],

         [[-0.0242,  0.0328,  0.0470],
          [ 0.0000, -0.0381, -0.0336],
          [-0.0584,  0.0319,  0.0164]],

         [[-0.0254,  0.0386, -0.0188],
          [-0.0130, -0.0194, -0.0394],
          [-0.0330,  0.0000, -0.0424]],

         [[-0.0246, -0.0246,  0.0335],
          [ 0.0000,  0.0495,  0.0195],
          [-0.0307, -0.0183,  0.0317]],

         [[ 0.0313,  0.0168,  0.0424],
          [ 0.0258,  0.0434, -0.0000],
          [-0.0000, -0.0274,  0.0258]],

         [[-0.0000,  0.0425,  0.0379],
          [ 0.0135,  0.0000, -0.0546],
          [-0.0428,  0.0422,  0.0586]],

         [[ 0.0309, -0.0498, -0.0473],
          [-0.0242, -0.0198, -0.0000],
          [ 0.0498,  0.0429, -0.0000]],

         [[-0.0387,  0.0574,  0.0456],
          [ 0.0278,  0.0000, -0.0287],
          [ 0.0167, -0.0429, -0.0213]],

         [[-0.0545,  0.0335,  0.0347],
         

In [11]:
# print(student.tcascade1.conv5.weight)
print(student.tcascade1.conv5.weight.nelement(), 'are the total elements in student')
print(torch.sum(student.tcascade1.conv5.weight == 0), 'are total zeros')

288 are the total elements in student
tensor(53, device='cuda:0') are total zeros


In [12]:
# parameters_to_prune
50/288

0.1736111111111111

In [21]:
for module, thing in parameters_to_prune:
#     print(module)
    prune.remove(module,'weight')

In [17]:
for name, p in teacher.named_parameters():
    print(p)

Parameter containing:
tensor([[[[ 0.0172,  0.3011, -0.2654],
          [-0.2297, -0.2844,  0.0250],
          [-0.2247, -0.0522,  0.3060]]],


        [[[ 0.2902, -0.2479,  0.0419],
          [ 0.2955,  0.1588,  0.3082],
          [-0.2418, -0.3303,  0.0853]]],


        [[[ 0.2772, -0.1923,  0.2228],
          [-0.1866,  0.0261,  0.0243],
          [-0.2592, -0.1999,  0.3031]]],


        [[[-0.0314, -0.2055, -0.2776],
          [-0.3325,  0.2308,  0.1119],
          [-0.1389, -0.1043,  0.1712]]],


        [[[-0.3243,  0.1163,  0.0553],
          [-0.0155,  0.2661,  0.1732],
          [-0.1525, -0.0182, -0.2124]]],


        [[[-0.0072, -0.2901, -0.1952],
          [-0.1345,  0.0213,  0.3184],
          [ 0.1236,  0.0596,  0.3001]]],


        [[[-0.0961,  0.1128, -0.2050],
          [ 0.2597,  0.1865,  0.1858],
          [ 0.1207,  0.0679, -0.1170]]],


        [[[ 0.0039, -0.0340, -0.1556],
          [ 0.1495, -0.1394, -0.1068],
          [ 0.1680,  0.2435,  0.2856]]],


        [[

In [13]:
print(student.tcascade1.conv5.weight.nelement(), 'are the total elements in student')
print(torch.sum(student.tcascade1.conv5.weight == 0), 'are total zeros')


288 are the total elements in student
tensor(53, device='cuda:0') are total zeros


In [19]:
nonzero = total = 0
verbose=True
for name, p in teacher.named_parameters():
    tensor = p.data.cpu().numpy()
    nz_count = np.count_nonzero(tensor)
    total_params = np.prod(tensor.shape)
    nonzero += nz_count
    total += total_params
    if verbose:
        print(
            f"{name:20} | nonzeros = {nz_count:7} / {total_params:7} ({100 * nz_count / total_params:6.2f}%) | total_pruned = {total_params - nz_count :7} | shape = {tensor.shape}"
        )
if verbose:
    print(
        f"alive: {nonzero}, pruned : {total - nonzero}, total: {total}, Compression rate : {total/nonzero:10.2f}x  ({100 * (total-nonzero) / total:6.2f}% pruned)"
    )
round((nonzero / total) * 100, 1)

tcascade1.conv1.weight | nonzeros =     288 /     288 (100.00%) | total_pruned =       0 | shape = (32, 1, 3, 3)
tcascade1.conv1.bias | nonzeros =      32 /      32 (100.00%) | total_pruned =       0 | shape = (32,)
tcascade1.conv2.weight | nonzeros =    9216 /    9216 (100.00%) | total_pruned =       0 | shape = (32, 32, 3, 3)
tcascade1.conv2.bias | nonzeros =      32 /      32 (100.00%) | total_pruned =       0 | shape = (32,)
tcascade1.conv3.weight | nonzeros =    9216 /    9216 (100.00%) | total_pruned =       0 | shape = (32, 32, 3, 3)
tcascade1.conv3.bias | nonzeros =      32 /      32 (100.00%) | total_pruned =       0 | shape = (32,)
tcascade1.conv4.weight | nonzeros =    9216 /    9216 (100.00%) | total_pruned =       0 | shape = (32, 32, 3, 3)
tcascade1.conv4.bias | nonzeros =      32 /      32 (100.00%) | total_pruned =       0 | shape = (32,)
tcascade1.conv5.weight | nonzeros =     288 /     288 (100.00%) | total_pruned =       0 | shape = (1, 32, 3, 3)
tcascade1.conv5.bias

100.0

In [22]:
nonzero = total = 0
verbose=True
for name, p in student.named_parameters():
    tensor = p.data.cpu().numpy()
    nz_count = np.count_nonzero(tensor)
    total_params = np.prod(tensor.shape)
    nonzero += nz_count
    total += total_params
    if verbose:
        print(
            f"{name:20} | nonzeros = {nz_count:7} / {total_params:7} ({100 * nz_count / total_params:6.2f}%) | total_pruned = {total_params - nz_count :7} | shape = {tensor.shape}"
        )
if verbose:
    print(
        f"alive: {nonzero}, pruned : {total - nonzero}, total: {total}, Compression rate : {total/nonzero:10.2f}x  ({100 * (total-nonzero) / total:6.2f}% pruned)"
    )
round((nonzero / total) * 100, 1)

tcascade1.conv1.bias | nonzeros =      32 /      32 (100.00%) | total_pruned =       0 | shape = (32,)
tcascade1.conv1.weight | nonzeros =     271 /     288 ( 94.10%) | total_pruned =      17 | shape = (32, 1, 3, 3)
tcascade1.conv2.bias | nonzeros =      32 /      32 (100.00%) | total_pruned =       0 | shape = (32,)
tcascade1.conv2.weight | nonzeros =    7307 /    9216 ( 79.29%) | total_pruned =    1909 | shape = (32, 32, 3, 3)
tcascade1.conv3.bias | nonzeros =      32 /      32 (100.00%) | total_pruned =       0 | shape = (32,)
tcascade1.conv3.weight | nonzeros =    7418 /    9216 ( 80.49%) | total_pruned =    1798 | shape = (32, 32, 3, 3)
tcascade1.conv4.bias | nonzeros =      32 /      32 (100.00%) | total_pruned =       0 | shape = (32,)
tcascade1.conv4.weight | nonzeros =    7353 /    9216 ( 79.79%) | total_pruned =    1863 | shape = (32, 32, 3, 3)
tcascade1.conv5.bias | nonzeros =       1 /       1 (100.00%) | total_pruned =       0 | shape = (1,)
tcascade1.conv5.weight | nonzer

80.1

In [18]:
print(student.tcascade1.conv5.weight)

Parameter containing:
tensor([[[[ 0.0457,  0.0212,  0.0000],
          [-0.0462, -0.0000,  0.0494],
          [ 0.0427,  0.0000, -0.0368]],

         [[ 0.0146, -0.0488, -0.0245],
          [-0.0306,  0.0177,  0.0243],
          [ 0.0164,  0.0452,  0.0507]],

         [[-0.0189,  0.0506,  0.0516],
          [ 0.0533, -0.0324,  0.0582],
          [ 0.0531, -0.0160,  0.0352]],

         [[-0.0169,  0.0396,  0.0000],
          [ 0.0000, -0.0000, -0.0161],
          [ 0.0505, -0.0231, -0.0133]],

         [[-0.0555, -0.0557, -0.0353],
          [-0.0527,  0.0254, -0.0296],
          [-0.0388,  0.0000,  0.0380]],

         [[ 0.0317,  0.0414, -0.0487],
          [ 0.0366,  0.0359, -0.0143],
          [ 0.0519,  0.0000, -0.0000]],

         [[-0.0348, -0.0284, -0.0288],
          [ 0.0000,  0.0534, -0.0000],
          [ 0.0481, -0.0173,  0.0483]],

         [[-0.0188, -0.0528, -0.0138],
          [-0.0000, -0.0518,  0.0360],
          [ 0.0449, -0.0231,  0.0287]],

         [[ 0.0189, -0.034

## Inspect a Module

Let's inspect the (unpruned) ``conv1`` layer in our LeNet model. It will contain two 
parameters ``weight`` and ``bias``, and no buffers, for now.



In [6]:
module = model.tcascade1.conv1
print(list(module.named_parameters()))

[('weight', Parameter containing:
tensor([[[[ 0.2001,  0.2638, -0.0628],
          [-0.1813,  0.0317, -0.1810],
          [-0.3260,  0.3072, -0.2151]]],


        [[[ 0.1975, -0.2577, -0.0574],
          [ 0.1201, -0.2992, -0.1244],
          [ 0.0175, -0.2571, -0.3195]]],


        [[[-0.0764,  0.1673, -0.2599],
          [ 0.0737, -0.2706, -0.1304],
          [-0.1478,  0.1394,  0.3143]]],


        [[[ 0.0226,  0.1425, -0.3321],
          [ 0.3090, -0.2263,  0.0307],
          [-0.1808,  0.1974,  0.1166]]],


        [[[-0.2904, -0.2400,  0.0246],
          [-0.1724,  0.0355, -0.1658],
          [-0.0176, -0.1508, -0.3249]]],


        [[[-0.0133,  0.1745, -0.1607],
          [ 0.1716,  0.1826,  0.1978],
          [ 0.1046, -0.0289,  0.1318]]],


        [[[ 0.1915, -0.3072,  0.0677],
          [ 0.1486,  0.0074,  0.0396],
          [-0.2496, -0.0719,  0.2147]]],


        [[[ 0.3088,  0.2707,  0.2863],
          [ 0.1902,  0.1706, -0.3112],
          [-0.2775, -0.3095,  0.1772]]],


In [7]:
print(list(module.named_buffers()))

[]


## Pruning a Module

To prune a module (in this example, the ``conv1`` layer of our LeNet 
architecture), first select a pruning technique among those available in 
``torch.nn.utils.prune`` (or
[implement](#extending-torch-nn-utils-pruning-with-custom-pruning-functions)
your own by subclassing 
``BasePruningMethod``). Then, specify the module and the name of the parameter to 
prune within that module. Finally, using the adequate keyword arguments 
required by the selected pruning technique, specify the pruning parameters.

In this example, we will prune at random 30% of the connections in 
the parameter named ``weight`` in the ``conv1`` layer.
The module is passed as the first argument to the function; ``name`` 
identifies the parameter within that module using its string identifier; and 
``amount`` indicates either the percentage of connections to prune (if it 
is a float between 0. and 1.), or the absolute number of connections to 
prune (if it is a non-negative integer).



In [8]:
prune.random_unstructured(module, name="weight", amount=0.3)

Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

Pruning acts by removing ``weight`` from the parameters and replacing it with 
a new parameter called ``weight_orig`` (i.e. appending ``"_orig"`` to the 
initial parameter ``name``). ``weight_orig`` stores the unpruned version of 
the tensor. The ``bias`` was not pruned, so it will remain intact.



In [9]:
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([ 0.0257, -0.1146, -0.2197, -0.0810, -0.2314, -0.0029, -0.1543,  0.2895,
        -0.1083,  0.2504, -0.0092,  0.2466,  0.1418,  0.1376,  0.0554, -0.2786,
        -0.2634,  0.1401, -0.0530, -0.1261,  0.0748,  0.1323, -0.0645, -0.3297,
        -0.0208, -0.3264,  0.1944,  0.2929,  0.3126, -0.1498,  0.1150,  0.3315],
       device='cuda:0', requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.2001,  0.2638, -0.0628],
          [-0.1813,  0.0317, -0.1810],
          [-0.3260,  0.3072, -0.2151]]],


        [[[ 0.1975, -0.2577, -0.0574],
          [ 0.1201, -0.2992, -0.1244],
          [ 0.0175, -0.2571, -0.3195]]],


        [[[-0.0764,  0.1673, -0.2599],
          [ 0.0737, -0.2706, -0.1304],
          [-0.1478,  0.1394,  0.3143]]],


        [[[ 0.0226,  0.1425, -0.3321],
          [ 0.3090, -0.2263,  0.0307],
          [-0.1808,  0.1974,  0.1166]]],


        [[[-0.2904, -0.2400,  0.0246],
          [-0.1724,  0.0355, -0.1658],
 

The pruning mask generated by the pruning technique selected above is saved 
as a module buffer named ``weight_mask`` (i.e. appending ``"_mask"`` to the 
initial parameter ``name``).



In [10]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 0., 0.],
          [1., 1., 1.],
          [0., 1., 0.]]],


        [[[1., 0., 0.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [1., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [1., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [0., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 0.],
          [0., 1., 1.],
          [1., 0., 1.]]],


        [[[1., 1., 1.],
          [0., 0., 0.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [0., 0., 1.],
          [0., 1., 1.]]],


        [[[1., 0., 1.],
          [0., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 0.],
          [0., 0., 1.]]],


        [[[1., 0., 1.],
          [0., 1., 0.],
          [1., 1., 1.]]

For the forward pass to work without modification, the ``weight`` attribute 
needs to exist. The pruning techniques implemented in 
``torch.nn.utils.prune`` compute the pruned version of the weight (by 
combining the mask with the original parameter) and store them in the 
attribute ``weight``. Note, this is no longer a parameter of the ``module``,
it is now simply an attribute.



In [11]:
print(module.weight)

tensor([[[[ 0.2001,  0.0000, -0.0000],
          [-0.1813,  0.0317, -0.1810],
          [-0.0000,  0.3072, -0.0000]]],


        [[[ 0.1975, -0.0000, -0.0000],
          [ 0.1201, -0.2992, -0.1244],
          [ 0.0175, -0.2571, -0.3195]]],


        [[[-0.0764,  0.1673, -0.0000],
          [ 0.0737, -0.0000, -0.1304],
          [-0.1478,  0.1394,  0.3143]]],


        [[[ 0.0226,  0.1425, -0.0000],
          [ 0.3090, -0.0000,  0.0307],
          [-0.1808,  0.1974,  0.1166]]],


        [[[-0.2904, -0.2400,  0.0246],
          [-0.0000,  0.0355, -0.1658],
          [-0.0176, -0.1508, -0.3249]]],


        [[[-0.0133,  0.1745, -0.1607],
          [ 0.1716,  0.1826,  0.1978],
          [ 0.1046, -0.0289,  0.1318]]],


        [[[ 0.1915, -0.0000,  0.0000],
          [ 0.0000,  0.0074,  0.0396],
          [-0.2496, -0.0000,  0.2147]]],


        [[[ 0.3088,  0.2707,  0.2863],
          [ 0.0000,  0.0000, -0.0000],
          [-0.2775, -0.3095,  0.1772]]],


        [[[ 0.1118,  0.1810, -0.

Finally, pruning is applied prior to each forward pass using PyTorch's
``forward_pre_hooks``. Specifically, when the ``module`` is pruned, as we 
have done here, it will acquire a ``forward_pre_hook`` for each parameter 
associated with it that gets pruned. In this case, since we have so far 
only pruned the original parameter named ``weight``, only one hook will be
present.



In [12]:
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7fa257019978>)])


For completeness, we can now prune the ``bias`` too, to see how the 
parameters, buffers, hooks, and attributes of the ``module`` change.
Just for the sake of trying out another pruning technique, here we prune the 
3 smallest entries in the bias by L1 norm, as implemented in the 
``l1_unstructured`` pruning function.



In [13]:
prune.l1_unstructured(module, name="bias", amount=3)

Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

We now expect the named parameters to include both ``weight_orig`` (from 
before) and ``bias_orig``. The buffers will include ``weight_mask`` and 
``bias_mask``. The pruned versions of the two tensors will exist as 
module attributes, and the module will now have two ``forward_pre_hooks``.



In [14]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[ 0.2001,  0.2638, -0.0628],
          [-0.1813,  0.0317, -0.1810],
          [-0.3260,  0.3072, -0.2151]]],


        [[[ 0.1975, -0.2577, -0.0574],
          [ 0.1201, -0.2992, -0.1244],
          [ 0.0175, -0.2571, -0.3195]]],


        [[[-0.0764,  0.1673, -0.2599],
          [ 0.0737, -0.2706, -0.1304],
          [-0.1478,  0.1394,  0.3143]]],


        [[[ 0.0226,  0.1425, -0.3321],
          [ 0.3090, -0.2263,  0.0307],
          [-0.1808,  0.1974,  0.1166]]],


        [[[-0.2904, -0.2400,  0.0246],
          [-0.1724,  0.0355, -0.1658],
          [-0.0176, -0.1508, -0.3249]]],


        [[[-0.0133,  0.1745, -0.1607],
          [ 0.1716,  0.1826,  0.1978],
          [ 0.1046, -0.0289,  0.1318]]],


        [[[ 0.1915, -0.3072,  0.0677],
          [ 0.1486,  0.0074,  0.0396],
          [-0.2496, -0.0719,  0.2147]]],


        [[[ 0.3088,  0.2707,  0.2863],
          [ 0.1902,  0.1706, -0.3112],
          [-0.2775, -0.3095,  0.1772

In [15]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 0., 0.],
          [1., 1., 1.],
          [0., 1., 0.]]],


        [[[1., 0., 0.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [1., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [1., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [0., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 0.],
          [0., 1., 1.],
          [1., 0., 1.]]],


        [[[1., 1., 1.],
          [0., 0., 0.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [0., 0., 1.],
          [0., 1., 1.]]],


        [[[1., 0., 1.],
          [0., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 0.],
          [0., 0., 1.]]],


        [[[1., 0., 1.],
          [0., 1., 0.],
          [1., 1., 1.]]

In [16]:
print(module.bias)

tensor([ 0.0257, -0.1146, -0.2197, -0.0810, -0.2314, -0.0000, -0.1543,  0.2895,
        -0.1083,  0.2504, -0.0000,  0.2466,  0.1418,  0.1376,  0.0554, -0.2786,
        -0.2634,  0.1401, -0.0530, -0.1261,  0.0748,  0.1323, -0.0645, -0.3297,
        -0.0000, -0.3264,  0.1944,  0.2929,  0.3126, -0.1498,  0.1150,  0.3315],
       device='cuda:0', grad_fn=<MulBackward0>)


In [17]:
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7fa257019978>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7fa257007b38>)])


## Iterative Pruning

The same parameter in a module can be pruned multiple times, with the 
effect of the various pruning calls being equal to the combination of the
various masks applied in series.
The combination of a new mask with the old mask is handled by the 
``PruningContainer``'s ``compute_mask`` method.

Say, for example, that we now want to further prune ``module.weight``, this
time using structured pruning along the 0th axis of the tensor (the 0th axis 
corresponds to the output channels of the convolutional layer and has 
dimensionality 6 for ``conv1``), based on the channels' L2 norm. This can be 
achieved using the ``ln_structured`` function, with ``n=2`` and ``dim=0``.



In [18]:
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# As we can verify, this will zero out all the connections corresponding to 
# 50% (3 out of 6) of the channels, while preserving the action of the 
# previous mask.
print(module.weight)

tensor([[[[ 0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000]]],


        [[[ 0.1975, -0.0000, -0.0000],
          [ 0.1201, -0.2992, -0.1244],
          [ 0.0175, -0.2571, -0.3195]]],


        [[[-0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[-0.2904, -0.2400,  0.0246],
          [-0.0000,  0.0355, -0.1658],
          [-0.0176, -0.1508, -0.3249]]],


        [[[-0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]]],


        [[[ 0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000]]],


        [[[ 0.3088,  0.2707,  0.2863],
          [ 0.0000,  0.0000, -0.0000],
          [-0.2775, -0.3095,  0.1772]]],


        [[[ 0.0000,  0.0000, -0.

The corresponding hook will now be of type 
``torch.nn.utils.prune.PruningContainer``, and will store the history of 
pruning applied to the ``weight`` parameter.



In [19]:
for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container

[<torch.nn.utils.prune.RandomUnstructured object at 0x7fa257019978>, <torch.nn.utils.prune.LnStructured object at 0x7fa2548f7f28>]


## Serializing a pruned model
All relevant tensors, including the mask buffers and the original parameters
used to compute the pruned tensors are stored in the model's ``state_dict`` 
and can therefore be easily serialized and saved, if needed.



In [20]:
print(model.state_dict().keys())

odict_keys(['tcascade1.conv1.weight_orig', 'tcascade1.conv1.bias_orig', 'tcascade1.conv1.weight_mask', 'tcascade1.conv1.bias_mask', 'tcascade1.conv2.weight', 'tcascade1.conv2.bias', 'tcascade1.conv3.weight', 'tcascade1.conv3.bias', 'tcascade1.conv4.weight', 'tcascade1.conv4.bias', 'tcascade1.conv5.weight', 'tcascade1.conv5.bias', 'tcascade2.conv1.weight', 'tcascade2.conv1.bias', 'tcascade2.conv2.weight', 'tcascade2.conv2.bias', 'tcascade2.conv3.weight', 'tcascade2.conv3.bias', 'tcascade2.conv4.weight', 'tcascade2.conv4.bias', 'tcascade2.conv5.weight', 'tcascade2.conv5.bias', 'tcascade3.conv1.weight', 'tcascade3.conv1.bias', 'tcascade3.conv2.weight', 'tcascade3.conv2.bias', 'tcascade3.conv3.weight', 'tcascade3.conv3.bias', 'tcascade3.conv4.weight', 'tcascade3.conv4.bias', 'tcascade3.conv5.weight', 'tcascade3.conv5.bias', 'tcascade4.conv1.weight', 'tcascade4.conv1.bias', 'tcascade4.conv2.weight', 'tcascade4.conv2.bias', 'tcascade4.conv3.weight', 'tcascade4.conv3.bias', 'tcascade4.conv4.w

In [52]:
# print(student.state_dict().keys())

odict_keys(['tcascade1.conv1.bias', 'tcascade1.conv1.weight', 'tcascade1.conv2.bias', 'tcascade1.conv2.weight', 'tcascade1.conv3.bias', 'tcascade1.conv3.weight', 'tcascade1.conv4.bias', 'tcascade1.conv4.weight', 'tcascade1.conv5.bias', 'tcascade1.conv5.weight', 'tcascade2.conv1.bias', 'tcascade2.conv1.weight', 'tcascade2.conv2.bias', 'tcascade2.conv2.weight', 'tcascade2.conv3.bias', 'tcascade2.conv3.weight', 'tcascade2.conv4.bias', 'tcascade2.conv4.weight', 'tcascade2.conv5.bias', 'tcascade2.conv5.weight', 'tcascade3.conv1.bias', 'tcascade3.conv1.weight', 'tcascade3.conv2.bias', 'tcascade3.conv2.weight', 'tcascade3.conv3.bias', 'tcascade3.conv3.weight', 'tcascade3.conv4.bias', 'tcascade3.conv4.weight', 'tcascade3.conv5.bias', 'tcascade3.conv5.weight', 'tcascade4.conv1.bias', 'tcascade4.conv1.weight', 'tcascade4.conv2.bias', 'tcascade4.conv2.weight', 'tcascade4.conv3.bias', 'tcascade4.conv3.weight', 'tcascade4.conv4.bias', 'tcascade4.conv4.weight', 'tcascade4.conv5.bias', 'tcascade4.con

## Remove pruning re-parametrization

To make the pruning permanent, remove the re-parametrization in terms
of ``weight_orig`` and ``weight_mask``, and remove the ``forward_pre_hook``,
we can use the ``remove`` functionality from ``torch.nn.utils.prune``.
Note that this doesn't undo the pruning, as if it never happened. It simply 
makes it permanent, instead, by reassigning the parameter ``weight`` to the 
model parameters, in its pruned version.



Prior to removing the re-parametrization:



In [21]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[ 0.2001,  0.2638, -0.0628],
          [-0.1813,  0.0317, -0.1810],
          [-0.3260,  0.3072, -0.2151]]],


        [[[ 0.1975, -0.2577, -0.0574],
          [ 0.1201, -0.2992, -0.1244],
          [ 0.0175, -0.2571, -0.3195]]],


        [[[-0.0764,  0.1673, -0.2599],
          [ 0.0737, -0.2706, -0.1304],
          [-0.1478,  0.1394,  0.3143]]],


        [[[ 0.0226,  0.1425, -0.3321],
          [ 0.3090, -0.2263,  0.0307],
          [-0.1808,  0.1974,  0.1166]]],


        [[[-0.2904, -0.2400,  0.0246],
          [-0.1724,  0.0355, -0.1658],
          [-0.0176, -0.1508, -0.3249]]],


        [[[-0.0133,  0.1745, -0.1607],
          [ 0.1716,  0.1826,  0.1978],
          [ 0.1046, -0.0289,  0.1318]]],


        [[[ 0.1915, -0.3072,  0.0677],
          [ 0.1486,  0.0074,  0.0396],
          [-0.2496, -0.0719,  0.2147]]],


        [[[ 0.3088,  0.2707,  0.2863],
          [ 0.1902,  0.1706, -0.3112],
          [-0.2775, -0.3095,  0.1772

In [22]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[1., 0., 0.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[1., 1., 1.],
          [0., 1., 1.],
          [1., 1., 1.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[1., 1., 1.],
          [0., 0., 0.],
          [1., 1., 1.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[1., 0., 1.],
          [0., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]

In [23]:
print(module.weight)

tensor([[[[ 0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000]]],


        [[[ 0.1975, -0.0000, -0.0000],
          [ 0.1201, -0.2992, -0.1244],
          [ 0.0175, -0.2571, -0.3195]]],


        [[[-0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[-0.2904, -0.2400,  0.0246],
          [-0.0000,  0.0355, -0.1658],
          [-0.0176, -0.1508, -0.3249]]],


        [[[-0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]]],


        [[[ 0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000]]],


        [[[ 0.3088,  0.2707,  0.2863],
          [ 0.0000,  0.0000, -0.0000],
          [-0.2775, -0.3095,  0.1772]]],


        [[[ 0.0000,  0.0000, -0.

After removing the re-parametrization:



In [24]:
prune.remove(module, 'weight')
print(list(module.named_parameters()))

[('bias_orig', Parameter containing:
tensor([ 0.0257, -0.1146, -0.2197, -0.0810, -0.2314, -0.0029, -0.1543,  0.2895,
        -0.1083,  0.2504, -0.0092,  0.2466,  0.1418,  0.1376,  0.0554, -0.2786,
        -0.2634,  0.1401, -0.0530, -0.1261,  0.0748,  0.1323, -0.0645, -0.3297,
        -0.0208, -0.3264,  0.1944,  0.2929,  0.3126, -0.1498,  0.1150,  0.3315],
       device='cuda:0', requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000]]],


        [[[ 0.1975, -0.0000, -0.0000],
          [ 0.1201, -0.2992, -0.1244],
          [ 0.0175, -0.2571, -0.3195]]],


        [[[-0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[-0.2904, -0.2400,  0.0246],
          [-0.0000,  0.0355, -0.1658],
 

In [25]:
print(list(module.named_buffers()))

[('bias_mask', tensor([1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1.],
       device='cuda:0'))]


## Pruning multiple parameters in a model 

By specifying the desired pruning technique and parameters, we can easily 
prune multiple tensors in a network, perhaps according to their type, as we 
will see in this example.



In [27]:
new_model = DCTeacherNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers 
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers 
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

dict_keys(['tcascade1.conv1.weight_mask', 'tcascade1.conv2.weight_mask', 'tcascade1.conv3.weight_mask', 'tcascade1.conv4.weight_mask', 'tcascade1.conv5.weight_mask', 'tcascade2.conv1.weight_mask', 'tcascade2.conv2.weight_mask', 'tcascade2.conv3.weight_mask', 'tcascade2.conv4.weight_mask', 'tcascade2.conv5.weight_mask', 'tcascade3.conv1.weight_mask', 'tcascade3.conv2.weight_mask', 'tcascade3.conv3.weight_mask', 'tcascade3.conv4.weight_mask', 'tcascade3.conv5.weight_mask', 'tcascade4.conv1.weight_mask', 'tcascade4.conv2.weight_mask', 'tcascade4.conv3.weight_mask', 'tcascade4.conv4.weight_mask', 'tcascade4.conv5.weight_mask', 'tcascade5.conv1.weight_mask', 'tcascade5.conv2.weight_mask', 'tcascade5.conv3.weight_mask', 'tcascade5.conv4.weight_mask', 'tcascade5.conv5.weight_mask'])


## Global pruning

So far, we only looked at what is usually referred to as "local" pruning,
i.e. the practice of pruning tensors in a model one by one, by 
comparing the statistics (weight magnitude, activation, gradient, etc.) of 
each entry exclusively to the other entries in that tensor. However, a 
common and perhaps more powerful technique is to prune the model all at 
once, by removing (for example) the lowest 20% of connections across the 
whole model, instead of removing the lowest 20% of connections in each 
layer. This is likely to result in different pruning percentages per layer.
Let's see how to do that using ``global_unstructured`` from 
``torch.nn.utils.prune``.



In [35]:
# model = LeNet()
model = DCTeacherNet()

# parameters_to_prune = (
#     (model.conv1, 'weight'),
#     (model.conv2, 'weight'),
#     (model.fc1, 'weight'),
#     (model.fc2, 'weight'),
#     (model.fc3, 'weight'),
# )
parameters_to_prune = (
    (model.tcascade1.conv1, 'weight'),
    (model.tcascade1.conv2, 'weight'),
    (model.tcascade1.conv3, 'weight'),
    (model.tcascade1.conv4, 'weight'),
    (model.tcascade1.conv5, 'weight'),
    (model.tcascade2.conv1, 'weight'),
    (model.tcascade2.conv2, 'weight'),
    (model.tcascade2.conv3, 'weight'),
    (model.tcascade2.conv4, 'weight'),
    (model.tcascade2.conv5, 'weight'),
    (model.tcascade3.conv1, 'weight'),
    (model.tcascade3.conv2, 'weight'),
    (model.tcascade3.conv3, 'weight'),
    (model.tcascade3.conv4, 'weight'),
    (model.tcascade3.conv5, 'weight'),
    (model.tcascade4.conv1, 'weight'),
    (model.tcascade4.conv2, 'weight'),
    (model.tcascade4.conv3, 'weight'),
    (model.tcascade4.conv4, 'weight'),
    (model.tcascade4.conv5, 'weight'),
    (model.tcascade5.conv1, 'weight'),
    (model.tcascade5.conv2, 'weight'),
    (model.tcascade5.conv3, 'weight'),
    (model.tcascade5.conv4, 'weight'),
    (model.tcascade5.conv5, 'weight'),
)
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

Now we can check the sparsity induced in every pruned parameter, which will 
not be equal to 20% in each layer. However, the global sparsity will be 
(approximately) 20%.



In [36]:
# print(
#     "Sparsity in conv1.weight: {:.2f}%".format(
#         100. * float(torch.sum(model.conv1.weight == 0))
#         / float(model.conv1.weight.nelement())
#     )
# )
# print(
#     "Sparsity in conv2.weight: {:.2f}%".format(
#         100. * float(torch.sum(model.conv2.weight == 0))
#         / float(model.conv2.weight.nelement())
#     )
# )
# print(
#     "Sparsity in fc1.weight: {:.2f}%".format(
#         100. * float(torch.sum(model.fc1.weight == 0))
#         / float(model.fc1.weight.nelement())
#     )
# )
# print(
#     "Sparsity in fc2.weight: {:.2f}%".format(
#         100. * float(torch.sum(model.fc2.weight == 0))
#         / float(model.fc2.weight.nelement())
#     )
# )
# print(
#     "Sparsity in fc3.weight: {:.2f}%".format(
#         100. * float(torch.sum(model.fc3.weight == 0))
#         / float(model.fc3.weight.nelement())
#     )
# )
# print(
#     "Global sparsity: {:.2f}%".format(
#         100. * float(
#             torch.sum(model.conv1.weight == 0)
#             + torch.sum(model.conv2.weight == 0)
#             + torch.sum(model.fc1.weight == 0)
#             + torch.sum(model.fc2.weight == 0)
#             + torch.sum(model.fc3.weight == 0)
#         )
#         / float(
#             model.conv1.weight.nelement()
#             + model.conv2.weight.nelement()
#             + model.fc1.weight.nelement()
#             + model.fc2.weight.nelement()
#             + model.fc3.weight.nelement()
#         )
#     )
# )

In [37]:
print(
    "Sparsity in tcascade1.conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.tcascade1.conv1.weight == 0))
        / float(model.tcascade1.conv1.weight.nelement())
    )
)
print(
    "Sparsity in tcascade1.conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.tcascade1.conv2.weight == 0))
        / float(model.tcascade1.conv2.weight.nelement())
    )
)
print(
    "Sparsity in tcascade1.conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.tcascade1.conv3.weight == 0))
        / float(model.tcascade1.conv3.weight.nelement())
    )
)
print(
    "Sparsity in tcascade1.conv4.weight: {:.2f}%".format(
        100. * float(torch.sum(model.tcascade1.conv4.weight == 0))
        / float(model.tcascade1.conv4.weight.nelement())
    )
)
print(
    "Sparsity in tcascade1.conv5.weight: {:.2f}%".format(
        100. * float(torch.sum(model.tcascade1.conv5.weight == 0))
        / float(model.tcascade1.conv5.weight.nelement())
    )
)
print(
    "Sparsity in tcascade2.conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.tcascade2.conv1.weight == 0))
        / float(model.tcascade2.conv1.weight.nelement())
    )
)
print(
    "Sparsity in tcascade2.conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.tcascade2.conv2.weight == 0))
        / float(model.tcascade2.conv2.weight.nelement())
    )
)
print(
    "Sparsity in tcascade3.conv5.weight: {:.2f}%".format(
        100. * float(torch.sum(model.tcascade3.conv5.weight == 0))
        / float(model.tcascade3.conv5.weight.nelement())
    )
)
print(
    "Sparsity in tcascade5.conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.tcascade5.conv3.weight == 0))
        / float(model.tcascade5.conv3.weight.nelement())
    )
)
print(
    "Sparsity in tcascade4.conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.tcascade4.conv1.weight == 0))
        / float(model.tcascade4.conv1.weight.nelement())
    )
)
print(
    "Sparsity in tcascade5.conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.tcascade5.conv1.weight == 0))
        / float(model.tcascade5.conv1.weight.nelement())
    )
)


print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.tcascade1.conv1.weight == 0)
            + torch.sum(model.tcascade1.conv2.weight == 0)
            + torch.sum(model.tcascade1.conv3.weight == 0)
            + torch.sum(model.tcascade1.conv4.weight == 0)
            + torch.sum(model.tcascade1.conv5.weight == 0)
            +torch.sum(model.tcascade2.conv1.weight == 0)
            + torch.sum(model.tcascade2.conv2.weight == 0)
            + torch.sum(model.tcascade2.conv3.weight == 0)
            + torch.sum(model.tcascade2.conv4.weight == 0)
            + torch.sum(model.tcascade2.conv5.weight == 0)
            +torch.sum(model.tcascade3.conv1.weight == 0)
            + torch.sum(model.tcascade3.conv2.weight == 0)
            + torch.sum(model.tcascade3.conv3.weight == 0)
            + torch.sum(model.tcascade3.conv4.weight == 0)
            + torch.sum(model.tcascade3.conv5.weight == 0)
            +torch.sum(model.tcascade4.conv1.weight == 0)
            + torch.sum(model.tcascade4.conv2.weight == 0)
            + torch.sum(model.tcascade4.conv3.weight == 0)
            + torch.sum(model.tcascade4.conv4.weight == 0)
            + torch.sum(model.tcascade4.conv5.weight == 0)
            +torch.sum(model.tcascade5.conv1.weight == 0)
            + torch.sum(model.tcascade5.conv2.weight == 0)
            + torch.sum(model.tcascade5.conv3.weight == 0)
            + torch.sum(model.tcascade5.conv4.weight == 0)
            + torch.sum(model.tcascade5.conv5.weight == 0)
        )
        / float(
            model.tcascade1.conv1.weight.nelement()
            + model.tcascade1.conv2.weight.nelement()
            + model.tcascade1.conv3.weight.nelement()
            + model.tcascade1.conv4.weight.nelement()
            + model.tcascade1.conv5.weight.nelement()
            
            + model.tcascade2.conv1.weight.nelement()
            + model.tcascade2.conv2.weight.nelement()
            + model.tcascade2.conv3.weight.nelement()
            + model.tcascade2.conv4.weight.nelement()
            + model.tcascade2.conv5.weight.nelement()
            
            + model.tcascade3.conv1.weight.nelement()
            + model.tcascade3.conv2.weight.nelement()
            + model.tcascade3.conv3.weight.nelement()
            + model.tcascade3.conv4.weight.nelement()
            + model.tcascade3.conv5.weight.nelement()
            
            + model.tcascade4.conv1.weight.nelement()
            + model.tcascade4.conv2.weight.nelement()
            + model.tcascade4.conv3.weight.nelement()
            + model.tcascade4.conv4.weight.nelement()
            + model.tcascade4.conv5.weight.nelement()
            
            + model.tcascade5.conv1.weight.nelement()
            + model.tcascade5.conv2.weight.nelement()
            + model.tcascade5.conv3.weight.nelement()
            + model.tcascade5.conv4.weight.nelement()
            + model.tcascade5.conv5.weight.nelement()            
        )
    )
)

Sparsity in tcascade1.conv1.weight: 3.82%
Sparsity in tcascade1.conv2.weight: 19.63%
Sparsity in tcascade1.conv3.weight: 21.17%
Sparsity in tcascade1.conv4.weight: 20.24%
Sparsity in tcascade1.conv5.weight: 14.58%
Sparsity in tcascade2.conv1.weight: 4.17%
Sparsity in tcascade2.conv2.weight: 19.91%
Sparsity in tcascade3.conv5.weight: 18.06%
Sparsity in tcascade5.conv3.weight: 20.19%
Sparsity in tcascade4.conv1.weight: 3.47%
Sparsity in tcascade5.conv1.weight: 4.51%
Global sparsity: 20.00%


In [38]:
# model = ConvNet()
summary(model, [(1, 240, 240),(1,240,240,2),(1,240,240)])

Layer (type:depth-idx)                   Output Shape              Param #
├─TeacherNet: 1-1                        [-1, 32, 240, 240]        --
|    └─Conv2d: 2-1                       [-1, 32, 240, 240]        320
|    └─Conv2d: 2-2                       [-1, 32, 240, 240]        9,248
|    └─Conv2d: 2-3                       [-1, 32, 240, 240]        9,248
|    └─Conv2d: 2-4                       [-1, 32, 240, 240]        9,248
|    └─Conv2d: 2-5                       [-1, 1, 240, 240]         289
├─DataConsistencyLayer: 1-2              [-1, 1, 240, 240]         --
├─TeacherNet: 1-3                        [-1, 32, 240, 240]        --
|    └─Conv2d: 2-6                       [-1, 32, 240, 240]        320
|    └─Conv2d: 2-7                       [-1, 32, 240, 240]        9,248
|    └─Conv2d: 2-8                       [-1, 32, 240, 240]        9,248
|    └─Conv2d: 2-9                       [-1, 32, 240, 240]        9,248
|    └─Conv2d: 2-10                      [-1, 1, 240, 240]      

Layer (type:depth-idx)                   Output Shape              Param #
├─TeacherNet: 1-1                        [-1, 32, 240, 240]        --
|    └─Conv2d: 2-1                       [-1, 32, 240, 240]        320
|    └─Conv2d: 2-2                       [-1, 32, 240, 240]        9,248
|    └─Conv2d: 2-3                       [-1, 32, 240, 240]        9,248
|    └─Conv2d: 2-4                       [-1, 32, 240, 240]        9,248
|    └─Conv2d: 2-5                       [-1, 1, 240, 240]         289
├─DataConsistencyLayer: 1-2              [-1, 1, 240, 240]         --
├─TeacherNet: 1-3                        [-1, 32, 240, 240]        --
|    └─Conv2d: 2-6                       [-1, 32, 240, 240]        320
|    └─Conv2d: 2-7                       [-1, 32, 240, 240]        9,248
|    └─Conv2d: 2-8                       [-1, 32, 240, 240]        9,248
|    └─Conv2d: 2-9                       [-1, 32, 240, 240]        9,248
|    └─Conv2d: 2-10                      [-1, 1, 240, 240]      

In [48]:
print(dict(model.named_buffers()).keys()) 

dict_keys(['tcascade1.conv1.weight_mask', 'tcascade1.conv2.weight_mask', 'tcascade1.conv3.weight_mask', 'tcascade1.conv4.weight_mask', 'tcascade1.conv5.weight_mask', 'tcascade2.conv1.weight_mask', 'tcascade2.conv2.weight_mask', 'tcascade2.conv3.weight_mask', 'tcascade2.conv4.weight_mask', 'tcascade2.conv5.weight_mask', 'tcascade3.conv1.weight_mask', 'tcascade3.conv2.weight_mask', 'tcascade3.conv3.weight_mask', 'tcascade3.conv4.weight_mask', 'tcascade3.conv5.weight_mask', 'tcascade4.conv1.weight_mask', 'tcascade4.conv2.weight_mask', 'tcascade4.conv3.weight_mask', 'tcascade4.conv4.weight_mask', 'tcascade4.conv5.weight_mask', 'tcascade5.conv1.weight_mask', 'tcascade5.conv2.weight_mask', 'tcascade5.conv3.weight_mask', 'tcascade5.conv4.weight_mask', 'tcascade5.conv5.weight_mask'])


In [49]:
print(model.state_dict().keys())

odict_keys(['tcascade1.conv1.bias', 'tcascade1.conv1.weight_orig', 'tcascade1.conv1.weight_mask', 'tcascade1.conv2.bias', 'tcascade1.conv2.weight_orig', 'tcascade1.conv2.weight_mask', 'tcascade1.conv3.bias', 'tcascade1.conv3.weight_orig', 'tcascade1.conv3.weight_mask', 'tcascade1.conv4.bias', 'tcascade1.conv4.weight_orig', 'tcascade1.conv4.weight_mask', 'tcascade1.conv5.bias', 'tcascade1.conv5.weight_orig', 'tcascade1.conv5.weight_mask', 'tcascade2.conv1.bias', 'tcascade2.conv1.weight_orig', 'tcascade2.conv1.weight_mask', 'tcascade2.conv2.bias', 'tcascade2.conv2.weight_orig', 'tcascade2.conv2.weight_mask', 'tcascade2.conv3.bias', 'tcascade2.conv3.weight_orig', 'tcascade2.conv3.weight_mask', 'tcascade2.conv4.bias', 'tcascade2.conv4.weight_orig', 'tcascade2.conv4.weight_mask', 'tcascade2.conv5.bias', 'tcascade2.conv5.weight_orig', 'tcascade2.conv5.weight_mask', 'tcascade3.conv1.bias', 'tcascade3.conv1.weight_orig', 'tcascade3.conv1.weight_mask', 'tcascade3.conv2.bias', 'tcascade3.conv2.w

In [50]:
# prune.remove(model, 'weight')
print(list(model.named_parameters()))

[('tcascade1.conv1.bias', Parameter containing:
tensor([ 0.3206,  0.2277, -0.0224, -0.1190, -0.1219, -0.2464, -0.3047,  0.1758,
         0.2369, -0.2260, -0.2610, -0.0997,  0.2779, -0.1167, -0.2425,  0.1505,
        -0.1895, -0.0439, -0.3020, -0.1299,  0.0540,  0.0038,  0.1878,  0.2220,
         0.1483,  0.0487, -0.0083,  0.1714, -0.2891, -0.3311,  0.2973, -0.2278],
       device='cuda:0', requires_grad=True)), ('tcascade1.conv1.weight_orig', Parameter containing:
tensor([[[[-0.0712, -0.1197, -0.1174],
          [ 0.0717, -0.0932,  0.3034],
          [-0.1384,  0.1265,  0.0146]]],


        [[[-0.1617,  0.1207,  0.1376],
          [ 0.2227, -0.2087,  0.1456],
          [ 0.0589, -0.0848,  0.1665]]],


        [[[ 0.0352,  0.2914,  0.0686],
          [ 0.0218, -0.0485, -0.2296],
          [ 0.0912, -0.0778,  0.1898]]],


        [[[ 0.2657,  0.1528,  0.1511],
          [-0.2501, -0.2871, -0.0963],
          [-0.2052,  0.0216, -0.0112]]],


        [[[-0.2729,  0.1264,  0.0964],
        

In [51]:
# prune.remove(parameters_to_prune)
print(list(student.named_parameters()))

[('tcascade1.conv1.bias', Parameter containing:
tensor([ 0.0311, -0.0142, -0.3228, -0.2544, -0.2019, -0.1662, -0.0742, -0.2472,
        -0.2262,  0.1288,  0.2413, -0.0203, -0.2599,  0.1825,  0.1367, -0.1379,
         0.2336, -0.0838,  0.3217,  0.1526, -0.2622, -0.2148,  0.1181,  0.3123,
        -0.0679,  0.0685, -0.3036,  0.1446, -0.2558, -0.0536, -0.2969,  0.2541],
       device='cuda:0', requires_grad=True)), ('tcascade1.conv1.weight', Parameter containing:
tensor([[[[-0.3107, -0.3135,  0.0000],
          [-0.2999,  0.0000, -0.0493],
          [ 0.0523, -0.2555,  0.2878]]],


        [[[-0.1362,  0.1813,  0.2024],
          [ 0.0422,  0.0719,  0.0399],
          [ 0.2823,  0.2276, -0.0504]]],


        [[[ 0.0235, -0.2656, -0.0618],
          [ 0.2222,  0.0725,  0.0357],
          [-0.2199,  0.0253,  0.0759]]],


        [[[ 0.2683,  0.1784, -0.2634],
          [-0.1947, -0.2500,  0.0634],
          [-0.2611,  0.0555,  0.3294]]],


        [[[ 0.0000, -0.2222,  0.2909],
          [ 0

In [28]:
for module, thing in parameters_to_prune:
    print(module)
#     prune.remove(module,'weight')

Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(1, 32, ke

## Extending ``torch.nn.utils.prune`` with custom pruning functions
To implement your own pruning function, you can extend the
``nn.utils.prune`` module by subclassing the ``BasePruningMethod``
base class, the same way all other pruning methods do. The base class
implements the following methods for you: ``__call__``, ``apply_mask``,
``apply``, ``prune``, and ``remove``. Beyond some special cases, you shouldn't
have to reimplement these methods for your new pruning technique.
You will, however, have to implement ``__init__`` (the constructor),
and ``compute_mask`` (the instructions on how to compute the mask
for the given tensor according to the logic of your pruning
technique). In addition, you will have to specify which type of
pruning this technique implements (supported options are ``global``,
``structured``, and ``unstructured``). This is needed to determine
how to combine masks in the case in which pruning is applied
iteratively. In other words, when pruning a pre-pruned parameter,
the current prunining techique is expected to act on the unpruned
portion of the parameter. Specifying the ``PRUNING_TYPE`` will
enable the ``PruningContainer`` (which handles the iterative
application of pruning masks) to correctly identify the slice of the
parameter to prune.

Let's assume, for example, that you want to implement a pruning
technique that prunes every other entry in a tensor (or -- if the
tensor has previously been pruned -- in the remaining unpruned
portion of the tensor). This will be of ``PRUNING_TYPE='unstructured'``
because it acts on individual connections in a layer and not on entire
units/channels (``'structured'``), or across different parameters
(``'global'``).



In [None]:
class FooBarPruningMethod(prune.BasePruningMethod):
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0 
        return mask

Now, to apply this to a parameter in an ``nn.Module``, you should
also provide a simple function that instantiates the method and
applies it.



In [None]:
def foobar_unstructured(module, name):
    """Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module) 
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the 
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the 
    original (unpruned) parameter is stored in a new parameter named 
    `name+'_orig'`.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module
    
    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
    """
    FooBarPruningMethod.apply(module, name)
    return module

Let's try it out!



In [None]:
model = LeNet()
foobar_unstructured(model.fc3, name='bias')

print(model.fc3.bias_mask)