In [37]:
import torch, torch.nn as nn
from torch.nn.utils import prune
from ultralytics import YOLO


In [None]:
INPUT_PATH = "training_output/runs/detect/train/weights/best.pt"
OUTPUT_PATH = "training_output/runs/detect/train/weights/model_pruned.pt"


In [38]:
model = torch.load(INPUT_PATH)
model


{'date': '2024-12-11T09:39:49.714782',
 'version': '8.3.49',
 'license': 'AGPL-3.0 (https://ultralytics.com/license)',
 'docs': 'https://docs.ultralytics.com',
 'epoch': -1,
 'best_fitness': None,
 'model': DetectionModel(
   (model): Sequential(
     (0): Conv(
       (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
       (bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
       (act): SiLU(inplace=True)
     )
     (1): Conv(
       (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
       (bn): BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
       (act): SiLU(inplace=True)
     )
     (2): C2f(
       (cv1): Conv(
         (conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
         (bn): BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
         (act): SiLU(inplace=True)
       )
       (cv2): 

In [39]:
def sparsity(model):
    '''Return global model sparsity'''
    a = 0
    b = 0

    for p in model.parameters():
        a += p.numel()
        b += (p == 0).sum()

    return b / a


In [40]:
pruning_param = 0.3

for name, m in model['model'].named_modules():
    if isinstance(m, nn.Conv2d) or isinstance(m, torch.nn.Linear):
        prune.l1_unstructured(m, name='weight', amount=pruning_param)  # prune
        prune.remove(m, 'weight')

print(f'Model pruned to {sparsity(model['model']):.3g} global sparsity')


Model pruned to 0.299 global sparsity


In [41]:
torch.save(model, OUTPUT_PATH)
model


{'date': '2024-12-11T09:39:49.714782',
 'version': '8.3.49',
 'license': 'AGPL-3.0 (https://ultralytics.com/license)',
 'docs': 'https://docs.ultralytics.com',
 'epoch': -1,
 'best_fitness': None,
 'model': DetectionModel(
   (model): Sequential(
     (0): Conv(
       (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
       (bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
       (act): SiLU(inplace=True)
     )
     (1): Conv(
       (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
       (bn): BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
       (act): SiLU(inplace=True)
     )
     (2): C2f(
       (cv1): Conv(
         (conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
         (bn): BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
         (act): SiLU(inplace=True)
       )
       (cv2): 