# notebook 风格import

In [1]:
import sys, os
sys.path.append(os.path.dirname(os.path.abspath('.')))

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_pruning as tp

# 定义一个比较复杂的网络结构
使用Module, Module.add_module, Sequential, ModuleList, ModuleDict

In [2]:
class DeepFCN(nn.Module):
    def __init__(self, input_size, num_classes):
        super(DeepFCN, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.add_module('first_relu', nn.ReLU())
        self.fc2 = nn.Sequential(
            nn.Linear(256,64),
            nn.ReLU()
        )
        self.fc3 = nn.ModuleList(
            [nn.Sequential(
            nn.Linear(64,64),
            nn.ReLU()) for i in range(3)
            ]
        )
        self.fc4 = nn.ModuleDict({
            'fc4-1': nn.Linear(64,32),
            'relu': nn.ReLU()
        })
        self.fc5 = nn.Linear(32, num_classes)

    def forward(self, x):
        x = self.fc1(x)
        x = self.first_relu(x)
        x = self.fc2(x)
        for i, l in enumerate(self.fc3):
            x = l(x)
        x = self.fc4['fc4-1'](x)
        x = self.fc4['relu'](x)
        y_hat = self.fc5(x)
        return y_hat

model = DeepFCN(128, 10)
print(model)

DeepFCN(
  (fc1): Linear(in_features=128, out_features=256, bias=True)
  (first_relu): ReLU()
  (fc2): Sequential(
    (0): Linear(in_features=256, out_features=64, bias=True)
    (1): ReLU()
  )
  (fc3): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
    )
    (1): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
    )
    (2): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
    )
  )
  (fc4): ModuleDict(
    (fc4-1): Linear(in_features=64, out_features=32, bias=True)
    (relu): ReLU()
  )
  (fc5): Linear(in_features=32, out_features=10, bias=True)
)


# 注册不能修改的层 static_layers

In [3]:
static_layers = []
static_layers.append(model.fc5)
print(static_layers)

[Linear(in_features=32, out_features=10, bias=True)]


# 递归进行随机剪枝idx选取

In [4]:
strategy = tp.strategy.RandomStrategy()

module_to_idxs = {}

def init_strategy(m):
    if isinstance(m, nn.Linear) and m not in static_layers:
        module_to_idxs[m] = strategy(m.weight, amount=0.2)

model.apply(init_strategy)

print(module_to_idxs)

{Linear(in_features=128, out_features=256, bias=True): [89, 52, 120, 132, 73, 195, 148, 33, 9, 170, 44, 98, 209, 203, 51, 94, 11, 36, 117, 253, 17, 204, 78, 1, 130, 174, 173, 39, 123, 205, 37, 143, 255, 157, 246, 47, 12, 29, 30, 93, 248, 241, 105, 22, 208, 197, 122, 91, 96, 3, 190], Linear(in_features=256, out_features=64, bias=True): [55, 5, 26, 25, 56, 31, 33, 43, 40, 10, 7, 53], Linear(in_features=64, out_features=64, bias=True): [43, 16, 50, 52, 24, 29, 3, 44, 10, 46, 18, 13], Linear(in_features=64, out_features=64, bias=True): [63, 31, 14, 32, 35, 19, 48, 62, 25, 44, 26, 58], Linear(in_features=64, out_features=64, bias=True): [54, 48, 21, 12, 57, 7, 34, 61, 43, 6, 52, 46], Linear(in_features=64, out_features=32, bias=True): [13, 15, 3, 18, 30, 4]}


# 制作依赖图，根据idx进行规划

In [5]:
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,128))

pruning_plans = []

def get_pruning_plans(m):
    if m in module_to_idxs:
        pruning_plans.append(DG.get_pruning_plan(m, tp.prune_linear, idxs=module_to_idxs[m]))

model.apply(get_pruning_plans)
for plan in pruning_plans:
    print(plan)


-------------
[ <DEP: prune_linear => prune_linear on fc1 (Linear(in_features=128, out_features=256, bias=True))>, Index=[89, 52, 120, 132, 73, 195, 148, 33, 9, 170, 44, 98, 209, 203, 51, 94, 11, 36, 117, 253, 17, 204, 78, 1, 130, 174, 173, 39, 123, 205, 37, 143, 255, 157, 246, 47, 12, 29, 30, 93, 248, 241, 105, 22, 208, 197, 122, 91, 96, 3, 190], NumPruned=6579]
[ <DEP: prune_linear => _prune_elementwise_op on _ElementWiseOp()>, Index=[89, 52, 120, 132, 73, 195, 148, 33, 9, 170, 44, 98, 209, 203, 51, 94, 11, 36, 117, 253, 17, 204, 78, 1, 130, 174, 173, 39, 123, 205, 37, 143, 255, 157, 246, 47, 12, 29, 30, 93, 248, 241, 105, 22, 208, 197, 122, 91, 96, 3, 190], NumPruned=0]
[ <DEP: _prune_elementwise_op => prune_related_linear on fc2.0 (Linear(in_features=256, out_features=64, bias=True))>, Index=[89, 52, 120, 132, 73, 195, 148, 33, 9, 170, 44, 98, 209, 203, 51, 94, 11, 36, 117, 253, 17, 204, 78, 1, 130, 174, 173, 39, 123, 205, 37, 143, 255, 157, 246, 47, 12, 29, 30, 93, 248, 241, 105,

## 执行剪枝计划，并对比前后模型

In [6]:
print(model)
for plan in pruning_plans:
        plan.exec()
print('<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< before')
print('after >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
print(model)

DeepFCN(
  (fc1): Linear(in_features=128, out_features=256, bias=True)
  (first_relu): ReLU()
  (fc2): Sequential(
    (0): Linear(in_features=256, out_features=64, bias=True)
    (1): ReLU()
  )
  (fc3): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
    )
    (1): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
    )
    (2): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
    )
  )
  (fc4): ModuleDict(
    (fc4-1): Linear(in_features=64, out_features=32, bias=True)
    (relu): ReLU()
  )
  (fc5): Linear(in_features=32, out_features=10, bias=True)
)
<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< before
after >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
DeepFCN(
  (fc1): Linear(in_features=128, out_features=205, bias=True)
  (first_relu): ReLU()
  (fc2): Sequential(
    (0): Linear(in_features=205, out_features=52, bias=True)
    (1): ReLU()
 