<a href="https://colab.research.google.com/github/Aaban-Saad/CSE465-post-training-pruning/blob/main/pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install ultralytics

In [93]:
import torch
import torch.nn.utils.prune as prune
import torch.nn as nn
from ultralytics import YOLO  # Import YOLO model

# Load YOLOv11 model architecture
model = YOLO("yolo11x.pt")  # Load model structure

print(model.model)

DetectionModel(
  (model): Sequential(
    (0): Conv(
      (conv): Conv2d(3, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(96, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (1): Conv(
      (conv): Conv2d(96, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (2): C3k2(
      (cv1): Conv(
        (conv): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(192, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (cv2): Conv(
        (conv): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(384, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
   

In [94]:
def get_pruning_ratio(layer_idx, total_layers, min_prune=0.30, max_prune=0.05):
    """
    Determines the pruning amount for a layer based on its depth.
    - `min_prune`: Minimum pruning ratio (e.g., 10%)
    - `max_prune`: Maximum pruning ratio (e.g., 50%)
    - Uses a linear scaling formula: deeper layers get pruned more.
    """
    return min_prune + (max_prune - min_prune) * (layer_idx / total_layers)

def structured_prune_model(model):
    """
    Prunes deeper layers more aggressively while keeping initial layers less pruned.
    """
    conv_layers = [module for name, module in model.model.named_modules() if isinstance(module, nn.Conv2d)]
    total_layers = len(conv_layers)

    for idx, module in enumerate(conv_layers):
        pruning_ratio = get_pruning_ratio(idx, total_layers)
        prune.l1_unstructured(module, name='weight', amount=pruning_ratio)
        prune.remove(module, 'weight')  # Remove pruning reparameterization

        print(f"Layer {idx+1}/{total_layers} pruned by {pruning_ratio:.2%}")

    return model

# Prune the model using structured pruning
pruned_model = structured_prune_model(model)

# After pruning, evaluate the model
pruned_model.model.eval()  # Set model to eval mode explicitly
results_after_pruning = pruned_model.val(data="coco8.yaml")
print(f"After pruning - MAP: {results_after_pruning.box.map}")

# Save the pruned model
pruned_model_path = "yolo11s_pruned.pt"
pruned_model.save(pruned_model_path)  # Save entire pruned model
print(f"Pruned model saved as {pruned_model_path}")

Layer 1/174 pruned by 30.00%
Layer 2/174 pruned by 29.86%
Layer 3/174 pruned by 29.71%
Layer 4/174 pruned by 29.57%
Layer 5/174 pruned by 29.43%
Layer 6/174 pruned by 29.28%
Layer 7/174 pruned by 29.14%
Layer 8/174 pruned by 28.99%
Layer 9/174 pruned by 28.85%
Layer 10/174 pruned by 28.71%
Layer 11/174 pruned by 28.56%
Layer 12/174 pruned by 28.42%
Layer 13/174 pruned by 28.28%
Layer 14/174 pruned by 28.13%
Layer 15/174 pruned by 27.99%
Layer 16/174 pruned by 27.84%
Layer 17/174 pruned by 27.70%
Layer 18/174 pruned by 27.56%
Layer 19/174 pruned by 27.41%
Layer 20/174 pruned by 27.27%
Layer 21/174 pruned by 27.13%
Layer 22/174 pruned by 26.98%
Layer 23/174 pruned by 26.84%
Layer 24/174 pruned by 26.70%
Layer 25/174 pruned by 26.55%
Layer 26/174 pruned by 26.41%
Layer 27/174 pruned by 26.26%
Layer 28/174 pruned by 26.12%
Layer 29/174 pruned by 25.98%
Layer 30/174 pruned by 25.83%
Layer 31/174 pruned by 25.69%
Layer 32/174 pruned by 25.55%
Layer 33/174 pruned by 25.40%
Layer 34/174 pruned

[34m[1mval: [0mScanning /content/datasets/coco8/labels/val.cache... 4 images, 0 backgrounds, 0 corrupt: 100%|██████████| 4/4 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 1/1 [00:19<00:00, 19.90s/it]


                   all          4         17      0.868      0.931      0.979      0.853
                person          3         10          1      0.713        0.9      0.614
                   dog          1          1      0.813          1      0.995      0.995
                 horse          1          2      0.851          1      0.995      0.806
              elephant          1          2          1      0.875      0.995      0.714
              umbrella          1          1      0.752          1      0.995      0.995
          potted plant          1          1      0.792          1      0.995      0.995
Speed: 2.0ms preprocess, 4960.3ms inference, 0.0ms loss, 1.8ms postprocess per image
Results saved to [1mruns/detect/val49[0m
After pruning - MAP: 0.8532307371965265
Pruned model saved as yolo11s_pruned.pt
