# **Template for Torch-Pruning**

This template is just built for your convinience.

You are not required to follow the steps and method given below.

In [1]:
!pip install --upgrade torch_pruning



In [2]:
import torch
import torchvision
from torchvision.models import mobilenet_v2
import torch_pruning as tp

##A Minimal Example ##  
In this section, you will perform channel pruning using the library [Torch-Pruning](https://github.com/VainF/Torch-Pruning).  

The puuner in Torch-Pruning has three main functions: sparse training (optional), importance estimation, and parameter removal.  
Torch-pruning offers two core features to support this process:

tp.importance(): This criteria is utilized to measure the importance of weights.  

tp.pruner(): This is a pruner used for the actual pruning of the parameters.  

For detailed information on this process, please refer to this [tutorial](https://github.com/VainF/Torch-Pruning/wiki/4.-High%E2%80%90level-Pruners/). Additionally, a more practical example is available in [here](https://github.com/VainF/Torch-Pruning/blob/master/benchmarks/main.py).

### 1. Load model


In [3]:
model = torch.load('./model/mobilenetv2_0.963.pth', map_location="cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

### 2. Prepare a pruner
By default, Torch-Pruning will automatically prune the last non-singleton dim of these parameters. If you want to customize this behaviour, please provide an `unwrapped_parameters` list as the following example.

In [4]:
# Importance criterion
imp = tp.importance.GroupNormImportance(p=2) #GroupTaylorImportance() # or , GroupHessianImportance(), etc.

# Initialize a pruner with the model and the importance criterion
example_inputs = torch.randn(1, 3, 224, 224).to(device)

ignored_layers = [model.features[2].conv[2]]
channel_groups = {}
for m in model.modules():
  if isinstance(m, torch.nn.Linear) and m.out_features == 10: # ignore the classifier
    ignored_layers.append(m)

pruner = tp.pruner.GroupNormPruner ( # you can choose any pruner you like.
    model,
    example_inputs,
    importance=imp,   # Importance Estimator
    pruning_ratio=0.5, # remove 50% channels, ex :ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    # ignored_layers=ignored_layers,
    # pruning_ratio_dict = {model.features[15].conv[0][0]:0.90,
    #                       model.features[15].conv[2]:0.85,
    #                       model.features[16].conv[0][0]:0.93,
    #                       model.features[16].conv[2]:0.60,
    #                       model.features[17].conv[0][0]:0.92,
    #                       model.features[17].conv[2]:0.85,
    #                       model.features[18][0]: 0.93}, # manually set the sparsity of model.conv1
    iterative_steps = 1,  # number of steps to achieve the target ch_sparsity.
    ignored_layers = ignored_layers,        # ignore some layers such as the finall linear classifier
    channel_groups = channel_groups,  # round channels
)

In [5]:
model.features[15].conv[0][0]

Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)

### 3. Prune the model

In [6]:
model

MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

In [7]:
import random
import copy
import numpy as np
import torch
from matplotlib import pyplot as plt
from torch import nn
from torch.optim import *
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader
from torchprofile import profile_macs
from torchvision.datasets import *
from torchvision.transforms import *
from tqdm.auto import tqdm


assert torch.cuda.is_available(), \
"The current runtime does not have CUDA support." \
"Please go to menu bar (Runtime - Change runtime type) and select GPU"

In [8]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x723887f3b970>

In [9]:
transforms = {
    "train": Compose([
      Resize((224, 224)),
      ToTensor(),
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    "test": Compose([
      Resize((224, 224)),
      ToTensor(),
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
}

dataset = {}
for split in ["train", "test"]:
  dataset[split] = CIFAR10(
    root="data/cifar10",
    train=(split == "train"),
    download=True,
    transform=transforms[split],
  )

# You can apply your own batch_size
dataloader = {}
for split in ['train', 'test']:
  dataloader[split] = DataLoader(
    dataset[split],
    batch_size=64,
    shuffle=(split == 'train'),
    num_workers=0,
    pin_memory=True,
    drop_last=True
  )

Files already downloaded and verified
Files already downloaded and verified


In [10]:
def train(
  model: nn.Module,
  dataloader: DataLoader,
  criterion: nn.Module,
  optimizer: Optimizer,
  scheduler: LambdaLR,
  callbacks = None
) -> None:
  model.train()

  for inputs, targets in tqdm(dataloader, desc='train', leave=False):
    # Move the data from CPU to GPU
    inputs = inputs.cuda()
    targets = targets.cuda()

    # Reset the gradients (from the last iteration)
    optimizer.zero_grad()

    # Forward inference
    outputs = model(inputs)
    loss = criterion(outputs, targets)

    # Backward propagation
    loss.backward()

    # Update optimizer and LR scheduler
    optimizer.step()
    scheduler.step()

    if callbacks is not None:
        for callback in callbacks:
            callback()

In [11]:
@torch.inference_mode()
def evaluate(
  model: nn.Module,
  dataloader: DataLoader,
  verbose=True,
) -> float:
  model.eval()

  num_samples = 0
  num_correct = 0

  for inputs, targets in tqdm(dataloader, desc="eval", leave=False,
                              disable=not verbose):
    # Move the data from CPU to GPU
    inputs = inputs.cuda()
    targets = targets.cuda()

    # Inference
    outputs = model(inputs)

    # Convert logits to class indices
    outputs = outputs.argmax(dim=1)

    # Update metrics
    num_samples += targets.size(0)
    num_correct += (outputs == targets).sum()

  return (num_correct / num_samples * 100).item()

In [12]:
def finetune(model, num_finetune_epochs):
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_finetune_epochs)
    criterion = nn.CrossEntropyLoss()

    best_sparse_model_checkpoint = dict()
    best_accuracy = 0
    print(f'Finetuning Fine-grained Pruned Sparse Model')
    for epoch in range(num_finetune_epochs):
        # At the end of each train iteration, we have to apply the pruning mask
        #    to keep the model sparse during the training
        train(model, dataloader['train'], criterion, optimizer, scheduler)
        accuracy = evaluate(model, dataloader['test'])
        is_best = accuracy > best_accuracy
        if is_best:
            best_sparse_model_checkpoint['state_dict'] = copy.deepcopy(model.state_dict())
            best_accuracy = accuracy
        print(f'    Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%')

    # load the best sparse model checkpoint to evaluate the final performance
    model.load_state_dict(best_sparse_model_checkpoint['state_dict'])
    sparse_model_accuracy = evaluate(model, dataloader['test'])
    print(f"Sparse model has accuracy={sparse_model_accuracy:.2f}% after fintuning")
    
    return model

In [13]:
EPOCHS = 60

In [14]:
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
base_MFLOPs = base_macs/1e6
print("The base model:")
print(model)
print("Base Model Summary:")
print("#Parameter: ")
print(base_nparams)
print("MFLOPs: ")
print(base_MFLOPs)

The base model:
MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): Batc

In [15]:
# Model size before pruning

model.eval()

if isinstance(imp, tp.importance.GroupTaylorImportance):
  # Taylor expansion requires gradients for importance estimation
  for inputs, targets in tqdm(dataloader['test'], desc='eval', leave=False):
    loss = nn.CrossEntropyLoss(model(inputs), targets) # A dummy loss, please replace this line with your loss function and data!
    loss.backward() # before pruner.step()
    break

# prune
pruner.step()

# Parameter & MACs Counter
pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(model, example_inputs)
pruned_MFLOPs = pruned_macs/1e6
print("The pruned model:")
print(model)
print("Summary:")
print("#Parameter: ")
print(pruned_nparams)
print("MFLOPs: ")
print(pruned_MFLOPs)

The pruned model:
MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(8, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): Batch

In [16]:
# finetune the pruned model here
finetuned_model = finetune(model, EPOCHS)
# ...

finetuned_macs, finetuned_nparams = tp.utils.count_ops_and_params(finetuned_model, example_inputs)
finetuned_MFLOPs = finetuned_macs/1e6
print("The finetuned pruned model:")
print(finetuned_model)

Finetuning Fine-grained Pruned Sparse Model


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 1 Accuracy 76.20% / Best Accuracy: 76.20%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 2 Accuracy 81.82% / Best Accuracy: 81.82%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 3 Accuracy 85.31% / Best Accuracy: 85.31%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 4 Accuracy 86.54% / Best Accuracy: 86.54%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 5 Accuracy 87.75% / Best Accuracy: 87.75%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 6 Accuracy 87.46% / Best Accuracy: 87.75%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 7 Accuracy 88.74% / Best Accuracy: 88.74%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 8 Accuracy 87.01% / Best Accuracy: 88.74%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 9 Accuracy 88.85% / Best Accuracy: 88.85%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 10 Accuracy 88.09% / Best Accuracy: 88.85%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 11 Accuracy 89.75% / Best Accuracy: 89.75%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 12 Accuracy 89.37% / Best Accuracy: 89.75%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 13 Accuracy 89.50% / Best Accuracy: 89.75%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 14 Accuracy 89.53% / Best Accuracy: 89.75%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 15 Accuracy 89.70% / Best Accuracy: 89.75%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 16 Accuracy 88.95% / Best Accuracy: 89.75%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 17 Accuracy 89.77% / Best Accuracy: 89.77%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 18 Accuracy 89.40% / Best Accuracy: 89.77%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 19 Accuracy 89.83% / Best Accuracy: 89.83%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 20 Accuracy 89.52% / Best Accuracy: 89.83%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 21 Accuracy 90.43% / Best Accuracy: 90.43%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 22 Accuracy 89.81% / Best Accuracy: 90.43%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 23 Accuracy 90.29% / Best Accuracy: 90.43%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 24 Accuracy 89.44% / Best Accuracy: 90.43%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 25 Accuracy 90.41% / Best Accuracy: 90.43%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 26 Accuracy 90.15% / Best Accuracy: 90.43%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 27 Accuracy 90.46% / Best Accuracy: 90.46%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 28 Accuracy 90.12% / Best Accuracy: 90.46%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 29 Accuracy 90.46% / Best Accuracy: 90.46%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 30 Accuracy 89.95% / Best Accuracy: 90.46%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 31 Accuracy 90.39% / Best Accuracy: 90.46%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 32 Accuracy 90.17% / Best Accuracy: 90.46%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 33 Accuracy 90.44% / Best Accuracy: 90.46%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 34 Accuracy 90.27% / Best Accuracy: 90.46%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 35 Accuracy 90.13% / Best Accuracy: 90.46%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 36 Accuracy 90.49% / Best Accuracy: 90.49%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 37 Accuracy 90.42% / Best Accuracy: 90.49%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 38 Accuracy 90.54% / Best Accuracy: 90.54%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 39 Accuracy 90.83% / Best Accuracy: 90.83%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 40 Accuracy 90.24% / Best Accuracy: 90.83%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 41 Accuracy 90.43% / Best Accuracy: 90.83%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 42 Accuracy 90.24% / Best Accuracy: 90.83%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 43 Accuracy 90.43% / Best Accuracy: 90.83%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 44 Accuracy 90.11% / Best Accuracy: 90.83%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 45 Accuracy 90.57% / Best Accuracy: 90.83%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 46 Accuracy 90.38% / Best Accuracy: 90.83%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 47 Accuracy 90.61% / Best Accuracy: 90.83%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 48 Accuracy 90.59% / Best Accuracy: 90.83%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 49 Accuracy 90.46% / Best Accuracy: 90.83%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 50 Accuracy 90.42% / Best Accuracy: 90.83%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 51 Accuracy 90.50% / Best Accuracy: 90.83%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 52 Accuracy 90.36% / Best Accuracy: 90.83%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 53 Accuracy 90.71% / Best Accuracy: 90.83%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 54 Accuracy 90.81% / Best Accuracy: 90.83%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 55 Accuracy 90.94% / Best Accuracy: 90.94%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 56 Accuracy 90.98% / Best Accuracy: 90.98%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 57 Accuracy 90.48% / Best Accuracy: 90.98%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 58 Accuracy 91.03% / Best Accuracy: 91.03%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 59 Accuracy 90.11% / Best Accuracy: 91.03%


train:   0%|          | 0/781 [00:00<?, ?it/s]

eval:   0%|          | 0/156 [00:00<?, ?it/s]

    Epoch 60 Accuracy 90.69% / Best Accuracy: 91.03%


eval:   0%|          | 0/156 [00:00<?, ?it/s]

Sparse model has accuracy=91.03% after fintuning
The finetuned pruned model:
MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(8, 48, kernel_s

In [17]:
torch.save(finetuned_model, './model/part2_model.pth')


In [None]:
# finetune the pruned model here
# finetuned_model = finetune(finetuned_model, EPOCHS)

# finetuned_macs, finetuned_nparams = tp.utils.count_ops_and_params(finetuned_model, example_inputs)
# finetuned_MFLOPs = finetuned_macs/1e6
# print("The finetuned pruned model:")
# print(finetuned_model)