# **Template for Torch-Pruning**

This template is just built for your convinience.

You are not required to follow the steps and method given below.

In [28]:
!pip install --upgrade torch_pruning



In [29]:
import torch
import torchvision
from torchvision.models import mobilenet_v2
import torch_pruning as tp
from functools import partial

import copy
import math
import random
import time
from collections import OrderedDict, defaultdict
from typing import Union, List

import numpy as np
from matplotlib import pyplot as plt
from torch import nn
from torch.optim import *
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader
from torchprofile import profile_macs
from torchvision.datasets import *
from torchvision.transforms import *
from tqdm.auto import tqdm
import torch.nn.functional as F
from torchprofile import profile_macs
import os

## A Minimal Example   
In this section, you will perform channel pruning using the library [Torch-Pruning](https://github.com/VainF/Torch-Pruning).  

The puuner in Torch-Pruning has three main functions: sparse training (optional), importance estimation, and parameter removal.  
Torch-pruning offers two core features to support this process:

tp.importance(): This criteria is utilized to measure the importance of weights.  

tp.pruner(): This is a pruner used for the actual pruning of the parameters.  

For detailed information on this process, please refer to this [tutorial](https://github.com/VainF/Torch-Pruning/wiki/4.-High%E2%80%90level-Pruners/). Additionally, a more practical example is available in [here](https://github.com/VainF/Torch-Pruning/blob/master/benchmarks/main.py).

### 1. Load model


In [30]:
model = torch.load('./mobilenetv2_0.963.pth', map_location="cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

### 2. Prepare a pruner
By default, Torch-Pruning will automatically prune the last non-singleton dim of these parameters. If you want to customize this behaviour, please provide an `unwrapped_parameters` list as the following example.

In [31]:
transforms = {
    "train": Compose([
      Resize((224, 224)),
      ToTensor(),
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    "test": Compose([
      Resize((224, 224)),
      ToTensor(),
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
}

dataset = {}
for split in ["train", "test"]:
  dataset[split] = CIFAR10(
    root="data/cifar10",
    train=(split == "train"),
    download=True,
    transform=transforms[split],
  )

# You can apply your own batch_size
dataloader = {}
for split in ['train', 'test']:
  dataloader[split] = DataLoader(
    dataset[split],
    batch_size=2,
    shuffle=(split == 'train'),
    num_workers=0,
    pin_memory=True,
    drop_last=True
  )

Files already downloaded and verified
Files already downloaded and verified


In [41]:
def progressive_pruning(pruner, model, speed_up, example_inputs, train_loader=None):
    model.eval()
    base_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
    current_speed_up = 1
    while current_speed_up < speed_up:
        pruner.step()
        pruned_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
        current_speed_up = float(base_ops) / pruned_ops
        if pruner.current_step == pruner.iterative_steps:
            break
    return current_speed_up

In [42]:
output_dir = "./pruning_output"
def eval(model, test_loader, device=None, verbose=True):

    num_samples = 0
    num_correct = 0
    
    model.to(device)
    model.eval()
    with torch.no_grad():
        for inputs, targets in tqdm(test_loader, desc="eval", leave=False,
                                    disable=not verbose):
            # Move the data from CPU to GPU
            inputs = inputs.cuda()
            targets = targets.cuda()

            # Inference
            outputs = model(inputs)

            # Convert logits to class indices
            outputs = outputs.argmax(dim=1)

            # Update metrics
            num_samples += targets.size(0)
            num_correct += (outputs == targets).sum()

    return (num_correct / num_samples * 100).item()

def train_model(
    model,
    train_loader,
    test_loader,
    epochs,
    lr,
    save_as=None,
    
    # For pruning
    weight_decay=1e-4,
    save_state_dict_only=True,
    pruner=None,
    device=None,
    verbose=False
):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    if pruner is not None:
        pruner.update_regularizer()
    model.to(device)
    best_acc = -1

    for epoch in range(1, epochs+1):
        model.train()

        for i, (inputs, targets) in enumerate(tqdm(train_loader, desc='train', leave=False)):
            # Move the data from CPU to GPU
            inputs = inputs.cuda()
            targets = targets.cuda()
            # Reset the gradients (from the last iteration)
            optimizer.zero_grad()

            # Forward inference
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Backward propagation
            loss.backward()

            if pruner is not None:
                pruner.regularize(model) # for sparsity learning

            # Update optimizer and LR scheduler
            optimizer.step()
            if i % 10 == 0 and verbose:
                print(
                    "Epoch {:d}/{:d}, iter {:d}/{:d}, loss={:.4f}, lr={:.4f}".format(
                        epoch,
                        epochs,
                        i,
                        len(train_loader),
                        loss.item(),
                        optimizer.param_groups[0]["lr"],
                    )
                )

        if pruner is not None and isinstance(pruner, tp.pruner.GrowingRegPruner):
            pruner.update_reg() # increase the strength of regularization
            #print(pruner.group_reg[pruner._groups[0]])
        
        model.eval()
        acc = eval(model, test_loader, device=device)
        print(
            "Epoch {:d}/{:d}, Acc={:.4f}, lr={:.4f}".format(
                epoch, epochs, acc, optimizer.param_groups[0]["lr"]
            )
        )
        if best_acc < acc:
            os.makedirs(output_dir, exist_ok=True)
            
            if save_as is None:
                save_as = os.path.join(output_dir, "{}_{}_{}.pth".format('CIFAR10', 'mobilenet', 'group_norm'))

            if save_state_dict_only:
                torch.save(model.state_dict(), save_as)
            else:
                torch.save(model, save_as)
            best_acc = acc
        scheduler.step()
    print("Best Acc=%.4f" % (best_acc))

In [43]:
NUM_CLASSES = 10
# Importance criterion
imp = tp.importance.GroupNormImportance(p=2) # or GroupTaylorImportance(), GroupHessianImportance(), etc.

# Initialize a pruner with the model and the importance criterion
example_inputs = torch.randn(1, 3, 224, 224).to(device)

unwrapped_parameters = []
ignored_layers = []
pruning_ratio_dict = {}
for m in model.modules():
  if isinstance(m, torch.nn.Linear) and m.out_features == NUM_CLASSES: # ignore the classifier
    ignored_layers.append(m)
  elif isinstance(m, torch.nn.modules.conv._ConvNd) and m.out_channels == NUM_CLASSES:
            ignored_layers.append(m)

pruner_entry = partial(tp.pruner.GroupNormPruner, global_pruning=False)
pruner = pruner_entry(
        model,
        example_inputs,
        importance=imp,
        iterative_steps=400,
        pruning_ratio=1.0,
        pruning_ratio_dict=pruning_ratio_dict,
        max_pruning_ratio=1.0,
        ignored_layers=ignored_layers,
        unwrapped_parameters=unwrapped_parameters,
    )

### 3. Prune the model

In [45]:
total_epochs = 30
lr = 0.01
# Model size before pruning
model.eval()

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
base_acc = eval(model, dataloader['test'], device=device)
print(f'base acc: {base_acc}')
if isinstance(imp, tp.importance.GroupTaylorImportance):
  # Taylor expansion requires gradients for importance estimation
  loss = model(example_inputs).sum() # A dummy loss, please replace this line with your loss function and data!
  loss.backward() # before pruner.step()

print("Pruning...")
# prune
progressive_pruning(pruner, model, speed_up=2, example_inputs=example_inputs, train_loader=dataloader['train'])

# Parameter & MACs Counter
pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(model, example_inputs)
pruned_acc = eval(model, dataloader['test'], device=device)
print("Params: {:.2f} M => {:.2f} M ({:.2f}%)".format(
                base_nparams / 1e6, pruned_nparams / 1e6, pruned_nparams / base_nparams * 100
            ))
print("FLOPs: {:.2f} M => {:.2f} M ({:.2f}%, {:.2f}X )".format(
                base_macs / 1e6,
                pruned_macs / 1e6,
                pruned_macs / base_macs * 100,
                base_macs / pruned_macs,
            ))
print("Acc: {:.4f} => {:.4f}".format(base_acc, pruned_acc))
MFLOPs = pruned_macs/1e6
print("The pruned model:")
print(model)
print("Summary:")
print("MFLOPs: ")
print(MFLOPs)

# 2. Finetuning
print('Finetuning...')
train_model(
                model,
                train_loader=dataloader['train'],
                test_loader=dataloader['test'],
                epochs=total_epochs,
                lr=lr,
                save_as=None,
                weight_decay=1e-4,
                save_state_dict_only=False,
                pruner=None,
                device=device,
                verbose=False
            )

                                                          

base acc: 73.58999633789062
Pruning...


                                                         

Params: 0.10 M => 0.05 M (44.65%)
FLOPs: 18.53 M => 9.18 M (49.56%, 2.02X )
Acc: 73.5900 => 10.0000
The pruned model:
MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=3, bias=False)
          (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(3, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
   

                                                            

Epoch 0/30, Acc=62.3600, lr=0.0100


                                                            

Epoch 1/30, Acc=65.2200, lr=0.0100


                                                            

Epoch 2/30, Acc=67.3800, lr=0.0099


                                                            

Epoch 3/30, Acc=65.0600, lr=0.0098


                                                            

Epoch 4/30, Acc=63.7700, lr=0.0096


                                                            

Epoch 5/30, Acc=69.3100, lr=0.0093


                                                            

Epoch 6/30, Acc=67.7500, lr=0.0090


                                                            

Epoch 7/30, Acc=69.0700, lr=0.0087


                                                            

Epoch 8/30, Acc=68.9200, lr=0.0083


                                                            

Epoch 9/30, Acc=68.6200, lr=0.0079


                                                            

Epoch 10/30, Acc=68.8400, lr=0.0075


                                                            

Epoch 11/30, Acc=70.3700, lr=0.0070


                                                            

Epoch 12/30, Acc=67.8000, lr=0.0065


                                                            

Epoch 13/30, Acc=71.1600, lr=0.0060


                                                            

Epoch 14/30, Acc=70.4100, lr=0.0055


                                                            

Epoch 15/30, Acc=73.6100, lr=0.0050


                                                            

Epoch 16/30, Acc=70.8800, lr=0.0045


                                                            

Epoch 17/30, Acc=72.1500, lr=0.0040


                                                            

Epoch 18/30, Acc=74.3300, lr=0.0035


                                                            

Epoch 19/30, Acc=73.2600, lr=0.0030


                                                            

Epoch 20/30, Acc=73.6500, lr=0.0025


                                                            

Epoch 21/30, Acc=74.4900, lr=0.0021


                                                            

Epoch 22/30, Acc=76.4100, lr=0.0017


                                                            

Epoch 23/30, Acc=77.0000, lr=0.0013


                                                            

Epoch 24/30, Acc=76.9300, lr=0.0010


                                                            

Epoch 25/30, Acc=77.7800, lr=0.0007


                                                            

Epoch 26/30, Acc=77.6700, lr=0.0004


                                                            

Epoch 27/30, Acc=78.6100, lr=0.0002


                                                            

Epoch 28/30, Acc=78.7400, lr=0.0001


                                                            

Epoch 29/30, Acc=78.9700, lr=0.0000
Best Acc=78.9700


