In [None]:
# !! TO MODIFY default_exp xformer

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#hide
%load_ext autoreload
%autoreload 2

In [None]:
#export
from fastai.basics import *
from transformers_sandbox.core import *
from transformers_sandbox.layers import *
from transformers_sandbox.attention import *
from transformers_sandbox.transformer import LMMixin, EncDecMixin

# MODELNAME

> This is template notebook intended to speedup adding new model architectures

## Helpers

In [None]:
# !! example

#export
class ChunkedFeedForward(Module):
    "Applies positionwise feed-forward layer to input chunced along dim"
    def __init__(self, d:int, d_ff:int=None, n_chunks:int=1, dropout:float=0., dim:int=-1):
        store_attr('n_chunks,dim')
        d_ff = default(d_ff, 4*d)
        self.net = nn.Sequential(
            nn.Linear(d, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d),
            nn.Dropout(dropout)
            )
    def forward(self, x, **kwargs):
        if self.n_chunks == 1:
            return self.net(x)
        chunks = x.chunk(self.n_chunks, dim = self.dim)
        return torch.cat([self.net(c) for c in chunks], dim = self.dim)

In [None]:
# !! don't forget the tests

bs = 4
sl = 64
d = 128
x = torch.randn(bs, sl, d)
ff  = ChunkedFeedForward(d, n_chunks=8, dim=1)
out = ff(x)
assert out.size() == (bs, sl, d)

## Bricks

> Architecture specific layers, blocks and containers. Consider moving general purpose layers and attention modules to `01_layers` and `02*_attention_func_name` notebooks respectively

### Encoder

In [None]:
#export
class TransformerEncoderBlock(Module):
    """
    Bacis transformer encoder block. Consists of multi-head attention and positional 
    feedforward layers
    """
    def __init__(self,
                 d_model:int, 
                 n_heads:int = 8, 
                 d_ff:int = None, 
                 attn_dropout:float = 0.1,
                 ff_dropout:float = 0.1,
                 causal:bool = False, 
                 attn_bias:bool = False, 
                 prenorm:bool=False,
                 shared_qk:bool=False):
        store_attr('attn_dropout') # mb separate argument attn_post_dropout
        if prenorm:
            self.attn = Residual(PreNorm(d_model, Attention(d_model, n_heads=n_heads, causal=causal, dropout=attn_dropout, bias=attn_bias, shared_qk=shared_qk)))
            self.ff = Residual(PreNorm(d_model, FeedForward(d_model, d_ff=d_ff, dropout=ff_dropout)))
        else:
            self.attn = PostNorm(d_model, Residual(Attention(d_model, n_heads=n_heads, causal=causal, dropout=attn_dropout, bias=attn_bias, shared_qk=shared_qk)))
            self.ff = PostNorm(d_model, Residual(FeedForward(d_model, d_ff=d_ff, dropout=ff_dropout)))
        
    def forward(self, x, mask=None): #? more args
        out = self.attn(x, mask=mask)
        return self.ff(out)

In [None]:
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
m = TransformerEncoderBlock(d)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape

torch.Size([4, 128, 64])

In [None]:
#hide
m = TransformerEncoderBlock(d, shared_qk=True)
out = m(x)
assert (out.size() == (bs, sl, d))

In [None]:
#export
class TransformerEncoder(Module):
    """Stack of TransformerEncoderBlocks"""
    def __init__(self, 
                 d_model, 
                 n_layers=6, 
                 n_heads=8, 
                 d_ff=None,
                 ff_dropout=0.1, 
                 attn_dropout=0.1,
                 attn_bias=False,
                 causal=False, 
                 prenorm=False,
                 shared_qk:bool=False,
                 final_norm=None):
        store_attr('d_model')
        self.layers = nn.ModuleList([])    
        for _ in range(n_layers):
            self.layers.append(TransformerEncoderBlock(d_model, n_heads, causal=causal, 
                                    d_ff=d_ff, attn_dropout=attn_dropout, ff_dropout=ff_dropout, 
                                    prenorm=prenorm, attn_bias=attn_bias, shared_qk=shared_qk))
        self.norm = None if final_norm is None else final_norm(d_model)
        
    def forward(self, x, mask=None):
        for layer in self.layers: x = layer(x, mask=mask)
        if self.norm is not None: x = self.norm(x)
        return x

In [None]:
x = torch.randn(bs, sl, d)
m = TransformerEncoder(d)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape

torch.Size([4, 128, 64])

### Decoder

In [None]:
#export
class TransformerDecoderBlock(Module):
    """
    Standart transformer decoder block. Consist of self-attention, encoder-decoder attention 
    and positiona feed-forward alyers
    """
    def __init__(self, 
                 d_model, 
                 n_heads = 8, 
                 d_ff = None,
                 attn_dropout = 0.1, 
                 ff_dropout=0.1,
                 mask = None ,
                 attn_bias=False,
                 prenorm=False):
        # mb separate argument attn_post_dropout
        if prenorm:
            self.attn = Residual(PreNorm(d_model, Attention(d_model, n_heads=n_heads, causal=True, dropout=attn_dropout, bias=attn_bias)))
            self.cross = Residual(PreNorm(d_model, Attention(d_model, n_heads=n_heads, causal=False, dropout=attn_dropout, bias=attn_bias)))
            self.ff = Residual(PreNorm(d_model, FeedForward(d_model, d_ff=d_ff, dropout=ff_dropout)))
        else:
            self.attn = PostNorm(d_model, Residual(Attention(d_model, n_heads=n_heads, causal=True, dropout=attn_dropout, bias=attn_bias)))
            self.cross = PostNorm(d_model, Residual(Attention(d_model, n_heads=n_heads, causal=False, dropout=attn_dropout, bias=attn_bias)))
            self.ff = PostNorm(d_model, Residual(FeedForward(d_model, d_ff=d_ff, dropout=ff_dropout)))
        
    def forward(self, x, context, mask=None, context_mask=None):
        out = self.attn(x, mask=mask)
        out = self.cross(out, context, mask=mask, context_mask=context_mask)
        return self.ff(out)

In [None]:
#export
class TransformerDecoderBlockV2(Module):
    """Transformer decoder block using additive attention layer instead of self-attention 
    followed by cross-attention"""
    def __init__(self,
                 d_model, 
                 n_heads = 8, 
                 mask = None, 
                 d_ff=None,
                 attn_dropout=0.1, 
                 ff_dropout=0.1, 
                 attn_bias=False,
                 prenorm=False):
        if prenorm:
            self.attn = Residual(PreNorm(d_model, AdditiveAttention(d_model, n_heads=n_heads, causal=True, dropout=attn_dropout, bias=attn_bias)))
            self.ff = Residual(PreNorm(d_model, FeedForward(d_model, d_ff=d_ff, dropout=ff_dropout)))
        else:
            self.attn = PostNorm(d_model, Residual(AdditiveAttention(d_model, n_heads=n_heads, causal=True, dropout=attn_dropout, bias=attn_bias)))
            self.ff = PostNorm(d_model, Residual(FeedForward(d_model, d_ff=d_ff, dropout=ff_dropout)))
        
    def forward(self, x, context, mask=None, context_mask=None):
        out = self.attn(x, context, mask=mask, context_mask=context_mask)
        out = self.ff(out)
        return out

In [None]:
#export   
class TransformerDecoder(Module):
    """Stack of TransformerDecoder layers"""
    def __init__(self, 
                 d_model, 
                 n_layers=6, 
                 n_heads=8, 
                 d_ff=None, 
                 attn_dropout=0.1, 
                 ff_dropout=0.1, 
                 prenorm=False, 
                 comb_attn=False, 
                 attn_bias=False, 
                 final_norm=None):
        store_attr('d_model')
        #TODO(Arto) refactor
        block = TransformerDecoderBlockV2 if comb_attn else TransformerDecoderBlock
        self.layers = nn.ModuleList([])
        for _ in range(n_layers):
            self.layers.append(block(d_model, n_heads, d_ff=d_ff, attn_dropout=attn_dropout, 
                                     ff_dropout=ff_dropout, prenorm=prenorm, attn_bias=attn_bias))
        self.norm = None if final_norm is None else final_norm(d_model)
        
    def forward(self, x, context, mask=None, context_mask=None):
        for layer in self.layers: x = layer(x, context, mask, context_mask)
        if self.norm is not None: x = self.norm(x)
        return x

In [None]:
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl, d)
m = TransformerDecoder(d)
out = m(x, context)
assert (out.size() == (bs, sl, d))
out.shape

torch.Size([4, 128, 64])

## Models

### Language model

In [None]:
#export
class TransformerLM(Module, LMMixin):
    """
    Basic Transformer for language modelling
    
    Parameters:
        * vocab_sz: int
        * d_model: int - inner dimension of the model
        * n_layers: int (default: 6) 
        * n_heads: int (default: 8)
        * d_ff: int - inner dimension of the pointwise FeedForward net, if None defaults to 4*d_model
        * attn_dropout: float - attention dropout
        * ff_dropout: float - feed-forward dropout
        * emb_dropout: float - embedding dropout
        * causal: bool (default: True) - if True does causal masking automatically
        * max_seq_len: int (default: 512)
        * tie_weights: bool - if True target embedding weights are used for computation output projection
        * prenorm: bool - wether to use PreNorm or PostNorm
        * attn_bias: bool - wether to allow biases in attention projection layers
        * pad_idx: int - padding token id, required for autogeneration of padding mask
        * pos_enc: str from {'absolute', 'fixed', 'axial'} - type of positional encoding to use
        * axial_shape: tuple - [optional] should be factors of max_seq_len
        * axial_emb_dims: tuple - [optional] axial embedding components, should sum to d_model
    Inputs:
        * x - input ids, shape [bs, sl]
        * mask - optional boolean mask, shape [bs, sl]
    Returns:
        * logits - target token logits, shape [bs, sl, vocab_sz]
    """
    def __init__(self, 
                 vocab_sz:int, 
                 d_model:int, 
                 n_layers:int=6,
                 n_heads:int=8,
                 d_ff:int=None,
                 attn_dropout:float=0.1,
                 ff_dropout:float=0.1,
                 emb_dropout:float=0.1,
                 tie_weights:bool=True,
                 causal:bool=True,
                 pos_enc:str='absolute',
                 max_seq_len:int=512,
                 axial_shape:tuple=None,
                 axial_emb_dims:tuple=None,
                 pad_idx:int=None,
                 prenorm:bool=False,
                 attn_bias:bool=False,
                 shared_qk:bool=False):
        store_attr()
        self.emb = TransformerEmbedding(vocab_sz, d_model, max_seq_len, dropout=emb_dropout, 
                                        pos_enc=pos_enc, axial_shape=axial_shape, 
                                        axial_emb_dims=axial_emb_dims)
        final_norm = nn.LayerNorm if prenorm else None
        self.encoder = TransformerEncoder(d_model, n_layers, n_heads, causal=causal, d_ff=d_ff,
                                          attn_dropout=attn_dropout, ff_dropout=ff_dropout,
                                          prenorm=prenorm, attn_bias=attn_bias,
                                          shared_qk=shared_qk, final_norm=final_norm)
        self.proj = nn.Linear(d_model, vocab_sz)
        if tie_weights: self.proj.weight = self.emb.emb.weight
        
    def forward(self, x, mask=None):
        x = self.emb(x)
        x = self.encoder(x, mask=mask)
        return self.proj(x)


In [None]:
bs = 4
sl = 128
d = 64
vocab_sz = 256
x = torch.randint(vocab_sz, (bs, sl))
model = TransformerLM(vocab_sz, d, n_layers=2, causal=False)
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))
out.shape

torch.Size([4, 128, 256])

In [None]:
#hide
model = TransformerLM(vocab_sz, d, n_layers=2, causal=True, prenorm=True)
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))

In [None]:
#hide
model = TransformerLM(vocab_sz, d, n_layers=2, causal=True, shared_qk=True)
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))
assert isinstance(model.encoder.layers[0].attn.sublayer.sublayer.in_proj, SharedQKAttnInProj)

In [None]:
#export
def transformer_lm_splits(model):
    "Splits TransformerLM `model` into groups for differential learning rates."
    groups = L([model.emb] + [l for l in model.encoder.layers] + [model.proj])
    return groups.map(params)

In [None]:
#hide
assert len(transformer_lm_splits(model)) == 2+2

### Encoder-Decoder model

In [None]:
#export
class Transformer(Module, EncDecMixin):
    """
    Basic Transformer Encoder-Decoder model
    Parameters:
        * enc_vocab_sz: int - source vocab size 
        * dec_vocab_sz: int - target vocab size
        * d_model: int - inner dimension of the model
        * n_enc_layers: int (default: 6) 
        * n_dec_layers: int (default: 6) 
        * heads: int (default: 8)
        * d_ff: int - inner dimension of the pointwise FeedForward net, if None defaults to 4*d_model
        * attn_dropout: float - attention dropout
        * ff_dropout: float - feed-forward dropout
        * emb_dropout: float - embedding dropout
        * max_seq_len: int (default: 512)
        * prenorm: bool - whether to use PreNorm or PostNorm
        * attn_bias: bool - whether to allow biases in attention projection layers
        * pad_idx: int - padding token id, if pad_idx is provided, and no mask/context_mask are 
                passed to forward method will be used to generate padding masks
        * tie_weights: bool - if True target embedding weights are used for computation output projection
        * shared_emb: bool - if True encoder and decoder will use shared embedding layer
        * pos_enc: str from {'absolute', 'fixed', 'axial'} - type of positional encoding to use
        * axial_shape: tuple - [optional] should be factors of max_seq_len
        * axial_emb_dims: tuple - [optional] axial embedding components, should sum to d_model
    Inputs:
        * src - source input ids, shape [bs, src_sl]
        * tgt - target input ids, shape [bs, tgt_sl]
        * src_mask - optional boolean source mask, shape [bs, src_sl]
        * tgt_mask - optional boolean target mask, shape [bs, tgt_sl]
    Returns:
        * logits - target token logits, shape [bs, tgt_sl, tgt_vocab_sz]
    """
    def __init__(self, 
                 enc_vocab_sz, 
                 dec_vocab_sz, 
                 d_model, 
                 n_enc_layers=6, 
                 n_dec_layers=6, 
                 n_heads=8, 
                 d_ff=None,
                 pad_idx=None, 
                 tie_weights=True,
                 shared_emb = False,
                 attn_dropout=0.1, 
                 ff_dropout=0.1, 
                 emb_dropout=0.1,
                 prenorm=False, 
                 attn_bias=False,
                 comb_attn=False, 
                 pos_enc='absolute', 
                 max_seq_len=512, 
                 axial_shape=None, 
                 axial_emb_dims=None):
        store_attr()
        self.enc_emb = TransformerEmbedding(enc_vocab_sz, d_model, max_seq_len, dropout=emb_dropout, pos_enc=pos_enc,
                                            axial_shape=axial_shape, axial_emb_dims=axial_emb_dims)
        if shared_emb:
            assert (enc_vocab_sz == dec_vocab_sz), "Encoder and decoder vocab size doesn't match"
            self.dec_emb = self.enc_emb
        else:
            self.dec_emb = TransformerEmbedding(dec_vocab_sz, d_model, max_seq_len, dropout=emb_dropout, pos_enc=pos_enc,
                                                axial_shape=axial_shape, axial_emb_dims=axial_emb_dims)
        final_norm = nn.LayerNorm if prenorm else None
        self.encoder = TransformerEncoder(d_model, n_enc_layers, n_heads, d_ff=d_ff, attn_dropout=attn_dropout, ff_dropout=ff_dropout, 
                                          prenorm=prenorm, attn_bias=attn_bias, final_norm=final_norm, causal=False)
        self.decoder = TransformerDecoder(d_model, n_dec_layers, n_heads, d_ff=d_ff, attn_dropout=attn_dropout, ff_dropout=ff_dropout, 
                                          prenorm=prenorm, comb_attn=comb_attn, attn_bias=attn_bias, final_norm=final_norm)
        self.proj = nn.Linear(d_model, dec_vocab_sz)
        if tie_weights: self.proj.weight = self.dec_emb.emb.weight

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        src_mask = default(src_mask, self.get_padding_mask(src))
        tgt_mask = default(tgt_mask, self.get_padding_mask(tgt))
        enc = self.encoder(self.enc_emb(src), mask=src_mask)
        out = self.decoder(self.dec_emb(tgt), context=enc, mask=tgt_mask, context_mask=src_mask)
        return self.proj(out)
    
    def get_padding_mask(self, x):
        if self.pad_idx is None: return None
        return (x != self.pad_idx)


In [None]:
bs = 4
src_sl = 70
tgt_sl = 80
d = 64
src_vocab_sz = 256
tgt_vocab_sz = 256
src = torch.randint(src_vocab_sz, (bs, src_sl))
tgt = torch.randint(tgt_vocab_sz, (bs, tgt_sl))
model = Transformer(src_vocab_sz, tgt_vocab_sz, d, n_enc_layers=2, n_dec_layers=2)
out = model(src, tgt)
assert (out.size() == (bs, tgt_sl, tgt_vocab_sz))
out.shape

torch.Size([4, 80, 256])

In [None]:
#export
#TODO find out what is the best way to split encoder-decoder architecture
def transformer_splits(model):
    "[v0] Splits Transformer `model` into groups for differential learning rates."
    groups = L([nn.ModuleList([model.enc_emb, model.dec_emb])] + [l for l in model.encoder.layers] + [l for l in model.decoder.layers] + [model.proj])
    return groups.map(params)

In [None]:
#hide
assert len(transformer_splits(model)) == 2+4

In [None]:
#hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_core.ipynb.
Converted 01_layers.ipynb.
Converted 02_attention.ipynb.
Converted 03_transformer.ipynb.
Converted 04_reformer.ipynb.
Converted 05_tokenizers.ipynb.
Converted 06_data.ipynb.
Converted 07_metrics.ipynb.
Converted 08_optimizers.ipynb.
Converted 09_tracking.ipynb.
Converted index.ipynb.
