# Chunked Delta Rule - Triton Kernel

**The Problem:** True Delta Rule is sequential - each token needs state from previous token.

**The Solution:** Chunk approximation with refinement:
1. Pass 0: All tokens in chunk use chunk-start state (parallel)
2. Pass 1: Refine using Pass 0 outputs (parallel)

**Expected speedup:** 5-15x over pure Python sequential, while maintaining accuracy.

In [1]:
# Environment check
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.version.cuda}")
print(f"GPU: {torch.cuda.get_device_name(0)}")

import triton
print(f"Triton: {triton.__version__}")

PyTorch: 2.11.0.dev20260128+cu128
CUDA: 12.8
GPU: NVIDIA GeForce RTX 4050 Laptop GPU
Triton: 3.6.0


In [2]:
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import time

# =============================================================================
# REFERENCE: Sequential Delta Rule (correct but slow)
# =============================================================================

def sequential_delta_rule(k, v, beta, g, initial_state=None):
    """Pure Python sequential - baseline for correctness."""
    B, T, H, K = k.shape
    V = v.shape[-1]
    device, dtype = k.device, k.dtype
    
    state = torch.zeros(B, H, K, V, device=device, dtype=dtype) if initial_state is None else initial_state.clone()
    
    outputs = []
    for t in range(T):
        k_t, v_t = k[:, t], v[:, t]
        beta_t, g_t = beta[:, t], g[:, t]
        
        pred = torch.einsum('bhkv,bhk->bhv', state, k_t)
        error = v_t - pred
        update = torch.einsum('bhv,bhk->bhkv', error, k_t)
        state = g_t[..., None, None] * state + beta_t[..., None, None] * update
        outputs.append(torch.einsum('bhkv,bhk->bhv', state, k_t))
    
    return torch.stack(outputs, dim=1), state

print("Reference implementation loaded.")

Reference implementation loaded.


In [3]:
# =============================================================================
# TRITON KERNEL: Chunked Delta Rule
# =============================================================================
#
# Strategy: 
#   - Parallelize across (B, H) - each thread handles one batch/head
#   - Process T sequentially but with fused GPU operations
#   - This is still O(T) but constant factors are much better
#
# The sequential T dependency is fundamental to Delta Rule correctness.
# Chunked approximation would break error correction.
# =============================================================================

@triton.jit
def delta_rule_fwd_kernel(
    K, V, Beta, G, State_in,
    Out, State_out,
    stride_k_b, stride_k_t, stride_k_h, stride_k_k,
    stride_v_b, stride_v_t, stride_v_h, stride_v_v,
    stride_bg_b, stride_bg_t, stride_bg_h,
    stride_s_b, stride_s_h, stride_s_k, stride_s_v,
    stride_o_b, stride_o_t, stride_o_h, stride_o_v,
    T, K_DIM: tl.constexpr, V_DIM: tl.constexpr,
):
    """
    Delta Rule forward: one (batch, head) per program.
    
    State update: S = g*S + β*(v - S@k)⊗k
    """
    pid_b = tl.program_id(0)
    pid_h = tl.program_id(1)
    
    # Load initial state [K_DIM, V_DIM]
    k_offs = tl.arange(0, K_DIM)
    v_offs = tl.arange(0, V_DIM)
    
    state_ptrs = State_in + pid_b * stride_s_b + pid_h * stride_s_h + \
                 k_offs[:, None] * stride_s_k + v_offs[None, :] * stride_s_v
    state = tl.load(state_ptrs).to(tl.float32)
    
    # Process each token
    for t in range(T):
        # Load k[t] [K_DIM]
        k_ptrs = K + pid_b * stride_k_b + t * stride_k_t + pid_h * stride_k_h + k_offs * stride_k_k
        k_t = tl.load(k_ptrs).to(tl.float32)
        
        # Load v[t] [V_DIM]
        v_ptrs = V + pid_b * stride_v_b + t * stride_v_t + pid_h * stride_v_h + v_offs * stride_v_v
        v_t = tl.load(v_ptrs).to(tl.float32)
        
        # Load scalars
        beta_t = tl.load(Beta + pid_b * stride_bg_b + t * stride_bg_t + pid_h * stride_bg_h).to(tl.float32)
        g_t = tl.load(G + pid_b * stride_bg_b + t * stride_bg_t + pid_h * stride_bg_h).to(tl.float32)
        
        # Prediction: sum_k state[k,v] * k_t[k] for each v
        # pred[v] = sum over k of state[k,v] * k_t[k]
        pred = tl.sum(state * k_t[:, None], axis=0)  # [V_DIM]
        
        # Error
        error = v_t - pred  # [V_DIM]
        
        # Outer product: error[v] * k_t[k] -> [K_DIM, V_DIM]
        outer = k_t[:, None] * error[None, :]  # [K_DIM, V_DIM]
        
        # Update state
        state = g_t * state + beta_t * outer
        
        # Output: retrieve from updated state
        out_t = tl.sum(state * k_t[:, None], axis=0)  # [V_DIM]
        
        # Store output
        out_ptrs = Out + pid_b * stride_o_b + t * stride_o_t + pid_h * stride_o_h + v_offs * stride_o_v
        tl.store(out_ptrs, out_t)
    
    # Store final state
    tl.store(state_ptrs.to(State_out.dtype.element_ty).to(tl.pointer_type(tl.float32)) - State_in + State_out, state)


def triton_delta_rule(k, v, beta, g, initial_state=None):
    """
    Triton-accelerated Delta Rule.
    
    Args:
        k: [B, T, H, K] - normalized keys
        v: [B, T, H, V] - values
        beta: [B, T, H] - write gate
        g: [B, T, H] - forget gate
        initial_state: [B, H, K, V] or None
    
    Returns:
        output: [B, T, H, V]
        final_state: [B, H, K, V]
    """
    B, T, H, K_dim = k.shape
    V_dim = v.shape[-1]
    device = k.device
    
    # Ensure contiguous float32
    k = k.contiguous().float()
    v = v.contiguous().float()
    beta = beta.contiguous().float()
    g = g.contiguous().float()
    
    if initial_state is None:
        state_in = torch.zeros(B, H, K_dim, V_dim, device=device, dtype=torch.float32)
    else:
        state_in = initial_state.contiguous().float()
    
    out = torch.empty(B, T, H, V_dim, device=device, dtype=torch.float32)
    state_out = torch.empty_like(state_in)
    
    # Launch kernel: one program per (batch, head)
    grid = (B, H)
    
    delta_rule_fwd_kernel[grid](
        k, v, beta, g, state_in,
        out, state_out,
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),
        beta.stride(0), beta.stride(1), beta.stride(2),
        state_in.stride(0), state_in.stride(1), state_in.stride(2), state_in.stride(3),
        out.stride(0), out.stride(1), out.stride(2), out.stride(3),
        T, K_dim, V_dim,
    )
    
    return out, state_out

print("Triton kernel loaded.")

Triton kernel loaded.


In [4]:
# =============================================================================
# SIMPLER TRITON KERNEL (if above has issues)
# =============================================================================

@triton.jit  
def delta_simple_kernel(
    K, V, Beta, G, State,
    Out,
    B, T, H, K_DIM: tl.constexpr, V_DIM: tl.constexpr,
):
    """Simplified kernel - might be more compatible."""
    pid = tl.program_id(0)  # Linear index for (b, h)
    b = pid // H
    h = pid % H
    
    # State is [B, H, K, V], row-major
    state_base = b * H * K_DIM * V_DIM + h * K_DIM * V_DIM
    
    # Load state into registers
    state = tl.zeros((K_DIM, V_DIM), dtype=tl.float32)
    for ki in range(K_DIM):
        for vi in range(V_DIM):
            state[ki, vi] = tl.load(State + state_base + ki * V_DIM + vi)
    
    # Process tokens
    for t in range(T):
        # Base offsets
        k_base = b * T * H * K_DIM + t * H * K_DIM + h * K_DIM
        v_base = b * T * H * V_DIM + t * H * V_DIM + h * V_DIM
        bg_base = b * T * H + t * H + h
        
        # Load k, v, beta, g
        k_t = tl.zeros((K_DIM,), dtype=tl.float32)
        v_t = tl.zeros((V_DIM,), dtype=tl.float32)
        for ki in range(K_DIM):
            k_t[ki] = tl.load(K + k_base + ki)
        for vi in range(V_DIM):
            v_t[vi] = tl.load(V + v_base + vi)
        beta_t = tl.load(Beta + bg_base)
        g_t = tl.load(G + bg_base)
        
        # Prediction
        pred = tl.zeros((V_DIM,), dtype=tl.float32)
        for ki in range(K_DIM):
            for vi in range(V_DIM):
                pred[vi] += state[ki, vi] * k_t[ki]
        
        # Error and update
        error = v_t - pred
        for ki in range(K_DIM):
            for vi in range(V_DIM):
                state[ki, vi] = g_t * state[ki, vi] + beta_t * error[vi] * k_t[ki]
        
        # Output
        out_t = tl.zeros((V_DIM,), dtype=tl.float32)
        for ki in range(K_DIM):
            for vi in range(V_DIM):
                out_t[vi] += state[ki, vi] * k_t[ki]
        
        out_base = b * T * H * V_DIM + t * H * V_DIM + h * V_DIM
        for vi in range(V_DIM):
            tl.store(Out + out_base + vi, out_t[vi])
    
    # Store final state
    for ki in range(K_DIM):
        for vi in range(V_DIM):
            tl.store(State + state_base + ki * V_DIM + vi, state[ki, vi])


def triton_delta_simple(k, v, beta, g, initial_state=None):
    B, T, H, K_dim = k.shape
    V_dim = v.shape[-1]
    device = k.device
    
    k = k.contiguous().float()
    v = v.contiguous().float()
    beta = beta.contiguous().float()
    g = g.contiguous().float()
    
    if initial_state is None:
        state = torch.zeros(B, H, K_dim, V_dim, device=device, dtype=torch.float32)
    else:
        state = initial_state.clone().contiguous().float()
    
    out = torch.empty(B, T, H, V_dim, device=device, dtype=torch.float32)
    
    grid = (B * H,)
    delta_simple_kernel[grid](
        k, v, beta, g, state, out,
        B, T, H, K_dim, V_dim,
    )
    
    return out, state

print("Simple Triton kernel loaded.")

Simple Triton kernel loaded.


In [5]:
# =============================================================================
# TEST CORRECTNESS
# =============================================================================

print("=" * 60)
print("CORRECTNESS TEST")
print("=" * 60)

B, T, H, K, V = 2, 64, 4, 32, 64
device = "cuda"

k = F.normalize(torch.randn(B, T, H, K, device=device), dim=-1)
v = torch.randn(B, T, H, V, device=device)
beta = torch.sigmoid(torch.randn(B, T, H, device=device) - 2)  # Low β
g = torch.sigmoid(torch.randn(B, T, H, device=device) + 2)    # High g

# Reference
out_ref, state_ref = sequential_delta_rule(k, v, beta, g)

# Triton
try:
    out_tri, state_tri = triton_delta_simple(k, v, beta, g)
    
    out_err = (out_tri - out_ref.float()).norm() / out_ref.float().norm()
    state_err = (state_tri - state_ref.float()).norm() / state_ref.float().norm()
    
    print(f"Output error:  {out_err.item():.6f}")
    print(f"State error:   {state_err.item():.6f}")
    print(f"→ {'✓ PASS' if out_err < 0.01 else '✗ FAIL'}")
except Exception as e:
    print(f"Triton error: {e}")
    import traceback
    traceback.print_exc()

CORRECTNESS TEST
Triton error: at 18:12:
    pid = tl.program_id(0)  # Linear index for (b, h)
    b = pid // H
    h = pid % H

    # State is [B, H, K, V], row-major
    state_base = b * H * K_DIM * V_DIM + h * K_DIM * V_DIM

    # Load state into registers
    state = tl.zeros((K_DIM, V_DIM), dtype=tl.float32)
    for ki in range(K_DIM):
        for vi in range(V_DIM):
            state[ki, vi] = tl.load(State + state_base + ki * V_DIM + vi)
            ^
NotImplementedError('__setitem__ is not supported in triton')


Traceback (most recent call last):
  File "/tmp/ipykernel_9581/1194476838.py", line 22, in <module>
    out_tri, state_tri = triton_delta_simple(k, v, beta, g)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_9581/1713007584.py", line 88, in triton_delta_simple
    delta_simple_kernel[grid](
  File "/home/m_tes/groundthink/gt-v6/.venv/lib/python3.12/site-packages/triton/runtime/jit.py", line 370, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/m_tes/groundthink/gt-v6/.venv/lib/python3.12/site-packages/triton/runtime/jit.py", line 720, in run
    kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/m_tes/groundthink/gt-v6/.venv/lib/python3.12/site-packages/triton/runtime/

In [6]:
# =============================================================================
# BENCHMARK
# =============================================================================

print("\n" + "=" * 60)
print("SPEED BENCHMARK")
print("=" * 60)

configs = [
    (4, 64, 8, 32, 64),
    (8, 128, 8, 32, 64),
    (8, 256, 8, 32, 64),
    (16, 128, 8, 32, 64),
]

n_warmup = 3
n_runs = 10

print(f"{'Config':>25} | {'PyTorch':>10} | {'Triton':>10} | {'Speedup':>8}")
print("-" * 60)

for B, T, H, K, V in configs:
    k = F.normalize(torch.randn(B, T, H, K, device='cuda'), dim=-1)
    v = torch.randn(B, T, H, V, device='cuda')
    beta = torch.sigmoid(torch.randn(B, T, H, device='cuda') - 2)
    g = torch.sigmoid(torch.randn(B, T, H, device='cuda') + 2)
    
    # Warmup
    for _ in range(n_warmup):
        sequential_delta_rule(k, v, beta, g)
        torch.cuda.synchronize()
    
    # PyTorch
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(n_runs):
        sequential_delta_rule(k, v, beta, g)
    torch.cuda.synchronize()
    pytorch_ms = (time.perf_counter() - start) / n_runs * 1000
    
    # Triton
    try:
        for _ in range(n_warmup):
            triton_delta_simple(k, v, beta, g)
            torch.cuda.synchronize()
        
        torch.cuda.synchronize()
        start = time.perf_counter()
        for _ in range(n_runs):
            triton_delta_simple(k, v, beta, g)
        torch.cuda.synchronize()
        triton_ms = (time.perf_counter() - start) / n_runs * 1000
        
        speedup = pytorch_ms / triton_ms
        triton_str = f"{triton_ms:.2f} ms"
        speedup_str = f"{speedup:.2f}x"
    except Exception as e:
        triton_str = "ERROR"
        speedup_str = "-"
    
    config_str = f"B={B},T={T},H={H},K={K},V={V}"
    print(f"{config_str:>25} | {pytorch_ms:>8.2f}ms | {triton_str:>10} | {speedup_str:>8}")


SPEED BENCHMARK
                   Config |    PyTorch |     Triton |  Speedup
------------------------------------------------------------
   B=4,T=64,H=8,K=32,V=64 |    13.42ms |      ERROR |        -
  B=8,T=128,H=8,K=32,V=64 |    39.96ms |      ERROR |        -
  B=8,T=256,H=8,K=32,V=64 |    77.16ms |      ERROR |        -
 B=16,T=128,H=8,K=32,V=64 |    45.92ms |      ERROR |        -


In [7]:
# =============================================================================
# DELTA RULE VALIDATION (confirm error correction works)
# =============================================================================

print("\n" + "=" * 60)
print("DELTA RULE VALIDATION")
print("=" * 60)

B, H, K, V = 1, 4, 32, 64
state = torch.zeros(B, H, K, V, device='cuda')

k = F.normalize(torch.randn(B, H, K, device='cuda'), dim=-1)
v = torch.randn(B, H, V, device='cuda')

# First write
pred1 = torch.einsum('bhkv,bhk->bhv', state, k)
error1 = v - pred1
state = state + torch.einsum('bhv,bhk->bhkv', error1, k)
norm1 = state.norm().item()

# Second write (SAME k, v)
pred2 = torch.einsum('bhkv,bhk->bhv', state, k)
error2 = v - pred2
state = state + torch.einsum('bhv,bhk->bhkv', error2, k)
norm2 = state.norm().item()

print(f"Error1: {error1.norm().item():.4f}")
print(f"Error2: {error2.norm().item():.6f} (should be ~0)")
print(f"State growth: {norm2/norm1:.4f}x (should be ~1.0)")
print(f"→ {'✓ PASS' if error2.norm().item() < 0.001 else '✗ FAIL'}")


DELTA RULE VALIDATION
Error1: 16.2738
Error2: 0.000002 (should be ~0)
State growth: 1.0000x (should be ~1.0)
→ ✓ PASS


In [8]:
# =============================================================================
# INTEGRATION: GatedDeltaNetLayer using Triton
# =============================================================================

import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight


class GatedDeltaNetLayer(nn.Module):
    """GDN using Triton Delta Rule kernel."""
    
    def __init__(self, d_model, n_heads, head_dim, value_dim):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = head_dim
        self.value_dim = value_dim
        
        self.q_proj = nn.Linear(d_model, n_heads * head_dim, bias=False)
        self.k_proj = nn.Linear(d_model, n_heads * head_dim, bias=False)
        self.v_proj = nn.Linear(d_model, n_heads * value_dim, bias=False)
        self.o_proj = nn.Linear(n_heads * value_dim, d_model, bias=False)
        
        self.beta_proj = nn.Linear(d_model, n_heads, bias=True)
        nn.init.constant_(self.beta_proj.bias, -2.0)
        
        self.g_proj = nn.Linear(d_model, n_heads, bias=True)
        nn.init.constant_(self.g_proj.bias, 2.0)
        
        self.norm = RMSNorm(d_model)
        
    def forward(self, x, initial_state=None):
        B, T, D = x.shape
        H, K, V = self.n_heads, self.head_dim, self.value_dim
        
        x_norm = self.norm(x)
        
        q = self.q_proj(x_norm).view(B, T, H, K)
        k = self.k_proj(x_norm).view(B, T, H, K)
        v = self.v_proj(x_norm).view(B, T, H, V)
        
        k = F.normalize(k.float(), p=2, dim=-1).to(x.dtype)
        
        beta = torch.sigmoid(self.beta_proj(x_norm))
        g = torch.sigmoid(self.g_proj(x_norm))
        
        # Use Triton kernel
        out, state = triton_delta_simple(k, v, beta, g, initial_state)
        
        out = out.to(x.dtype).reshape(B, T, H * V)
        out = self.o_proj(out)
        
        return x + out, state

# Test
layer = GatedDeltaNetLayer(256, 8, 32, 64).cuda()
x = torch.randn(2, 128, 256, device='cuda')
out, state = layer(x)
print(f"\nGDN Layer test: input {x.shape} -> output {out.shape}, state {state.shape}")
print("✓ Integration successful")

CompilationError: at 18:12:
    pid = tl.program_id(0)  # Linear index for (b, h)
    b = pid // H
    h = pid % H

    # State is [B, H, K, V], row-major
    state_base = b * H * K_DIM * V_DIM + h * K_DIM * V_DIM

    # Load state into registers
    state = tl.zeros((K_DIM, V_DIM), dtype=tl.float32)
    for ki in range(K_DIM):
        for vi in range(V_DIM):
            state[ki, vi] = tl.load(State + state_base + ki * V_DIM + vi)
            ^
NotImplementedError('__setitem__ is not supported in triton')