# Experiment 1: Computing Sensitivity Layer-by-Layer

This experiment computes layer sensitivity using two methods: the Fisher Information Matrix and the L1 Norm (Magnitude). Magnitude pruning acts as a blunt instrument and fails to distinguish layer importance, whereas Fisher Information reveals a high dynamic range with sharp peaks of sensitivity across the network.

---

## Magnitude vs Fisher Behavior

In `magnitude_gsm8k.json`, the sensitivity scores for MLP layers cluster tightly. For example:

- Layer 4: 440,320  
- Layer 5: 430,080  
- Layer 6: 417,792  

These values are nearly identical, demonstrating that magnitude-based scoring does not meaningfully separate critical vs. non-critical components.

In contrast, `fisher_gsm8k.json` exhibits large variations in importance. For instance:

- Layer 1 (down_proj): 8.47  
- Layer 0 (gate_proj): 0.33  

This represents a ~25Ã— difference in sensitivity.

---

## Peak Sensitivity Values

The highest non-embedding Fisher values are:

| Dataset     | Highest Layer | Score |
|-------------|---------------|-------|
| GSM8K (Math) | layers.0.self_attn.v_proj | 22.0 |
| WikiText (English) | layers.0.self_attn.v_proj | 431.0 |

In both datasets, `model.layers.0.self_attn.v_proj` is the single most sensitive weight matrix.

---

## Dataset-Dependent Interpretation

Early MLP layers are disproportionately important in mathematical reasoning relative to general language modeling.

Example: Layer 1 MLP `down_proj`

- Math context: score = **8.47**  
  - Relative to max score (22.0): ~**38%**
- English context: score = **18.57**  
  - Relative to max score (431.0): ~**4%**

So sensitivity structure is task-dependent, and mathematical reasoning utilizes early MLP pathways more aggressively than general English workloads.

---

## Attention Projection Sensitivity

Across both datasets, Value projections (`v_proj`) dominate sensitivity, while Query projections (`q_proj`) are nearly irrelevant for quantization stability.

Example from `fisher_gsm8k.json`:

- Layer 0 v_proj: **21.97** (critical)
- Layer 0 q_proj: **0.04** (negligible)

---



In [1]:
import os, sys
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
stderr = sys.stderr
sys.stderr = open(os.devnull, 'w')

In [None]:
!pip install torch transformers accelerate datasets bitsandbytes scipy tqdm numpy matplotlib sentencepiece protobuf

In [None]:
!git clone https://github.com/TarunNagarajan/TaskQuant.git

In [None]:
!pip install --upgrade datasets

In [5]:
import warnings
import logging

warnings.filterwarnings('ignore')
logging.getLogger('datasets.load').setLevel(logging.CRITICAL)

In [6]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [7]:
import logging
logging.getLogger("tensorflow").setLevel(logging.ERROR)

In [None]:
%cd TaskQuant
!git pull

## Fisher Computation: GSM8K


In [None]:
!python scripts/01_compute_sensitivity.py --method fisher --dataset gsm8k

## Magnitude: GSM8K

In [None]:
!python scripts/01_compute_sensitivity.py --method magnitude --dataset gsm8k

## Fisher Computation: Wikitext 2

In [None]:
!python scripts/01_compute_sensitivity.py --method fisher --dataset wikitext

## Magnitude: Wikitext 2

In [None]:
!python scripts/01_compute_sensitivity.py --method magnitude --dataset wikitext