# 导入函数库

In [1]:
import torch
import torchvision
import torch.nn.utils.prune as prune
from torchvision import models
import torch.nn as nn

# 定义网络结构，以及替换网络的最后一层

In [2]:
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2) # 此处沿用了之前迁移学习所使用的resnet18的二分类，所以最后一层替换成了2个神经元

## 查看网络结构

In [3]:
model_ft.parameters

<bound method Module.parameters of ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)


# 收集要prune的模块

In [4]:
parameters_to_prune = []

In [5]:
for name, module in model_ft.named_modules():
    if isinstance(module, torch.nn.Conv2d):  # 收集卷积模块
        parameters_to_prune.append((module, 'weight'))  # 此处不对bias进行剪枝
    elif isinstance(module, torch.nn.Linear): # 收集全连接模块
        parameters_to_prune.append((module, 'weight'))  # 此处不对bias进行剪枝

## 展示收集到的模块

In [6]:
parameters_to_prune

[(Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
  'weight'),
 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
  'weight'),
 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
  'weight'),
 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
  'weight'),
 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
  'weight'),
 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
  'weight'),
 (Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
  'weight'),
 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False), 'weight'),
 (Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
  'weight'),
 (Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
  'weight'),
 (Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), b

In [7]:
len(parameters_to_prune)

21

## 转化成tuple形式

In [8]:
parameters_to_prune = tuple(parameters_to_prune)

# 进行全局prune

In [9]:
prune.global_unstructured(parameters_to_prune,  # 接受的是tuple格式，所以上面要把列表转元组
                         pruning_method=prune.L1Unstructured, # 使用的是L1非结构化剪枝
                         amount=0.5)  # 剪枝比例设置为0.5，也就是剪去50%的连接

## 计算conv1层的稀疏度

In [10]:
# .nelement()可以统计array里面的元素个数
100.0 * float(torch.sum(model_ft.conv1.weight == 0)) / float(model_ft.conv1.weight.nelement())

26.12670068027211

## 计算BasicBlock里的conv1的稀疏度

In [11]:
100.0 * float(torch.sum(model_ft.layer1[0].conv1.weight == 0)) / float(model_ft.layer1[0].conv1.weight.nelement())

35.68250868055556

我们发现全局剪枝的话，每层剪枝比例是不一样的，而且通常来说，全局剪枝会比每层固定剪枝一定比例的效果来的更好