In [1]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

In [12]:
module = model.conv1
print(model)
print(module)
print(list(module.named_parameters()))

LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)
Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
[('bias', Parameter containing:
tensor([ 0.2740,  0.0881, -0.3035, -0.0603, -0.1819,  0.0966], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-0.2874,  0.2522,  0.0900],
          [-0.0997, -0.0479,  0.2821],
          [-0.1305, -0.1103, -0.3043]]],


        [[[ 0.1046,  0.0790, -0.1596],
          [-0.1487,  0.0541,  0.2125],
          [ 0.0021, -0.1950, -0.2188]]],


        [[[-0.0556, -0.3216, -0.2750],
          [-0.2837,  0.1171, -0.1186],
          [ 0.0486,  0.0043,  0.0214]]],


        [[[-0.1516, -0.1821, -0.1250],
          [ 0.0812, -0.0642,  0.2846],
          [-0.1217,  

In [16]:
print(list(module.named_buffers()))


[('weight_mask', tensor([[[[0., 1., 0.],
          [0., 0., 0.],
          [1., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [1., 0., 0.]]],


        [[[1., 1., 0.],
          [1., 1., 1.],
          [0., 1., 1.]]],


        [[[1., 0., 1.],
          [0., 1., 1.],
          [0., 1., 0.]]],


        [[[0., 0., 1.],
          [0., 1., 1.],
          [1., 1., 0.]]],


        [[[1., 1., 1.],
          [0., 0., 1.],
          [1., 1., 1.]]]], device='cuda:0'))]


In [14]:
prune.random_unstructured(module, name="weight", amount=0.3)


Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

In [None]:
print(list(module.named_parameters()))


In [17]:
print(module.weight)


tensor([[[[-0.0000,  0.2522,  0.0000],
          [-0.0000, -0.0000,  0.0000],
          [-0.1305, -0.0000, -0.0000]]],


        [[[ 0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [ 0.0021, -0.0000, -0.0000]]],


        [[[-0.0556, -0.3216, -0.0000],
          [-0.2837,  0.1171, -0.1186],
          [ 0.0000,  0.0043,  0.0214]]],


        [[[-0.1516, -0.0000, -0.1250],
          [ 0.0000, -0.0642,  0.2846],
          [-0.0000,  0.1433, -0.0000]]],


        [[[-0.0000, -0.0000,  0.2462],
          [-0.0000, -0.2874,  0.1703],
          [-0.2923, -0.2939,  0.0000]]],


        [[[-0.2007,  0.0580, -0.0412],
          [ 0.0000, -0.0000,  0.1449],
          [-0.2900, -0.1460, -0.0546]]]], device='cuda:0',
       grad_fn=<MulBackward0>)


In [18]:
print(module._forward_pre_hooks)


OrderedDict([(1, <torch.nn.utils.prune.PruningContainer object at 0x7fd3591ba080>)])
