In [19]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from model import LeNet

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [21]:
model = LeNet().to(device = device)

In [22]:
module = model.conv1

In [23]:
print("named_parameters_before:", list(module.named_parameters()))

named_parameters_before: [('weight', Parameter containing:
tensor([[[[ 0.2344, -0.1534, -0.1717],
          [ 0.1423,  0.1479, -0.1161],
          [ 0.1598, -0.0189,  0.0788]]],


        [[[-0.1830, -0.1412, -0.1916],
          [ 0.1269,  0.0265, -0.2616],
          [ 0.0813,  0.0653, -0.0322]]],


        [[[ 0.0394, -0.2268,  0.3219],
          [ 0.1310, -0.1745, -0.2131],
          [ 0.0869, -0.2119, -0.1545]]],


        [[[-0.1306, -0.0579,  0.1979],
          [-0.0960,  0.0246,  0.2391],
          [-0.0072, -0.0899, -0.0895]]],


        [[[ 0.1965, -0.2621,  0.1457],
          [ 0.1589, -0.0594,  0.0040],
          [ 0.0165, -0.2806,  0.0036]]],


        [[[ 0.2581, -0.0061,  0.1070],
          [-0.0425,  0.1705,  0.1227],
          [-0.2775, -0.1106, -0.1397]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.2617,  0.0278, -0.2649, -0.1894, -0.2870,  0.1317], device='cuda:0',
       requires_grad=True))]


In [24]:
print("named_buffers_before:", list(module.named_buffers()))

named_buffers_before: []


In [25]:
prune.random_unstructured(module, name="weight", amount=0.3)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

In [26]:
print("named_parameters_after:", list(module.named_parameters()))

named_parameters_after: [('bias', Parameter containing:
tensor([ 0.2617,  0.0278, -0.2649, -0.1894, -0.2870,  0.1317], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.2344, -0.1534, -0.1717],
          [ 0.1423,  0.1479, -0.1161],
          [ 0.1598, -0.0189,  0.0788]]],


        [[[-0.1830, -0.1412, -0.1916],
          [ 0.1269,  0.0265, -0.2616],
          [ 0.0813,  0.0653, -0.0322]]],


        [[[ 0.0394, -0.2268,  0.3219],
          [ 0.1310, -0.1745, -0.2131],
          [ 0.0869, -0.2119, -0.1545]]],


        [[[-0.1306, -0.0579,  0.1979],
          [-0.0960,  0.0246,  0.2391],
          [-0.0072, -0.0899, -0.0895]]],


        [[[ 0.1965, -0.2621,  0.1457],
          [ 0.1589, -0.0594,  0.0040],
          [ 0.0165, -0.2806,  0.0036]]],


        [[[ 0.2581, -0.0061,  0.1070],
          [-0.0425,  0.1705,  0.1227],
          [-0.2775, -0.1106, -0.1397]]]], device='cuda:0', requires_grad=True))]


In [27]:
print("named_buffers_after:", list(module.named_buffers()))

named_buffers_after: [('weight_mask', tensor([[[[1., 1., 0.],
          [0., 1., 1.],
          [0., 0., 1.]]],


        [[[1., 0., 0.],
          [1., 1., 1.],
          [1., 0., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [1., 0., 0.],
          [0., 0., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 1.],
          [1., 1., 0.]]]], device='cuda:0'))]


In [28]:
print(module.weight)

tensor([[[[ 0.2344, -0.1534, -0.0000],
          [ 0.0000,  0.1479, -0.1161],
          [ 0.0000, -0.0000,  0.0788]]],


        [[[-0.1830, -0.0000, -0.0000],
          [ 0.1269,  0.0265, -0.2616],
          [ 0.0813,  0.0000, -0.0322]]],


        [[[ 0.0394, -0.0000,  0.3219],
          [ 0.1310, -0.1745, -0.2131],
          [ 0.0869, -0.2119, -0.1545]]],


        [[[-0.1306, -0.0579,  0.0000],
          [-0.0960,  0.0246,  0.2391],
          [-0.0072, -0.0899, -0.0895]]],


        [[[ 0.1965, -0.2621,  0.0000],
          [ 0.1589, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0036]]],


        [[[ 0.2581, -0.0000,  0.1070],
          [-0.0425,  0.1705,  0.1227],
          [-0.2775, -0.1106, -0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)


经过剪枝后，named_patameters中只有weight_orig（和原来的weight值一样），buffer中添加了weight_mask，即为weight剪枝时的掩码张量。
换言之，weight = weight * weight_mask

In [29]:
prune.random_unstructured(module, name="bias", amount=0.3)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

prune也有很多方式，比如l1_unstructured和random_unstructured，前者根据l1值找出需要剪枝的边，后者随机

In [30]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[ 0.2344, -0.1534, -0.1717],
          [ 0.1423,  0.1479, -0.1161],
          [ 0.1598, -0.0189,  0.0788]]],


        [[[-0.1830, -0.1412, -0.1916],
          [ 0.1269,  0.0265, -0.2616],
          [ 0.0813,  0.0653, -0.0322]]],


        [[[ 0.0394, -0.2268,  0.3219],
          [ 0.1310, -0.1745, -0.2131],
          [ 0.0869, -0.2119, -0.1545]]],


        [[[-0.1306, -0.0579,  0.1979],
          [-0.0960,  0.0246,  0.2391],
          [-0.0072, -0.0899, -0.0895]]],


        [[[ 0.1965, -0.2621,  0.1457],
          [ 0.1589, -0.0594,  0.0040],
          [ 0.0165, -0.2806,  0.0036]]],


        [[[ 0.2581, -0.0061,  0.1070],
          [-0.0425,  0.1705,  0.1227],
          [-0.2775, -0.1106, -0.1397]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([ 0.2617,  0.0278, -0.2649, -0.1894, -0.2870,  0.1317], device='cuda:0',
       requires_grad=True))]


In [31]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 0.],
          [0., 1., 1.],
          [0., 0., 1.]]],


        [[[1., 0., 0.],
          [1., 1., 1.],
          [1., 0., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [1., 0., 0.],
          [0., 0., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 1.],
          [1., 1., 0.]]]], device='cuda:0')), ('bias_mask', tensor([1., 1., 1., 0., 1., 0.], device='cuda:0'))]


In [32]:
print(module.bias)

tensor([ 0.2617,  0.0278, -0.2649, -0.0000, -0.2870,  0.0000], device='cuda:0',
       grad_fn=<MulBackward0>)


In [33]:
print(module._forward_pre_hooks)

OrderedDict([(2, <torch.nn.utils.prune.RandomUnstructured object at 0x7f9f18c0c9d0>), (3, <torch.nn.utils.prune.RandomUnstructured object at 0x7f9f18c10dd0>)])
