# Triton Delta Rule Kernel

**Target:** RTX 4050, PyTorch 2.10.0, Triton 3.6.0, CUDA 12.x

**Fix:** Triton doesn't support `tensor[i,j] = x`. Must use vectorized block operations.

In [1]:
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.version.cuda}")
print(f"GPU: {torch.cuda.get_device_name(0)}")

import triton
print(f"Triton: {triton.__version__}")

PyTorch: 2.11.0.dev20260128+cu128
CUDA: 12.8
GPU: NVIDIA GeForce RTX 4050 Laptop GPU
Triton: 3.6.0


In [2]:
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import time

# =============================================================================
# TRITON KERNEL - PROPER VECTORIZED OPERATIONS
# =============================================================================

@triton.jit
def delta_rule_fwd_kernel(
    K_ptr, V_ptr, Beta_ptr, G_ptr, State_ptr, Out_ptr,
    stride_k_b, stride_k_t, stride_k_h, stride_k_d,
    stride_v_b, stride_v_t, stride_v_h, stride_v_d,
    stride_bg_b, stride_bg_t, stride_bg_h,
    stride_s_b, stride_s_h, stride_s_k, stride_s_v,
    stride_o_b, stride_o_t, stride_o_h, stride_o_v,
    T,
    K_DIM: tl.constexpr,
    V_DIM: tl.constexpr,
):
    """
    Delta Rule: S = g*S + β*(v - S@k)⊗k
    
    One program per (batch, head). Sequential over T, vectorized over K×V.
    """
    pid_b = tl.program_id(0)
    pid_h = tl.program_id(1)
    
    # Offset vectors for block operations
    k_offs = tl.arange(0, K_DIM)
    v_offs = tl.arange(0, V_DIM)
    
    # State pointers [K_DIM, V_DIM]
    state_base = State_ptr + pid_b * stride_s_b + pid_h * stride_s_h
    state_ptrs = state_base + k_offs[:, None] * stride_s_k + v_offs[None, :] * stride_s_v
    
    # Load initial state
    state = tl.load(state_ptrs).to(tl.float32)
    
    for t in range(T):
        # Load k[t] [K_DIM]
        k_base = K_ptr + pid_b * stride_k_b + t * stride_k_t + pid_h * stride_k_h
        k_t = tl.load(k_base + k_offs * stride_k_d).to(tl.float32)
        
        # Load v[t] [V_DIM]
        v_base = V_ptr + pid_b * stride_v_b + t * stride_v_t + pid_h * stride_v_h
        v_t = tl.load(v_base + v_offs * stride_v_d).to(tl.float32)
        
        # Load scalars
        bg_offset = pid_b * stride_bg_b + t * stride_bg_t + pid_h * stride_bg_h
        beta_t = tl.load(Beta_ptr + bg_offset).to(tl.float32)
        g_t = tl.load(G_ptr + bg_offset).to(tl.float32)
        
        # Prediction: sum_k state[k,v] * k_t[k]
        pred = tl.sum(state * k_t[:, None], axis=0)  # [V_DIM]
        
        # Error
        error = v_t - pred
        
        # Outer product
        outer = k_t[:, None] * error[None, :]  # [K_DIM, V_DIM]
        
        # Update state
        state = g_t * state + beta_t * outer
        
        # Output
        out_t = tl.sum(state * k_t[:, None], axis=0)
        
        # Store output
        out_base = Out_ptr + pid_b * stride_o_b + t * stride_o_t + pid_h * stride_o_h
        tl.store(out_base + v_offs * stride_o_v, out_t)
    
    # Store final state
    tl.store(state_ptrs, state)


def triton_delta_rule(k, v, beta, g, initial_state=None):
    B, T, H, K_DIM = k.shape
    V_DIM = v.shape[-1]
    device = k.device
    orig_dtype = k.dtype
    
    k = k.contiguous().float()
    v = v.contiguous().float()
    beta = beta.contiguous().float()
    g = g.contiguous().float()
    
    if initial_state is None:
        state = torch.zeros(B, H, K_DIM, V_DIM, device=device, dtype=torch.float32)
    else:
        state = initial_state.contiguous().float().clone()
    
    out = torch.empty(B, T, H, V_DIM, device=device, dtype=torch.float32)
    
    grid = (B, H)
    delta_rule_fwd_kernel[grid](
        k, v, beta, g, state, out,
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),
        beta.stride(0), beta.stride(1), beta.stride(2),
        state.stride(0), state.stride(1), state.stride(2), state.stride(3),
        out.stride(0), out.stride(1), out.stride(2), out.stride(3),
        T, K_DIM, V_DIM,
    )
    
    return out.to(orig_dtype), state.to(orig_dtype)

print("Triton kernel loaded.")

Triton kernel loaded.


In [3]:
# =============================================================================
# REFERENCE (PyTorch sequential)
# =============================================================================

def sequential_delta_rule(k, v, beta, g, initial_state=None):
    B, T, H, K = k.shape
    V = v.shape[-1]
    device, dtype = k.device, k.dtype
    
    state = torch.zeros(B, H, K, V, device=device, dtype=dtype) if initial_state is None else initial_state.clone()
    
    outputs = []
    for t in range(T):
        k_t, v_t = k[:, t], v[:, t]
        beta_t, g_t = beta[:, t], g[:, t]
        
        pred = torch.einsum('bhkv,bhk->bhv', state, k_t)
        error = v_t - pred
        update = torch.einsum('bhv,bhk->bhkv', error, k_t)
        state = g_t[..., None, None] * state + beta_t[..., None, None] * update
        outputs.append(torch.einsum('bhkv,bhk->bhv', state, k_t))
    
    return torch.stack(outputs, dim=1), state

print("Reference implementation loaded.")

Reference implementation loaded.


In [4]:
# =============================================================================
# TEST CORRECTNESS
# =============================================================================

print("=" * 60)
print("CORRECTNESS TEST")
print("=" * 60)

B, T, H, K, V = 2, 64, 4, 32, 64

k = F.normalize(torch.randn(B, T, H, K, device='cuda'), dim=-1)
v = torch.randn(B, T, H, V, device='cuda')
beta = torch.sigmoid(torch.randn(B, T, H, device='cuda') - 2)
g = torch.sigmoid(torch.randn(B, T, H, device='cuda') + 2)

out_ref, state_ref = sequential_delta_rule(k, v, beta, g)
out_tri, state_tri = triton_delta_rule(k, v, beta, g)

out_err = (out_tri.float() - out_ref.float()).norm() / out_ref.float().norm()
state_err = (state_tri.float() - state_ref.float()).norm() / state_ref.float().norm()

print(f"Output error:  {out_err.item():.8f}")
print(f"State error:   {state_err.item():.8f}")
print(f"→ {'✓ PASS' if out_err < 1e-4 else '✗ FAIL'}")

CORRECTNESS TEST
Output error:  0.00000007
State error:   0.00000008
→ ✓ PASS


In [5]:
# =============================================================================
# SPEED BENCHMARK
# =============================================================================

print("\n" + "=" * 60)
print("SPEED BENCHMARK")
print("=" * 60)

configs = [
    (4, 64, 8, 32, 64),
    (8, 128, 8, 32, 64),
    (8, 256, 8, 32, 64),
    (16, 128, 8, 32, 64),
]

n_warmup, n_runs = 5, 20

print(f"{'Config':>25} | {'PyTorch':>10} | {'Triton':>10} | {'Speedup':>8}")
print("-" * 60)

for B, T, H, K, V in configs:
    k = F.normalize(torch.randn(B, T, H, K, device='cuda'), dim=-1)
    v = torch.randn(B, T, H, V, device='cuda')
    beta = torch.sigmoid(torch.randn(B, T, H, device='cuda') - 2)
    g = torch.sigmoid(torch.randn(B, T, H, device='cuda') + 2)
    
    # PyTorch
    for _ in range(n_warmup):
        sequential_delta_rule(k, v, beta, g)
        torch.cuda.synchronize()
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(n_runs):
        sequential_delta_rule(k, v, beta, g)
    torch.cuda.synchronize()
    pytorch_ms = (time.perf_counter() - start) / n_runs * 1000
    
    # Triton
    for _ in range(n_warmup):
        triton_delta_rule(k, v, beta, g)
        torch.cuda.synchronize()
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(n_runs):
        triton_delta_rule(k, v, beta, g)
    torch.cuda.synchronize()
    triton_ms = (time.perf_counter() - start) / n_runs * 1000
    
    speedup = pytorch_ms / triton_ms
    print(f"B={B},T={T},H={H}:>25 | {pytorch_ms:>8.2f}ms | {triton_ms:>8.2f}ms | {speedup:>7.2f}x")


SPEED BENCHMARK
                   Config |    PyTorch |     Triton |  Speedup
------------------------------------------------------------
B=4,T=64,H=8:>25 |    15.40ms |     0.10ms |  154.17x
B=8,T=128,H=8:>25 |    27.40ms |     0.28ms |   98.48x
B=8,T=256,H=8:>25 |    62.21ms |     0.53ms |  116.52x
B=16,T=128,H=8:>25 |    29.33ms |     0.48ms |   61.56x


In [6]:
# =============================================================================
# DELTA RULE VALIDATION
# =============================================================================

print("\n" + "=" * 60)
print("DELTA RULE VALIDATION")
print("=" * 60)

B, H, K, V = 1, 4, 32, 64
state = torch.zeros(B, H, K, V, device='cuda')
k = F.normalize(torch.randn(B, H, K, device='cuda'), dim=-1)
v = torch.randn(B, H, V, device='cuda')

# First write
pred1 = torch.einsum('bhkv,bhk->bhv', state, k)
error1 = v - pred1
state = state + torch.einsum('bhv,bhk->bhkv', error1, k)
norm1 = state.norm().item()

# Second write (SAME k, v)
pred2 = torch.einsum('bhkv,bhk->bhv', state, k)
error2 = v - pred2
state = state + torch.einsum('bhv,bhk->bhkv', error2, k)
norm2 = state.norm().item()

print(f"Error1: {error1.norm().item():.4f}")
print(f"Error2: {error2.norm().item():.8f} (should be ~0)")
print(f"State growth: {norm2/norm1:.6f}x (should be ~1.0)")
print(f"→ {'✓ PASS' if error2.norm().item() < 1e-5 else '✗ FAIL'}")


DELTA RULE VALIDATION
Error1: 16.6302
Error2: 0.00000227 (should be ~0)
State growth: 1.000000x (should be ~1.0)
→ ✓ PASS


In [7]:
# =============================================================================
# INTEGRATION: GatedDeltaNetLayer
# =============================================================================

import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight


class GatedDeltaNetLayer(nn.Module):
    def __init__(self, d_model, n_heads, head_dim, value_dim):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = head_dim
        self.value_dim = value_dim
        
        self.q_proj = nn.Linear(d_model, n_heads * head_dim, bias=False)
        self.k_proj = nn.Linear(d_model, n_heads * head_dim, bias=False)
        self.v_proj = nn.Linear(d_model, n_heads * value_dim, bias=False)
        self.o_proj = nn.Linear(n_heads * value_dim, d_model, bias=False)
        
        self.beta_proj = nn.Linear(d_model, n_heads, bias=True)
        nn.init.constant_(self.beta_proj.bias, -2.0)
        
        self.g_proj = nn.Linear(d_model, n_heads, bias=True)
        nn.init.constant_(self.g_proj.bias, 2.0)
        
        self.norm = RMSNorm(d_model)
        
    def forward(self, x, initial_state=None):
        B, T, D = x.shape
        H, K, V = self.n_heads, self.head_dim, self.value_dim
        
        x_norm = self.norm(x)
        
        k = self.k_proj(x_norm).view(B, T, H, K)
        v = self.v_proj(x_norm).view(B, T, H, V)
        k = F.normalize(k.float(), p=2, dim=-1).to(x.dtype)
        
        beta = torch.sigmoid(self.beta_proj(x_norm))
        g = torch.sigmoid(self.g_proj(x_norm))
        
        out, state = triton_delta_rule(k, v, beta, g, initial_state)
        
        out = out.to(x.dtype).reshape(B, T, H * V)
        return x + self.o_proj(out), state


# Test
layer = GatedDeltaNetLayer(256, 8, 32, 64).cuda()
x = torch.randn(2, 128, 256, device='cuda')
out, state = layer(x)
print(f"GDN test: {x.shape} -> {out.shape}, state {state.shape}")
print("✓ Integration successful")

GDN test: torch.Size([2, 128, 256]) -> torch.Size([2, 128, 256]), state torch.Size([2, 8, 32, 64])
✓ Integration successful
