In [1]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [2]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LeNet().to(device=device)

If you have parameters in your model, which should be saved and restored in the state_dict(), but not trained by the optimizer, you should register them as buffers.

In [4]:
module = model.conv1
print(list(module.named_parameters()))
print(list(module.named_buffers()))

[('weight', Parameter containing:
tensor([[[[ 0.0579, -0.0224,  0.2381],
          [ 0.1789,  0.2221,  0.1635],
          [ 0.3322,  0.3205, -0.2696]]],


        [[[-0.0981,  0.1579,  0.2701],
          [ 0.0863,  0.1182,  0.2472],
          [ 0.1960,  0.0325,  0.2782]]],


        [[[-0.0234,  0.0636, -0.2284],
          [-0.3239, -0.0875,  0.1926],
          [ 0.2080,  0.1421, -0.0819]]],


        [[[-0.1942,  0.2092, -0.1111],
          [ 0.1764, -0.2321, -0.1233],
          [ 0.0385,  0.0964,  0.2867]]],


        [[[ 0.3141,  0.2319, -0.3064],
          [-0.1830,  0.1184,  0.2310],
          [-0.2528, -0.2914,  0.0233]]],


        [[[-0.0335, -0.1931, -0.1619],
          [ 0.2001, -0.2876,  0.1103],
          [ 0.0338, -0.2547,  0.3255]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.0223,  0.1173,  0.0912, -0.0221,  0.2902,  0.0723], device='cuda:0',
       requires_grad=True))]
[]


random_structured() Prunes tensor corresponding to parameter called name in module by removing the specified amount of (currently unpruned) units selected at random. Modifies module in place (and also return the modified module) by:
* 1. adding a named buffer called "name + _mask" corresponding to the binary mask applied to the parameter "name" by the pruning method.
* 2. replacing the parammeter "name" by its pruned version, while the original parameter is stored in a new parameter named "name + _orig".<br>

Parameter:
* module - module containing the tensor to prune.
* name(str) - parameter name within module on which pruning will act.
* amount(int or float) - quantity of parameters to prune. If float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If int, it represents the absolute number of parameters to prune.

In [5]:
prune.random_unstructured(module, name="weight", amount=0.3)
print(list(module.named_parameters()))


[('bias', Parameter containing:
tensor([-0.0223,  0.1173,  0.0912, -0.0221,  0.2902,  0.0723], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.0579, -0.0224,  0.2381],
          [ 0.1789,  0.2221,  0.1635],
          [ 0.3322,  0.3205, -0.2696]]],


        [[[-0.0981,  0.1579,  0.2701],
          [ 0.0863,  0.1182,  0.2472],
          [ 0.1960,  0.0325,  0.2782]]],


        [[[-0.0234,  0.0636, -0.2284],
          [-0.3239, -0.0875,  0.1926],
          [ 0.2080,  0.1421, -0.0819]]],


        [[[-0.1942,  0.2092, -0.1111],
          [ 0.1764, -0.2321, -0.1233],
          [ 0.0385,  0.0964,  0.2867]]],


        [[[ 0.3141,  0.2319, -0.3064],
          [-0.1830,  0.1184,  0.2310],
          [-0.2528, -0.2914,  0.0233]]],


        [[[-0.0335, -0.1931, -0.1619],
          [ 0.2001, -0.2876,  0.1103],
          [ 0.0338, -0.2547,  0.3255]]]], device='cuda:0', requires_grad=True))]


In [6]:
print(list(module.named_buffers()))
# 0 is the pruned position.

[('weight_mask', tensor([[[[1., 1., 0.],
          [1., 0., 1.],
          [0., 1., 1.]]],


        [[[0., 1., 0.],
          [1., 1., 0.],
          [0., 1., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 0.],
          [0., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 0.],
          [1., 1., 1.]]],


        [[[0., 1., 1.],
          [1., 1., 0.],
          [0., 1., 0.]]]], device='cuda:0'))]


In [7]:
module.weight

tensor([[[[ 0.0579, -0.0224,  0.0000],
          [ 0.1789,  0.0000,  0.1635],
          [ 0.0000,  0.3205, -0.2696]]],


        [[[-0.0000,  0.1579,  0.0000],
          [ 0.0863,  0.1182,  0.0000],
          [ 0.0000,  0.0325,  0.2782]]],


        [[[-0.0234,  0.0000, -0.2284],
          [-0.3239, -0.0875,  0.1926],
          [ 0.2080,  0.1421, -0.0819]]],


        [[[-0.1942,  0.0000, -0.1111],
          [ 0.1764, -0.2321, -0.0000],
          [ 0.0000,  0.0964,  0.2867]]],


        [[[ 0.3141,  0.2319, -0.3064],
          [-0.1830,  0.1184,  0.0000],
          [-0.2528, -0.2914,  0.0233]]],


        [[[-0.0000, -0.1931, -0.1619],
          [ 0.2001, -0.2876,  0.0000],
          [ 0.0000, -0.2547,  0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

In [8]:
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7fdb42caa650>)])


In [9]:
# prune the bias
prune.l1_unstructured(module, name="bias", amount=3)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

In [10]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[ 0.0579, -0.0224,  0.2381],
          [ 0.1789,  0.2221,  0.1635],
          [ 0.3322,  0.3205, -0.2696]]],


        [[[-0.0981,  0.1579,  0.2701],
          [ 0.0863,  0.1182,  0.2472],
          [ 0.1960,  0.0325,  0.2782]]],


        [[[-0.0234,  0.0636, -0.2284],
          [-0.3239, -0.0875,  0.1926],
          [ 0.2080,  0.1421, -0.0819]]],


        [[[-0.1942,  0.2092, -0.1111],
          [ 0.1764, -0.2321, -0.1233],
          [ 0.0385,  0.0964,  0.2867]]],


        [[[ 0.3141,  0.2319, -0.3064],
          [-0.1830,  0.1184,  0.2310],
          [-0.2528, -0.2914,  0.0233]]],


        [[[-0.0335, -0.1931, -0.1619],
          [ 0.2001, -0.2876,  0.1103],
          [ 0.0338, -0.2547,  0.3255]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.0223,  0.1173,  0.0912, -0.0221,  0.2902,  0.0723], device='cuda:0',
       requires_grad=True))]


In [11]:
prune.ln_structured(module, name="weight", amount=0.3, n=2, dim=0)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

The corresponding hook will now be of type torch.nn.utils.prune.PruningContainer, and will store the history of pruning applied to the weight parameter.

In [12]:
for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":
        break
print(list(hook))

[<torch.nn.utils.prune.RandomUnstructured object at 0x7fdb42caa650>, <torch.nn.utils.prune.LnStructured object at 0x7fdb42cab310>]


In [13]:
# save
print(model.state_dict().keys())

odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])


remove "weight_mask" in named_buffers().

In [16]:
prune.remove(module, "weight")
print(list(module.named_parameters()))
print(list(module.named_buffers()))

ValueError: Parameter 'weight' of module Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1)) has to be pruned before pruning can be removed

In [17]:
new_model = LeNet()
for name, module in new_model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name="weight", amount=0.2)

    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name="weight", amount=0.4)


print(dict(new_model.named_buffers()).keys())

dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])


In [24]:
model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

tensor.nelement() : renturn the number of elements in tensor.

In [25]:
print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

Sparsity in conv1.weight: 5.56%
Sparsity in conv2.weight: 7.06%
Sparsity in fc1.weight: 22.15%
Sparsity in fc2.weight: 11.88%
Sparsity in fc3.weight: 8.93%
Global sparsity: 20.00%


## Custom pruning functions

In [28]:
class FooBarPruningMethod(prune.BasePruningMethod):
    PRUNING_TYPE = "unstructured"

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask
    
def foobar_unstructured(module,name):
    FooBarPruningMethod.apply(module, name)
    return module

In [26]:
def compute_mask(default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask

compute_mask(model.conv1.weight_mask)


tensor([[[[0., 1., 0.],
          [1., 0., 1.],
          [0., 1., 0.]]],


        [[[1., 0., 1.],
          [0., 1., 0.],
          [1., 0., 1.]]],


        [[[0., 1., 0.],
          [1., 0., 1.],
          [0., 1., 0.]]],


        [[[1., 0., 1.],
          [0., 1., 0.],
          [1., 0., 1.]]],


        [[[0., 1., 0.],
          [1., 0., 0.],
          [0., 1., 0.]]],


        [[[1., 0., 1.],
          [0., 1., 0.],
          [1., 0., 1.]]]])

In the source code of apply() method, 
```
if not isinstance(method, PruningContainer):
    # copy `module[name]` to `module[name + '_orig']`
    module.register_parameter(name + "_orig", orig)
    # temporarily delete `module[name]`
    del module._parameters[name]
    default_mask = torch.ones_like(orig)  # temp
```

In [29]:
model = LeNet()
foobar_unstructured(model.fc3, name='bias')

print(model.fc3.bias_mask)


tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
