# Demonstrating Wanda (Pruning by Weights and Activations) on GPT‑2

In this notebook we will:

1. Install dependencies  
2. Load a small LLM (`gpt2`) and keep a copy for comparison  
3. Run a small **calibration set** through the model to collect input activations for one Linear layer  
4. Compute **activation norms** and the Wanda **importance scores**:  



$$
S_{ij} = \lvert W_{ij}\rvert \;\times\; \lVert X_{j}\rVert_2
$$


   on a **per‑output** (per‑row) basis  
5. Zero out the bottom *s*% of scores **within each output neuron**  
6. Compare parameter counts and text generation before vs. after pruning  

This illustrates how Wanda leverages both weight magnitudes and activation statistics to select which weights to prune.

In [1]:
# 1. Install Dependencies
!pip install transformers torch --quiet


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m28.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m31.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m33.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

## 2. Imports and Utilities

We bring in PyTorch, Hugging Face Transformers, and define a helper to count nonzero weights.


In [2]:
import copy
import torch
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer

def count_nonzero(model):
    total, nonzero = 0, 0
    for p in model.parameters():
        total += p.numel()
        nonzero += (p.data != 0).sum().item()
    return total, nonzero


## 3. Load GPT‑2 and Make a Copy

We load `gpt2`, move it to device, and clone it for “before‐pruning” generation.


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer     = GPT2Tokenizer.from_pretrained("gpt2")
model         = GPT2LMHeadModel.from_pretrained("gpt2").to(device).eval()
model_before  = copy.deepcopy(model)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

## 4. Choose a Target Linear Layer and Collect Activations

We’ll prune the first feed‑forward layer (`c_fc`) in the first transformer block.  
To collect its **input activations**, we register a forward hook.


In [14]:
# 1. Remove all forward hooks on both c_fc and c_proj
model.transformer.h[0].mlp.c_fc._forward_hooks.clear()
model.transformer.h[0].mlp.c_proj._forward_hooks.clear()

# 2. Reset the activation list
collected_X = []

# 3. Register only the c_proj hook
def hook_fn(module, inp, outp):
    x = inp[0]                  # [B, L, 3072]
    B, L, C_in = x.shape
    collected_X.append(x.detach().reshape(-1, C_in))

hook = model.transformer.h[0].mlp.c_proj.register_forward_hook(hook_fn)


## 5. Run Calibration Texts

We feed a few example sentences to accumulate activation statistics.


In [15]:
calibration_texts = [
    "The quick brown fox jumps over the lazy dog.",
    "In a distant future, AI and humans will collaborate closely.",
    "Once upon a time, language models could be pruned effectively."
]

for txt in calibration_texts:
    inputs = tokenizer(txt, return_tensors="pt").to(device)
    with torch.no_grad():
        _ = model(**inputs)

# Combine into [total_tokens, 3072]
X_all = torch.cat(collected_X, dim=0)
print("Collected activations shape:", X_all.shape)  # -> [N, 3072]


Collected activations shape: torch.Size([35, 3072])


## 6. Compute Activation Norms

Concatenate all collected activations and compute the ℓ₂ norm for each input channel **j**.


In [16]:
# Compute per‑channel L2 norm: shape (C_in,)
X_norm = torch.norm(X_all, p=2, dim=0)
print("X_norm shape (should equal in‑features of layer):", X_norm.shape)


X_norm shape (should equal in‑features of layer): torch.Size([3072])


## 7. Compute Wanda Importance Scores

For the chosen layer’s weight matrix \( W \) of shape \( (\text{C}_{\text{out}},\, \text{C}_{\text{in}}) \), compute:

$$
S = |W| \times X_{\text{norm}}
$$

by broadcasting.


In [26]:
# Step 7: Compute importance scores S

layer = model.transformer.h[0].mlp.c_proj

# 1. Extract the weight matrix and its magnitude
W = layer.weight.data.clone()      # [C_out, C_in] = [768, 3072]
absW = W.abs()                # [768, 3072]

# 2. Broadcast the activation norms across rows
#    X_norm is [3072], so make it [1, 3072] and multiply
X_norm_row  = X_norm.unsqueeze(0)    # [1, 3072]
print(X_norm_row.shape)
print(absW.shape)
S = absW * X_norm_row      # [768, 3072]

print(f"S shape: {S.shape}  (should match W shape [768, 3072])")


torch.Size([1, 3072])
torch.Size([3072, 768])


RuntimeError: The size of tensor a (768) must match the size of tensor b (3072) at non-singleton dimension 1

## 8. Prune by Zeroing Bottom-s% Scores Per-Output

We’ll prune 30% (`s=0.3`) of the **smallest** scores **within each row** (output neuron).


In [20]:
# Step 8: Zero out the lowest‑importance weights per row

s        = 0.3
C_out, C_in = S.shape
k        = int(C_in * s)  # number of weights to prune per row

# Create a boolean mask of ones (True = keep)
mask = torch.ones_like(S, dtype=torch.bool)

# For each output row i, find its k smallest scores and set mask[i, idx] = False
_, idx = torch.topk(S, k, largest=False, dim=1)  # [768, k]
for i in range(C_out):
    mask[i, idx[i]] = False

# Apply the mask in‑place to the layer’s weights
layer.weight.data *= mask


NameError: name 'S' is not defined

## 9. Cleanup Hook & Compare Sparsity

Remove the forward hook and report total vs. nonzero parameter counts before and after pruning.


In [None]:
# Step 9: Remove hook and compare sparsity

hook.remove()

before_total, before_nonzero = count_nonzero(model_before)
after_total,  after_nonzero  = count_nonzero(model)

print(f"Total parameters:       {before_total:,}")
print(f"Nonzero before pruning: {before_nonzero:,} ({100*before_nonzero/before_total:.1f}% dense)")
print(f"Nonzero after pruning:  {after_nonzero:,} ({100*after_nonzero/after_total:.1f}% dense)")
print(f"Overall sparsity:       {100*(1 - after_nonzero/after_total):.1f}%")


## 10. Compare Text Generation

Prompt both the **original** and the **Wanda‑pruned** model on the same input to observe any differences.


In [None]:
# Step 10: Generate and compare outputs

prompt    = "In a world where LLMs are optimized, we"
inputs    = tokenizer(prompt, return_tensors="pt").to(device)

with torch.no_grad():
    out_before = model_before.generate(**inputs, max_new_tokens=30)
    out_after  = model.generate(**inputs, max_new_tokens=30)

print("=== Original GPT‑2 ===")
print(tokenizer.decode(out_before[0], skip_special_tokens=True))
print("\n=== Wanda‑Pruned GPT‑2 ===")
print(tokenizer.decode(out_after[0], skip_special_tokens=True))
