# ‚úÇÔ∏è Model Pruning: Removing Unnecessary Weights

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/model-quantization/blob/main/pruning_demo.ipynb)

Model Pruning is a compression technique that removes redundant parameters from a neural network. This is typically done by setting small weights to zero. 

### Types of Pruning:
1. **Unstructured Pruning**: Individual weights are removed. This leads to sparse matrices but requires specialized hardware/software for real speedup.
2. **Structured Pruning**: Entire neurons, channels, or layers are removed. This leads to smaller dense matrices that are easier to accelerate on standard hardware.

In [None]:
import torch
import torch.nn.utils.prune as prune
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# 1. Load GPT-2
model_id = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_id)
tokenizer = GPT2Tokenizer.from_pretrained(model_id)

print(f"Initial Parameters: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")

## üî™ Global Unstructured Pruning
We will prune 30% of the weights in all linear layers using L1-norm magnitude (removing the smallest weights).

In [None]:
parameters_to_prune = []
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        parameters_to_prune.append((module, 'weight'))

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.3, # Prune 30%
)

print("Global pruning (30%) applied.")

## üìä Verifying Sparsity
Let's check how many weights are now exactly zero.

In [None]:
def calculate_sparsity(model):
    total_zeros = 0
    total_elements = 0
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            total_zeros += torch.sum(module.weight == 0).item()
            total_elements += module.weight.numel()
    
    sparsity = 100. * total_zeros / total_elements
    print(f"Global Sparsity: {sparsity:.2f}%")
    return sparsity

calculate_sparsity(model)

## üèéÔ∏è Generation Test
Does the model still generate coherent text after losing 30% of its connections?

In [None]:
input_text = "Neural network pruning is used to"
inputs = tokenizer(input_text, return_tensors="pt")

with torch.no_grad():
    output = model.generate(**inputs, max_length=30, do_sample=True)

print(tokenizer.decode(output[0], skip_special_tokens=True))