## Automatic pruning

This notebook demonstrates end2end pipeline for MobileNetV2 pruning.

Our pruning process consists of calibration for pruning, least important channel selection, channel pruning and model fine-tuning.

### Main chapters of this notebook:
1. Setup the environment
1. Prepare dataset and create dataloaders
1. Evaluate pretrained MobileNetV2
1. Calibrate, prune and evaluate pruned model
1. Finetune and evaluate pruned model

## Setup the environment

First, let's set up the environment and make some common imports.

In [None]:
import os

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
# You may need to uncomment and change this variable to match free GPU index
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
# Common:
import logging
import numpy as np
import torch
from fvcore.nn import FlopCountAnalysis
from pathlib import Path
from torch import nn
from tqdm.auto import tqdm
from typing import List
from tutorial_utils.checkpoints import download_imagenette_mobilenet
from tutorial_utils.dataset import create_imagenette_dataloaders_for_pruning
from tutorial_utils.train import accuracy

# Training:
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import RAdam
from tutorial_utils.phases import tutorial_train_loop
from tutorial_utils.train import WarmupScheduler

# Pruning:
from enot.pruning import KnapsackPruningLabelSelector
from enot.pruning import PruningCalibrator
from enot.pruning import prune_model
from enot.utils.batch_norm import tune_bn_stats

Define model evaluation function:

In [None]:
# This function can evaluate both nn.Modules and executable functions.
def eval_model(model, dataloader):
    model.eval()

    total = 0
    total_loss = 0.0
    total_correct = 0.0

    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader):
            n = inputs.shape[0]

            pred_labels = model(inputs)
            batch_loss = criterion(pred_labels, labels)
            batch_accuracy = accuracy(pred_labels, labels)

            total += n
            total_loss += batch_loss.item() * n
            total_correct += batch_accuracy.item() * n

    return total_loss / total, total_correct / total

### In the following cell we setup all necessary dirs

* `HOME_DIR` - experiments home directory
* `DATASETS_DIR` - root directory for datasets (imagenette2, ...)
* `PROJECT_DIR` - project directory to save training logs, checkpoints, ...

In [None]:
HOME_DIR = Path.home() / '.optimization_experiments'
DATASETS_DIR = HOME_DIR / 'datasets'
PROJECT_DIR = HOME_DIR / 'e2e_pruning'

HOME_DIR.mkdir(exist_ok=True)
DATASETS_DIR.mkdir(exist_ok=True)
PROJECT_DIR.mkdir(exist_ok=True)

## Prepare dataset and create dataloaders

In [None]:
train_dataloader, validation_dataloader = create_imagenette_dataloaders_for_pruning(
    dataset_root_dir=DATASETS_DIR,
    project_dir=PROJECT_DIR,
    input_size=224,
    batch_size=32,
    num_workers=4,
)

## Evaluate pretrained MobileNetV2

In [None]:
from torchvision.models.mobilenetv2 import mobilenet_v2

regular_model = mobilenet_v2(pretrained=False, num_classes=10).cuda()

# Turning off FullyConnected layer dropout.
# This is required to stabilize fine-tuning procedure.
regular_model.classifier[0].p = 0.0

In [None]:
checkpoint_path = PROJECT_DIR / 'e2e_imagenette_pruning.pth'
download_imagenette_mobilenet(checkpoint_path)

regular_model.load_state_dict(
    torch.load(checkpoint_path)['model'],
)

In [None]:
val_loss, val_accuracy = eval_model(regular_model, validation_dataloader)
print(f'Regular (non-pruned) model: accuracy={val_accuracy:.3f}, loss={val_loss:.3f}')

## Calibrate, prune and evaluate pruned model

First, let's define ``mmac_acceleration_factor`` (desired acceliration in terms of macs) and loss function ``loss_function`` (calculates total loss for single batch of data loader).

In [None]:
mmac_acceleration_factor = 3  # This gives x3 macs reduction.
loss_function = torch.nn.CrossEntropyLoss(reduction='sum')

Now a function needs to be defined to measure latency in terms of macs for an arbitrary model.

In [None]:
def mmac_calculation_function(model):
    inputs, _ = next(iter(train_dataloader))

    fca = FlopCountAnalysis(
        model=model.eval(),
        inputs=inputs,
    )
    fca.unsupported_ops_warnings(False)
    fca.uncalled_modules_warnings(False)

    return fca.total() / 1e6

Now, let's perform model calibration for pruning. Calibration finds all prunable channels in the network and estimates their importances. Accumulated pruning-related information will be stored in ``pruning_info`` variable.

In [None]:
# This code is implemented in enot.pruning.calibrate_model_for_pruning.

regular_model.eval()
pruning_calibrator = PruningCalibrator(model=regular_model)
with pruning_calibrator:
    for images, labels in train_dataloader:
        predictions = regular_model(images)
        loss = loss_function(predictions, labels)
        loss.backward()

pruning_info = pruning_calibrator.pruning_info

Selecting the least important channels to achieve the desired acceleration in terms of macs.

In [None]:
baseline_latency = mmac_calculation_function(model=regular_model)
target_latency = baseline_latency / mmac_acceleration_factor

print('baseline mmacs =', baseline_latency)
print('target mmacs =', target_latency)

label_selector = KnapsackPruningLabelSelector(
    latency_calculation_function=mmac_calculation_function,
    target_latency=target_latency,
    verbose=True,
)
prune_labels = label_selector.select(model=regular_model, pruning_info=pruning_info)

Create a copy of original model and remove selected channels.

In [None]:
pruned_model = prune_model(
    model=regular_model,
    pruning_info=pruning_info,
    prune_labels=prune_labels,
    inplace=False,
)
pruned_model_latency = mmac_calculation_function(model=pruned_model)
print('pruned model mmacs =', pruned_model_latency)
pruned_model.cuda()

Tune batch normalization layers on train data to stabilize their running variables after pruning.

In [None]:
tune_bn_stats(
    model=pruned_model,
    dataloader=train_dataloader,
    reset_bns=True,  # Drop old batch norm running statistics.
    set_momentums_none=True,  # Accumulate average statistics.
    n_steps=None,
    epochs=1,
)

After pruning, pruned user model has the same structure as the original model, except that some convolutions, fully-connected layers and batch norm layers now have smaller number of channels.

In [None]:
val_loss, val_accuracy = eval_model(pruned_model, validation_dataloader)
print(f'Pruned model: accuracy={val_accuracy:.3f}, loss={val_loss:.3f}')

## Finetune and evaluate pruned model

In [None]:
N_EPOCHS = 5
N_WARMUP_EPOCHS = 1
learning_rate = 0.05

# Uncomment lines below if you want to reach the best pruned model
# performance (~93% accuracy for pruned model).

# N_EPOCHS = 50  # Increase the number of model fine-tuning epochs.
# N_WARMUP_EPOCHS = 10  # Increase the number of warmup epochs.
# learning_rate = 0.01  # Increase learning rate

len_train = len(train_dataloader)

optimizer = RAdam(pruned_model.parameters(), lr=learning_rate, weight_decay=0.00004)
scheduler = CosineAnnealingLR(optimizer, T_max=len_train * N_EPOCHS)
scheduler = WarmupScheduler(scheduler, warmup_steps=len_train * N_WARMUP_EPOCHS)
loss_function = nn.CrossEntropyLoss().cuda()

tutorial_train_loop(
    epochs=N_EPOCHS,
    model=pruned_model,
    optimizer=optimizer,
    metric_function=accuracy,
    loss_function=loss_function,
    train_loader=train_dataloader,
    validation_loader=validation_dataloader,
    scheduler=scheduler,
)

In [None]:
val_loss, val_accuracy = eval_model(pruned_model, validation_dataloader)
print(f'Fine-tuned pruned model: accuracy={val_accuracy:.3f}, loss={val_loss:.3f}')