### Naive Approach Benchmarks

In [2]:
# Given naive approach benchmarks
# If needed, install deps:
# !pip install torch zstandard

import math, io, pickle, time
from time import perf_counter
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
    import zstandard as zstd
except ImportError as e:
    raise RuntimeError("Please install `zstandard` (pip install zstandard) before running this cell.")

torch.manual_seed(42)
np.random.seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


# -----------------------------
# Tiny Transformer layer
# -----------------------------
class TinyTransformerLayer(nn.Module):
    """Small block with MHA + FFN, sufficient for demonstration."""
    def __init__(self, d_model=64, nhead=4, dim_feedforward=128, dropout=0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True, dropout=dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.GELU(),
            nn.Linear(dim_feedforward, d_model),
        )
    def forward(self, x, attn_mask=None):
        h = self.ln1(x)
        a, _ = self.self_attn(h, h, h, attn_mask=attn_mask)
        x = x + a
        h2 = self.ln2(x)
        y = self.ff(h2)
        return x + y

base_layer = TinyTransformerLayer().eval()
print("Tiny layer params:", sum(p.numel() for p in base_layer.parameters()))


# -----------------------------
# Simulate finetuning (sparse updates)
# -----------------------------
from copy import deepcopy

def perturb_subset_(module: nn.Module, frac: float = 0.03, magnitude: float = 1e-2, seed: int = 0):
    """In-place: perturb `frac` of elements in each float parameter.
    This simulates sparse finetuning, making XOR(base, finetuned) mostly zeros.
    """
    g = torch.Generator(device='cpu').manual_seed(seed)
    with torch.no_grad():
        for name, p in module.named_parameters():
            if not p.dtype.is_floating_point:
                continue
            numel = p.numel()
            k = max(1, int(numel * frac))
            idx = torch.randperm(numel, generator=g, device=p.device)[:k]
            flat = p.view(-1)
            noise = magnitude * torch.randn(k, generator=g, device=p.device, dtype=flat.dtype)
            flat[idx] += noise

finetuned_layer = deepcopy(base_layer)
perturb_subset_(finetuned_layer, frac=0.03, magnitude=1e-2, seed=123)

# Count changed elements (exact because we add noise)
changed, total = 0, 0
for (n1, p_base), (n2, p_ft) in zip(base_layer.named_parameters(), finetuned_layer.named_parameters()):
    assert n1 == n2
    if p_base.dtype.is_floating_point:
        diff = (p_base != p_ft).sum().item()
        changed += diff
        total += p_base.numel()
print(f"Changed elements: {changed} / {total} (~{changed/total*100:.2f}%)")


# -----------------------------
# XOR + Zstd compression helpers
# -----------------------------
def tensor_to_uint8_view(t: torch.Tensor) -> np.ndarray:
    """Return a NumPy uint8 *view* over tensor bytes (CPU)."""
    a = t.detach().cpu().numpy()
    return a.view(np.uint8)

def xor_uint8(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    assert a.dtype == np.uint8 and b.dtype == np.uint8 and a.size == b.size
    return np.bitwise_xor(a, b)

def compress_state_dict_with_xor_zstd(base_sd: dict, finetuned_sd: dict, level: int = 10):
    """Pack XOR(Zstd) deltas.
    Returns a dict with:
      - meta[name]: shape/dtype/numel/param_bytes/uint8_shape
      - deltas[name]: Zstd-compressed XOR bytes
      - stats: aggregate sizes
    """
    compressor = zstd.ZstdCompressor(level=level)
    meta, deltas = {}, {}
    stats = {'total_params_bytes': 0, 'total_delta_bytes': 0, 'total_compressed_bytes': 0, 'per_param': {}}

    for name, base_t in base_sd.items():
        ft_t = finetuned_sd[name]
        base_bytes = tensor_to_uint8_view(base_t)
        ft_bytes = tensor_to_uint8_view(ft_t)
        assert base_bytes.size == ft_bytes.size, name

        delta_u8 = xor_uint8(base_bytes, ft_bytes)
        raw_delta = delta_u8.tobytes()
        comp = compressor.compress(raw_delta)

        meta[name] = {
            'shape': list(ft_t.shape),
            'dtype_str': str(ft_t.detach().cpu().numpy().dtype.str),  # e.g. '<f4'
            'numel': ft_t.numel(),
            'param_bytes': ft_bytes.size,
            'uint8_shape': list(base_bytes.shape),
        }
        deltas[name] = comp

        stats['total_params_bytes'] += ft_bytes.size
        stats['total_delta_bytes'] += len(raw_delta)
        stats['total_compressed_bytes'] += len(comp)
        stats['per_param'][name] = {
            'param_bytes': ft_bytes.size,
            'delta_bytes': len(raw_delta),
            'compressed_bytes': len(comp),
            'compressed_ratio_vs_param': len(comp) / max(1, ft_bytes.size),
        }

    return {'meta': meta, 'deltas': deltas, 'zstd_level': level, 'stats': stats}

def reconstruct_finetuned_state_dict_with_timing(base_sd: dict, pkg: dict):
    """Decompress and XOR-apply with timing breakdown.
    Returns: (rec_sd, timing_dict) with keys: t_decompress, t_xor_apply
    """
    dctx = zstd.ZstdDecompressor()
    rec = {}
    t_decomp = 0.0
    t_xor = 0.0
    for name, meta in pkg['meta'].items():
        comp = pkg['deltas'][name]
        # Decompress
        t0 = perf_counter()
        raw_delta = dctx.decompress(comp)
        t_decomp += perf_counter() - t0
        delta_u8 = np.frombuffer(raw_delta, dtype=np.uint8).reshape(meta['uint8_shape'])

        # XOR apply
        base_bytes = tensor_to_uint8_view(base_sd[name])
        t1 = perf_counter()
        ft_bytes = np.bitwise_xor(base_bytes, delta_u8)  # shape-aligned XOR
        t_xor += perf_counter() - t1

        # Reinterpret as original dtype/shape
        arr = ft_bytes.view(np.dtype(meta['dtype_str'])).reshape(meta['shape']).copy()
        rec[name] = torch.from_numpy(arr)
    return rec, {'t_decompress': t_decomp, 't_xor_apply': t_xor}


# -----------------------------
# Build base/finetuned SDs and compress
# -----------------------------
base_sd = {k: v.detach().cpu().contiguous() for k, v in base_layer.state_dict().items()}
ft_sd   = {k: v.detach().cpu().contiguous() for k, v in finetuned_layer.state_dict().items()}

pkg = compress_state_dict_with_xor_zstd(base_sd, ft_sd, level=10)
print("=== Compression Stats ===")
print("Total param bytes  :", pkg['stats']['total_params_bytes'])
print("Total delta bytes  :", pkg['stats']['total_delta_bytes'])
print("Total compressed   :", pkg['stats']['total_compressed_bytes'])
print("Global ratio       :", pkg['stats']['total_compressed_bytes']/pkg['stats']['total_params_bytes'])
for name, st in list(pkg['stats']['per_param'].items())[:5]:
    print(f"- {name:40s} | param={st['param_bytes']} | comp={st['compressed_bytes']} | ratio={st['compressed_ratio_vs_param']:.4f}")

# Serialize to in-memory bytes (so we can measure "load" as deserialization;
# replace with actual disk I/O in your system if you want true file load timings).
pkg_bytes = pickle.dumps(pkg)


# -----------------------------
# End-to-end function with timers (load → reconstruct → load_state_dict → forward)
# -----------------------------
def e2e_forward_from_pkg_bytes(base_layer: nn.Module, pkg_bytes: bytes, x: torch.Tensor, device: str = 'cpu'):
    """End-to-end: load (deserialize) → decompress → XOR-apply → load_state_dict → forward.
    Returns (timings_dict, checksum).
    NOTE: 't_load_pkg' here measures pickle.loads (in-memory). Replace with disk I/O for on-disk timings.
    """
    # 0) Prepare inputs
    x = x.to(device)
    base_sd = {k: v.detach().cpu().contiguous() for k, v in base_layer.state_dict().items()}

    # 1) Load package (deserialize)
    t0 = perf_counter()
    pkg = pickle.loads(pkg_bytes)
    t_load = perf_counter() - t0

    # 2) Decompress + XOR apply (split timings)
    rec_sd, t_parts = reconstruct_finetuned_state_dict_with_timing(base_sd, pkg)

    # 3) Materialize a new layer and copy weights
    t2 = perf_counter()
    rec_layer = TinyTransformerLayer().to(device).eval()
    with torch.no_grad():
        for k, p in rec_layer.state_dict().items():
            p.copy_(rec_sd[k].to(device))
    t_state = perf_counter() - t2

    # 4) Forward once (demo)
    t3 = perf_counter()
    with torch.no_grad():
        y = rec_layer(x)
    t_fwd = perf_counter() - t3

    timings = {
        't_load_pkg': t_load,
        't_decompress': t_parts['t_decompress'],
        't_xor_apply': t_parts['t_xor_apply'],
        't_state_load': t_state,
        't_forward': t_fwd,
        't_end_to_end': t_load + t_parts['t_decompress'] + t_parts['t_xor_apply'] + t_state + t_fwd,
    }
    checksum = float(y.abs().sum().detach().cpu())  # simple checksum to ensure consistent output
    return timings, checksum


# -----------------------------
# Micro-benchmark: mean ± std per stage
# -----------------------------
def benchmark_e2e_breakdown(base_layer, pkg_bytes: bytes, x: torch.Tensor, device: str = 'cpu', repeats: int = 7):
    """Repeat the E2E pipeline `repeats` times and report mean/std for each stage."""
    keys = ['t_load_pkg','t_decompress','t_xor_apply','t_state_load','t_forward','t_end_to_end']
    buf = {k: [] for k in keys}
    checks = []
    for _ in range(repeats):
        timings, checksum = e2e_forward_from_pkg_bytes(base_layer, pkg_bytes, x, device=device)
        for k in keys:
            buf[k].append(timings[k])
        checks.append(checksum)
    stats = {k: {'mean': float(np.mean(v)), 'std': float(np.std(v))} for k, v in buf.items()}
    return stats, checks


# -----------------------------
# What-if speed windows
# -----------------------------
def print_what_if_bounds(stats):
    """Compute what-if bounds from measured means:
       - Bound A: skip global XOR (still decompress) → compute on decompressed XOR bytes
       - Bound B: skip XOR + decompress (theoretical Zstd-domain compute)
    """
    t_load = stats['t_load_pkg']['mean']
    t_decomp = stats['t_decompress']['mean']
    t_xor = stats['t_xor_apply']['mean']
    t_state = stats['t_state_load']['mean']
    t_fwd = stats['t_forward']['mean']
    t_e2e = stats['t_end_to_end']['mean']

    t_boundA = t_load + t_decomp + t_state + t_fwd       # skip XOR
    t_boundB = t_load + t_state + t_fwd                  # skip XOR + decompress (theoretical)

    def fmt_speedup(old, new):
        return f"{old/new:.2f}× faster (↓{(1 - new/old)*100:.1f}%)" if new > 0 else "∞"

    print("=== What-if Speed Windows (vs current E2E mean) ===")
    print(f"Baseline E2E: {t_e2e:.6f}s")
    print(f"Skip XOR (keep decompress): {t_boundA:.6f}s  → {fmt_speedup(t_e2e, t_boundA)}")
    print(f"Skip XOR+Decompress (Zstd-domain compute): {t_boundB:.6f}s  → {fmt_speedup(t_e2e, t_boundB)}")
    print("\nInterpretation:")
    print("- (Baseline → Bound A) is the best-case improvement if we avoid the global XOR-apply step by")
    print("  computing directly on decompressed XOR bytes (tile JIT patching, patch-based kernels, etc.).")
    print("- (Baseline → Bound B) is a theoretical ceiling if we could avoid both decompression and XOR")
    print("  (i.e., compute in Zstd domain); not realistic today, but shows maximum headroom.")


# -----------------------------
# Per-parameter hotspots (identify where XOR time concentrates)
# -----------------------------
def per_param_decompress_xor_timing(base_sd: dict, pkg: dict, topk: int = 5):
    """Per-parameter timing for decompress and XOR-apply to identify hotspots."""
    dctx = zstd.ZstdDecompressor()
    records = []
    for name, meta in pkg['meta'].items():
        comp = pkg['deltas'][name]
        t0 = perf_counter(); raw = dctx.decompress(comp); t_de = perf_counter()-t0
        delta = np.frombuffer(raw, dtype=np.uint8).reshape(meta['uint8_shape'])
        t1 = perf_counter(); _ = np.bitwise_xor(tensor_to_uint8_view(base_sd[name]), delta); t_x = perf_counter()-t1
        records.append((name, meta['param_bytes'], t_de, t_x))
    # sort by XOR time (or total)
    records.sort(key=lambda r: r[3], reverse=True)
    print(f"=== Top {topk} params by XOR-apply time ===")
    for name, sz, tde, tx in records[:topk]:
        print(f"{name:40s} | bytes={sz:8d} | decompress={tde:.6f}s | xor={tx:.6f}s")
    return records


# -----------------------------
# XOR sparsity & naive sparse storage estimate
# -----------------------------
def estimate_xor_sparsity_and_storage(pkg: dict):
    """Measure XOR sparsity (non-zero bytes) and estimate a simple sparse storage cost.
    Model: store (index:uint32 + value:uint8) per non-zero byte.
    """
    dctx = zstd.ZstdDecompressor()
    total_bytes = 0
    total_nz = 0
    for name, meta in pkg['meta'].items():
        raw = dctx.decompress(pkg['deltas'][name])
        arr = np.frombuffer(raw, dtype=np.uint8)
        nz = int((arr != 0).sum())
        total_bytes += arr.size
        total_nz += nz
    density = total_nz / max(1, total_bytes)
    dense_bytes = total_bytes  # 1 byte per entry
    # Sparse: 4 bytes for index + 1 byte value (very simple model)
    sparse_bytes = total_nz * (4 + 1)
    return {
        'xor_total_bytes': total_bytes,
        'xor_nonzero_bytes': total_nz,
        'nonzero_density': density,
        'dense_storage_bytes': dense_bytes,
        'simple_sparse_storage_bytes': sparse_bytes,
        'sparse_over_dense_ratio': sparse_bytes / max(1, dense_bytes),
    }


# -----------------------------
# Run benchmarks
# -----------------------------
# End-to-end breakdown
x_bench = torch.randn(4, 16, 64)
stats, checks = benchmark_e2e_breakdown(base_layer, pkg_bytes, x_bench, device=device, repeats=7)
print("=== Component-wise Timing (seconds): mean ± std over 7 runs ===")
for k, s in stats.items():
    print(f"{k:>14s}: {s['mean']:.6f} ± {s['std']:.6f}")
print("Sanity checksums (first 3):", checks[:3])

# What-if bounds
print_what_if_bounds(stats)

# Hotspots
_ = per_param_decompress_xor_timing(base_sd, pkg, topk=5)

# XOR sparsity summary
sparsity = estimate_xor_sparsity_and_storage(pkg)
print("\n=== XOR Sparsity Estimate ===")
for k, v in sparsity.items():
    print(f"{k:28s}: {v}")
print("\nNOTE: If non-zero density is low, consider block-sparse indexing (e.g., 256B/4KB blocks), "
      "RLE on zero-runs, or bitplane packing so **many XOR adapters can co-reside on a single GPU.**")



Using device: cpu
Tiny layer params: 33472
Changed elements: 994 / 33472 (~2.97%)
=== Compression Stats ===
Total param bytes  : 133888
Total delta bytes  : 133888
Total compressed   : 5633
Global ratio       : 0.04207247848948375
- ln1.weight                               | param=256 | comp=27 | ratio=0.1055
- ln1.bias                                 | param=256 | comp=26 | ratio=0.1016
- self_attn.in_proj_weight                 | param=49152 | comp=1963 | ratio=0.0399
- self_attn.in_proj_bias                   | param=768 | comp=53 | ratio=0.0690
- self_attn.out_proj.weight                | param=16384 | comp=707 | ratio=0.0432
=== Component-wise Timing (seconds): mean ± std over 7 runs ===
    t_load_pkg: 0.000016 ± 0.000003
  t_decompress: 0.000050 ± 0.000015
   t_xor_apply: 0.000053 ± 0.000013
  t_state_load: 0.000938 ± 0.001317
     t_forward: 0.012634 ± 0.030183
  t_end_to_end: 0.013692 ± 0.031508
Sanity checksums (first 3): [3409.4951171875, 3409.4951171875, 3409.4951171875]
==

### GEMM on zstd Compressed XORed matrices

In [37]:
# HELPER FUNCTION: Tiled Hooks # -----------------------------

import json


def update_mha_bias_in_place(module, param_name, new_data, device):
    """Update the MHA bias parameter in-place with finetuned data."""
    print(f"  Updated MHA bias parameter in-place for {param_name}")
    if 'in_proj_bias' in param_name:
        module.in_proj_bias.data = new_data.to(module.in_proj_bias.device, module.in_proj_bias.dtype)
    else:
        module.out_proj.bias.data = new_data.to(module.out_proj.bias.device, module.out_proj.bias.dtype)
    return None  # No forward hook needed


def track_timing(timing_data, tile_data, param_name, t_decompress, t_xor=0.0, t_forward=0.0, num_tiles=1):
    """Track the timing metrics for this parameter."""
    timing_data['param_names'].append(param_name)
    timing_data['t_partial_decompress'].append(t_decompress)
    timing_data['t_tile_xor_patch'].append(t_xor)
    timing_data['t_forward_tile'].append(t_forward)
    tile_data['param_names'].append(param_name)
    tile_data['num_tiles'].append(num_tiles)


def create_tiled_hook(param_name, module, pkg, device, tile_rows, tile_cols, timing_data, size_data, tile_data, is_bias=False, is_norm=False, is_mha=False):
    """Create a forward hook that applies finetuning updates to a parameter."""
    print(f"-----------------{param_name} Hook-----------------")
    print(is_mha, is_bias, is_norm)
    if param_name not in pkg['deltas']:
        print(f"Skipping {param_name}: no update in pkg['deltas']")
        return None

    meta = pkg['meta'][param_name]
    shape = meta['shape']
    uint8_shape = meta['uint8_shape']
    compressed_xor = pkg['deltas'][param_name]
    print(f"  Shape: {shape}, uint8_shape: {uint8_shape}, numel: {meta['numel']}, bytes: {meta['param_bytes']}")

    size_data['param_names'].append(param_name)
    size_data['numel'].append(meta['numel'])

    # For now, just return None - hook logic comes in later steps
    t0 = perf_counter()
    xor_bytes = zstd.decompress(compressed_xor)
    xor_arr = np.frombuffer(xor_bytes, dtype=np.uint8).reshape(uint8_shape)
    t_decompress = perf_counter() - t0
    print(f"  Decompressed XOR bytes in {t_decompress:.7f}s, shape: {xor_arr.shape}")

    # Get base data (for biases/norms only)
    if is_bias or is_norm:
        if is_mha and is_bias:
            base_data = (module.in_proj_bias if 'in_proj_bias' in param_name else module.out_proj_bias).data.cpu().numpy()
        else:
            base_data = (module.bias if is_bias else module.weight).data.cpu().numpy()
        print(f"  Base data shape: {base_data.shape}, dtype: {base_data.dtype}")

        # View as bytes, XOR, view back as float32, reshape
        base_bytes = base_data.view(np.uint8).reshape(uint8_shape)
        ft_bytes = np.bitwise_xor(base_bytes, xor_arr)
        new_data = ft_bytes.view(np.dtype(meta['dtype_str'])).reshape(shape)
        new_data = torch.from_numpy(new_data).to(dtype=torch.float32, device=device)
        print(f"  Finetuned data shape: {new_data.shape}, sample values: {new_data.flatten()[:5]}")

        track_timing(timing_data, tile_data, param_name, t_decompress)

        if is_mha and is_bias:
            print(f"  DEBUG: is_mha={is_mha}, is_bias={is_bias} for {param_name} - updating in-place, NO HOOK")
            update_mha_bias_in_place(module, param_name, new_data, device)
            return None  # Exit early, absolutely no hook
        else:
            # Non-MHA biases/norms: create hook
            print(f"  DEBUG: Creating forward hook for non-MHA {param_name}")
            def hook_fn(module, inputs, output):
                call_count = getattr(hook_fn, 'call_count', 0)
                hook_fn.call_count = call_count + 1
                print(f"  Applying hook for {param_name} (call #{call_count + 1}), input shape: {inputs[0].shape}")
                if is_norm:
                    output = output * new_data
                    print(f"  Applied norm multiplication (output *= finetuned weight)")
                else:
                    output_shape = output.shape if not is_mha else output[0].shape
                    if isinstance(module, nn.LayerNorm):
                        batch, seq, dim = output_shape
                        bias_expanded = new_data.unsqueeze(0).unsqueeze(1).expand(batch, seq, -1)
                        output = output + bias_expanded
                        print(f"  Applied LayerNorm bias addition (expanded to [batch, seq, dim] = {bias_expanded.shape})")
                    else:
                        bias_expanded = new_data.unsqueeze(0).expand(output_shape[0], -1)
                        output = output + bias_expanded if not is_mha else (output[0] + bias_expanded, output[1])
                        print(f"  Applied bias addition (expanded to [batch, dim] = {bias_expanded.shape})")
                print(f"  Output shape after hook: {output.shape if not is_mha else output[0].shape}")
                return output

            print(f"  Created and returning hook function for {param_name}")
            return hook_fn
    else:
        # Weights: skip for now
        print(f"  Skipping weight {param_name} (tiling in later steps)")
        track_timing(timing_data, tile_data, param_name, 0.0)
        return None


In [38]:
# HELPER FUNCTION: Visualize metrics # -----------------------------

import json

def visualize_metrics(timing_data, size_data, tile_data):
    if timing_data['param_names']:
        timing_chart = {
            "type": "bar",
            "data": {
                "labels": timing_data['param_names'],
                "datasets": [
                    {
                        "label": "Decompress Time (s)",
                        "data": timing_data['t_partial_decompress'],
                        "backgroundColor": "rgba(75, 192, 192, 0.6)",
                        "borderColor": "rgba(75, 192, 192, 1)",
                        "borderWidth": 1
                    },
                    {
                        "label": "XOR Patch Time (s)",
                        "data": timing_data['t_tile_xor_patch'],
                        "backgroundColor": "rgba(255, 99, 132, 0.6)",
                        "borderColor": "rgba(255, 99, 132, 1)",
                        "borderWidth": 1
                    },
                    {
                        "label": "Forward Tile Time (s)",
                        "data": timing_data['t_forward_tile'],
                        "backgroundColor": "rgba(54, 162, 235, 0.6)",
                        "borderColor": "rgba(54, 162, 235, 1)",
                        "borderWidth": 1
                    }
                ]
            },
            "options": {
                "plugins": {
                    "title": {
                        "display": True,
                        "text": "Timing Breakdown per Parameter"
                    }
                },
                "scales": {
                    "y": {
                        "beginAtZero": True,
                        "title": {
                            "display": True,
                            "text": "Time (seconds)"
                        }
                    },
                    "x": {
                        "title": {
                            "display": True,
                            "text": "Parameter"
                        }
                    }
                }
            }
        }
        print("\nTiming Chart JSON:")
        print(json.dumps(timing_chart, indent=2))

        size_chart = {
            "type": "bar",
            "data": {
                "labels": size_data['param_names'],
                "datasets": [
                    {
                        "label": "Number of Elements",
                        "data": size_data['numel'],
                        "backgroundColor": "rgba(153, 102, 255, 0.6)",
                        "borderColor": "rgba(153, 102, 255, 1)",
                        "borderWidth": 1
                    }
                ]
            },
            "options": {
                "plugins": {
                    "title": {
                        "display": True,
                        "text": "Parameter Sizes (Number of Elements)"
                    }
                },
                "scales": {
                    "y": {
                        "beginAtZero": True,
                        "title": {
                            "display": True,
                            "text": "Number of Elements"
                        }
                    },
                    "x": {
                        "title": {
                            "display": True,
                            "text": "Parameter"
                        }
                    }
                }
            }
        }
        print("\nSize Chart JSON:")
        print(json.dumps(size_chart, indent=2))

        tile_chart = {
            "type": "bar",
            "data": {
                "labels": tile_data['param_names'],
                "datasets": [
                    {
                        "label": "Number of Tiles",
                        "data": tile_data['num_tiles'],
                        "backgroundColor": "rgba(255, 159, 64, 0.6)",
                        "borderColor": "rgba(255, 159, 64, 1)",
                        "borderWidth": 1
                    }
                ]
            },
            "options": {
                "plugins": {
                    "title": {
                        "display": True,
                        "text": "Number of Tiles per Parameter"
                    }
                },
                "scales": {
                    "y": {
                        "beginAtZero": True,
                        "title": {
                            "display": True,
                            "text": "Number of Tiles"
                        }
                    },
                    "x": {
                        "title": {
                            "display": True,
                            "text": "Parameter"
                        }
                    }
                }
            }
        }
        print("\nTile Chart JSON:")
        print(json.dumps(tile_chart, indent=2))

In [39]:
# -----------------------------
# TODO: compute on decompressed XOR without global XOR
# -----------------------------
def e2e_forward_on_delta_no_full_reconstruct_TODO(base_layer: nn.Module, pkg_bytes: bytes, x: torch.Tensor,
                                                  *, device: str = 'cpu', tile_rows: int = 64, tile_cols: int = 64,
                                                  test_mode: bool = False):
    print(f"\n=== Step 1 - Starting Function ===")
    print(f"Input x shape: {x.shape}, device: {device}, tile_rows: {tile_rows}, tile_cols: {tile_cols}, test_mode: {test_mode}")
    t0 = perf_counter()
    pkg = pickle.loads(pkg_bytes)
    t_load = perf_counter() - t0
    print(f"Package load time: {t_load:.7f}s")
    print(f"pkg keys: {list(pkg.keys())}")
    print(f"Model param names: {[name for name, _ in base_layer.named_parameters()]}")
    print(f"Tiny layer params (numel): {[p.numel() for p in base_layer.parameters()]}")
    print(f"Total parameters: {sum(p.numel() for p in base_layer.parameters())}")

    print(f"Model param shapes:")
    for name, p in base_layer.named_parameters():
        print(f"  {name}: {p.shape}")

    delta_keys = set(pkg['deltas'].keys())
    param_keys = set(name for name, _ in base_layer.named_parameters())
    print(f"Delta keys (updates): {delta_keys}")
    print(f"Param keys (model): {param_keys}")
    if not delta_keys.issubset(param_keys):
        print(f"Warning: pkg['deltas'] contains extra keys: {delta_keys - param_keys}")


    print(f"\n=== Step 2 - Module Collection ===")
    named_linears = [(name, module) for name, module in base_layer.named_modules() if isinstance(module, nn.Linear)]
    named_norms = [(name, module) for name, module in base_layer.named_modules() if isinstance(module, nn.LayerNorm)]
    named_mha = [(name, module) for name, module in base_layer.named_modules() if isinstance(module, nn.MultiheadAttention)]
    print(f"Linear layers: {[name for name, _ in named_linears]}")
    print(f"Norm layers: {[name for name, _ in named_norms]}")
    print(f"MHA layers: {[name for name, _ in named_mha]}")


    print(f"\n=== Step 3 - Initialize Visualization Data Structures ===")
    timing_data = {'param_names': [], 't_partial_decompress': [], 't_tile_xor_patch': [], 't_forward_tile': []}
    size_data = {'param_names': [], 'numel': []}
    tile_data = {'param_names': [], 'num_tiles': []}
    print(f"Initialized tracking data structures:")
    print(f"  Timing keys: {list(timing_data.keys())}")
    print(f"  Size keys: {list(size_data.keys())}")
    print(f"  Tile keys: {list(tile_data.keys())}")
    

    print(f"\n=== Registering hooks for linear layers ===")
    hook_handles = []
    for name, module in named_linears:
        weight_name = name + '.weight' if name else 'weight'
        bias_name = name + '.bias' if name else 'bias'
        print(f"  Processing linear layer '{name}': weight='{weight_name}', bias='{bias_name}'")
        weight_hook = create_tiled_hook(weight_name, module, pkg, device, tile_rows, tile_cols, timing_data, size_data, tile_data)
        bias_hook = create_tiled_hook(bias_name, module, pkg, device, tile_rows, tile_cols, timing_data, size_data, tile_data, is_bias=True)
        if weight_hook:
            hook_handles.append(module.register_forward_hook(weight_hook))
            print(f"    Registered weight hook for '{weight_name}'")
        if bias_hook:
            hook_handles.append(module.register_forward_hook(bias_hook))
            print(f"    Registered bias hook for '{bias_name}'")
    print(f"  Total linear hooks registered: {len(hook_handles)}")
    
    
    print(f"\n=== Registering hooks for norm layers ===")
    linear_attempts = len(named_linears) * 2  # weight + bias per linear
    norm_attempts = 0
    for name, module in named_norms:
        w_name = name + '.weight' if name else 'weight'
        b_name = name + '.bias' if name else 'bias'
        print(f"  Processing norm layer '{name}': weight='{w_name}', bias='{b_name}'")
        w_hook = create_tiled_hook(w_name, module, pkg, device, tile_rows, tile_cols, timing_data, size_data, tile_data, is_norm=True)
        b_hook = create_tiled_hook(b_name, module, pkg, device, tile_rows, tile_cols, timing_data, size_data, tile_data, is_bias=True)
        if w_hook:
            hook_handles.append(module.register_forward_hook(w_hook))
            print(f"    Registered weight hook for '{w_name}'")
        if b_hook:
            hook_handles.append(module.register_forward_hook(b_hook))
            print(f"    Registered bias hook for '{b_name}'")
        norm_attempts += 2  # weight + bias attempt
    print(f"  Processed {linear_attempts} linear parameters, {norm_attempts} norm parameters (total attempts: {linear_attempts + norm_attempts})")
    print(f"  Actual hooks registered so far: {len(hook_handles)} (placeholders return None)")


    print(f"\n=== Registering hooks for MHA layers ===")
    mha_attempts = 0
    for name, module in named_mha:
        print(f"  Processing MHA layer '{name}':")
        # Weights
        for w_name in ['in_proj_weight', 'out_proj_weight']:
            full_name = name + '.' + w_name if name else w_name
            print(f"    Processing weight '{full_name}'")
            hook = create_tiled_hook(full_name, module, pkg, device, tile_rows, tile_cols, timing_data, size_data, tile_data, is_mha=True)
            if hook:
                hook_handles.append(module.register_forward_hook(hook))
                print(f"      Registered weight hook for '{full_name}'")
        # Biases
        for b_name in ['in_proj_bias', 'out_proj_bias']:
            full_name = name + '.' + b_name if name else b_name
            print(f"    Processing bias '{full_name}'")
            hook = create_tiled_hook(full_name, module, pkg, device, tile_rows, tile_cols, timing_data, size_data, tile_data, is_bias=True, is_mha=True)
            if hook:
                hook_handles.append(module.register_forward_hook(hook))
                print(f"      Registered bias hook for '{full_name}'")
        mha_attempts += 4  # 2 weights + 2 biases per MHA
    print(f"  Processed {mha_attempts} MHA parameters (total attempts: {linear_attempts + norm_attempts + mha_attempts})")
    print(f"  Actual hooks registered so far: {len(hook_handles)} (placeholders return None)")

    print(f"\n=== Running forward pass with hooks ===")
    t0 = perf_counter()
    with torch.no_grad():
        output = base_layer(x.to(device))
    t_total = perf_counter() - t0
    print(f"Forward pass time: {t_total:.7f}s")
    print(f"Total time: {t_total + t_load:.7f}s (load: {t_load:.7f}s, forward: {t_total:.7f}s)")
    print(f"Final output shape: {output.shape if not isinstance(output, tuple) else output[0].shape}")

    print(f"\n=== Cleaning up hooks ===")
    for handle in hook_handles:
        handle.remove()
    print(f"Removed {len(hook_handles)} hook handles")


x_bench = torch.randn(4, 16, 64)
temp = e2e_forward_on_delta_no_full_reconstruct_TODO(base_layer, pkg_bytes, x_bench, device=device)



=== Step 1 - Starting Function ===
Input x shape: torch.Size([4, 16, 64]), device: cpu, tile_rows: 64, tile_cols: 64, test_mode: False
Package load time: 0.0000205s
pkg keys: ['meta', 'deltas', 'zstd_level', 'stats']
Model param names: ['ln1.weight', 'ln1.bias', 'self_attn.in_proj_weight', 'self_attn.in_proj_bias', 'self_attn.out_proj.weight', 'self_attn.out_proj.bias', 'ln2.weight', 'ln2.bias', 'ff.0.weight', 'ff.0.bias', 'ff.2.weight', 'ff.2.bias']
Tiny layer params (numel): [64, 64, 12288, 192, 4096, 64, 64, 64, 8192, 128, 8192, 64]
Total parameters: 33472
Model param shapes:
  ln1.weight: torch.Size([64])
  ln1.bias: torch.Size([64])
  self_attn.in_proj_weight: torch.Size([192, 64])
  self_attn.in_proj_bias: torch.Size([192])
  self_attn.out_proj.weight: torch.Size([64, 64])
  self_attn.out_proj.bias: torch.Size([64])
  ln2.weight: torch.Size([64])
  ln2.bias: torch.Size([64])
  ff.0.weight: torch.Size([128, 64])
  ff.0.bias: torch.Size([128])
  ff.2.weight: torch.Size([64, 128])


RuntimeError: The size of tensor a (64) must match the size of tensor b (192) at non-singleton dimension 2