## Automatic pruning

This notebook demonstrates end2end pipeline for MobileNetV2 pruning.

This notebook assumes that you want to integrate pruning into your training script and modify it for fine-tuning. 

Our pruning process consists of calibration for pruning, least important channel selection, channel pruning and model fine-tuning.

### Main chapters of this notebook:
1. Setup the environment
1. Prepare dataset and create dataloaders
1. Evaluate pretrained MobileNetV2
1. Calibrate, prune and evaluate pruned model equally
1. Finetune and evaluate pruned model
1. Usage of optimal pruning tool
1. Finetune and evaluate optimal pruned model

Short summary for this notebook:

``` python
def loss_function(model_output, sample):
    labels = sample[1].cuda()
    return F.cross_entropy(model_output, labels, reduction='mean')

def sample_to_n_samples(sample):
    return sample[0].shape[0]

def sample_to_model_inputs(sample):
    images = sample[0].cuda()
    return (images,), {}

def latency_calculation_function(model, dataloader):
    # Need to pass dataset items as input of network
    inputs, _ = sample_to_model_inputs(next(iter(dataloader)))
    ...
    return model_latency


lcf = partial(latency_calculation_function, dataloader=train_dataloader)

# let's speed up model up to x3 times
optimal_pruned_model = calibrate_and_prune_model_optimal(
    model=baseline_model,
    dataloader=train_dataloader,
    loss_function=loss_function,
    latency_calculation_function=lcf(baseline_model) / 3,
    target_latency=desired_model_latency_value,
    finetune_bn=True,
    n_steps=None,
    epochs=1,
    n_network_runs=200,
    sample_to_n_samples=sample_to_n_samples,
    sample_to_model_inputs=sample_to_model_inputs,
    verbose=2,
)
```

## Setup the environment

First, let's set up the environment and make some common imports.

In [None]:
import os

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
# You may need to uncomment and change this variable to match free GPU index
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
# Common:
from functools import partial
import torch
import torch.nn.functional as F
from pathlib import Path
from torch import nn
from tqdm.auto import tqdm
from tutorial_utils.checkpoints import download_imagenette_mobilenet
from tutorial_utils.dataset import create_imagenette_dataloaders_for_pruning
from tutorial_utils.train import accuracy

# Training:
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch_optimizer import RAdam
from tutorial_utils.phases import tutorial_train_loop
from tutorial_utils.train import WarmupScheduler

# Pruning:
from enot.pruning import calibrate_and_prune_model_equal
from enot.pruning import calibrate_and_prune_model_optimal

# Latency:
from fvcore.nn.flop_count import FlopCountAnalysis

Define model evaluation function:

In [None]:
# This function can evaluate both nn.Modules and executable functions.
def eval_model(model, dataloader):

    model.eval()

    total = 0
    total_loss = 0.0
    total_correct = 0.0

    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader):

            n = inputs.shape[0]

            pred_labels = model(inputs)
            batch_loss = criterion(pred_labels, labels)
            batch_accuracy = accuracy(pred_labels, labels)

            total += n
            total_loss += batch_loss.item() * n
            total_correct += batch_accuracy.item() * n

    return total_loss / total, total_correct / total

### In the following cell we setup all necessary dirs

* `HOME_DIR` - experiments home directory
* `DATASETS_DIR` - root directory for datasets (imagenette2, ...)
* `PROJECT_DIR` - project directory to save training logs, checkpoints, ...

In [None]:
HOME_DIR = Path.home() / '.optimization_experiments'
DATASETS_DIR = HOME_DIR / 'datasets'
PROJECT_DIR = HOME_DIR / 'e2e_pruning'

HOME_DIR.mkdir(exist_ok=True)
DATASETS_DIR.mkdir(exist_ok=True)
PROJECT_DIR.mkdir(exist_ok=True)

## Prepare dataset and create dataloaders

In [None]:
train_dataloader, validation_dataloader = create_imagenette_dataloaders_for_pruning(
    dataset_root_dir=DATASETS_DIR,
    project_dir=PROJECT_DIR,
    input_size=224,
    batch_size=32,
    num_workers=4,
)

## Evaluate pretrained MobileNetV2

In [None]:
from torchvision.models.mobilenetv2 import mobilenet_v2

regular_model = mobilenet_v2(pretrained=False, num_classes=10).cuda()

# Turning off FullyConnected layer dropout.
# This is required to stabilize fine-tuning procedure.
regular_model.classifier[0].p = 0.0

In [None]:
checkpoint_path = PROJECT_DIR / 'e2e_imagenette_pruning.pth'
download_imagenette_mobilenet(checkpoint_path)

regular_model.load_state_dict(
    torch.load(checkpoint_path)['model'],
)

In [None]:
val_loss, val_accuracy = eval_model(regular_model, validation_dataloader)
print(f'Regular (non-pruned) model: accuracy={val_accuracy:.3f}, loss={val_loss:.3f}')

## Calibrate, prune and evaluate pruned model

First, let's define pruning ratio ``pruning_ratio`` (the amount of channels removed from the network), loss function ``loss_function`` (calculates mean loss for single batch of data loader), function to get the number of samples (images) in one data loader sample ``sample_to_n_samples``, and function to transform data loader sample to model inputs ``sample_to_model_inputs``.

In [None]:
pruning_ratio = 0.433  # This gives about x3 FLOPs reduction.


def loss_function(model_output, sample):
    labels = sample[1].cuda()
    return F.cross_entropy(model_output, labels, reduction='mean')


def sample_to_n_samples(sample):
    return sample[0].shape[0]


def sample_to_model_inputs(sample):
    images = sample[0].cuda()
    return (images,), {}

Perform model pruning by calling ``calibrate_and_prune_model_equal``.

In [None]:
pruned_model = calibrate_and_prune_model_equal(
    model=regular_model,
    dataloader=train_dataloader,
    loss_function=loss_function,
    pruning_ratio=pruning_ratio,
    finetune_bn=True,
    calibration_steps=None,  # When None - uses epochs argument to set the number of steps.
    calibration_epochs=1,
    sample_to_n_samples=sample_to_n_samples,
    sample_to_model_inputs=sample_to_model_inputs,
    show_tqdm=True,
)
pruned_model.cuda()

After pruning, pruned user model has the same structure as the original model, except that some convolutions, fully-connected layers and batch norm layers now have smaller number of channels.

In [None]:
val_loss, val_accuracy = eval_model(pruned_model, validation_dataloader)
print(f'Pruned model: accuracy={val_accuracy:.3f}, loss={val_loss:.3f}')

## Finetune and evaluate pruned model

In [None]:
N_EPOCHS = 5
N_WARMUP_EPOCHS = 1
learning_rate = 0.05

# Uncomment lines below if you want to reach the best pruned model
# performance (~93% accuracy for pruned model).

# N_EPOCHS = 50  # Increase the number of model fine-tuning epochs.
# N_WARMUP_EPOCHS = 10  # Increase the number of warmup epochs.
# learning_rate = 0.01  # Increase learning rate

len_train = len(train_dataloader)

optimizer = RAdam(pruned_model.parameters(), lr=learning_rate, weight_decay=4e-5)
scheduler = CosineAnnealingLR(optimizer, T_max=len_train * N_EPOCHS)
scheduler = WarmupScheduler(scheduler, warmup_steps=len_train * N_WARMUP_EPOCHS)
loss = nn.CrossEntropyLoss().cuda()

tutorial_train_loop(
    epochs=N_EPOCHS,
    model=pruned_model,
    optimizer=optimizer,
    metric_function=accuracy,
    loss_function=loss,
    train_loader=train_dataloader,
    validation_loader=validation_dataloader,
    scheduler=scheduler,
)

In [None]:
val_loss, val_accuracy = eval_model(pruned_model, validation_dataloader)
print(f'Fine-tuned pruned model: accuracy={val_accuracy:.3f}, loss={val_loss:.3f}')

# Let's try to find optimal model with same latency

For simplicity we use MACs as latency function but feel free to use your own latency i.e.:
1. CPU time
2. GPU time
3. Any other functions such as memory consumption

In [None]:
def latency_calculation_function(model, dataloader):
    # Need to pass dataset items as input of network
    inputs, _ = sample_to_model_inputs(next(iter(dataloader)))
    fca = FlopCountAnalysis(model=model.eval(), inputs=inputs)
    fca.unsupported_ops_warnings(False)
    fca.uncalled_modules_warnings(False)
    original_mflops = fca.total() / 1e6
    return original_mflops


lcf = partial(latency_calculation_function, dataloader=train_dataloader)

In [None]:
desired_latency = lcf(pruned_model.eval())

print(f'Previously pruned model latency={desired_latency:.3f} MMACs')

`latency_penalty` means how strong you restrict difference between target latency and searched model latency. Larger value leads to more precise match. 

If you need to find "closest" model to your desired latency you should increse this parameter. Since there is **latency/accuracy trade-off** you may fail to find optimal model in terms of quality. So you need to increase value of `n_search_steps`. This parameter enlarge total execution time of `calibrate_and_prune_model_optimal` function but leads to better searched model. If you fail to find model better than "equal pruned" try to increase `n_search_steps`.


Note: if you use small pruning rates (up to x2 times) or your task is "easy" task, developers recomend to use   
`calibrate_and_prune_model_equal` as best choice since almost every pruned model in such problem statement is OK. 


The default value of `n_search_steps` is 200, `latency_penalty` is 300.

In [None]:
kwargs = {
    'latency_penalty': 100,
}

optimal_pruned_model = calibrate_and_prune_model_optimal(
    model=regular_model,
    dataloader=train_dataloader,
    loss_function=loss_function,
    latency_calculation_function=lcf,
    target_latency=desired_latency,
    finetune_bn=True,
    calibration_steps=None,  # When None - uses epochs argument to set the number of steps.
    calibration_epochs=1,
    n_search_steps=200,
    sample_to_n_samples=sample_to_n_samples,
    sample_to_model_inputs=sample_to_model_inputs,
    show_tqdm=True,
    **kwargs,
)
optimal_pruned_model.cuda()

In [None]:
opt_val_loss, opt_val_accuracy = eval_model(optimal_pruned_model, validation_dataloader)
opt_lat = lcf(optimal_pruned_model.eval())

print(f'Optimal pruned model: accuracy={opt_val_accuracy:.3f}, loss={opt_val_loss:.3f}, latency={opt_lat}')

In [None]:
N_EPOCHS = 5
N_WARMUP_EPOCHS = 1
learning_rate = 0.05

# Uncomment lines below if you want to reach the best pruned model
# performance (~93% accuracy for pruned model).

# N_EPOCHS = 50  # Increase the number of model fine-tuning epochs.
# N_WARMUP_EPOCHS = 10  # Increase the number of warmup epochs.
# learning_rate = 0.01  # Increase learning rate

len_train = len(train_dataloader)

optimizer = RAdam(optimal_pruned_model.parameters(), lr=learning_rate, weight_decay=4e-5)
scheduler = CosineAnnealingLR(optimizer, T_max=len_train * N_EPOCHS)
scheduler = WarmupScheduler(scheduler, warmup_steps=len_train * N_WARMUP_EPOCHS)
loss = nn.CrossEntropyLoss().cuda()

tutorial_train_loop(
    epochs=N_EPOCHS,
    model=optimal_pruned_model,
    optimizer=optimizer,
    metric_function=accuracy,
    loss_function=loss,
    train_loader=train_dataloader,
    validation_loader=validation_dataloader,
    scheduler=scheduler,
)