In [None]:
#default_exp attention.nystrom

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#hide
%load_ext autoreload
%autoreload 2

In [None]:
#export
import torch
from torch import nn, einsum
import torch.nn.functional as F
from fastai.basics import *

from functools import partial, reduce
from inspect import isfunction
from operator import mul
from copy import deepcopy
import math
from torch import Tensor
from typing import Tuple

from einops import rearrange, repeat

from transformers_sandbox.core import *
from transformers_sandbox.layers import *
from transformers_sandbox.attention.core import *

# Nystrom Attention
> Memory efficient attention computation based on Nystrom aproximation

Paper: https://arxiv.org/abs/2102.03902

Authors code: https://github.com/mlpen/Nystromformer

In [None]:
#export
class reshape2bhnd:
    "b n (h d) -> b h n d"
    def __init__(self, h=8): self.h = h
    
    def __call__(self, x):
        b, n, d = x.size()
        return x.view(b, n, self.h, d//self.h).transpose(1,2)
    
def reshape2bnd(x):
    b, h, n, d = x.size()
    return x.transpose(1,2).contiguous().view(b, n, -1)

## Moore-Penrose iterative pseudoinverse

In [None]:
def iter_pinv(A, n_iter=6):
    "Iteratively computes Moore-Penrose pseudoinverse of matrix `A`"
    I = torch.eye(A.size(-1), device=A.device)
    #Note: A.abs().sum(-1).max() == 1 because of sofmax applied to A
    Z = A.transpose(-1,-2) / A.abs().sum(-2).max(-1).values[..., None, None]
    for _ in range(n_iter):
        AZ = A@Z
        Z = 0.25*Z @ (13*I - AZ @ (15*I - AZ @ (7*I - AZ)))
    return Z

In [None]:
#hide
x = torch.softmax(torch.randn(64,64), dim=-1)
I = torch.eye(x.size(-1))
z = x.transpose(-1,-2) / x.abs().sum(-2).max(-1).values[..., None, None]
# assert torch.norm((I - x@z)) < 1
torch.norm(I - x@z), np.linalg.det(x.numpy())

(tensor(7.8314), 0.0)

In [None]:
x = torch.softmax(torch.randn(4,2,64,64), dim=-1)
n_iter = 40
z = iter_pinv(x, n_iter)
assert torch.allclose(x@z, torch.eye(x.size(-1)), atol=1e-4, rtol=1e-4), \
        f"Iterative pseudoinverse didn't converge in {n_iter} iterations"
# (x@z - torch.eye(x.size(-1))).max().item()

In [None]:
#skip
#hide
x.requires_grad = True
x.pinverse()
loss = x.pinverse().sum()
loss.backward()
x.grad.isnan().any()

tensor(False)

In [None]:
#skip
# %timeit x.pinverse()

In [None]:
#skip
# %timeit iter_pinv(x, 10)

## Nystrom attention aproximation

TODO:
* [ ] add masking
* [ ] more testcases

In [None]:
import pdb

In [None]:
?ScaledDotProdAttention

[0;31mInit signature:[0m
[0mScaledDotProdAttention[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0md_model[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mn_heads[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcausal[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdropout[0m[0;34m=[0m[0;36m0.0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mshared_qk[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mstore_attention[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m      Computes scaled dot-product attnetion given q, k, v
[0;31mInit docstring:[0m Initializes internal Module state, shared by both nn.Module and ScriptModule.
[0;31mFile:[0m           /media/arto/work/dev/git/transformers_sandbox/transformers_sandbox/attention/core.py
[0;31mType:[0m           PrePostInitMeta
[0;31mSubclasses:[0m     


In [None]:
#export
class NystromAttention(Module):
    """Computes attention using nystrom aproximation based approach"""
    def __init__(self, d_model, n_heads=8, causal=False, n_landmarks=64,
                 store_attention:bool=False, use_conv=False, dropout=0.,
                 pinv_n_iter=6, conv_kernel_size=33, **kwargs):
        store_attr()
        self.scale = (d_model//n_heads)**-0.5
        if use_conv:
            self.conv = nn.Conv2d(n_heads, n_heads,
                                  kernel_size=(conv_kernel_size, 1),
                                  padding=(conv_kernel_size//2, 0),
                                  groups=n_heads, bias=False)
    
    def forward(self, q, k, v, attn_mask=None):
        bs, n, d, h, l = *q.size(), self.n_heads, self.n_landmarks
        dh = d//h
        #reshape = partial(rearrange, pattern='b n (h d) -> b h n d', h=self.n_heads)
        q, k, v = map(reshape2bhnd(h), (q,k,v))
        
        if n <= self.n_landmarks: 
            #?? mb do this with threshold instead
            dots = F.softmax(einsum('...nd, ...md -> ...nm', q*scale, k), dim=-1)
            if exists(attn_mask):
                dots.masked_fill_(~attn_mask, MASK_VAL)
                del attn_mask
            out = einsum('...nm, ...md -> ...nd', dots, v)
        
        ql = torch.reshape(q, (bs, h, l, -1, dh)).mean(dim=-2)
        kl = torch.reshape(k, (bs, h, l, -1, dh)).mean(dim=-2)
        
        f = F.softmax(einsum('...nd, ...md -> ...nm', q *self.scale, kl), dim=-1)
        b = F.softmax(einsum('...md, ...nd -> ...mn', ql*self.scale, k ), dim=-1)
        a = F.softmax(einsum('...ld, ...md -> ...lm', ql*self.scale, kl), dim=-1)
        a = iter_pinv(a)
        
        out = einsum('...nm, ...md -> ...nd',
                     einsum('...nl, ...lm -> ...nm', f, a),
                     einsum('...mn, ...nd -> ...md', b, v))
        
        if self.use_conv:
            out += self.conv(v)
        return reshape2bnd(out)

In [None]:
bs = 4
sl = 128
d = 64
q = torch.randn(bs, sl, d)
k = torch.randn(bs, sl, d)
v = torch.randn(bs, sl, d)
l = 8 # number of landmark points
scale = d**-0.5
ql = torch.reshape(q, (bs, l, -1, d)).mean(dim=-2)
kl = torch.reshape(k, (bs, l, -1, d)).mean(dim=-2)
ql.shape, kl.shape

f = F.softmax(einsum('...nd, ...md -> ...nm', q*scale, kl), dim=-1)
b = F.softmax(einsum('...md, ...nd -> ...mn', ql*scale, k), dim=-1)
print(f.shape, b.shape, end=' ')
a = einsum('...nd, ...md -> ...nm', ql*scale, kl)
print(a.shape)

torch.Size([4, 128, 8]) torch.Size([4, 8, 128]) torch.Size([4, 8, 8])


In [None]:
bs = 4
sl = 128
d = 64
q = torch.randn(bs, sl, d)
k = torch.randn(bs, sl, d)
v = torch.randn(bs, sl, d)
attn_func = NystromAttention(d, 4, n_landmarks=16)
out = attn_func(q, k, v)
assert out.size() == (bs,sl,d)
out.shape

torch.Size([4, 128, 64])

In [None]:
#hide
attn_func = NystromAttention(d, 4)
mask = torch.ones(1,sl,sl).bool()
out = attn_func(q, k, v, attn_mask=mask)
assert out.size() == (bs,sl,d)

In [None]:
#cuda
q = torch.randn(bs, sl, d).cuda()
k = torch.randn(bs, sl, d).cuda()
v = torch.randn(bs, sl, d).cuda()
attn_func = NystromAttention(d, 4)
out = attn_func(q, k, v)

## Attention container

In [None]:
#TODO
#skip
NystromerAttention = partial(Attention, attn_func=NystromAttention)

In [None]:
#skip
#hide
class _Attention(Module):
    """
    Standard attention module using scaled dot-product attention
    """
    def __init__(self, 
                 d_model:int, 
                 n_heads:int = 8, 
                 causal:bool = False,
                 mask:Tensor = None,
                 dropout:float=0.1,
                 out_dropout:float=None,
                 bias:bool=False,
                 shared_qk:bool=False,
                 store_attention:bool=False):
        store_attr('causal, mask, n_heads, bias, shared_qk')
        out_dropout = ifnone(out_dropout, dropout)
        if shared_qk: self.in_proj = SharedQKAttnInProj(d_model, bias=bias)
        else: self.in_proj = AttnInProjV2(d_model, bias=bias)
        self.attn = ScaledDotProdAttention(d_model, n_heads, causal=causal,
                                           dropout=dropout, shared_qk=shared_qk, 
                                           store_attention=store_attention)
        self.out_proj = nn.Linear(d_model, d_model, bias=bias)
        self.dropout = nn.Dropout(out_dropout)
        self._init()

    def forward(self, x, context = None, mask = None, context_mask = None):
        q, k, v = self.in_proj(x, context)
        if self.shared_qk: k = F.normalize(k, 2, dim=-1).type_as(k)
                
        attn_mask = self._make_attn_mask(mask, context_mask, x, context)
        out = self.attn(q, k, v, attn_mask)
        
        out = self.out_proj(out)
        return self.dropout(out)
        
    def _init(self):
        [nn.init.xavier_uniform_(w) for w in self.parameters() if w.dim()>1]
        if self.bias:
            [nn.init.constant_(b, 0) for b in self.parameters() if b.dim()==1]
    
    def _make_attn_mask(self, mask, context_mask, x, context):
        if any(map(exists, (mask, context_mask))):
            b, n, _, device = *x.size(), x.device
            q_mask = default(mask, lambda: torch.ones((b, n), device = device).bool())
            k_mask = q_mask if not exists(context) else context_mask
            k_mask = default(k_mask, lambda: torch.ones((b, context.shape[-2]), device = device).bool())
            
            q_mask = rearrange(q_mask, 'b i -> b () i ()')
            k_mask = rearrange(k_mask, 'b j -> b () () j')
            return q_mask * k_mask
        else: return None #attn_mask is None if both mask and context_mask are None

In [None]:
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
attn = NystromerAttention(d)
out = attn(x)
assert (bs, sl, d) == out.size()
out.shape

torch.Size([4, 128, 64])

In [None]:
# out = attn(x, context)
# assert (bs, sl, d) == out.size()
# out.shape

In [None]:
# e_msg = "Causal masking error"
# attn = Attention(d, causal=True, dropout=0)
# x1 = torch.randn(bs, sl, d)
# out1 = attn(x1)
# x2 = x1.clone()
# x2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
# out2 = attn(x2)
# # all elements in first half are equal despite second half is defferent
# assert all_equal(out1[:, :sl//2], out2[:, :sl//2]), e_msg
# assert not (out1[:, sl//2:] == out2[:, sl//2:]).any(), e_msg

In [None]:
# e_msg = "Masking error"
# attn = Attention(d, causal=False, dropout=0)
# x1 = torch.randn(bs, sl, d)
# mask = torch.ones(bs, sl)
# # mask out second half of input
# mask[:, sl//2:] = 0
# mask = mask.bool()
# out1 = attn(x1, mask=mask)
# x2 = x1.clone()
# x2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
# out2 = attn(x2, mask=mask)
# # all elements are equal, masked values do not effect result
# assert all_equal(out1[:, :sl//2], out2[:, :sl//2]), e_msg
# out1 = attn(x1)
# out2 = attn(x2)
# assert not (out1[:, :sl//2] == out2[:, :sl//2]).any()

In [None]:
# e_msg = "Context masking error"
# attn = Attention(d, causal=False, dropout=0)
# x = torch.randn(bs, sl, d)
# context = torch.randn(bs, sl, d)
# context_mask = torch.ones(bs, sl)
# # mask out second half of context
# context_mask[:, sl//2:] = 0
# context_mask = context_mask.bool()
# out1 = attn(x, context, context_mask=context_mask)
# context2 = context.clone()
# context2[:, sl//2:, :] = torch.randn(bs, sl//2, d)
# out2 = attn(x, context2, context_mask=context_mask)
# # all elements are equal, masked values do not effect result
# assert all_equal(out1, out2), e_msg
# # all output values are different for different context
# out1 = attn(x, context)
# out2 = attn(x, context2)
# assert not (out1 == out2).any()

In [None]:
# # check stored attention matrix
# torch.manual_seed(842)
# bs = 4
# sl = 16
# csl = sl + 16
# d = 64
# x = torch.rand(bs, sl, d)
# context = torch.rand(bs, csl, d)
# mask = torch.ones(bs, sl)
# mask[:, -5:] = 0
# context_mask = torch.ones(bs, csl)
# context_mask[:, -10:] = 0
# mask, context_mask = mask.bool(), context_mask.bool()
# attn = Attention(d, store_attention=True)
# out = attn(x, context, mask=mask, context_mask=context_mask)
# attention = attn.attn.attention
# assert (bs, sl, d) == out.size()
# assert attention.size() == (bs, attn.attn.n_heads, sl, csl)
# # zeros for masked keys and "don't cares" for masked queries
# plt.matshow(attention[0,0]);

In [None]:
# #hide
# #skip
# # check stored attention matrix
# torch.manual_seed(842)
# bs = 4
# sl = 16
# d = 64
# x = torch.rand(bs, sl, d)
# mask = torch.ones(bs, sl)
# mask[:, -5:] = 0
# mask = mask.bool()
# attn = Attention(d, store_attention=True, causal=True)
# out = attn(x, mask=mask)
# attention = attn.attn.attention
# assert (bs, sl, d) == out.size()
# assert attention.size() == (bs, attn.attn.n_heads, sl, sl)
# # zeros for masked keys and "don't cares" for masked queries
# plt.matshow(attention[0,0]);

In [None]:
#hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_core.ipynb.
Converted 01_layers.ipynb.
Converted 02_attention.core.ipynb.
Converted 02b_attention.nystrom.ipynb.
Converted 03_models.transformer.ipynb.
Converted 04a_models.reformer.ipynb.
Converted 04x_models.xtransformer.ipynb.
Converted 05_tokenizers.ipynb.
Converted 06_data.ipynb.
Converted 07_metrics.ipynb.
Converted 08_optimizers.ipynb.
Converted 09_tracking.ipynb.
Converted 10_config.ipynb.
Converted 40_experimental.ipynb.
Converted index.ipynb.
