# Pruning Notebook
This notebook demonstrates both unstructured pruning with PyTorch and structured pruning with PyTorch Lightning.

In [1]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import ModelPruning

# Define a simple LightningModule
class MyModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(128, 10)
    def forward(self, x):
        return self.fc(x)

model = MyModel()

# Unstructured pruning: prune 30% of fc weights
prune.l1_unstructured(model.fc, name="weight", amount=0.3)
print(f"Sparsity in fc layer: {100. * float(torch.sum(model.fc.weight == 0)) / model.fc.weight.nelement():.2f}%")

Matplotlib is building the font cache; this may take a moment.
  from .autonotebook import tqdm as notebook_tqdm


Sparsity in fc layer: 30.00%


### Structured Pruning with PyTorch Lightning

In [2]:
# Using Lightning's ModelPruning callback
model = MyModel()
trainer = Trainer(callbacks=[ModelPruning("l1_unstructured", amount=0.5)], max_epochs=1)
print("ModelPruning callback configured for 50% global sparsity.")

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


ModelPruning callback configured for 50% global sparsity.


c:\Users\ricar\Github\Pos_Tech_MLET\src\.venv\Lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
