# S1 — Minimal: Dense → Masked → CSR
Simple, readable baseline with robust perplexity and measured sparsity.

In [1]:
import os, sys, warnings, pandas as pd, torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
sys.path.append('..'); sys.path.append('../src')
from src.eval.metrics import params_size_and_sparsity, eval_ppl_causal
from src.eval.utils import measure_latency_ms
from src.eval.csvlog import append_row
from src.eval.plotting import bar_plot
from src.pruning.policies import apply_global_magnitude_pruning_cpu_safe, select_prunable_linears
from src.pruning.pipeline import freeze_pruning_, convert_linear_weights_to_csr_
from src.wrappers.linear_csr import LinearCSRForward
warnings.filterwarnings('ignore', message='.*Sparse CSR tensor support is in beta state.*')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)
RESULTS_DIR = os.path.join('..','results'); CSV_PATH = os.path.join(RESULTS_DIR,'S1_minimal.csv')
os.makedirs(RESULTS_DIR, exist_ok=True)
pd.DataFrame(columns=["setup","size_mb","sparsity","latency_ms","perplexity"]).to_csv(CSV_PATH, index=False)

def load_fresh():
    """
    Load exactly one model depending on the device:
      - CUDA  -> EleutherAI/pythia-410m (fp16)
      - CPU   -> facebook/opt-125m     (fp32)
    """
    if device == "cuda":
        model_name = "EleutherAI/pythia-410m"
        torch_dtype = torch.float16
    else:
        model_name = "facebook/opt-125m"
        torch_dtype = None  # use default (fp32)
    tok = AutoTokenizer.from_pretrained(model_name)
    tok.pad_token = tok.eos_token
    kwargs = {}
    if torch_dtype is not None:
        kwargs["torch_dtype"] = torch_dtype
    mdl = AutoModelForCausalLM.from_pretrained(model_name, **kwargs).to(device).eval()
    print(f"Loaded: {model_name}")
    return mdl, tok, model_name
    
def latency_fn(model, tokenizer):
    def f(L=128, B=1):
        inp = torch.randint(0, tokenizer.vocab_size, (B, L), device=device)
        att = torch.ones(B, L, device=device, dtype=torch.long)
        return model(input_ids=inp, attention_mask=att, labels=inp).logits
    return f
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")
SAMPLE_TEXTS = [t for t in ds["text"] if isinstance(t, str) and t.strip()]


Device: cpu


README.md: 0.00B [00:00, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


test-00000-of-00001.parquet:   0%|          | 0.00/733k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


train-00000-of-00001.parquet:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


validation-00000-of-00001.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

## 1) Dense baseline

In [2]:
model, tok, name = load_fresh()
stats = params_size_and_sparsity(model)
ppl   = eval_ppl_causal(model, tok, SAMPLE_TEXTS, device)
lat   = measure_latency_ms(latency_fn(model, tok), 128, 1, warmup=3, iters=10)
append_row(CSV_PATH, setup='Dense', size_mb=stats['size_mb'], sparsity=stats['sparsity'], latency_ms=lat, perplexity=ppl)
stats, ppl, lat

Loaded: facebook/opt-125m


RuntimeError: [enforce fail at alloc_cpu.cpp:121] data. DefaultCPUAllocator: not enough memory: you tried to allocate 7741636608 bytes.

## 2) Masked pruning (30%) — dense execution

In [None]:
SP_MASK = 0.30
model, tok, name = load_fresh()
layers = select_prunable_linears(model, blacklist=("lm_head",))
apply_global_magnitude_pruning_cpu_safe(layers, amount=SP_MASK)
stats = params_size_and_sparsity(model)
ppl   = eval_ppl_causal(model, tok, SAMPLE_TEXTS, device)
lat   = measure_latency_ms(latency_fn(model, tok), 128, 1, warmup=3, iters=10)
append_row(CSV_PATH, setup=f'Masked{int(SP_MASK*100)}', size_mb=stats['size_mb'], sparsity=stats['sparsity'], latency_ms=lat, perplexity=ppl)
stats, ppl, lat

## 3) CSR execution (50%) — real sparse kernels

In [None]:
SP_CSR = 0.50
model, tok, name = load_fresh()
layers = select_prunable_linears(model, blacklist=("lm_head",))
apply_global_magnitude_pruning_cpu_safe(layers, amount=SP_CSR)
freeze_pruning_(layers); convert_linear_weights_to_csr_(layers)
swapped = 0
def find_parent(root, child):
    for _, mod in root.named_modules():
        for cn, cc in mod.named_children():
            if cc is child: return mod, cn
    raise RuntimeError('Parent not found')
for lin in layers:
    if swapped >= 4: break
    parent, attr = find_parent(model, lin)
    setattr(parent, attr, LinearCSRForward(lin.weight.detach(), lin.bias.detach() if lin.bias is not None else None).to(device))
    swapped += 1
stats = params_size_and_sparsity(model)
ppl   = eval_ppl_causal(model, tok, SAMPLE_TEXTS, device)
lat   = measure_latency_ms(latency_fn(model, tok), 128, 1, warmup=3, iters=10)
append_row(CSV_PATH, setup=f'CSR{int(SP_CSR*100)}', size_mb=stats['size_mb'], sparsity=stats['sparsity'], latency_ms=lat, perplexity=ppl)
stats, ppl, lat

## 4) Plots

In [None]:
df = pd.read_csv(CSV_PATH); display(df)
bar_plot(df, 'setup', 'size_mb', 'Model size (MB)', 'size_vs_sparsity.png', RESULTS_DIR, y_min=450)
bar_plot(df, 'setup', 'latency_ms', 'Latency (ms / forward)', 'latency_vs_sparsity.png', RESULTS_DIR)
bar_plot(df, 'setup', 'perplexity', 'Perplexity', 'ppl_vs_sparsity.png', RESULTS_DIR)
