# Model Pruning with a Pre-Trained Transformer
1.  Load a pre-trained sentiment analysis model (`DistilBERT`).
2.  Evaluate its performance on a sample sentence **before** pruning.
3.  Inspect the model's layers to choose a target for pruning.
4.  Apply **unstructured magnitude pruning** to the target layer.
5.  Verify the effect of pruning by measuring the layer's sparsity.
6.  Evaluate the model's performance **after** pruning to observe the impact.

---
## 1. Setup Environment

In [1]:
!pip install datasets transformers torch numpy huggingface_hub -q

In [2]:
!pip install 'accelerate>=0.26.0' --upgrade -q

---
## 2. Import Libraries
We'll import all the necessary components for our task.

In [3]:
import torch
import numpy as np
import torch.nn.utils.prune as prune
from transformers import AutoTokenizer, AutoModelForSequenceClassification

---
## 3. Define Model and Device
Let's specify the model we'll use from the Hugging Face Hub and set up our compute device (GPU if available).

In [4]:
model_name = "distilbert-base-uncased-finetuned-sst-2-english"

In [5]:
print(f"Using model: {model_name}")

Using model: distilbert-base-uncased-finetuned-sst-2-english


In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [7]:
print(f"Using device: {device}")

Using device: cuda


---
## 4. Load Pre-Trained Model and Tokenizer
Here, we download the model and its corresponding tokenizer. We don't need to define the model architecture ourselves.

In [8]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [9]:
model = AutoModelForSequenceClassification.from_pretrained(model_name)

In [10]:
model.to(device)

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)


### DEBUG: Inspect the loaded model structure
Let's print the model object to understand its layers. This is how we'll find a target for pruning.

In [11]:
print(model)

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)


---
## 5. Evaluate the Original Model's Performance
Before we change anything, let's establish a baseline for the model's performance on a sample sentence.

### Create a helper function for evaluation

In [14]:
def evaluate_model(model, tokenizer, sentence, device):
    inputs = tokenizer(sentence, return_tensors="pt").to(device)
    with torch.no_grad():
        logits = model(**inputs).logits

    probabilities = torch.softmax(logits, dim=1).cpu().numpy()[0]
    predicted_class_id = torch.argmax(logits, dim=1).item()
    prediction = model.config.id2label[predicted_class_id]

    print(f'Sentence: "{sentence}"')
    print(f'Prediction: {prediction} (Confidence: {probabilities[predicted_class_id]:.4f})')

In [15]:
test_sentence = "The performances in this movie are absolutely stellar."

In [16]:
print("--- Evaluating Original Model ---")
evaluate_model(model, tokenizer, test_sentence, device)

--- Evaluating Original Model ---
Sentence: "The performances in this movie are absolutely stellar."
Prediction: POSITIVE (Confidence: 0.9999)


---
## 6. The Pruning Process
Now we'll perform the actual pruning. We'll target the **query projection layer** (`q_lin`) in the first attention block of the transformer.

### Step 6.1: Select the Target Layer

In [17]:
module_to_prune = model.distilbert.transformer.layer[0].attention.q_lin

In [18]:
print("Selected module for pruning:")
print(module_to_prune)

Selected module for pruning:
Linear(in_features=768, out_features=768, bias=True)


### Step 6.2: Check Sparsity Before Pruning
Sparsity is the percentage of weights that are zero. For an unpruned model, this should be 0%.

In [19]:
def calculate_sparsity(module):
    return 100. * float(torch.sum(module.weight == 0)) / float(module.weight.nelement())

In [20]:
initial_sparsity = calculate_sparsity(module_to_prune)
print(f"Sparsity before pruning: {initial_sparsity:.2f}%")

Sparsity before pruning: 0.00%


### Step 6.3: Apply Pruning
We will prune 30% of the weights in the layer with the lowest L1 magnitude (i.e., closest to zero).

In [21]:
prune.l1_unstructured(module_to_prune, name="weight", amount=0.3)

Linear(in_features=768, out_features=768, bias=True)

### Step 6.4: Check Sparsity After Pruning
The pruning is applied via a 'forward hook'. The original weights are still there, but a mask is applied. The sparsity calculation should now reflect the pruned weights.

In [22]:
sparsity_after_pruning = calculate_sparsity(module_to_prune)
print(f"Sparsity after applying pruning mask: {sparsity_after_pruning:.2f}%")

Sparsity after applying pruning mask: 30.00%


### Step 6.5: Make the Pruning Permanent
The `prune.remove` function removes the hook and permanently sets the pruned weights to zero in the weight tensor.

In [23]:
prune.remove(module_to_prune, 'weight')

Linear(in_features=768, out_features=768, bias=True)

In [24]:
final_sparsity = calculate_sparsity(module_to_prune)
print(f"Sparsity after making pruning permanent: {final_sparsity:.2f}%")

Sparsity after making pruning permanent: 30.00%


---
## 7. Evaluate the Pruned Model (Before Fine-Tuning)
Now let's see how our model performs on the same sentence after we've removed 30% of the weights from a key layer. We expect a drop in performance or confidence.

In [26]:
print("--- Evaluating Pruned Model ---")
evaluate_model(model, tokenizer, test_sentence, device)

--- Evaluating Pruned Model ---
Sentence: "The performances in this movie are absolutely stellar."
Prediction: POSITIVE (Confidence: 0.9999)


In [27]:
print(model)

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)


In [28]:
from datasets import load_dataset

# Load a small sample of the dataset
train_dataset = load_dataset("imdb", split="train[:500]")
eval_dataset = load_dataset("imdb", split="test[:100]")

# Tokenize the text
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_eval_dataset = eval_dataset.map(tokenize_function, batched=True)

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [29]:
from transformers import Trainer, TrainingArguments

# Define training arguments

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=1,              # One epoch is enough for a quick demo
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    logging_dir="./logs",
    report_to="none",
)

# Create the Trainer instance
trainer = Trainer(
    model=model,                         # Your pruned model
    args=training_args,                  # Training configuration
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_eval_dataset
)

# Start fine-tuning
print("Starting fine-tuning...")
trainer.train()
print("Fine-tuning complete.")

Starting fine-tuning...


Step,Training Loss


Fine-tuning complete.


In [30]:
# Evaluate the model to see the recovered performance
print("--- Evaluating Fine-Tuned Model ---")
evaluation_results = trainer.evaluate()
print(evaluation_results)

--- Evaluating Fine-Tuned Model ---


{'eval_loss': 1.1162684131704737e-05, 'eval_runtime': 1.8064, 'eval_samples_per_second': 55.359, 'eval_steps_per_second': 7.197, 'epoch': 1.0}


In [32]:
new_sentence = "This was a fantastic film, I highly recommend it."
inputs = tokenizer(new_sentence, return_tensors="pt").to(device)

# Use the model directly
with torch.no_grad():
    logits = model(**inputs).logits

predicted_class_id = torch.argmax(logits, dim=1).item()
prediction = model.config.id2label[predicted_class_id]

print(f"Prediction: {prediction}")

Prediction: POSITIVE
