In [3]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from model import LeNet

ModuleNotFoundError: No module named 'torch.nn.utils.prune'

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
model = LeNet().to(device=device)

In [4]:
print(model.state_dict().keys())

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])


In [5]:
print(type(model))

<class 'model.LeNet'>


In [6]:
isinstance(model, nn.Module)

True

In [7]:
print(model)

LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [8]:
print(model.conv1)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))


In [9]:
parameters_to_prune = (
            (model.conv1, 'weight'),
            (model.conv2, 'weight'),
            (model.fc1, 'weight'),
            (model.fc2, 'weight'),
            (model.fc3, 'weight'))

In [10]:
prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2)

In [11]:
for name, module in model.named_children():
    print("Sparsity in {}: {:.2f}%".format(
        name,
        100. * float(torch.sum(module.weight == 0))
        / float(module.weight.nelement())
        ))

Sparsity in conv1: 9.26%
Sparsity in conv2: 7.52%
Sparsity in fc1: 22.17%
Sparsity in fc2: 11.68%
Sparsity in fc3: 9.40%


p.s.将无训练参数的网络层用nn.functional的形式写入forward中，则在print网络的时候不会被打印，因为关于functional的关系被写入forward函数中，只会在前向的过程中实现：

In [12]:
print(model)

LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)
