[Reference](https://levelup.gitconnected.com/a-friendly-introduction-to-model-pruning-with-pytorch-73245d5d28f3)

In [6]:
import torch
import torch.nn.utils.prune as prune

# Define a simple neural network
class SimpleNet(torch.nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 20, 5)
        self.fc1 = torch.nn.Linear(20 * 24 * 24, 500)

    def forward(self, x):
        x = self.conv1(x)
        x = x.view(-1, 20 * 24 * 24)
        x = self.fc1(x)
        return x

# Create model instance
model = SimpleNet()

In [9]:
layer = model.conv1
sparsity = 100.0 * float(torch.sum(layer.weight == 0)) / float(layer.weight.nelement())
print(f"Sparsity in {layer}: {sparsity:.2f}%")

Sparsity in Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1)): 0.00%


In [11]:
# Structured Pruning of 30% of the filters in the same conv2d layer
# https://docs.pytorch.org/docs/stable/generated/torch.nn.utils.prune.ln_structured.html
prune.ln_structured(
    model.conv1,
    name='weight',
    amount=0.3,
    n=2,
    dim=0
)

layer = model.conv1
sparsity = 100.0 * float(torch.sum(layer.weight == 0)) / float(layer.weight.nelement())
print(f"Sparsity in {layer}: {sparsity:.2f}%")

Sparsity in Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1)): 30.00%


In [13]:
import torch.nn.utils.prune as prune

parameters_to_prune = [
    (layer, 'weight') for layer in model.modules()
    if hasattr(layer, 'weight')
]

# Global Unstructured Pruning
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.3
)

# 각 layer의 sparsity 출력
for layer, _ in parameters_to_prune:
    sparsity = 100.0 * float(torch.sum(layer.weight == 0)) / float(layer.weight.nelement())
    print(f"Sparsity in {layer.__class__.__name__}: {sparsity:.2f}%")


Sparsity in Conv2d: 30.40%
Sparsity in Linear: 30.00%


In [14]:
# Fine-tuning loop
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

def fine_tune(model, train_loader, epochs=5):
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()