# **Template for Torch-Pruning**

This template is just built for your convinience.

You are not required to follow the steps and method given below.

In [1]:
!pip install --upgrade torch_pruning
!pip install torchprofile
# !pip install torch torchvision torchaudio



In [2]:
import torch
import torchvision
from torchvision.models import mobilenet_v2
import torch_pruning as tp
from functools import partial
import copy
import math
import random
import time
from collections import OrderedDict, defaultdict
from typing import Union, List

import numpy as np
from matplotlib import pyplot as plt
from torch import nn
from torch.optim import *
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader
from torchprofile import profile_macs
from torchvision.datasets import *
from torchvision.transforms import *
from tqdm.auto import tqdm
import torch.nn.functional as F
from torchprofile import profile_macs
import os

  from .autonotebook import tqdm as notebook_tqdm


## A Minimal Example   
In this section, you will perform channel pruning using the library [Torch-Pruning](https://github.com/VainF/Torch-Pruning).  

The puuner in Torch-Pruning has three main functions: sparse training (optional), importance estimation, and parameter removal.  
Torch-pruning offers two core features to support this process:

tp.importance(): This criteria is utilized to measure the importance of weights.  

tp.pruner(): This is a pruner used for the actual pruning of the parameters.  

For detailed information on this process, please refer to this [tutorial](https://github.com/VainF/Torch-Pruning/wiki/4.-High%E2%80%90level-Pruners/). Additionally, a more practical example is available in [here](https://github.com/VainF/Torch-Pruning/blob/master/benchmarks/main.py).

### 1. Load model


In [3]:
model = torch.load('./mobilenetv2_0.963.pth', map_location="cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

### 2. Prepare a pruner
By default, Torch-Pruning will automatically prune the last non-singleton dim of these parameters. If you want to customize this behaviour, please provide an `unwrapped_parameters` list as the following example.

In [4]:
transforms = {
    "train": Compose([
      Resize((224, 224)),
      ToTensor(),
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    "test": Compose([
      Resize((224, 224)),
      ToTensor(),
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
}

dataset = {}
for split in ["train", "test"]:
  dataset[split] = CIFAR10(
    root="data/cifar10",
    train=(split == "train"),
    download=True,
    transform=transforms[split],
  )

# You can apply your own batch_size
dataloader = {}
for split in ['train', 'test']:
  dataloader[split] = DataLoader(
    dataset[split],
    batch_size=10,
    shuffle=(split == 'train'),
    num_workers=0,
    pin_memory=True,
    drop_last=True
  )

Files already downloaded and verified
Files already downloaded and verified


In [5]:
def progressive_pruning(pruner, model, speed_up, example_inputs, train_loader=None):
    model.eval()
    base_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
    current_speed_up = 1
    while current_speed_up < speed_up:
#         model.zero_grad()
#         imp=pruner.importance
#         imp._prepare_model(model, pruner)
#         for k, (imgs, lbls) in enumerate(train_loader):
#             if k>=10: break
#             imgs = imgs.cuda()
#             lbls = lbls.cuda()
#             output = model(imgs)
# #             sampled_y = torch.multinomial(torch.nn.functional.softmax(output.cpu().data, dim=1),
# #                                               1).squeeze().cuda()
# #             loss_sample = F.cross_entropy(output, sampled_y)
#             loss_sample = nn.CrossEntropyLoss()(output, lbls)
#             loss_sample.backward()
#             imp.step()
        pruner.step()
        pruned_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
        current_speed_up = float(base_ops) / pruned_ops
        if pruner.current_step == pruner.iterative_steps:
            break
    return current_speed_up

In [6]:
output_dir = "./pruning_output"
def eval(model, test_loader, device=None, verbose=True):

    num_samples = 0
    num_correct = 0
    
    model.to(device)
    model.eval()
    with torch.no_grad():
        for inputs, targets in tqdm(test_loader, desc="eval", leave=False,
                                    disable=not verbose):
            # Move the data from CPU to GPU
            inputs = inputs.cuda()
            targets = targets.cuda()

            # Inference
            outputs = model(inputs)

            # Convert logits to class indices
            outputs = outputs.argmax(dim=1)

            # Update metrics
            num_samples += targets.size(0)
            num_correct += (outputs == targets).sum()

    return (num_correct / num_samples * 100).item()

def train_model(
    model,
    train_loader,
    test_loader,
    epochs,
    lr,
    lr_decay_milestones,
    lr_decay_gamma = 0.7,
    save_as=None,
    
    # For pruning
    weight_decay=1e-4,
    save_state_dict_only=True,
    pruner=None,
    device=None,
    verbose=False
):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=1, eta_min=0.0008, last_epoch=-1, verbose=False)
    # milestones = [int(ms) for ms in lr_decay_milestones.split(",")]
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(
    #     optimizer, milestones=milestones, gamma=lr_decay_gamma
    # )
    if pruner is not None:
        pruner.update_regularizer()
    model.to(device)
    best_acc = -1

    for epoch in range(1, epochs+1):
        model.train()

        for i, (inputs, targets) in enumerate(tqdm(train_loader, desc='train', leave=False)):
            # Move the data from CPU to GPU
            inputs = inputs.cuda()
            targets = targets.cuda()
            # Reset the gradients (from the last iteration)
            optimizer.zero_grad()

            # Forward inference
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Backward propagation
            loss.backward()

            if pruner is not None:
                pruner.regularize(model) # for sparsity learning

            # Update optimizer and LR scheduler
            optimizer.step()
            if i % 10 == 0 and verbose:
                print(
                    "Epoch {:d}/{:d}, iter {:d}/{:d}, loss={:.4f}, lr={:.4f}".format(
                        epoch,
                        epochs,
                        i,
                        len(train_loader),
                        loss.item(),
                        optimizer.param_groups[0]["lr"],
                    )
                )

        if pruner is not None and isinstance(pruner, tp.pruner.GrowingRegPruner):
            pruner.update_reg() # increase the strength of regularization
            #print(pruner.group_reg[pruner._groups[0]])
        
        model.eval()
        acc = eval(model, test_loader, device=device)
        print(
            "Epoch {:d}/{:d}, Acc={:.4f}, lr={:.4f}".format(
                epoch, epochs, acc, optimizer.param_groups[0]["lr"]
            )
        )
        if best_acc < acc:
            os.makedirs(output_dir, exist_ok=True)
            
            if save_as is None:
                save_as = os.path.join(output_dir, "{}_{}_{}.pth".format('CIFAR10', 'mobilenet', 'group_norm'))

            if save_state_dict_only:
                torch.save(model.state_dict(), save_as)
            else:
                torch.save(model, save_as)
            best_acc = acc
        scheduler.step()
    print("Best Acc=%.4f" % (best_acc))

In [7]:
NUM_CLASSES = 10
# Importance criterion
# imp = tp.pruner.importance.OBDCImportance(group_reduction='mean', num_classes=NUM_CLASSES)
# pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=True)

# imp = tp.importance.GroupNormImportance(p=2) # or GroupTaylorImportance(), GroupHessianImportance(), etc.
# pruner_entry = partial(tp.pruner.GroupNormPruner, global_pruning=False)

imp = tp.importance.GroupNormImportance(p=2, normalizer='max') # normalized by the maximum score for CIFAR
pruner_entry = partial(tp.pruner.GroupNormPruner, reg=5e-4, global_pruning=True)

# Initialize a pruner with the model and the importance criterion
example_inputs = torch.randn(1, 3, 224, 224).to(device)

unwrapped_parameters = []
ignored_layers = []
pruning_ratio_dict = {}
for m in model.modules():
  if isinstance(m, torch.nn.Linear) and m.out_features == NUM_CLASSES: # ignore the classifier
    ignored_layers.append(m)
  elif isinstance(m, torch.nn.modules.conv._ConvNd) and m.out_channels == NUM_CLASSES:
            ignored_layers.append(m)

pruner = pruner_entry(
        model,
        example_inputs,
        importance=imp,
        iterative_steps=500,
        pruning_ratio=0.7,
        pruning_ratio_dict=pruning_ratio_dict,
        max_pruning_ratio=0.95,
        ignored_layers=ignored_layers,
        unwrapped_parameters=unwrapped_parameters,
    )

### 3. Prune the model

In [8]:

# pruning
target_speed_up = 7
# Model size before pruning
model.eval()
# first_speed_up = 4
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
base_acc = eval(model, dataloader['test'], device=device)

print(f'base acc: {base_acc}')
if isinstance(imp, tp.importance.GroupTaylorImportance):
  # Taylor expansion requires gradients for importance estimation
  loss = model(example_inputs).sum() # A dummy loss, please replace this line with your loss function and data!
  loss.backward() # before pruner.step()

print("Pruning...")

progressive_pruning(pruner, model, speed_up=target_speed_up, example_inputs=example_inputs, train_loader=dataloader['train'])

# Parameter & MACs Counter
pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(model, example_inputs)
pruned_acc = eval(model, dataloader['test'], device=device)
print("Params: {:.2f} M => {:.2f} M ({:.2f}%)".format(
                base_nparams / 1e6, pruned_nparams / 1e6, pruned_nparams / base_nparams * 100
            ))
print("FLOPs: {:.2f} M => {:.2f} M ({:.2f}%, {:.2f}X )".format(
                base_macs / 1e6,
                pruned_macs / 1e6,
                pruned_macs / base_macs * 100,
                base_macs / pruned_macs,
            ))
print("Acc: {:.4f} => {:.4f}".format(base_acc, pruned_acc))
MFLOPs = pruned_macs/1e6
print("The first pruned model:")
print(model)
print("Summary:")
print("MFLOPs: ")
print(MFLOPs)

# 2. Finetuning
print('Finetuning...')
train_model(
                model,
                train_loader=dataloader['train'],
                test_loader=dataloader['test'],
                epochs=100,
                lr=0.01,
                lr_decay_milestones="60, 100, 150, 180",
                lr_decay_gamma=0.1,
                save_as=None,
                weight_decay=1e-4,
                save_state_dict_only=False,
                pruner=None,
                device=device,
                verbose=False
            )

                                                        

base acc: 96.30000305175781
Pruning...




Params: 2.24 M => 0.35 M (15.86%)
FLOPs: 318.97 M => 49.51 M (15.52%, 6.44X )
Acc: 96.3000 => 10.0000
The first pruned model:
MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 7, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(7, 7, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=7, bias=False)
          (1): BatchNorm2d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(7, 13, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(13, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActiv

                                                          

Epoch 1/100, Acc=44.0500, lr=0.0100


                                                          

Epoch 2/100, Acc=51.4500, lr=0.0091


                                                          

Epoch 3/100, Acc=57.2200, lr=0.0068


                                                          

Epoch 4/100, Acc=61.8400, lr=0.0040


                                                          

Epoch 5/100, Acc=63.6400, lr=0.0017


                                                          

Epoch 6/100, Acc=64.6400, lr=0.0100


                                                          

Epoch 7/100, Acc=68.8600, lr=0.0091


                                                          

Epoch 8/100, Acc=71.4700, lr=0.0068


                                                          

Epoch 9/100, Acc=74.5200, lr=0.0040


                                                          

Epoch 10/100, Acc=76.2000, lr=0.0017


                                                          

Epoch 11/100, Acc=72.8800, lr=0.0100


                                                          

Epoch 12/100, Acc=74.9900, lr=0.0091


                                                          

Epoch 13/100, Acc=75.5700, lr=0.0068


                                                          

Epoch 14/100, Acc=78.4200, lr=0.0040


                                                          

Epoch 15/100, Acc=78.9600, lr=0.0017


                                                          

Epoch 16/100, Acc=76.3100, lr=0.0100


                                                          

Epoch 17/100, Acc=75.7900, lr=0.0091


                                                          

Epoch 18/100, Acc=78.9200, lr=0.0068


                                                          

Epoch 19/100, Acc=80.0600, lr=0.0040


                                                          

Epoch 20/100, Acc=80.8700, lr=0.0017


                                                          

Epoch 21/100, Acc=77.6800, lr=0.0100


                                                          

Epoch 22/100, Acc=77.1400, lr=0.0091


                                                          

Epoch 23/100, Acc=77.5000, lr=0.0068


                                                          

Epoch 24/100, Acc=80.3800, lr=0.0040


                                                          

Epoch 25/100, Acc=81.6100, lr=0.0017


                                                          

Epoch 26/100, Acc=78.4300, lr=0.0100


                                                          

Epoch 27/100, Acc=79.3300, lr=0.0091


                                                          

Epoch 28/100, Acc=80.1700, lr=0.0068


                                                          

Epoch 29/100, Acc=81.3800, lr=0.0040


                                                          

Epoch 30/100, Acc=82.0800, lr=0.0017


                                                          

Epoch 31/100, Acc=80.0800, lr=0.0100


                                                          

Epoch 32/100, Acc=79.5600, lr=0.0091


                                                          

Epoch 33/100, Acc=81.0500, lr=0.0068


                                                          

Epoch 34/100, Acc=82.4000, lr=0.0040


                                                          

Epoch 35/100, Acc=83.4300, lr=0.0017


                                                          

Epoch 36/100, Acc=79.1700, lr=0.0100


                                                          

Epoch 37/100, Acc=79.1900, lr=0.0091


                                                          

Epoch 38/100, Acc=81.2400, lr=0.0068


                                                          

Epoch 39/100, Acc=82.4500, lr=0.0040


                                                          

Epoch 40/100, Acc=84.0200, lr=0.0017


                                                          

Epoch 41/100, Acc=79.2200, lr=0.0100


                                                          

Epoch 42/100, Acc=80.7900, lr=0.0091


                                                          

Epoch 43/100, Acc=82.0300, lr=0.0068


                                                          

Epoch 44/100, Acc=82.6900, lr=0.0040


                                                          

Epoch 45/100, Acc=84.1300, lr=0.0017


                                                          

Epoch 46/100, Acc=80.2100, lr=0.0100


                                                          

Epoch 47/100, Acc=80.6100, lr=0.0091


                                                          

Epoch 48/100, Acc=82.1500, lr=0.0068


                                                          

Epoch 49/100, Acc=83.5000, lr=0.0040


                                                          

Epoch 50/100, Acc=83.6600, lr=0.0017


                                                          

Epoch 51/100, Acc=80.5800, lr=0.0100


                                                          

Epoch 52/100, Acc=81.0300, lr=0.0091


                                                          

Epoch 53/100, Acc=81.8600, lr=0.0068


                                                          

Epoch 54/100, Acc=83.5400, lr=0.0040


                                                          

Epoch 55/100, Acc=84.2700, lr=0.0017


                                                          

Epoch 56/100, Acc=79.7200, lr=0.0100


                                                          

Epoch 57/100, Acc=80.9800, lr=0.0091


                                                          

Epoch 58/100, Acc=82.2000, lr=0.0068


                                                          

Epoch 59/100, Acc=83.6600, lr=0.0040


                                                          

Epoch 60/100, Acc=84.3600, lr=0.0017


                                                          

Epoch 61/100, Acc=79.2500, lr=0.0100


                                                          

Epoch 62/100, Acc=81.9900, lr=0.0091


                                                          

Epoch 63/100, Acc=82.9500, lr=0.0068


                                                          

Epoch 64/100, Acc=83.5300, lr=0.0040


                                                          

Epoch 65/100, Acc=84.6700, lr=0.0017


                                                          

Epoch 66/100, Acc=78.9300, lr=0.0100


                                                          

Epoch 67/100, Acc=80.1900, lr=0.0091


                                                          

Epoch 68/100, Acc=82.1400, lr=0.0068


                                                          

Epoch 69/100, Acc=83.4300, lr=0.0040


                                                          

Epoch 70/100, Acc=84.8700, lr=0.0017


                                                          

Epoch 71/100, Acc=81.5700, lr=0.0100


                                                          

Epoch 72/100, Acc=81.5100, lr=0.0091


                                                          

Epoch 73/100, Acc=82.4000, lr=0.0068


                                                          

Epoch 74/100, Acc=83.9300, lr=0.0040


                                                          

Epoch 75/100, Acc=85.0300, lr=0.0017


                                                          

Epoch 76/100, Acc=81.4900, lr=0.0100


                                                          

Epoch 77/100, Acc=82.3400, lr=0.0091


                                                          

Epoch 78/100, Acc=83.5600, lr=0.0068


                                                          

Epoch 79/100, Acc=84.5200, lr=0.0040


                                                          

Epoch 80/100, Acc=85.1100, lr=0.0017


                                                          

Epoch 81/100, Acc=81.6000, lr=0.0100


                                                          

Epoch 82/100, Acc=81.1300, lr=0.0091


                                                          

Epoch 83/100, Acc=83.8200, lr=0.0068


                                                          

Epoch 84/100, Acc=84.5900, lr=0.0040


                                                          

Epoch 85/100, Acc=85.4600, lr=0.0017


                                                          

Epoch 86/100, Acc=81.6000, lr=0.0100


                                                          

Epoch 87/100, Acc=82.0600, lr=0.0091


                                                          

Epoch 88/100, Acc=83.1300, lr=0.0068


                                                          

Epoch 89/100, Acc=85.1200, lr=0.0040


                                                          

Epoch 90/100, Acc=85.4700, lr=0.0017


                                                          

Epoch 91/100, Acc=81.7700, lr=0.0100


                                                          

Epoch 92/100, Acc=82.7300, lr=0.0091


                                                          

Epoch 93/100, Acc=83.2700, lr=0.0068


                                                          

Epoch 94/100, Acc=84.3600, lr=0.0040


                                                          

Epoch 95/100, Acc=86.1200, lr=0.0017


                                                          

Epoch 96/100, Acc=82.0500, lr=0.0100


                                                          

Epoch 97/100, Acc=82.2600, lr=0.0091


                                                          

Epoch 98/100, Acc=84.0100, lr=0.0068


                                                          

Epoch 99/100, Acc=85.8200, lr=0.0040


                                                          

Epoch 100/100, Acc=86.0900, lr=0.0017
Best Acc=86.1200


