In [None]:
#default_exp layers

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#export
import torch
from torch import nn, einsum
import torch.nn.functional as F
from fastai.basics import *

from functools import partial, reduce
from inspect import isfunction
from operator import mul
from copy import deepcopy

from torch import Tensor
from typing import Tuple

from einops import rearrange, repeat
try:
    from axial_positional_embedding import AxialPositionalEmbedding, AxialPositionalEmbeddingImage
except ImportError as e:
    print(e)

# Layers 

## helper functions

In [None]:
#export

def exists(val):
    return val is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

def expand_dim1(x):
    if len(x.shape) == 1:
        return x[None, :]
    else: return x

## Wrappers

based on https://github.com/lucidrains/all-normalization-transformer/blob/master/all_normalization_transformer/all_normalization_transformer.py

In [None]:
#export
class Residual(Module):
    """Add skip-connection: out = x + sublayer(x)"""
    def __init__(self, sublayer:Module): store_attr()
    def forward(self, x, *args, **kwargs):
        return x + self.sublayer(x, *args, **kwargs)

In [None]:
#export
class PostNorm(Module):
    """Adds LayerNorm after sublayer"""
    def __init__(self, d_model:int, sublayer:Module):
        store_attr('sublayer')
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, *args, **kwargs):
        x = self.sublayer(x, *args, **kwargs)
        return self.norm(x)

In [None]:
#export    
class PreNorm(Module):
    """Adds LayerNorm before sublayer"""
    def __init__(self, d_model:int, sublayer:Module):
        store_attr('sublayer')
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, *args, **kwargs):
        x = self.norm(x)
        return self.sublayer(x, *args, **kwargs)

## Pointwise FeedForward

In [None]:
#export
class FeedForward(Module):
    """
    Simple positional feed-forward module with GELU activation function.
    If d_ff is None defaults to 4*d_model
    """
    def __init__(self, d_model:int, d_ff:int=None, dropout:float=0.):
        d_ff = default(d_ff, 4 * d_model)
        layers = [nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout),
                    nn.Linear(d_ff, d_model), nn.Dropout(dropout)]
        self.net = nn.Sequential(*layers)
        self._init()
        
    def forward(self, x):
        return self.net(x)
    
    def _init(self):
        [nn.init.xavier_uniform_(p) for p in self.parameters() if p.dim() > 1]

In [None]:
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
ff  = Residual(PreNorm(d, FeedForward(d)))
out = ff(x)
assert (bs, sl, d) == out.size()
out.shape

torch.Size([4, 128, 64])

## Attention

In [None]:
#export
MASK_VAL = 5e-4

In [None]:
#export
class AttnInProj(Module):
    """Computes q, k, v from input x and [optional] context"""
    def __init__(self, d_model:int, bias:bool=False):
        self.to_q = nn.Linear(d_model, d_model, bias=bias)
        self.to_k = nn.Linear(d_model, d_model, bias=bias)
        self.to_v = nn.Linear(d_model, d_model, bias=bias)
    def forward(self, x, context=None):
        context = ifnone(context, x)
        q = self.to_q(x)
        k, v = self.to_k(context), self.to_v(context)
        return q, k, v

In [None]:
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-16, d)
proj = AttnInProj(d)
q1, k1, v1 = proj(x)
assert (bs, sl, d) == q1.size() == k1.size() == v1.size()
q1.shape, k1.shape, v1.shape

(torch.Size([4, 128, 64]), torch.Size([4, 128, 64]), torch.Size([4, 128, 64]))

In [None]:
q2, k2, v2 = proj(x, context)
assert (bs, sl, d) == q2.size()
assert k2.size() == v2.size() == context.size()
assert all_equal(q1, q2)
assert not all_equal(k1, k2)
q2.shape, k2.shape, v2.shape

(torch.Size([4, 128, 64]), torch.Size([4, 112, 64]), torch.Size([4, 112, 64]))

In [None]:
#export
#TODO make sure store_attention works
class ScaledDotProdAttention(Module):
    
    def __init__(self, d_model, n_heads, causal=False, dropout=0., store_attention:bool=False):
        store_attr()
        self.scale = (d_model//n_heads)**-0.5
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, q, k, v, mask=None, context_mask=None):
        device = q.device
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.n_heads), (q, k, v))
        
        # boolean input_mask is False at positions not to attend to
        input_mask = None
        if any(map(exists, (mask, context_mask))):
            q_mask = default(mask, lambda: torch.ones((b, n), device = device).bool())
            k_mask = q_mask if not exists(context) else context_mask
            k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device = device).bool())
            
            q_mask = rearrange(q_mask, 'b i -> b () i ()')
            k_mask = rearrange(k_mask, 'b j -> b () () j')
            input_mask = q_mask * k_mask
        
        # classic dot-product attention
        dots = torch.einsum('bhid,bhjd->bhij', q*self.scale, k)
        
        if exists(input_mask):
            dots.masked_fill_(~input_mask, MASK_VAL)
            del input_mask

        if self.causal:
            i, j = dots.shape[-2:]
            mask = torch.ones((i, j), device = device).triu_(j - i + 1).bool()
            dots.masked_fill_(mask, MASK_VAL)
            del mask

        attn = F.softmax(dots, -1)
        if self.store_attention: self.attention = attn.detach().cpu()
        
        attn = self.dropout(attn)
        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return out

In [None]:
q = torch.randn(bs, sl, d)
k = torch.randn(bs, sl, d)
v = torch.randn(bs, sl, d)
attn_func = ScaledDotProdAttention(d, 4)
out = attn_func(q, k, v)
out.shape

torch.Size([4, 128, 64])

In [None]:
#export
class Attention(Module):
    """
    Standard attention module using scaled dot-product attention
    """
    def __init__(self, 
                 d_model:int, 
                 n_heads:int = 8, 
                 causal:bool = False,
                 mask:Tensor = None,
                 dropout:float=0.1,
                 out_dropout:float=None,
                 bias:bool=True,
                 store_attention:bool=False):
        store_attr('causal, mask, n_heads, bias')
        out_dropout = ifnone(out_dropout, dropout)
        self.in_proj = AttnInProj(d_model, bias=bias)
        self.attn = ScaledDotProdAttention(d_model, n_heads, causal=causal,
                                           dropout=dropout, store_attention=store_attention)
        self.out_proj = nn.Linear(d_model, d_model, bias=bias)
        self.dropout = nn.Dropout(out_dropout)
        self._init()

    def forward(self, x, context = None, mask = None, context_mask = None):
        q, k, v = self.in_proj(x, context)
        
        out = self.attn(q, k, v, mask, context_mask)
        
        out = self.out_proj(out)
        return self.dropout(out)
        
    def _init(self):
        [nn.init.xavier_uniform_(w) for w in self.parameters() if w.dim()>1]
        if self.bias:
            [nn.init.constant_(b, 0) for b in self.parameters() if b.dim()==1]

In [None]:
x = torch.randn(bs, sl, d)
attn  = Residual(PreNorm(d, Attention(d)))
out = attn(x)
out.shape

torch.Size([4, 128, 64])

In [None]:
# TODO test with context, masks
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl-20, d)
attn  = Residual(PreNorm(d, Attention(d)))

out = attn(x, context)
out.shape

torch.Size([4, 128, 64])

In [None]:
# decoder attention class combining self and cross attention 
# may be replaced with generalized attention in future
class AdditiveAttention(nn.Module):
    def __init__(self, 
                 d_model, 
                 heads = 8, 
                 causal = False,
                 mask = None,
                 dropout=0.1, 
                 bias=True):
        super().__init__()
        self.causal = causal
        self.store_attention = False
        self.mask = mask #??
        self.heads = heads
        self.scale = (d_model//heads) ** -0.5
        
        self.to_q = nn.Linear(d_model, d_model, bias = bias)
        self.to_kv = nn.Linear(d_model, d_model * 2, bias = bias)
        self.dropout = nn.Dropout(dropout)

        self.to_out = nn.Linear(d_model, d_model)

        self._init()

    def forward(self, x, context = None, mask = None, context_mask = None, store_attention=False):
        b, n, d, h, device = *x.shape, self.heads, x.device
        context = default(context, torch.empty(b, 0, d, dtype=x.dtype, device=device))
        kv_input = torch.cat([x, context], dim=-2)
        
        q = self.to_q(x)
        kv = self.to_kv(kv_input).chunk(2, dim = -1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, *kv))

        # boolean input_mask is False at positions not to attend to
        input_mask = None
        if any(map(exists, (mask, context_mask))):
            q_mask = default(mask, lambda: torch.ones((b, n), device = device).bool())
            self_mask = q_mask[:, None, :, None] * q_mask[:, None, None, :]
            if context.size(-2) != 0:
                k_mask = default(context_mask, lambda: torch.ones((b, context.shape[-2]), device = device).bool())
                cross_mask = q_mask[:, None, :, None] * k_mask[:, None, None, :]
            else: cross_mask = torch.empty(0, dtype=self_mask.dtype, device=device)
            input_mask = torch.cat([self_mask, cross_mask], dim=-1)
        
        # classic scaled dot-product attention
        dots = torch.einsum('bhid,bhjd->bhij', q * self.scale, k)
        
        # might need to tune MASK_VAL for fp16 to work
        if exists(input_mask):
            dots.masked_fill_(~input_mask, MASK_VAL)
            del input_mask

        if self.causal:
            i, j = torch.triu_indices(n, n, 1)
            dots[:,:,i,j] = MASK_VAL

        attn = F.softmax(dots, -1)
        if self.store_attention: # and not self.training
            self.attention = attn.detach().cpu()
        attn = self.dropout(attn)

        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

    def _init(self):
        [nn.init.xavier_uniform_(w) for w in [self.to_q.weight, self.to_kv.weight, self.to_out.weight]]
        if getattr(self.to_q, 'bias', None) is not None: nn.init.constant_(self.to_q.bias, 0)
        if getattr(self.to_kv, 'bias', None) is not None: nn.init.constant_(self.to_kv.bias, 0)
        nn.init.constant_(self.to_out.bias, 0)

## Transformer blocks

### Encoder

In [None]:
#export
class TransformerEncoderBlock(Module):
    """
    Bacis transformer encoder block. Consists of multi-head attention and positional 
    feedforward layers
    """
    def __init__(self, 
                 d_model, 
                 n_heads = 8, 
                 d_ff = None, 
                 attn_dropout = 0.1,
                 ff_dropout = 0.1,
                 causal = False, 
                 mask = None, 
                 attn_bias = True, 
                 prenorm=False):
        store_attr('attn_dropout') # mb separate argument attn_post_dropout
        if prenorm:
            self.attn = Residual(PreNorm(d_model, Attention(d_model, n_heads=n_heads, causal=causal, dropout=attn_dropout, bias=attn_bias)))
            self.ff = Residual(PreNorm(d_model, FeedForward(d_model, d_ff=d_ff, dropout=ff_dropout)))
        else:
            self.attn = PostNorm(d_model, Residual(Attention(d_model, n_heads=n_heads, causal=causal, dropout=attn_dropout, bias=attn_bias)))
            self.ff = PostNorm(d_model, Residual(FeedForward(d_model, d_ff=d_ff, dropout=ff_dropout)))
        self.dropout = nn.Dropout(attn_dropout)
        
    def forward(self, x, mask=None): #? more args
        out = self.attn(x, mask=mask)
        out = self.dropout(out)
        return self.ff(out)

In [None]:
x = torch.randn(bs, sl, d)
m = TransformerEncoderBlock(d)
out = m(x)
out.shape

torch.Size([4, 128, 64])

In [None]:
#export
class TransformerEncoder(Module):
    def __init__(self, 
                 d_model, 
                 n_layers=6, 
                 n_heads=8, 
                 d_ff=None,
                 ff_dropout=0.1, 
                 attn_dropout=0.1,
                 attn_bias=True,
                 causal=False, 
                 prenorm=False, 
                 final_norm=None):
        store_attr('d_model')
        self.layers = nn.ModuleList([])    
        for _ in range(n_layers):
            self.layers.append(TransformerEncoderBlock(d_model, n_heads, causal=causal, 
                                    d_ff=d_ff, attn_dropout=attn_dropout, ff_dropout=ff_dropout, 
                                    prenorm=prenorm, attn_bias=attn_bias))
        self.norm = None if final_norm is None else final_norm(d_model)
        
    def forward(self, x, mask=None):
        for layer in self.layers: x = layer(x, mask=mask)
        if self.norm is not None: x = self.norm(x)
        return x

In [None]:
x = torch.randn(bs, sl, d)
m = TransformerEncoder(d)
out = m(x)
out.shape

torch.Size([4, 128, 64])

### Decoder

In [None]:
#export
class TransformerDecoderBlock(Module):
    def __init__(self, 
                 d_model, 
                 n_heads = 8, 
                 d_ff = None,
                 attn_dropout = 0.1, 
                 ff_dropout=0.1,
                 mask = None ,
                 attn_bias = True,
                 prenorm=False):
        store_attr('attn_dropout')     # mb separate argument attn_post_dropout
        if prenorm:
            self.attn = Residual(PreNorm(d_model, Attention(d_model, n_heads=n_heads, causal=True, dropout=attn_dropout, bias=attn_bias)))
            self.cross = Residual(PreNorm(d_model, Attention(d_model, n_heads=n_heads, causal=False, dropout=attn_dropout, bias=attn_bias)))
            self.ff = Residual(PreNorm(d_model, FeedForward(d_model, d_ff=d_ff, dropout=ff_dropout)))
        else:
            self.attn = PostNorm(d_model, Residual(Attention(d_model, n_heads=n_heads, causal=True, dropout=attn_dropout, bias=attn_bias)))
            self.cross = PostNorm(d_model, Residual(Attention(d_model, n_heads=n_heads, causal=False, dropout=attn_dropout, bias=attn_bias)))
            self.ff = PostNorm(d_model, Residual(FeedForward(d_model, d_ff=d_ff, dropout=ff_dropout)))
        self.dropout = nn.Dropout(attn_dropout)
        
    def forward(self, x, context, mask=None, context_mask=None):
        out = self.attn(x, mask=mask)
        out = self.dropout(out)
        out = self.cross(out, context, mask=mask, context_mask=context_mask)
        out = self.dropout(out)
        return self.ff(out)

In [None]:
#export
class TransformerDecoderBlockV2(nn.Module):
    def __init__(self, d_model, n_heads = 8, mask = None, d_ff=None,
                 attn_dropout=0.1, ff_dropout=0.1, attn_bias=True,
                 prenorm=False):
        super().__init__()
        self.attn_dropout = attn_dropout # mb separate argument attn_post_dropout
        if prenorm:
            self.attn = Residual(PreNorm(d_model, AdditiveAttention(d_model, n_heads=n_heads, causal=True, dropout=attn_dropout, bias=attn_bias)))
            self.ff = Residual(PreNorm(d_model, FeedForward(d_model, d_ff=d_ff, dropout=ff_dropout)))
        else:
            self.attn = PostNorm(d_model, Residual(AdditiveAttention(d_model, n_heads=n_heads, causal=True, dropout=attn_dropout, bias=attn_bias)))
            self.ff = PostNorm(d_model, Residual(FeedForward(d_model, d_ff=d_ff, dropout=ff_dropout)))
        
    def forward(self, x, context, mask=None, context_mask=None):
        out = self.attn(x, context, mask=mask, context_mask=context_mask)
        out = F.dropout(out, p=self.attn_dropout)
        out = self.ff(out)
        return out

In [None]:
#export   
class TransformerDecoder(Module):
    def __init__(self, 
                 d_model, 
                 n_layers=6, 
                 n_heads=8, 
                 d_ff=None, 
                 attn_dropout=0.1, 
                 ff_dropout=0.1, 
                 prenorm=False, 
                 comb_attn=False, 
                 attn_bias=True, 
                 final_norm=None):
        store_attr('d_model')
        #TODO(Arto) refactor
        block = TransformerDecoderBlockV2 if comb_attn else TransformerDecoderBlock
        self.layers = nn.ModuleList([])
        for _ in range(n_layers):
            self.layers.append(block(d_model, n_heads, d_ff=d_ff, attn_dropout=attn_dropout, 
                                     ff_dropout=ff_dropout, prenorm=prenorm, attn_bias=attn_bias))
        self.norm = None if final_norm is None else final_norm(d_model)
        
    def forward(self, x, context, mask=None, context_mask=None):
        for layer in self.layers: x = layer(x, context, mask, context_mask)
        if self.norm is not None: x = self.norm(x)
        return x

## Embedding

In [None]:
#export
class AbsolutePositionalEmbedding(Module):
    """Learnable absolute positional encodings"""
    def __init__(self, d_emb:int, max_seq_len:int):
        self.emb = nn.Embedding(max_seq_len, d_emb)

    def forward(self, x):
        t = torch.arange(x.shape[1], device=x.device)
        return self.emb(t)

In [None]:
#export
class FixedPositionalEmbedding(Module):
    """Fixed positional encodings"""
    def __init__(self, d_emb:int):
        inv_freq = 1. / (10000 ** (torch.arange(0, d_emb, 2).float() / d_emb))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, x):
        t = torch.arange(x.shape[1], device=x.device).type_as(self.inv_freq)
        sinusoid_inp = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
        return emb[None, :, :]

In [None]:
#export
class TransformerEmbedding(Module):
    """
    Combines token embedings with positional encodings
    pos_enc: str from {'absolute', 'fixed', 'axial'}
    """
    def __init__(self, 
                 emb_sz:int, 
                 d_emb:int, 
                 max_seq_len:int=512, 
                 dropout:float=0., 
                 pos_enc:str='absolute', 
                 axial_shape:Tuple=None, 
                 axial_emb_dims:Tuple=None):
        store_attr('d_emb')
        self.scale = d_emb ** 0.5
        self.std = 0.02    # fairseq: d_emb ** -0.5, fastai: 0.01
        self.emb = nn.Embedding(emb_sz, d_emb)
        self.dropout = nn.Dropout(dropout)
        
        if pos_enc == 'absolute': self.pos_enc = AbsolutePositionalEmbedding(d_emb, max_seq_len)
        elif pos_enc == 'fixed': self.pos_enc = FixedPositionalEmbedding(d_emb)
        elif pos_enc == 'axial':
            assert axial_shape is not None
            assert reduce(mul, axial_shape) == max_seq_len
            axial_emb_dims = default(axial_emb_dims, get_axial_dims(d_emb, len(axial_shape)))
            self.pos_enc = AxialPositionalEmbedding(d_emb, axial_shape, axial_emb_dims)
        self._init()
        
    def forward(self, x):
        x = self.emb(x)  #* self.scale
        x *= self.scale 
        x += self.pos_enc(x)
        return self.dropout(x)
    
    def _init(self):
        nn.init.trunc_normal_(self.emb.weight, std = self.std)
        if hasattr(self.pos_enc, 'weight'): nn.init.trunc_normal_(self.pos_enc.weight, std = self.std)

In [None]:
#hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_core.ipynb.
Converted 01_layers.ipynb.
Converted 02_models.ipynb.
Converted index.ipynb.
