# 2. Pruning

This notebook demonstrates how to prune a model using the `torh.torch.nn.utils.prune` and `torch-pruning` library. Pruning is a technique to reduce the size of a neural network by removing weights that are deemed unnecessary, which can lead to faster inference times and reduced memory usage.

There is 2 types of pruning:
- **Unstructured pruning**: Removes individual weights using an importance metric (e.g., low-magnitude weights are pruned). This can lead to sparse models, which reduce drastically the number of parameters but must rely on specialized hardware and/or libraries to take advantage of the sparsity during inference.
- **Structured pruning**: Removes entire channels or layers, using a metric measuring an entire channel or layer importance (e.g., low-magnitude channels are pruned). This leads to a more regular model that can be used on standard hardware without requiring specialized libraries.

Metrics used for pruning are typically based on the magnitude of weights, gradients, or other statistics that indicate the importance of a weight or a channel.

The process is defined as such:
* A Torch model is loaded.
* A pruning strategy is defined, which specifies how to prune the model (e.g., unstructured or structured pruning, and the importance metric to use).
* The model is pruned using the defined strategy.
* The model is exported PyTorch format for further optimization or deployment.

2 pruning methods will be used in this notebook, both for 2 models (image and audio classification):
* L1-magntiude unstructured pruning using `torch.torch.nn.utils.prune`.
* L1-magnitude structured pruning using `torch-pruning`.

# Setup

## General

In [1]:
import os
from typing import Callable, Literal

from PIL import Image

import torch
import torchvision
import torchaudio
from torch.optim import Optimizer
from torch.amp import GradScaler, autocast
import torch.nn.utils.prune as unstruct_prune
import torch_pruning as struct_prune

import logging
from tqdm import tqdm

import psutil
try:
    import pynvml
    pynvml.nvmlInit()
    pynvml_available = True
except (ImportError, pynvml.NVMLError):
    pynvml_available = False
print(f"pynvml available: {pynvml_available}")

# Setup device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32

# Base directories for datasets and models
BASE_DATA_DIR = "../data"
IMAGE_DATA_DIR = os.path.join(BASE_DATA_DIR, "image")
BASE_MODEL_DIR = "../models"
MODEL_BASELINE_DIR = os.path.join(BASE_MODEL_DIR, "baseline")

# CIFAR-10 dataset directories
CIFAR10_DIR = os.path.join(IMAGE_DATA_DIR, "cifar10")
CIFAR10_TRAIN_DIR = os.path.join(CIFAR10_DIR, "train")
CIFAR10_TRAIN_PT_FILE = os.path.join(CIFAR10_TRAIN_DIR, "data.pt")
CIFAR10_VAL_DIR = os.path.join(CIFAR10_DIR, "val")
CIFAR10_VAL_PT_FILE = os.path.join(CIFAR10_VAL_DIR, "data.pt")
CIFAR10_TEST_DIR = os.path.join(CIFAR10_DIR, "test")
CIFAR10_TEST_PT_FILE = os.path.join(CIFAR10_TEST_DIR, "data.pt")

# MobileNetV2 model directory
MOBILENETV2_CIFAR10_BASELINE_PT_FILE = os.path.join(MODEL_BASELINE_DIR, "mobilenetv2_cifar10.pt")

# Setup logging
logging.basicConfig(level=logging.DEBUG, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
                    handlers=[logging.StreamHandler()])

logger = logging.getLogger(__name__)

pynvml available: True
Using device: cuda


## Models

In [2]:
# 1. Instantiate the base MobileNetV2 architecture
# We use weights=None because we will load our own fine-tuned weights.
mobilnet_v2_cifar10_base = torchvision.models.mobilenet_v2(weights=None)
num_classes_cifar10 = 10

# 2. Adapt the classifier head to match the CIFAR-10 adaptation (10 classes)
# This is necessary so the architecture matches the saved state_dict.
if hasattr(mobilnet_v2_cifar10_base, 'classifier') and isinstance(mobilnet_v2_cifar10_base.classifier, torch.nn.Sequential):
    if hasattr(mobilnet_v2_cifar10_base.classifier[-1], 'in_features'):
        in_features = mobilnet_v2_cifar10_base.classifier[-1].in_features
        mobilnet_v2_cifar10_base.classifier[-1] = torch.nn.Linear(in_features, num_classes_cifar10)
    else:
        # This case should ideally not be hit if MobileNetV2 structure is standard
        logger.error("Could not find 'in_features' in the last layer of the classifier to adapt.")
        raise AttributeError("Could not find 'in_features' in the last layer of the classifier.")
elif hasattr(mobilnet_v2_cifar10_base, 'fc'): # Fallback for models using 'fc'
     in_features = mobilnet_v2_cifar10_base.fc.in_features
     mobilnet_v2_cifar10_base.fc = torch.nn.Linear(in_features, num_classes_cifar10)
else:
    logger.error("Model does not have a known 'classifier' (Sequential) or 'fc' (Linear) attribute to adapt.")
    raise AttributeError("Model does not have a known classifier structure to adapt.")

# 3. Load the saved state_dict
# The MOBILENETV2_CIFAR10_BASELINE_PT_FILE contains the state_dict.
saved_state_dict = torch.load(MOBILENETV2_CIFAR10_BASELINE_PT_FILE, map_location=DEVICE)
mobilnet_v2_cifar10_base.load_state_dict(saved_state_dict)

# 4. Move model to the correct device and set to evaluation mode
mobilnet_v2_cifar10_base.to(DEVICE)
mobilnet_v2_cifar10_base.eval() # Important if you're not immediately training

logger.info(f"MobileNetV2 model for CIFAR-10 loaded from state_dict and prepared on {DEVICE}.")

2025-06-10 13:05:51,220 - __main__ - INFO - MobileNetV2 model for CIFAR-10 loaded from state_dict and prepared on cuda.


## Pruning

### L1 magnitude unstructured pruning

In [3]:
def mark_weights_with_l1_unstructured_pruning(model: torch.nn.Module, 
                                  pruning_amount: float, 
                                  layers_to_prune: tuple = (torch.nn.Linear, torch.nn.Conv2d),
                                  parameter_name: str = "weight") -> torch.nn.Module:
    """
    Applies L1 unstructured pruning to specified layers of a model.

    Args:
        model (torch.nn.Module): The model to prune.
        pruning_amount (float): The fraction of connections to prune (e.g., 0.2 for 20%).
        layers_to_prune (tuple): A tuple of layer types to prune (e.g., (torch.nn.Linear, torch.nn.Conv2d)).
        parameter_name (str): The name of the parameter to prune within the layers (e.g., "weight", "bias").

    Returns:
        torch.nn.Module: The model with pruning applied (reparameterized).
    """
    if not (0.0 < pruning_amount < 1.0):
        raise ValueError("Pruning amount must be between 0.0 and 1.0 (exclusive).")

    logger.info(f"Applying L1 unstructured pruning with amount: {pruning_amount:.2f} for parameter '{parameter_name}' in layers: {[layer.__name__ for layer in layers_to_prune]}")
    num_pruned_layers = 0
    for module in model.modules():
        if isinstance(module, layers_to_prune):
            try:
                unstruct_prune.l1_unstructured(module, name=parameter_name, amount=pruning_amount)
                logger.debug(f"Pruned {parameter_name} of layer: {module}")
                num_pruned_layers +=1
            except Exception as e:
                logger.warning(f"Could not prune {parameter_name} of layer {module}: {e}")
    
    if num_pruned_layers == 0:
        logger.warning("No layers were pruned. Check 'layers_to_prune' and model structure.")
    else:
        logger.info(f"Applied L1 unstructured pruning to {num_pruned_layers} layers.")
    
    return model


def remove_pruning_reparameterization(model: torch.nn.Module,
                                      layers_to_prune: tuple = (torch.nn.Linear, torch.nn.Conv2d),
                                      parameter_name: str = "weight") -> torch.nn.Module:
    """
    Removes the pruning reparameterization, making the pruning permanent.
    The pruned weights are set to zero directly in the parameter tensor.

    Args:
        model (nn.Module): The model with pruning applied.
        layers_to_prune (tuple): A tuple of layer types from which to remove reparameterization.
        parameter_name (str): The name of the parameter that was pruned.

    Returns:
        nn.Module: The model with pruning made permanent.

    Notes:
        Must be called after `mark_weights_with_l1_unstructured_pruning`.
    """
    logger.info("Making pruning permanent by removing reparameterization...")
    num_permanent_layers = 0
    for module in model.modules():
        if isinstance(module, layers_to_prune):
            if unstruct_prune.is_pruned(module): # Check if the module has pruning hooks
                try:
                    unstruct_prune.remove(module, name=parameter_name)
                    logger.debug(f"Made pruning permanent for {parameter_name} of layer: {module}")
                    num_permanent_layers +=1
                except Exception as e:
                    logger.warning(f"Could not make pruning permanent for {parameter_name} of layer {module}: {e}")
            
    if num_permanent_layers == 0:
        logger.warning("No pruning reparameterization was removed. Was the model pruned?")
    else:
        logger.info(f"Made pruning permanent for {num_permanent_layers} layers.")
    return model


def calculate_sparsity(model: torch.nn.Module, 
                       layers_to_check: tuple = (torch.nn.Linear, torch.nn.Conv2d),
                       parameter_name: str = "weight") -> dict[str, float]:
    """
    Calculates the sparsity of specified parameters in the model.

    Args:
        model (nn.Module): The model to check.
        layers_to_check (tuple): Layer types to inspect.
        parameter_name (str): Name of the parameter (e.g., "weight").

    Returns:
        Dict[str, float]: A dictionary containing overall sparsity and sparsity per layer.

    Notes:
        Must be applied after `mark_weights_with_l1_unstructured_pruning` and `remove_pruning_reparameterization`.
    """
    results = {}
    total_zeros = 0
    total_elements = 0

    for name, module in model.named_modules():
        if isinstance(module, layers_to_check):
            if hasattr(module, parameter_name):
                param = getattr(module, parameter_name)
                if param is not None:
                    # If pruning is applied but not made permanent, the original tensor might not be zero.
                    # We need to check the 'weight_orig' and 'weight_mask' if they exist.
                    # However, if remove_pruning_reparameterization has been called,
                    # the 'weight' tensor itself will contain zeros.
                    
                    # For simplicity after `remove`, we check the actual parameter.
                    # If `remove` hasn't been called, this will report sparsity of the
                    # underlying tensor before the mask is applied during forward pass.
                    # For true sparsity *during* forward pass before `remove`, one would
                    # need to access module.weight_mask and module.weight_orig.
                    
                    layer_zeros = torch.sum(param.data == 0).item()
                    layer_elements = param.data.numel()
                    total_zeros += layer_zeros
                    total_elements += layer_elements
                    if layer_elements > 0:
                        layer_sparsity = layer_zeros / layer_elements
                        results[f"{name}.{parameter_name}_sparsity"] = layer_sparsity
                        logger.debug(f"Sparsity of {name}.{parameter_name}: {layer_sparsity:.4f} ({layer_zeros}/{layer_elements})")
                    else:
                        results[f"{name}.{parameter_name}_sparsity"] = 0.0
                        logger.debug(f"Layer {name}.{parameter_name} has 0 elements.")


    overall_sparsity = total_zeros / total_elements if total_elements > 0 else 0.0
    results["overall_sparsity"] = overall_sparsity
    logger.info(f"Overall sparsity ({parameter_name}): {overall_sparsity:.4f} ({total_zeros}/{total_elements})")
    
    return results

def apply_l1_unstructured_pruning(model: torch.nn.Module, 
                                  pruning_amount: float, 
                                  layers_to_prune: tuple = (torch.nn.Linear, torch.nn.Conv2d),
                                  parameter_name: str = "weight") -> torch.nn.Module:
    """
    Applies L1 unstructured pruning to the model and returns the pruned model.

    Args:
        model (torch.nn.Module): The model to prune.
        pruning_amount (float): The fraction of connections to prune.
        layers_to_prune (tuple): Layers to apply pruning to.
        parameter_name (str): The name of the parameter to prune.

    Returns:
        torch.nn.Module: The pruned model.
    """
    model = mark_weights_with_l1_unstructured_pruning(model, pruning_amount, layers_to_prune, parameter_name)
    model = remove_pruning_reparameterization(model, layers_to_prune, parameter_name)
    logger.info("L1 unstructured pruning completed.")
    _ = calculate_sparsity(model, layers_to_prune, parameter_name)
    return model

### L1 magnitude structured pruning

In [4]:
def count_model_parameters(model: torch.nn.Module, only_trainable: bool = True) -> int:
    """Counts the total number of parameters in a model."""
    if only_trainable:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    return sum(p.numel() for p in model.parameters())

def apply_l1_structured_pruning(
    model: torch.nn.Module,
    example_inputs: torch.Tensor,
    pruning_amount: float,
    layers_to_prune: tuple[type, ...] = (torch.nn.Conv2d, torch.nn.Linear),
    ignored_layers: list[torch.nn.Module] | None = None,
    prune_output_channels: bool = True
) -> torch.nn.Module:
    """
    Applies L1 magnitude structured pruning to specified layers of a model
    by removing a fraction of channels/features from each targeted layer.

    By default, it prunes output channels of Conv2d layers and output features
    of Linear layers.

    Args:
        model (nn.Module): The model to prune.
        example_inputs (torch.Tensor): A batch of example inputs for dependency tracing.
                                       Should be on the same device as the model.
        pruning_amount (float): The fraction of channels/features to prune from each
                                targeted layer (e.g., 0.2 for 20%).
        layers_to_prune (tuple[type, ...]): Tuple of layer types to prune.
        ignored_layers (list[torch.nn.Module] | None): A list of specific layer
                                                    modules to ignore during pruning.
        prune_output_channels (bool): If True, prunes output channels/features.
                                      If False, attempts to prune input channels/features
                                      (Note: Input channel importance calculation here is simplified).

    Returns:
        torch.nn.Module: The model with structured pruning applied. The model is modified in-place.
    """
    if not (0.0 <= pruning_amount < 1.0):
        raise ValueError("Pruning amount must be between 0.0 (inclusive) and 1.0 (exclusive).")

    if pruning_amount == 0.0:
        logger.info("Pruning amount is 0.0. No structured pruning will be applied.")
        return model

    # torch-pruning usually expects the model in eval mode for graph construction
    original_mode_is_train = model.training
    model.eval()

    device = next(model.parameters()).device
    example_inputs = example_inputs.to(device)

    logger.info(f"Applying L1 structured pruning with amount: {pruning_amount:.2f}")
    initial_params = count_model_parameters(model)
    logger.info(f"Initial model parameters: {initial_params}")

    DG = struct_prune.DependencyGraph()
    DG.build_dependency(model, example_inputs=example_inputs)

    num_pruned_overall_layers = 0

    for name, module in model.named_modules():
        if isinstance(module, layers_to_prune):
            if ignored_layers and module in ignored_layers:
                logger.debug(f"Skipping ignored layer: {name} ({type(module).__name__})")
                continue

            current_channels = 0
            pruning_fn = None
            dim_type = ""
            weights = module.weight.data

            if prune_output_channels:
                if isinstance(module, torch.nn.Conv2d):
                    current_channels = module.out_channels
                    pruning_fn = struct_prune.prune_conv_out_channels
                    dim_type = "output channels"
                    # L1 norm for each output filter: (C_out, C_in, K_h, K_w) -> sum over C_in, K_h, K_w
                    channel_importance = torch.norm(weights.flatten(1), p=1, dim=1)
                elif isinstance(module, torch.nn.Linear):
                    current_channels = module.out_features
                    pruning_fn = struct_prune.prune_linear_out_features
                    dim_type = "output features"
                    # L1 norm for each output feature's weights: (F_out, F_in) -> sum over F_in
                    channel_importance = torch.norm(weights, p=1, dim=1)
                else: # Should not be reached if layers_to_prune is respected
                    continue
            else: # Pruning input channels
                if isinstance(module, torch.nn.Conv2d):
                    current_channels = module.in_channels
                    pruning_fn = struct_prune.prune_conv_in_channels
                    dim_type = "input channels"
                    # Simplified L1 for input channels: (C_out, C_in, K_h, K_w) -> transpose to (C_in, C_out, K_h, K_w)
                    channel_importance = torch.norm(weights.transpose(0,1).contiguous().flatten(1), p=1, dim=1)
                elif isinstance(module, torch.nn.Linear):
                    current_channels = module.in_features
                    pruning_fn = struct_prune.prune_linear_in_features
                    dim_type = "input features"
                    # Simplified L1 for input features: (F_out, F_in) -> transpose to (F_in, F_out)
                    channel_importance = torch.norm(weights.T.contiguous(), p=1, dim=1)
                else: # Should not be reached
                    continue

            if current_channels == 0:
                logger.warning(f"Layer {name} ({type(module).__name__}) has 0 {dim_type} to prune. Skipping.")
                continue

            num_to_prune = int(pruning_amount * current_channels)

            if num_to_prune == 0:
                logger.debug(f"Layer {name}: No {dim_type} to prune with amount {pruning_amount:.2f} (Total: {current_channels}).")
                continue

            # Ensure we don't prune all channels, as torch-pruning might error or lead to a dead network.
            # It's safer to leave at least one channel.
            if num_to_prune >= current_channels:
                num_to_prune = current_channels - 1
                logger.warning(
                    f"Layer {name}: Pruning amount {pruning_amount:.2f} would remove all or too many {dim_type}. "
                    f"Adjusting to prune {num_to_prune} {dim_type} to keep at least one."
                )
                if num_to_prune <= 0: # If current_channels was 1
                    logger.info(f"Layer {name}: Cannot prune, only 1 {dim_type} exists. Skipping.")
                    continue
            
            # Get indices of channels to prune (those with smallest L1 norm)
            sorted_channel_indices = torch.argsort(channel_importance)
            pruning_indices = sorted_channel_indices[:num_to_prune].tolist()

            try:
                pruning_plan = DG.get_pruning_plan(module, pruning_fn, idxs=pruning_indices)
                if pruning_plan:
                    logger.debug(f"Pruning {num_to_prune} {dim_type} from layer {name} ({type(module).__name__}). Smallest L1 norm indices (first 10): {pruning_indices[:10]}...")
                    pruning_plan.exec()
                    num_pruned_overall_layers += 1
                else:
                    logger.warning(f"Could not generate pruning plan for layer {name} ({type(module).__name__}).")
            except Exception as e:
                logger.error(f"Failed to prune layer {name} ({type(module).__name__}): {e}", exc_info=True)

    if num_pruned_overall_layers == 0:
        logger.warning("No layers were structurally pruned. Check 'layers_to_prune', model structure, and pruning_amount.")
    else:
        logger.info(f"Applied structured pruning to {num_pruned_overall_layers} layers.")

    final_params = count_model_parameters(model)
    logger.info(f"Final model parameters after structured pruning: {final_params}")
    if initial_params > 0 :
        reduction_percent = (initial_params - final_params) / initial_params * 100
        logger.info(f"Parameter reduction: {initial_params - final_params} ({reduction_percent:.2f}%)")
    else:
        logger.info(f"Parameter reduction: {initial_params - final_params}")


    if original_mode_is_train:
        model.train() # Set back to original mode

    return model

# L1 unstructured pruning

## One-step pruning

In [5]:
# Apply L1 unstructured pruning to the MobileNetV2 model
mobilnet_v2_cifar10_unstruct_prune = apply_l1_unstructured_pruning(mobilnet_v2_cifar10_base, pruning_amount=0.2)

2025-06-10 13:05:54,967 - __main__ - INFO - Applying L1 unstructured pruning with amount: 0.20 for parameter 'weight' in layers: ['Linear', 'Conv2d']
2025-06-10 13:05:55,141 - __main__ - DEBUG - Pruned weight of layer: Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
2025-06-10 13:05:55,142 - __main__ - DEBUG - Pruned weight of layer: Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
2025-06-10 13:05:55,143 - __main__ - DEBUG - Pruned weight of layer: Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
2025-06-10 13:05:55,144 - __main__ - DEBUG - Pruned weight of layer: Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
2025-06-10 13:05:55,144 - __main__ - DEBUG - Pruned weight of layer: Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=96, bias=False)
2025-06-10 13:05:55,145 - __main__ - DEBUG - Pruned weight of layer: Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=Fa

## Iterative pruning

# L1 structured pruning

## One-step pruning

In [6]:
mobilnet_v2_cifar10_struct_prune = apply_l1_structured_pruning(
    mobilnet_v2_cifar10_base,
    example_inputs=torch.randn(1, 3, 224, 224, device=DEVICE, dtype=DTYPE),
    pruning_amount=0.2,
    layers_to_prune=(torch.nn.Conv2d, torch.nn.Linear),
    prune_output_channels=True
)

2025-06-10 13:06:07,582 - __main__ - INFO - Applying L1 structured pruning with amount: 0.20
2025-06-10 13:06:07,583 - __main__ - INFO - Initial model parameters: 2236682


RuntimeError: Input type (CUDABFloat16Type) and weight type (torch.cuda.FloatTensor) should be the same