In [1]:
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as prune

In [2]:
device = torch.device("cpu")

In [3]:
class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2,2))
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [4]:
model = LeNet()

In [6]:
model.modules

<bound method Module.modules of LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)>

In [7]:
model.children

<bound method Module.children of LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)>

In [8]:
model.named_children

<bound method Module.named_children of LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)>

In [9]:
model.named_modules

<bound method Module.named_modules of LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)>

In [7]:
module = model.conv1
list(module.named_parameters())

[('weight',
  Parameter containing:
  tensor([[[[-0.2487,  0.3102,  0.2069],
            [-0.0560, -0.1190, -0.3054],
            [-0.1513, -0.0013,  0.1598]]],
  
  
          [[[-0.3121, -0.2036, -0.0970],
            [-0.1208,  0.0731, -0.0441],
            [ 0.3108, -0.1555,  0.3227]]],
  
  
          [[[-0.0645, -0.0914, -0.0298],
            [ 0.3153, -0.1837,  0.1004],
            [ 0.0495,  0.3147, -0.1659]]],
  
  
          [[[ 0.2732,  0.0143,  0.2799],
            [ 0.2853,  0.0012, -0.2103],
            [-0.2442,  0.3331, -0.2162]]],
  
  
          [[[ 0.2175,  0.2603,  0.0033],
            [-0.0258,  0.1743,  0.2421],
            [ 0.2314, -0.1907, -0.1123]]],
  
  
          [[[-0.1852,  0.2778, -0.0128],
            [-0.1567,  0.1004, -0.1757],
            [ 0.1725,  0.0271,  0.2683]]]], requires_grad=True)),
 ('bias',
  Parameter containing:
  tensor([ 0.0315, -0.3157, -0.0437,  0.1053, -0.2900, -0.1562],
         requires_grad=True))]

In [11]:
list(module.named_buffers())

[]

In [12]:
prune.random_unstructured(module, name='weight', amount=0.3)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

In [13]:
list(module.named_parameters())

[('bias',
  Parameter containing:
  tensor([ 0.0315, -0.3157, -0.0437,  0.1053, -0.2900, -0.1562],
         requires_grad=True)),
 ('weight_orig',
  Parameter containing:
  tensor([[[[-0.2487,  0.3102,  0.2069],
            [-0.0560, -0.1190, -0.3054],
            [-0.1513, -0.0013,  0.1598]]],
  
  
          [[[-0.3121, -0.2036, -0.0970],
            [-0.1208,  0.0731, -0.0441],
            [ 0.3108, -0.1555,  0.3227]]],
  
  
          [[[-0.0645, -0.0914, -0.0298],
            [ 0.3153, -0.1837,  0.1004],
            [ 0.0495,  0.3147, -0.1659]]],
  
  
          [[[ 0.2732,  0.0143,  0.2799],
            [ 0.2853,  0.0012, -0.2103],
            [-0.2442,  0.3331, -0.2162]]],
  
  
          [[[ 0.2175,  0.2603,  0.0033],
            [-0.0258,  0.1743,  0.2421],
            [ 0.2314, -0.1907, -0.1123]]],
  
  
          [[[-0.1852,  0.2778, -0.0128],
            [-0.1567,  0.1004, -0.1757],
            [ 0.1725,  0.0271,  0.2683]]]], requires_grad=True))]

In [14]:
list(module.named_buffers())

[('weight_mask',
  tensor([[[[1., 0., 1.],
            [1., 0., 1.],
            [1., 1., 1.]]],
  
  
          [[[1., 1., 1.],
            [1., 1., 0.],
            [1., 1., 1.]]],
  
  
          [[[1., 1., 1.],
            [0., 1., 1.],
            [1., 1., 0.]]],
  
  
          [[[1., 1., 1.],
            [0., 0., 1.],
            [1., 0., 1.]]],
  
  
          [[[1., 1., 1.],
            [0., 1., 0.],
            [1., 0., 0.]]],
  
  
          [[[0., 1., 1.],
            [0., 1., 1.],
            [1., 0., 0.]]]]))]

In [18]:
module.weight

tensor([[[[-0.2487,  0.0000,  0.2069],
          [-0.0560, -0.0000, -0.3054],
          [-0.1513, -0.0013,  0.1598]]],


        [[[-0.3121, -0.2036, -0.0970],
          [-0.1208,  0.0731, -0.0000],
          [ 0.3108, -0.1555,  0.3227]]],


        [[[-0.0645, -0.0914, -0.0298],
          [ 0.0000, -0.1837,  0.1004],
          [ 0.0495,  0.3147, -0.0000]]],


        [[[ 0.2732,  0.0143,  0.2799],
          [ 0.0000,  0.0000, -0.2103],
          [-0.2442,  0.0000, -0.2162]]],


        [[[ 0.2175,  0.2603,  0.0033],
          [-0.0000,  0.1743,  0.0000],
          [ 0.2314, -0.0000, -0.0000]]],


        [[[-0.0000,  0.2778, -0.0128],
          [-0.0000,  0.1004, -0.1757],
          [ 0.1725,  0.0000,  0.0000]]]], grad_fn=<MulBackward0>)

In [20]:
module._forward_pre_hooks

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured at 0x7efcd7556828>)])

In [21]:
prune.l1_unstructured(module, 'bias', amount=3)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

In [23]:
list(module.named_parameters())

[('weight_orig',
  Parameter containing:
  tensor([[[[-0.2487,  0.3102,  0.2069],
            [-0.0560, -0.1190, -0.3054],
            [-0.1513, -0.0013,  0.1598]]],
  
  
          [[[-0.3121, -0.2036, -0.0970],
            [-0.1208,  0.0731, -0.0441],
            [ 0.3108, -0.1555,  0.3227]]],
  
  
          [[[-0.0645, -0.0914, -0.0298],
            [ 0.3153, -0.1837,  0.1004],
            [ 0.0495,  0.3147, -0.1659]]],
  
  
          [[[ 0.2732,  0.0143,  0.2799],
            [ 0.2853,  0.0012, -0.2103],
            [-0.2442,  0.3331, -0.2162]]],
  
  
          [[[ 0.2175,  0.2603,  0.0033],
            [-0.0258,  0.1743,  0.2421],
            [ 0.2314, -0.1907, -0.1123]]],
  
  
          [[[-0.1852,  0.2778, -0.0128],
            [-0.1567,  0.1004, -0.1757],
            [ 0.1725,  0.0271,  0.2683]]]], requires_grad=True)),
 ('bias_orig',
  Parameter containing:
  tensor([ 0.0315, -0.3157, -0.0437,  0.1053, -0.2900, -0.1562],
         requires_grad=True))]

In [24]:
list(module.named_buffers())

[('weight_mask',
  tensor([[[[1., 0., 1.],
            [1., 0., 1.],
            [1., 1., 1.]]],
  
  
          [[[1., 1., 1.],
            [1., 1., 0.],
            [1., 1., 1.]]],
  
  
          [[[1., 1., 1.],
            [0., 1., 1.],
            [1., 1., 0.]]],
  
  
          [[[1., 1., 1.],
            [0., 0., 1.],
            [1., 0., 1.]]],
  
  
          [[[1., 1., 1.],
            [0., 1., 0.],
            [1., 0., 0.]]],
  
  
          [[[0., 1., 1.],
            [0., 1., 1.],
            [1., 0., 0.]]]])),
 ('bias_mask', tensor([0., 1., 0., 0., 1., 1.]))]

In [25]:
module.bias

tensor([ 0.0000, -0.3157, -0.0000,  0.0000, -0.2900, -0.1562],
       grad_fn=<MulBackward0>)

In [32]:
module._forward_pre_hooks

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured at 0x7efcd7556828>),
             (1, <torch.nn.utils.prune.L1Unstructured at 0x7efcd7437940>)])

In [33]:
prune.ln_structured(module, name='weight', amount=0.5, n=2, dim=0)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

In [34]:
module.weight

tensor([[[[-0.2487,  0.0000,  0.2069],
          [-0.0560, -0.0000, -0.3054],
          [-0.1513, -0.0013,  0.1598]]],


        [[[-0.3121, -0.2036, -0.0970],
          [-0.1208,  0.0731, -0.0000],
          [ 0.3108, -0.1555,  0.3227]]],


        [[[-0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000]]],


        [[[ 0.2732,  0.0143,  0.2799],
          [ 0.0000,  0.0000, -0.2103],
          [-0.2442,  0.0000, -0.2162]]],


        [[[ 0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000]]],


        [[[-0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000]]]], grad_fn=<MulBackward0>)

In [35]:
module._forward_pre_hooks

OrderedDict([(1, <torch.nn.utils.prune.L1Unstructured at 0x7efcd7437940>),
             (2, <torch.nn.utils.prune.PruningContainer at 0x7efcd7423940>)])

In [42]:
for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container

[<torch.nn.utils.prune.RandomUnstructured object at 0x7efcd7556828>, <torch.nn.utils.prune.LnStructured object at 0x7efcd7423ac8>]


In [41]:
list(list(module._forward_pre_hooks.values())[1])

[<torch.nn.utils.prune.RandomUnstructured at 0x7efcd7556828>,
 <torch.nn.utils.prune.LnStructured at 0x7efcd7423ac8>]

In [44]:
model.state_dict().keys()

odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

In [45]:
list(module.named_parameters())

[('weight_orig',
  Parameter containing:
  tensor([[[[-0.2487,  0.3102,  0.2069],
            [-0.0560, -0.1190, -0.3054],
            [-0.1513, -0.0013,  0.1598]]],
  
  
          [[[-0.3121, -0.2036, -0.0970],
            [-0.1208,  0.0731, -0.0441],
            [ 0.3108, -0.1555,  0.3227]]],
  
  
          [[[-0.0645, -0.0914, -0.0298],
            [ 0.3153, -0.1837,  0.1004],
            [ 0.0495,  0.3147, -0.1659]]],
  
  
          [[[ 0.2732,  0.0143,  0.2799],
            [ 0.2853,  0.0012, -0.2103],
            [-0.2442,  0.3331, -0.2162]]],
  
  
          [[[ 0.2175,  0.2603,  0.0033],
            [-0.0258,  0.1743,  0.2421],
            [ 0.2314, -0.1907, -0.1123]]],
  
  
          [[[-0.1852,  0.2778, -0.0128],
            [-0.1567,  0.1004, -0.1757],
            [ 0.1725,  0.0271,  0.2683]]]], requires_grad=True)),
 ('bias_orig',
  Parameter containing:
  tensor([ 0.0315, -0.3157, -0.0437,  0.1053, -0.2900, -0.1562],
         requires_grad=True))]

In [46]:
list(module.named_buffers())

[('weight_mask',
  tensor([[[[1., 0., 1.],
            [1., 0., 1.],
            [1., 1., 1.]]],
  
  
          [[[1., 1., 1.],
            [1., 1., 0.],
            [1., 1., 1.]]],
  
  
          [[[0., 0., 0.],
            [0., 0., 0.],
            [0., 0., 0.]]],
  
  
          [[[1., 1., 1.],
            [0., 0., 1.],
            [1., 0., 1.]]],
  
  
          [[[0., 0., 0.],
            [0., 0., 0.],
            [0., 0., 0.]]],
  
  
          [[[0., 0., 0.],
            [0., 0., 0.],
            [0., 0., 0.]]]])),
 ('bias_mask', tensor([0., 1., 0., 0., 1., 1.]))]

In [47]:
module.weight

tensor([[[[-0.2487,  0.0000,  0.2069],
          [-0.0560, -0.0000, -0.3054],
          [-0.1513, -0.0013,  0.1598]]],


        [[[-0.3121, -0.2036, -0.0970],
          [-0.1208,  0.0731, -0.0000],
          [ 0.3108, -0.1555,  0.3227]]],


        [[[-0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000]]],


        [[[ 0.2732,  0.0143,  0.2799],
          [ 0.0000,  0.0000, -0.2103],
          [-0.2442,  0.0000, -0.2162]]],


        [[[ 0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000]]],


        [[[-0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000]]]], grad_fn=<MulBackward0>)

In [49]:
prune.remove(module, 'weight')

list(module.named_buffers())


[('bias_mask', tensor([0., 1., 0., 0., 1., 1.]))]

In [50]:
new_model = LeNet()

In [56]:
list(new_model.named_modules())

[('',
  LeNet(
    (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
    (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
    (fc1): Linear(in_features=400, out_features=120, bias=True)
    (fc2): Linear(in_features=120, out_features=84, bias=True)
    (fc3): Linear(in_features=84, out_features=10, bias=True)
  )),
 ('conv1', Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))),
 ('conv2', Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))),
 ('fc1', Linear(in_features=400, out_features=120, bias=True)),
 ('fc2', Linear(in_features=120, out_features=84, bias=True)),
 ('fc3', Linear(in_features=84, out_features=10, bias=True))]

In [59]:
i=iter(new_model.named_modules())

In [62]:
next(i)

('conv2', Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1)))

In [68]:
for name, module in new_model.named_modules():
    if isinstance(module, nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    elif isinstance(module, nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

In [72]:
dict(new_model.named_buffers()).keys()

dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])

In [82]:
model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2
)

In [None]:
print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

In [99]:
torch.sum(model.conv2.weight == 0) / float(model.conv2.weight.nelement()) *100

tensor(8.2176)