# Structural + CSR analysis

Objectives of this notebook:

- **Per layer:**
  - `nonzero`, `total`, `sparsity`
  - dimensions (shape of the main weight tensor)
  - estimated FLOPs per token for linear / CSR layers
- **Per group:**
  - `embedding`
  - `attention_linear` (Q/K/V/out)
  - `mlp_linear` (fully-connected layers in the MLP)
  - `lm_head`
  - `other_*`

We compare three setups:

- **Dense**
- **Masked30** (global 30% pruning with dense execution)
- **CSR30** (same pruning, CSR conversion for all prunable linears)


# What the metrics mean in this notebook (clear explanation)

This notebook performs a **structural analysis** of a Transformer model for three variants:

* **Dense** — original model
* **Masked30** — 30% global magnitude pruning (weights set to 0, tensors stay dense)
* **CSR30** — same pruning, but selected Linear layers are replaced by `LinearCSRForward` (true sparse CSR matrices)

The goal is to understand:
-> **Where parameters live**
-> **Where sparsity appears**
-> **How much theoretical compute is removed**
-> **How the structure differs across Dense / Masked / CSR**

---

## Per-layer metrics (`df_layers_*`)

Each **row** in `df_layers` corresponds to **one module inside the model**.
Example rows:

* `model.decoder.layers.0.self_attn.q_proj`
* `model.decoder.layers.3.mlp.fc1`
* `model.decoder.embed_tokens`
* `lm_head`
* etc.

For each module we record:

### **Basic identifiers**

* **`model`** — which variant this row belongs to (`Dense`, `Masked30`, `CSR30`)
* **`module_name`** — full dotted name inside the model
* **`group`** — high-level category:

  * `embedding`
  * `attention_linear` (q/k/v/out projections)
  * `mlp_linear` (feed-forward layers)
  * `lm_head`
  * `norm`
  * `other_linear` / `other`

### **Parameter statistics**

* **`nonzero`** — number of parameters ≠ 0
* **`total`** — total parameters
* **`sparsity`** — fraction of zeros:

$\text{sparsity} = 1 - \frac{\text{nonzero}}{\text{total}}$

* **`shape`** — weight matrix shape (e.g. `(out, in)` for Linear)

### **Compute estimate**

* **Dense Linear**:
  FLOPs per token ≈ `2 * in_features * out_features`
* **CSR Linear**:
  FLOPs per token ≈ `2 * nnz`
  (only non-zeros count)

### **Parameter share**

* **`param_frac`** — proportion of model parameters in that module:

$\text{param_frac} = \frac{\text{total parameters in layer}}{\text{total parameters in model}}$

-> This lets you zoom in on **specific layers** and see how pruning alters their structure.

---

## Group-level metrics (`df_groups_*`)

We aggregate the layer-level stats by `(model, group)`.

For each group (e.g., all attention projections):

* **`nonzero`** — sum of non-zeros across modules
* **`total`** — total parameters
* **`sparsity`**:

$\text{sparsity}*{\text{group}} = 1 - \frac{\text{nonzero}*{\text{group}}}{\text{total}_{\text{group}}}$

* **`flops_per_token`** — sum of FLOPs for all modules in this group
* **`param_frac`** — fraction of the model’s parameters inside this group

-> This highlights which *parts of the network* dominate size and compute.

---

## How to interpret the plots

The notebook generates three comparisons across **Dense / Masked30 / CSR30**:

### **1. Group sparsity**

Shows **where zeroing actually happens**.

* Embeddings / lm_head remain dense (excluded from pruning)
* Attention and MLP linear layers become sparse

### **2. Parameter share**

Tells you **where parameters naturally live**, regardless of sparsity.

* MLP layers → large fraction
* Attention projections → significant
* Embeddings → often huge chunk in small LLMs

### **3. FLOPs per token**

A **theoretical compute estimate**:

* Dense compute ∝ number of weight multiplications
* CSR compute ∝ number of non-zeros
* Masked compute ≈ Dense compute (dense kernels ignore zeros)

This illustrates:

* Masked pruning is **structural sparsity only** (no speedup)
* CSR pruning is **algorithmic sparsity** (real reduction in multiply-adds)

In [1]:
import os, sys, warnings, pandas as pd, torch, torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer

sys.path.append('..'); sys.path.append('../src')

from src.eval.metrics import params_size_and_sparsity
from src.eval.plotting import bar_plot
from src.pruning.policies import apply_global_magnitude_pruning_cpu_safe, select_prunable_linears
from src.pruning.pipeline import freeze_pruning_, convert_linear_weights_to_csr_
from src.wrappers.linear_csr import LinearCSRForward

warnings.filterwarnings('ignore', message='.*Sparse CSR tensor support is in beta state.*')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)

RESULTS_DIR = os.path.join('..', 'results')
STRUCT_DIR = os.path.join(RESULTS_DIR, 'structural_layers')
os.makedirs(STRUCT_DIR, exist_ok=True)


Device: cpu


## 1. Model loading

We load a small model depending on the device:

- On **GPU**: `EleutherAI/pythia-410m` in fp16
- On **CPU**: `facebook/opt-125m` in fp32

If you want to use a local snapshot (for example on Narval), simply replace `model_name` with the local path.


In [2]:
def load_fresh():
    """
    Load a small model depending on the device.

    - CUDA -> EleutherAI/pythia-410m (fp16)
    - CPU  -> facebook/opt-125m     (fp32)
    """
    if device == "cuda":
        model_name = "EleutherAI/pythia-410m"
        torch_dtype = torch.float16
    else:
        model_name = "facebook/opt-125m"
        torch_dtype = None  # fp32

    tok = AutoTokenizer.from_pretrained(model_name)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    kwargs = {}
    if torch_dtype is not None:
        kwargs["torch_dtype"] = torch_dtype

    mdl = AutoModelForCausalLM.from_pretrained(
        model_name,
        **kwargs
    ).to(device).eval()
    print(f"Loaded: {model_name}")
    return mdl, tok, model_name


## 2. Per-layer analysis helpers

We define the following utilities:

- `tensor_stats(t)` → `(nonzero, total)`
- `linear_flops(weight)` → approximate FLOPs per token for a linear layer (≈ `2 * in * out`)
- `classify_module(name, module)` → assign each module to a group (`embedding`, `attention_linear`, `mlp_linear`, `lm_head`, etc.)
- `analyze_layers(model, label)` → returns:
  - a *per-layer* DataFrame, and
  - a *per-group* aggregation.


In [3]:
def tensor_stats(t: torch.Tensor):
    if t.numel() == 0:
        return 0, 0
    nnz = int((t != 0).sum().item())
    total = t.numel()
    return nnz, total

def linear_flops(weight: torch.Tensor):
    """Approximate FLOPs per token for a linear layer.
    We count 2 * in * out (mul + add)."""
    if weight is None or weight.dim() != 2:
        return 0
    out_features, in_features = weight.shape
    return int(2 * in_features * out_features)

def classify_module(name: str, module: nn.Module) -> str:
    lname = name.lower()

    if isinstance(module, nn.Embedding) or 'embed' in lname:
        return 'embedding'

    if 'lm_head' in lname:
        return 'lm_head'

    if isinstance(module, (nn.Linear, LinearCSRForward)):
        # Attention
        if 'attn' in lname or 'attention' in lname or 'self_attn' in lname:
            return 'attention_linear'
        # MLP
        if 'mlp' in lname or 'ff' in lname or 'fc1' in lname or 'fc2' in lname:
            return 'mlp_linear'
        return 'other_linear'

    if 'norm' in lname:
        return 'norm'

    return 'other'

def analyze_layers(model: nn.Module, label: str):
    rows = []

    for name, module in model.named_modules():
        if name == '':
            continue

        nonzero = 0
        total = 0
        flops = 0
        shape = None

        if isinstance(module, LinearCSRForward):
            # Read CSR meta-data
            nonzero = module.meta_nnz
            total = module.meta_total_params
            sparsity = module.meta_sparsity
            shape = tuple(module.meta_dense_shape.tolist())
            # FLOPs ≈ 2 * nnz (mul + add)
            flops = 2 * nonzero
        else:
            # Standard path (nn.Linear, Embedding, etc.)
            for p_name, p in module.named_parameters(recurse=False):
                if not isinstance(p, torch.Tensor):
                    continue
                nnz, tot = tensor_stats(p)
                nonzero += nnz
                total += tot
                if p_name in ('weight', 'weight_orig'):
                    flops += linear_flops(p)
                    if p.dim() == 2:
                        shape = tuple(p.shape)

            if total == 0:
                continue
            sparsity = 1.0 - nonzero / total

        group = classify_module(name, module)
        rows.append({
            'model': label,
            'module_name': name,
            'group': group,
            'nonzero': nonzero,
            'total': total,
            'sparsity': sparsity,
            'shape': str(shape) if shape is not None else '',
            'flops_per_token': flops
        })

    df_layers = pd.DataFrame(rows)
    total_params_model = df_layers['total'].sum()
    df_layers['param_frac'] = df_layers['total'] / total_params_model

    df_groups = (
        df_layers
        .groupby(['model', 'group'], as_index=False)
        .agg({'nonzero': 'sum', 'total': 'sum', 'flops_per_token': 'sum'})
    )
    df_groups['sparsity'] = 1.0 - df_groups['nonzero'] / df_groups['total']
    df_groups['param_frac'] = df_groups['total'] / total_params_model

    return df_layers, df_groups


## 3. Dense — baseline structure

We first analyse the **unpruned dense model**.

For this variant we compute for each layer and group:

- number of non-zero parameters,
- total parameter count,
- sparsity,
- parameter fraction within the model,
- estimated FLOPs per token.


In [4]:
model_dense, tok, model_name = load_fresh()
dense_stats = params_size_and_sparsity(model_dense)
print('Dense global stats:', dense_stats)

df_layers_dense, df_groups_dense = analyze_layers(model_dense, 'Dense')
display(df_groups_dense.sort_values('total', ascending=False))

dense_layers_csv = os.path.join(STRUCT_DIR, 'layers_dense.csv')
dense_groups_csv = os.path.join(STRUCT_DIR, 'groups_dense.csv')
df_layers_dense.to_csv(dense_layers_csv, index=False)
df_groups_dense.to_csv(dense_groups_csv, index=False)
print('Saved:', dense_layers_csv)
print('Saved:', dense_groups_csv)

# Overview plots
bar_plot(df_groups_dense, 'group', 'sparsity', 'Dense: sparsity per group', 'dense_sparsity_per_group.png', STRUCT_DIR, y_min=0.0)
bar_plot(df_groups_dense, 'group', 'param_frac', 'Dense: parameter share per group', 'dense_param_frac_per_group.png', STRUCT_DIR, y_min=0.0)
bar_plot(df_groups_dense, 'group', 'flops_per_token', 'Dense: FLOPs per token per group', 'dense_flops_per_group.png', STRUCT_DIR, y_min=0.0)


Loaded: facebook/opt-125m
Dense global stats: {'nonzero': 125238422, 'total': 125239296, 'sparsity': 6.978640314292406e-06, 'size_mb': 477.75}


Unnamed: 0,model,group,nonzero,total,flops_per_token,sparsity,param_frac
0,Dense,attention_linear,21673527,21673536,56623560423104,4.152638e-07,0.173004
1,Dense,embedding,30671751,30671808,80366959266592,1.862567e-06,0.244873
2,Dense,lm_head,29598169,29598208,77217792721792,1.315463e-06,0.23649
3,Dense,mlp_linear,43384793,43384832,113246208462208,9.028209e-07,0.346477
4,Dense,norm,0,0,0,,


Saved: ..\results\structural_layers\layers_dense.csv
Saved: ..\results\structural_layers\groups_dense.csv
Saved: ..\results\structural_layers\dense_sparsity_per_group.png
Saved: ..\results\structural_layers\dense_param_frac_per_group.png
Saved: ..\results\structural_layers\dense_flops_per_group.png


## 4. Masked30 — global magnitude pruning (30%) on prunable linears

We now apply global magnitude pruning (30%) on the prunable linear layers:

- We prune only the `nn.Linear` modules returned by `select_prunable_linears`.
- The `lm_head` is blacklisted.
- Embeddings and LayerNorms are never pruned.

We then inspect the effect on:

- sparsity per layer / per group,
- theoretical FLOPs per token per group.


In [5]:
SP = 0.30

model_masked, tok_m, _ = load_fresh()
layers_prunable = select_prunable_linears(model_masked, blacklist=("lm_head",))
print('Prunable linear layers:', len(layers_prunable))

apply_global_magnitude_pruning_cpu_safe(layers_prunable, amount=SP)
freeze_pruning_(layers_prunable)

masked_stats = params_size_and_sparsity(model_masked)
print('Masked global stats:', masked_stats)

df_layers_masked, df_groups_masked = analyze_layers(model_masked, f'Masked{int(SP*100)}')
display(df_groups_masked.sort_values('total', ascending=False))

masked_layers_csv = os.path.join(STRUCT_DIR, 'layers_masked30.csv')
masked_groups_csv = os.path.join(STRUCT_DIR, 'groups_masked30.csv')
df_layers_masked.to_csv(masked_layers_csv, index=False)
df_groups_masked.to_csv(masked_groups_csv, index=False)
print('Saved:', masked_layers_csv)
print('Saved:', masked_groups_csv)

bar_plot(df_groups_masked, 'group', 'sparsity', 'Masked30: sparsity per group', 'masked30_sparsity_per_group.png', STRUCT_DIR, y_min=0.0)
bar_plot(df_groups_masked, 'group', 'param_frac', 'Masked30: parameter share per group', 'masked30_param_frac_per_group.png', STRUCT_DIR, y_min=0.0)
bar_plot(df_groups_masked, 'group', 'flops_per_token', 'Masked30: FLOPs per token per group', 'masked30_flops_per_group.png', STRUCT_DIR, y_min=0.0)


SP = 0.3...


## 5. CSR30 — pruning + CSR conversion (GPU-ready)

The pipeline for the CSR variant is:

1. Load a fresh model.
2. Apply the same global 30% pruning on prunable linear layers.
3. Call `freeze_pruning_` to materialise the masks.
4. Run `convert_linear_weights_to_csr_`.
5. Replace all pruned linear layers by `LinearCSRForward` modules.

We then repeat the layer- and group-level analysis.

> On CPU this is slower, but on GPU this corresponds to the “CSR, GPU-ready” configuration that we will benchmark on Narval.


In [6]:
model_csr, tok_c, _ = load_fresh()
layers_prunable_csr = select_prunable_linears(model_csr, blacklist=("lm_head",))
print('Prunable linear layers (CSR):', len(layers_prunable_csr))

apply_global_magnitude_pruning_cpu_safe(layers_prunable_csr, amount=SP)
freeze_pruning_(layers_prunable_csr)
convert_linear_weights_to_csr_(layers_prunable_csr)

# Replace all pruned layers with LinearCSRForward
def find_parent(root, child):
    for _, mod in root.named_modules():
        for cn, cc in mod.named_children():
            if cc is child:
                return mod, cn
    raise RuntimeError('Parent not found')

for lin in layers_prunable_csr:
    parent, attr = find_parent(model_csr, lin)
    csr_module = LinearCSRForward(
        lin.weight.detach(),
        lin.bias.detach() if lin.bias is not None else None
    ).to(device)
    setattr(parent, attr, csr_module)

csr_stats = params_size_and_sparsity(model_csr)
print('CSR global stats:', csr_stats)

df_layers_csr, df_groups_csr = analyze_layers(model_csr, f'CSR{int(SP*100)}')
display(df_groups_csr.sort_values('total', ascending=False))

csr_layers_csv = os.path.join(STRUCT_DIR, 'layers_csr30.csv')
csr_groups_csv = os.path.join(STRUCT_DIR, 'groups_csr30.csv')
df_layers_csr.to_csv(csr_layers_csv, index=False)
df_groups_csr.to_csv(csr_groups_csv, index=False)
print('Saved:', csr_layers_csv)
print('Saved:', csr_groups_csv)

bar_plot(df_groups_csr, 'group', 'sparsity', 'CSR30: sparsity per group', 'csr30_sparsity_per_group.png', STRUCT_DIR, y_min=0.0)
bar_plot(df_groups_csr, 'group', 'param_frac', 'CSR30: parameter share per group', 'csr30_param_frac_per_group.png', STRUCT_DIR, y_min=0.0)
bar_plot(df_groups_csr, 'group', 'flops_per_token', 'CSR30: FLOPs per token per group', 'csr30_flops_per_group.png', STRUCT_DIR, y_min=0.0)


Loaded: facebook/opt-125m
Prunable linear layers (CSR): ...


## 6. Global comparison: Dense vs Masked30 vs CSR30

We merge the per-group DataFrames to build an overview with:

- sparsity per group and per variant,
- parameter share per group and per variant,
- total FLOPs per token per group and per variant.

This makes it easy to see where parameters and compute are concentrated, and how pruning + CSR change the picture.


In [7]:
dfg_d = df_groups_dense.copy();   dfg_d['variant'] = 'Dense'
dfg_m = df_groups_masked.copy();  dfg_m['variant'] = 'Masked30'
dfg_c = df_groups_csr.copy();     dfg_c['variant'] = 'CSR30'

dfg_all = pd.concat([dfg_d, dfg_m, dfg_c], ignore_index=True)
display(dfg_all.sort_values(['group', 'variant']))

pivot_sparsity = dfg_all.pivot(index='group', columns='variant', values='sparsity').fillna(0.0)
pivot_frac = dfg_all.pivot(index='group', columns='variant', values='param_frac').fillna(0.0)
pivot_flops = dfg_all.pivot(index='group', columns='variant', values='flops_per_token').fillna(0.0)

print('\nSparsity per group:')
display(pivot_sparsity)
print('\nParameter share per group:')
display(pivot_frac)
print('\nFLOPs per token per group:')
display(pivot_flops)


## 7. Using these results in the report

With this notebook you can document, with quantitative evidence:

- **Parameter distribution** across groups: embeddings vs attention vs MLP vs `lm_head` (`param_frac`).
- **Where sparsity actually appears**: `sparsity` per group, highlighting that embeddings and the output head typically remain dense.
- **Impact on theoretical compute**: `flops_per_token` per group and per variant, which lets you argue about potential FLOPs savings if CSR kernels are efficient.

You can also zoom in *per layer* using the `df_layers_*` tables to show, for example, that:

- MLP layers often contain more parameters and benefit more from pruning than attention projections.
- Some early or late layers are more sensitive to pruning (which you can correlate with top-1 accuracy or perplexity experiments).
