In [1]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

In [3]:
model

LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [4]:
conv = model.conv1
print(list(conv.named_parameters()))

[('weight', Parameter containing:
tensor([[[[-0.1953, -0.0010, -0.1085,  0.1715,  0.1367],
          [ 0.0376, -0.1104,  0.0248, -0.1407,  0.0730],
          [ 0.1880,  0.1731, -0.0466, -0.1832, -0.1856],
          [ 0.0817,  0.1655,  0.1003,  0.1301, -0.1494],
          [-0.1244,  0.0909, -0.1905, -0.1454,  0.0191]]],


        [[[-0.0866,  0.1612, -0.1080, -0.1353, -0.0048],
          [-0.0390,  0.0409,  0.1372,  0.1084,  0.0089],
          [ 0.0773, -0.0910,  0.0907, -0.0005, -0.0688],
          [-0.1594, -0.0870,  0.1363, -0.0502, -0.1091],
          [-0.1946,  0.1958,  0.0182,  0.0292,  0.0192]]],


        [[[-0.0018,  0.1886,  0.1196, -0.0536,  0.0363],
          [ 0.0043, -0.0062, -0.1419, -0.0246,  0.1308],
          [-0.0145, -0.0925, -0.0921, -0.0057, -0.1508],
          [-0.0317,  0.1312, -0.0472,  0.1026,  0.0955],
          [-0.1895, -0.1796, -0.0120,  0.1371, -0.1712]]],


        [[[-0.1677,  0.0754,  0.1224, -0.1442, -0.1595],
          [-0.1937,  0.1224,  0.1674, -0.0

In [5]:
print(list(conv.named_buffers()))

[]


In [6]:
print(conv.weight)
print(conv.bias)

Parameter containing:
tensor([[[[-0.1953, -0.0010, -0.1085,  0.1715,  0.1367],
          [ 0.0376, -0.1104,  0.0248, -0.1407,  0.0730],
          [ 0.1880,  0.1731, -0.0466, -0.1832, -0.1856],
          [ 0.0817,  0.1655,  0.1003,  0.1301, -0.1494],
          [-0.1244,  0.0909, -0.1905, -0.1454,  0.0191]]],


        [[[-0.0866,  0.1612, -0.1080, -0.1353, -0.0048],
          [-0.0390,  0.0409,  0.1372,  0.1084,  0.0089],
          [ 0.0773, -0.0910,  0.0907, -0.0005, -0.0688],
          [-0.1594, -0.0870,  0.1363, -0.0502, -0.1091],
          [-0.1946,  0.1958,  0.0182,  0.0292,  0.0192]]],


        [[[-0.0018,  0.1886,  0.1196, -0.0536,  0.0363],
          [ 0.0043, -0.0062, -0.1419, -0.0246,  0.1308],
          [-0.0145, -0.0925, -0.0921, -0.0057, -0.1508],
          [-0.0317,  0.1312, -0.0472,  0.1026,  0.0955],
          [-0.1895, -0.1796, -0.0120,  0.1371, -0.1712]]],


        [[[-0.1677,  0.0754,  0.1224, -0.1442, -0.1595],
          [-0.1937,  0.1224,  0.1674, -0.0874,  0.1728

In [7]:
# prune.random_structured(conv, name='weight', amount=0.5,  dim=0)

In [8]:
prune.random_unstructured(conv, name="weight", amount=0.3)
prune.random_unstructured(conv, name="bias", amount=0.3)
# prune.remove(conv, 'weight') # 更新权重

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

In [9]:
print(conv.weight)
print(conv.bias)

tensor([[[[-0.0000, -0.0000, -0.1085,  0.1715,  0.1367],
          [ 0.0376, -0.1104,  0.0248, -0.1407,  0.0000],
          [ 0.1880,  0.1731, -0.0466, -0.1832, -0.0000],
          [ 0.0817,  0.1655,  0.1003,  0.0000, -0.0000],
          [-0.1244,  0.0000, -0.1905, -0.1454,  0.0191]]],


        [[[-0.0866,  0.1612, -0.1080, -0.1353, -0.0048],
          [-0.0000,  0.0409,  0.1372,  0.1084,  0.0000],
          [ 0.0773, -0.0910,  0.0907, -0.0005, -0.0688],
          [-0.1594, -0.0870,  0.0000, -0.0502, -0.1091],
          [-0.0000,  0.1958,  0.0182,  0.0292,  0.0192]]],


        [[[-0.0018,  0.1886,  0.0000, -0.0536,  0.0363],
          [ 0.0043, -0.0062, -0.0000, -0.0246,  0.1308],
          [-0.0145, -0.0925, -0.0000, -0.0057, -0.1508],
          [-0.0000,  0.1312, -0.0472,  0.1026,  0.0000],
          [-0.1895, -0.1796, -0.0000,  0.1371, -0.0000]]],


        [[[-0.1677,  0.0000,  0.0000, -0.1442, -0.1595],
          [-0.0000,  0.0000,  0.1674, -0.0874,  0.1728],
          [-0.1931,

In [10]:
print(list(conv.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[-0.1953, -0.0010, -0.1085,  0.1715,  0.1367],
          [ 0.0376, -0.1104,  0.0248, -0.1407,  0.0730],
          [ 0.1880,  0.1731, -0.0466, -0.1832, -0.1856],
          [ 0.0817,  0.1655,  0.1003,  0.1301, -0.1494],
          [-0.1244,  0.0909, -0.1905, -0.1454,  0.0191]]],


        [[[-0.0866,  0.1612, -0.1080, -0.1353, -0.0048],
          [-0.0390,  0.0409,  0.1372,  0.1084,  0.0089],
          [ 0.0773, -0.0910,  0.0907, -0.0005, -0.0688],
          [-0.1594, -0.0870,  0.1363, -0.0502, -0.1091],
          [-0.1946,  0.1958,  0.0182,  0.0292,  0.0192]]],


        [[[-0.0018,  0.1886,  0.1196, -0.0536,  0.0363],
          [ 0.0043, -0.0062, -0.1419, -0.0246,  0.1308],
          [-0.0145, -0.0925, -0.0921, -0.0057, -0.1508],
          [-0.0317,  0.1312, -0.0472,  0.1026,  0.0955],
          [-0.1895, -0.1796, -0.0120,  0.1371, -0.1712]]],


        [[[-0.1677,  0.0754,  0.1224, -0.1442, -0.1595],
          [-0.1937,  0.1224,  0.1674,

In [11]:
print(list(conv.named_buffers()))

[('weight_mask', tensor([[[[0., 0., 1., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 0., 0.],
          [1., 0., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [0., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 1., 0., 1., 1.],
          [0., 1., 1., 1., 1.]]],


        [[[1., 1., 0., 1., 1.],
          [1., 1., 0., 1., 1.],
          [1., 1., 0., 1., 1.],
          [0., 1., 1., 1., 0.],
          [1., 1., 0., 1., 0.]]],


        [[[1., 0., 0., 1., 1.],
          [0., 0., 1., 1., 1.],
          [1., 1., 1., 1., 0.],
          [0., 1., 0., 0., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 0., 1., 0.],
          [0., 0., 1., 1., 1.],
          [1., 0., 0., 0., 0.],
          [0., 0., 0., 1., 1.]]],


        [[[1., 0., 1., 0., 1.],
          [0., 0., 1., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 0., 0., 0., 1.],
          [1., 1., 1., 1., 1.]]]]))

In [12]:
print(conv.weight)

tensor([[[[-0.0000, -0.0000, -0.1085,  0.1715,  0.1367],
          [ 0.0376, -0.1104,  0.0248, -0.1407,  0.0000],
          [ 0.1880,  0.1731, -0.0466, -0.1832, -0.0000],
          [ 0.0817,  0.1655,  0.1003,  0.0000, -0.0000],
          [-0.1244,  0.0000, -0.1905, -0.1454,  0.0191]]],


        [[[-0.0866,  0.1612, -0.1080, -0.1353, -0.0048],
          [-0.0000,  0.0409,  0.1372,  0.1084,  0.0000],
          [ 0.0773, -0.0910,  0.0907, -0.0005, -0.0688],
          [-0.1594, -0.0870,  0.0000, -0.0502, -0.1091],
          [-0.0000,  0.1958,  0.0182,  0.0292,  0.0192]]],


        [[[-0.0018,  0.1886,  0.0000, -0.0536,  0.0363],
          [ 0.0043, -0.0062, -0.0000, -0.0246,  0.1308],
          [-0.0145, -0.0925, -0.0000, -0.0057, -0.1508],
          [-0.0000,  0.1312, -0.0472,  0.1026,  0.0000],
          [-0.1895, -0.1796, -0.0000,  0.1371, -0.0000]]],


        [[[-0.1677,  0.0000,  0.0000, -0.1442, -0.1595],
          [-0.0000,  0.0000,  0.1674, -0.0874,  0.1728],
          [-0.1931,

In [None]:
model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)