# XOR + Zstd Compression Demo on a Tiny Transformer Layer

**Outline:** Show how to (1) create a tiny Transformer layer, (2) simulate finetuning by perturbing a small subset of parameters, (3) compress the finetuned weights as XOR against the base, then Zstd-compress the XOR result, and (4) decompress to **exactly** recover the finetuned layer and run a forward pass.

**Why this demo:** It mirrors our current workflow and marks the exact place we want to optimize next—*computing on XOR-compressed data directly (without full decompression)*.



In [None]:
# If needed, install deps:
# !pip install torch zstandard

import math, io, pickle, time
from time import perf_counter
import numpy as np
import torch
import torch.nn as nn|
import torch.nn.functional as F
try:
    import zstandard as zstd
except ImportError as e:
    raise RuntimeError("Please install `zstandard` (pip install zstandard) before running this cell.")

torch.manual_seed(42)
np.random.seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


# -----------------------------
# Tiny Transformer layer
# -----------------------------
class TinyTransformerLayer(nn.Module):
    """Small block with MHA + FFN, sufficient for demonstration."""
    def __init__(self, d_model=64, nhead=4, dim_feedforward=128, dropout=0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True, dropout=dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.GELU(),
            nn.Linear(dim_feedforward, d_model),
        )
    def forward(self, x, attn_mask=None):
        h = self.ln1(x)
        a, _ = self.self_attn(h, h, h, attn_mask=attn_mask)
        x = x + a
        h2 = self.ln2(x)
        y = self.ff(h2)
        return x + y

base_layer = TinyTransformerLayer().eval()
print("Tiny layer params:", sum(p.numel() for p in base_layer.parameters()))


# -----------------------------
# Simulate finetuning (sparse updates)
# -----------------------------
from copy import deepcopy

def perturb_subset_(module: nn.Module, frac: float = 0.03, magnitude: float = 1e-2, seed: int = 0):
    """In-place: perturb `frac` of elements in each float parameter.
    This simulates sparse finetuning, making XOR(base, finetuned) mostly zeros.
    """
    g = torch.Generator(device='cpu').manual_seed(seed)
    with torch.no_grad():
        for name, p in module.named_parameters():
            if not p.dtype.is_floating_point:
                continue
            numel = p.numel()
            k = max(1, int(numel * frac))
            idx = torch.randperm(numel, generator=g, device=p.device)[:k]
            flat = p.view(-1)
            noise = magnitude * torch.randn(k, generator=g, device=p.device, dtype=flat.dtype)
            flat[idx] += noise

finetuned_layer = deepcopy(base_layer)
perturb_subset_(finetuned_layer, frac=0.03, magnitude=1e-2, seed=123)

# Count changed elements (exact because we add noise)
changed, total = 0, 0
for (n1, p_base), (n2, p_ft) in zip(base_layer.named_parameters(), finetuned_layer.named_parameters()):
    assert n1 == n2
    if p_base.dtype.is_floating_point:
        diff = (p_base != p_ft).sum().item()
        changed += diff
        total += p_base.numel()
print(f"Changed elements: {changed} / {total} (~{changed/total*100:.2f}%)")


# -----------------------------
# XOR + Zstd compression helpers
# -----------------------------
def tensor_to_uint8_view(t: torch.Tensor) -> np.ndarray:
    """Return a NumPy uint8 *view* over tensor bytes (CPU)."""
    a = t.detach().cpu().numpy()
    return a.view(np.uint8)

def xor_uint8(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    assert a.dtype == np.uint8 and b.dtype == np.uint8 and a.size == b.size
    return np.bitwise_xor(a, b)

def compress_state_dict_with_xor_zstd(base_sd: dict, finetuned_sd: dict, level: int = 10):
    """Pack XOR(Zstd) deltas.
    Returns a dict with:
      - meta[name]: shape/dtype/numel/param_bytes/uint8_shape
      - deltas[name]: Zstd-compressed XOR bytes
      - stats: aggregate sizes
    """
    compressor = zstd.ZstdCompressor(level=level)
    meta, deltas = {}, {}
    stats = {'total_params_bytes': 0, 'total_delta_bytes': 0, 'total_compressed_bytes': 0, 'per_param': {}}

    for name, base_t in base_sd.items():
        ft_t = finetuned_sd[name]
        base_bytes = tensor_to_uint8_view(base_t)
        ft_bytes = tensor_to_uint8_view(ft_t)
        assert base_bytes.size == ft_bytes.size, name

        delta_u8 = xor_uint8(base_bytes, ft_bytes)
        raw_delta = delta_u8.tobytes()
        comp = compressor.compress(raw_delta)

        meta[name] = {
            'shape': list(ft_t.shape),
            'dtype_str': str(ft_t.detach().cpu().numpy().dtype.str),  # e.g. '<f4'
            'numel': ft_t.numel(),
            'param_bytes': ft_bytes.size,
            'uint8_shape': list(base_bytes.shape),
        }
        deltas[name] = comp

        stats['total_params_bytes'] += ft_bytes.size
        stats['total_delta_bytes'] += len(raw_delta)
        stats['total_compressed_bytes'] += len(comp)
        stats['per_param'][name] = {
            'param_bytes': ft_bytes.size,
            'delta_bytes': len(raw_delta),
            'compressed_bytes': len(comp),
            'compressed_ratio_vs_param': len(comp) / max(1, ft_bytes.size),
        }

    return {'meta': meta, 'deltas': deltas, 'zstd_level': level, 'stats': stats}

def reconstruct_finetuned_state_dict_with_timing(base_sd: dict, pkg: dict):
    """Decompress and XOR-apply with timing breakdown.
    Returns: (rec_sd, timing_dict) with keys: t_decompress, t_xor_apply
    """
    dctx = zstd.ZstdDecompressor()
    rec = {}
    t_decomp = 0.0
    t_xor = 0.0
    for name, meta in pkg['meta'].items():
        comp = pkg['deltas'][name]
        # Decompress
        t0 = perf_counter()
        raw_delta = dctx.decompress(comp)
        t_decomp += perf_counter() - t0
        delta_u8 = np.frombuffer(raw_delta, dtype=np.uint8).reshape(meta['uint8_shape'])

        # XOR apply
        base_bytes = tensor_to_uint8_view(base_sd[name])
        t1 = perf_counter()
        ft_bytes = np.bitwise_xor(base_bytes, delta_u8)  # shape-aligned XOR
        t_xor += perf_counter() - t1

        # Reinterpret as original dtype/shape
        arr = ft_bytes.view(np.dtype(meta['dtype_str'])).reshape(meta['shape']).copy()
        rec[name] = torch.from_numpy(arr)
    return rec, {'t_decompress': t_decomp, 't_xor_apply': t_xor}


# -----------------------------
# Build base/finetuned SDs and compress
# -----------------------------
base_sd = {k: v.detach().cpu().contiguous() for k, v in base_layer.state_dict().items()}
ft_sd   = {k: v.detach().cpu().contiguous() for k, v in finetuned_layer.state_dict().items()}

pkg = compress_state_dict_with_xor_zstd(base_sd, ft_sd, level=10)
print("=== Compression Stats ===")
print("Total param bytes  :", pkg['stats']['total_params_bytes'])
print("Total delta bytes  :", pkg['stats']['total_delta_bytes'])
print("Total compressed   :", pkg['stats']['total_compressed_bytes'])
print("Global ratio       :", pkg['stats']['total_compressed_bytes']/pkg['stats']['total_params_bytes'])
for name, st in list(pkg['stats']['per_param'].items())[:5]:
    print(f"- {name:40s} | param={st['param_bytes']} | comp={st['compressed_bytes']} | ratio={st['compressed_ratio_vs_param']:.4f}")

# Serialize to in-memory bytes (so we can measure "load" as deserialization;
# replace with actual disk I/O in your system if you want true file load timings).
pkg_bytes = pickle.dumps(pkg)


# -----------------------------
# End-to-end function with timers (load → reconstruct → load_state_dict → forward)
# -----------------------------
def e2e_forward_from_pkg_bytes(base_layer: nn.Module, pkg_bytes: bytes, x: torch.Tensor, device: str = 'cpu'):
    """End-to-end: load (deserialize) → decompress → XOR-apply → load_state_dict → forward.
    Returns (timings_dict, checksum).
    NOTE: 't_load_pkg' here measures pickle.loads (in-memory). Replace with disk I/O for on-disk timings.
    """
    # 0) Prepare inputs
    x = x.to(device)
    base_sd = {k: v.detach().cpu().contiguous() for k, v in base_layer.state_dict().items()}

    # 1) Load package (deserialize)
    t0 = perf_counter()
    pkg = pickle.loads(pkg_bytes)
    t_load = perf_counter() - t0

    # 2) Decompress + XOR apply (split timings)
    rec_sd, t_parts = reconstruct_finetuned_state_dict_with_timing(base_sd, pkg)

    # 3) Materialize a new layer and copy weights
    t2 = perf_counter()
    rec_layer = TinyTransformerLayer().to(device).eval()
    with torch.no_grad():
        for k, p in rec_layer.state_dict().items():
            p.copy_(rec_sd[k].to(device))
    t_state = perf_counter() - t2

    # 4) Forward once (demo)
    t3 = perf_counter()
    with torch.no_grad():
        y = rec_layer(x)
    t_fwd = perf_counter() - t3

    timings = {
        't_load_pkg': t_load,
        't_decompress': t_parts['t_decompress'],
        't_xor_apply': t_parts['t_xor_apply'],
        't_state_load': t_state,
        't_forward': t_fwd,
        't_end_to_end': t_load + t_parts['t_decompress'] + t_parts['t_xor_apply'] + t_state + t_fwd,
    }
    checksum = float(y.abs().sum().detach().cpu())  # simple checksum to ensure consistent output
    return timings, checksum


# -----------------------------
# Micro-benchmark: mean ± std per stage
# -----------------------------
def benchmark_e2e_breakdown(base_layer, pkg_bytes: bytes, x: torch.Tensor, device: str = 'cpu', repeats: int = 7):
    """Repeat the E2E pipeline `repeats` times and report mean/std for each stage."""
    keys = ['t_load_pkg','t_decompress','t_xor_apply','t_state_load','t_forward','t_end_to_end']
    buf = {k: [] for k in keys}
    checks = []
    for _ in range(repeats):
        timings, checksum = e2e_forward_from_pkg_bytes(base_layer, pkg_bytes, x, device=device)
        for k in keys:
            buf[k].append(timings[k])
        checks.append(checksum)
    stats = {k: {'mean': float(np.mean(v)), 'std': float(np.std(v))} for k, v in buf.items()}
    return stats, checks


# -----------------------------
# What-if speed windows
# -----------------------------
def print_what_if_bounds(stats):
    """Compute what-if bounds from measured means:
       - Bound A: skip global XOR (still decompress) → compute on decompressed XOR bytes
       - Bound B: skip XOR + decompress (theoretical Zstd-domain compute)
    """
    t_load = stats['t_load_pkg']['mean']
    t_decomp = stats['t_decompress']['mean']
    t_xor = stats['t_xor_apply']['mean']
    t_state = stats['t_state_load']['mean']
    t_fwd = stats['t_forward']['mean']
    t_e2e = stats['t_end_to_end']['mean']

    t_boundA = t_load + t_decomp + t_state + t_fwd       # skip XOR
    t_boundB = t_load + t_state + t_fwd                  # skip XOR + decompress (theoretical)

    def fmt_speedup(old, new):
        return f"{old/new:.2f}× faster (↓{(1 - new/old)*100:.1f}%)" if new > 0 else "∞"

    print("=== What-if Speed Windows (vs current E2E mean) ===")
    print(f"Baseline E2E: {t_e2e:.6f}s")
    print(f"Skip XOR (keep decompress): {t_boundA:.6f}s  → {fmt_speedup(t_e2e, t_boundA)}")
    print(f"Skip XOR+Decompress (Zstd-domain compute): {t_boundB:.6f}s  → {fmt_speedup(t_e2e, t_boundB)}")
    print("\nInterpretation:")
    print("- (Baseline → Bound A) is the best-case improvement if we avoid the global XOR-apply step by")
    print("  computing directly on decompressed XOR bytes (tile JIT patching, patch-based kernels, etc.).")
    print("- (Baseline → Bound B) is a theoretical ceiling if we could avoid both decompression and XOR")
    print("  (i.e., compute in Zstd domain); not realistic today, but shows maximum headroom.")


# -----------------------------
# Per-parameter hotspots (identify where XOR time concentrates)
# -----------------------------
def per_param_decompress_xor_timing(base_sd: dict, pkg: dict, topk: int = 5):
    """Per-parameter timing for decompress and XOR-apply to identify hotspots."""
    dctx = zstd.ZstdDecompressor()
    records = []
    for name, meta in pkg['meta'].items():
        comp = pkg['deltas'][name]
        t0 = perf_counter(); raw = dctx.decompress(comp); t_de = perf_counter()-t0
        delta = np.frombuffer(raw, dtype=np.uint8).reshape(meta['uint8_shape'])
        t1 = perf_counter(); _ = np.bitwise_xor(tensor_to_uint8_view(base_sd[name]), delta); t_x = perf_counter()-t1
        records.append((name, meta['param_bytes'], t_de, t_x))
    # sort by XOR time (or total)
    records.sort(key=lambda r: r[3], reverse=True)
    print(f"=== Top {topk} params by XOR-apply time ===")
    for name, sz, tde, tx in records[:topk]:
        print(f"{name:40s} | bytes={sz:8d} | decompress={tde:.6f}s | xor={tx:.6f}s")
    return records


# -----------------------------
# XOR sparsity & naive sparse storage estimate
# -----------------------------
def estimate_xor_sparsity_and_storage(pkg: dict):
    """Measure XOR sparsity (non-zero bytes) and estimate a simple sparse storage cost.
    Model: store (index:uint32 + value:uint8) per non-zero byte.
    """
    dctx = zstd.ZstdDecompressor()
    total_bytes = 0
    total_nz = 0
    for name, meta in pkg['meta'].items():
        raw = dctx.decompress(pkg['deltas'][name])
        arr = np.frombuffer(raw, dtype=np.uint8)
        nz = int((arr != 0).sum())
        total_bytes += arr.size
        total_nz += nz
    density = total_nz / max(1, total_bytes)
    dense_bytes = total_bytes  # 1 byte per entry
    # Sparse: 4 bytes for index + 1 byte value (very simple model)
    sparse_bytes = total_nz * (4 + 1)
    return {
        'xor_total_bytes': total_bytes,
        'xor_nonzero_bytes': total_nz,
        'nonzero_density': density,
        'dense_storage_bytes': dense_bytes,
        'simple_sparse_storage_bytes': sparse_bytes,
        'sparse_over_dense_ratio': sparse_bytes / max(1, dense_bytes),
    }


# -----------------------------
# Run benchmarks
# -----------------------------
# End-to-end breakdown
x_bench = torch.randn(4, 16, 64)
stats, checks = benchmark_e2e_breakdown(base_layer, pkg_bytes, x_bench, device=device, repeats=7)
print("=== Component-wise Timing (seconds): mean ± std over 7 runs ===")
for k, s in stats.items():
    print(f"{k:>14s}: {s['mean']:.6f} ± {s['std']:.6f}")
print("Sanity checksums (first 3):", checks[:3])

# What-if bounds
print_what_if_bounds(stats)

# Hotspots
_ = per_param_decompress_xor_timing(base_sd, pkg, topk=5)

# XOR sparsity summary
sparsity = estimate_xor_sparsity_and_storage(pkg)
print("\n=== XOR Sparsity Estimate ===")
for k, v in sparsity.items():
    print(f"{k:28s}: {v}")
print("\nNOTE: If non-zero density is low, consider block-sparse indexing (e.g., 256B/4KB blocks), "
      "RLE on zero-runs, or bitplane packing so **many XOR adapters can co-reside on a single GPU.**")





In [None]:
# -----------------------------
# TODO: compute on decompressed XOR without global XOR
# -----------------------------
def e2e_forward_on_delta_no_full_reconstruct_TODO(base_layer: nn.Module, pkg_bytes: bytes, x: torch.Tensor,
                                                  *, device: str = 'cpu', tile_rows: int = 64, tile_cols: int = 64):
    """TODO (design & prototype): End-to-end pipeline that computes without full reconstruction.

    Goal for a first prototype:
      1) Deserialize pkg.
      2) Convert a single Linear layer's weight (or a toy GEMM) into a tiled compute:
         - For each tile, decode only the XOR bytes needed for that tile.
         - XOR just-in-time into a small working buffer, run GEMM on that tile.
      3) Measure timings: t_partial_decompress, t_tile_xor_patch, t_forward_tile, total.

    NOTE: This stub does not implement the actual tiled kernel—it's where the optimization work begins.
    """

    t0 = perf_counter()
    pkg = pickle.loads(pkg_bytes)
    print(pkg)
    t_load = perf_counter() - t0

    print(base_layer.size())

    # Pseudocode structure (replace with a working tiled GEMM path):
    raise NotImplementedError(
        "Prototype here: map layer weights to tiles, for each tile decompress-needed bytes "
        "→ XOR into a small buffer → microkernel(GEMM) on that buffer. "
        "Add timers: t_partial_decompress, t_tile_xor_patch, t_forward_tile."
    )


# Testing a TODO prototype

temp = e2e_forward_on_delta_no_full_reconstruct_TODO(base_layer, pkg_bytes, x_bench, device=device)
print(temp)






In [1]:
# If needed, install deps:
# !pip install torch zstandard

import math, io, pickle, time
from time import perf_counter
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
    import zstandard as zstd
except ImportError as e:
    raise RuntimeError("Please install `zstandard` (pip install zstandard) before running this cell.")

torch.manual_seed(42)
np.random.seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


# -----------------------------
# Tiny Transformer layer
# -----------------------------
class TinyTransformerLayer(nn.Module):
    """Small block with MHA + FFN, sufficient for demonstration."""
    def __init__(self, d_model=64, nhead=4, dim_feedforward=128, dropout=0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True, dropout=dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.GELU(),
            nn.Linear(dim_feedforward, d_model),
        )
    def forward(self, x, attn_mask=None):
        h = self.ln1(x)
        a, _ = self.self_attn(h, h, h, attn_mask=attn_mask)
        x = x + a
        h2 = self.ln2(x)
        y = self.ff(h2)
        return x + y

base_layer = TinyTransformerLayer().eval()
print("Tiny layer params:", sum(p.numel() for p in base_layer.parameters()))


# -----------------------------
# Simulate finetuning (sparse updates)
# -----------------------------
from copy import deepcopy

def perturb_subset_(module: nn.Module, frac: float = 0.03, magnitude: float = 1e-2, seed: int = 0):
    """In-place: perturb `frac` of elements in each float parameter.
    This simulates sparse finetuning, making XOR(base, finetuned) mostly zeros.
    """
    g = torch.Generator(device='cpu').manual_seed(seed)
    with torch.no_grad():
        for name, p in module.named_parameters():
            if not p.dtype.is_floating_point:
                continue
            numel = p.numel()
            k = max(1, int(numel * frac))
            idx = torch.randperm(numel, generator=g, device=p.device)[:k]
            flat = p.view(-1)
            noise = magnitude * torch.randn(k, generator=g, device=p.device, dtype=flat.dtype)
            flat[idx] += noise

# Two finetuned variants to simulate two tenants
ft1 = deepcopy(base_layer); perturb_subset_(ft1, frac=0.03, magnitude=1e-2, seed=123)
ft2 = deepcopy(base_layer); perturb_subset_(ft2, frac=0.05, magnitude=1e-2, seed=456)

# Count changed elements (vs base) for the first one
changed, total = 0, 0
for (n1, p_base), (n2, p_ft) in zip(base_layer.named_parameters(), ft1.named_parameters()):
    assert n1 == n2
    if p_base.dtype.is_floating_point:
        diff = (p_base != p_ft).sum().item()
        changed += diff
        total += p_base.numel()
print(f"[ft1] Changed elements: {changed} / {total} (~{changed/total*100:.2f}%)")


# -----------------------------
# XOR + Zstd compression helpers (pack entire state_dict)
# -----------------------------
def tensor_to_uint8_view(t: torch.Tensor) -> np.ndarray:
    """Return a NumPy uint8 *view* over tensor bytes (CPU)."""
    a = t.detach().cpu().numpy()
    return a.view(np.uint8)

def xor_uint8(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    assert a.dtype == np.uint8 and b.dtype == np.uint8 and a.size == b.size
    return np.bitwise_xor(a, b)

def compress_state_dict_with_xor_zstd(base_sd: dict, finetuned_sd: dict, level: int = 10):
    """Pack XOR(Zstd) deltas.
    Returns a dict with:
      - meta[name]: shape/dtype/numel/param_bytes/uint8_shape
      - deltas[name]: Zstd-compressed XOR bytes
      - stats: aggregate sizes
    """
    compressor = zstd.ZstdCompressor(level=level)
    meta, deltas = {}, {}
    stats = {'total_params_bytes': 0, 'total_delta_bytes': 0, 'total_compressed_bytes': 0, 'per_param': {}}

    for name, base_t in base_sd.items():
        ft_t = finetuned_sd[name]
        base_bytes = tensor_to_uint8_view(base_t)
        ft_bytes = tensor_to_uint8_view(ft_t)
        assert base_bytes.size == ft_bytes.size, name

        delta_u8 = xor_uint8(base_bytes, ft_bytes)
        raw_delta = delta_u8.tobytes()
        comp = compressor.compress(raw_delta)

        meta[name] = {
            'shape': list(ft_t.shape),
            'dtype_str': str(ft_t.detach().cpu().numpy().dtype.str),  # e.g. '<f4'
            'numel': ft_t.numel(),
            'param_bytes': ft_bytes.size,
            'uint8_shape': list(base_bytes.shape),
        }
        deltas[name] = comp

        stats['total_params_bytes'] += ft_bytes.size
        stats['total_delta_bytes'] += len(raw_delta)
        stats['total_compressed_bytes'] += len(comp)
        stats['per_param'][name] = {
            'param_bytes': ft_bytes.size,
            'delta_bytes': len(raw_delta),
            'compressed_bytes': len(comp),
            'compressed_ratio_vs_param': len(comp) / max(1, ft_bytes.size),
        }

    return {'meta': meta, 'deltas': deltas, 'zstd_level': level, 'stats': stats}

def reconstruct_finetuned_state_dict_with_timing(base_sd: dict, pkg: dict):
    """Decompress and XOR-apply with timing breakdown.
    Returns: (rec_sd, timing_dict) with keys: t_decompress, t_xor_apply
    """
    dctx = zstd.ZstdDecompressor()
    rec = {}
    t_decomp = 0.0
    t_xor = 0.0
    for name, meta in pkg['meta'].items():
        comp = pkg['deltas'][name]
        # Decompress
        t0 = perf_counter()
        raw_delta = dctx.decompress(comp)
        t_decomp += perf_counter() - t0
        delta_u8 = np.frombuffer(raw_delta, dtype=np.uint8).reshape(meta['uint8_shape'])

        # XOR apply
        base_bytes = tensor_to_uint8_view(base_sd[name])
        t1 = perf_counter()
        ft_bytes = np.bitwise_xor(base_bytes, delta_u8)  # shape-aligned XOR
        t_xor += perf_counter() - t1

        # Reinterpret as original dtype/shape
        arr = ft_bytes.view(np.dtype(meta['dtype_str'])).reshape(meta['shape']).copy()
        rec[name] = torch.from_numpy(arr)
    return rec, {'t_decompress': t_decomp, 't_xor_apply': t_xor}


# Build packages for two tenants
base_sd = {k: v.detach().cpu().contiguous() for k, v in base_layer.state_dict().items()}
pkg1 = compress_state_dict_with_xor_zstd(base_sd, {k: v.detach().cpu().contiguous() for k, v in ft1.state_dict().items()}, level=10)
pkg2 = compress_state_dict_with_xor_zstd(base_sd, {k: v.detach().cpu().contiguous() for k, v in ft2.state_dict().items()}, level=10)
pkg_list = [pkg1, pkg2]
print("Tenants:", len(pkg_list))
print("pkg1 global ratio:", pkg1['stats']['total_compressed_bytes']/pkg1['stats']['total_params_bytes'])
print("pkg2 global ratio:", pkg2['stats']['total_compressed_bytes']/pkg2['stats']['total_params_bytes'])


# -----------------------------
# Standard E2E baseline (deserialize → decompress → XOR → load_state → forward)
# -----------------------------
def e2e_forward_from_pkg(base_layer: nn.Module, pkg: dict, x: torch.Tensor, device: str = 'cpu'):
    """End-to-end: decompress → XOR-apply → load_state_dict → forward (pkg already in memory)."""
    # Inputs
    x = x.to(device)
    base_sd_local = {k: v.detach().cpu().contiguous() for k, v in base_layer.state_dict().items()}

    # Decompress + XOR apply
    rec_sd, t_parts = reconstruct_finetuned_state_dict_with_timing(base_sd_local, pkg)

    # Materialize a new layer and copy weights
    t2 = perf_counter()
    rec_layer = TinyTransformerLayer().to(device).eval()
    with torch.no_grad():
        for k, p in rec_layer.state_dict().items():
            p.copy_(rec_sd[k].to(device))
    t_state = perf_counter() - t2

    # Forward once
    t3 = perf_counter()
    with torch.no_grad():
        y = rec_layer(x)
    t_fwd = perf_counter() - t3

    timings = {
        't_decompress': t_parts['t_decompress'],
        't_xor_apply': t_parts['t_xor_apply'],
        't_state_load': t_state,
        't_forward': t_fwd,
        't_end_to_end': t_parts['t_decompress'] + t_parts['t_xor_apply'] + t_state + t_fwd,
    }
    checksum = float(y.abs().sum().detach().cpu())
    return timings, checksum


# -----------------------------
# ====== NEW: Multi-tenant serving with base-compute reuse (LoRA-style idea) ======
# We only patch ff.0.weight for demonstration: reuse base forward up to h2,
# compute y_base_ff1 once, add tenant-specific y_delta = ΔW @ x, then finish the layer.
# -----------------------------

def sparse_delta_from_pkg_for_linear(base_sd: dict, pkg: dict, param_name: str):
    """Extract a sparse ΔW for a Linear weight from XOR(Zstd) pkg by only decoding changed elements.

    Returns:
      rows: np.ndarray[int64], cols: np.ndarray[int64], delta_vals: np.ndarray[float32], (out_features, in_features)
      and timing dict: {'t_param_decompress', 't_sparse_extract'}
    """
    assert param_name in pkg['meta'], f"{param_name} not in pkg"
    meta = pkg['meta'][param_name]
    dctx = zstd.ZstdDecompressor()

    # Decompress only this parameter's XOR bytes
    t0 = perf_counter()
    raw = dctx.decompress(pkg['deltas'][param_name])
    t_decomp = perf_counter() - t0
    delta_u8 = np.frombuffer(raw, dtype=np.uint8)

    # Identify changed elements (any byte in the element changed)
    itemsize = np.dtype(meta['dtype_str']).itemsize  # e.g., 4 for float32
    # Bytes -> element indices
    changed_elem_idx = np.unique(np.flatnonzero(delta_u8) // itemsize)

    t1 = perf_counter()
    if changed_elem_idx.size == 0:
        # Nothing changed
        out_features, in_features = meta['shape'][0], meta['shape'][1]
        return (np.array([], dtype=np.int64),
                np.array([], dtype=np.int64),
                np.array([], dtype=np.float32),
                (out_features, in_features),
                {'t_param_decompress': t_decomp, 't_sparse_extract': 0.0})

    # Vectorized element-wise XOR only on the changed rows
    base_bytes_all = tensor_to_uint8_view(base_sd[param_name]).reshape(-1, itemsize)
    delta_bytes_all = delta_u8.reshape(-1, itemsize)
    base_sel = base_bytes_all[changed_elem_idx]
    delta_sel = delta_bytes_all[changed_elem_idx]
    ft_sel_bytes = np.bitwise_xor(base_sel, delta_sel)
    ft_vals = ft_sel_bytes.view(np.dtype(meta['dtype_str'])).reshape(-1)
    base_vals = base_sd[param_name].view(-1).detach().cpu().numpy()[changed_elem_idx]
    delta_vals = (ft_vals - base_vals).astype(np.float32)

    out_features, in_features = meta['shape'][0], meta['shape'][1]
    rows = (changed_elem_idx // in_features).astype(np.int64)
    cols = (changed_elem_idx % in_features).astype(np.int64)
    t_sparse_extract = perf_counter() - t1

    return rows, cols, delta_vals, (out_features, in_features), \
           {'t_param_decompress': t_decomp, 't_sparse_extract': t_sparse_extract}


def apply_sparse_delta_linear(h2: torch.Tensor, rows, cols, delta_vals, shape, device='cpu'):
    """Apply y_delta = ΔW @ x for a Linear(out=in_features->out_features) over a 3D input (B, L, in_features).
       Returns y_delta with shape (B, L, out_features).
    """
    out_features, in_features = shape
    if len(delta_vals) == 0:
        return torch.zeros(h2.size(0), h2.size(1), out_features, device=device)

    indices = torch.tensor(np.vstack([rows, cols]), dtype=torch.long, device=device)  # (2, nnz)
    values = torch.tensor(delta_vals, dtype=torch.float32, device=device)
    deltaW = torch.sparse_coo_tensor(indices, values, size=(out_features, in_features), device=device).coalesce()

    X = h2.to(device).reshape(-1, in_features)       # (B*L, in)
    Y_delta = torch.sparse.mm(deltaW, X.T).T         # (B*L, out)
    return Y_delta.reshape(h2.size(0), h2.size(1), out_features)


def multitenant_forward_ff0_xor(base_layer: nn.Module, pkg_list: list[dict], x: torch.Tensor, device='cpu'):
    """Multi-tenant serving prototype:
       - Reuse base compute up to h2 and y_base_ff1.
       - For each tenant, build sparse ΔW (only for ff.0.weight), compute y_delta = ΔW @ h2,
         then finish GELU -> ff.2 -> residual to get final output.
       Returns: outputs (list of tensors), timing dict with shared/per-tenant breakdowns.
    """
    base_layer = base_layer.to(device).eval()
    x = x.to(device)

    # Shared base forward up to h2 and y_base_ff1
    t_shared0 = perf_counter()
    with torch.no_grad():
        h = base_layer.ln1(x)
        a, _ = base_layer.self_attn(h, h, h)
        x1 = x + a                             # residual before FFN
        h2 = base_layer.ln2(x1)
        y_base_ff1 = base_layer.ff[0](h2)      # (B, L, dim_ff)
    t_shared = perf_counter() - t_shared0

    # Pre-bind layers reused for finishing path
    gelu = base_layer.ff[1]                    # GELU
    ff2  = base_layer.ff[2]                    # second linear

    # Build base_sd once (CPU) for delta extraction
    base_sd_local = {k: v.detach().cpu().contiguous() for k, v in base_layer.state_dict().items()}

    outputs = []
    per_tenant = []
    for pkg in pkg_list:
        t0 = perf_counter()
        # Only use ff.0.weight delta
        param_name = 'ff.0.weight'
        rows, cols, delta_vals, shape, t_part = sparse_delta_from_pkg_for_linear(base_sd_local, pkg, param_name)

        # Sparse matmul: y_delta = ΔW @ h2
        t1 = perf_counter()
        y_delta = apply_sparse_delta_linear(h2, rows, cols, delta_vals, shape, device=device)
        t_sparse_mm = perf_counter() - t1

        # Merge into ff.0 output, then finish FFN + residual
        t2 = perf_counter()
        with torch.no_grad():
            y_ff1 = y_base_ff1 + y_delta
            v = gelu(y_ff1)
            w = ff2(v)
            y_out = x1 + w
        t_finish = perf_counter() - t2

        outputs.append(y_out)
        per_tenant.append({
            't_param_decompress': t_part['t_param_decompress'],
            't_sparse_extract': t_part['t_sparse_extract'],
            't_sparse_mm': t_sparse_mm,
            't_finish': t_finish,
            't_total_per_tenant': (t_part['t_param_decompress'] + t_part['t_sparse_extract'] +
                                   t_sparse_mm + t_finish),
            'nnz_elements': int(len(delta_vals)),
        })

    timings = {'t_shared_base': t_shared, 'per_tenant': per_tenant}
    return outputs, timings


# -----------------------------
# Demo inputs
# -----------------------------
B, L, D = 2, 16, 64
x_demo = torch.randn(B, L, D)


# -----------------------------
# 1) Baseline E2E per-tenant (full reconstruct) for reference
# -----------------------------
print("\n=== Baseline E2E (per-tenant, full reconstruct) ===")
for i, pkg in enumerate(pkg_list, 1):
    t, _ = e2e_forward_from_pkg(base_layer, pkg, x_demo, device=device)
    print(f"[tenant {i}] decompress={t['t_decompress']:.6f}s  xor={t['t_xor_apply']:.6f}s  "
          f"state={t['t_state_load']:.6f}s  fwd={t['t_forward']:.6f}s  E2E={t['t_end_to_end']:.6f}s")


# -----------------------------
# 2) Multi-tenant reuse (single shared base path)
# -----------------------------
print("\n=== Multi-tenant reuse on ff.0.weight (shared base compute) ===")
outs, timing_mt = multitenant_forward_ff0_xor(base_layer, pkg_list, x_demo, device=device)
print(f"shared t_base={timing_mt['t_shared_base']:.6f}s")
for i, t in enumerate(timing_mt['per_tenant'], 1):
    print(f"[tenant {i}] dec={t['t_param_decompress']:.6f}s  sparse_extract={t['t_sparse_extract']:.6f}s  "
          f"sparse_mm={t['t_sparse_mm']:.6f}s  finish={t['t_finish']:.6f}s  total_tenant={t['t_total_per_tenant']:.6f}s  "
          f"nnz={t['nnz_elements']}")





In [17]:
# ==========================================
# Memory impact estimator for multi-tenant XOR
# ==========================================

def _bytes_str(n):
    units = ["B","KiB","MiB","GiB","TiB"]
    i = 0
    x = float(n)
    while x >= 1024.0 and i < len(units)-1:
        x /= 1024.0
        i += 1
    return f"{x:.2f} {units[i]}"

def _floating_param_bytes(sd: dict):
    """Sum of bytes of all floating-point parameters in a state_dict (CPU tensors)."""
    total = 0
    for t in sd.values():
        if torch.tensor([], dtype=t.dtype).dtype.is_floating_point:
            total += t.numel() * t.element_size()
    return total

def _count_changed_elements(base_sd: dict, ft_sd: dict):
    """Count changed float *elements* (not bytes) between base and finetuned."""
    changed, total = 0, 0
    for k in base_sd.keys():
        b = base_sd[k]
        f = ft_sd[k]
        if torch.tensor([], dtype=b.dtype).dtype.is_floating_point:
            bb = b.detach().cpu()
            ff = f.detach().cpu()
            total += bb.numel()
            changed += (bb != ff).sum().item()
    return int(changed), int(total)

def summarize_vram_for_multi_tenant(base_sd, pkg_list, finetuned_modules, tenants=10):
    """
    Estimate VRAM for different strategies if all tenant deltas are kept on GPU memory.
    - Strategy A: keep *decompressed XOR bytes* (uint8) for each tenant
    - Strategy B: keep *compressed Zstd* bytes for each tenant
    - Strategy C: keep *sparse COO* (rows, cols, values) of changed elements (32-bit vs 64-bit indices)

    Notes:
    - Decompressed XOR bytes size ~= full-precision weight bytes (because XOR is per-byte).
    - This does NOT include activation buffers, optimizer states, KV cache, etc. It's weights/deltas only.
    """

    # Base model size (float weights, one copy on GPU)
    base_bytes = _floating_param_bytes(base_sd)

    # Use per-tenant package stats (assume similar across tenants)
    comp_per_tenant = np.mean([p['stats']['total_compressed_bytes'] for p in pkg_list])
    rawxor_per_tenant = np.mean([p['stats']['total_delta_bytes'] for p in pkg_list])  # equals float bytes

    # Changed elements fraction from provided finetuned modules (ft1, ft2, ...)
    ch_fracs = []
    for ft in finetuned_modules:
        ft_sd_local = {k: v.detach().cpu().contiguous() for k, v in ft.state_dict().items()}
        ch, tot = _count_changed_elements(base_sd, ft_sd_local)
        ch_fracs.append(ch / max(1, tot))
    avg_changed_frac = float(np.mean(ch_fracs)) if ch_fracs else 0.0

    # Total float elements in the model (assume homogeneous dtype for estimate)
    total_elems = sum(int(t.numel()) for t in base_sd.values()
                      if torch.tensor([], dtype=t.dtype).dtype.is_floating_point)

    nnz_elems_est = int(round(avg_changed_frac * total_elems))

    # Sparse COO memory per element:
    # - values: float32 (4B)
    # - indices: (row, col). Use 32-bit or 64-bit indices; many frameworks use 32-bit if dims < 2^31.
    per_elem_coo32 = 4 + 4 + 4   # 12 bytes
    per_elem_coo64 = 4 + 8 + 8   # values fp32 + two int64 indices = 20 bytes

    sparse_coo32_per_tenant = nnz_elems_est * per_elem_coo32
    sparse_coo64_per_tenant = nnz_elems_est * per_elem_coo64

    # Totals on GPU (weights + tenants' deltas in the chosen form)
    total_A = base_bytes + tenants * rawxor_per_tenant         # decompressed XOR kept resident
    total_B = base_bytes + tenants * comp_per_tenant           # compressed kept resident (for on-the-fly decode/compute)
    total_C32 = base_bytes + tenants * sparse_coo32_per_tenant # sparse COO 32-bit idx
    total_C64 = base_bytes + tenants * sparse_coo64_per_tenant # sparse COO 64-bit idx

    print("\n=== VRAM usage estimate (weights/deltas only) ===")
    print(f"Base model weights (1 copy, float):   {_bytes_str(base_bytes)}")
    print(f"Avg per-tenant compressed XOR (Zstd): {_bytes_str(comp_per_tenant)}  "
          f"(ratio={comp_per_tenant/max(1, rawxor_per_tenant):.4f} vs raw)")
    print(f"Per-tenant *decompressed* XOR bytes:  {_bytes_str(rawxor_per_tenant)}  (≈ full model bytes)")

    print(f"\nChanged elements fraction (avg over provided tenants): {avg_changed_frac*100:.2f}%")
    print(f"  → Estimated nnz elements per tenant: {nnz_elems_est:,}")
    print(f"  → Sparse COO per tenant (idx32): {_bytes_str(sparse_coo32_per_tenant)}")
    print(f"  → Sparse COO per tenant (idx64): {_bytes_str(sparse_coo64_per_tenant)}")

    print(f"\n--- If you keep deltas for {tenants} tenants on GPU ---")
    print(f"A) Keep *decompressed XOR* (uint8): {_bytes_str(total_A)}   "
          f"[= base {_bytes_str(base_bytes)} + {tenants} × {_bytes_str(rawxor_per_tenant)}]")
    print(f"B) Keep *compressed* (Zstd level {pkg_list[0]['zstd_level']}): {_bytes_str(total_B)}   "
          f"[= base {_bytes_str(base_bytes)} + {tenants} × {_bytes_str(comp_per_tenant)}]")
    print(f"C) Keep *sparse COO* (idx32): {_bytes_str(total_C32)}   "
          f"[= base {_bytes_str(base_bytes)} + {tenants} × {_bytes_str(sparse_coo32_per_tenant)}]")
    print(f"   Keep *sparse COO* (idx64): {_bytes_str(total_C64)}")

    print("\nKey takeaway:")
    print("- Decompressed XOR per tenant ≈ one full model's worth of bytes. "
          "With many tenants, VRAM blows up linearly (A).")
    print("- Compressed bytes are small (B) but require decode-time compute (decompress→operator fusion or compute in compressed domain).")
    print("- Sparse COO (C) scales with actual changed fraction; when changed is rare, it saves VRAM, but requires index and sparse kernel support.")

# ---- Run the summary for N=10 tenants ----
base_sd_local = {k: v.detach().cpu().contiguous() for k, v in base_layer.state_dict().items()}
summarize_vram_for_multi_tenant(
    base_sd=base_sd_local,
    pkg_list=pkg_list,
    finetuned_modules=[ft1, ft2],   # use the real sparsity of the two tenants to estimate
    tenants=10
)




## Takeaways
- XOR against a base + Zstd yields **lossless** and often highly compressible deltas when updates are sparse.
- We can **exactly** reconstruct the finetuned layer and verify forward-pass equivalence.
- The next step (our main research problem) is to **avoid full decompression to avoid memory footprint peak and ensure the computation efficiency**.
