[文章参考来源](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html#serializing-a-pruned-model)

In [1]:
%matplotlib inline

In [2]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

## 定义网络结构

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

## 查看网络结构

In [4]:
model.parameters

<bound method Module.parameters of LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)>

In [5]:
model.modules

<bound method Module.modules of LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)>

## 监测conv1，其中包含了weight和bias,因为还没使用conv1，所以没有buffers产生

In [6]:
module = model.conv1

In [7]:
module

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

先说一下后面会用到的几个概念，
model.conv1.weight   这里的weight是属性,运行prune操作之后产生的稀疏权重会存储在这里，然后我们需要运行prune.remove()操作，才能让model.conv1.weight和model.conv1.named_parameters()的显示结果变成一样

model.conv1.named_buffers()  运行prune操作后，生成的mask会在这里

model.conv1.named_parameters()   模型参数的存储位置,这里面的值会跟随torch.save()存储到本地的.pth

In [8]:
list(module.named_parameters())

[('weight',
  Parameter containing:
  tensor([[[[-0.2455,  0.0065, -0.1007],
            [-0.2746,  0.1460, -0.1434],
            [ 0.0533, -0.1866, -0.2310]]],
  
  
          [[[ 0.2056, -0.0561, -0.1685],
            [ 0.1527, -0.1482,  0.2526],
            [ 0.0145, -0.0996,  0.2773]]],
  
  
          [[[ 0.2489, -0.1215,  0.1214],
            [ 0.1938, -0.2379,  0.2340],
            [-0.2849,  0.2429,  0.1251]]],
  
  
          [[[-0.2392, -0.1834,  0.1596],
            [-0.1394,  0.2261, -0.2334],
            [-0.2830, -0.1771, -0.1306]]],
  
  
          [[[-0.0442,  0.2732,  0.3017],
            [-0.0153,  0.3309,  0.0380],
            [ 0.0040,  0.2164, -0.1804]]],
  
  
          [[[-0.2539,  0.0888, -0.0319],
            [ 0.1115, -0.1317, -0.1421],
            [ 0.1865, -0.3027,  0.0986]]]], device='cuda:0', requires_grad=True)),
 ('bias',
  Parameter containing:
  tensor([-0.2467,  0.2349, -0.2630, -0.0814,  0.1895,  0.2932], device='cuda:0',
         requires_grad=True)

In [9]:
module.weight

Parameter containing:
tensor([[[[-0.2455,  0.0065, -0.1007],
          [-0.2746,  0.1460, -0.1434],
          [ 0.0533, -0.1866, -0.2310]]],


        [[[ 0.2056, -0.0561, -0.1685],
          [ 0.1527, -0.1482,  0.2526],
          [ 0.0145, -0.0996,  0.2773]]],


        [[[ 0.2489, -0.1215,  0.1214],
          [ 0.1938, -0.2379,  0.2340],
          [-0.2849,  0.2429,  0.1251]]],


        [[[-0.2392, -0.1834,  0.1596],
          [-0.1394,  0.2261, -0.2334],
          [-0.2830, -0.1771, -0.1306]]],


        [[[-0.0442,  0.2732,  0.3017],
          [-0.0153,  0.3309,  0.0380],
          [ 0.0040,  0.2164, -0.1804]]],


        [[[-0.2539,  0.0888, -0.0319],
          [ 0.1115, -0.1317, -0.1421],
          [ 0.1865, -0.3027,  0.0986]]]], device='cuda:0', requires_grad=True)

因为还没使用过conv1,所以这里的buffers为空，调用prune之后你会发现差别的

In [10]:
list(module.named_buffers())

[]

## 调用prune对conv1里的weight进行剪枝

In [11]:
prune.random_unstructured(module, name="weight", amount=0.3)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

你可以发现这里的weight被重命名为weight_orig

In [12]:
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([-0.2467,  0.2349, -0.2630, -0.0814,  0.1895,  0.2932], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-0.2455,  0.0065, -0.1007],
          [-0.2746,  0.1460, -0.1434],
          [ 0.0533, -0.1866, -0.2310]]],


        [[[ 0.2056, -0.0561, -0.1685],
          [ 0.1527, -0.1482,  0.2526],
          [ 0.0145, -0.0996,  0.2773]]],


        [[[ 0.2489, -0.1215,  0.1214],
          [ 0.1938, -0.2379,  0.2340],
          [-0.2849,  0.2429,  0.1251]]],


        [[[-0.2392, -0.1834,  0.1596],
          [-0.1394,  0.2261, -0.2334],
          [-0.2830, -0.1771, -0.1306]]],


        [[[-0.0442,  0.2732,  0.3017],
          [-0.0153,  0.3309,  0.0380],
          [ 0.0040,  0.2164, -0.1804]]],


        [[[-0.2539,  0.0888, -0.0319],
          [ 0.1115, -0.1317, -0.1421],
          [ 0.1865, -0.3027,  0.0986]]]], device='cuda:0', requires_grad=True))]


model.conv1.weight 这里的weight是属性,运行prune操作之后产生的稀疏权重会存储在这里

In [13]:
module.weight

tensor([[[[-0.2455,  0.0065, -0.1007],
          [-0.2746,  0.1460, -0.1434],
          [ 0.0533, -0.1866, -0.2310]]],


        [[[ 0.2056, -0.0561, -0.0000],
          [ 0.1527, -0.0000,  0.2526],
          [ 0.0145, -0.0000,  0.0000]]],


        [[[ 0.2489, -0.1215,  0.1214],
          [ 0.0000, -0.0000,  0.2340],
          [-0.2849,  0.2429,  0.0000]]],


        [[[-0.2392, -0.1834,  0.1596],
          [-0.0000,  0.2261, -0.2334],
          [-0.0000, -0.1771, -0.1306]]],


        [[[-0.0442,  0.0000,  0.3017],
          [-0.0153,  0.3309,  0.0380],
          [ 0.0040,  0.0000, -0.1804]]],


        [[[-0.2539,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [ 0.1865, -0.3027,  0.0986]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

model.conv1.named_buffers() 运行prune操作后，生成的mask会在这里

In [14]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [1., 0., 1.],
          [1., 0., 0.]]],


        [[[1., 1., 1.],
          [0., 0., 1.],
          [1., 1., 0.]]],


        [[[1., 1., 1.],
          [0., 1., 1.],
          [0., 1., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 1.],
          [1., 0., 1.]]],


        [[[1., 0., 0.],
          [0., 0., 0.],
          [1., 1., 1.]]]], device='cuda:0'))]


In [15]:
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f110f8550b8>)])


# 进行L1的Pruning操作

In [16]:
prune.l1_unstructured(module, name='bias', amount=3)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

In [17]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[-0.2455,  0.0065, -0.1007],
          [-0.2746,  0.1460, -0.1434],
          [ 0.0533, -0.1866, -0.2310]]],


        [[[ 0.2056, -0.0561, -0.1685],
          [ 0.1527, -0.1482,  0.2526],
          [ 0.0145, -0.0996,  0.2773]]],


        [[[ 0.2489, -0.1215,  0.1214],
          [ 0.1938, -0.2379,  0.2340],
          [-0.2849,  0.2429,  0.1251]]],


        [[[-0.2392, -0.1834,  0.1596],
          [-0.1394,  0.2261, -0.2334],
          [-0.2830, -0.1771, -0.1306]]],


        [[[-0.0442,  0.2732,  0.3017],
          [-0.0153,  0.3309,  0.0380],
          [ 0.0040,  0.2164, -0.1804]]],


        [[[-0.2539,  0.0888, -0.0319],
          [ 0.1115, -0.1317, -0.1421],
          [ 0.1865, -0.3027,  0.0986]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.2467,  0.2349, -0.2630, -0.0814,  0.1895,  0.2932], device='cuda:0',
       requires_grad=True))]


In [18]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [1., 0., 1.],
          [1., 0., 0.]]],


        [[[1., 1., 1.],
          [0., 0., 1.],
          [1., 1., 0.]]],


        [[[1., 1., 1.],
          [0., 1., 1.],
          [0., 1., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 1.],
          [1., 0., 1.]]],


        [[[1., 0., 0.],
          [0., 0., 0.],
          [1., 1., 1.]]]], device='cuda:0')), ('bias_mask', tensor([1., 0., 1., 0., 0., 1.], device='cuda:0'))]


In [19]:
print(module.bias)

tensor([-0.2467,  0.0000, -0.2630, -0.0000,  0.0000,  0.2932], device='cuda:0',
       grad_fn=<MulBackward0>)


In [20]:
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f110f8550b8>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7f110f86d860>)])


## 这里进行了structured pruning，注意看输出的weight,都是一整个channel为0的

In [21]:
prune.ln_structured(module, name='weight', amount=0.5, n=2, dim=0)
# as we can verify, this will zero out all the connections corresponding to 50%(3 out of 6) of the channels,
# while preserving the action of the previous mask

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

For the forward pass to work without modification, the weight attribute needs to exist. The pruning techniques implemented in torch.nn.utils.prune compute the pruned version of the weight (by combining the mask with the original parameter) and store them in the attribute weight. Note, this is no longer a parameter of the module, it is now simply an attribute

In [22]:
print(module.weight)

tensor([[[[-0.2455,  0.0065, -0.1007],
          [-0.2746,  0.1460, -0.1434],
          [ 0.0533, -0.1866, -0.2310]]],


        [[[ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]]],


        [[[ 0.2489, -0.1215,  0.1214],
          [ 0.0000, -0.0000,  0.2340],
          [-0.2849,  0.2429,  0.0000]]],


        [[[-0.2392, -0.1834,  0.1596],
          [-0.0000,  0.2261, -0.2334],
          [-0.0000, -0.1771, -0.1306]]],


        [[[-0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000]]],


        [[[-0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)


In [23]:
for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == 'weight':
        break

print(list(hook))

[<torch.nn.utils.prune.RandomUnstructured object at 0x7f110f8550b8>, <torch.nn.utils.prune.LnStructured object at 0x7f1106a14a90>]


# 可以看到存在weight_orig和weight_mask，两者运算之后产生prune之后的结果

In [24]:
print(model.state_dict().keys())

odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])


In [25]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[-0.2455,  0.0065, -0.1007],
          [-0.2746,  0.1460, -0.1434],
          [ 0.0533, -0.1866, -0.2310]]],


        [[[ 0.2056, -0.0561, -0.1685],
          [ 0.1527, -0.1482,  0.2526],
          [ 0.0145, -0.0996,  0.2773]]],


        [[[ 0.2489, -0.1215,  0.1214],
          [ 0.1938, -0.2379,  0.2340],
          [-0.2849,  0.2429,  0.1251]]],


        [[[-0.2392, -0.1834,  0.1596],
          [-0.1394,  0.2261, -0.2334],
          [-0.2830, -0.1771, -0.1306]]],


        [[[-0.0442,  0.2732,  0.3017],
          [-0.0153,  0.3309,  0.0380],
          [ 0.0040,  0.2164, -0.1804]]],


        [[[-0.2539,  0.0888, -0.0319],
          [ 0.1115, -0.1317, -0.1421],
          [ 0.1865, -0.3027,  0.0986]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.2467,  0.2349, -0.2630, -0.0814,  0.1895,  0.2932], device='cuda:0',
       requires_grad=True))]


In [26]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[1., 1., 1.],
          [0., 0., 1.],
          [1., 1., 0.]]],


        [[[1., 1., 1.],
          [0., 1., 1.],
          [0., 1., 1.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]], device='cuda:0')), ('bias_mask', tensor([1., 0., 1., 0., 0., 1.], device='cuda:0'))]


In [27]:
print(module.weight)

tensor([[[[-0.2455,  0.0065, -0.1007],
          [-0.2746,  0.1460, -0.1434],
          [ 0.0533, -0.1866, -0.2310]]],


        [[[ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]]],


        [[[ 0.2489, -0.1215,  0.1214],
          [ 0.0000, -0.0000,  0.2340],
          [-0.2849,  0.2429,  0.0000]]],


        [[[-0.2392, -0.1834,  0.1596],
          [-0.0000,  0.2261, -0.2334],
          [-0.0000, -0.1771, -0.1306]]],


        [[[-0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000]]],


        [[[-0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)


运行prune.remove的作用，To make the pruning permanent, remove the re-parametrization in terms of weight_orig and weight_mask, and remove the forward_pre_hook, we can use the remove functionality from torch.nn.utils.prune. Note that this doesn't undo the pruning, as if it never happened. It simply makes it permanent, instead, by reassigning the parameter weight to the model parameters, in its pruned version.

# prune.remove之后，我们发现weight_orig变成了weight,其实就是把module.weight的值赋值给了weight_orig

In [28]:
prune.remove(module, 'weight')
print(list(module.named_parameters()))

[('bias_orig', Parameter containing:
tensor([-0.2467,  0.2349, -0.2630, -0.0814,  0.1895,  0.2932], device='cuda:0',
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[-0.2455,  0.0065, -0.1007],
          [-0.2746,  0.1460, -0.1434],
          [ 0.0533, -0.1866, -0.2310]]],


        [[[ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]]],


        [[[ 0.2489, -0.1215,  0.1214],
          [ 0.0000, -0.0000,  0.2340],
          [-0.2849,  0.2429,  0.0000]]],


        [[[-0.2392, -0.1834,  0.1596],
          [-0.0000,  0.2261, -0.2334],
          [-0.0000, -0.1771, -0.1306]]],


        [[[-0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000]]],


        [[[-0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000]]]], device='cuda:0', requires_grad=True))]


In [29]:
print(list(module.named_buffers()))

[('bias_mask', tensor([1., 0., 1., 0., 0., 1.], device='cuda:0'))]


## Prunning multiple parameters in a model

In [30]:
new_model = LeNet()  

In [31]:
for name, module in new_model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

In [32]:
print(dict(new_model.named_buffers()).keys())

dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])


## Global Prunning

So far, we only looked at what is usually referred to as "local" pruning, i.e. the practice of pruning tensors in a model one by one, by comparing the statistics (weight magnitude, activation, gradient, etc.) of each entry exclusively to the other entries in that tensor. However, a common and perhaps more powerful technique is to prune the model all at once, by removing (for example) the lowest 20% of connections across the whole model, instead of removing the lowest 20% of connections in each layer. This is likely to result in different pruning percentages per layer. Let's see how to do that using global_unstructured from torch.nn.utils.prune

In [33]:
model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight')
)

In [34]:
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount = 0.7
)

In [35]:
print("Sparsity in conv1.weight: {:.2f}%".format(100.0 * float(torch.sum(model.conv1.weight == 0)) / 
     float(model.conv1.weight.nelement())))

Sparsity in conv1.weight: 11.11%


In [36]:
print("Sparsity in conv2.weight: {:.2f}%".format(100.0 * float(torch.sum(model.conv2.weight == 0)) /
                                                float(model.conv2.weight.nelement())))

Sparsity in conv2.weight: 27.31%


In [37]:
print("Sparsity in fc1.weight: {:.2f}%".format(100.0 * float(torch.sum(model.fc1.weight == 0)) / 
     float(model.fc1.weight.nelement())))

Sparsity in fc1.weight: 77.34%


In [38]:
print("Sparsity in fc2.weight: {:.2f}%".format(100.0 * float(torch.sum(model.fc2.weight == 0)) / 
     float(model.fc2.weight.nelement())))

Sparsity in fc2.weight: 41.76%


In [39]:
print("Sparsity in fc3.weight: {:.2f}%".format(100.0 * float(torch.sum(model.fc3.weight == 0)) /
                                              float(model.fc3.weight.nelement())))

Sparsity in fc3.weight: 37.50%


In [40]:
print("Global sparsity: {:.2f}%".format(
    100.0 * float(
        torch.sum(model.conv1.weight == 0)
        + torch.sum(model.conv2.weight == 0)
        + torch.sum(model.fc1.weight == 0)
        + torch.sum(model.fc2.weight == 0)
        + torch.sum(model.fc3.weight == 0)
    )
    / float(
        model.conv1.weight.nelement()
        + model.conv2.weight.nelement()
        + model.fc1.weight.nelement()
        + model.fc2.weight.nelement()
        + model.fc3.weight.nelement()
    )
))

Global sparsity: 70.00%


# 计算0值的个数，结果看来是吻合的，计算结果为14.81%稀疏度，在conv1层的weight上

In [41]:
torch.sum(model.conv1.weight == 0)

tensor(6)

In [42]:
model.conv1.weight.nelement()

54

In [43]:
8 / 54.0 * 100

14.814814814814813

In [44]:
model.conv1.weight

tensor([[[[ 0.0000, -0.2941, -0.0869],
          [ 0.2473, -0.1696, -0.0801],
          [-0.1297, -0.3271,  0.2942]]],


        [[[-0.3062, -0.2663, -0.2162],
          [ 0.0476, -0.2757,  0.2436],
          [-0.2258,  0.0679, -0.0000]]],


        [[[ 0.2015,  0.0000,  0.2316],
          [ 0.2343, -0.2772,  0.3242],
          [-0.1783, -0.0607, -0.2033]]],


        [[[ 0.0815, -0.2245,  0.3070],
          [-0.0744, -0.1191,  0.1764],
          [ 0.3127,  0.2398,  0.0000]]],


        [[[ 0.3139,  0.0565, -0.2029],
          [ 0.3229, -0.2615, -0.2801],
          [-0.0000, -0.2270, -0.0000]]],


        [[[-0.1660,  0.1511,  0.1712],
          [ 0.2690, -0.0737, -0.2197],
          [ 0.2962,  0.1628, -0.0903]]]], grad_fn=<MulBackward0>)

In [45]:
list(model.conv1.named_parameters())

[('bias',
  Parameter containing:
  tensor([-0.2342, -0.2550, -0.1118,  0.2898, -0.0040,  0.0299],
         requires_grad=True)),
 ('weight_orig',
  Parameter containing:
  tensor([[[[ 0.0311, -0.2941, -0.0869],
            [ 0.2473, -0.1696, -0.0801],
            [-0.1297, -0.3271,  0.2942]]],
  
  
          [[[-0.3062, -0.2663, -0.2162],
            [ 0.0476, -0.2757,  0.2436],
            [-0.2258,  0.0679, -0.0202]]],
  
  
          [[[ 0.2015,  0.0066,  0.2316],
            [ 0.2343, -0.2772,  0.3242],
            [-0.1783, -0.0607, -0.2033]]],
  
  
          [[[ 0.0815, -0.2245,  0.3070],
            [-0.0744, -0.1191,  0.1764],
            [ 0.3127,  0.2398,  0.0317]]],
  
  
          [[[ 0.3139,  0.0565, -0.2029],
            [ 0.3229, -0.2615, -0.2801],
            [-0.0097, -0.2270, -0.0180]]],
  
  
          [[[-0.1660,  0.1511,  0.1712],
            [ 0.2690, -0.0737, -0.2197],
            [ 0.2962,  0.1628, -0.0903]]]], requires_grad=True))]

# 在保存模型之前，对每个layer运行prune.remove操作

In [46]:
for module, name in parameters_to_prune:
    print(module, name)
    prune.remove(module, name)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1)) weight
Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1)) weight
Linear(in_features=400, out_features=120, bias=True) weight
Linear(in_features=120, out_features=84, bias=True) weight
Linear(in_features=84, out_features=10, bias=True) weight


In [47]:
list(model.conv1.named_parameters())

[('bias',
  Parameter containing:
  tensor([-0.2342, -0.2550, -0.1118,  0.2898, -0.0040,  0.0299],
         requires_grad=True)),
 ('weight',
  Parameter containing:
  tensor([[[[ 0.0000, -0.2941, -0.0869],
            [ 0.2473, -0.1696, -0.0801],
            [-0.1297, -0.3271,  0.2942]]],
  
  
          [[[-0.3062, -0.2663, -0.2162],
            [ 0.0476, -0.2757,  0.2436],
            [-0.2258,  0.0679, -0.0000]]],
  
  
          [[[ 0.2015,  0.0000,  0.2316],
            [ 0.2343, -0.2772,  0.3242],
            [-0.1783, -0.0607, -0.2033]]],
  
  
          [[[ 0.0815, -0.2245,  0.3070],
            [-0.0744, -0.1191,  0.1764],
            [ 0.3127,  0.2398,  0.0000]]],
  
  
          [[[ 0.3139,  0.0565, -0.2029],
            [ 0.3229, -0.2615, -0.2801],
            [-0.0000, -0.2270, -0.0000]]],
  
  
          [[[-0.1660,  0.1511,  0.1712],
            [ 0.2690, -0.0737, -0.2197],
            [ 0.2962,  0.1628, -0.0903]]]], requires_grad=True))]

In [48]:
list(model.conv1.named_buffers())

[]

# 保存模型以及比较.pth个压缩后的.zip格式模型的大小

In [49]:
torch.save(model.state_dict(), 'sparse_model.pth')

In [50]:
(241.8 - 91.1) / 241.8

0.6232423490488007

In [51]:
model.conv1.weight

Parameter containing:
tensor([[[[ 0.0000, -0.2941, -0.0869],
          [ 0.2473, -0.1696, -0.0801],
          [-0.1297, -0.3271,  0.2942]]],


        [[[-0.3062, -0.2663, -0.2162],
          [ 0.0476, -0.2757,  0.2436],
          [-0.2258,  0.0679, -0.0000]]],


        [[[ 0.2015,  0.0000,  0.2316],
          [ 0.2343, -0.2772,  0.3242],
          [-0.1783, -0.0607, -0.2033]]],


        [[[ 0.0815, -0.2245,  0.3070],
          [-0.0744, -0.1191,  0.1764],
          [ 0.3127,  0.2398,  0.0000]]],


        [[[ 0.3139,  0.0565, -0.2029],
          [ 0.3229, -0.2615, -0.2801],
          [-0.0000, -0.2270, -0.0000]]],


        [[[-0.1660,  0.1511,  0.1712],
          [ 0.2690, -0.0737, -0.2197],
          [ 0.2962,  0.1628, -0.0903]]]], requires_grad=True)