In [None]:
%pip install tensorrt
%pip install pytorch-bench
%pip install torch-pruning

In [None]:
!git clone https://github.com/NVIDIA-AI-IOT/torch2trt
%cd torch2trt
!python setup.py install

In [3]:
#@title Main function
"""
Apply pruning, knowledge distillation and quantization to a model
"""

# Imports
import torch_pruning as tp
import torch.nn.functional as F
import torch
import os
import copy
import random
import numpy as np
import torch
from torch import nn
from torch.optim import *
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader
from torchvision.datasets import *
from torchvision.transforms import *
from tqdm.auto import tqdm
from functools import partial
import logging
from datetime import datetime
from torch2trt import torch2trt
from pytorch_bench import get_model_size
assert torch.cuda.is_available(), "Cuda Not Available!"

# Logging
logging.basicConfig(level=logging.INFO, format='%(message)s')
# Device
device = torch.device("cuda")

def setup_logging(base_path):
    os.makedirs(base_path, exist_ok=True)
    logging.basicConfig(filename=f"{base_path}/log.txt", level=logging.INFO)
    run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    logging.info(f"Run ID: {run_id}")
    return run_id

# Evaluation loop
@torch.no_grad()
def evaluate(
    model: nn.Module,
    dataloader: DataLoader,
    verbose=True,
) -> float:
    """
    Evaluate the model on the given dataset
    """
    model.eval()

    num_samples = 0
    num_correct = 0
    loss = 0

    for inputs, targets in tqdm(dataloader, desc="eval", leave=False, disable=not verbose):
        # Move the data from CPU to GPU
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Inference
        outputs = model(inputs)
        # Calculate loss
        loss += F.cross_entropy(outputs, targets, reduction="sum")
        # Convert logits to class indices
        outputs = outputs.argmax(dim=1)
        # Update metrics
        num_samples += targets.size(0)
        num_correct += (outputs == targets).sum()
    return (num_correct / num_samples * 100).item(), (loss / num_samples).item()

# training loop
def train(
    model: nn.Module,
    train_loader: DataLoader,
    test_loader: DataLoader,
    epochs: int,
    lr: int,
    weight_decay=5e-4,
    callbacks=None,
    save=None,
    save_only_state_dict=False,
) -> None:

    optimizer = torch.optim.SGD(model.parameters(
    ), lr=lr, momentum=0.9, weight_decay=weight_decay)
    """
    Training loop
    """
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = nn.CrossEntropyLoss()
    best_acc = -1
    best_checkpoint = dict()


    for epoch in range(epochs):
        model.train()
        for inputs, targets in tqdm(train_loader, leave=False):
            inputs = inputs.to(device)
            targets = targets.to(device)

            # Reset the gradients (from the last iteration)
            optimizer.zero_grad()

            # Forward inference
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Backward propagation
            loss.backward()

            # Update optimizer
            optimizer.step()

            if callbacks is not None:
                for callback in callbacks:
                    callback()

        acc, val_loss = evaluate(model, test_loader)
        logging.info(
            f'Epoch {epoch + 1}/{epochs} | Val acc: {acc:.2f} | Val loss: {val_loss:.4f} | LR: {optimizer.param_groups[0]["lr"]:.6f}')



        if best_acc < acc:
            best_checkpoint['state_dict'] = copy.deepcopy(model.state_dict())
            best_acc = acc
        # Update LR scheduler
        scheduler.step()
    model.load_state_dict(best_checkpoint['state_dict'])
    if save:
        if save_only_state_dict:
            torch.save(model.state_dict(), save)
        else:
            torch.save(model, save)
    logging.info(f'Best val acc: {best_acc:.2f}')


def get_pruner(model, example_input, num_classes):
    imp = tp.importance.GroupNormImportance(p=2)
    pruner_entry = partial(tp.pruner.GroupNormPruner, isomorphic=True, global_pruning=False)


    unwrapped_parameters = []
    ignored_layers = []
    ch_sparsity_dict = {}
    # ignore output layers
    for m in model.modules():
        if isinstance(m, torch.nn.Linear) and m.out_features == num_classes:
            ignored_layers.append(m)
        elif isinstance(m, torch.nn.modules.conv._ConvNd) and m.out_channels == num_classes:
            ignored_layers.append(m)

    pruner = pruner_entry(
        model,
        example_input,
        importance=imp,
        iterative_steps=400,
        pruning_ratio=1.0,
        pruning_ratio_dict=ch_sparsity_dict,
        ignored_layers=ignored_layers,
        unwrapped_parameters=unwrapped_parameters,
    )
    return pruner

# pruning jusqu'à atteindre le speed up voulu
def progressive_pruning_speedup(pruner, model, speed_up, example_inputs):
    model.eval()
    base_ops, _ = tp.utils.count_ops_and_params(
        model, example_inputs=example_inputs)
    current_speed_up = 1
    while current_speed_up < speed_up:
        pruner.step(interactive=False)
        pruned_ops, _ = tp.utils.count_ops_and_params(
            model, example_inputs=example_inputs)
        current_speed_up = float(base_ops) / pruned_ops
        # print(current_speed_up)
    return current_speed_up


# pruning jusqu'à atteindre le ratio de compression voulu
def progressive_pruning_compression_ratio(pruner, model, compression_ratio, example_inputs):
    # compression ratio défini par taille initiale / taille finale
    model.eval()
    _, base_params = tp.utils.count_ops_and_params(
        model, example_inputs=example_inputs)
    current_compression_ratio = 1
    while current_compression_ratio < compression_ratio:
        pruner.step(interactive=False)
        _, pruned_params = tp.utils.count_ops_and_params(
            model, example_inputs=example_inputs)
        current_compression_ratio = float(base_params) / pruned_params
        # print(current_compression_ratio)
    return current_compression_ratio

# training loop
def train_kd(
    model_student: nn.Module,
    model_teacher: nn.Module,
    train_loader: DataLoader,
    test_loader: DataLoader,
    epochs: int,
    lr: int,
    temperature: int,
    alpha: float,
    weight_decay=5e-4,
    callbacks=None,
    save=None,
    save_only_state_dict=False,
) -> None:

    optimizer = torch.optim.SGD(model_student.parameters(
    ), lr=lr, momentum=0.9, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = nn.CrossEntropyLoss()
    best_acc = -1
    best_checkpoint = dict()


    for epoch in range(epochs):
        model_student.train()
        model_teacher.train()
        for inputs, targets in tqdm(train_loader, leave=False):
            inputs = inputs.to(device)
            targets = targets.to(device)

            # Reset the gradients (from the last iteration)
            optimizer.zero_grad()

            # Forward inference
            out_student = model_student(inputs)
            out_teacher = model_teacher(inputs)


            # kd loss
            predict_student = F.log_softmax(out_student / temperature, dim=1)
            predict_teacher = F.softmax(out_teacher / temperature, dim=1)
            loss = nn.KLDivLoss(reduction="batchmean")(predict_student, predict_teacher) * (alpha * temperature * temperature) + criterion(out_student, targets) * (1-alpha)

            loss.backward()


            # Update optimizer
            optimizer.step()

            if callbacks is not None:
                for callback in callbacks:
                    callback()

        acc, val_loss = evaluate(model_student, test_loader)
        logging.info(
            f'KD - Epoch {epoch + 1}/{epochs} | Val acc: {acc:.2f} | Val loss: {val_loss:.4f} | LR: {optimizer.param_groups[0]["lr"]:.6f}')


        if best_acc < acc:
            best_checkpoint['state_dict'] = copy.deepcopy(model_student.state_dict())
            best_acc = acc
        # Update LR scheduler
        scheduler.step()
    model_student.load_state_dict(best_checkpoint['state_dict'])
    if save:
        # on veut sauvegarder le meilleur modèle
        if save_only_state_dict:
            torch.save(model_student.state_dict(), save)
        else:
            torch.save(model_student, save)
    logging.info(f'Best val acc after KD: {best_acc:.2f}')



def get_compression_ratio_and_bitwidth_from_compression_ratio(compression_ratio):
    if compression_ratio <=2:
        return compression_ratio, 32
    elif  2 < compression_ratio <= 4:
        return compression_ratio/2, 16
    else:
        return compression_ratio/4, 8


def get_speed_up_and_bitwidth_from_speed_up(speed_up):
    if speed_up <= 2:
        return speed_up, 32
    elif  2 < speed_up <= 4:
        return speed_up/2, 16
    else:
        return speed_up/4, 8


def apply_pruning_and_kd(model: nn.Module, train_loader: DataLoader, test_loader: DataLoader,
                         example_input: torch.Tensor, num_classes: int,
                         epochs: int = 120, lr: float = 0.01, temperature: float = 4,
                         alpha: float = 0.9, compression_ratio: float = None,
                         speed_up: float = None, random_seed: int = 42) -> nn.Module:

    # Setup
    base_path = "results_experiments"
    run_id = setup_logging(base_path)
    logging.info(f"Model: {model.__class__.__name__}")

    # Set random seed
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    # Prepare models
    model = model.to(device)
    example_input = example_input.to(device)
    model_teacher = copy.deepcopy(model)

    # Initial evaluation
    start_macs, start_params = tp.utils.count_ops_and_params(model, example_input)
    start_acc, start_loss = evaluate(model, test_loader)
    logging.info(f'Initial Model: MACs = {start_macs/1e6:.3f}M, Params = {start_params/1e6:.3f}M, Accuracy = {start_acc:.2f}%, Loss = {start_loss:.3f}')

    # Pruning
    pruner = get_pruner(model, example_input, num_classes)
    if compression_ratio:
        progressive_pruning_compression_ratio(pruner, model, compression_ratio, example_input)
    elif speed_up:
        progressive_pruning_speedup(pruner, model, speed_up, example_input)

    # Knowledge Distillation
    logging.info('Starting Knowledge Distillation')
    train_kd(model, model_teacher, train_loader, test_loader, epochs, lr, temperature, alpha)

    # Final evaluation
    end_macs, end_params = tp.utils.count_ops_and_params(model, example_input)
    end_acc, end_loss = evaluate(model, test_loader, device)
    logging.info(f'Final Model: MACs = {end_macs/1e6:.3f}M, Params = {end_params/1e6:.3f}M, Accuracy = {end_acc:.2f}%, Loss = {end_loss:.3f}')

    # Save model
    torch.save(model.state_dict(), 'final_model.pth')

    return model


def optimize(model,
            traindataloader,
            testdataloader,
            example_input,
            num_classes,
            epochs=120,
            lr=0.01,
            temperature=4,
            alpha=0.9,
            compression_ratio=2,
            speed_up=None,
            bitwidth=16,
            wandb_project=None,
            random_seed=42):
    # on  veut log tous les résultats intermédiaires (modèle régularizé/modèle pruné) dans un dossier results
    if not os.path.exists("results_experiments"):
        os.makedirs("results_experiments")
    # subfolder for each run
    run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    os.makedirs(f"results_experiments/{run_id}")
    base_path = f"results_experiments/{run_id}"
    # logging
    logging.basicConfig(filename=f"{base_path}/log.txt", level=logging.INFO)
    logging.info(f"Run ID: {run_id}")
    logging.info(f"Model: {model.__class__.__name__}")

    example_input = example_input.to(device)

    if wandb_project:
        run = wandb.init(project=wandb_project)
        logging.info("Wandb initialized")
        logging.info(f"Wandb project: {wandb_project}")


    # Fixer le seed pour la reproductibilité
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    # Copy the initial model for KD
    model_teacher = copy.deepcopy(model)

    # Avant pruning
    start_macs, start_params = tp.utils.count_ops_and_params(model, example_input)
    start_acc, start_loss = evaluate(model, testdataloader)
    logging.info(' ----- Initial Model: -----')
    logging.info(f'Number of MACs = {start_macs/1e6:.3f} M')
    logging.info(f'Number of Parameters = {start_params/1e6:.3f} M')
    logging.info(f'Accuracy = {start_acc:.2f} %')
    logging.info(f'Loss = {start_loss:.3f}')
    logging.info(' ---------------------------')
    if wandb_project:
        wandb.run.summary["start_macs (M)"] = f'{start_macs/1e6:.3f}'
        wandb.run.summary["start_params (M)"] = f'{start_params/1e6:.3f}'
        wandb.run.summary["start_acc (%)"] = f'{start_acc:.2f}'
        wandb.run.summary["start_loss"] = f'{start_loss:.3f}'

    if compression_ratio:
        compression_ratio, bitwidth = get_compression_ratio_and_bitwidth_from_compression_ratio(compression_ratio)
    if speed_up:
        speed_up, bitwidth = get_speed_up_and_bitwidth_from_speed_up(speed_up)

    pruner = get_pruner(model, example_input, num_classes)

    logging.info('----- Pruning -----')
    if compression_ratio:
        progressive_pruning_compression_ratio(pruner, model, compression_ratio, example_input)
    else:
        progressive_pruning_speedup(pruner, model, speed_up, example_input)

    # Fine tuning
    logging.info('----- Fine tuning with KD -----')
    train_kd(model, model_teacher, traindataloader, testdataloader, epochs=epochs, lr=lr, temperature=temperature, alpha=alpha,save=f'{base_path}/kd_model.pth')

    # Post fine tuning
    end_macs, end_params = tp.utils.count_ops_and_params(model, example_input)
    end_acc, end_loss = evaluate(model, testdataloader)
    logging.info('----- Results after fine tuning -----')
    logging.info(f'Number of Parameters: {start_params/1e6:.2f} M => {end_params/1e6:.2f} M')
    logging.info(f'MACs: {start_macs/1e6:.2f} M => {end_macs/1e6:.2f} M')
    logging.info(f'Accuracy: {start_acc:.2f} % => {end_acc:.2f} %')
    logging.info(f'Loss: {start_loss:.2f} => {end_loss:.2f}')
    if wandb_project:
        # log les valeurs dans wandb
        wandb.run.summary["best_acc"] = end_acc
        wandb.run.summary["best_loss"] = end_loss
        wandb.run.summary["end macs (M)"] = end_macs/1e6
        wandb.run.summary["end num_params (M)"] = end_params/1e6
        wandb.run.summary["size (MB)"] = get_model_size(model)/8e6

    # Quantization part
    logging.info('----- Quantization -----')
    # free cache memory
    torch.cuda.empty_cache()
    # if user want to choose the bitwidth
    if bitwidth == 8:
        logging.info('Calibrating on train dataset...')
        calib_dataset = list()
        for i, img in enumerate(traindataloader):
            calib_dataset.extend(img[0])
            if i == 2000:
                break
        model_trt = torch2trt(model,[example_input], fp16_mode=True, int8_mode=True, int8_calib_dataset=calib_dataset, max_batch_size=128)
        compression_ratio_quant = 4
    elif bitwidth == 16:
        model_trt = torch2trt(model,[example_input], fp16_mode=True, max_batch_size=128)
        compression_ratio_quant = 2
    else: # run with fp32 if not quantization -> speed up from tensorrt inference engine
        model_trt = torch2trt(model,[example_input], max_batch_size=128)
        bitwidth = 32
        compression_ratio_quant = 1
    logging.info(f"Final Compression Ratio: {start_params/end_params*compression_ratio_quant:.2f}")
    logging.info(f"Bit width: {bitwidth}")
    torch.save(model_trt.state_dict(), f'{base_path}/model_trt.pth')

    return model_trt

### Example
For the example, we will train and then compress a model on CIFAR10 dataset.

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=0)

# Load pre-trained ResNet18 model
model = torchvision.models.resnet18(weights='DEFAULT')

# Modify the last fully connected layer for CIFAR-10 (10 classes)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)

# Move model to device
model = model.to(device)

# Train model for 10 epochs
train(model, trainloader, testloader, epochs=10, lr=0.01)


In [None]:
# Evalute model
acc, loss = evaluate(model, testloader)

print(f'Accuracy = {acc:.2f} Loss = {loss:.3f}')
from pytorch_bench import benchmark, measure_latency_gpu

example_input = torch.randn(1, 3, 32, 32).cuda()
results = benchmark(model, example_input)

### Pruning
Apply first optimization method: Pruning and retrain using knowledge distillation

In [None]:
model = apply_pruning_and_kd(model, trainloader, testloader, epochs=10, lr=0.01, example_input=example_input, num_classes=10, compression_ratio=2)

In [7]:
example_input = torch.randn(1, 3, 32, 32).cuda()
x,y,fps = measure_latency_gpu(model, example_input)
print(fps)

344.0524477536497


In [None]:
# Evalute opimized model
print('### RESULTS PRUNED MODEL ###')
acc, loss = evaluate(model, testloader)
print(f'Accuracy = {acc:.2f} Loss = {loss:.3f}')

results = benchmark(model, example_input)


### Quantization
Now we can apply a second optimization method: quantization

In [None]:
model_trt = optimize(model, trainloader, testloader, example_input, num_classes=10, epochs=1, lr=0.01)

In [10]:
example_input = torch.randn(1, 3, 32, 32).cuda()
x,y,fps = measure_latency_gpu(model_trt, example_input)
print(f'FPS after quantization = {fps:.2f}')

FPS after quantization = 1083.96
