In [None]:
#default_exp attention

In [2]:
#hide
from nbdev.showdoc import *

In [1]:
#hide
%load_ext autoreload
%autoreload 2

In [13]:
#export
import torch
from torch import nn, einsum
import torch.nn.functional as F
from fastai.basics import *

from functools import partial, reduce
from inspect import isfunction
from operator import mul
from copy import deepcopy

from torch import Tensor
from typing import Tuple

from einops import rearrange, repeat

from reformer_fastai.core import *
from reformer_fastai.layers import *

# Attention Modules 

In [14]:
#export
MASK_VAL = 5e-4
SELF_ATTN_MASK_VAL = 1e-4

## Scaled Dot Product Attention

In [15]:
#export
class AttnInProj(Module):
    """Computes q, k, v from input x and [optional] context"""
    def __init__(self, d_model:int, bias:bool=False):
        self.to_q = nn.Linear(d_model, d_model, bias=bias)
        self.to_k = nn.Linear(d_model, d_model, bias=bias)
        self.to_v = nn.Linear(d_model, d_model, bias=bias)
    def forward(self, x, context=None):
        context = ifnone(context, x)
        q = self.to_q(x)
        k, v = self.to_k(context), self.to_v(context)
        return q, k, v

In [16]:
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
proj = AttnInProj(d)
q1, k1, v1 = proj(x)
assert (bs, sl, d) == q1.size() == k1.size() == v1.size()
q1.shape, k1.shape, v1.shape

(torch.Size([4, 128, 64]), torch.Size([4, 128, 64]), torch.Size([4, 128, 64]))

In [17]:
q2, k2, v2 = proj(x, context)
assert (bs, sl, d) == q2.size()
assert k2.size() == v2.size() == context.size()
assert all_equal(q1, q2)
assert not all_equal(k1, k2)
q2.shape, k2.shape, v2.shape

(torch.Size([4, 128, 64]), torch.Size([4, 112, 64]), torch.Size([4, 112, 64]))

In [18]:
#export
#TODO make sure store_attention works
class ScaledDotProdAttention(Module):
    """
    Computes scaled dot-product attnetion given q, k, v
    """
    def __init__(self, d_model, n_heads, causal=False, dropout=0., store_attention:bool=False):
        store_attr()
        self.scale = (d_model//n_heads)**-0.5
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, q, k, v, input_mask=None):
        device = q.device
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.n_heads), (q, k, v))
        
        #TODO: remove after refactor confirmed working
        # boolean input_mask is False at positions not to attend to
#         input_mask = None
#         if any(map(exists, (mask, context_mask))):
#             q_mask = default(mask, lambda: torch.ones((b, n), device = device).bool())
#             k_mask = q_mask if not exists(context) else context_mask
#             k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device = device).bool())
            
#             q_mask = rearrange(q_mask, 'b i -> b () i ()')
#             k_mask = rearrange(k_mask, 'b j -> b () () j')
#             input_mask = q_mask * k_mask
        
        # classic dot-product attention
        dots = torch.einsum('bhid,bhjd->bhij', q*self.scale, k)
        
        if exists(input_mask):
            dots.masked_fill_(~input_mask, MASK_VAL)
            del input_mask

        if self.causal:
            i, j = dots.shape[-2:]
            mask = torch.ones((i, j), device = device).triu_(j - i + 1).bool()
            dots.masked_fill_(mask, MASK_VAL)
            del mask

        attn = F.softmax(dots, -1)
        if self.store_attention: self.attention = attn.detach().cpu()
        
        attn = self.dropout(attn)
        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return out

Scaled dot-product attention is calculated as:

$$\textbf {Attention}(Q,K,V) = \textbf {softmax}({QK^T\over\sqrt d_k})V $$

In [19]:
q = torch.randn(bs, sl, d)
k = torch.randn(bs, sl, d)
v = torch.randn(bs, sl, d)
attn_func = ScaledDotProdAttention(d, 4)
out = attn_func(q, k, v)
out.shape

torch.Size([4, 128, 64])

In [20]:
#hide
#TODO: add tests with input mask

In [21]:
#export
class Attention(Module):
    """
    Standard attention module using scaled dot-product attention
    """
    def __init__(self, 
                 d_model:int, 
                 n_heads:int = 8, 
                 causal:bool = False,
                 mask:Tensor = None,
                 dropout:float=0.1,
                 out_dropout:float=None,
                 bias:bool=True,
                 store_attention:bool=False):
        store_attr('causal, mask, n_heads, bias')
        out_dropout = ifnone(out_dropout, dropout)
        self.in_proj = AttnInProj(d_model, bias=bias)
        self.attn = ScaledDotProdAttention(d_model, n_heads, causal=causal,
                                           dropout=dropout, store_attention=store_attention)
        self.out_proj = nn.Linear(d_model, d_model, bias=bias)
        self.dropout = nn.Dropout(out_dropout)
        self._init()

    def forward(self, x, context = None, mask = None, context_mask = None):
        q, k, v = self.in_proj(x, context)
        input_mask = self._make_input_mask(mask, context_mask, context, x.device)
        out = self.attn(q, k, v, input_mask)
        
        out = self.out_proj(out)
        return self.dropout(out)
        
    def _init(self):
        [nn.init.xavier_uniform_(w) for w in self.parameters() if w.dim()>1]
        if self.bias:
            [nn.init.constant_(b, 0) for b in self.parameters() if b.dim()==1]
    
    def _make_input_mask(self, mask, context_mask, context, device):
        if any(map(exists, (mask, context_mask))):
            q_mask = default(mask, lambda: torch.ones((b, n), device = device).bool())
            k_mask = q_mask if not exists(context) else context_mask
            k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device = device).bool())
            
            q_mask = rearrange(q_mask, 'b i -> b () i ()')
            k_mask = rearrange(k_mask, 'b j -> b () () j')
            return q_mask * k_mask
        else: return None #input_mask is None if both mask and context_mask are None

In [22]:
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
attn = Attention(d)
out = attn(x)
assert (bs, sl, d) == out.size()
out.shape

torch.Size([4, 128, 64])

In [23]:
#hide
#TODO: add tests with input mask

## Additive Attention

In [24]:
#hide
# previous version of additive attention for reference
class DecoderAttention(Module):
    def __init__(self, 
                 d_model, 
                 heads = 8, 
                 causal = False,
                 mask = None,
                 dropout=0.1, 
                 bias=True):
        self.causal = causal
        self.store_attention = False
        self.mask = mask #??
        self.heads = heads
        self.scale = (d_model//heads) ** -0.5
        
        self.to_q = nn.Linear(d_model, d_model, bias = bias)
        self.to_kv = nn.Linear(d_model, d_model * 2, bias = bias)
        self.dropout = nn.Dropout(dropout)

        self.to_out = nn.Linear(d_model, d_model)

        self._init()

    def forward(self, x, context = None, mask = None, context_mask = None, store_attention=False):
        b, n, d, h, device = *x.shape, self.heads, x.device
        context = default(context, torch.empty(b, 0, d, dtype=x.dtype, device=device))
        kv_input = torch.cat([x, context], dim=-2)
        
        q = self.to_q(x)
        kv = self.to_kv(kv_input).chunk(2, dim = -1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, *kv))

        # boolean input_mask is False at positions not to attend to
        input_mask = None
        if any(map(exists, (mask, context_mask))):
            q_mask = default(mask, lambda: torch.ones((b, n), device = device).bool())
            self_mask = q_mask[:, None, :, None] * q_mask[:, None, None, :]
            if context.size(-2) != 0:
                k_mask = default(context_mask, lambda: torch.ones((b, context.shape[-2]), device = device).bool())
                cross_mask = q_mask[:, None, :, None] * k_mask[:, None, None, :]
            else: cross_mask = torch.empty(0, dtype=self_mask.dtype, device=device)
            input_mask = torch.cat([self_mask, cross_mask], dim=-1)
        
        # classic scaled dot-product attention
        dots = torch.einsum('bhid,bhjd->bhij', q * self.scale, k)
        
        # might need to tune MASK_VAL for fp16 to work
        if exists(input_mask):
            dots.masked_fill_(~input_mask, MASK_VAL)
            del input_mask

        if self.causal:
            i, j = torch.triu_indices(n, n, 1)
            dots[:,:,i,j] = MASK_VAL

        attn = F.softmax(dots, -1)
        if self.store_attention: # and not self.training
            self.attention = attn.detach().cpu()
        attn = self.dropout(attn)

        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

    def _init(self):
        [nn.init.xavier_uniform_(w) for w in [self.to_q.weight, self.to_kv.weight, self.to_out.weight]]
        if getattr(self.to_q, 'bias', None) is not None: nn.init.constant_(self.to_q.bias, 0)
        if getattr(self.to_kv, 'bias', None) is not None: nn.init.constant_(self.to_kv.bias, 0)
        nn.init.constant_(self.to_out.bias, 0)

In [25]:
#export
class AdditiveInProj(Module):
    """Computes q, k, v from input x and [optional] context"""
    def __init__(self, d_model:int, bias:bool=False):
        self.to_q = nn.Linear(d_model, d_model, bias=bias)
        self.to_k = nn.Linear(d_model, d_model, bias=bias)
        self.to_v = nn.Linear(d_model, d_model, bias=bias)
    def forward(self, x, context=None):
        b, _, d = x.size()
        context = ifnone(context, torch.empty(b, 0, d, dtype=x.dtype, device=x.device))
        kv_input = torch.cat([x, context], dim=-2)
        q = self.to_q(x)
        k, v = self.to_k(kv_input), self.to_v(kv_input)
        return q, k, v

In [26]:
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
proj = AdditiveInProj(d)
q1, k1, v1 = proj(x)
assert (bs, sl, d) == q1.size() == k1.size() == v1.size()
q1.shape, k1.shape, v1.shape

(torch.Size([4, 128, 64]), torch.Size([4, 128, 64]), torch.Size([4, 128, 64]))

In [27]:
q2, k2, v2 = proj(x, context)
assert (bs, sl, d) == q2.size()
assert k2.size() == v2.size() == (bs, x.size(1)+context.size(1), d)
assert all_equal(q1, q2)
assert not all_equal(k1, k2)
q2.shape, k2.shape, v2.shape

(torch.Size([4, 128, 64]), torch.Size([4, 240, 64]), torch.Size([4, 240, 64]))

In [28]:
#export
class AdditiveAttention(Attention):
    """
    Additive attention module: 
    """
    def __init__(self, 
                 d_model:int, 
                 n_heads:int = 8, 
                 causal:bool = False,
                 dropout:float=0.1,
                 out_dropout:float=None,
                 bias:bool=True,
                 store_attention:bool=False):
        store_attr('causal, n_heads, bias')
        out_dropout = ifnone(out_dropout, dropout)
        self.in_proj = AdditiveInProj(d_model, bias=bias)
        self.attn = ScaledDotProdAttention(d_model, n_heads, causal=causal,
                                           dropout=dropout, store_attention=store_attention)
        self.out_proj = nn.Linear(d_model, d_model, bias=bias)
        self.dropout = nn.Dropout(out_dropout)
        self._init()
    
    def _make_input_mask(self, mask, context_mask, context, device):
        if any(map(exists, (mask, context_mask))):
            q_mask = default(mask, lambda: torch.ones((b, n), device = device).bool())
            self_mask = q_mask[:, None, :, None] * q_mask[:, None, None, :]
            if context.size(-2) != 0:
                k_mask = default(context_mask, lambda: torch.ones((b, context.shape[-2]), device = device).bool())
                cross_mask = q_mask[:, None, :, None] * k_mask[:, None, None, :]
            else: cross_mask = torch.empty(0, dtype=self_mask.dtype, device=device)
            return torch.cat([self_mask, cross_mask], dim=-1)
        else: return None #input_mask is None if both mask and context_mask are None

In [29]:
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
attn = AdditiveAttention(d)
out = attn(x)
assert (bs, sl, d) == out.size()
out.shape

torch.Size([4, 128, 64])

## Shared QK Attention

In [30]:
#export
class SharedAttnInProj(Module):
    """Computes q, k, v from input x and [optional] context"""
    def __init__(self, d_model:int, bias:bool=False):
#         self.to_q = nn.Linear(d_model, d_model, bias=bias)
        self.to_qk = nn.Linear(d_model, d_model, bias=bias)
        self.to_v = nn.Linear(d_model, d_model, bias=bias)
    def forward(self, x, context=None):
        context = ifnone(context, x)
        qk = self.to_qk(x)
#         k = q
#         k = self.to_q(x)
        v = self.to_v(context)
#         return q, k, v
        return qk, v

In [224]:
class SharedQKAttention(Module):
    """
    Standard attention module using scaled dot-product attention
    """
    def __init__(self, 
                 d_model:int, 
                 n_heads:int = 8, 
                 causal:bool = False,
                 mask:Tensor = None,
                 dropout:float=0.1,
                 out_dropout:float=None,
                 bias:bool=True,
                 store_attention:bool=False):
        store_attr('causal, mask, n_heads, bias')
        out_dropout = ifnone(out_dropout, dropout)
        self.in_proj = SharedAttnInProj(d_model, bias=bias)
        self.attn = ScaledDotProdAttention(d_model, n_heads, causal=causal,
                                           dropout=dropout, store_attention=store_attention)
        self.out_proj = nn.Linear(d_model, d_model, bias=bias)
        self.dropout = nn.Dropout(out_dropout)
        self._init()

    def forward(self, x, context = None, mask = None, context_mask = None):
#         q, k, v = self.in_proj(x, context)
        q, v = self.in_proj(x, context)
        
        k = F.normalize(q, 2, dim=-1).type_as(q)
        
        input_mask = self._make_input_mask(mask, context_mask, context, x.device)
        
        b, n, d, h = *x.shape, self.n_heads
        
        mask = torch.zeros(b, h, n, n).bool()    
        m = torch.arange(n)
        mask[:, :, m, m] = True
    
        if exists(input_mask): final_mask = input_mask + self_mask 
        else: final_mask = self_mask
        
        out = self.attn(q, k, v, input_mask)
        
        out = self.out_proj(out)
        return self.dropout(out)
        
    def _init(self):
        [nn.init.xavier_uniform_(w) for w in self.parameters() if w.dim()>1]
        if self.bias:
            [nn.init.constant_(b, 0) for b in self.parameters() if b.dim()==1]
    
    def _make_input_mask(self, mask, context_mask, context, device):
        if any(map(exists, (mask, context_mask))):
            q_mask = default(mask, lambda: torch.ones((b, n), device = device).bool())
            k_mask = q_mask if not exists(context) else context_mask
            k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device = device).bool())
            
            q_mask = rearrange(q_mask, 'b i -> b () i ()')
            k_mask = rearrange(k_mask, 'b j -> b () () j')
            return q_mask * k_mask
        else: return None #input_mask is None if both mask and context_mask are None

In [225]:
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
attn = SharedQKAttention(d)
out = attn(x)
assert (bs, sl, d) == out.size()
out.shape

torch.Size([4, 128, 64])

In [116]:
np.equal(q, k)

tensor([[1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.uint8)

In [None]:
q = torch.rand((4,8))

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.n_heads), (q, k, v))

# classic dot-product attention
dots = torch.einsum('bhid,bhjd->bhij', q*self.scale, k)

In [107]:
q[:, None].size(), q[:, None][0]

(torch.Size([4, 1, 8]),
 tensor([[0.6418, 0.9309, 0.0875, 0.4311, 0.7610, 0.5909, 0.5121, 0.8897]]))

In [108]:
k[None, :].size(), k[None, :][0]

(torch.Size([1, 4, 8]),
 tensor([[0.6418, 0.9309, 0.0875, 0.4311, 0.7610, 0.5909, 0.5121, 0.8897],
         [0.8414, 0.9780, 0.0447, 0.7468, 0.6824, 0.6221, 0.3255, 0.1995],
         [0.5803, 0.7266, 0.5935, 0.8492, 0.6448, 0.8659, 0.1144, 0.8912],
         [0.5869, 0.2278, 0.8895, 0.7219, 0.8223, 0.7393, 0.4513, 0.1532]]))

In [110]:
q[:, None][0] == k[None, :][0]

tensor([[ True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False]])

In [131]:
q = torch.rand((4,4,12))
k = q

self_mask = q[:, :, :, None] == k[:, :, None, :]
self_mask.shape      # [bs, n_chunks, chunk_size, chunk_size*2]

torch.Size([4, 4, 12, 12])

In [132]:
q.size(), k.size()

(torch.Size([4, 4, 12]), torch.Size([4, 4, 12]))

In [133]:
q[:, :, :, None][0,0,:10,0], k[:, :, None, :][0,0,0,:10]

(tensor([0.6686, 0.8243, 0.4995, 0.7482, 0.0835, 0.0371, 0.5566, 0.4160, 0.2992,
         0.4209]),
 tensor([0.6686, 0.8243, 0.4995, 0.7482, 0.0835, 0.0371, 0.5566, 0.4160, 0.2992,
         0.4209]))

In [134]:
self_mask[0,0,:5,:5]

tensor([[ True, False, False, False, False],
        [False,  True, False, False, False],
        [False, False,  True, False, False],
        [False, False, False,  True, False],
        [False, False, False, False,  True]])

In [178]:
b, n, d, h = 4,5,2,8

q = torch.rand((b,n,d*h))
k = q
v = torch.rand((b,n,d*h))

q.size()

torch.Size([4, 5, 16])

In [193]:
qm, km, vm, = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
#q, k, v, = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=8), (q, k, v))

# classic dot-product attention
# dots = torch.einsum('bhid,bhjd->bhij', q*0.01, k)
dots = torch.einsum('bhid,bhjd->bhij', qm*0.01, km)
dots.size()

torch.Size([4, 8, 5, 5])

In [196]:
qm, km, vm, = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
dots = torch.einsum('bhid,bhjd->bhij', qm*0.01, km)
dots.size()

torch.Size([4, 8, 5, 5])

In [198]:
i = torch.arange(n)
dots[:, :, i, i] = -float('inf')
dots[0,0,:5,:5]

tensor([[  -inf, 0.0048, 0.0117, 0.0098, 0.0084],
        [0.0048,   -inf, 0.0046, 0.0042, 0.0033],
        [0.0117, 0.0046,   -inf, 0.0091, 0.0081],
        [0.0098, 0.0042, 0.0091,   -inf, 0.0063],
        [0.0084, 0.0033, 0.0081, 0.0063,   -inf]])

tensor([[  -inf, 0.0048, 0.0117, 0.0098, 0.0084],
        [0.0048,   -inf, 0.0046, 0.0042, 0.0033],
        [0.0117, 0.0046,   -inf, 0.0091, 0.0081],
        [0.0098, 0.0042, 0.0091,   -inf, 0.0063],
        [0.0084, 0.0033, 0.0081, 0.0063,   -inf]])

In [183]:
q.size(), k.size()

(torch.Size([4, 5, 16]), torch.Size([4, 5, 16]))

In [180]:
self_mask = q[:, :, :, None] == k[:, :, None, :]
self_mask.size(), self_mask[0,0,:5,:5]

(torch.Size([4, 5, 16, 16]), tensor([[ True, False, False, False, False],
         [False,  True, False, False, False],
         [False, False,  True, False, False],
         [False, False, False,  True, False],
         [False, False, False, False,  True]]))

In [182]:
qm[:, :, None].size(), km[:, None, :].size()

(torch.Size([4, 8, 1, 5, 2]), torch.Size([4, 1, 8, 5, 2]))

In [181]:
mself_mask = qm[:, :, None] == km[:, None, :]
mself_mask.size(), mself_mask[0,0,:5,:5]

(torch.Size([4, 8, 8, 5, 2]), tensor([[[ True,  True],
          [ True,  True],
          [ True,  True],
          [ True,  True],
          [ True,  True]],
 
         [[False, False],
          [False, False],
          [False, False],
          [False, False],
          [False, False]],
 
         [[False, False],
          [False, False],
          [False, False],
          [False, False],
          [False, False]],
 
         [[False, False],
          [False, False],
          [False, False],
          [False, False],
          [False, False]],
 
         [[False, False],
          [False, False],
          [False, False],
          [False, False],
          [False, False]]]))

In [174]:
q.size(), q[:, :, :, None].size()

(torch.Size([4, 5, 16]), torch.Size([4, 5, 16, 1]))

In [None]:
mself_mask = rearrange(mself_mask, 'b n (h d) -> b h n d', h=8)

In [153]:
q.size(), q[:, :, :, None].size()

(torch.Size([4, 8, 4, 4]), torch.Size([4, 8, 4, 1, 4]))

In [155]:
k.size(), k[:, :, None, :].size()

(torch.Size([4, 8, 4, 4]), torch.Size([4, 8, 1, 4, 4]))

In [152]:
self_mask = q[:, :, :, None] == k[:, :, None, :]
self_mask.size()

(torch.Size([4, 8, 4, 1, 4]),
 torch.Size([4, 8, 1, 4, 4]),
 torch.Size([4, 8, 4, 4, 4]))

In [148]:
self_mask = q[:, :, None] == k[:, None, :]
self_mask.size()

torch.Size([4, 8, 8, 4, 4])

In [149]:
self_mask = torch.ne(q.unsqueeze(-1), k.unsqueeze(-2))
self_mask.size()

torch.Size([4, 8, 4, 4, 4])

In [141]:
self_mask[0,0,:5,:5]

(torch.Size([4, 8, 4, 4, 4]), tensor([[[ True,  True,  True,  True],
          [False, False, False, False],
          [False, False, False, False],
          [False, False, False, False]],
 
         [[False, False, False, False],
          [ True,  True,  True,  True],
          [False, False, False, False],
          [False, False, False, False]],
 
         [[False, False, False, False],
          [False, False, False, False],
          [ True,  True,  True,  True],
          [False, False, False, False]],
 
         [[False, False, False, False],
          [False, False, False, False],
          [False, False, False, False],
          [ True,  True,  True,  True]]]))

In [94]:
q.size(), k.size(), q

(torch.Size([4, 8]),
 torch.Size([4, 8]),
 tensor([[0.8887, 0.3320, 0.5010, 0.7032, 0.8009, 0.7979, 0.6450, 0.5553],
         [0.0773, 0.2908, 0.9349, 0.5278, 0.2663, 0.2752, 0.1811, 0.1575],
         [0.0947, 0.6816, 0.0076, 0.3070, 0.6063, 0.3572, 0.3927, 0.9506],
         [0.6972, 0.4863, 0.6777, 0.8588, 0.3077, 0.9452, 0.4496, 0.4202]]))

In [96]:
dots = torch.matmul(q, k.transpose(1,0))
dots.size(), dots

(torch.Size([4, 4]), tensor([[3.6479, 1.6419, 2.0818, 3.2484],
         [1.6419, 1.4474, 0.8552, 1.7718],
         [2.0818, 0.8552, 2.1208, 1.7664],
         [3.2484, 1.7718, 1.7664, 3.2862]]))

In [97]:
q.size(), q #, k

(torch.Size([4, 8]),
 tensor([[0.8887, 0.3320, 0.5010, 0.7032, 0.8009, 0.7979, 0.6450, 0.5553],
         [0.0773, 0.2908, 0.9349, 0.5278, 0.2663, 0.2752, 0.1811, 0.1575],
         [0.0947, 0.6816, 0.0076, 0.3070, 0.6063, 0.3572, 0.3927, 0.9506],
         [0.6972, 0.4863, 0.6777, 0.8588, 0.3077, 0.9452, 0.4496, 0.4202]]))

In [69]:
self_mask.size(), self_mask[0]

(torch.Size([4, 4, 8]),
 tensor([[ True,  True,  True,  True,  True,  True,  True,  True],
         [False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False]]))

In [68]:
self_mask.squeeze(0).size(), self_mask.squeeze(0)

(torch.Size([4, 4, 8]),
 tensor([[[ True,  True,  True,  True,  True,  True,  True,  True],
          [False, False, False, False, False, False, False, False],
          [False, False, False, False, False, False, False, False],
          [False, False, False, False, False, False, False, False]],
 
         [[False, False, False, False, False, False, False, False],
          [ True,  True,  True,  True,  True,  True,  True,  True],
          [False, False, False, False, False, False, False, False],
          [False, False, False, False, False, False, False, False]],
 
         [[False, False, False, False, False, False, False, False],
          [False, False, False, False, False, False, False, False],
          [ True,  True,  True,  True,  True,  True,  True,  True],
          [False, False, False, False, False, False, False, False]],
 
         [[False, False, False, False, False, False, False, False],
          [False, False, False, False, False, False, False, False],
          [Fals

## LSH Attention

LSH attention from Reformer: [The Efficient Transformer](https://arxiv.org/abs/2001.04451). Based on [lucidrains/reformer-pytorch](https://github.com/lucidrains/reformer-pytorch/), but simpliefied and refactored.

In [6]:
#export
class LSHAttention(Module):
    """
    Additive attention module: 
    """
    def __init__( self,
                  dropout = 0.,                       # attention matrix dropout
                  bucket_size = 64,                   # at least 64 suggested in trax
                  n_hashes = 8,                       # papers sugests 8
                  causal = False,
                  allow_duplicate_attention = False,  # as in the paper
                  attend_across_buckets = False,      # as in the paper
                  drop_for_hash_rate = 0.0,           # unsure of default, not mentioned in paper
                  return_attn = False,
                  **kwargs):
        
        if dropout >= 1.0 or drop_for_hash_rate >=1.0:
            raise ValueError('Dropout rates must be lower than 1.')
        
        store_attr(but=['dropout', 'drop_for_hash_rate'])  # fastcore - store attibutes
        self.dropout = nn.Dropout(dropout)
        self.dropout_for_hash = nn.Dropout(drop_for_hash_rate)
        self._cache = {} # cache buckets for reversible network, required to make Reformer work at depth

    @cache_method_decorator('_cache', 'buckets', reexecute=True)
    def hash_vectors(self, n_buckets, vecs):
        # 0. We need an even number of buckets: 
        assert n_buckets % 2 == 0

        # 1. account for the input shapes. vecs = [bs, sl, dim]
        batch_size, seqlen, dim = vecs.shape
        device = vecs.device
        #print(device)
        rotations_shape = (dim, self.n_hashes, n_buckets // 2)

        # 2. Calculate hash bucket id via random rotations, concatenation and argmax 
        # note: we copy rotations accross batch dimension (see exploration notebook for details). 
        random_rotations = repeat(torch.randn(rotations_shape,device=device), 
                                  'd nh nb -> bs d nh nb', bs=batch_size)           
        dropped_vecs = self.dropout_for_hash(vecs)
                       
        rotated_vecs = torch.einsum('bsd,bdhn->bhsn', 
                                    dropped_vecs,       # [bs, sl, dim]
                                    random_rotations)   # [bs, dim, n_hashes, n_buckets//2]
                                                        # rotated vecs: [bs, n_hashes, sl, n_buckets//2]

        rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) # [bs, n_hashes, sl, n_buckets]
        buckets = torch.argmax(rotated_vecs, dim=-1)                    # [bs, n_hashes, sl] 

        # 3. Next we add offsets so that bucket numbers from different hashing rounds don't overlap.
        # We also reshape the buckets so that each hash round is concatenated along the -1 dim
        offsets = torch.arange(self.n_hashes,device=device)                              # list of [0,1,2,..n_hashes-1]
        offsets = rearrange(offsets * n_buckets, 'nh -> 1 nh 1')        # [1, n_hashes, 1]
        buckets = rearrange(buckets+offsets, 'bs nh sl -> bs (nh sl)')  # [bs, (n_hashes*sl)]
        return buckets

    def forward(self, qk, v, input_mask = None, **kwargs):
        batch_size, seqlen, dim, device = *qk.shape, qk.device
        #print(qk.device)

        # caching
        is_reverse = kwargs.pop('_reverse', False)
        depth = kwargs.pop('_depth', None)
        
        # We will have an even number of buckets, and our attention chunks needs to fit completely within a seqlen
        assert seqlen % (self.bucket_size * 2) == 0, f'Sequence length ({seqlen}) needs to be divisible by target bucket size  x 2 - {self.bucket_size * 2}'
        
        # get the hash buckets for our qk input vectors
        n_buckets = seqlen // self.bucket_size
        buckets = self.hash_vectors(n_buckets, qk, key_namespace=depth, fetch=is_reverse, set_cache=self.training)

        # We use the same vector as both a query and a key.
        assert int(buckets.shape[1]) == self.n_hashes * seqlen
        
        # Create an index that reflexts both bucket id and sequence id. This let's us sort qk according 
        # to both simultaneously. Repeated across the batch dimension.
        ticker = repeat(torch.arange((self.n_hashes * seqlen),device=device), 'l -> bs l', bs=batch_size)
        buckets_and_t = seqlen * buckets + (ticker % seqlen) 
        buckets_and_t = buckets_and_t.detach()                # [bs, seqlen*n_hashes]

        # Hash-based sort ("s" at the start of variable names means "sorted")
        sbuckets_and_t, sticker = sort_key_val(buckets_and_t, ticker, dim=-1)  # [bs, seqlen*n_hashes]
        _, undo_sort = sticker.sort(dim=-1)                                    # indexes to undo sortings
        del ticker

        sbuckets_and_t = sbuckets_and_t.detach()   # no need to store gradiens for indexes
        sticker = sticker.detach()
        undo_sort = undo_sort.detach()

        st = (sticker % seqlen)             # index of [0..seqlen-1] for each hash round
        sqk = batched_index_select(qk, st)  # get the sorted qk, [bs, seqlen*n_hashes, dim]
        sv = batched_index_select(v, st)    # get the sorted v, [bs, seqlen*n_hashes, dim] 

        # Reshape to include a n_chunks axis.
        n_chunks = self.n_hashes * n_buckets
        bq_t = bkv_t = rearrange(st, 'bs (n s) -> bs n s', n=n_chunks) # [bs, n_chunks, chunk_size]
        bqk = rearrange(sqk, 'bs (n s) d -> bs n s d', n=n_chunks)     # [bs, n_chunks, chunk_size, dim]
        bv = rearrange(sv, 'bs (n s) d -> bs n s d', n=n_chunks)       # [bs, n_chunks, chunk_size, dim]

        # Hashing operates on unit-length vectors. Unnormalized query vectors are
        # fine because they effectively provide a learnable temperature for the
        # attention softmax, but normalizing keys is needed so that similarity for
        # the purposes of attention correctly corresponds to hash locality.
        bq = bqk
        bk = F.normalize(bqk, p=2, dim=-1).type_as(bq)

        # Allow each chunk to attend within itself, and also one chunk back. Chunk
        # boundaries might occur in the middle of a sequence of items from the
        # same bucket, so this increases the chances of attending to relevant items.
        # Note: no look_back for queries

        bk = look_one_back(bk)        # [bs, n_chunks, chunk_size*2, dim]
        bv = look_one_back(bv)        # [bs, n_chunks, chunk_size*2, dim]
        bkv_t = look_one_back(bkv_t)

        # Dot-product attention.
        dots = torch.einsum('bnsd,bnzd->bnsz', 
                    bq,                  # [bs, n_chunks, chunk_size, dim]
                    bk                   # [bs, n_chunks, chunk_size*2, dim]
                   ) * (dim ** -0.5)     # dots: [bs, n_chunks, chunk_size, chunk_size*2]
        masked_value = max_neg_value(dots)

        # Input mask for padding in variable lengthed sequences
        if input_mask is not None:
            input_mask = F.pad(input_mask, (0, seqlen - input_mask.shape[1]), value=True)
            mq = input_mask.gather(1, st).reshape((batch_size, n_chunks, -1))
            mkv = look_one_back(mq)
            mask = mq[:, :, :, None] * mkv[:, :, None, :]
            dots.masked_fill_(~mask, masked_value)
            del mask

        # Causal masking
        if self.causal:
            mask = bq_t[:, :, :, None] < bkv_t[:, :, None, :]
            dots.masked_fill_(mask, masked_value)
            del mask

        # Mask out attention to self except when no other targets are available.
        self_mask = bq_t[:, :, :, None] == bkv_t[:, :, None, :]
        dots.masked_fill_(self_mask, SELF_ATTN_MASK_VAL)
        del self_mask

        # Mask out attention to other hash buckets.
        if not self.attend_across_buckets:
            bq_buckets = bkv_buckets = torch.reshape(sbuckets_and_t // seqlen, (batch_size, n_chunks, -1))
            bkv_buckets = look_one_back(bkv_buckets)
            bucket_mask = bq_buckets[:, :, :, None] != bkv_buckets[:, :, None, :]
            dots.masked_fill_(bucket_mask, masked_value)
            del bucket_mask

        # Don't double-count query-key pairs across multiple rounds of hashing.
        # There are two possible strategies here. (1) The default is to count how
        # many times a query-key pair is repeated, and to lower its log-prob
        # correspondingly at each repetition.
        
        if not self.allow_duplicate_attention:
            locs1 = undo_sort // bq_t.shape[-1]
            locs2 = (locs1 + 1) % n_chunks
            if not self.attend_across_buckets:
                locs1 = buckets * n_chunks + locs1
                locs2 = buckets * n_chunks + locs2
            locs = torch.cat([
                torch.reshape(locs1, (batch_size, self.n_hashes, seqlen)),
                torch.reshape(locs2, (batch_size, self.n_hashes, seqlen)),
            ], 1).permute((0, 2, 1))

            slocs = batched_index_select(locs, st)
            b_locs = torch.reshape(slocs, (batch_size, n_chunks, -1, 2 * self.n_hashes))

            b_locs1 = b_locs[:, :, :, None, :self.n_hashes]

            bq_locs = b_locs1.expand(b_locs.shape[:3] + (2, self.n_hashes))
            bq_locs = torch.reshape(bq_locs, b_locs.shape)
            bkv_locs = look_one_back(b_locs)

            dup_counts = (bq_locs[:, :, :, None, :] == bkv_locs[:, :, None, :, :])
            # for memory considerations, chunk summation of last dimension for counting duplicates
            dup_counts = chunked_sum(dup_counts, chunks=(self.n_hashes * batch_size))
            dup_counts = dup_counts.detach()
            assert dup_counts.shape == dots.shape
            dots = dots - torch.log(dup_counts + 1e-9)
            del dup_counts

        # Softmax.
        dots_logsumexp = torch.logsumexp(dots, dim=-1, keepdim=True)
        dots = torch.exp(dots - dots_logsumexp).type_as(dots)
        dropped_dots = self.dropout(dots)
        
        # calculate self-attention (attn * values)
        bo = torch.einsum('bnsz,bnzd->bnsd', 
                          dropped_dots,      # [bs, n_chunks, chunk_size, chunk_size*2]
                          bv)                # [bs, n_chunks, chunk_size*2, dim]    
                                             # bo: [bs, n_chunks, chunk_size, dim]
        
        # unchunk, unsort and reshape self-attention
        so = rearrange(bo, 'b n s d -> b (n s) d')                     # [bs, seqlen*n_hashes, dim]
        o = batched_index_select(so, undo_sort)                        # [bs, seqlen*n_hashes, dim]
        o = rearrange(o, 'b (nh sl) d -> b nh sl d', nh=self.n_hashes) # [bs, n_hashes, seqlen, dim]
        
        # unchunk, unsort and reshape logits
        slogits = rearrange(dots_logsumexp, 'bs n s 1 -> bs (n s)')              # [bs, seqlen*n_hashes]
        logits = slogits.gather(1, undo_sort)                                    # [bs, seqlen*n_hashes]
        logits = rearrange(logits, 'bs (nr sl) -> bs nr sl 1', nr=self.n_hashes) # [bs, n_hashes, seqlen, 1]
        
        # average probabilites across hash rounds (dim 1) and get weighted attention
        probs = torch.exp(logits - torch.logsumexp(logits, dim=1, keepdim=True)) # [bs, n_rounds, seqlen, 1]
        out = torch.sum(o * probs, dim=1)                                        # [bs, seqlen, dim]

        # return unsorted attention weights - empty otherwise
        attn = torch.empty(0, device=device)
        if self.return_attn:
            attn_unsort = ((bq_t * seqlen)[:, :, :, None] + bkv_t[:, :, None, :])
            attn_unsort = attn_unsort.view(batch_size * self.n_hashes, -1).long()
            unsorted_dots = torch.zeros(batch_size * self.n_hashes, seqlen * seqlen, device=device)
            unsorted_dots.scatter_add_(1, attn_unsort, dots.view_as(attn_unsort))
            del attn_unsort
            unsorted_dots = unsorted_dots.reshape(batch_size, self.n_hashes, seqlen, seqlen)
            attn = torch.sum(unsorted_dots * probs, dim=1)

        # return output, attention matrix, and bucket distribution
        return out, attn, buckets

Test attention layer. Note: `d_model` is infered from input. Assumes shared key and query.

In [8]:
qk, v = torch.randn(bs, sl, d), torch.randn(bs, sl, d)
lsh_attn = LSHAttention()
out, _, _ = lsh_attn(qk, v)
assert (bs, sl, d) == out.size()

## Reformer Attention

In [None]:
#export
#TODO: proto to be implemented...
class ReformerAttention(Module):
    """
    Reformer attnetion container
    
    Switch between FullSharedQKAttention and LSHAttention
    """
    def __init__(self, 
                 d_model:int, 
                 n_heads:int = 8, 
                 causal:bool = False,
                 mask:Tensor = None,
                 dropout:float=0.1,
                 out_dropout:float=None,
                 bias:bool=True,
                 store_attention:bool=False):
        store_attr('causal, mask, n_heads, bias')
        out_dropout = ifnone(out_dropout, dropout)
#         self.in_proj = AttnInProj(d_model, bias=bias)
#         self.attn = ScaledDotProdAttention(d_model, n_heads, causal=causal,
#                                            dropout=dropout, store_attention=store_attention)
        self.out_proj = nn.Linear(d_model, d_model, bias=bias)
        self.dropout = nn.Dropout(out_dropout)
        self._init()

    def forward(self, x, context = None, mask = None, context_mask = None):
        q, k, v = self.in_proj(x, context)
        
        out = self.attn(q, k, v, mask, context_mask)
        
        out = self.out_proj(out)
        return self.dropout(out)
        
    def _init(self):
        [nn.init.xavier_uniform_(w) for w in self.parameters() if w.dim()>1]
        if self.bias:
            [nn.init.constant_(b, 0) for b in self.parameters() if b.dim()==1]

In [None]:
#hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_core.ipynb.
Converted 01_layers.ipynb.
Converted 02_attention.ipynb.
Converted 03_transformer.ipynb.
Converted 04_reformer.ipynb.
Converted index.ipynb.
