In [3]:
import torch
import torch.nn.functional as F
from torch.optim import SGD

from nni_assets.compression.mnist_model import TorchModel, trainer, evaluator, device

# define the model
model = TorchModel().to(device)

# show the model structure, note that pruner will wrap the model layer.
print(model)

TorchModel(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
  (relu1): ReLU()
  (relu2): ReLU()
  (relu3): ReLU()
  (relu4): ReLU()
  (pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (pool2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)


In [2]:
# define the optimizer and criterion for pre-training

optimizer = SGD(model.parameters(), 1e-2)
criterion = F.nll_loss

# pre-train and evaluate the model on MNIST dataset
for epoch in range(3):
    trainer(model, optimizer, criterion)
    evaluator(model)

RuntimeError: Numpy is not available

In [4]:
total = 0
for name, param in model.state_dict().items():
    print(f"{name}: {param.numel()}")
    total += param.numel()
print(f"total: {total}")

conv1.weight: 150
conv1.bias: 6
conv2.weight: 2400
conv2.bias: 16
fc1.weight: 30720
fc1.bias: 120
fc2.weight: 10080
fc2.bias: 84
fc3.weight: 840
fc3.bias: 10
total: 44426


In [5]:
# Prune configs
config_list = [{
    'op_types': ['Linear', 'Conv2d'], # types of layers to prune
    'exclude_op_names': ['fc3'], # exclude specific layers
    'sparse_ratio': 0.3 # mask 30% of the parameters
}]

In [6]:
from nni.compression.pruning import L1NormPruner
pruner = L1NormPruner(model, config_list)

# show the wrapped model structure, `PrunerModuleWrapper` have wrapped the layers that configured in the config_list.
print(model)

TorchModel(
  (conv1): Conv2d(
    1, 6, kernel_size=(5, 5), stride=(1, 1)
    (_nni_wrapper): ModuleWrapper(module=Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1)), module_name=conv1)
  )
  (conv2): Conv2d(
    6, 16, kernel_size=(5, 5), stride=(1, 1)
    (_nni_wrapper): ModuleWrapper(module=Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)), module_name=conv2)
  )
  (fc1): Linear(
    in_features=256, out_features=120, bias=True
    (_nni_wrapper): ModuleWrapper(module=Linear(in_features=256, out_features=120, bias=True), module_name=fc1)
  )
  (fc2): Linear(
    in_features=120, out_features=84, bias=True
    (_nni_wrapper): ModuleWrapper(module=Linear(in_features=120, out_features=84, bias=True), module_name=fc2)
  )
  (fc3): Linear(in_features=84, out_features=10, bias=True)
  (relu1): ReLU()
  (relu2): ReLU()
  (relu3): ReLU()
  (relu4): ReLU()
  (pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (pool2): MaxPool2d(kernel_size=(2, 2), str

In [7]:
# compress the model and generate the masks
_, masks = pruner.compress()
# show the masks sparsity
for name, mask in masks.items():
    print(name, ' sparsity : ', '{:.2}'.format(mask['weight'].sum() / mask['weight'].numel()))

fc2  sparsity :  0.7
conv2  sparsity :  0.75
fc1  sparsity :  0.7
conv1  sparsity :  0.83


In [8]:
# need to unwrap the model, if the model is wrapped before speedup
pruner.unwrap_model()

# speedup the model, for more information about speedup, please refer :doc:`pruning_speedup`.
from nni.compression.speedup import ModelSpeedup

ModelSpeedup(model, torch.rand(3, 1, 28, 28).to(device), masks).speedup_model()

[2024-07-04 14:34:24] [32mStart to speedup the model...[0m
[2024-07-04 14:34:24] [32mResolve the mask conflict before mask propagate...[0m
[2024-07-04 14:34:24] [32mdim0 sparsity: 0.227273[0m
[2024-07-04 14:34:24] [32mdim1 sparsity: 0.000000[0m
0 Filter
[2024-07-04 14:34:24] [32mdim0 sparsity: 0.227273[0m
[2024-07-04 14:34:24] [32mdim1 sparsity: 0.000000[0m
[2024-07-04 14:34:24] [32mInfer module masks...[0m
[2024-07-04 14:34:24] [32mPropagate original variables[0m
[2024-07-04 14:34:24] [32mPropagate variables for placeholder: x, output mask:  0.0000 [0m
[2024-07-04 14:34:24] [32mPropagate variables for call_module: conv1, weight:  0.1667 bias:  0.1667 , output mask:  0.0000 [0m
[2024-07-04 14:34:24] [32mPropagate variables for call_module: relu1, , output mask:  0.0000 [0m
[2024-07-04 14:34:24] [32mPropagate variables for call_module: pool1, , output mask:  0.0000 [0m
[2024-07-04 14:34:24] [32mPropagate variables for call_module: conv2, weight:  0.2500 bias:  0

TorchModel(
  (conv1): Conv2d(1, 5, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(5, 12, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=192, out_features=84, bias=True)
  (fc2): Linear(in_features=84, out_features=59, bias=True)
  (fc3): Linear(in_features=59, out_features=10, bias=True)
  (relu1): ReLU()
  (relu2): ReLU()
  (relu3): ReLU()
  (relu4): ReLU()
  (pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (pool2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)