In [1]:
import torch 
import time

In [4]:
device = torch.device("cuda")
B, T, H, D = 128, 2048, 16, 64

# Inputs
v = torch.randn(B, T, H, D, dtype=torch.complex64, device=device)
k = torch.randn(B, T, H, D, dtype=torch.complex64, device=device)
q = torch.randn(B, T, H, D, dtype=torch.complex64, device=device)

def holo_op(v, k, q):
    # The logic we want to fuse 
    mem = torch.cumsum(v * k, dim = 1)

    # Scaling logic
    T = mem.shape[1]
    scale = torch.sqrt(torch.arange(1, T + 1, device=mem.device, dtype=torch.float32))
    scale = scale.view(1, T, 1, 1)
    
    return (mem * q).real / scale

# 2. Compile it
# fullgraph=True ensures no python fallbacks. 
# mode="max-autotune" enables the most aggressive Triton optimizations.
compiled_holo = torch.compile(holo_op, mode="max-autotune", fullgraph=True)

# --- Warmup ---
print("Warming up...")
for _ in range(5):
    _ = holo_op(v, k, q)      # Eager
    _ = compiled_holo(v, k, q) # Compiled

# --- Benchmark Eager ---
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
    _ = holo_op(v, k, q)
torch.cuda.synchronize()
eager_time = (time.time() - start) * 10

# --- Benchmark Compile ---
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
    _ = compiled_holo(v, k, q)
torch.cuda.synchronize()
compiled_time = (time.time() - start) * 10

print(f"Eager PyTorch:   {eager_time:.3f} ms")
print(f"torch.compile:   {compiled_time:.3f} ms")
print(f"Speedup:         {eager_time / compiled_time:.2f}x")

Warming up...
Eager PyTorch:   23.697 ms
torch.compile:   38.974 ms
Speedup:         0.61x


In [7]:
import torch
import torch.nn as nn
import time
import triton
import triton.language as tl

# ==========================================
# 1. Fixed Triton Kernel
# ==========================================
@triton.jit
def _holo_scan_fused_kernel(
    v_ptr, k_ptr, q_ptr, out_ptr,
    # Strides (Standard PyTorch strides)
    stride_v_b, stride_v_t, stride_v_h, stride_v_d,
    stride_k_b, stride_k_t, stride_k_h, stride_k_d,
    stride_q_b, stride_q_t, stride_q_h, stride_q_d,
    stride_out_b, stride_out_t, stride_out_h, stride_out_d,
    T, H, D, 
    BLOCK_D: tl.constexpr
):
    # 1. Parallelize over (Batch * Head)
    pid = tl.program_id(0)
    
    # 2. Decompose PID into Batch and Head indices
    # logical grid is flattened, so we reconstruct dimensions
    b_idx = pid // H
    h_idx = pid % H
    
    # 3. Calculate Base Pointers for this (Batch, Head) specific sequence
    # Logic: base = b * stride_b + h * stride_h
    # This allows us to read directly from (B, T, H, D) without permuting!
    
    v_base = v_ptr + b_idx * stride_v_b + h_idx * stride_v_h
    k_base = k_ptr + b_idx * stride_k_b + h_idx * stride_k_h
    q_base = q_ptr + b_idx * stride_q_b + h_idx * stride_q_h
    out_base = out_ptr + b_idx * stride_out_b + h_idx * stride_out_h

    # 4. Setup D-dimension (contiguous block)
    offs_d = tl.arange(0, BLOCK_D)
    mask_d = offs_d < D
    
    # 5. Initialize Accumulators
    acc_real = tl.zeros([BLOCK_D], dtype=tl.float32)
    acc_imag = tl.zeros([BLOCK_D], dtype=tl.float32)

    # 6. Sequential Loop over Time
    for t in range(T):
        # Stride logic: base + t * stride_t + d * stride_d
        off_v = t * stride_v_t + offs_d * stride_v_d
        off_k = t * stride_k_t + offs_d * stride_k_d
        off_q = t * stride_q_t + offs_d * stride_q_d
        
        # Load
        v_real = tl.load(v_base + off_v, mask=mask_d, other=0.0)
        v_imag = tl.load(v_base + off_v + 1, mask=mask_d, other=0.0)
        k_real = tl.load(k_base + off_k, mask=mask_d, other=0.0)
        k_imag = tl.load(k_base + off_k + 1, mask=mask_d, other=0.0)
        q_real = tl.load(q_base + off_q, mask=mask_d, other=0.0)
        q_imag = tl.load(q_base + off_q + 1, mask=mask_d, other=0.0)

        # Compute (v * k)
        term_real = v_real * k_real - v_imag * k_imag
        term_imag = v_real * k_imag + v_imag * k_real
        
        # Accumulate (Scan)
        acc_real += term_real
        acc_imag += term_imag
        
        # Retrieve (acc * q)
        out_val = acc_real * q_real - acc_imag * q_imag
        
        # Scale
        scale = tl.sqrt((t + 1).to(tl.float32))
        out_val = out_val / scale

        # Store
        off_out = t * stride_out_t + offs_d * stride_out_d
        tl.store(out_base + off_out, out_val, mask=mask_d)


# ==========================================
# 2. Wrapper (No Copies!)
# ==========================================

def holo_op_triton(v, k, q):
    # Input is (B, T, H, D)
    # We DO NOT permute. We DO NOT make contiguous.
    # We pass the original messy strides to the kernel.
    
    B, T, H, D = v.shape
    
    # View as Float32 to get float-pointer compatible strides
    # We use a helper because .view(float32) requires contiguity on last dim usually,
    # but since we are just passing pointers and calculating offsets manually,
    # we can just cast the data_ptr if we are careful.
    
    # SAFER WAY:
    # We still need the last dimension (D) to be contiguous for the block load `tl.load(ptr + range(0, D))`.
    # Pytorch (B, T, H, D) is usually contiguous in D.
    if v.stride(-1) != 1: v = v.contiguous()
    if k.stride(-1) != 1: k = k.contiguous()
    if q.stride(-1) != 1: q = q.contiguous()
    
    # Create Float32 Views (Only changes metadata, no copy if last dim is contiguous)
    v_f = v.view(torch.float32)
    k_f = k.view(torch.float32)
    q_f = q.view(torch.float32)
    
    # Alloc Output
    out = torch.empty((B, T, H, D), device=v.device, dtype=torch.float32)
    
    # Grid: B * H
    grid = (B * H, )
    BLOCK_D = triton.next_power_of_2(D)
    
    # Strides:
    # Remember our float view doubled the last stride (because 1 complex = 2 floats).
    # But wait, v.stride() returns stride in "elements".
    # If v is complex64, v.stride(-1) is 1. 
    # v_f is float32, v_f.stride(-1) is 1.
    # BUT we need to jump 2 floats to get to the next real number in our kernel logic?
    # Actually:
    # In Kernel: offs_d * stride_v_d.
    # We want offs_d=1 to point to the next REAL number.
    # In memory: [R0, I0, R1, I1].
    # R0 is at 0. R1 is at 2.
    # So stride_d MUST be 2.
    
    # However, v_f.stride(-1) is 1.
    # So we must manually adjust strides for the kernel.
    # Specifically, multiply all strides by 2 because we are treating complex* as float*.
    # EXCEPT if the original stride was 1 (dense), it becomes 1 in float view... 
    # No, strictly speaking:
    # 1 complex element = 8 bytes.
    # 1 float element = 4 bytes.
    # Pointer arithmetic in Triton is on the type (float32).
    # To jump 1 complex element (8 bytes), we must jump 2 float elements.
    # So YES, we multiply ALL input strides by 2.
    
    def get_strides(x):
        return (x.stride(0)*2, x.stride(1)*2, x.stride(2)*2, x.stride(3)*2)

    s_v = get_strides(v)
    s_k = get_strides(k)
    s_q = get_strides(q)
    
    # Output is float32, so its strides are native. 
    # BUT our kernel logic uses stride_out_d to jump output elements.
    # Output is just Real numbers. [Out0, Out1].
    # So stride is 1.
    s_out = out.stride()
    
    _holo_scan_fused_kernel[grid](
        v_f, k_f, q_f, out,
        *s_v,
        *s_k,
        *s_q,
        *s_out,
        T=T, H=H, D=D,
        BLOCK_D=BLOCK_D
    )

    return out

In [8]:
# ==========================================
# 1. Implementations
# ==========================================

# --- A. Eager Implementation ---
def holo_op_eager(v, k, q):
    # 1. Bind & Accumulate (The Bottleneck: Writes massive tensor to HBM)
    mem = torch.cumsum(v * k, dim=1)
    
    # 2. Scaling
    T = mem.shape[1]
    scale = torch.sqrt(torch.arange(1, T + 1, device=mem.device, dtype=torch.float32))
    scale = scale.view(1, T, 1, 1)
    
    # 3. Retrieve
    return (mem * q).real / scale

In [9]:
# ==========================================
# 2. Benchmark Setup
# ==========================================

device = torch.device("cuda")
torch.set_float32_matmul_precision('high')

# Setup Config (reduced slightly to fit typical VRAM for testing)
B, T, H, D = 64, 2048, 16, 64  
# NOTE: Total VRAM usage approx 4GB for inputs

print(f"Benchmarking with B={B}, T={T}, H={H}, D={D}")
print("Generating inputs...")

v = torch.randn(B, T, H, D, dtype=torch.complex64, device=device)
k = torch.randn(B, T, H, D, dtype=torch.complex64, device=device)
q = torch.randn(B, T, H, D, dtype=torch.complex64, device=device)

# Compile the Eager version
print("Compiling torch.compile version...")
holo_op_compiled = torch.compile(holo_op_eager, mode="max-autotune", fullgraph=True)

# Warmup Compilation
try:
    _ = holo_op_compiled(v, k, q)
except Exception as e:
    print(f"Compilation failed (common with complex64 sometimes): {e}")
    # Fallback to default mode if max-autotune fails
    holo_op_compiled = torch.compile(holo_op_eager)
    _ = holo_op_compiled(v, k, q)

# Verify Correctness (Quick check vs Triton)
print("Verifying correctness...")
out_eager = holo_op_eager(v, k, q)
out_triton = holo_op_triton(v, k, q)

# Note: Triton float math accumulation order vs PyTorch Parallel reduction
# can cause small differences. standard rtol=1e-3 is usually fine.
diff = torch.max(torch.abs(out_eager - out_triton))
print(f"Max Difference Eager vs Triton: {diff:.6f}")
assert diff < 1e-2, "Triton implementation mismatch!"

# ==========================================
# 3. Run Benchmark
# ==========================================

def run_bench(name, func, iters=100):
    # Warmup
    for _ in range(5):
        _ = func(v, k, q)
    torch.cuda.synchronize()
    
    # Timing
    start = time.time()
    for _ in range(iters):
        _ = func(v, k, q)
    torch.cuda.synchronize()
    
    total_time_ms = (time.time() - start) * 1000
    avg_time_ms = total_time_ms / iters
    print(f"{name:<15}: {avg_time_ms:.3f} ms")
    return avg_time_ms

print("\n--- Starting Benchmark ---")
t_eager = run_bench("Eager PyTorch", holo_op_eager)
t_compile = run_bench("Torch Compile", holo_op_compiled)
t_triton = run_bench("Triton Kernel", holo_op_triton)

print("\n--- Results ---")
print(f"Compile Speedup vs Eager: {t_eager / t_compile:.2f}x")
print(f"Triton Speedup vs Eager:  {t_eager / t_triton:.2f}x")

Benchmarking with B=64, T=2048, H=16, D=64
Generating inputs...
Compiling torch.compile version...
Verifying correctness...
Max Difference Eager vs Triton: 0.000003

--- Starting Benchmark ---
Eager PyTorch  : 11.820 ms
Torch Compile  : 19.416 ms
Triton Kernel  : 4.770 ms

--- Results ---
Compile Speedup vs Eager: 0.61x
Triton Speedup vs Eager:  2.48x
