In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def chunk_delta_rule(q, k, v, beta, chunk_size):
    """
    Alternative implementation of chunked attention mechanism to the input queries, keys, and values.

    Args:
        q: Query tensor of shape (batch_size, seq_len, d_model)
        k: Key tensor of shape (batch_size, seq_len, d_model)
        v: Value tensor of shape (batch_size, seq_len, d_model)
        chunk_size: Size of each chunk
    Returns:
        Output tensor after applying chunked attention of shape (batch_size, seq_len, d_model)
    """ 
    B, H, L, D = q.shape
    C = chunk_size
    num_chunks = L // C

    # Reshape inputs for chunk processing: [B, H, num_chunks, C, D]
    q = q.reshape(B, H, num_chunks, C, D)
    k = k.reshape(B, H, num_chunks, C, D)
    v = v.reshape(B, H, num_chunks, C, D)
    beta = beta.reshape(B, H, num_chunks, C)

    k_beta = k * beta.unsqueeze(-1)  # (B, H, num_chunks, C, D)
    v_beta = v * beta.unsqueeze(-1)  # (B, H, num_chunks, C, D)

    T = -(k_beta @ k.transpose(-1, -2)).tril(-1)  # (B, H, num_chunks, C, C)
    for i in range(1, C):
        T[..., i, :i] = T[..., i, :i] + (T[..., i, :, None] * T[..., :, :i]).sum(-2)

    T = T + torch.eye(C, device=q.device, dtype=q.dtype)
    
    W = T @ k_beta  # (B, H, num_chunks, C, D)
    U = T @ v_beta  # (B, H, num_chunks, C, D)

    S = torch.zeros(B, H, D, D, device=q.device, dtype=q.dtype)
    O = torch.empty_like(v)

    for i in range(num_chunks):
        q_i = q[:, :, i]  # (B, H, C, D)
        k_i = k[:, :, i]  # (B, H, C, D)
        w_i = W[:, :, i]  # (B, H, C, D)
        u_init = U[:, :, i]  # (B, H, C, D)

        u_i = u_init - (w_i @ S)
        o_inter = (q_i @ S)            
        a_i = (q_i @ k_i.transpose(-1, -2)).tril()
        o_intra = a_i @ u_i

        O[:, :, i] = o_intra + o_inter
        S = S + (k_i.transpose(-1, -2) @ u_i)

    return O.reshape(B, H, L, D)

In [3]:
B, H, L, d = 2, 8, 1024, 64
Q = torch.randn(B, H, L, d)
K = torch.randn(B, H, L, d)
V = torch.randn(B, H, L, d)
chunk_size = 16
beta = torch.ones(B, H, L)

In [4]:
chunk_delta_rule(Q, K, V, beta, chunk_size).shape

torch.Size([2, 8, 1024, 64])

In [5]:
class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-8):
        super(RMSNorm, self).__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        norm_x = x.norm(2, dim=-1, keepdim=True)
        rms = norm_x / (x.size(-1) ** 0.5)
        x_normed = x / (rms + self.eps)
        return x_normed * self.scale

In [6]:
class DeltaNet(nn.Module):
    def __init__(self, d_model, chunk_size=64, num_heads=8):
        super(DeltaNet, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.chunk_size = chunk_size
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.beta_linear = nn.Linear(d_model, num_heads)
        
        self.q_conv = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1)
        self.k_conv = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1)
        self.v_conv = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1)
        
        self.rms_norm = RMSNorm(d_model)
        self.output_linear = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, L, _ = x.shape
        H = self.num_heads
        D = self.head_dim
        
        # Linear projections
        q = self.q_linear(x)
        q = nn.functional.silu(self.q_conv(q.transpose(1, 2))).transpose(1, 2)
        
        k = self.k_linear(x)
        k = nn.functional.silu(self.k_conv(k.transpose(1, 2))).transpose(1, 2)
        
        v = self.v_linear(x)
        v = nn.functional.silu(self.v_conv(v.transpose(1, 2))).transpose(1, 2)
        
        # Reshape to multi-head format: (B, L, d_model) -> (B, H, L, head_dim)
        q = q.view(B, L, H, D).transpose(1, 2)  # (B, H, L, D)
        k = k.view(B, L, H, D).transpose(1, 2)  # (B, H, L, D)
        v = v.view(B, L, H, D).transpose(1, 2)  # (B, H, L, D)
        
        # Normalize q and k per head
        q = q / (q.norm(dim=-1, keepdim=True) + 1e-6)
        k = k / (k.norm(dim=-1, keepdim=True) + 1e-6)
        
        # Beta: (B, L, num_heads) -> (B, H, L)
        beta = torch.sigmoid(self.beta_linear(x))  # (B, L, H)
        beta = beta.transpose(1, 2)  # (B, H, L)

        # Apply chunk delta rule with multi-head
        out = chunk_delta_rule(q, k, v, beta, self.chunk_size)  # (B, H, L, D)
        
        # Reshape back: (B, H, L, D) -> (B, L, d_model)
        out = out.transpose(1, 2).contiguous().view(B, L, self.d_model)
        
        out = self.rms_norm(out)
        out = self.output_linear(out)
        return out

In [7]:
class SwiGLU(nn.Module):
    def __init__(self, dim, hidden_dim, bias=False):
        super().__init__()
        # SwiGLU: (Swish(xW) * xV)W_o
        self.w1 = nn.Linear(dim, hidden_dim, bias=bias) # Gate
        self.w2 = nn.Linear(dim, hidden_dim, bias=bias) # Value
        self.w3 = nn.Linear(hidden_dim, dim, bias=bias) # Output

    def forward(self, x):
        x1 = self.w1(x)
        x2 = self.w2(x)
        # Swish/SiLU activation
        hidden = F.silu(x1) * x2
        return self.w3(hidden)

In [8]:
class DeltaNetBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        # First residual path: DeltaNet
        self.norm1 = RMSNorm(dim)
        self.delta_net = DeltaNet(dim, num_heads)
        
        # Second residual path: SwiGLU
        self.norm2 = RMSNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = SwiGLU(dim, mlp_hidden_dim)

    def forward(self, x):
        # x + DeltaNet(RMSNorm(x))
        x = x + self.delta_net(self.norm1(x))
        # x + SwiGLU(RMSNorm(x))
        x = x + self.mlp(self.norm2(x))
        return x

In [9]:
class DeltaNetModel(nn.Module):
    def __init__(self, vocab_size, dim, depth, num_heads):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, dim)
        
        # N x Blocks
        self.layers = nn.ModuleList([
            DeltaNetBlock(dim, num_heads) 
            for _ in range(depth)
        ])
        
        # Final Norm
        self.norm_f = RMSNorm(dim)
        
        # Final Linear (Head)
        self.head = nn.Linear(dim, vocab_size, bias=False)

    def forward(self, input_ids):
        # Inputs
        x = self.embedding(input_ids)
        
        # N x Stacked Layers
        for layer in self.layers:
            x = layer(x)
            
        # Final Norm + Linear
        x = self.norm_f(x)
        logits = self.head(x)
        
        return logits

In [10]:
B, L, D = 2, 64, 128
V = 1000
heads = 4

model = DeltaNetModel(vocab_size=V, dim=D, depth=2, num_heads=heads)
input_ids = torch.randint(0, V, (B, L))
logits = model(input_ids) # (2, 64, 1000)
print(logits.shape) 

torch.Size([2, 64, 1000])
