In [19]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from model import LeNet

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [21]:
model = LeNet().to(device = device)

In [22]:
module = model.conv1

In [23]:
print("named_parameters_before:", list(module.named_parameters()))

named_parameters_before: [('weight', Parameter containing:
tensor([[[[ 0.0284,  0.2955, -0.1782],
          [ 0.0582,  0.2633,  0.1184],
          [ 0.0152, -0.1134,  0.1717]]],


        [[[-0.2635,  0.0592,  0.2944],
          [-0.0558, -0.2884, -0.1755],
          [ 0.2850,  0.1975, -0.0298]]],


        [[[ 0.0245, -0.0228, -0.0329],
          [-0.2639, -0.0598,  0.0321],
          [ 0.1138,  0.1018,  0.1236]]],


        [[[ 0.1103,  0.0776,  0.1747],
          [ 0.1857,  0.2623, -0.1519],
          [ 0.2165,  0.0449,  0.1655]]],


        [[[-0.0517,  0.1594, -0.0714],
          [-0.3314, -0.1445, -0.3316],
          [-0.2006, -0.3020, -0.0486]]],


        [[[-0.1472,  0.2213,  0.1104],
          [-0.1598, -0.2827,  0.3098],
          [-0.2300,  0.2098, -0.3126]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.0197,  0.0293, -0.1146,  0.0632, -0.2255, -0.2628], device='cuda:0',
       requires_grad=True))]


In [24]:
print("named_buffers_before:", list(module.named_buffers()))

named_buffers_before: []


In [25]:
prune.random_unstructured(module, name="weight", amount=0.3)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

In [26]:
print("named_parameters_after:", list(module.named_parameters()))

named_parameters_after: [('bias', Parameter containing:
tensor([-0.0197,  0.0293, -0.1146,  0.0632, -0.2255, -0.2628], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.0284,  0.2955, -0.1782],
          [ 0.0582,  0.2633,  0.1184],
          [ 0.0152, -0.1134,  0.1717]]],


        [[[-0.2635,  0.0592,  0.2944],
          [-0.0558, -0.2884, -0.1755],
          [ 0.2850,  0.1975, -0.0298]]],


        [[[ 0.0245, -0.0228, -0.0329],
          [-0.2639, -0.0598,  0.0321],
          [ 0.1138,  0.1018,  0.1236]]],


        [[[ 0.1103,  0.0776,  0.1747],
          [ 0.1857,  0.2623, -0.1519],
          [ 0.2165,  0.0449,  0.1655]]],


        [[[-0.0517,  0.1594, -0.0714],
          [-0.3314, -0.1445, -0.3316],
          [-0.2006, -0.3020, -0.0486]]],


        [[[-0.1472,  0.2213,  0.1104],
          [-0.1598, -0.2827,  0.3098],
          [-0.2300,  0.2098, -0.3126]]]], device='cuda:0', requires_grad=True))]


In [27]:
print("named_buffers_after:", list(module.named_buffers()))

named_buffers_after: [('weight_mask', tensor([[[[0., 1., 1.],
          [1., 1., 1.],
          [1., 0., 1.]]],


        [[[1., 1., 1.],
          [0., 1., 0.],
          [1., 1., 0.]]],


        [[[0., 1., 0.],
          [0., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 0.],
          [1., 1., 1.]]]], device='cuda:0'))]


In [28]:
print(module.weight)

tensor([[[[ 0.0000,  0.2955, -0.1782],
          [ 0.0582,  0.2633,  0.1184],
          [ 0.0152, -0.0000,  0.1717]]],


        [[[-0.2635,  0.0592,  0.2944],
          [-0.0000, -0.2884, -0.0000],
          [ 0.2850,  0.1975, -0.0000]]],


        [[[ 0.0000, -0.0228, -0.0000],
          [-0.0000, -0.0000,  0.0321],
          [ 0.1138,  0.1018,  0.1236]]],


        [[[ 0.1103,  0.0776,  0.1747],
          [ 0.1857,  0.2623, -0.1519],
          [ 0.2165,  0.0449,  0.0000]]],


        [[[-0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000, -0.3316],
          [-0.2006, -0.3020, -0.0486]]],


        [[[-0.1472,  0.2213,  0.1104],
          [-0.1598, -0.2827,  0.0000],
          [-0.2300,  0.2098, -0.3126]]]], device='cuda:0',
       grad_fn=<MulBackward0>)


经过剪枝后，named_patameters中只有weight_orig（和原来的weight值一样），buffer中添加了weight_mask，即为weight剪枝时的掩码张量。
换言之，weight = weight * weight_mask

In [29]:
prune.random_unstructured(module, name="bias", amount=0.3)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

prune也有很多方式，比如l1_unstructured和random_unstructured，前者根据l1值找出需要剪枝的边，后者随机

In [30]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[ 0.0284,  0.2955, -0.1782],
          [ 0.0582,  0.2633,  0.1184],
          [ 0.0152, -0.1134,  0.1717]]],


        [[[-0.2635,  0.0592,  0.2944],
          [-0.0558, -0.2884, -0.1755],
          [ 0.2850,  0.1975, -0.0298]]],


        [[[ 0.0245, -0.0228, -0.0329],
          [-0.2639, -0.0598,  0.0321],
          [ 0.1138,  0.1018,  0.1236]]],


        [[[ 0.1103,  0.0776,  0.1747],
          [ 0.1857,  0.2623, -0.1519],
          [ 0.2165,  0.0449,  0.1655]]],


        [[[-0.0517,  0.1594, -0.0714],
          [-0.3314, -0.1445, -0.3316],
          [-0.2006, -0.3020, -0.0486]]],


        [[[-0.1472,  0.2213,  0.1104],
          [-0.1598, -0.2827,  0.3098],
          [-0.2300,  0.2098, -0.3126]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.0197,  0.0293, -0.1146,  0.0632, -0.2255, -0.2628], device='cuda:0',
       requires_grad=True))]


In [31]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[0., 1., 1.],
          [1., 1., 1.],
          [1., 0., 1.]]],


        [[[1., 1., 1.],
          [0., 1., 0.],
          [1., 1., 0.]]],


        [[[0., 1., 0.],
          [0., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 0.],
          [1., 1., 1.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 1., 1., 1., 1.], device='cuda:0'))]


In [32]:
print(module.bias)

tensor([-0.0000,  0.0000, -0.1146,  0.0632, -0.2255, -0.2628], device='cuda:0',
       grad_fn=<MulBackward0>)


In [33]:
print(module._forward_pre_hooks)

OrderedDict([(2, <torch.nn.utils.prune.RandomUnstructured object at 0x7f86bf13da10>), (3, <torch.nn.utils.prune.RandomUnstructured object at 0x7f86b5547d50>)])
