In [1]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [7]:
device = torch.device("cuda")

class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()

    # 1 input image channel, 6 output  channels, 5x5 square conv kernal
    self.conv1 = nn.Conv2d(1,6,5)
    self.conv2 = nn.Conv2d(6,16,5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
    x = F.max_pool2d(F.relu(self.conv2(x)), 2)
    x = x.view(-1, int(x.nelement() / x.shape[0]))
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

model = LeNet().to(device=device)

In [8]:
module = model.conv1
print(list(module.named_parameters()))

[('weight', Parameter containing:
tensor([[[[ 0.0054,  0.1585, -0.0415, -0.0173,  0.1092],
          [ 0.0304,  0.1292, -0.0807,  0.0465, -0.1192],
          [ 0.1186,  0.0707, -0.1917,  0.0334, -0.1369],
          [-0.1165, -0.1298,  0.1269, -0.1175, -0.0077],
          [-0.0231, -0.1308,  0.1173,  0.1724, -0.1338]]],


        [[[-0.0367,  0.1611,  0.1250, -0.0599, -0.1525],
          [ 0.0013,  0.1350,  0.1147, -0.1733, -0.0767],
          [ 0.1225, -0.1981,  0.1737,  0.0225,  0.1848],
          [ 0.0158, -0.1383, -0.0958, -0.0053, -0.0439],
          [ 0.1980, -0.1148,  0.0593, -0.0423, -0.0234]]],


        [[[ 0.1603, -0.0760, -0.0866, -0.0023, -0.0098],
          [-0.1082, -0.0529, -0.1309,  0.1340,  0.1907],
          [ 0.0633,  0.1868,  0.0369,  0.0926, -0.1605],
          [-0.1349, -0.1510, -0.0604, -0.0316,  0.1803],
          [-0.0571,  0.1223, -0.1137,  0.1718, -0.0852]]],


        [[[ 0.0999,  0.1339,  0.0968,  0.0373, -0.0722],
          [-0.0567,  0.1202,  0.0716, -0.0

In [9]:
print(list(model.named_buffers()))

[]


In [10]:
prune.random_unstructured(module, name="weight", amount=0.3)

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

In [11]:
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([-0.1638,  0.0620, -0.1033,  0.0749, -0.1727, -0.1186], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.0054,  0.1585, -0.0415, -0.0173,  0.1092],
          [ 0.0304,  0.1292, -0.0807,  0.0465, -0.1192],
          [ 0.1186,  0.0707, -0.1917,  0.0334, -0.1369],
          [-0.1165, -0.1298,  0.1269, -0.1175, -0.0077],
          [-0.0231, -0.1308,  0.1173,  0.1724, -0.1338]]],


        [[[-0.0367,  0.1611,  0.1250, -0.0599, -0.1525],
          [ 0.0013,  0.1350,  0.1147, -0.1733, -0.0767],
          [ 0.1225, -0.1981,  0.1737,  0.0225,  0.1848],
          [ 0.0158, -0.1383, -0.0958, -0.0053, -0.0439],
          [ 0.1980, -0.1148,  0.0593, -0.0423, -0.0234]]],


        [[[ 0.1603, -0.0760, -0.0866, -0.0023, -0.0098],
          [-0.1082, -0.0529, -0.1309,  0.1340,  0.1907],
          [ 0.0633,  0.1868,  0.0369,  0.0926, -0.1605],
          [-0.1349, -0.1510, -0.0604, -0.0316,  0.1803],
          [-0.0

In [12]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 1., 1., 0.],
          [1., 0., 1., 1., 0.],
          [1., 1., 1., 0., 0.],
          [1., 1., 1., 1., 1.],
          [0., 1., 1., 0., 0.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 0.],
          [1., 0., 1., 0., 1.],
          [1., 1., 1., 0., 1.]]],


        [[[1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 1., 0., 1., 0.],
          [1., 0., 0., 1., 1.]]],


        [[[0., 1., 0., 0., 1.],
          [0., 1., 0., 1., 1.],
          [1., 1., 1., 1., 1.],
          [0., 1., 1., 1., 1.],
          [0., 0., 1., 1., 1.]]],


        [[[0., 0., 0., 1., 0.],
          [1., 1., 0., 1., 1.],
          [1., 0., 0., 1., 1.],
          [1., 1., 1., 1., 1.],
          [0., 1., 0., 1., 1.]]],


        [[[0., 1., 1., 1., 0.],
          [0., 1., 1., 0., 1.],
          [1., 1., 0., 1., 1.],
          [1., 0., 1., 1., 0.],
          [0., 1., 0., 1., 1.]]]], 

In [14]:
print(module.weight)

tensor([[[[ 0.0054,  0.1585, -0.0415, -0.0173,  0.0000],
          [ 0.0304,  0.0000, -0.0807,  0.0465, -0.0000],
          [ 0.1186,  0.0707, -0.1917,  0.0000, -0.0000],
          [-0.1165, -0.1298,  0.1269, -0.1175, -0.0077],
          [-0.0000, -0.1308,  0.1173,  0.0000, -0.0000]]],


        [[[-0.0367,  0.1611,  0.1250, -0.0599, -0.1525],
          [ 0.0013,  0.1350,  0.1147, -0.1733, -0.0767],
          [ 0.1225, -0.1981,  0.1737,  0.0000,  0.0000],
          [ 0.0158, -0.0000, -0.0958, -0.0000, -0.0439],
          [ 0.1980, -0.1148,  0.0593, -0.0000, -0.0234]]],


        [[[ 0.1603, -0.0760, -0.0866, -0.0023, -0.0000],
          [-0.1082, -0.0529, -0.1309,  0.1340,  0.0000],
          [ 0.0633,  0.1868,  0.0369,  0.0926, -0.1605],
          [-0.1349, -0.1510, -0.0000, -0.0316,  0.0000],
          [-0.0571,  0.0000, -0.0000,  0.1718, -0.0852]]],


        [[[ 0.0000,  0.1339,  0.0000,  0.0000, -0.0722],
          [-0.0000,  0.1202,  0.0000, -0.0564, -0.0484],
          [-0.1369,

In [15]:
print(module._forward_pre_hooks)

OrderedDict({0: <torch.nn.utils.prune.RandomUnstructured object at 0x7f8cde9b59a0>})


In [16]:
prune.l1_unstructured(module, name='bias', amount=3)

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

In [18]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[ 0.0054,  0.1585, -0.0415, -0.0173,  0.1092],
          [ 0.0304,  0.1292, -0.0807,  0.0465, -0.1192],
          [ 0.1186,  0.0707, -0.1917,  0.0334, -0.1369],
          [-0.1165, -0.1298,  0.1269, -0.1175, -0.0077],
          [-0.0231, -0.1308,  0.1173,  0.1724, -0.1338]]],


        [[[-0.0367,  0.1611,  0.1250, -0.0599, -0.1525],
          [ 0.0013,  0.1350,  0.1147, -0.1733, -0.0767],
          [ 0.1225, -0.1981,  0.1737,  0.0225,  0.1848],
          [ 0.0158, -0.1383, -0.0958, -0.0053, -0.0439],
          [ 0.1980, -0.1148,  0.0593, -0.0423, -0.0234]]],


        [[[ 0.1603, -0.0760, -0.0866, -0.0023, -0.0098],
          [-0.1082, -0.0529, -0.1309,  0.1340,  0.1907],
          [ 0.0633,  0.1868,  0.0369,  0.0926, -0.1605],
          [-0.1349, -0.1510, -0.0604, -0.0316,  0.1803],
          [-0.0571,  0.1223, -0.1137,  0.1718, -0.0852]]],


        [[[ 0.0999,  0.1339,  0.0968,  0.0373, -0.0722],
          [-0.0567,  0.1202,  0.0716,

In [19]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 1., 1., 0.],
          [1., 0., 1., 1., 0.],
          [1., 1., 1., 0., 0.],
          [1., 1., 1., 1., 1.],
          [0., 1., 1., 0., 0.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 0.],
          [1., 0., 1., 0., 1.],
          [1., 1., 1., 0., 1.]]],


        [[[1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 1., 0., 1., 0.],
          [1., 0., 0., 1., 1.]]],


        [[[0., 1., 0., 0., 1.],
          [0., 1., 0., 1., 1.],
          [1., 1., 1., 1., 1.],
          [0., 1., 1., 1., 1.],
          [0., 0., 1., 1., 1.]]],


        [[[0., 0., 0., 1., 0.],
          [1., 1., 0., 1., 1.],
          [1., 0., 0., 1., 1.],
          [1., 1., 1., 1., 1.],
          [0., 1., 0., 1., 1.]]],


        [[[0., 1., 1., 1., 0.],
          [0., 1., 1., 0., 1.],
          [1., 1., 0., 1., 1.],
          [1., 0., 1., 1., 0.],
          [0., 1., 0., 1., 1.]]]], 

In [20]:
print(module.bias)

tensor([-0.1638,  0.0000, -0.0000,  0.0000, -0.1727, -0.1186], device='cuda:0',
       grad_fn=<MulBackward0>)


In [21]:
print(module._forward_pre_hooks)

OrderedDict({0: <torch.nn.utils.prune.RandomUnstructured object at 0x7f8cde9b59a0>, 1: <torch.nn.utils.prune.L1Unstructured object at 0x7f8cde7c0bf0>})


# iterative pruning

In [24]:
prune.ln_structured(module, name='weight', amount=0.5, n=2, dim=0)
print(module.weight)

tensor([[[[ 0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000,  0.0000, -0.0000]]],


        [[[-0.0000,  0.0000,  0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000, -0.0000, -0.0000]]],


        [[[ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000,  0.0000, -0.0000]]],


        [[[ 0.0000,  0.0000,  0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000,

In [25]:
for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))

[<torch.nn.utils.prune.RandomUnstructured object at 0x7f8cde9b59a0>, <torch.nn.utils.prune.LnStructured object at 0x7f8cde7c21b0>, <torch.nn.utils.prune.LnStructured object at 0x7f8cde9b72f0>]


In [26]:
prune.remove(module, 'weight')
print(list(module.named_parameters()))

[('bias_orig', Parameter containing:
tensor([-0.1638,  0.0620, -0.1033,  0.0749, -0.1727, -0.1186], device='cuda:0',
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000,  0.0000, -0.0000]]],


        [[[-0.0000,  0.0000,  0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000, -0.0000, -0.0000]]],


        [[[ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000, -0.0000,  0.0000],
          [-0.0

In [27]:
model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters=parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2
)

In [28]:
print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

Sparsity in conv1.weight: 6.00%
Sparsity in conv2.weight: 13.71%
Sparsity in fc1.weight: 22.13%
Sparsity in fc2.weight: 12.23%
Sparsity in fc3.weight: 11.79%
Global sparsity: 20.00%
