<a href="https://colab.research.google.com/github/HappyGithub-dev/Pruning_Network/blob/main/RoboReg_Network_Pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input channel, 6 output channels, 5x5 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

In [None]:
module = model.conv1
print(list(module.named_parameters()))

[('weight', Parameter containing:
tensor([[[[-0.0556, -0.1685, -0.0267,  0.0891, -0.0903],
          [ 0.0120, -0.0897,  0.1352,  0.1212,  0.1581],
          [ 0.0368,  0.1208,  0.0078,  0.1517,  0.0689],
          [ 0.1190,  0.0377, -0.1386,  0.0400,  0.0682],
          [ 0.1607, -0.1372, -0.1669,  0.0656,  0.1678]]],


        [[[ 0.0761, -0.0414, -0.1494,  0.0254, -0.0140],
          [-0.1772, -0.0460, -0.0340,  0.1283,  0.1575],
          [-0.0344, -0.0971,  0.0822, -0.0786, -0.1271],
          [ 0.0131,  0.1167, -0.0495, -0.1338,  0.0584],
          [-0.0645,  0.0195,  0.0872,  0.1457, -0.1807]]],


        [[[ 0.0944, -0.1234,  0.0688,  0.1474,  0.1353],
          [-0.1278,  0.1907, -0.1457,  0.0170, -0.1785],
          [-0.1118,  0.0336,  0.1021, -0.0990,  0.0113],
          [ 0.1292, -0.1263, -0.0493, -0.1824,  0.0060],
          [ 0.0383, -0.1724,  0.0344, -0.0721,  0.1184]]],


        [[[-0.0926,  0.1704, -0.0636,  0.1602, -0.0021],
          [ 0.1682,  0.1021,  0.0067,  0.1

In [None]:
print(list(module.named_buffers()))

[]


In [None]:
prune.random_unstructured(module, name="weight", amount=0.3)

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

In [None]:
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([-0.1914, -0.0752, -0.1610, -0.1747,  0.1057,  0.1590],
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-0.0556, -0.1685, -0.0267,  0.0891, -0.0903],
          [ 0.0120, -0.0897,  0.1352,  0.1212,  0.1581],
          [ 0.0368,  0.1208,  0.0078,  0.1517,  0.0689],
          [ 0.1190,  0.0377, -0.1386,  0.0400,  0.0682],
          [ 0.1607, -0.1372, -0.1669,  0.0656,  0.1678]]],


        [[[ 0.0761, -0.0414, -0.1494,  0.0254, -0.0140],
          [-0.1772, -0.0460, -0.0340,  0.1283,  0.1575],
          [-0.0344, -0.0971,  0.0822, -0.0786, -0.1271],
          [ 0.0131,  0.1167, -0.0495, -0.1338,  0.0584],
          [-0.0645,  0.0195,  0.0872,  0.1457, -0.1807]]],


        [[[ 0.0944, -0.1234,  0.0688,  0.1474,  0.1353],
          [-0.1278,  0.1907, -0.1457,  0.0170, -0.1785],
          [-0.1118,  0.0336,  0.1021, -0.0990,  0.0113],
          [ 0.1292, -0.1263, -0.0493, -0.1824,  0.0060],
          [ 0.0383, -0.1724,  0.

In [None]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.],
          [1., 0., 1., 0., 0.],
          [0., 0., 1., 1., 0.],
          [1., 0., 1., 1., 1.]]],


        [[[0., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 0., 0., 1., 1.],
          [1., 1., 0., 1., 1.],
          [1., 1., 1., 0., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [0., 1., 1., 0., 0.],
          [0., 0., 1., 1., 0.],
          [1., 1., 1., 1., 1.],
          [0., 1., 0., 1., 1.]]],


        [[[1., 0., 0., 0., 1.],
          [0., 1., 1., 0., 1.],
          [1., 1., 0., 1., 1.],
          [1., 0., 0., 0., 1.],
          [1., 1., 0., 1., 0.]]],


        [[[1., 1., 1., 0., 1.],
          [0., 0., 1., 0., 1.],
          [0., 1., 1., 0., 1.],
          [0., 1., 1., 1., 1.],
          [1., 1., 0., 1., 0.]]],


        [[[0., 1., 1., 1., 0.],
          [1., 1., 0., 0., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]]]))

In [None]:
print(module.weight)

tensor([[[[-0.0556, -0.1685, -0.0267,  0.0891, -0.0903],
          [ 0.0120, -0.0897,  0.1352,  0.0000,  0.1581],
          [ 0.0368,  0.0000,  0.0078,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.1386,  0.0400,  0.0000],
          [ 0.1607, -0.0000, -0.1669,  0.0656,  0.1678]]],


        [[[ 0.0000, -0.0414, -0.1494,  0.0254, -0.0140],
          [-0.1772, -0.0460, -0.0340,  0.1283,  0.1575],
          [-0.0344, -0.0000,  0.0000, -0.0786, -0.1271],
          [ 0.0131,  0.1167, -0.0000, -0.1338,  0.0584],
          [-0.0645,  0.0195,  0.0872,  0.0000, -0.1807]]],


        [[[ 0.0944, -0.1234,  0.0688,  0.1474,  0.1353],
          [-0.0000,  0.1907, -0.1457,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.1021, -0.0990,  0.0000],
          [ 0.1292, -0.1263, -0.0493, -0.1824,  0.0060],
          [ 0.0000, -0.1724,  0.0000, -0.0721,  0.1184]]],


        [[[-0.0926,  0.0000, -0.0000,  0.0000, -0.0021],
          [ 0.0000,  0.1021,  0.0067,  0.0000,  0.1163],
          [ 0.1408,

In [None]:
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7dcfebe6ec50>)])


In [None]:
prune.l1_unstructured(module, name="bias", amount=3)

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

In [None]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[-0.0556, -0.1685, -0.0267,  0.0891, -0.0903],
          [ 0.0120, -0.0897,  0.1352,  0.1212,  0.1581],
          [ 0.0368,  0.1208,  0.0078,  0.1517,  0.0689],
          [ 0.1190,  0.0377, -0.1386,  0.0400,  0.0682],
          [ 0.1607, -0.1372, -0.1669,  0.0656,  0.1678]]],


        [[[ 0.0761, -0.0414, -0.1494,  0.0254, -0.0140],
          [-0.1772, -0.0460, -0.0340,  0.1283,  0.1575],
          [-0.0344, -0.0971,  0.0822, -0.0786, -0.1271],
          [ 0.0131,  0.1167, -0.0495, -0.1338,  0.0584],
          [-0.0645,  0.0195,  0.0872,  0.1457, -0.1807]]],


        [[[ 0.0944, -0.1234,  0.0688,  0.1474,  0.1353],
          [-0.1278,  0.1907, -0.1457,  0.0170, -0.1785],
          [-0.1118,  0.0336,  0.1021, -0.0990,  0.0113],
          [ 0.1292, -0.1263, -0.0493, -0.1824,  0.0060],
          [ 0.0383, -0.1724,  0.0344, -0.0721,  0.1184]]],


        [[[-0.0926,  0.1704, -0.0636,  0.1602, -0.0021],
          [ 0.1682,  0.1021,  0.0067,

In [None]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.],
          [1., 0., 1., 0., 0.],
          [0., 0., 1., 1., 0.],
          [1., 0., 1., 1., 1.]]],


        [[[0., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 0., 0., 1., 1.],
          [1., 1., 0., 1., 1.],
          [1., 1., 1., 0., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [0., 1., 1., 0., 0.],
          [0., 0., 1., 1., 0.],
          [1., 1., 1., 1., 1.],
          [0., 1., 0., 1., 1.]]],


        [[[1., 0., 0., 0., 1.],
          [0., 1., 1., 0., 1.],
          [1., 1., 0., 1., 1.],
          [1., 0., 0., 0., 1.],
          [1., 1., 0., 1., 0.]]],


        [[[1., 1., 1., 0., 1.],
          [0., 0., 1., 0., 1.],
          [0., 1., 1., 0., 1.],
          [0., 1., 1., 1., 1.],
          [1., 1., 0., 1., 0.]]],


        [[[0., 1., 1., 1., 0.],
          [1., 1., 0., 0., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]]]))

In [None]:
print(module.bias)

tensor([-0.1914, -0.0000, -0.1610, -0.1747,  0.0000,  0.0000],
       grad_fn=<MulBackward0>)


In [None]:
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7dcfebe6ec50>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7dcfebe6f4c0>)])


In [None]:
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)


print(module.weight)

tensor([[[[-0.0556, -0.1685, -0.0267,  0.0891, -0.0903],
          [ 0.0120, -0.0897,  0.1352,  0.0000,  0.1581],
          [ 0.0368,  0.0000,  0.0078,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.1386,  0.0400,  0.0000],
          [ 0.1607, -0.0000, -0.1669,  0.0656,  0.1678]]],


        [[[ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000, -0.0000]]],


        [[[ 0.0944, -0.1234,  0.0688,  0.1474,  0.1353],
          [-0.0000,  0.1907, -0.1457,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.1021, -0.0990,  0.0000],
          [ 0.1292, -0.1263, -0.0493, -0.1824,  0.0060],
          [ 0.0000, -0.1724,  0.0000, -0.0721,  0.1184]]],


        [[[-0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,

In [None]:
for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select the correct hook
        break

print(list(hook))

[<torch.nn.utils.prune.RandomUnstructured object at 0x7dcfebe6ec50>, <torch.nn.utils.prune.LnStructured object at 0x7dcfebe6dd80>]


In [None]:
print(model.state_dict().keys()) ## Check step for mask existence


odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])


In [None]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[-0.0556, -0.1685, -0.0267,  0.0891, -0.0903],
          [ 0.0120, -0.0897,  0.1352,  0.1212,  0.1581],
          [ 0.0368,  0.1208,  0.0078,  0.1517,  0.0689],
          [ 0.1190,  0.0377, -0.1386,  0.0400,  0.0682],
          [ 0.1607, -0.1372, -0.1669,  0.0656,  0.1678]]],


        [[[ 0.0761, -0.0414, -0.1494,  0.0254, -0.0140],
          [-0.1772, -0.0460, -0.0340,  0.1283,  0.1575],
          [-0.0344, -0.0971,  0.0822, -0.0786, -0.1271],
          [ 0.0131,  0.1167, -0.0495, -0.1338,  0.0584],
          [-0.0645,  0.0195,  0.0872,  0.1457, -0.1807]]],


        [[[ 0.0944, -0.1234,  0.0688,  0.1474,  0.1353],
          [-0.1278,  0.1907, -0.1457,  0.0170, -0.1785],
          [-0.1118,  0.0336,  0.1021, -0.0990,  0.0113],
          [ 0.1292, -0.1263, -0.0493, -0.1824,  0.0060],
          [ 0.0383, -0.1724,  0.0344, -0.0721,  0.1184]]],


        [[[-0.0926,  0.1704, -0.0636,  0.1602, -0.0021],
          [ 0.1682,  0.1021,  0.0067,

In [None]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.],
          [1., 0., 1., 0., 0.],
          [0., 0., 1., 1., 0.],
          [1., 0., 1., 1., 1.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[1., 1., 1., 1., 1.],
          [0., 1., 1., 0., 0.],
          [0., 0., 1., 1., 0.],
          [1., 1., 1., 1., 1.],
          [0., 1., 0., 1., 1.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 1., 1., 1., 0.],
          [1., 1., 0., 0., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]]]))

In [None]:
print(module.weight)

tensor([[[[-0.0556, -0.1685, -0.0267,  0.0891, -0.0903],
          [ 0.0120, -0.0897,  0.1352,  0.0000,  0.1581],
          [ 0.0368,  0.0000,  0.0078,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.1386,  0.0400,  0.0000],
          [ 0.1607, -0.0000, -0.1669,  0.0656,  0.1678]]],


        [[[ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000, -0.0000]]],


        [[[ 0.0944, -0.1234,  0.0688,  0.1474,  0.1353],
          [-0.0000,  0.1907, -0.1457,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.1021, -0.0990,  0.0000],
          [ 0.1292, -0.1263, -0.0493, -0.1824,  0.0060],
          [ 0.0000, -0.1724,  0.0000, -0.0721,  0.1184]]],


        [[[-0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,

In [None]:
prune.remove(module, 'weight')
print(list(module.named_parameters()))

[('bias_orig', Parameter containing:
tensor([-0.1914, -0.0752, -0.1610, -0.1747,  0.1057,  0.1590],
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[-0.0556, -0.1685, -0.0267,  0.0891, -0.0903],
          [ 0.0120, -0.0897,  0.1352,  0.0000,  0.1581],
          [ 0.0368,  0.0000,  0.0078,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.1386,  0.0400,  0.0000],
          [ 0.1607, -0.0000, -0.1669,  0.0656,  0.1678]]],


        [[[ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000, -0.0000]]],


        [[[ 0.0944, -0.1234,  0.0688,  0.1474,  0.1353],
          [-0.0000,  0.1907, -0.1457,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.1021, -0.0990,  0.0000],
          [ 0.1292, -0.1263, -0.0493, -0.1824,  0.0060],
          [ 0.0000, -0.1724,  0.

In [None]:
print(list(module.named_buffers()))

[('bias_mask', tensor([1., 0., 1., 1., 0., 0.]))]


In [None]:
new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])


In [None]:
model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

In [None]:
print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

Sparsity in conv1.weight: 4.67%
Sparsity in conv2.weight: 14.21%
Sparsity in fc1.weight: 22.11%
Sparsity in fc2.weight: 12.45%
Sparsity in fc3.weight: 9.40%
Global sparsity: 20.00%
