## Setup

In [1]:
!pip install "nvidia-modelopt[all]" -U --extra-index-url https://pypi.nvidia.com

Looking in indexes: https://pypi.org/simple, https://pypi.nvidia.com
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [2]:
import os
import time
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import modelopt.torch.prune as mtp
import modelopt.torch.opt as mto

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{device=}")

device=device(type='cuda')


## Get CIFAR-10 train and test sets

In [3]:
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_loader = DataLoader(
    datasets.CIFAR10(root="./data", train=True, download=True, transform=transform),
    batch_size=128, shuffle=True
)

test_loader = DataLoader(
    datasets.CIFAR10(root="./data", train=False, download=True, transform=transform),
    batch_size=256
)

## Adjust ResNet18 network for CIFAR-10 dataset

In [4]:
def get_resnet18_for_cifar10():
    model = models.resnet18(weights=None, num_classes=10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    return model.to(device)

full_model = get_resnet18_for_cifar10()

## Define Train and Evaluate functions

In [5]:
def train(model, loader, epochs, lr=0.01, save_path="model.pth", silent=False):
    if os.path.exists(save_path):
        if not silent:
            print(f"Model already trained. Loading from {save_path}")
        model.load_state_dict(torch.load(save_path))
        return

    # no saved model found. training from given model state

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    model.train()

    for epoch in range(epochs):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()
        if not silent:
            print(f"Epoch {epoch+1}: loss={loss.item():.4f}")

    torch.save(model.state_dict(), save_path)
    if not silent:
        print(f"Training complete. Model saved to {save_path}")

In [6]:
def evaluate(model):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            preds = model(x).argmax(1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total

## Define helper functions to measure latency

In [7]:
class Timer:
    def __init__(self):
        self.use_cuda = torch.cuda.is_available()
        if self.use_cuda:
            self.starter = torch.cuda.Event(enable_timing=True)
            self.ender = torch.cuda.Event(enable_timing=True)

    def start(self):
        if self.use_cuda:
            self.starter.record()
        else:
            self.start_time = time.time()

    def stop(self):
        if self.use_cuda:
            self.ender.record()
            torch.cuda.synchronize()
            return self.starter.elapsed_time(self.ender)  # ms
        else:
            return (time.time() - self.start_time) * 1000  # ms

In [8]:
def estimate_latency(model, example_inputs, repetitions=50):
    timer = Timer()
    timings = np.zeros((repetitions, 1))

    # warm-up
    for _ in range(5):
        _ = model(example_inputs)

    with torch.no_grad():
        for rep in range(repetitions):
            timer.start()
            _ = model(example_inputs)
            elapsed = timer.stop()
            timings[rep] = elapsed

    return np.mean(timings), np.std(timings)

## Train and Evaluate full model

In [9]:
train(full_model, train_loader, epochs=10, save_path="full_model.pth")
accuracy_full = evaluate(full_model)

example_input = torch.rand(128, 3, 32, 32).to(device)
latency_mu, latency_std = estimate_latency(full_model, example_input)
print(f"[full model] \t\tLatency: {latency_mu:.2f} ± {latency_std:.2f} ms \tAccuracy: {accuracy_full*100:.2f}%")

Model already trained. Loading from full_model.pth
[full model] 		Latency: 16.74 ± 0.06 ms 	Accuracy: 76.85%


## Prune

In [10]:
# clone full model before pruning
pruned_model = copy.deepcopy(full_model)
pruned_model = pruned_model.to(device)

# set which layers to skip pruning. important to keep final classifier layer
ignored_layers = []
for m in pruned_model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 10:
        ignored_layers.append(m)

In [11]:

   
# iterative pruning
iterative_steps = 10


In [12]:
for iter in range(iterative_steps):
    # prune
    prune_constraints = {"params": f"{100 - (iter+1)*10}%"}

    pruned_model, prune_res = mtp.prune(
        model=copy.deepcopy(full_model).to(device),
        mode="fastnas",
        constraints=prune_constraints,
        dummy_input=example_input,
        config={
            "data_loader": train_loader,  # training data is used for calibrating BN layers
            "score_func": evaluate,  # validation score is used to rank the subnets
            # checkpoint to store the search state and resume or re-run the search with different constraint
            "checkpoint": f"modelopt_fastnas_search_checkpoint_{iter}.pth",
        },
    )
    
    # evaluate after prune
    acc_before = evaluate(pruned_model)
    # fine-tune pruned model
    train(pruned_model, train_loader, epochs=1, save_path=f"pruned_model_mtp_{iter}.pth", silent=True)
    # evaluate after fine-tune
    acc_after = evaluate(pruned_model)
    latency_mu, latency_std = estimate_latency(pruned_model, example_input)
    current_pruning_ratio = 1 / iterative_steps * (iter + 1)
    print(f"[pruned model] \tPrun constraints: {prune_constraints['params']}, \tLatency: {latency_mu:.2f} ± {latency_std:.2f} ms \tAccuracy pruned: {acc_before*100:.2f}%\tFinetuned: {acc_after*100:.2f}%")

    mto.save(pruned_model, f"modelopt_pruned_model_iter_{iter}.pth")




Profiling the following subnets from the given model: ('min', 'centroid', 'max').
--------------------------------------------------------------------------------


[3m                                                                             [0m
[3m                              Profiling Results                              [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃[1m [0m[1mConstraint  [0m[1m [0m┃[1m [0m[1mmin         [0m[1m [0m┃[1m [0m[1mcentroid    [0m[1m [0m┃[1m [0m[1mmax         [0m[1m [0m┃[1m [0m[1mmax/min ratio[0m[1m [0m┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ flops        │ 7.46G        │ 21.52G       │ 71.09G       │ 9.54          │
│ params       │ 586.60K      │ 4.60M        │ 11.16M       │ 19.03         │
└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘
[3m                                              [0m
[3m            Constraints Evaluation            [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃[1m              [0m┃[1m              [0m┃[1m [0m[1mSatisfiable [0m[1m [

Collecting pre-search statistics: 100%|██████████| 74/74 [04:23<00:00,  3.56s/it, cur=layer4.1.conv1.out_channels(512/512): 0.00] 
[num_satisfied] = 16:   0%|          | 16/5000 [00:14<1:14:05,  1.12it/s]


[best_subnet_constraints] = {'params': '9.69M', 'flops': '68.07G'}
[pruned model] 	Prun constraints: 90%, 	Latency: 17.51 ± 0.09 ms 	Accuracy pruned: 78.55%	Finetuned: 78.10%

Profiling the following subnets from the given model: ('min', 'centroid', 'max').
--------------------------------------------------------------------------------


[3m                                                                             [0m
[3m                              Profiling Results                              [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃[1m [0m[1mConstraint  [0m[1m [0m┃[1m [0m[1mmin         [0m[1m [0m┃[1m [0m[1mcentroid    [0m[1m [0m┃[1m [0m[1mmax         [0m[1m [0m┃[1m [0m[1mmax/min ratio[0m[1m [0m┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ flops        │ 7.46G        │ 21.52G       │ 71.09G       │ 9.54          │
│ params       │ 586.60K      │ 4.60M        │ 11.16M       │ 19.03         │
└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘
[3m                                              [0m
[3m            Constraints Evaluation            [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃[1m              [0m┃[1m              [0m┃[1m [0m[1mSatisfiable [0m[1m [

Collecting pre-search statistics: 100%|██████████| 74/74 [04:08<00:00,  3.36s/it, cur=layer4.1.conv1.out_channels(512/512): 0.00]
[num_satisfied] = 10:   0%|          | 16/5000 [00:12<1:04:53,  1.28it/s]


[best_subnet_constraints] = {'params': '8.66M', 'flops': '65.96G'}
[pruned model] 	Prun constraints: 80%, 	Latency: 16.50 ± 0.04 ms 	Accuracy pruned: 77.54%	Finetuned: 76.08%

Profiling the following subnets from the given model: ('min', 'centroid', 'max').
--------------------------------------------------------------------------------


[3m                                                                             [0m
[3m                              Profiling Results                              [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃[1m [0m[1mConstraint  [0m[1m [0m┃[1m [0m[1mmin         [0m[1m [0m┃[1m [0m[1mcentroid    [0m[1m [0m┃[1m [0m[1mmax         [0m[1m [0m┃[1m [0m[1mmax/min ratio[0m[1m [0m┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ flops        │ 7.46G        │ 21.52G       │ 71.09G       │ 9.54          │
│ params       │ 586.60K      │ 4.60M        │ 11.16M       │ 19.03         │
└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘
[3m                                              [0m
[3m            Constraints Evaluation            [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃[1m              [0m┃[1m              [0m┃[1m [0m[1mSatisfiable [0m[1m [

Collecting pre-search statistics: 100%|██████████| 74/74 [04:15<00:00,  3.45s/it, cur=layer4.1.conv1.out_channels(512/512): 0.00] 
[num_satisfied] = 11:   0%|          | 16/5000 [00:12<1:04:30,  1.29it/s]


[best_subnet_constraints] = {'params': '7.77M', 'flops': '64.15G'}
[pruned model] 	Prun constraints: 70%, 	Latency: 16.25 ± 0.03 ms 	Accuracy pruned: 77.34%	Finetuned: 76.61%

Profiling the following subnets from the given model: ('min', 'centroid', 'max').
--------------------------------------------------------------------------------


[3m                                                                             [0m
[3m                              Profiling Results                              [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃[1m [0m[1mConstraint  [0m[1m [0m┃[1m [0m[1mmin         [0m[1m [0m┃[1m [0m[1mcentroid    [0m[1m [0m┃[1m [0m[1mmax         [0m[1m [0m┃[1m [0m[1mmax/min ratio[0m[1m [0m┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ flops        │ 7.46G        │ 21.52G       │ 71.09G       │ 9.54          │
│ params       │ 586.60K      │ 4.60M        │ 11.16M       │ 19.03         │
└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘
[3m                                              [0m
[3m            Constraints Evaluation            [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃[1m              [0m┃[1m              [0m┃[1m [0m[1mSatisfiable [0m[1m [

Collecting pre-search statistics: 100%|██████████| 74/74 [04:10<00:00,  3.38s/it, cur=layer4.1.conv1.out_channels(512/512): 0.00] 
[num_satisfied] = 11:   0%|          | 16/5000 [00:12<1:03:15,  1.31it/s]


[best_subnet_constraints] = {'params': '6.41M', 'flops': '60.68G'}
[pruned model] 	Prun constraints: 60%, 	Latency: 16.47 ± 0.04 ms 	Accuracy pruned: 74.26%	Finetuned: 67.26%

Profiling the following subnets from the given model: ('min', 'centroid', 'max').
--------------------------------------------------------------------------------


[3m                                                                             [0m
[3m                              Profiling Results                              [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃[1m [0m[1mConstraint  [0m[1m [0m┃[1m [0m[1mmin         [0m[1m [0m┃[1m [0m[1mcentroid    [0m[1m [0m┃[1m [0m[1mmax         [0m[1m [0m┃[1m [0m[1mmax/min ratio[0m[1m [0m┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ flops        │ 7.46G        │ 21.52G       │ 71.09G       │ 9.54          │
│ params       │ 586.60K      │ 4.60M        │ 11.16M       │ 19.03         │
└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘
[3m                                              [0m
[3m            Constraints Evaluation            [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃[1m              [0m┃[1m              [0m┃[1m [0m[1mSatisfiable [0m[1m [

Collecting pre-search statistics: 100%|██████████| 74/74 [04:08<00:00,  3.36s/it, cur=layer4.1.conv1.out_channels(512/512): 0.00] 
[num_satisfied] = 10:   0%|          | 16/5000 [00:13<1:11:16,  1.17it/s]


[best_subnet_constraints] = {'params': '5.52M', 'flops': '57.96G'}
[pruned model] 	Prun constraints: 50%, 	Latency: 15.41 ± 0.03 ms 	Accuracy pruned: 69.86%	Finetuned: 69.02%

Profiling the following subnets from the given model: ('min', 'centroid', 'max').
--------------------------------------------------------------------------------


[3m                                                                             [0m
[3m                              Profiling Results                              [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃[1m [0m[1mConstraint  [0m[1m [0m┃[1m [0m[1mmin         [0m[1m [0m┃[1m [0m[1mcentroid    [0m[1m [0m┃[1m [0m[1mmax         [0m[1m [0m┃[1m [0m[1mmax/min ratio[0m[1m [0m┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ flops        │ 7.46G        │ 21.52G       │ 71.09G       │ 9.54          │
│ params       │ 586.60K      │ 4.60M        │ 11.16M       │ 19.03         │
└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘
[3m                                              [0m
[3m            Constraints Evaluation            [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃[1m              [0m┃[1m              [0m┃[1m [0m[1mSatisfiable [0m[1m [

Collecting pre-search statistics: 100%|██████████| 74/74 [04:05<00:00,  3.31s/it, cur=layer4.1.conv1.out_channels(512/512): 0.00]
[num_satisfied] = 10:   0%|          | 16/5000 [00:12<1:03:22,  1.31it/s]


[best_subnet_constraints] = {'params': '4.39M', 'flops': '53.17G'}
[pruned model] 	Prun constraints: 40%, 	Latency: 14.81 ± 0.02 ms 	Accuracy pruned: 55.97%	Finetuned: 71.66%

Profiling the following subnets from the given model: ('min', 'centroid', 'max').
--------------------------------------------------------------------------------


[3m                                                                             [0m
[3m                              Profiling Results                              [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃[1m [0m[1mConstraint  [0m[1m [0m┃[1m [0m[1mmin         [0m[1m [0m┃[1m [0m[1mcentroid    [0m[1m [0m┃[1m [0m[1mmax         [0m[1m [0m┃[1m [0m[1mmax/min ratio[0m[1m [0m┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ flops        │ 7.46G        │ 21.52G       │ 71.09G       │ 9.54          │
│ params       │ 586.60K      │ 4.60M        │ 11.16M       │ 19.03         │
└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘
[3m                                              [0m
[3m            Constraints Evaluation            [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃[1m              [0m┃[1m              [0m┃[1m [0m[1mSatisfiable [0m[1m [

Collecting pre-search statistics: 100%|██████████| 74/74 [04:14<00:00,  3.44s/it, cur=layer4.1.conv1.out_channels(512/512): 0.00] 
[num_satisfied] = 8:   0%|          | 16/5000 [00:11<1:00:09,  1.38it/s]


[best_subnet_constraints] = {'params': '3.28M', 'flops': '45.36G'}
[pruned model] 	Prun constraints: 30%, 	Latency: 18.69 ± 0.04 ms 	Accuracy pruned: 38.84%	Finetuned: 72.00%

Profiling the following subnets from the given model: ('min', 'centroid', 'max').
--------------------------------------------------------------------------------


[3m                                                                             [0m
[3m                              Profiling Results                              [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃[1m [0m[1mConstraint  [0m[1m [0m┃[1m [0m[1mmin         [0m[1m [0m┃[1m [0m[1mcentroid    [0m[1m [0m┃[1m [0m[1mmax         [0m[1m [0m┃[1m [0m[1mmax/min ratio[0m[1m [0m┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ flops        │ 7.46G        │ 21.52G       │ 71.09G       │ 9.54          │
│ params       │ 586.60K      │ 4.60M        │ 11.16M       │ 19.03         │
└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘
[3m                                              [0m
[3m            Constraints Evaluation            [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃[1m              [0m┃[1m              [0m┃[1m [0m[1mSatisfiable [0m[1m [

Collecting pre-search statistics: 100%|██████████| 74/74 [04:06<00:00,  3.33s/it, cur=layer4.1.conv1.out_channels(512/512): 0.00] 
[num_satisfied] = 9:   0%|          | 16/5000 [00:09<51:45,  1.61it/s]  


[best_subnet_constraints] = {'params': '2.18M', 'flops': '40.99G'}
[pruned model] 	Prun constraints: 20%, 	Latency: 17.29 ± 0.06 ms 	Accuracy pruned: 32.16%	Finetuned: 72.17%

Profiling the following subnets from the given model: ('min', 'centroid', 'max').
--------------------------------------------------------------------------------


[3m                                                                             [0m
[3m                              Profiling Results                              [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃[1m [0m[1mConstraint  [0m[1m [0m┃[1m [0m[1mmin         [0m[1m [0m┃[1m [0m[1mcentroid    [0m[1m [0m┃[1m [0m[1mmax         [0m[1m [0m┃[1m [0m[1mmax/min ratio[0m[1m [0m┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ flops        │ 7.46G        │ 21.52G       │ 71.09G       │ 9.54          │
│ params       │ 586.60K      │ 4.60M        │ 11.16M       │ 19.03         │
└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘
[3m                                              [0m
[3m            Constraints Evaluation            [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃[1m              [0m┃[1m              [0m┃[1m [0m[1mSatisfiable [0m[1m [

Collecting pre-search statistics: 100%|██████████| 74/74 [04:11<00:00,  3.39s/it, cur=layer4.1.conv1.out_channels(512/512): 0.00] 
[num_satisfied] = 9:   0%|          | 16/5000 [00:06<33:27,  2.48it/s]  


[best_subnet_constraints] = {'params': '1.04M', 'flops': '19.39G'}
[pruned model] 	Prun constraints: 10%, 	Latency: 13.73 ± 0.03 ms 	Accuracy pruned: 10.82%	Finetuned: 73.66%

Profiling the following subnets from the given model: ('min', 'centroid', 'max').
--------------------------------------------------------------------------------


[3m                                                                             [0m
[3m                              Profiling Results                              [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃[1m [0m[1mConstraint  [0m[1m [0m┃[1m [0m[1mmin         [0m[1m [0m┃[1m [0m[1mcentroid    [0m[1m [0m┃[1m [0m[1mmax         [0m[1m [0m┃[1m [0m[1mmax/min ratio[0m[1m [0m┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ flops        │ 7.46G        │ 21.52G       │ 71.09G       │ 9.54          │
│ params       │ 586.60K      │ 4.60M        │ 11.16M       │ 19.03         │
└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘
[3m                                              [0m
[3m            Constraints Evaluation            [0m
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓
┃[1m              [0m┃[1m              [0m┃[1m [0m[1mSatisfiable [0m[1m [

ValueError: NOT all constraints can be satisfied within the search space, see above!

In [None]:
# prune
prune_constraints = {"params": "1%"}

pruned_model, prune_res = mtp.prune(
    model=copy.deepcopy(full_model).to(device),
    mode="fastnas",
    constraints=prune_constraints,
    dummy_input=example_input,
    config={
        "data_loader": train_loader,  # training data is used for calibrating BN layers
        "score_func": evaluate,  # validation score is used to rank the subnets
        # checkpoint to store the search state and resume or re-run the search with different constraint
        "checkpoint": f"modelopt_fastnas_search_checkpoint_9.pth",
    },
)

# evaluate after prune
acc_before = evaluate(pruned_model)
# fine-tune pruned model
train(pruned_model, train_loader, epochs=1, save_path=f"pruned_model_mtp_9.pth", silent=True)
# evaluate after fine-tune
acc_after = evaluate(pruned_model)
latency_mu, latency_std = estimate_latency(pruned_model, example_input)
current_pruning_ratio = 1 / iterative_steps * (iter + 1)
print(f"[pruned model] \tPrun constraints: {prune_constraints['params']}, \tLatency: {latency_mu:.2f} ± {latency_std:.2f} ms \tAccuracy pruned: {acc_before*100:.2f}%\tFinetuned: {acc_after*100:.2f}%")

mto.save(pruned_model, f"modelopt_pruned_model_iter_{iter}.pth")

## Extra fine-tune last pruned model

In [None]:
train(pruned_model, train_loader, epochs=5, save_path=f"pruned_model_final_tuning.pth")

In [None]:
accuracy_final = evaluate(pruned_model)
print(f"Pruned extra fine-tuned model accuracy: {accuracy_final*100:.2f}%")