In [1]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [3]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [4]:
model = LeNet()

In [5]:
class FooBarPruningMethod(prune.BasePruningMethod):
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask

In [6]:
def foobar_unstructured(module, name):
    """Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module)
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the
    original (unpruned) parameter is stored in a new parameter named
    `name+'_orig'`.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module

    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
    """
    FooBarPruningMethod.apply(module, name)
    return module

In [9]:
model = LeNet()
foobar_unstructured(model.conv1, name='weight')

print(model.conv1.weight_mask)

tensor([[[[0., 1., 0., 1., 0.],
          [1., 0., 1., 0., 1.],
          [0., 1., 0., 1., 0.],
          [1., 0., 1., 0., 1.],
          [0., 1., 0., 1., 0.]]],


        [[[1., 0., 1., 0., 1.],
          [0., 1., 0., 1., 0.],
          [1., 0., 1., 0., 1.],
          [0., 1., 0., 1., 0.],
          [1., 0., 1., 0., 1.]]],


        [[[0., 1., 0., 1., 0.],
          [1., 0., 1., 0., 1.],
          [0., 1., 0., 1., 0.],
          [1., 0., 1., 0., 1.],
          [0., 1., 0., 1., 0.]]],


        [[[1., 0., 1., 0., 1.],
          [0., 1., 0., 1., 0.],
          [1., 0., 1., 0., 1.],
          [0., 1., 0., 1., 0.],
          [1., 0., 1., 0., 1.]]],


        [[[0., 1., 0., 1., 0.],
          [1., 0., 1., 0., 1.],
          [0., 1., 0., 1., 0.],
          [1., 0., 1., 0., 1.],
          [0., 1., 0., 1., 0.]]],


        [[[1., 0., 1., 0., 1.],
          [0., 1., 0., 1., 0.],
          [1., 0., 1., 0., 1.],
          [0., 1., 0., 1., 0.],
          [1., 0., 1., 0., 1.]]]])
