# Sensitivity Analysis

> Per-layer sensitivity analysis for compression methods (sparsity, pruning, quantization)

In [None]:
#| default_exp analysis.sensitivity

In [None]:
#| export
from __future__ import annotations
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy
from dataclasses import dataclass, field, asdict
from typing import Callable, Any, Literal
from collections import OrderedDict

# fasterai imports (relative within fasterai package)
from fasterai.sparse.all import Sparsifier
from fasterai.prune.all import Pruner
from fasterai.core.all import large_final, Criteria, Granularities

In [None]:
#| hide
from nbdev.showdoc import *

## Data Classes

In [None]:
#| export
@dataclass(slots=True)
class LayerSensitivity:
    """Sensitivity result for a single layer."""
    name: str                    # layer name
    layer_type: str              # e.g., "Conv2d", "Linear"
    params: int                  # number of parameters
    baseline_metric: float       # metric before compression
    compressed_metric: float     # metric after compression
    delta: float                 # metric change (positive = degradation)
    
    def as_dict(self) -> dict[str, Any]:
        """Convert to dictionary."""
        return asdict(self)


@dataclass(slots=True)
class SensitivityResult:
    """Structured result from sensitivity analysis."""
    compression_type: str                   # "sparsity", "pruning", "quantization"
    compression_level: float                # e.g., 50 for 50% sparsity
    baseline_metric: float                  # overall baseline metric
    layers: list[LayerSensitivity]          # per-layer results
    metric_name: str = "accuracy"           # name of the metric
    higher_is_better: bool = True           # whether higher metric is better
    _results: list[LayerSensitivity] = field(default=None, init=False, repr=False)  # for top() compatibility
    
    def __post_init__(self):
        self._results = self.layers  # for compatibility with top() pattern
    
    def as_dict(self) -> dict[str, Any]:
        """Convert to flat dictionary."""
        return {
            "compression_type": self.compression_type,
            "compression_level": self.compression_level,
            "baseline_metric": self.baseline_metric,
            "metric_name": self.metric_name,
            "higher_is_better": self.higher_is_better,
            "layers": [l.as_dict() for l in self.layers],
        }
    
    def top(
        self,
        n: int = 5,                    # number of layers to return
        *,
        most_sensitive: bool = True,   # True=highest delta (fragile), False=lowest (robust)
    ) -> list[LayerSensitivity]:
        """Return top N most or least sensitive layers."""
        sorted_layers = sorted(self.layers, key=lambda x: x.delta, reverse=most_sensitive)
        return sorted_layers[:n]
    
    def summary(
        self,
        *,
        top: int = 5,  # number of layers to show per category
    ) -> None:
        """Print a formatted summary of sensitivity analysis."""
        print(f"{'═' * 60}")
        print(f"Sensitivity Analysis: {self.compression_type} @ {self.compression_level}%")
        print(f"{'═' * 60}")
        print(f"  Baseline {self.metric_name}: {self.baseline_metric:.4f}")
        print(f"  Layers analyzed: {len(self.layers)}")
        print()
        
        # Most sensitive (fragile) layers
        print(f"  🔴 Most Sensitive (fragile):")
        for i, layer in enumerate(self.top(top, most_sensitive=True), 1):
            sign = "+" if layer.delta > 0 else ""
            print(f"     {i}. {layer.name:30} Δ={sign}{layer.delta:.4f}")
        print()
        
        # Most robust layers
        print(f"  🟢 Most Robust (compressible):")
        for i, layer in enumerate(self.top(top, most_sensitive=False), 1):
            sign = "+" if layer.delta > 0 else ""
            print(f"     {i}. {layer.name:30} Δ={sign}{layer.delta:.4f}")
    
    def to_dataframe(self):
        """Convert to pandas DataFrame."""
        import pandas as pd
        rows = [layer.as_dict() for layer in self.layers]
        return pd.DataFrame(rows)
    
    def to_schedule(
        self,
        model: nn.Module,          # model (used for parameter counts)
        target_pct: float = 50,    # target mean compression percentage
        min_pct: float = 0,        # minimum compression for any layer
        max_pct: float = 90,       # maximum compression for any layer
        gamma: float = 1.0,        # exponent for sensitivity scaling (higher = more differentiation)
    ) -> dict[str, float]:
        """Convert sensitivity to non-uniform compression schedule.
        
        High sensitivity layers get lower compression, robust layers get higher.
        Uses parameter-weighted optimization to hit target_pct exactly.
        """
        if not self.layers:
            return {}
        
        # Convert to fractions
        target = target_pct / 100.0
        smin = min_pct / 100.0
        smax = max_pct / 100.0
        
        # Get sensitivity scores
        names = [l.name for l in self.layers]
        deltas = np.array([max(0.0, l.delta) for l in self.layers], dtype=float)
        weights = np.array([float(l.params) for l in self.layers], dtype=float)
        
        if weights.sum() == 0:
            return {n: target_pct for n in names}
        
        # Normalize sensitivity and invert (high sensitivity -> low compression)
        if np.allclose(deltas, deltas[0]):
            s0 = np.full_like(deltas, target)
        else:
            norm = (deltas - deltas.min()) / (np.ptp(deltas) + 1e-12)
            inv = (1.0 - norm) ** gamma
            s0 = smin + (smax - smin) * inv
        
        # Binary search for lambda to hit target weighted mean
        W = weights.sum()
        tgt = target * W
        
        def f(lam):
            s = np.clip(s0 + lam, smin, smax)
            return float(np.dot(weights, s))
        
        # Find lambda via bisection
        lam_lo, lam_hi = -1.0, 1.0
        while f(lam_lo) > tgt:
            lam_lo *= 2
        while f(lam_hi) < tgt:
            lam_hi *= 2
        
        for _ in range(60):
            lam_mid = 0.5 * (lam_lo + lam_hi)
            if f(lam_mid) < tgt:
                lam_lo = lam_mid
            else:
                lam_hi = lam_mid
        
        final_s = np.clip(s0 + 0.5 * (lam_lo + lam_hi), smin, smax)
        
        return {name: round(s * 100, 2) for name, s in zip(names, final_s)}
    
    def plot(
        self,
        figsize: tuple = (12, 5),  # figure size (width, height)
    ) -> None:
        """Plot sensitivity as a bar chart."""
        import matplotlib.pyplot as plt
        
        names = [l.name for l in self.layers]
        deltas = np.array([l.delta for l in self.layers], dtype=float)
        
        # Color by sensitivity
        norm = (deltas - deltas.min()) / (np.ptp(deltas) + 1e-9)
        colors = plt.cm.RdYlGn_r(norm)  # Red=sensitive, Green=robust
        
        plt.figure(figsize=figsize)
        plt.bar(range(len(deltas)), deltas, color=colors)
        plt.axhline(0, color='gray', linewidth=1.2, linestyle='--')
        plt.xticks(range(len(names)), names, rotation=60, ha='right')
        plt.ylabel(f"{self.metric_name} drop (Δ)")
        plt.title(f"Layer Sensitivity to {self.compression_type} @ {self.compression_level}%", 
                  pad=12, weight='bold')
        plt.grid(axis='y', linestyle=':', alpha=0.6)
        plt.tight_layout()
        plt.show()

## SensitivityAnalyzer

In [None]:
#| export
class SensitivityAnalyzer:
    """Analyze per-layer sensitivity to compression methods.
    
    Uses fasterai's Sparsifier for sparsity analysis and Pruner for structural pruning.
    Supports sparsity (weight zeroing), pruning (structural), and quantization.
    """
    
    VALID_COMPRESSIONS = frozenset({"sparsity", "pruning", "quantization"})
    COMPRESSIBLE_LAYERS = Granularities.available_modules()  # Use fasterai's layer registry
    
    def __init__(
        self,
        model: nn.Module,                              # model to analyze
        sample: torch.Tensor,                          # example input (for Pruner dependency analysis)
        eval_fn: Callable[[nn.Module], float],         # evaluation function returning metric
        *,
        criteria: Criteria = large_final,              # fasterai criteria for importance scoring
        higher_is_better: bool = True,                 # whether higher metric values are better
        metric_name: str = "accuracy",                 # name of the metric for display
        device: str | torch.device | None = None,      # device for computation
        calibration_data: torch.Tensor | None = None,  # for observer-based quantization
    ):
        self.model = model
        self.sample = sample
        self.eval_fn = eval_fn
        self.criteria = criteria
        self.higher_is_better = higher_is_better
        self.metric_name = metric_name
        self.device = device or next(model.parameters()).device
        self.calibration_data = calibration_data
        self._results: SensitivityResult | None = None
        self._sparsifier: Sparsifier | None = None
        self._activation_hooks: list[Any] = []
        self._activation_quantize_config: dict[str, bool] = {}
    
    def _get_compressible_layers(self) -> list[tuple[str, nn.Module]]:
        """Get all compressible layers (Conv2d, Linear, etc.)."""
        return [
            (name, module) 
            for name, module in self.model.named_modules()
            if isinstance(module, self.COMPRESSIBLE_LAYERS)
            and hasattr(module, 'weight') and module.weight is not None
        ]
    
    def _init_sparsifier(
        self,
        granularity: str = "weight",  # sparsity granularity
    ) -> None:
        """Initialize fasterai Sparsifier (saves initial weights for all layers)."""
        if self._sparsifier is None:
            self._sparsifier = Sparsifier(
                self.model,
                granularity=granularity,
                context='local',
                criteria=self.criteria,
            )
    
    def _cleanup_sparsifier(self) -> None:
        """Remove sparsifier buffers from model."""
        if self._sparsifier is not None:
            self._sparsifier._clean_buffers()
            self._sparsifier = None
    
    def _apply_sparsity(
        self,
        module: nn.Module,  # layer to sparsify
        level: float,       # sparsity percentage (0-100)
    ) -> None:
        """Apply sparsity using fasterai Sparsifier."""
        self._sparsifier.sparsify_layer(module, level)
    
    def _restore_layer(
        self,
        module: nn.Module,  # layer to restore
    ) -> None:
        """Restore a single layer from saved initial weights."""
        if hasattr(module, '_init_weights'):
            module.weight.data.copy_(module._init_weights)
        if hasattr(module, '_init_biases') and module._init_biases is not None:
            module.bias.data.copy_(module._init_biases)
        if hasattr(module, '_mask'):
            del module._buffers['_mask']
    
    def _apply_structural_pruning(
        self, 
        target_name: str,  # name of layer to prune
        level: float,      # pruning ratio (0-100)
    ) -> nn.Module:
        """Apply structural pruning to a single layer using fasterai Pruner.
        
        Returns a deep copy of the model with only the target layer pruned.
        """
        model_copy = deepcopy(self.model)
        
        all_layers = []
        target_module = None
        for name, module in model_copy.named_modules():
            if isinstance(module, self.COMPRESSIBLE_LAYERS):
                all_layers.append(module)
                if name == target_name:
                    target_module = module
        
        ignored_layers = [m for m in all_layers if m is not target_module]
        
        try:
            pruner = Pruner(
                model_copy,
                pruning_ratio=level,
                context='local',
                criteria=self.criteria,
                ignored_layers=ignored_layers,
                example_inputs=self.sample,
            )
            pruner.prune_model()
        except Exception as e:
            import warnings
            warnings.warn(f"Structural pruning failed for {target_name}: {e}")
            return model_copy
        
        return model_copy
    
    # ─── Quantization helpers ────────────────────────────────────────────────────
    
    def _compute_qparams_symmetric(
        self,
        tensor: torch.Tensor,       # tensor to compute qparams for
        bits: int = 8,              # quantization bits
        per_channel: bool = False,  # per-channel or per-tensor
        channel_axis: int = 0,      # axis for per-channel quantization
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Compute scale and zero_point for symmetric quantization."""
        qmin, qmax = -(2 ** (bits - 1)), 2 ** (bits - 1) - 1
        
        if per_channel and tensor.dim() > 1:
            dims = list(range(tensor.dim()))
            dims.remove(channel_axis)
            amax = tensor.abs()
            for dim in sorted(dims, reverse=True):
                amax = amax.max(dim=dim).values
            scale = amax / qmax
            scale = torch.clamp(scale, min=1e-8)
            zero_point = torch.zeros_like(scale, dtype=torch.int32)
        else:
            amax = tensor.abs().max()
            scale = torch.tensor([max(amax.item() / qmax, 1e-8)], device=tensor.device)
            zero_point = torch.tensor([0], dtype=torch.int32, device=tensor.device)
        
        return scale, zero_point
    
    def _compute_qparams_asymmetric(
        self,
        tensor: torch.Tensor,       # tensor to compute qparams for
        bits: int = 8,              # quantization bits
        per_channel: bool = False,  # per-channel or per-tensor
        channel_axis: int = 0,      # axis for per-channel quantization
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Compute scale and zero_point for asymmetric quantization."""
        qmin, qmax = 0, 2 ** bits - 1
        
        if per_channel and tensor.dim() > 1:
            dims = list(range(tensor.dim()))
            dims.remove(channel_axis)
            t_min, t_max = tensor.clone(), tensor.clone()
            for dim in sorted(dims, reverse=True):
                t_min = t_min.min(dim=dim).values
                t_max = t_max.max(dim=dim).values
        else:
            t_min, t_max = tensor.min(), tensor.max()
        
        scale = (t_max - t_min) / (qmax - qmin)
        scale = torch.clamp(scale, min=1e-8)
        zero_point = torch.clamp(torch.round(-t_min / scale), qmin, qmax).to(torch.int32)
        
        if not per_channel or tensor.dim() <= 1:
            scale = scale.view(1) if scale.dim() == 0 else scale
            zero_point = zero_point.view(1) if zero_point.dim() == 0 else zero_point
        
        return scale, zero_point
    
    def _fake_quantize_per_channel(
        self,
        tensor: torch.Tensor,      # tensor to quantize
        bits: int = 8,             # quantization bits
        symmetric: bool = True,    # symmetric or asymmetric
        channel_axis: int = 0,     # axis for per-channel quantization
    ) -> torch.Tensor:
        """Apply fake quantization (per-channel)."""
        qmin = -(2 ** (bits - 1)) if symmetric else 0
        qmax = (2 ** (bits - 1)) - 1 if symmetric else (2 ** bits) - 1
        
        if symmetric:
            scale, zero_point = self._compute_qparams_symmetric(
                tensor, bits, per_channel=True, channel_axis=channel_axis
            )
        else:
            scale, zero_point = self._compute_qparams_asymmetric(
                tensor, bits, per_channel=True, channel_axis=channel_axis
            )
        
        return torch.fake_quantize_per_channel_affine(
            tensor, scale, zero_point, channel_axis, qmin, qmax
        )
    
    def _fake_quantize_per_tensor(
        self,
        tensor: torch.Tensor,    # tensor to quantize
        bits: int = 8,           # quantization bits
        symmetric: bool = True,  # symmetric or asymmetric
    ) -> torch.Tensor:
        """Apply fake quantization (per-tensor)."""
        qmin = -(2 ** (bits - 1)) if symmetric else 0
        qmax = (2 ** (bits - 1)) - 1 if symmetric else (2 ** bits) - 1
        
        if symmetric:
            scale, zero_point = self._compute_qparams_symmetric(tensor, bits, per_channel=False)
        else:
            scale, zero_point = self._compute_qparams_asymmetric(tensor, bits, per_channel=False)
        
        return torch.fake_quantize_per_tensor_affine(
            tensor, scale.item(), int(zero_point.item()), qmin, qmax
        )
    
    def _apply_weight_quantization(
        self, 
        module: nn.Module,       # layer to quantize
        bits: int = 8,           # quantization bits
        per_channel: bool = True,  # per-channel or per-tensor
    ) -> None:
        """Apply weight quantization using fake_quantize."""
        weight = module.weight.data
        if per_channel and weight.dim() > 1:
            quantized = self._fake_quantize_per_channel(weight, bits, symmetric=True, channel_axis=0)
        else:
            quantized = self._fake_quantize_per_tensor(weight, bits, symmetric=True)
        weight.copy_(quantized)
    
    def _create_activation_quantize_hook(
        self,
        layer_name: str,  # layer name for config lookup
        bits: int = 8,    # quantization bits
    ):
        """Create a forward hook that quantizes activations."""
        def hook(module, input, output):
            if self._activation_quantize_config.get(layer_name, False):
                return self._fake_quantize_per_tensor(output, bits, symmetric=False)
            return output
        return hook
    
    def _setup_activation_hooks(
        self,
        bits: int = 8,  # quantization bits
    ) -> None:
        """Register activation quantization hooks on all layers."""
        self._remove_activation_hooks()
        for name, module in self._get_compressible_layers():
            hook = self._create_activation_quantize_hook(name, bits)
            handle = module.register_forward_hook(hook)
            self._activation_hooks.append(handle)
            self._activation_quantize_config[name] = False
    
    def _remove_activation_hooks(self) -> None:
        """Remove all activation quantization hooks."""
        for handle in self._activation_hooks:
            handle.remove()
        self._activation_hooks = []
        self._activation_quantize_config = {}
    
    # ─── Main analysis method ────────────────────────────────────────────────────
    
    def analyze(
        self,
        compression: Literal["sparsity", "pruning", "quantization"] = "sparsity",  # compression type
        level: float = 50,                    # compression level (% for sparsity/pruning, bits for quant)
        *,
        granularity: str = "weight",          # granularity for sparsity (fasterai granularities)
        layers: list[str] | None = None,      # specific layer names to analyze (None = all)
        quant_per_channel: bool = True,       # use per-channel quantization
        quant_activations: bool = False,      # also quantize activations
        verbose: bool = True,                 # print progress
    ) -> SensitivityResult:
        """Analyze per-layer sensitivity to compression."""
        if compression not in self.VALID_COMPRESSIONS:
            raise ValueError(f"compression must be one of {self.VALID_COMPRESSIONS}")
        
        self.model.eval()
        
        if compression == "sparsity":
            self._init_sparsifier(granularity)
        
        if compression == "quantization" and quant_activations:
            bits = int(level) if level > 1 else 8
            self._setup_activation_hooks(bits)
        
        if verbose:
            print(f"Computing baseline {self.metric_name}...", end=" ", flush=True)
        baseline = self.eval_fn(self.model)
        if verbose:
            print(f"{baseline:.4f}")
        
        all_layers = self._get_compressible_layers()
        if layers is not None:
            all_layers = [(n, m) for n, m in all_layers if n in layers]
        
        mode_info = ""
        if compression == "quantization":
            mode_info = f" (per-{'channel' if quant_per_channel else 'tensor'}"
            mode_info += f", {'weights+activations' if quant_activations else 'weights only'})"
        elif compression == "sparsity":
            mode_info = f" (granularity={granularity}, criteria={self.criteria.f.__name__})"
        elif compression == "pruning":
            mode_info = f" (structural, criteria={self.criteria.f.__name__})"
        
        if verbose:
            unit = 'bits' if compression == 'quantization' else '%'
            print(f"Analyzing {len(all_layers)} layers for {compression} @ {level}{unit}{mode_info}")
        
        results: list[LayerSensitivity] = []
        
        for i, (name, module) in enumerate(all_layers):
            if verbose:
                print(f"  [{i+1}/{len(all_layers)}] {name}...", end=" ", flush=True)
            
            if compression == "sparsity":
                self._apply_sparsity(module, level)
                compressed_metric = self.eval_fn(self.model)
                self._restore_layer(module)
                param_count = module.weight.numel()
                
            elif compression == "pruning":
                pruned_model = self._apply_structural_pruning(name, level)
                compressed_metric = self.eval_fn(pruned_model)
                param_count = module.weight.numel()
                del pruned_model
                
            elif compression == "quantization":
                saved_weight = module.weight.data.clone()
                saved_bias = module.bias.data.clone() if module.bias is not None else None
                
                bits = int(level) if level > 1 else 8
                self._apply_weight_quantization(module, bits, per_channel=quant_per_channel)
                
                if quant_activations:
                    self._activation_quantize_config[name] = True
                
                compressed_metric = self.eval_fn(self.model)
                
                if quant_activations:
                    self._activation_quantize_config[name] = False
                
                module.weight.data.copy_(saved_weight)
                if saved_bias is not None:
                    module.bias.data.copy_(saved_bias)
                param_count = module.weight.numel()
            
            if self.higher_is_better:
                delta = baseline - compressed_metric
            else:
                delta = compressed_metric - baseline
            
            if verbose:
                sign = "+" if delta > 0 else ""
                print(f"Δ={sign}{delta:.4f}")
            
            results.append(LayerSensitivity(
                name=name,
                layer_type=module.__class__.__name__,
                params=param_count,
                baseline_metric=baseline,
                compressed_metric=compressed_metric,
                delta=delta,
            ))
        
        if compression == "sparsity":
            self._cleanup_sparsifier()
        if compression == "quantization" and quant_activations:
            self._remove_activation_hooks()
        
        compression_desc = compression
        if compression == "quantization":
            compression_desc = f"quantization-{int(level) if level > 1 else 8}bit"
            if quant_activations:
                compression_desc += "+act"
        
        self._results = SensitivityResult(
            compression_type=compression_desc,
            compression_level=level,
            baseline_metric=baseline,
            layers=results,
            metric_name=self.metric_name,
            higher_is_better=self.higher_is_better,
        )
        
        if verbose:
            print(f"✓ Analysis complete")
        
        return self._results
    
    def sweep(
        self,
        compression: Literal["sparsity", "pruning", "quantization"] = "sparsity",  # compression type
        levels: list[float] | None = None,  # compression levels to test (default: [25, 50, 75])
        **kwargs,
    ) -> list[SensitivityResult]:
        """Run sensitivity analysis at multiple compression levels."""
        if levels is None:
            levels = [25, 50, 75]
        results = []
        for level in levels:
            print(f"\n{'='*60}")
            unit = 'bits' if compression == 'quantization' else '%'
            print(f"Sweep: {compression} @ {level}{unit}")
            print(f"{'='*60}")
            result = self.analyze(compression, level, **kwargs)
            results.append(result)
        return results

## Convenience Functions

In [None]:
#| export
def analyze_sensitivity(
    model: nn.Module,                    # model to analyze
    sample: torch.Tensor,                # example input tensor
    eval_fn: Callable[[nn.Module], float],  # evaluation function returning metric
    compression: Literal["sparsity", "pruning", "quantization"] = "sparsity",  # compression type
    level: float = 50,                   # compression level (% for sparsity/pruning, bits for quant)
    *,
    criteria: Criteria = large_final,    # fasterai criteria for importance scoring
    higher_is_better: bool = True,       # whether higher metric values are better
    metric_name: str = "accuracy",       # name of the metric for display
    granularity: str = "weight",         # granularity for sparsity
    verbose: bool = True,                # print progress
    **kwargs,
) -> SensitivityResult:
    """One-line sensitivity analysis using fasterai compression methods."""
    analyzer = SensitivityAnalyzer(
        model, sample, eval_fn, 
        criteria=criteria, 
        higher_is_better=higher_is_better,
        metric_name=metric_name,
    )
    return analyzer.analyze(compression, level, granularity=granularity, verbose=verbose, **kwargs)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()