# Toy pruning example

> NOTE: The example below is for educational purposes only. One can easily compress the model we define below by just reducing the number of channels manually. The purpose is to show how to integrate channel pruning into a `torch` training loop.

In this toy example we are going to show how to run channel pruning on a really simple model on `MNIST` dataset.
Even though, its just `MNIST` the same workflow and principles apply to running channel pruning with `torch-dag` on other models. The outline of the notebook is as follows:
1. Download the data.
2. Build a `torch.nn.Module` model.
3. Train it and compute accuracy on the test set.
4. Convert the model to `torch-dag` `DagModule` format.
5. Prepare the converted model for pruning.
6. Run trainnig with pruning.
7. Remove channels from the model.
8. Report accuracy after pruning.

In [1]:
import torch.nn as nn
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
import logging

logging.basicConfig(level=logging.INFO)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Download and load the MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CPU_DEVICE = torch.device('cpu')

## 2. Build the model

In [4]:
# Define the convolutional model
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 32)
        self.fc3 = nn.Linear(32, 10)
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.activation(self.conv1(x))
        x = nn.functional.max_pool2d(x, 2)
        x = self.activation(self.conv2(x))
        x = nn.functional.max_pool2d(x, 2)
        x = self.activation(self.conv3(x))
        x = torch.flatten(x, start_dim=1)
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        x = self.fc3(x)
        return nn.functional.log_softmax(x, dim=1)

## 3. Train the original model

In [5]:
# Initialize the model and optimizer
model = ConvNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [6]:
# Train the model
num_epochs = 2
batch_size = 100
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

model.to(DEVICE)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(images)
        loss = nn.functional.nll_loss(outputs, labels)
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, len(train_loader),
                                                                     loss.item()))

Epoch [1/2], Step [100/600], Loss: 0.2511
Epoch [1/2], Step [200/600], Loss: 0.2141
Epoch [1/2], Step [300/600], Loss: 0.1352
Epoch [1/2], Step [400/600], Loss: 0.0589
Epoch [1/2], Step [500/600], Loss: 0.0338
Epoch [1/2], Step [600/600], Loss: 0.0627
Epoch [2/2], Step [100/600], Loss: 0.0257
Epoch [2/2], Step [200/600], Loss: 0.0297
Epoch [2/2], Step [300/600], Loss: 0.0430
Epoch [2/2], Step [400/600], Loss: 0.0793
Epoch [2/2], Step [500/600], Loss: 0.0444
Epoch [2/2], Step [600/600], Loss: 0.0675


In [7]:
# Test the model
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

Accuracy of the model on the 10000 test images: 98.77 %


## 4. Convert to `DagModule`.

In [8]:
import torch_dag as td
import torch_dag_algorithms as tda
model.to(CPU_DEVICE)
dag = td.build_from_unstructured_module(model)
td.compare_module_outputs(first_module=model, second_module=dag, input_shape=(8, 1, 28, 28)) # sanity check for conversion

  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


## 5. Prepare the converted model for pruning

In [9]:
INPUT_SHAPE = (1, 1, 28, 28)
PRUNING_PROPORTION = 0.5  # target model size relative to the original model
NUM_PRUNING_STEPS = 5000
batch_size = 100
initial_normalized_flops = tda.pruning.compute_normalized_flops(dag, input_shape_without_batch=INPUT_SHAPE[1:])

pruning_config = tda.pruning.ChannelPruning(
    model=dag,
    input_shape_without_batch=INPUT_SHAPE[1:],
    pruning_proportion=PRUNING_PROPORTION,
    num_training_steps=NUM_PRUNING_STEPS,
    anneal_losses=False,
)

pruning_model = pruning_config.prepare_for_pruning()
print(f'Prunable proportion: {pruning_config.prunable_proportion}')

INFO:torch_dag_algorithms.pruning.filters:[[1m[96mNonPrunableCustomModulesFilter[0m] Removing orbit [1m[95mOrbit[0m[[1m[93mcolor[0m=3, [1m[93mdiscovery_stage[0m=OrbitsDiscoveryStage.EXTENDED_ORBIT_DISCOVERY, [1m[93msources[0m=[conv3], [1m[93msinks[0m=[fc1], [1m[93mnon_border[0m={flatten, activation_2}, [1m[93mend_path[0m=[(flatten, fc1)]]
INFO:torch_dag_algorithms.pruning.filters:[[1m[96mOutputInScopeFilter[0m] Removing orbit [1m[95mOrbit[0m[[1m[93mcolor[0m=6, [1m[93mdiscovery_stage[0m=OrbitsDiscoveryStage.EXTENDED_ORBIT_DISCOVERY, [1m[93msources[0m=[fc3], [1m[93msinks[0m=[], [1m[93mnon_border[0m={log_softmax}, [1m[93mend_path[0m=[]]
INFO:torch_dag_algorithms.pruning.dag_orbitalizer:[+] Total normalized flops: 21.232244897959177
INFO:torch_dag_algorithms.pruning.dag_orbitalizer:[+] Prunable normalized flops: 21.21624489795918
INFO:torch_dag_algorithms.pruning.dag_orbitalizer:[+] Unprunable normalized flops: 0.015999999999998238
INFO:torch

Prunable proportion: 0.9992464291893347


## 6. Run trainnig with pruning

In [10]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

optimizer = torch.optim.Adam(pruning_model.parameters(), lr=0.001)

global_step = 0
batches_per_epoch = len(train_loader)
num_epochs = NUM_PRUNING_STEPS // batches_per_epoch

_ = pruning_model.to(DEVICE)

In [11]:
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = pruning_model(images)

        proportion, flops_loss_value, entropy_loss_value, bkd_loss_value = \
            pruning_config.compute_current_proportion_and_pruning_losses(global_step=global_step)
        task_loss = nn.functional.nll_loss(outputs, labels)
        loss = task_loss + flops_loss_value + entropy_loss_value + bkd_loss_value
        loss.backward()
        optimizer.step()
        global_step += 1

        if global_step % 100 == 0:
            print(f'global step: {global_step}/{NUM_PRUNING_STEPS}, proportion: {proportion}, task loss: {task_loss}, entropy loss: {entropy_loss_value}')

global step: 100/5000, proportion: 0.9245241284370422, task loss: 0.15911529958248138, entropy loss: 0.20658813416957855
global step: 200/5000, proportion: 0.8771666884422302, task loss: 0.10889792442321777, entropy loss: 0.2939905524253845
global step: 300/5000, proportion: 0.6751950979232788, task loss: 0.0686008483171463, entropy loss: 0.38392454385757446
global step: 400/5000, proportion: 0.4627421498298645, task loss: 0.033269405364990234, entropy loss: 0.39139097929000854
global step: 500/5000, proportion: 0.4674898386001587, task loss: 0.15927720069885254, entropy loss: 0.38195693492889404
global step: 600/5000, proportion: 0.47195300459861755, task loss: 0.04428831487894058, entropy loss: 0.3704186677932739
global step: 700/5000, proportion: 0.4754504859447479, task loss: 0.08102502673864365, entropy loss: 0.35771644115448
global step: 800/5000, proportion: 0.4793959856033325, task loss: 0.11654949188232422, entropy loss: 0.34719157218933105
global step: 900/5000, proportion: 0

## 7. Remove channels from the model

In [12]:
pruned_model = pruning_config.remove_channels()

INFO:torch_dag_algorithms.pruning.channel_removal_primitives:Pruning conv conv1: leaving fraction: 1.0 of out channels.
INFO:torch_dag_algorithms.pruning.channel_removal_primitives:Pruning conv conv2: leaving fraction: 0.46875 of out channels.
INFO:torch_dag_algorithms.pruning.channel_removal_primitives:Pruning conv conv3: leaving fraction: 1.0 of out channels.
  return [torch.tensor(scores_)]
INFO:torch_dag_algorithms.pruning.channel_removal_primitives:Pruning conv fc1: leaving fraction: 0.1171875 of out channels.
INFO:torch_dag_algorithms.pruning.channel_removal_primitives:Pruning conv fc2: leaving fraction: 1.0 of out channels.
INFO:torch_dag_algorithms.pruning.channel_removal_primitives:Pruning conv fc3: leaving fraction: 1.0 of out channels.
INFO:torch_dag_algorithms.pruning.mask_propagation:No explicit mask propagation for vertex: log_softmax, of type <class 'torch.nn.modules.activation.LogSoftmax'>. Returning `None` masks.


In [13]:
final_normalized_flops = tda.pruning.compute_normalized_flops(pruned_model, input_shape_without_batch=INPUT_SHAPE[1:])
final_proportion = final_normalized_flops / initial_normalized_flops

## 8. Report accuracy and model size after pruning

In [14]:
pruned_model.to(DEVICE)

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        outputs = pruned_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the pruned model on the 10000 test images: {} %'.format(100 * correct / total))

Accuracy of the pruned model on the 10000 test images: 99.13 %


In [15]:
print(f'Initial normalized flops: {initial_normalized_flops}, final normalized flops: {final_normalized_flops}')
print(f'Final proportion: {final_proportion}')

Initial normalized flops: 21.232244897959184, final normalized flops: 9.596017857142858
Final proportion: 0.45195493473538517
