In [1]:
import torch
import torch.nn as nn
from einops import rearrange

#### Fixed-decay with projections

The objective of this notebook is to provide an example of the fixed-decay approach discussed in the Based paper (https://arxiv.org/abs/2402.18668). While Based achieves high-quality with *no* decay whatsoever, the following addition may be helpful to your use case. 

In [2]:
# inputs

b, h, n, d, f = 2, 4, 64, 16, 16
eps = 1e-12
d_model = h * d
q = torch.randn(b, h, n, f)
k = torch.randn(b, h, n, f)
v = torch.randn(b, h, n, d)
hidden_states = torch.randn(b, n, d_model)

In [3]:
# construct the fixed decay matrices

class DecayClass(nn.Module):
    def __init__(self, l_max, decay_const=-3, decay_denom=False, n_kv_heads=16):
        super().__init__()
        self.l_max = l_max
        assert self.l_max > 0, print(f'double check l_max')
        decay_const = decay_const
        self.decay_denom = decay_denom
        self.num_heads = n_kv_heads
        decay = torch.log(1 - 2 ** (decay_const - torch.arange(self.num_heads, dtype=torch.float)))
        self.register_buffer("decay", decay)
    
    def forward(self):
        index = torch.arange(self.l_max).to(self.decay)
        mask = torch.tril(torch.ones(self.l_max, self.l_max).to(self.decay))
        mask = torch.masked_fill(index[:, None] - index[None, :], ~mask.bool(), float("inf"))
        mask = torch.exp(mask * self.decay[:, None, None])
        mask = torch.nan_to_num(mask)
        if self.decay_denom:
            mask = mask / mask.sum(dim=-1, keepdim=True).sqrt()
        return mask, torch.exp(self.decay)


decay_cls = DecayClass(l_max=n, decay_const=-3, decay_denom=False, n_kv_heads=h)


In [4]:
# plug into linear attention

# Version 1: default, no decay (https://github.com/HazyResearch/based/blob/9db60a33d20e6c024de97703715768da9d872e30/based/models/mixers/linear_attention.py#L136)
A_qk = torch.einsum("bhnd,bhmd->bhnm", q, k) 
A_qk = torch.tril(A_qk)        
y = torch.einsum("bhnm,bhme->bhne", A_qk.to(q.dtype), v.to(q.dtype))
z = 1 / (torch.einsum("bhld,bhld->bhl", q, k.cumsum(2)) + eps)
y = y * z[..., None]
y = rearrange(y, 'b h l d -> b l (h d)')


# Version 2: with decay
use_decay_proj = True
decay_proj = nn.Linear(d_model, h)
cumsum_matrix = torch.tril(torch.ones((n, n))).to(q.device, q.dtype)

decay = decay_cls()
decay, decay_recurrent = decay if decay is not None else (None, None)

A_qk = torch.einsum("bhnd,bhmd->bhnm", q, k) 
if decay is not None:
    decay = decay[:, :n, :n]
    if len(decay.shape) == 3:
        decay = decay.unsqueeze(0)
    if use_decay_proj:
        dt_out = decay_proj(hidden_states) # (b l d) --> (b, l, h)
        assert decay.shape[2] >= n, f"decay matrix {decay.shape} to short for sequence length {l}"
        decay_mat = dt_out.transpose(1,2).unsqueeze(-1) * decay   # (b, h, l, 1) * (1, h, l, l)
    elif decay is not None:
        decay_mat = decay
    A_qk = A_qk * decay_mat
else:
    A_qk = A_qk * cumsum_matrix       
out = torch.einsum("bhnm,bhme->bhne", A_qk.to(hidden_states.dtype), v.to(hidden_states.dtype))
z = 1 / (torch.einsum("bhld,bhld->bhl", q, k.cumsum(2)) + eps)
y = out * z[..., None]
y = y.to(hidden_states.dtype)


  from .autonotebook import tqdm as notebook_tqdm
