In [1]:
import torch
import torch.nn as nn

In [43]:
def chunk_delta_rule(Q, K, V, beta, chunk_size):
    """
    Applies chunked attention mechanism to the input queries, keys, and values.

    Args:
        Q: Query tensor of shape (batch_size, seq_len, d_model)
        K: Key tensor of shape (batch_size, seq_len, d_model)
        V: Value tensor of shape (batch_size, seq_len, d_model)
        chunk_size: Size of each chunk
    Returns:
        Output tensor after applying chunked attention of shape (batch_size, seq_len, d_model)
    """
    B, L, d = Q.shape
    C = chunk_size
    K_beta = K * beta.unsqueeze(-1)  # (B, L, d)
    V_beta = V * beta.unsqueeze(-1)  # (B, L, d)
    T = - (K_beta @ K.transpose(-2, -1)).tril(-1)  # (B, L, L)
    
    for i in range(1, C):
        # Keep batch dimension: T[:, i, :i] has shape (B, i)
        T[:, i, :i] = T[:, i, :i] + (T[:, i, :, None] * T[:, :, :i]).sum(-2)

    T = T + torch.eye(L).to(Q.device).unsqueeze(0)  # (B, L, L)
    W = T @ K_beta  # (B, L, d)
    U = T @ V_beta  # (B, L, d)
    S = torch.zeros(B, d, d).to(Q.device)  # (B, d, d)
    O = torch.empty_like(V)
    
    # Reshape to chunks: (B, num_chunks, chunk_size, d)
    num_chunks = L // C
    Q_chunks = Q.reshape(B, num_chunks, C, d)
    K_chunks = K.reshape(B, num_chunks, C, d)
    W_chunks = W.reshape(B, num_chunks, C, d)
    U_chunks = U.reshape(B, num_chunks, C, d)
    
    for i in range(num_chunks):
        q_i = Q_chunks[:, i]  # (B, C, d)
        k_i = K_chunks[:, i]  # (B, C, d)
        w_i = W_chunks[:, i]  # (B, C, d)
        u_i = U_chunks[:, i] - w_i @ S  # (B, C, d)
        o_inter = q_i @ S  # (B, C, d)
        A_i = (q_i @ k_i.transpose(-2, -1)).tril()  # (B, C, C)
        o_intra = A_i @ u_i  # (B, C, d)
        O[:, i*C:(i+1)*C] = o_inter + o_intra
        S = S + k_i.transpose(-2, -1) @ u_i  # (B, d, d)
    
    return O

In [44]:
B, L, d = 2, 10, 4
Q = torch.randn(B, L, d)
K = torch.randn(B, L, d)
V = torch.randn(B, L, d)
chunk_size = 2
beta = torch.ones(B, L)

In [45]:
chunk_delta_rule(Q, K, V, beta, chunk_size).shape

torch.Size([2, 10, 4])