# KVShuttle: GPU-Native Compression Calibration (Zero-Copy)

This notebook measures the **true GPU kernel time** for KV cache compression by:
1. Pre-loading all data to GPU memory before timing starts
2. Running compression entirely on GPU tensors (no numpy round-trips)
3. Using `torch.cuda.Event` for microsecond-precision GPU timing

This eliminates the CPU↔GPU copy overhead that dominated the v1 results (1.5x speedup).
The goal is to measure what a real production system (vLLM, TRT-LLM) would achieve.

---

## GPU Runtime Options

| GPU | Colab Tier | Est. Runtime | Notes |
|-----|-----------|-------------|-------|
| **T4** (16 GB) | Free | ~15-25 min | Baseline. Seq ≤ 2048 only |
| **A100** (40/80 GB) | Pro+ | ~5-10 min | Adds seq 4096 + large models |
| **H100** (80 GB) | Enterprise / self-hosted | ~3-7 min | Native FP8, adds high-BW tests |

**Instructions**:
1. Runtime > Change runtime type > Select your GPU
2. Run All
3. Download `gpu_calibration_results_{gpu_tier}.json` from the file browser
4. To combine results from multiple GPUs, see the last cell

In [None]:
# Cell 1: Check GPU and setup
import torch
import numpy as np
import time
import json
import statistics

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    props = torch.cuda.get_device_properties(0)
    vram_gb = props.total_mem / 1e9
    print(f"GPU: {gpu_name}")
    print(f"VRAM: {vram_gb:.1f} GB")
    device = torch.device("cuda")

    # Detect GPU tier
    gpu_name_lower = gpu_name.lower()
    if "h100" in gpu_name_lower:
        GPU_TIER = "h100"
    elif "a100" in gpu_name_lower:
        GPU_TIER = "a100"
    elif "l4" in gpu_name_lower:
        GPU_TIER = "l4"
    else:
        GPU_TIER = "t4"
    print(f"Detected GPU tier: {GPU_TIER}")

    # Memory bandwidth lookup (GB/s) for metadata
    MEMORY_BW_LOOKUP = {
        "t4": 320,
        "l4": 300,
        "a100": 2039,
        "h100": 3350,
    }

    # Collect GPU properties for multi-GPU comparison
    gpu_info = {
        "name": gpu_name,
        "gpu_tier": GPU_TIER,
        "compute_capability": f"{props.major}.{props.minor}",
        "total_memory_gb": round(vram_gb, 1),
        "memory_bandwidth_gbps": MEMORY_BW_LOOKUP.get(GPU_TIER, 0),
        "multi_processor_count": props.multi_processor_count,
        "cuda_version": torch.version.cuda,
        "pytorch_version": torch.__version__,
    }
    print(f"Compute capability: {gpu_info['compute_capability']}")
    print(f"SMs: {gpu_info['multi_processor_count']}")
    print(f"Memory bandwidth: {gpu_info['memory_bandwidth_gbps']} GB/s")
    print(f"CUDA: {gpu_info['cuda_version']}")
else:
    print("WARNING: No CUDA GPU! Go to Runtime > Change runtime type > T4 GPU")
    gpu_name = "CPU"
    GPU_TIER = "cpu"
    gpu_info = {"name": "CPU", "gpu_tier": "cpu"}
    device = torch.device("cpu")

# Config — adaptive based on GPU tier
MODELS = {
    "llama-3.2-3b": (28, 8, 128),
    "qwen2.5-7b":   (28, 4, 128),
    "llama-3.1-8b": (32, 8, 128),
    "phi-3.5-mini":  (32, 32, 96),
}

# Add larger models when VRAM permits
if torch.cuda.is_available() and vram_gb >= 70:
    MODELS["llama-3.1-70b"] = (80, 8, 128)
    print(f"Added llama-3.1-70b config (VRAM={vram_gb:.0f} GB >= 70 GB)")

# Sequence lengths — add 4096 on high-VRAM GPUs
SEQ_LENS = [256, 512, 1024, 2048]
if torch.cuda.is_available() and vram_gb >= 40:
    SEQ_LENS.append(4096)
    print(f"Added seq_len=4096 (VRAM={vram_gb:.0f} GB >= 40 GB)")

# Bandwidth test range — add high-BW points for H100 NVLink
BANDWIDTHS_GBPS = [1, 5, 10, 25, 50, 100, 200, 400]
if GPU_TIER == "h100":
    BANDWIDTHS_GBPS.extend([800, 1600])
    print(f"Added 800/1600 Gbps bandwidth points for H100 NVLink")

WARMUP = 5

# More repeats on faster GPUs for stable measurements
if GPU_TIER in ("a100", "h100"):
    REPEATS = 50
    print(f"Using REPEATS={REPEATS} (faster kernels need more samples)")
else:
    REPEATS = 20

print(f"\nConfig: {len(MODELS)} models, seq_lens={SEQ_LENS}, "
      f"bandwidths={BANDWIDTHS_GBPS}, repeats={REPEATS}")

## Part 1: GPU-Native Compressor Kernels (Zero-Copy)

These functions operate **entirely on GPU tensors**. No numpy, no CPU↔GPU copies.
This is what runs inside a real serving system where KV cache lives on GPU.

In [None]:
# Cell 2: GPU-native INT8 quantization (zero-copy)

def gpu_int8_compress(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Per-layer symmetric INT8 quantization. Input/output stay on GPU."""
    qmax = 127
    num_layers = x.shape[0]
    flat = x.reshape(num_layers, -1)
    amax = flat.abs().amax(dim=1)
    amax = torch.where(amax == 0, torch.ones_like(amax), amax)
    scales = amax / qmax
    scales_exp = scales.view(num_layers, *([1] * (x.ndim - 1)))
    quantized = (x / scales_exp).round().clamp(-qmax, qmax).to(torch.int8)
    return quantized, scales

def gpu_int8_decompress(quantized: torch.Tensor, scales: torch.Tensor) -> torch.Tensor:
    """Dequantize INT8 → FP16. Input/output stay on GPU."""
    num_layers = quantized.shape[0]
    s_exp = scales.view(num_layers, *([1] * (quantized.ndim - 1)))
    return (quantized.float() * s_exp).half()

print("GPU-native INT8 defined.")

In [None]:
# Cell 3: GPU-native KIVI 2-bit (zero-copy)

def gpu_kivi_compress_keys(x: torch.Tensor, qmax: int = 3):
    """Per-channel 2-bit quant for keys (along seq_len dim=2). All on GPU."""
    tmin = x.amin(dim=2)
    tmax = x.amax(dim=2)
    rng = tmax - tmin
    rng = torch.where(rng == 0, torch.ones_like(rng), rng)
    scales = rng / qmax
    zeros = tmin
    quantized = ((x - zeros.unsqueeze(2)) / scales.unsqueeze(2)).round().clamp(0, qmax).to(torch.uint8)
    return quantized, scales, zeros

def gpu_kivi_compress_values(x: torch.Tensor, qmax: int = 3):
    """Per-token 2-bit quant for values (along head_dim dim=3). All on GPU."""
    tmin = x.amin(dim=3)
    tmax = x.amax(dim=3)
    rng = tmax - tmin
    rng = torch.where(rng == 0, torch.ones_like(rng), rng)
    scales = rng / qmax
    zeros = tmin
    quantized = ((x - zeros.unsqueeze(3)) / scales.unsqueeze(3)).round().clamp(0, qmax).to(torch.uint8)
    return quantized, scales, zeros

def gpu_kivi_decompress_keys(quantized, scales, zeros):
    return (quantized.float() * scales.unsqueeze(2) + zeros.unsqueeze(2)).half()

def gpu_kivi_decompress_values(quantized, scales, zeros):
    return (quantized.float() * scales.unsqueeze(3) + zeros.unsqueeze(3)).half()

print("GPU-native KIVI 2-bit defined.")

In [None]:
# Cell 3b: GPU-native INT4 per-group quantization (zero-copy)

def gpu_int4_compress(x, group_size=128):
    """Per-group asymmetric INT4 quantization. Packs 2 values per byte on GPU."""
    original_shape = list(x.shape)
    flat = x.reshape(-1)
    n = flat.numel()
    padded_len = ((n + group_size - 1) // group_size) * group_size
    if padded_len > n:
        flat = torch.nn.functional.pad(flat, (0, padded_len - n))
    grouped = flat.reshape(-1, group_size)
    gmin = grouped.amin(dim=1)
    gmax = grouped.amax(dim=1)
    rng = gmax - gmin
    rng = torch.where(rng == 0, torch.ones_like(rng), rng)
    scales = rng / 15.0
    zeros = gmin
    quantized = ((grouped - zeros.unsqueeze(1)) / scales.unsqueeze(1)).round().clamp(0, 15).to(torch.uint8)
    flat_q = quantized.reshape(-1)
    if flat_q.numel() % 2 != 0:
        flat_q = torch.nn.functional.pad(flat_q, (0, 1))
    packed = (flat_q[0::2] << 4) | flat_q[1::2]
    return packed, scales, zeros, original_shape

def gpu_int4_decompress(packed, scales, zeros, original_shape, group_size=128):
    """Dequantize packed INT4 back to float16 on GPU."""
    high = (packed >> 4) & 0x0F
    low = packed & 0x0F
    flat_q = torch.empty(packed.numel() * 2, dtype=torch.uint8, device=packed.device)
    flat_q[0::2] = high
    flat_q[1::2] = low
    total_elements = 1
    for s in original_shape:
        total_elements *= s
    padded_len = scales.numel() * group_size
    grouped = flat_q[:padded_len].reshape(-1, group_size).float()
    result = grouped * scales.unsqueeze(1) + zeros.unsqueeze(1)
    return result.reshape(-1)[:total_elements].reshape(original_shape).half()

print("GPU-native INT4 defined.")

In [None]:
# Cell 3c: GPU-native FP8 E4M3 (native on H100, simulated fallback)

# Detect native FP8 support (PyTorch 2.1+ on Hopper/H100)
HAS_NATIVE_FP8 = hasattr(torch, "float8_e4m3fn")
if HAS_NATIVE_FP8:
    try:
        _test = torch.zeros(1, device="cuda", dtype=torch.float8_e4m3fn)
        del _test
    except Exception:
        HAS_NATIVE_FP8 = False

if HAS_NATIVE_FP8:
    print("Using NATIVE FP8 E4M3 (torch.float8_e4m3fn) — H100 Hopper path")

    def gpu_fp8_compress(x):
        """Native FP8 E4M3: per-layer scale, cast to torch.float8_e4m3fn."""
        num_layers = x.shape[0]
        flat = x.reshape(num_layers, -1)
        amax = flat.abs().amax(dim=1)
        amax = torch.where(amax == 0, torch.ones_like(amax), amax)
        scales = amax / torch.finfo(torch.float8_e4m3fn).max
        scales_exp = scales.view(num_layers, *([1] * (x.ndim - 1)))
        quantized = (x / scales_exp).to(torch.float8_e4m3fn)
        return quantized, scales

    def gpu_fp8_decompress(quantized, scales):
        """Dequantize native FP8 back to float16."""
        num_layers = quantized.shape[0]
        s_exp = scales.view(num_layers, *([1] * (quantized.ndim - 1)))
        return (quantized.float() * s_exp).half()
else:
    print("Using SIMULATED FP8 E4M3 (uint8 fallback) — T4/A100 path")

    def gpu_fp8_compress(x):
        """Simulated FP8 E4M3: per-layer scale = amax/240, quantize to uint8."""
        num_layers = x.shape[0]
        flat = x.reshape(num_layers, -1)
        amax = flat.abs().amax(dim=1)
        amax = torch.where(amax == 0, torch.ones_like(amax), amax)
        scales = amax / 240.0
        scales_exp = scales.view(num_layers, *([1] * (x.ndim - 1)))
        quantized = (x / scales_exp + 128.0).round().clamp(0, 255).to(torch.uint8)
        return quantized, scales

    def gpu_fp8_decompress(quantized, scales):
        """Dequantize simulated FP8 back to float16."""
        num_layers = quantized.shape[0]
        s_exp = scales.view(num_layers, *([1] * (quantized.ndim - 1)))
        return ((quantized.float() - 128.0) * s_exp).half()

FP8_PATH = "native" if HAS_NATIVE_FP8 else "simulated"
print(f"GPU-native FP8 E4M3 defined (path={FP8_PATH}).")

In [None]:
# Cell 3d: GPU-native CacheGen (anchor INT8 + delta INT4, chunked)

def gpu_cachegen_compress(x, chunk_size=10):
    """CacheGen: anchor tokens (INT8) + delta tokens (INT4 packed). All on GPU."""
    original_shape = list(x.shape)
    L, H, S, D = x.shape
    N = L * H
    flat = x.reshape(N, S, D)
    num_chunks = (S + chunk_size - 1) // chunk_size

    anchor_indices = torch.arange(0, S, chunk_size, device=x.device)
    anchors = flat[:, anchor_indices, :]

    a_flat = anchors.reshape(N * num_chunks, D)
    a_min = a_flat.amin(dim=1, keepdim=True)
    a_max = a_flat.amax(dim=1, keepdim=True)
    a_rng = a_max - a_min
    a_rng = torch.where(a_rng == 0, torch.ones_like(a_rng), a_rng)
    a_scales = (a_rng / 255.0).squeeze(1)
    a_zeros = a_min.squeeze(1)
    a_quant = ((a_flat - a_zeros.unsqueeze(1)) / a_scales.unsqueeze(1)).round().clamp(0, 255).to(torch.uint8)

    anchors_expanded = torch.zeros_like(flat)
    for c in range(num_chunks):
        start = c * chunk_size
        end = min(start + chunk_size, S)
        anchors_expanded[:, start:end, :] = anchors[:, c:c+1, :]

    all_indices = torch.arange(S, device=x.device)
    delta_mask = (all_indices % chunk_size) != 0
    delta_indices = all_indices[delta_mask]
    deltas = flat[:, delta_indices, :] - anchors_expanded[:, delta_indices, :]

    d_flat = deltas.reshape(-1, D)
    d_min = d_flat.amin(dim=1, keepdim=True)
    d_max = d_flat.amax(dim=1, keepdim=True)
    d_rng = d_max - d_min
    d_rng = torch.where(d_rng == 0, torch.ones_like(d_rng), d_rng)
    d_scales = (d_rng / 15.0).squeeze(1)
    d_zeros = d_min.squeeze(1)
    d_quant = ((d_flat - d_zeros.unsqueeze(1)) / d_scales.unsqueeze(1)).round().clamp(0, 15).to(torch.uint8)

    d_flat_q = d_quant.reshape(-1)
    if d_flat_q.numel() % 2 != 0:
        d_flat_q = torch.nn.functional.pad(d_flat_q, (0, 1))
    delta_packed = (d_flat_q[0::2] << 4) | d_flat_q[1::2]

    return a_quant, a_scales, a_zeros, delta_packed, d_scales, d_zeros, original_shape

def gpu_cachegen_decompress(a_quant, a_scales, a_zeros, delta_packed, d_scales, d_zeros, original_shape, chunk_size=10):
    """Decompress CacheGen anchor+delta on GPU."""
    L, H, S, D = original_shape
    N = L * H
    num_chunks = (S + chunk_size - 1) // chunk_size

    anchors = (a_quant.float() * a_scales.unsqueeze(1) + a_zeros.unsqueeze(1))
    anchors = anchors.reshape(N, num_chunks, D)

    high = (delta_packed >> 4) & 0x0F
    low = delta_packed & 0x0F
    d_unpacked = torch.empty(delta_packed.numel() * 2, dtype=torch.uint8, device=delta_packed.device)
    d_unpacked[0::2] = high
    d_unpacked[1::2] = low

    all_indices = torch.arange(S, device=a_quant.device)
    num_delta_tokens = int((all_indices % chunk_size != 0).sum().item())
    total_delta_elements = N * num_delta_tokens * D
    d_flat = d_unpacked[:total_delta_elements].reshape(-1, D).float()

    deltas = d_flat * d_scales[:d_flat.shape[0]].unsqueeze(1) + d_zeros[:d_flat.shape[0]].unsqueeze(1)
    deltas = deltas.reshape(N, num_delta_tokens, D)

    result = torch.zeros(N, S, D, device=a_quant.device, dtype=torch.float32)
    anchor_positions = torch.arange(0, S, chunk_size, device=a_quant.device)
    result[:, anchor_positions, :] = anchors

    delta_positions = all_indices[all_indices % chunk_size != 0]
    chunk_ids = delta_positions // chunk_size
    result[:, delta_positions, :] = deltas + anchors[:, chunk_ids, :]

    return result.reshape(original_shape).half()

print("GPU-native CacheGen defined.")

In [None]:
# Cell 3e: GPU-native Cascade (prune 50% + INT4)

def gpu_cascade_compress(keys, values, keep_ratio=0.5, group_size=128):
    """Cascade: prune 50% tokens by key norm importance, then INT4 quantize."""
    L, H, S, D = keys.shape
    keep_count = max(1, int(S * keep_ratio))

    key_norms = torch.linalg.norm(keys, dim=3)
    importance = key_norms.mean(dim=(0, 1))

    protect = min(2, S)
    imp_copy = importance.clone()
    imp_copy[:protect] = -float('inf')
    if S > protect:
        imp_copy[max(protect, S - protect):] = -float('inf')

    remaining_budget = keep_count - min(2 * protect, S)
    if remaining_budget > 0:
        _, topk_idx = torch.topk(imp_copy, remaining_budget)
    else:
        topk_idx = torch.tensor([], dtype=torch.long, device=keys.device)

    protected = list(range(protect)) + list(range(max(protect, S - protect), S))
    protected_t = torch.tensor(protected, dtype=torch.long, device=keys.device)
    keep_indices = torch.sort(torch.cat([protected_t, topk_idx]))[0][:keep_count]

    pruned_keys = keys[:, :, keep_indices, :]
    pruned_values = values[:, :, keep_indices, :]

    k_packed, k_scales, k_zeros, k_shape = gpu_int4_compress(pruned_keys, group_size)
    v_packed, v_scales, v_zeros, v_shape = gpu_int4_compress(pruned_values, group_size)

    return (k_packed, k_scales, k_zeros, k_shape,
            v_packed, v_scales, v_zeros, v_shape,
            keep_indices, list(keys.shape))

def gpu_cascade_decompress(k_packed, k_scales, k_zeros, k_shape,
                           v_packed, v_scales, v_zeros, v_shape,
                           keep_indices, original_shape, group_size=128):
    """Decompress cascade: INT4 decompress then scatter back."""
    pruned_keys = gpu_int4_decompress(k_packed, k_scales, k_zeros, k_shape, group_size)
    pruned_values = gpu_int4_decompress(v_packed, v_scales, v_zeros, v_shape, group_size)
    L, H, S, D = original_shape
    keys = torch.zeros(L, H, S, D, device=k_packed.device, dtype=torch.float16)
    values = torch.zeros(L, H, S, D, device=k_packed.device, dtype=torch.float16)
    keys[:, :, keep_indices, :] = pruned_keys
    values[:, :, keep_indices, :] = pruned_values
    return keys, values

print("GPU-native Cascade (prune50+INT4) defined.")

In [None]:
# Cell 3f: GPU-native Palu (truncated SVD, rank_ratio=0.25)

def gpu_palu_compress(x, rank_ratio=0.25):
    """Truncated SVD: batched over (L*H) slices. Stores US and Vt as float16."""
    original_shape = list(x.shape)
    L, H, S, D = x.shape
    rank = max(1, int(min(S, D) * rank_ratio))
    N = L * H
    mats = x.reshape(N, S, D)
    U, Sigma, Vt = torch.linalg.svd(mats, full_matrices=False)
    US = U[:, :, :rank] * Sigma[:, None, :rank]
    Vt_r = Vt[:, :rank, :]
    return US.half(), Vt_r.half(), original_shape, rank

def gpu_palu_decompress(US, Vt, original_shape):
    """Reconstruct from SVD factors via batched matmul."""
    result = torch.bmm(US.float(), Vt.float())
    return result.reshape(original_shape).half()

print("GPU-native Palu (SVD) defined.")

In [None]:
# Cell 4: CPU numpy compressors (for comparison baseline)

def np_int8_compress(tensor: np.ndarray):
    fp32 = tensor.astype(np.float32)
    num_layers = fp32.shape[0]
    qmax = 127
    scales = np.zeros(num_layers, dtype=np.float32)
    quantized = np.zeros(fp32.shape, dtype=np.int8)
    for i in range(num_layers):
        layer = fp32[i]
        amax = np.abs(layer).max()
        scales[i] = amax / qmax if amax != 0 else 1.0
        quantized[i] = np.clip(np.round(layer / scales[i]), -qmax, qmax).astype(np.int8)
    return quantized, scales

def np_int8_decompress(quantized, scales):
    num_layers = quantized.shape[0]
    result = np.zeros(quantized.shape, dtype=np.float32)
    for i in range(num_layers):
        result[i] = quantized[i].astype(np.float32) * scales[i]
    return result.astype(np.float16)

def np_kivi_compress_keys(tensor):
    fp32 = tensor.astype(np.float32)
    qmax = 3
    tmin = fp32.min(axis=2)
    tmax = fp32.max(axis=2)
    rng = tmax - tmin
    rng[rng == 0] = 1.0
    scales = rng / qmax
    zeros = tmin
    quantized = np.clip(np.round((fp32 - zeros[:,:,np.newaxis,:]) / scales[:,:,np.newaxis,:]), 0, qmax).astype(np.uint8)
    return quantized, scales, zeros

def np_kivi_compress_values(tensor):
    fp32 = tensor.astype(np.float32)
    qmax = 3
    tmin = fp32.min(axis=3)
    tmax = fp32.max(axis=3)
    rng = tmax - tmin
    rng[rng == 0] = 1.0
    scales = rng / qmax
    zeros = tmin
    quantized = np.clip(np.round((fp32 - zeros[:,:,:,np.newaxis]) / scales[:,:,:,np.newaxis]), 0, qmax).astype(np.uint8)
    return quantized, scales, zeros

def np_kivi_decompress_keys(quant, scales, zeros):
    return (quant.astype(np.float32) * scales[:,:,np.newaxis,:] + zeros[:,:,np.newaxis,:]).astype(np.float16)

def np_kivi_decompress_values(quant, scales, zeros):
    return (quant.astype(np.float32) * scales[:,:,:,np.newaxis] + zeros[:,:,:,np.newaxis]).astype(np.float16)

# (INT8 and KIVI baselines above)

# --- New CPU baselines for INT4, FP8, CacheGen, Cascade, Palu ---

def np_int4_compress(tensor, group_size=128):
    """CPU INT4 per-group quantization (numpy baseline)."""
    fp32 = tensor.astype(np.float32)
    flat = fp32.reshape(-1)
    padded_len = ((len(flat) + group_size - 1) // group_size) * group_size
    padded = np.zeros(padded_len, dtype=np.float32)
    padded[:len(flat)] = flat
    grouped = padded.reshape(-1, group_size)
    gmin = grouped.min(axis=1)
    gmax = grouped.max(axis=1)
    rng = gmax - gmin
    rng[rng == 0] = 1.0
    scales = rng / 15.0
    zeros = gmin
    quantized = np.clip(np.round((grouped - zeros[:, None]) / scales[:, None]), 0, 15).astype(np.uint8)
    flat_q = quantized.reshape(-1)
    if len(flat_q) % 2 != 0:
        flat_q = np.append(flat_q, np.uint8(0))
    packed = (flat_q[0::2] << 4) | flat_q[1::2]
    return packed, scales, zeros, list(tensor.shape)

def np_int4_decompress(packed, scales, zeros, original_shape, group_size=128):
    high = (packed >> 4) & 0x0F
    low = packed & 0x0F
    flat_q = np.empty(len(packed) * 2, dtype=np.uint8)
    flat_q[0::2] = high
    flat_q[1::2] = low
    total = 1
    for s in original_shape:
        total *= s
    padded_len = len(scales) * group_size
    grouped = flat_q[:padded_len].reshape(-1, group_size).astype(np.float32)
    result = grouped * scales[:, None] + zeros[:, None]
    return result.reshape(-1)[:total].reshape(original_shape).astype(np.float16)

def np_fp8_compress(tensor):
    """CPU FP8 E4M3 simulation (numpy baseline)."""
    fp32 = tensor.astype(np.float32)
    num_layers = fp32.shape[0]
    scales = np.zeros(num_layers, dtype=np.float32)
    quantized = np.zeros(fp32.shape, dtype=np.uint8)
    for i in range(num_layers):
        layer = fp32[i]
        amax = np.abs(layer).max()
        scales[i] = amax / 240.0 if amax != 0 else 1.0
        mapped = np.clip(np.round(layer / scales[i] + 128.0), 0, 255)
        quantized[i] = mapped.astype(np.uint8)
    return quantized, scales

def np_fp8_decompress(quantized, scales):
    num_layers = quantized.shape[0]
    result = np.zeros(quantized.shape, dtype=np.float32)
    for i in range(num_layers):
        result[i] = (quantized[i].astype(np.float32) - 128.0) * scales[i]
    return result.astype(np.float16)

def np_cachegen_compress(tensor, chunk_size=10):
    """CPU CacheGen anchor+delta (numpy baseline)."""
    fp32 = tensor.astype(np.float32)
    L, H, S, D = fp32.shape
    N = L * H
    flat = fp32.reshape(N, S, D)
    num_chunks = (S + chunk_size - 1) // chunk_size

    all_anchors = []
    all_deltas = []
    for s in range(N):
        for c in range(num_chunks):
            start = c * chunk_size
            end = min(start + chunk_size, S)
            anchor = flat[s, start]
            all_anchors.append(anchor)
            if end - start > 1:
                deltas = flat[s, start+1:end] - anchor[None, :]
                all_deltas.append(deltas.reshape(-1))

    # Simple quantize anchors to uint8 and deltas to int4
    a_arr = np.array(all_anchors, dtype=np.float32)
    a_min = a_arr.min(axis=1, keepdims=True)
    a_max = a_arr.max(axis=1, keepdims=True)
    a_rng = a_max - a_min
    a_rng[a_rng == 0] = 1.0
    a_scales = a_rng / 255.0
    a_quant = np.clip(np.round((a_arr - a_min) / a_scales), 0, 255).astype(np.uint8)

    if all_deltas:
        d_arr = np.concatenate(all_deltas)
        d_min = d_arr.min()
        d_max = d_arr.max()
        d_rng = d_max - d_min if d_max != d_min else 1.0
        d_scale = d_rng / 15.0
        d_quant = np.clip(np.round((d_arr - d_min) / d_scale), 0, 15).astype(np.uint8)
    else:
        d_quant = np.array([], dtype=np.uint8)
        d_scale = 1.0
        d_min = 0.0

    return a_quant, a_scales.squeeze(1), a_min.squeeze(1), d_quant, d_scale, d_min, list(tensor.shape)

def np_cachegen_decompress(a_quant, a_scales, a_zeros, d_quant, d_scale, d_zero, original_shape, chunk_size=10):
    L, H, S, D = original_shape
    N = L * H
    result = np.zeros((N, S, D), dtype=np.float32)
    num_chunks = (S + chunk_size - 1) // chunk_size
    a_idx = 0
    d_idx = 0
    for s in range(N):
        for c in range(num_chunks):
            start = c * chunk_size
            end = min(start + chunk_size, S)
            anchor = a_quant[a_idx].astype(np.float32) * a_scales[a_idx] + a_zeros[a_idx]
            result[s, start] = anchor
            a_idx += 1
            if end - start > 1:
                n_vals = (end - start - 1) * D
                dq = d_quant[d_idx:d_idx+n_vals].astype(np.float32) * d_scale + d_zero
                result[s, start+1:end] = anchor[None, :] + dq.reshape(-1, D)
                d_idx += n_vals
    return result.reshape(original_shape).astype(np.float16)

def np_cascade_compress(keys, values, keep_ratio=0.5, group_size=128):
    """CPU Cascade: prune + INT4 (numpy baseline)."""
    fp32_k = keys.astype(np.float32)
    L, H, S, D = fp32_k.shape
    keep_count = max(1, int(S * keep_ratio))
    key_norms = np.linalg.norm(fp32_k, axis=3).mean(axis=(0, 1))
    top_indices = np.argsort(key_norms)[-keep_count:]
    keep_indices = np.sort(top_indices)
    pruned_k = keys[:, :, keep_indices, :]
    pruned_v = values[:, :, keep_indices, :]
    kp, ks, kz, ksh = np_int4_compress(pruned_k, group_size)
    vp, vs, vz, vsh = np_int4_compress(pruned_v, group_size)
    return kp, ks, kz, ksh, vp, vs, vz, vsh, keep_indices, list(keys.shape)

def np_cascade_decompress(kp, ks, kz, ksh, vp, vs, vz, vsh, keep_indices, original_shape, group_size=128):
    pk = np_int4_decompress(kp, ks, kz, ksh, group_size)
    pv = np_int4_decompress(vp, vs, vz, vsh, group_size)
    L, H, S, D = original_shape
    keys = np.zeros(original_shape, dtype=np.float16)
    values = np.zeros(original_shape, dtype=np.float16)
    keys[:, :, keep_indices, :] = pk
    values[:, :, keep_indices, :] = pv
    return keys, values

def np_palu_compress(tensor, rank_ratio=0.25):
    """CPU Palu SVD (numpy baseline)."""
    fp32 = tensor.astype(np.float32)
    L, H, S, D = fp32.shape
    rank = max(1, int(min(S, D) * rank_ratio))
    all_US = []
    all_Vt = []
    for l in range(L):
        for h in range(H):
            mat = fp32[l, h]
            U, s, Vt = np.linalg.svd(mat, full_matrices=False)
            US = U[:, :rank] * s[None, :rank]
            all_US.append(US.astype(np.float16))
            all_Vt.append(Vt[:rank, :].astype(np.float16))
    return all_US, all_Vt, list(tensor.shape), rank

def np_palu_decompress(all_US, all_Vt, original_shape, rank):
    L, H, S, D = original_shape
    result = np.zeros(original_shape, dtype=np.float32)
    idx = 0
    for l in range(L):
        for h in range(H):
            result[l, h] = all_US[idx].astype(np.float32) @ all_Vt[idx].astype(np.float32)
            idx += 1
    return result.astype(np.float16)

print("CPU numpy compressors defined (all 7).")

## Part 2: Timing Utilities

Key difference from v1: **CUDA Event timing** measures only GPU kernel execution.
Data is pre-loaded to GPU before timing starts.

In [None]:
# Cell 5: Timing utilities

def gpu_time_ms(fn, *args, warmup=WARMUP, repeats=REPEATS):
    """Time a GPU function using CUDA Events. All args must be on GPU already.
    Returns (median_ms, std_ms, last_result)."""
    # Warmup
    for _ in range(warmup):
        result = fn(*args)
        torch.cuda.synchronize()

    times = []
    for _ in range(repeats):
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)

        start_event.record()
        result = fn(*args)
        end_event.record()

        torch.cuda.synchronize()
        times.append(start_event.elapsed_time(end_event))  # ms, from CUDA

    return statistics.median(times), statistics.stdev(times) if len(times) > 1 else 0.0, result


def cpu_time_ms(fn, *args, warmup=WARMUP, repeats=REPEATS):
    """Time a CPU function. Returns (median_ms, std_ms, last_result)."""
    for _ in range(warmup):
        result = fn(*args)
    times = []
    for _ in range(repeats):
        t0 = time.perf_counter_ns()
        result = fn(*args)
        t1 = time.perf_counter_ns()
        times.append((t1 - t0) / 1e6)
    return statistics.median(times), statistics.stdev(times) if len(times) > 1 else 0.0, result


def generate_kv_gpu(num_layers, num_heads, seq_len, head_dim, seed=42):
    """Generate KV cache directly on GPU as float32."""
    rng = np.random.default_rng(seed)
    shape = (num_layers, num_heads, seq_len, head_dim)
    keys_np = rng.standard_normal(shape).astype(np.float16)
    values_np = rng.standard_normal(shape).astype(np.float16)
    # Move to GPU once — this is NOT timed
    keys_gpu = torch.from_numpy(keys_np).float().to(device)
    values_gpu = torch.from_numpy(values_np).float().to(device)
    return keys_np, values_np, keys_gpu, values_gpu


def transfer_ms(size_bytes, bandwidth_gbps):
    """Analytical transfer time."""
    return (size_bytes * 8) / (bandwidth_gbps * 1e6)


print("Timing utilities defined.")
print(f"Config: warmup={WARMUP}, repeats={REPEATS}")

## Part 3: Run GPU-Native Calibration

For each (model, seq_len):
1. Generate KV cache, pre-load to GPU
2. Time **CPU numpy** compress/decompress
3. Time **GPU-native** compress/decompress (data already on GPU, CUDA event timing)
4. Compute speedup factor

In [None]:
# Cell 6: INT8 calibration — GPU-native zero-copy timing

int8_results = []

print(f"{'Model':<16} {'Seq':>5} {'MB':>6}  {'CPU comp':>10} {'GPU comp':>10} {'Speedup':>8}  {'CPU dec':>10} {'GPU dec':>10} {'Speedup':>8}")
print("-" * 100)

for model_name, (num_layers, num_heads, head_dim) in MODELS.items():
    for seq_len in SEQ_LENS:
        keys_np, vals_np, keys_gpu, vals_gpu = generate_kv_gpu(num_layers, num_heads, seq_len, head_dim)
        cache_mb = (keys_np.nbytes + vals_np.nbytes) / (1024 * 1024)

        # CPU timing (keys only for speed — symmetric for values)
        cpu_comp_ms, _, (kq_np, ks_np) = cpu_time_ms(np_int8_compress, keys_np)
        cpu_decomp_ms, _, _ = cpu_time_ms(np_int8_decompress, kq_np, ks_np)

        # GPU-native timing (data already on GPU, CUDA event timing)
        gpu_comp_ms, gpu_comp_std, (kq_gpu, ks_gpu) = gpu_time_ms(gpu_int8_compress, keys_gpu)
        gpu_decomp_ms, gpu_decomp_std, _ = gpu_time_ms(gpu_int8_decompress, kq_gpu, ks_gpu)

        comp_speedup = cpu_comp_ms / gpu_comp_ms if gpu_comp_ms > 0 else float('inf')
        decomp_speedup = cpu_decomp_ms / gpu_decomp_ms if gpu_decomp_ms > 0 else float('inf')

        entry = {
            "model": model_name, "seq_len": seq_len, "cache_mb": round(cache_mb, 1),
            "cpu_compress_ms": round(cpu_comp_ms, 3),
            "gpu_compress_ms": round(gpu_comp_ms, 3),
            "gpu_compress_std_ms": round(gpu_comp_std, 3),
            "compress_speedup": round(comp_speedup, 1),
            "cpu_decompress_ms": round(cpu_decomp_ms, 3),
            "gpu_decompress_ms": round(gpu_decomp_ms, 3),
            "gpu_decompress_std_ms": round(gpu_decomp_std, 3),
            "decompress_speedup": round(decomp_speedup, 1),
        }
        int8_results.append(entry)

        print(f"{model_name:<16} {seq_len:>5} {cache_mb:>5.0f}M  "
              f"{cpu_comp_ms:>8.2f}ms {gpu_comp_ms:>8.2f}ms {comp_speedup:>7.1f}x  "
              f"{cpu_decomp_ms:>8.2f}ms {gpu_decomp_ms:>8.2f}ms {decomp_speedup:>7.1f}x")

        # Free GPU memory
        del keys_gpu, vals_gpu, kq_gpu, ks_gpu
        torch.cuda.empty_cache()

print(f"\nCompleted {len(int8_results)} INT8 calibration runs.")

In [None]:
# Cell 7: KIVI 2-bit calibration — GPU-native zero-copy timing

kivi_results = []

print(f"{'Model':<16} {'Seq':>5} {'MB':>6}  {'CPU comp':>10} {'GPU comp':>10} {'Speedup':>8}  {'CPU dec':>10} {'GPU dec':>10} {'Speedup':>8}")
print("-" * 100)

for model_name, (num_layers, num_heads, head_dim) in MODELS.items():
    for seq_len in SEQ_LENS:
        keys_np, vals_np, keys_gpu, vals_gpu = generate_kv_gpu(num_layers, num_heads, seq_len, head_dim)
        cache_mb = (keys_np.nbytes + vals_np.nbytes) / (1024 * 1024)

        # CPU timing
        cpu_k_ms, _, (kq, ks, kz) = cpu_time_ms(np_kivi_compress_keys, keys_np)
        cpu_v_ms, _, (vq, vs, vz) = cpu_time_ms(np_kivi_compress_values, vals_np)
        cpu_comp_ms = cpu_k_ms + cpu_v_ms

        cpu_dk_ms, _, _ = cpu_time_ms(np_kivi_decompress_keys, kq, ks, kz)
        cpu_dv_ms, _, _ = cpu_time_ms(np_kivi_decompress_values, vq, vs, vz)
        cpu_decomp_ms = cpu_dk_ms + cpu_dv_ms

        # GPU-native timing (data already on GPU)
        gpu_k_ms, _, (kq_g, ks_g, kz_g) = gpu_time_ms(gpu_kivi_compress_keys, keys_gpu)
        gpu_v_ms, _, (vq_g, vs_g, vz_g) = gpu_time_ms(gpu_kivi_compress_values, vals_gpu)
        gpu_comp_ms = gpu_k_ms + gpu_v_ms

        gpu_dk_ms, _, _ = gpu_time_ms(gpu_kivi_decompress_keys, kq_g, ks_g, kz_g)
        gpu_dv_ms, _, _ = gpu_time_ms(gpu_kivi_decompress_values, vq_g, vs_g, vz_g)
        gpu_decomp_ms = gpu_dk_ms + gpu_dv_ms

        comp_speedup = cpu_comp_ms / gpu_comp_ms if gpu_comp_ms > 0 else float('inf')
        decomp_speedup = cpu_decomp_ms / gpu_decomp_ms if gpu_decomp_ms > 0 else float('inf')

        entry = {
            "model": model_name, "seq_len": seq_len, "cache_mb": round(cache_mb, 1),
            "cpu_compress_ms": round(cpu_comp_ms, 3),
            "gpu_compress_ms": round(gpu_comp_ms, 3),
            "compress_speedup": round(comp_speedup, 1),
            "cpu_decompress_ms": round(cpu_decomp_ms, 3),
            "gpu_decompress_ms": round(gpu_decomp_ms, 3),
            "decompress_speedup": round(decomp_speedup, 1),
        }
        kivi_results.append(entry)

        print(f"{model_name:<16} {seq_len:>5} {cache_mb:>5.0f}M  "
              f"{cpu_comp_ms:>8.2f}ms {gpu_comp_ms:>8.2f}ms {comp_speedup:>7.1f}x  "
              f"{cpu_decomp_ms:>8.2f}ms {gpu_decomp_ms:>8.2f}ms {decomp_speedup:>7.1f}x")

        del keys_gpu, vals_gpu, kq_g, ks_g, kz_g, vq_g, vs_g, vz_g
        torch.cuda.empty_cache()

print(f"\nCompleted {len(kivi_results)} KIVI calibration runs.")

In [None]:
# Cell 7b: INT4 calibration

int4_results = []

print(f"{'Model':<16} {'Seq':>5} {'MB':>6}  {'CPU comp':>10} {'GPU comp':>10} {'Speedup':>8}  {'CPU dec':>10} {'GPU dec':>10} {'Speedup':>8}")
print("-" * 100)

for model_name, (num_layers, num_heads, head_dim) in MODELS.items():
    for seq_len in SEQ_LENS:
        keys_np, vals_np, keys_gpu, vals_gpu = generate_kv_gpu(num_layers, num_heads, seq_len, head_dim)
        cache_mb = (keys_np.nbytes + vals_np.nbytes) / (1024 * 1024)

        # CPU timing
        cpu_comp_ms, _, (kp, ks, kz, ksh) = cpu_time_ms(np_int4_compress, keys_np)
        cpu_decomp_ms, _, _ = cpu_time_ms(np_int4_decompress, kp, ks, kz, ksh)

        # GPU timing
        gpu_comp_ms, gpu_comp_std, (gp, gs, gz, gsh) = gpu_time_ms(gpu_int4_compress, keys_gpu)
        gpu_decomp_ms, gpu_decomp_std, _ = gpu_time_ms(gpu_int4_decompress, gp, gs, gz, gsh)

        comp_speedup = cpu_comp_ms / gpu_comp_ms if gpu_comp_ms > 0 else float('inf')
        decomp_speedup = cpu_decomp_ms / gpu_decomp_ms if gpu_decomp_ms > 0 else float('inf')

        entry = {
            "model": model_name, "seq_len": seq_len, "cache_mb": round(cache_mb, 1),
            "cpu_compress_ms": round(cpu_comp_ms, 3),
            "gpu_compress_ms": round(gpu_comp_ms, 3),
            "gpu_compress_std_ms": round(gpu_comp_std, 3),
            "compress_speedup": round(comp_speedup, 1),
            "cpu_decompress_ms": round(cpu_decomp_ms, 3),
            "gpu_decompress_ms": round(gpu_decomp_ms, 3),
            "gpu_decompress_std_ms": round(gpu_decomp_std, 3),
            "decompress_speedup": round(decomp_speedup, 1),
        }
        int4_results.append(entry)

        print(f"{model_name:<16} {seq_len:>5} {cache_mb:>5.0f}M  "
              f"{cpu_comp_ms:>8.2f}ms {gpu_comp_ms:>8.2f}ms {comp_speedup:>7.1f}x  "
              f"{cpu_decomp_ms:>8.2f}ms {gpu_decomp_ms:>8.2f}ms {decomp_speedup:>7.1f}x")

        del keys_gpu, vals_gpu, gp, gs, gz
        torch.cuda.empty_cache()

print(f"\nCompleted {len(int4_results)} INT4 calibration runs.")

In [None]:
# Cell 7c: FP8 calibration

fp8_results = []

print(f"{'Model':<16} {'Seq':>5} {'MB':>6}  {'CPU comp':>10} {'GPU comp':>10} {'Speedup':>8}  {'CPU dec':>10} {'GPU dec':>10} {'Speedup':>8}")
print("-" * 100)

for model_name, (num_layers, num_heads, head_dim) in MODELS.items():
    for seq_len in SEQ_LENS:
        keys_np, vals_np, keys_gpu, vals_gpu = generate_kv_gpu(num_layers, num_heads, seq_len, head_dim)
        cache_mb = (keys_np.nbytes + vals_np.nbytes) / (1024 * 1024)

        cpu_comp_ms, _, (kq, ks) = cpu_time_ms(np_fp8_compress, keys_np)
        cpu_decomp_ms, _, _ = cpu_time_ms(np_fp8_decompress, kq, ks)

        gpu_comp_ms, gpu_comp_std, (gq, gsc) = gpu_time_ms(gpu_fp8_compress, keys_gpu)
        gpu_decomp_ms, gpu_decomp_std, _ = gpu_time_ms(gpu_fp8_decompress, gq, gsc)

        comp_speedup = cpu_comp_ms / gpu_comp_ms if gpu_comp_ms > 0 else float('inf')
        decomp_speedup = cpu_decomp_ms / gpu_decomp_ms if gpu_decomp_ms > 0 else float('inf')

        entry = {
            "model": model_name, "seq_len": seq_len, "cache_mb": round(cache_mb, 1),
            "cpu_compress_ms": round(cpu_comp_ms, 3),
            "gpu_compress_ms": round(gpu_comp_ms, 3),
            "gpu_compress_std_ms": round(gpu_comp_std, 3),
            "compress_speedup": round(comp_speedup, 1),
            "cpu_decompress_ms": round(cpu_decomp_ms, 3),
            "gpu_decompress_ms": round(gpu_decomp_ms, 3),
            "gpu_decompress_std_ms": round(gpu_decomp_std, 3),
            "decompress_speedup": round(decomp_speedup, 1),
        }
        fp8_results.append(entry)

        print(f"{model_name:<16} {seq_len:>5} {cache_mb:>5.0f}M  "
              f"{cpu_comp_ms:>8.2f}ms {gpu_comp_ms:>8.2f}ms {comp_speedup:>7.1f}x  "
              f"{cpu_decomp_ms:>8.2f}ms {gpu_decomp_ms:>8.2f}ms {decomp_speedup:>7.1f}x")

        del keys_gpu, vals_gpu, gq, gsc
        torch.cuda.empty_cache()

print(f"\nCompleted {len(fp8_results)} FP8 calibration runs.")

In [None]:
# Cell 7d: CacheGen calibration

cachegen_results = []

print(f"{'Model':<16} {'Seq':>5} {'MB':>6}  {'CPU comp':>10} {'GPU comp':>10} {'Speedup':>8}  {'CPU dec':>10} {'GPU dec':>10} {'Speedup':>8}")
print("-" * 100)

for model_name, (num_layers, num_heads, head_dim) in MODELS.items():
    for seq_len in SEQ_LENS:
        keys_np, vals_np, keys_gpu, vals_gpu = generate_kv_gpu(num_layers, num_heads, seq_len, head_dim)
        cache_mb = (keys_np.nbytes + vals_np.nbytes) / (1024 * 1024)

        cpu_comp_ms, _, cpu_comp_out = cpu_time_ms(np_cachegen_compress, keys_np)
        cpu_decomp_ms, _, _ = cpu_time_ms(np_cachegen_decompress, *cpu_comp_out)

        gpu_comp_ms, gpu_comp_std, gpu_comp_out = gpu_time_ms(gpu_cachegen_compress, keys_gpu)
        # Unpack for decompress timing
        a_q, a_s, a_z, dp, ds, dz, osh = gpu_comp_out
        gpu_decomp_ms, gpu_decomp_std, _ = gpu_time_ms(gpu_cachegen_decompress, a_q, a_s, a_z, dp, ds, dz, osh)

        comp_speedup = cpu_comp_ms / gpu_comp_ms if gpu_comp_ms > 0 else float('inf')
        decomp_speedup = cpu_decomp_ms / gpu_decomp_ms if gpu_decomp_ms > 0 else float('inf')

        entry = {
            "model": model_name, "seq_len": seq_len, "cache_mb": round(cache_mb, 1),
            "cpu_compress_ms": round(cpu_comp_ms, 3),
            "gpu_compress_ms": round(gpu_comp_ms, 3),
            "gpu_compress_std_ms": round(gpu_comp_std, 3),
            "compress_speedup": round(comp_speedup, 1),
            "cpu_decompress_ms": round(cpu_decomp_ms, 3),
            "gpu_decompress_ms": round(gpu_decomp_ms, 3),
            "gpu_decompress_std_ms": round(gpu_decomp_std, 3),
            "decompress_speedup": round(decomp_speedup, 1),
        }
        cachegen_results.append(entry)

        print(f"{model_name:<16} {seq_len:>5} {cache_mb:>5.0f}M  "
              f"{cpu_comp_ms:>8.2f}ms {gpu_comp_ms:>8.2f}ms {comp_speedup:>7.1f}x  "
              f"{cpu_decomp_ms:>8.2f}ms {gpu_decomp_ms:>8.2f}ms {decomp_speedup:>7.1f}x")

        del keys_gpu, vals_gpu, a_q, a_s, a_z, dp, ds, dz
        torch.cuda.empty_cache()

print(f"\nCompleted {len(cachegen_results)} CacheGen calibration runs.")

In [None]:
# Cell 7e: Cascade (prune50+INT4) calibration

cascade_results = []

print(f"{'Model':<16} {'Seq':>5} {'MB':>6}  {'CPU comp':>10} {'GPU comp':>10} {'Speedup':>8}  {'CPU dec':>10} {'GPU dec':>10} {'Speedup':>8}")
print("-" * 100)

for model_name, (num_layers, num_heads, head_dim) in MODELS.items():
    for seq_len in SEQ_LENS:
        keys_np, vals_np, keys_gpu, vals_gpu = generate_kv_gpu(num_layers, num_heads, seq_len, head_dim)
        cache_mb = (keys_np.nbytes + vals_np.nbytes) / (1024 * 1024)

        cpu_comp_ms, _, cpu_out = cpu_time_ms(np_cascade_compress, keys_np, vals_np)
        cpu_decomp_ms, _, _ = cpu_time_ms(np_cascade_decompress, *cpu_out)

        gpu_comp_ms, gpu_comp_std, gpu_out = gpu_time_ms(gpu_cascade_compress, keys_gpu, vals_gpu)
        gpu_decomp_ms, gpu_decomp_std, _ = gpu_time_ms(gpu_cascade_decompress, *gpu_out)

        comp_speedup = cpu_comp_ms / gpu_comp_ms if gpu_comp_ms > 0 else float('inf')
        decomp_speedup = cpu_decomp_ms / gpu_decomp_ms if gpu_decomp_ms > 0 else float('inf')

        entry = {
            "model": model_name, "seq_len": seq_len, "cache_mb": round(cache_mb, 1),
            "cpu_compress_ms": round(cpu_comp_ms, 3),
            "gpu_compress_ms": round(gpu_comp_ms, 3),
            "gpu_compress_std_ms": round(gpu_comp_std, 3),
            "compress_speedup": round(comp_speedup, 1),
            "cpu_decompress_ms": round(cpu_decomp_ms, 3),
            "gpu_decompress_ms": round(gpu_decomp_ms, 3),
            "gpu_decompress_std_ms": round(gpu_decomp_std, 3),
            "decompress_speedup": round(decomp_speedup, 1),
        }
        cascade_results.append(entry)

        print(f"{model_name:<16} {seq_len:>5} {cache_mb:>5.0f}M  "
              f"{cpu_comp_ms:>8.2f}ms {gpu_comp_ms:>8.2f}ms {comp_speedup:>7.1f}x  "
              f"{cpu_decomp_ms:>8.2f}ms {gpu_decomp_ms:>8.2f}ms {decomp_speedup:>7.1f}x")

        del keys_gpu, vals_gpu
        torch.cuda.empty_cache()

print(f"\nCompleted {len(cascade_results)} Cascade calibration runs.")

In [None]:
# Cell 7f: Palu (SVD) calibration

palu_results = []

print(f"{'Model':<16} {'Seq':>5} {'MB':>6}  {'CPU comp':>10} {'GPU comp':>10} {'Speedup':>8}  {'CPU dec':>10} {'GPU dec':>10} {'Speedup':>8}")
print("-" * 100)

for model_name, (num_layers, num_heads, head_dim) in MODELS.items():
    for seq_len in SEQ_LENS:
        keys_np, vals_np, keys_gpu, vals_gpu = generate_kv_gpu(num_layers, num_heads, seq_len, head_dim)
        cache_mb = (keys_np.nbytes + vals_np.nbytes) / (1024 * 1024)

        cpu_comp_ms, _, (c_US, c_Vt, c_sh, c_r) = cpu_time_ms(np_palu_compress, keys_np)
        cpu_decomp_ms, _, _ = cpu_time_ms(np_palu_decompress, c_US, c_Vt, c_sh, c_r)

        gpu_comp_ms, gpu_comp_std, (g_US, g_Vt, g_sh, g_r) = gpu_time_ms(gpu_palu_compress, keys_gpu)
        gpu_decomp_ms, gpu_decomp_std, _ = gpu_time_ms(gpu_palu_decompress, g_US, g_Vt, g_sh)

        comp_speedup = cpu_comp_ms / gpu_comp_ms if gpu_comp_ms > 0 else float('inf')
        decomp_speedup = cpu_decomp_ms / gpu_decomp_ms if gpu_decomp_ms > 0 else float('inf')

        entry = {
            "model": model_name, "seq_len": seq_len, "cache_mb": round(cache_mb, 1),
            "cpu_compress_ms": round(cpu_comp_ms, 3),
            "gpu_compress_ms": round(gpu_comp_ms, 3),
            "gpu_compress_std_ms": round(gpu_comp_std, 3),
            "compress_speedup": round(comp_speedup, 1),
            "cpu_decompress_ms": round(cpu_decomp_ms, 3),
            "gpu_decompress_ms": round(gpu_decomp_ms, 3),
            "gpu_decompress_std_ms": round(gpu_decomp_std, 3),
            "decompress_speedup": round(decomp_speedup, 1),
        }
        palu_results.append(entry)

        print(f"{model_name:<16} {seq_len:>5} {cache_mb:>5.0f}M  "
              f"{cpu_comp_ms:>8.2f}ms {gpu_comp_ms:>8.2f}ms {comp_speedup:>7.1f}x  "
              f"{cpu_decomp_ms:>8.2f}ms {gpu_decomp_ms:>8.2f}ms {decomp_speedup:>7.1f}x")

        del keys_gpu, vals_gpu, g_US, g_Vt
        torch.cuda.empty_cache()

print(f"\nCompleted {len(palu_results)} Palu calibration runs.")

## Part 4: Pipeline Comparison (Sequential vs Pipelined)

Using **GPU-native kernel times** to model realistic end-to-end latency.

In [None]:
# Cell 8: Pipeline comparison with GPU-native timing (all 7 compressors)

pipeline_results = []

# Compression ratio map (approximate, for transfer size estimation)
COMP_RATIOS = {
    "uniform_int8": 2.0,
    "kivi_2bit": 7.0,
    "uniform_int4": 4.0,
    "fp8_e4m3": 2.0,
    "cachegen": 3.8,
    "cascade_prune50_int4": 8.0,
    "palu_lr": 4.0,
}

for model_name, (num_layers, num_heads, head_dim) in MODELS.items():
    for seq_len in [512, 1024, 2048]:
        keys_np, vals_np, keys_gpu, vals_gpu = generate_kv_gpu(num_layers, num_heads, seq_len, head_dim)
        original_bytes = keys_np.nbytes + vals_np.nbytes

        for comp_name in COMP_RATIOS:
            ratio = COMP_RATIOS[comp_name]
            compressed_bytes = int(original_bytes / ratio)

            # Measure GPU compress time
            if comp_name == "uniform_int8":
                comp_ms_k, _, _ = gpu_time_ms(gpu_int8_compress, keys_gpu, warmup=3, repeats=10)
                comp_ms_v, _, _ = gpu_time_ms(gpu_int8_compress, vals_gpu, warmup=3, repeats=10)
                full_comp_ms = comp_ms_k + comp_ms_v
                kq, ks = gpu_int8_compress(keys_gpu)
                vq, vs_ = gpu_int8_compress(vals_gpu)
                dec_k, _, _ = gpu_time_ms(gpu_int8_decompress, kq, ks, warmup=3, repeats=10)
                dec_v, _, _ = gpu_time_ms(gpu_int8_decompress, vq, vs_, warmup=3, repeats=10)
                full_decomp_ms = dec_k + dec_v
                del kq, ks, vq, vs_

            elif comp_name == "kivi_2bit":
                comp_ms_k, _, _ = gpu_time_ms(gpu_kivi_compress_keys, keys_gpu, warmup=3, repeats=10)
                comp_ms_v, _, _ = gpu_time_ms(gpu_kivi_compress_values, vals_gpu, warmup=3, repeats=10)
                full_comp_ms = comp_ms_k + comp_ms_v
                kq, ks, kz = gpu_kivi_compress_keys(keys_gpu)
                vq, vs_, vz = gpu_kivi_compress_values(vals_gpu)
                dec_k, _, _ = gpu_time_ms(gpu_kivi_decompress_keys, kq, ks, kz, warmup=3, repeats=10)
                dec_v, _, _ = gpu_time_ms(gpu_kivi_decompress_values, vq, vs_, vz, warmup=3, repeats=10)
                full_decomp_ms = dec_k + dec_v
                del kq, ks, kz, vq, vs_, vz

            elif comp_name == "uniform_int4":
                comp_ms_k, _, _ = gpu_time_ms(gpu_int4_compress, keys_gpu, warmup=3, repeats=10)
                comp_ms_v, _, _ = gpu_time_ms(gpu_int4_compress, vals_gpu, warmup=3, repeats=10)
                full_comp_ms = comp_ms_k + comp_ms_v
                gp, gs, gz, gsh = gpu_int4_compress(keys_gpu)
                dec_k, _, _ = gpu_time_ms(gpu_int4_decompress, gp, gs, gz, gsh, warmup=3, repeats=10)
                gp2, gs2, gz2, gsh2 = gpu_int4_compress(vals_gpu)
                dec_v, _, _ = gpu_time_ms(gpu_int4_decompress, gp2, gs2, gz2, gsh2, warmup=3, repeats=10)
                full_decomp_ms = dec_k + dec_v
                del gp, gs, gz, gp2, gs2, gz2

            elif comp_name == "fp8_e4m3":
                comp_ms_k, _, _ = gpu_time_ms(gpu_fp8_compress, keys_gpu, warmup=3, repeats=10)
                comp_ms_v, _, _ = gpu_time_ms(gpu_fp8_compress, vals_gpu, warmup=3, repeats=10)
                full_comp_ms = comp_ms_k + comp_ms_v
                gq, gsc = gpu_fp8_compress(keys_gpu)
                dec_k, _, _ = gpu_time_ms(gpu_fp8_decompress, gq, gsc, warmup=3, repeats=10)
                gq2, gsc2 = gpu_fp8_compress(vals_gpu)
                dec_v, _, _ = gpu_time_ms(gpu_fp8_decompress, gq2, gsc2, warmup=3, repeats=10)
                full_decomp_ms = dec_k + dec_v
                del gq, gsc, gq2, gsc2

            elif comp_name == "cachegen":
                full_comp_ms_k, _, comp_out_k = gpu_time_ms(gpu_cachegen_compress, keys_gpu, warmup=3, repeats=10)
                full_comp_ms_v, _, comp_out_v = gpu_time_ms(gpu_cachegen_compress, vals_gpu, warmup=3, repeats=10)
                full_comp_ms = full_comp_ms_k + full_comp_ms_v
                dec_k, _, _ = gpu_time_ms(gpu_cachegen_decompress, *comp_out_k, warmup=3, repeats=10)
                dec_v, _, _ = gpu_time_ms(gpu_cachegen_decompress, *comp_out_v, warmup=3, repeats=10)
                full_decomp_ms = dec_k + dec_v
                del comp_out_k, comp_out_v

            elif comp_name == "cascade_prune50_int4":
                full_comp_ms, _, comp_out = gpu_time_ms(gpu_cascade_compress, keys_gpu, vals_gpu, warmup=3, repeats=10)
                full_decomp_ms, _, _ = gpu_time_ms(gpu_cascade_decompress, *comp_out, warmup=3, repeats=10)
                del comp_out

            elif comp_name == "palu_lr":
                comp_ms_k, _, (g_US, g_Vt, g_sh, g_r) = gpu_time_ms(gpu_palu_compress, keys_gpu, warmup=3, repeats=10)
                comp_ms_v, _, (g_US2, g_Vt2, g_sh2, g_r2) = gpu_time_ms(gpu_palu_compress, vals_gpu, warmup=3, repeats=10)
                full_comp_ms = comp_ms_k + comp_ms_v
                dec_k, _, _ = gpu_time_ms(gpu_palu_decompress, g_US, g_Vt, g_sh, warmup=3, repeats=10)
                dec_v, _, _ = gpu_time_ms(gpu_palu_decompress, g_US2, g_Vt2, g_sh2, warmup=3, repeats=10)
                full_decomp_ms = dec_k + dec_v
                del g_US, g_Vt, g_US2, g_Vt2

            for bw in BANDWIDTHS_GBPS:
                raw_ms = transfer_ms(original_bytes, bw)
                xfer_ms = transfer_ms(compressed_bytes, bw)

                seq_total = full_comp_ms + xfer_ms + full_decomp_ms
                seq_speedup = raw_ms / seq_total if seq_total > 0 else float('inf')

                n = num_layers
                comp_c = full_comp_ms / n
                xfer_c = xfer_ms / n
                dec_c = full_decomp_ms / n
                bottleneck = max(comp_c, xfer_c, dec_c)
                pipe_total = (comp_c + xfer_c + dec_c) + (n - 1) * bottleneck
                pipe_speedup = raw_ms / pipe_total if pipe_total > 0 else float('inf')
                saving = (seq_total - pipe_total) / seq_total * 100 if seq_total > 0 else 0
                bn_name = "compress" if comp_c >= max(xfer_c, dec_c) else ("transfer" if xfer_c >= dec_c else "decompress")

                pipeline_results.append({
                    "model": model_name, "seq_len": seq_len, "compressor": comp_name,
                    "bandwidth_gbps": bw,
                    "gpu_compress_ms": round(full_comp_ms, 4),
                    "gpu_decompress_ms": round(full_decomp_ms, 4),
                    "transfer_ms": round(xfer_ms, 4),
                    "raw_transfer_ms": round(raw_ms, 4),
                    "sequential_total_ms": round(seq_total, 4),
                    "sequential_speedup": round(seq_speedup, 4),
                    "pipelined_total_ms": round(pipe_total, 4),
                    "pipelined_speedup": round(pipe_speedup, 4),
                    "pipeline_saving_pct": round(saving, 1),
                    "bottleneck_stage": bn_name,
                })

            torch.cuda.empty_cache()

        del keys_gpu, vals_gpu
        torch.cuda.empty_cache()

print(f"Completed {len(pipeline_results)} pipeline runs.")

print(f"\n{'Compressor':<25} {'BW':>6} {'Seq Speed':>10} {'Pipe Speed':>10} {'BN':>12}")
print("-" * 70)
for e in pipeline_results:
    if e["seq_len"] == 1024 and e["model"] == "llama-3.1-8b" and e["bandwidth_gbps"] in [10, 50, 100]:
        print(f"{e['compressor']:<25} {e['bandwidth_gbps']:>5}G "
              f"{e['sequential_speedup']:>9.2f}x {e['pipelined_speedup']:>9.2f}x "
              f"{e['bottleneck_stage']:>12}")

## Part 5: Results Summary

In [None]:
# Cell 9: Summary (all 7 compressors)

print("=" * 80)
print("GPU-NATIVE CALIBRATION RESULTS (Zero-Copy, CUDA Event Timing)")
print("=" * 80)
print(f"GPU: {gpu_name}")
print(f"Timing: torch.cuda.Event (microsecond precision, no CPU\u2194GPU copies)")

ALL_CAL = [
    ("uniform_int8", int8_results),
    ("kivi_2bit", kivi_results),
    ("uniform_int4", int4_results),
    ("fp8_e4m3", fp8_results),
    ("cachegen", cachegen_results),
    ("cascade_prune50_int4", cascade_results),
    ("palu_lr", palu_results),
]

for name, results in ALL_CAL:
    comp_sp = [e["compress_speedup"] for e in results]
    dec_sp = [e["decompress_speedup"] for e in results]
    print(f"\n--- {name} ---")
    print(f"  Compress:   mean={np.mean(comp_sp):.1f}x, range=[{min(comp_sp):.1f}x, {max(comp_sp):.1f}x]")
    print(f"  Decompress: mean={np.mean(dec_sp):.1f}x, range=[{min(dec_sp):.1f}x, {max(dec_sp):.1f}x]")

# Pipeline
pipe_savings = [e["pipeline_saving_pct"] for e in pipeline_results]
print(f"\n--- Pipeline Savings (all compressors) ---")
print(f"  Mean: {np.mean(pipe_savings):.1f}%, Range: [{min(pipe_savings):.1f}%, {max(pipe_savings):.1f}%]")

# Break-even
print(f"\n--- Break-Even Bandwidth (seq=1024, llama-3.1-8b) ---")
for comp_name in COMP_RATIOS:
    entries = sorted(
        [e for e in pipeline_results if e["compressor"] == comp_name and e["model"] == "llama-3.1-8b" and e["seq_len"] == 1024],
        key=lambda e: e["bandwidth_gbps"]
    )
    seq_be = pipe_be = "N/A"
    for e in entries:
        if e["sequential_speedup"] > 1.0 and seq_be == "N/A":
            seq_be = f"{e['bandwidth_gbps']}G"
        if e["pipelined_speedup"] > 1.0 and pipe_be == "N/A":
            pipe_be = f"{e['bandwidth_gbps']}G"
    print(f"  {comp_name}: sequential={seq_be}, pipelined={pipe_be}")

In [None]:
# Cell 10: Publication figures
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams.update({"font.size": 12, "figure.dpi": 150})

# Figure 1: CPU vs GPU compress time (log scale)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for idx, (results, title) in enumerate([(int8_results, "Uniform INT8"), (kivi_results, "KIVI 2-bit")]):
    ax = axes[idx]
    for model_name in MODELS:
        data = [e for e in results if e["model"] == model_name]
        sls = [e["seq_len"] for e in data]
        cpu = [e["cpu_compress_ms"] for e in data]
        gpu = [e["gpu_compress_ms"] for e in data]
        ax.plot(sls, cpu, "--o", label=f"{model_name} (CPU)", alpha=0.7)
        ax.plot(sls, gpu, "-s", label=f"{model_name} (GPU)", alpha=0.7)
    ax.set_xlabel("Sequence Length")
    ax.set_ylabel("Compress Time (ms)")
    ax.set_title(f"{title}: CPU vs CUDA (zero-copy)")
    ax.legend(fontsize=7)
    ax.set_xscale("log", base=2)
    ax.set_yscale("log")
    ax.grid(True, alpha=0.3)

plt.suptitle(f"GPU-Native Compression Timing — {gpu_name}", fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig("fig_cpu_vs_gpu_timing_v2.pdf", bbox_inches="tight")
plt.show()

In [None]:
# Cell 11: Figure 2 — Speedup curves

fig, axes = plt.subplots(1, 2, figsize=(14, 5))
target_model = "llama-3.1-8b"
target_seq = 1024

for idx, comp_name in enumerate(["uniform_int8", "kivi_2bit"]):
    ax = axes[idx]
    data = [e for e in pipeline_results
            if e["compressor"] == comp_name and e["model"] == target_model and e["seq_len"] == target_seq]
    bws = [e["bandwidth_gbps"] for e in data]
    seq_s = [e["sequential_speedup"] for e in data]
    pipe_s = [e["pipelined_speedup"] for e in data]

    ax.semilogx(bws, seq_s, "-o", label="Sequential", linewidth=2, markersize=8)
    ax.semilogx(bws, pipe_s, "-s", label="Pipelined", linewidth=2, markersize=8)
    ax.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="Break-even")
    ax.set_xlabel("Bandwidth (Gbps)")
    ax.set_ylabel("Speedup vs Raw Transfer")
    ax.set_title(f"{comp_name} ({target_model}, seq={target_seq})")
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_ylim(bottom=0)

plt.suptitle(f"GPU-Calibrated Speedup (Zero-Copy) — {gpu_name}", fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig("fig_gpu_speedup_curves_v2.pdf", bbox_inches="tight")
plt.show()

In [None]:
# Cell 12: Figure 3 — Speedup comparison bar chart

fig, ax = plt.subplots(figsize=(14, 6))
bw_targets = [1, 5, 10, 25, 50, 100]
bar_data = {}

for comp_name in ["uniform_int8", "kivi_2bit"]:
    for bw in bw_targets:
        key = f"{comp_name}\n{bw}G"
        pipe_s = [e["pipelined_speedup"] for e in pipeline_results
                  if e["compressor"] == comp_name and e["bandwidth_gbps"] == bw and e["seq_len"] == 1024]
        seq_s = [e["sequential_speedup"] for e in pipeline_results
                 if e["compressor"] == comp_name and e["bandwidth_gbps"] == bw and e["seq_len"] == 1024]
        if pipe_s:
            bar_data[key] = {"sequential": np.mean(seq_s), "pipelined": np.mean(pipe_s)}

x = np.arange(len(bar_data))
w = 0.35
seq_v = [v["sequential"] for v in bar_data.values()]
pipe_v = [v["pipelined"] for v in bar_data.values()]

b1 = ax.bar(x - w/2, seq_v, w, label="Sequential", alpha=0.8, color="steelblue")
b2 = ax.bar(x + w/2, pipe_v, w, label="Pipelined", alpha=0.8, color="darkorange")

ax.set_ylabel("Speedup vs Raw Transfer")
ax.set_title(f"GPU-Native Speedup (seq=1024, avg across models) — {gpu_name}")
ax.set_xticks(x)
ax.set_xticklabels(bar_data.keys(), fontsize=9)
ax.legend()
ax.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, linewidth=2)
ax.grid(True, alpha=0.3, axis="y")

for bar in b1:
    if bar.get_height() > 0.05:
        ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.02,
                f"{bar.get_height():.1f}x", ha="center", va="bottom", fontsize=8)
for bar in b2:
    if bar.get_height() > 0.05:
        ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.02,
                f"{bar.get_height():.1f}x", ha="center", va="bottom", fontsize=8)

plt.tight_layout()
plt.savefig("fig_gpu_speedup_bars_v2.pdf", bbox_inches="tight")
plt.show()

In [None]:
# Cell 13: LaTeX tables (all 7 compressors)

ALL_CAL_LATEX = [
    ("int8_calibration", "uniform\\_int8", int8_results),
    ("kivi_calibration", "kivi\\_2bit", kivi_results),
    ("int4_calibration", "uniform\\_int4", int4_results),
    ("fp8_calibration", "fp8\\_e4m3", fp8_results),
    ("cachegen_calibration", "cachegen", cachegen_results),
    ("cascade_calibration", "cascade\\_prune50\\_int4", cascade_results),
    ("palu_calibration", "palu\\_lr", palu_results),
]

print("% === GPU Calibration Table ===")
print("\\begin{table*}[t]")
print("\\centering")
print("\\caption{GPU-native compression timing (zero-copy, CUDA events) on " + gpu_name.replace('_', '\\_') + "}")
print("\\label{tab:gpu_calibration}")
print("\\begin{tabular}{llrrrr}")
print("\\toprule")
print("Compressor & Model & Seq Len & CPU (ms) & GPU (ms) & Speedup \\\\")
print("\\midrule")
for _, cname, results in ALL_CAL_LATEX:
    for e in results:
        if e["seq_len"] in [512, 1024, 2048]:
            m = e['model'].replace('_', '\\_')
            print(f"  {cname} & {m} & {e['seq_len']} & {e['cpu_compress_ms']:.2f} & {e['gpu_compress_ms']:.3f} & {e['compress_speedup']:.0f}$\\times$ \\\\")
    print("\\midrule")
print("\\bottomrule")
print("\\end{tabular}")
print("\\end{table*}")

print("\n% === Break-Even Table ===")
print("\\begin{table}[t]")
print("\\centering")
print("\\caption{GPU-calibrated break-even bandwidth (lowest Gbps where speedup $> 1$)}")
print("\\label{tab:gpu_breakeven}")
print("\\begin{tabular}{llcc}")
print("\\toprule")
print("Strategy & Model & Sequential & Pipelined \\\\")
print("\\midrule")
for cn in COMP_RATIOS:
    for mn in MODELS:
        entries = sorted([e for e in pipeline_results if e["compressor"]==cn and e["model"]==mn and e["seq_len"]==1024], key=lambda e: e["bandwidth_gbps"])
        sbe = pbe = "N/A"
        for e in entries:
            if e["sequential_speedup"] > 1 and sbe == "N/A": sbe = f"{e['bandwidth_gbps']}G"
            if e["pipelined_speedup"] > 1 and pbe == "N/A": pbe = f"{e['bandwidth_gbps']}G"
        c = cn.replace('_','\\_')
        m = mn.replace('_','\\_')
        print(f"  {c} & {m} & {sbe} & {pbe} \\\\")
print("\\bottomrule")
print("\\end{tabular}")
print("\\end{table}")

In [None]:
# Cell 14: Save all results (all 7 compressors)

ALL_CAL_SAVE = {
    "int8": int8_results,
    "kivi": kivi_results,
    "int4": int4_results,
    "fp8": fp8_results,
    "cachegen": cachegen_results,
    "cascade": cascade_results,
    "palu": palu_results,
}

# Build summary stats for each compressor
summary_dict = {}
for short_name, results in ALL_CAL_SAVE.items():
    comp_sp = [e["compress_speedup"] for e in results]
    dec_sp = [e["decompress_speedup"] for e in results]
    summary_dict[f"{short_name}_compress_speedup_mean"] = round(float(np.mean(comp_sp)), 1)
    summary_dict[f"{short_name}_compress_speedup_range"] = [round(float(min(comp_sp)), 1), round(float(max(comp_sp)), 1)]
    summary_dict[f"{short_name}_decompress_speedup_mean"] = round(float(np.mean(dec_sp)), 1)
    summary_dict[f"{short_name}_decompress_speedup_range"] = [round(float(min(dec_sp)), 1), round(float(max(dec_sp)), 1)]

pipe_savings = [e["pipeline_saving_pct"] for e in pipeline_results]
summary_dict["pipeline_saving_mean_pct"] = round(float(np.mean(pipe_savings)), 1)
summary_dict["pipeline_saving_range_pct"] = [round(float(min(pipe_savings)), 1), round(float(max(pipe_savings)), 1)]

all_results = {
    "metadata": {
        "gpu": gpu_name,
        "gpu_tier": GPU_TIER,
        "gpu_info": gpu_info,
        "pytorch_version": torch.__version__,
        "cuda_version": torch.version.cuda if torch.cuda.is_available() else "N/A",
        "timing_method": "torch.cuda.Event (zero-copy, GPU-native)",
        "fp8_path": FP8_PATH,
        "version": "v3_all_compressors",
        "models": {k: {"layers": v[0], "kv_heads": v[1], "head_dim": v[2]} for k, v in MODELS.items()},
        "seq_lens": SEQ_LENS,
        "warmup": WARMUP,
        "repeats": REPEATS,
    },
    "int8_calibration": int8_results,
    "kivi_calibration": kivi_results,
    "int4_calibration": int4_results,
    "fp8_calibration": fp8_results,
    "cachegen_calibration": cachegen_results,
    "cascade_calibration": cascade_results,
    "palu_calibration": palu_results,
    "pipeline_comparison": pipeline_results,
    "summary": summary_dict,
}

out_filename = f"gpu_calibration_results_{GPU_TIER}.json"
with open(out_filename, "w") as f:
    json.dump(all_results, f, indent=2)

print(f"Saved: {out_filename}")
print(f"  GPU: {gpu_name} (tier={GPU_TIER})")
print(f"  GPU info: {json.dumps(gpu_info, indent=2)}")
print(f"  FP8 path: {FP8_PATH}")
print(f"  - {len(int8_results)} INT8 + {len(kivi_results)} KIVI + {len(int4_results)} INT4")
print(f"  - {len(fp8_results)} FP8 + {len(cachegen_results)} CacheGen")
print(f"  - {len(cascade_results)} Cascade + {len(palu_results)} Palu")
print(f"  - {len(pipeline_results)} pipeline entries")
print(f"\nDone! Download {out_filename} for your paper.")

## What Changed from v2

| Aspect | v2 (2 compressors) | v3 (all 7 compressors) |
|--------|-------------------|----------------------|
| Compressors | INT8, KIVI 2-bit | + INT4, FP8, CacheGen, Cascade, Palu |
| Pipeline comparison | 2 strategies | 7 strategies |
| JSON sections | int8, kivi calibration | + int4, fp8, cachegen, cascade, palu |
| Paper Table 1 | 2 GPU rows filled | All 7 GPU rows filled |
| Runtime | ~5-10 min | ~15-25 min (SVD is compute-heavy) |

### Notes
- **Palu (SVD)**: Expect modest 2-5x GPU speedup since SVD is already well-parallelized via LAPACK on CPU
- **Cascade**: Includes `torch.topk` sorting cost in timing (realistic)
- **CacheGen**: Most complex — chunked delta pattern is fully vectorized on GPU

## Multi-GPU Comparison

After running this notebook on multiple GPU tiers (T4, A100, H100), combine the results
using the integration script:

```bash
# Integrate T4 baseline with additional GPU results
python experiments/scripts/integrate_gpu_calibration.py \
    experiments/results/model_sweep/results.json \
    experiments/notebooks/gpu_calibration_results_t4.json \
    --gpu-results experiments/notebooks/gpu_calibration_results_a100.json \
                  experiments/notebooks/gpu_calibration_results_h100.json
```

This produces multi-GPU comparison figures in `paper/figures/gpu_calibrated/` including:
- **Break-even bandwidth shift** across GPU tiers
- **Speedup comparison** bars per compressor per GPU
- **Pipeline saving** differences (faster GPUs → higher break-even bandwidth)

The visualization functions `plot_multi_gpu_speedup()` and `plot_breakeven_shift()` in
`kvshuttle/visualization/gpu_calibration.py` handle the rendering automatically.