In [1]:
import torch
import torch.nn as nn

In [2]:
def chunk_delta_rule(Q, K, V, beta, chunk_size):
    """
    Applies chunked attention mechanism to the input queries, keys, and values.

    Args:
        Q: Query tensor of shape (batch_size, seq_len, d_model)
        K: Key tensor of shape (batch_size, seq_len, d_model)
        V: Value tensor of shape (batch_size, seq_len, d_model)
        chunk_size: Size of each chunk
    Returns:
        Output tensor after applying chunked attention of shape (batch_size, seq_len, d_model)
    """
    B, L, d = Q.shape
    C = chunk_size
    K_beta = K * beta.unsqueeze(-1)  # (B, L, d)
    V_beta = V * beta.unsqueeze(-1)  # (B, L, d)
    T = - (K_beta @ K.transpose(-2, -1)).tril(-1)  # (B, L, L)
    
    for i in range(1, C):
        # Keep batch dimension: T[:, i, :i] has shape (B, i)
        T[:, i, :i] = T[:, i, :i] + (T[:, i, :, None] * T[:, :, :i]).sum(-2)

    T = T + torch.eye(L).to(Q.device).unsqueeze(0)  # (B, L, L)
    W = T @ K_beta  # (B, L, d)
    U = T @ V_beta  # (B, L, d)
    S = torch.zeros(B, d, d).to(Q.device)  # (B, d, d)
    O = torch.empty_like(V)
    
    # Reshape to chunks: (B, num_chunks, chunk_size, d)
    num_chunks = L // C
    Q_chunks = Q.reshape(B, num_chunks, C, d)
    K_chunks = K.reshape(B, num_chunks, C, d)
    W_chunks = W.reshape(B, num_chunks, C, d)
    U_chunks = U.reshape(B, num_chunks, C, d)
    
    for i in range(num_chunks):
        q_i = Q_chunks[:, i]  # (B, C, d)
        k_i = K_chunks[:, i]  # (B, C, d)
        w_i = W_chunks[:, i]  # (B, C, d)
        u_i = U_chunks[:, i] - w_i @ S  # (B, C, d)
        o_inter = q_i @ S  # (B, C, d)
        A_i = (q_i @ k_i.transpose(-2, -1)).tril()  # (B, C, C)
        o_intra = A_i @ u_i  # (B, C, d)
        O[:, i*C:(i+1)*C] = o_inter + o_intra
        S = S + k_i.transpose(-2, -1) @ u_i  # (B, d, d)
    
    return O

In [3]:
B, L, d = 2, 10, 4
Q = torch.randn(B, L, d)
K = torch.randn(B, L, d)
V = torch.randn(B, L, d)
chunk_size = 2
beta = torch.ones(B, L)

In [4]:
chunk_delta_rule(Q, K, V, beta, chunk_size).shape

torch.Size([2, 10, 4])

In [5]:
class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-8):
        super(RMSNorm, self).__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        norm_x = x.norm(2, dim=-1, keepdim=True)
        rms = norm_x / (x.size(-1) ** 0.5)
        x_normed = x / (rms + self.eps)
        return x_normed * self.scale

In [7]:
class DeltaNet(nn.Module):
    def __init__(self, d_model, chunk_size):
        super(DeltaNet, self).__init__()
        self.d_model = d_model
        self.chunk_size = chunk_size
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.beta_linear = nn.Linear(d_model, 1)
        self.q_conv = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1)
        self.k_conv = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1)
        self.v_conv = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1)
        self.rms_norm = RMSNorm(d_model)
        self.output_linear = nn.Linear(d_model, d_model)

    def forward(self, x):
        q = self.q_linear(x)
        q = nn.functional.relu(self.q_conv(q.transpose(1, 2))).transpose(1, 2)
        q = q / (q.norm(dim=-1, keepdim=True) + 1e-6)

        k = self.k_linear(x)
        k = nn.functional.relu(self.k_conv(k.transpose(1, 2))).transpose(1, 2)
        k = k / (k.norm(dim=-1, keepdim=True) + 1e-6)

        v = self.v_linear(x)
        v = nn.functional.relu(self.v_conv(v.transpose(1, 2))).transpose(1, 2)

        beta = torch.sigmoid(self.beta_linear(x)).squeeze(-1)

        out = chunk_delta_rule(q, k, v, beta, self.chunk_size)
        out = self.rms_norm(out)
        out = self.output_linear(out)
        return out

In [8]:
class SwiGLU(nn.Module):
    def __init__(self):
        super(SwiGLU, self).__init__()

    def forward(self, x):
        x1, x2 = x.chunk(2, dim=-1)
        return x1 * torch.sigmoid(x2)

In [9]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, chunk_size, ff_hidden_dim):
        super(TransformerBlock, self).__init__()
        self.deltanet = DeltaNet(d_model, chunk_size)
        self.norm1 = RMSNorm(d_model)
        self.norm2 = RMSNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ff_hidden_dim * 2),
            SwiGLU(),
            nn.Linear(ff_hidden_dim, d_model)
        )

    def forward(self, x):
        deltanet_attn_out = self.deltanet(x)
        x = x + deltanet_attn_out
        x = self.norm1(x)

        ffn_out = self.ffn(x)
        x = x + ffn_out
        x = self.norm2(x)

        return x

In [10]:
class TransformerDeltaNet(nn.Module):
    def __init__(self, d_model, chunk_size, num_layers):
        super(TransformerDeltaNet, self).__init__()
        self.layers = nn.ModuleList([DeltaNet(d_model, chunk_size) for _ in range(num_layers)])
        self.final_rms_norm = RMSNorm(d_model)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x) + x  # Residual connection
        x = self.final_rms_norm(x)
        return x

In [None]:
model = TransformerDeltaNet(d_model=64, chunk_size=4, num_layers=6)
input_tensor = torch.randn(8, 128, 64)  # (batch_size, seq_len, d_model)
output = model(input_tensor)
output.shape  # (8, 128, 64)

torch.Size([8, 128, 64])