# FinCompress — Notebook 04: Structured Pruning

**RUN ON: Colab/GPU**

## Structured vs. Unstructured Pruning

Pruning removes unimportant weights from a neural network. There are two families:

**Unstructured pruning** sets individual weights to zero. The result is a *sparse* weight matrix with the same shape. Standard matrix multiplication has no way to skip zero entries — it runs the full computation anyway. Real latency speedup requires specialized sparse BLAS kernels that are not available in standard PyTorch.

**Structured pruning** removes entire structural units: in transformers, this means whole attention heads or entire FFN neurons. The result is a smaller *dense* matrix that is genuinely faster on all hardware without any specialized kernels.

## Why BERT is Overparameterized for Downstream Tasks

BERT was designed for general language understanding — it needs all 12 layers and 12 heads to handle diverse tasks. For a specific 3-class sentiment task on financial text, many of these heads become redundant. Studies show that 30-50% of heads can be removed with minimal accuracy loss on classification tasks (Michel et al., NeurIPS 2019).

## Attention Head Importance Score

We use entropy-based importance: a head with focused attention patterns (low entropy) is more important than one that attends uniformly (high entropy). Score = 1 - normalized_entropy. Cheap to compute, no gradient needed, correlates well with gradient-based importance in practice.

In [None]:
import os, sys
PROJECT_PATH = '/content/drive/MyDrive/fincompress'
os.chdir(PROJECT_PATH)
sys.path.insert(0, PROJECT_PATH)

In [None]:
# Run iterative pruning with recovery fine-tuning
!python -m fincompress.pruning.prune_finetune

In [None]:
# Head importance heatmap BEFORE pruning
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import pandas as pd
from fincompress.pruning.structured_pruning import get_teacher_head_importance_proxy

# Load teacher
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = AutoTokenizer.from_pretrained('fincompress/checkpoints/teacher/tokenizer')
teacher = AutoModelForSequenceClassification.from_pretrained('fincompress/checkpoints/teacher').to(device)
teacher.eval()

# Simple dataset for calibration
class DS(Dataset):
    def __init__(self, df, tok, ml):
        self.texts = df['text'].tolist(); self.labels = df['label'].tolist()
        self.tok = tok; self.ml = ml
    def __len__(self): return len(self.texts)
    def __getitem__(self, i):
        e = self.tok(self.texts[i], max_length=self.ml, padding='max_length', truncation=True, return_tensors='pt')
        return {'input_ids': e['input_ids'].squeeze(0), 'attention_mask': e['attention_mask'].squeeze(0),
                'token_type_ids': e.get('token_type_ids', torch.zeros(self.ml, dtype=torch.long)).squeeze(0), 'label': torch.tensor(self.labels[i])}

df_val = pd.read_csv('fincompress/data/val.csv')
val_loader = DataLoader(DS(df_val, tokenizer, 128), batch_size=32)

# Compute importance
importance = torch.zeros(teacher.config.num_hidden_layers, teacher.config.num_attention_heads)
count = 0
with torch.no_grad():
    for bi, batch in enumerate(val_loader):
        if bi >= 30: break
        out = teacher(batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device), output_attentions=True)
        seq = batch['input_ids'].shape[1]
        eps = 1e-9
        max_ent = torch.log(torch.tensor(seq, dtype=torch.float))
        for li, aw in enumerate(out.attentions):
            ent = -(aw * (aw + eps).log()).sum(dim=-1).mean(dim=-1)  # [batch, heads]
            norm_ent = ent / (max_ent + eps)
            importance[li] += (1 - norm_ent).mean(dim=0).cpu()
        count += 1
importance /= count

fig, ax = plt.subplots(figsize=(12, 5))
sns.heatmap(importance.numpy(), annot=True, fmt='.2f', cmap='YlOrRd', ax=ax,
            xticklabels=[f'H{i}' for i in range(importance.shape[1])],
            yticklabels=[f'L{i}' for i in range(importance.shape[0])])
ax.set_xlabel('Head Index')
ax.set_ylabel('Layer Index')
ax.set_title('Attention Head Importance Scores (pre-pruning)\n(higher = more important, red = prime pruning candidate)')
plt.tight_layout()
plt.show()

In [None]:
# Pruning curve plot
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

curve_df = pd.read_csv('fincompress/results/pruning_curve.csv')

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(curve_df['heads_pruned_pct'], curve_df['val_f1'], 'o-', color='steelblue', linewidth=2)

# Find 'cliff': first point where delta F1 < -0.02 from peak
peak_f1 = curve_df['val_f1'].max()
cliff_row = curve_df[curve_df['val_f1'] < peak_f1 - 0.02]
if not cliff_row.empty:
    cliff_x = cliff_row.iloc[0]['heads_pruned_pct']
    ax.axvline(x=cliff_x, color='red', linestyle='--', label=f'Accuracy cliff ({cliff_x:.0f}%)')
    ax.axvspan(cliff_x, curve_df['heads_pruned_pct'].max() + 5, alpha=0.15, color='red')
    ax.annotate('Accuracy cliff', xy=(cliff_x, peak_f1 - 0.015), xytext=(cliff_x + 3, peak_f1 - 0.005),
                arrowprops=dict(arrowstyle='->', color='red'), color='red')

ax.set_xlabel('Cumulative Heads Pruned (%)')
ax.set_ylabel('Val Macro F1')
ax.set_title('Validation F1 vs. Attention Head Sparsity\n(red region = accuracy cliff)')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Pruning summary: which heads were pruned first, which layers most robust
import pandas as pd

curve_df = pd.read_csv('fincompress/results/pruning_curve.csv')
print('Pruning Curve Summary')
print('=' * 50)
print(curve_df.to_string(index=False))
print()
print(f'Peak val F1:  {curve_df["val_f1"].max():.4f} (at {curve_df.loc[curve_df["val_f1"].idxmax(), "heads_pruned_pct"]:.0f}% sparsity)')
print(f'Final val F1: {curve_df["val_f1"].iloc[-1]:.4f} (at {curve_df["heads_pruned_pct"].iloc[-1]:.0f}% sparsity)')
print(f'Total F1 drop: {curve_df["val_f1"].max() - curve_df["val_f1"].iloc[-1]:+.4f}')

## Interpretation: What Does It Mean That Certain Heads Are Prunable?

If 30-50% of attention heads can be removed with minimal accuracy loss, those heads were:

1. **Redundant**: Multiple heads in the same layer attending to similar patterns — only one is needed
2. **Task-irrelevant**: Heads useful for general language understanding but not sentiment classification
3. **Overfit to pre-training**: Heads specialized for masked language modeling tasks that don't transfer to sentiment

The heads that survive pruning tend to be:
- **Lower layers**: Syntactic patterns (subject-verb agreement, negation scope) that are fundamental to sentiment interpretation
- **Final layer heads**: Task-relevant attention patterns that directly influence the [CLS] representation used for classification

This pattern confirms that BERT-style models are significantly overparameterized for narrow downstream tasks — a key motivation for the entire compression field.