# Persistent Megakernel for Qwen3-0.6B

A single CUDA kernel that fuses the entire Qwen3-0.6B decode forward pass.
Targeting **1,200+ tok/s on A100** and **2,600+ tok/s on H100**.

**Architecture**: Non-cooperative persistent kernel with:
- Atomic barriers (not cooperative `grid.sync()`)
- Productive spin: idle blocks prefetch weights during attention
- Redundant RMSNorm to eliminate barriers
- On-device argmax sampling (no CPU readback)
- 128-bit vectorized weight loads with L1 bypass

**References**:
- [MegaQwen](https://elliotarledge.com/blog/megaqwen) (530 tok/s, RTX 3090)
- [1k tok/s kernel](https://blog.alpindale.net/posts/5090_decode_optimization/) (RTX 5090)
- [Hazy Research Megakernels](https://hazyresearch.stanford.edu/blog/2025-05-27-no-bubbles) (<1ms, H100)

## Cell 1: Setup & GPU Detection

In [None]:
# Install dependencies
!pip install -q transformers torch safetensors accelerate

# Check GPU
!nvidia-smi

import torch
import os
import subprocess
import time
import json
import numpy as np

# GPU detection
device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
cc = props.major * 10 + props.minor

GPU_NAME = props.name
SM_COUNT = props.multi_processor_count
TOTAL_MEM_GB = getattr(props, 'total_global_memory', getattr(props, 'total_mem', 0)) / 1e9

if cc >= 90:
    ARCH = "sm_90"
    GPU_CLASS = "H100"
    PEAK_BW = 3350.0
elif cc >= 80:
    ARCH = "sm_80"
    GPU_CLASS = "A100"
    PEAK_BW = 2039.0 if TOTAL_MEM_GB > 50 else 1555.0
elif cc >= 75:
    ARCH = "sm_75"
    GPU_CLASS = "T4"
    PEAK_BW = 300.0
else:
    ARCH = f"sm_{cc}"
    GPU_CLASS = "Unknown"
    PEAK_BW = 200.0

# Actual model size: 751M params * 2 bytes (BF16) = ~1.5 GB
MODEL_BYTES_BF16 = 1503e6
THEORETICAL_MAX = PEAK_BW * 1e9 / MODEL_BYTES_BF16

print(f"GPU: {GPU_NAME}")
print(f"Class: {GPU_CLASS} ({ARCH})")
print(f"SMs: {SM_COUNT}")
print(f"Memory: {TOTAL_MEM_GB:.1f} GB")
print(f"Peak Bandwidth: {PEAK_BW:.0f} GB/s")
print(f"Theoretical max tok/s (Qwen3-0.6B BF16): {THEORETICAL_MAX:.0f}")

## Cell 2: Download & Convert Weights

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_ID = "Qwen/Qwen3-0.6B"
CACHE_DIR = "./model_cache"

print(f"Downloading {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="cpu",  # load to CPU first, we'll manage GPU memory
    cache_dir=CACHE_DIR
)
print(f"Model loaded. Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")

# Verify architecture matches our expectations
cfg = model.config
assert cfg.hidden_size == 1024, f"Expected hidden_size=1024, got {cfg.hidden_size}"
assert cfg.num_hidden_layers == 28, f"Expected 28 layers, got {cfg.num_hidden_layers}"
assert cfg.num_attention_heads == 16, f"Expected 16 Q heads, got {cfg.num_attention_heads}"
assert cfg.num_key_value_heads == 8, f"Expected 8 KV heads, got {cfg.num_key_value_heads}"

rope_theta = getattr(cfg, 'rope_theta', None)
if rope_theta is None and hasattr(cfg, 'rope_parameters'):
    rope_theta = cfg.rope_parameters.get('rope_theta', 1000000)

print("\nArchitecture verified!")
print(f"  Hidden dim: {cfg.hidden_size}")
print(f"  Layers: {cfg.num_hidden_layers}")
print(f"  Q heads: {cfg.num_attention_heads}, KV heads: {cfg.num_key_value_heads}")
print(f"  Head dim: {cfg.head_dim}")
print(f"  Intermediate size: {cfg.intermediate_size}")
print(f"  Vocab size: {cfg.vocab_size}")
print(f"  RoPE theta: {rope_theta}")

In [None]:
# Convert weights to flat binary format
import struct

# Use dimensions from the actual model config (not hardcoded)
HIDDEN_DIM = cfg.hidden_size          # 1024
NUM_LAYERS = cfg.num_hidden_layers    # 28
NUM_Q_HEADS = cfg.num_attention_heads # 16
NUM_KV_HEADS = cfg.num_key_value_heads # 8
HEAD_DIM = cfg.head_dim               # 128
Q_DIM = NUM_Q_HEADS * HEAD_DIM       # 2048
KV_DIM = NUM_KV_HEADS * HEAD_DIM     # 1024
INTERMEDIATE_DIM = cfg.intermediate_size  # 3072
VOCAB_SIZE = cfg.vocab_size           # 151936

print(f"Dimensions: HEAD_DIM={HEAD_DIM}, Q_DIM={Q_DIM}, KV_DIM={KV_DIM}, INTERMEDIATE_DIM={INTERMEDIATE_DIM}")

state = model.state_dict()

# Extract and save weights in our flat binary format
parts = []
offsets = {}
pos = 0

def add(name, tensor):
    global pos
    t = tensor.contiguous().to(torch.bfloat16)
    data = t.view(torch.uint16).numpy().tobytes()
    offsets[name] = pos
    parts.append(data)
    pos += len(data)

print("Packing weights...")

# Embedding
add("embedding", state["model.embed_tokens.weight"])
print(f"  embedding: {state['model.embed_tokens.weight'].shape}")

# Final norm
add("final_norm", state["model.norm.weight"])

# Per-layer (11 weights each: attn_norm, w_q, w_k, w_v, q_norm, k_norm, w_o, ffn_norm, w_gate, w_up, w_down)
for i in range(NUM_LAYERS):
    prefix = f"model.layers.{i}"
    add(f"layer.{i}.attn_norm", state[f"{prefix}.input_layernorm.weight"])
    add(f"layer.{i}.w_q", state[f"{prefix}.self_attn.q_proj.weight"])
    add(f"layer.{i}.w_k", state[f"{prefix}.self_attn.k_proj.weight"])
    add(f"layer.{i}.w_v", state[f"{prefix}.self_attn.v_proj.weight"])
    add(f"layer.{i}.q_norm", state[f"{prefix}.self_attn.q_norm.weight"])
    add(f"layer.{i}.k_norm", state[f"{prefix}.self_attn.k_norm.weight"])
    add(f"layer.{i}.w_o", state[f"{prefix}.self_attn.o_proj.weight"])
    add(f"layer.{i}.ffn_norm", state[f"{prefix}.post_attention_layernorm.weight"])
    add(f"layer.{i}.w_gate", state[f"{prefix}.mlp.gate_proj.weight"])
    add(f"layer.{i}.w_up", state[f"{prefix}.mlp.up_proj.weight"])
    add(f"layer.{i}.w_down", state[f"{prefix}.mlp.down_proj.weight"])

print(f"  {NUM_LAYERS} layers packed (11 weights each)")

# Save binary
weights_path = "weights.bin"
with open(weights_path, "wb") as f:
    f.write(b"".join(parts))

offsets_path = "weights_offsets.json"
with open(offsets_path, "w") as f:
    json.dump(offsets, f, indent=2)

total_mb = pos / 1e6
print(f"\nSaved {total_mb:.1f} MB to {weights_path}")
print(f"Offsets saved to {offsets_path}")
print(f"Total offsets: {len(offsets)}")

## Cell 3: Compile Kernel

In [None]:
# Check if source files exist
import os
if not os.path.exists("src/megakernel.cu"):
    print("Source files not found! Please clone the repo first:")
    print("  !git clone https://github.com/MaruthiV/megakernel.git")
    print("  Then cd into the repo directory.")
else:
    print("Source files found.")

# Compile standalone binary (for info/testing)
print(f"\nCompiling standalone for {GPU_CLASS} ({ARCH})...")
result = subprocess.run(
    ["nvcc", "-O3", f"-arch={ARCH}", "-std=c++17", "--use_fast_math",
     "-lineinfo", "-Isrc", "src/megakernel.cu", "-o", "megakernel"],
    capture_output=True, text=True
)

if result.returncode == 0:
    print("Standalone binary: OK")
else:
    print(f"Standalone compilation FAILED!\n{result.stderr}")

# Compile shared library (for Python ctypes bridge)
print(f"Compiling shared library for {GPU_CLASS} ({ARCH})...")
result_lib = subprocess.run(
    ["nvcc", "-O3", f"-arch={ARCH}", "-std=c++17", "--use_fast_math",
     "-shared", "-Xcompiler", "-fPIC", "-DMEGAKERNEL_LIBRARY_MODE",
     "-lineinfo", "-Isrc", "src/megakernel.cu", "-o", "megakernel.so"],
    capture_output=True, text=True
)

if result_lib.returncode == 0:
    print("Shared library: OK")
else:
    print(f"Library compilation FAILED!\n{result_lib.stderr}")

# Run standalone info
if result.returncode == 0:
    print("\n--- Kernel Info ---")
    !./megakernel

## Cell 4: Correctness Validation

In [None]:
# Run PyTorch reference forward pass
model_gpu = model.to("cuda")
model_gpu.eval()

# Reference: generate 50 tokens with greedy decoding
prompt = "The meaning of life is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")

print(f"Prompt: '{prompt}'")
print(f"Input IDs: {input_ids[0].tolist()[:10]}...")

with torch.no_grad():
    ref_output = model_gpu.generate(
        input_ids,
        max_new_tokens=50,
        do_sample=False,
    )

ref_text = tokenizer.decode(ref_output[0], skip_special_tokens=True)
ref_tokens = ref_output[0].tolist()
print(f"\nReference output ({len(ref_tokens)} tokens):")
print(f"  {ref_text[:300]}")

# Get reference logits for first token
with torch.no_grad():
    ref_logits = model_gpu(input_ids).logits[0, -1, :]  # [VOCAB_SIZE]
    ref_next_token = ref_logits.argmax().item()
    print(f"\nReference next token: {ref_next_token} = '{tokenizer.decode([ref_next_token])}'")
    print(f"Top-5 logits: {ref_logits.topk(5)}")

# TODO: Compare with megakernel output once the kernel produces correct results
print("\n[Megakernel validation will be added once kernel is verified]")

## Cell 5: HuggingFace Baseline Throughput

In [None]:
# Measure HuggingFace baseline throughput
print("Measuring HuggingFace baseline throughput...")

# Short context (position ~1)
input_ids = torch.tensor([[1]], device="cuda")

# Warmup
for _ in range(10):
    with torch.no_grad():
        model_gpu(input_ids)
torch.cuda.synchronize()

# Benchmark
N = 50
start = time.perf_counter()
for _ in range(N):
    with torch.no_grad():
        model_gpu(input_ids)
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
hf_toks = N / elapsed

print(f"\nHuggingFace baseline (short context): {hf_toks:.1f} tok/s")
print(f"Per-token latency: {1e6/hf_toks:.1f} us")

# Long context baseline
long_input = torch.randint(0, cfg.vocab_size, (1, 1024), device="cuda")
for _ in range(5):
    with torch.no_grad():
        model_gpu(long_input)
torch.cuda.synchronize()

N_long = 20
start = time.perf_counter()
for _ in range(N_long):
    with torch.no_grad():
        model_gpu(long_input)
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
hf_long_toks = N_long / elapsed

print(f"HuggingFace baseline (1024 context): {hf_long_toks:.1f} tok/s")

HF_BASELINE = hf_toks
print(f"\nBaseline to beat: {HF_BASELINE:.1f} tok/s")
print(f"Target: {THEORETICAL_MAX * 0.71:.0f}+ tok/s (71% bandwidth utilization)")

## Cell 6: Megakernel Benchmarks

Once the kernel is fully working, this cell runs the full benchmark suite.

In [None]:
# Initialize megakernel engine
import sys
sys.path.insert(0, ".")
from host.bridge import MegakernelEngine

engine = MegakernelEngine(
    weights_path="weights.bin",
    offsets_path="weights_offsets.json",
    lib_path="./megakernel.so",
    tokenizer=tokenizer,
)

# Benchmark sweep across context positions
print("=" * 70)
print("  MEGAKERNEL BENCHMARK")
print("=" * 70)
print(f"  GPU:             {GPU_NAME}")
print(f"  Peak BW:         {PEAK_BW:.0f} GB/s")
print(f"  Theoretical max: {THEORETICAL_MAX:.0f} tok/s")
print(f"  HF baseline:     {HF_BASELINE:.1f} tok/s")
print("=" * 70)
print()

results = engine.benchmark_sweep()

print(f"{'Position':>10} {'Tok/s':>10} {'Latency(us)':>12} {'BW(GB/s)':>10} {'BW Util%':>10} {'vs HF':>8}")
print("-" * 62)
for r in results:
    bw_util = r["bw_gbps"] / PEAK_BW * 100
    speedup = r["tok_per_sec"] / HF_BASELINE if HF_BASELINE > 0 else 0
    print(f"{r['position']:>10} {r['tok_per_sec']:>10.1f} {r['latency_us']:>12.1f} "
          f"{r['bw_gbps']:>10.1f} {bw_util:>9.1f}% {speedup:>7.1f}x")

# Peak result
best = max(results, key=lambda r: r["tok_per_sec"])
best_util = best["bw_gbps"] / PEAK_BW * 100
best_speedup = best["tok_per_sec"] / HF_BASELINE if HF_BASELINE > 0 else 0
print()
print(f"  PEAK: {best['tok_per_sec']:.0f} tok/s "
      f"({best_util:.1f}% BW, {best_speedup:.1f}x vs HuggingFace)")

## Cell 7: Text Generation Demo

In [None]:
# Interactive text generation with the megakernel
prompt = input("Enter your prompt: ")

print(f"\nPrompt: {prompt}")
print("Generating...\n")

text, tok_per_sec = engine.generate(prompt, max_tokens=200, print_stream=True)

print(f"\n{'='*50}")
print(f"Throughput: {tok_per_sec:.1f} tok/s")
print(f"Latency:   {1e6/tok_per_sec:.0f} us/token")

bw_gbps = MODEL_BYTES_BF16 * tok_per_sec / 1e9
bw_util = bw_gbps / PEAK_BW * 100
print(f"Bandwidth: {bw_gbps:.1f} GB/s ({bw_util:.1f}% utilization)")
if HF_BASELINE > 0:
    print(f"Speedup:   {tok_per_sec/HF_BASELINE:.1f}x vs HuggingFace")

## Cell 8: Profiling (Optional)

In [None]:
# Optional: Profile with nsys
# Requires nsight-systems to be installed (available on Colab)

print("Profiling is optional. Uncomment to run:")
# !nsys profile --stats=true --output=megakernel_profile ./megakernel
# !nsys stats megakernel_profile.nsys-rep

# For detailed kernel analysis:
# !ncu --set full --launch-count 1 ./megakernel

## Summary

| Metric | HuggingFace | MegaQwen | Our Kernel | Theoretical Max |
|--------|-------------|----------|------------|------------------|
| Tok/s (short ctx) | ~136 | ~530 | TBD | See above |
| Tok/s (long ctx) | ~59 | ~158 | TBD | - |
| BW Utilization | ~5% | ~5% | Target 71%+ | 100% |
| GPU | RTX 3090 | RTX 3090 | A100/H100 | - |