In [1]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

In [9]:
pytorch_total_params = sum(p.numel() for p in model.parameters())
print(pytorch_total_params)

61706


In [10]:
module = model.conv1
print(list(module.named_parameters()))

[('weight', Parameter containing:
tensor([[[[ 0.0589, -0.1224,  0.0884,  0.1056,  0.0677],
          [-0.0132, -0.0639,  0.0321, -0.0326, -0.0756],
          [-0.0839,  0.1315,  0.1541, -0.0199, -0.1780],
          [ 0.1895,  0.0444,  0.1327,  0.0973, -0.1855],
          [ 0.1290, -0.1621, -0.0438, -0.0392,  0.0958]]],


        [[[-0.0294,  0.1710, -0.0709, -0.1965,  0.0258],
          [ 0.0577, -0.1364, -0.0339,  0.0797, -0.0121],
          [ 0.0374,  0.1337, -0.0736,  0.0487,  0.1601],
          [-0.0716, -0.1105, -0.0420, -0.1950, -0.0470],
          [ 0.1307, -0.1387,  0.1195,  0.0025, -0.0476]]],


        [[[-0.0681, -0.1309, -0.0585, -0.0875, -0.0764],
          [-0.1996,  0.0377, -0.0751, -0.1404,  0.1814],
          [-0.1874, -0.1908, -0.0563, -0.0602, -0.1702],
          [-0.1683, -0.1378,  0.0075,  0.0454,  0.0593],
          [-0.0528, -0.0723, -0.1319, -0.0541, -0.1055]]],


        [[[ 0.1962,  0.0700, -0.1572,  0.0175, -0.0665],
          [ 0.0678,  0.1529, -0.0304, -0.1

In [11]:
print(list(module.named_buffers()))

[]


In [12]:
prune.random_unstructured(module, name="weight", amount=0.3)            # defining unstructured pruning with amount = 0.3 i.e. 30% random pruning

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

In [14]:
torch.sum(module.weight_mask == 0)      # number of weights pruned

tensor(45)

In [15]:
print(list(module.named_parameters()))          # original weights are stored with name 'weight_orig'

[('bias', Parameter containing:
tensor([-0.0894,  0.1327, -0.1635, -0.0425,  0.1689, -0.1784],
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.0589, -0.1224,  0.0884,  0.1056,  0.0677],
          [-0.0132, -0.0639,  0.0321, -0.0326, -0.0756],
          [-0.0839,  0.1315,  0.1541, -0.0199, -0.1780],
          [ 0.1895,  0.0444,  0.1327,  0.0973, -0.1855],
          [ 0.1290, -0.1621, -0.0438, -0.0392,  0.0958]]],


        [[[-0.0294,  0.1710, -0.0709, -0.1965,  0.0258],
          [ 0.0577, -0.1364, -0.0339,  0.0797, -0.0121],
          [ 0.0374,  0.1337, -0.0736,  0.0487,  0.1601],
          [-0.0716, -0.1105, -0.0420, -0.1950, -0.0470],
          [ 0.1307, -0.1387,  0.1195,  0.0025, -0.0476]]],


        [[[-0.0681, -0.1309, -0.0585, -0.0875, -0.0764],
          [-0.1996,  0.0377, -0.0751, -0.1404,  0.1814],
          [-0.1874, -0.1908, -0.0563, -0.0602, -0.1702],
          [-0.1683, -0.1378,  0.0075,  0.0454,  0.0593],
          [-0.0528, -0.0723, -0.

In [16]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 0.],
          [1., 1., 1., 1., 1.],
          [0., 0., 1., 1., 1.],
          [1., 1., 1., 0., 1.]]],


        [[[1., 1., 1., 1., 0.],
          [1., 1., 1., 0., 1.],
          [1., 1., 1., 0., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.]]],


        [[[0., 1., 0., 1., 1.],
          [1., 0., 1., 1., 1.],
          [1., 1., 0., 1., 1.],
          [1., 1., 1., 0., 0.],
          [1., 1., 1., 1., 0.]]],


        [[[0., 0., 0., 0., 1.],
          [1., 0., 1., 0., 1.],
          [1., 1., 0., 1., 1.],
          [0., 0., 1., 0., 0.],
          [1., 1., 0., 1., 1.]]],


        [[[0., 1., 1., 1., 0.],
          [1., 1., 0., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 0., 1., 0., 0.],
          [1., 1., 1., 0., 1.]]],


        [[[0., 1., 1., 0., 0.],
          [1., 1., 0., 1., 1.],
          [1., 1., 1., 0., 1.],
          [1., 1., 1., 1., 0.],
          [0., 1., 0., 0., 0.]]]]))

In [17]:
print(module.weight)

tensor([[[[ 0.0589, -0.1224,  0.0884,  0.1056,  0.0677],
          [-0.0132, -0.0639,  0.0321, -0.0000, -0.0000],
          [-0.0839,  0.1315,  0.1541, -0.0199, -0.1780],
          [ 0.0000,  0.0000,  0.1327,  0.0973, -0.1855],
          [ 0.1290, -0.1621, -0.0438, -0.0000,  0.0958]]],


        [[[-0.0294,  0.1710, -0.0709, -0.1965,  0.0000],
          [ 0.0577, -0.1364, -0.0339,  0.0000, -0.0121],
          [ 0.0374,  0.1337, -0.0736,  0.0000,  0.1601],
          [-0.0716, -0.1105, -0.0420, -0.1950, -0.0470],
          [ 0.1307, -0.1387,  0.1195,  0.0000, -0.0476]]],


        [[[-0.0000, -0.1309, -0.0000, -0.0875, -0.0764],
          [-0.1996,  0.0000, -0.0751, -0.1404,  0.1814],
          [-0.1874, -0.1908, -0.0000, -0.0602, -0.1702],
          [-0.1683, -0.1378,  0.0075,  0.0000,  0.0000],
          [-0.0528, -0.0723, -0.1319, -0.0541, -0.0000]]],


        [[[ 0.0000,  0.0000, -0.0000,  0.0000, -0.0665],
          [ 0.0678,  0.0000, -0.0304, -0.0000, -0.0271],
          [ 0.0717,