In [9]:
import torch
from torch import nn
import torch.nn.functional as F


class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = LeNet5().to(device)

print(list(model.conv1.named_parameters()))

[('weight', Parameter containing:
tensor([[[[-0.0711, -0.0159,  0.0370,  0.0129, -0.1070],
          [-0.0443,  0.0649, -0.1031,  0.0578,  0.0238],
          [ 0.0742,  0.1049,  0.0835,  0.0871, -0.1083],
          [ 0.1130, -0.0133,  0.0526, -0.0261, -0.0450],
          [-0.0103,  0.0812, -0.0013, -0.0191, -0.1074]],

         [[-0.0725,  0.0426,  0.0019, -0.0095,  0.0440],
          [-0.0331,  0.0488, -0.0885,  0.0363,  0.0209],
          [-0.1130, -0.0900, -0.0899,  0.1149,  0.0595],
          [-0.0535,  0.1001, -0.1148,  0.0096,  0.0483],
          [ 0.0608, -0.1080,  0.0387, -0.0534, -0.0288]],

         [[-0.0711,  0.1122, -0.0558,  0.0072, -0.0142],
          [ 0.0182, -0.0048, -0.1034, -0.1003, -0.0324],
          [-0.0595,  0.0746,  0.0849, -0.0298, -0.0252],
          [-0.0317,  0.0905,  0.1022, -0.0477,  0.0275],
          [-0.0710,  0.0721,  0.0969, -0.0644,  0.0946]]],


        [[[-0.0570, -0.0887,  0.1047,  0.0757,  0.0887],
          [ 0.0169,  0.0631,  0.0365, -0.0831,

# Local/Global Pruning

In [11]:
import torch.nn.utils.prune as prune

prune.random_unstructured(model.conv1, name="weight", amount=0.3)
prune.random_unstructured(model.conv1, name="bias", amount=0.3)

Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))

In [12]:
model = LeNet5().to(device)

for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.random_unstructured(module, name="weight", amount=0.3)
    if isinstance(module, torch.nn.Linear):
        prune.random_unstructured(module, name="weight", amount=0.5)

In [13]:
model = LeNet5().to(device)

parameters_to_prune = (
    (model.conv1, "weight"),
    (model.conv2, "weight"),
    (model.fc1, "weight"),
    (model.fc2, "weight"),
    (model.fc3, "weight"),
)

prune.global_unstructured(
    parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2
)

# Custom Pruning

In [15]:
class MyPruningMethod(prune.BasePruningMethod):
    PRUNING_TYPE = "unstructured"

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask


def my_unstructured(module, name):
    MyPruningMethod.apply(module, name)
    return module

In [16]:
model = LeNet5().to(device)
model.conv1 = my_unstructured(model.conv1, name="weight")