In [None]:
import torch
import torch.nn as nn
import math, copy, time
import pandas as pd
import altair as alt

# Encoder-Decoder Architecture
A standard encoder decoder architecture as set out in the paper "Attention is all you need" at https://arxiv.org/pdf/1706.03762.pdf
Note: At each step the model is auto-regressive (consuming prev generated symbol as new input)
Args: 
- encoder: (nn.Module)
    - neural net that takes in a sequence of symbol representations `(x1, x2, .... xn)` and outputs a continuous representation `z = (z1, z2, .... zn)`
    - token embeddings form the symbol sequence, and single vector which is analogous to a one hot vector of the composite sentence is the continuous representation
    - takes input the `source embedding` and the `mask` as the embeddings are padded to a constant size
- decoder: (nn.Module)
    - takes the continuous representation `z` and generates an output sequence `(y1, y2, ... ym)`
- generator: (nn.Module)
    - takes the probability distribution outputted by the decoder and generates the text token
    
<div>
<img src="img/enc-dec.svg" width="200"/>
</div>


In [None]:
# Hyperparams
D_MODEL = 16       # number of dimensions handled by the network
N = 6              # number of encoders in the encoder stack

In [None]:
class EncoderDecoder(nn.Module):
    """
    Base class that implements a black-box encoder decoder architecture as set out in the transformers paper.
    Translation use case
    """
    
    def __init__(self, encoder, decoder, src_embeddings, tgt_embeddings, generator):
        super().__init__()
        self.encoder = encoder 
        self.decoder = decoder 
        self.generator = generator
        self.src_embeddings = src_embeddings        # embeddings table for input tokens
        self.tgt_embeddings = tgt_embeddings        # embeddings table for target tokens
        
    def forward(self, x):
        return self.decode(self.encode(x,...),...)
        
    def encode(self):
        return self.encoder(...)
    
    def decode(self):
        return self.decoder(...)

class Generator(nn.Module):
    """
    Does a linear + softmax operation to output the tokens
    """
    
    #! see if you can auto-initialize this with hyperparams
    def __init__(self, model_dims, vocab_size):
        super().__init__()
        
        self.linear = nn.Linear(model_dims, vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)
    
    def forward(self, x):
        return self.softmax(self.layer(x))
        

<div>
<img src="img/encoder_1.1.svg" width="300"/>
</div>

In [None]:
def clone(module_to_clone, num_of_clones):
    assert isinstance(num_of_clones, numbers.Integral)
    return nn.ModuleList([copy.deepcopy(module_to_clone) for _ in range(num_of_clones)])

class Encoder(nn.Module):
    " core encoder which is a stack of 6 individual encoders in sequence"
    
    def __init__(self, encoder_layer, N):
        super().__init__()
        self.encoder_stack = clone(encoder_layer, N)
        
    def forward(self, x):
        for encoder in encoder_stack:
            x = encoder(x)
        return x

We employ a residual connection (cite) around each of the two sub-layers, followed by layer normalization (cite).


<div>
    <img src="img/enc-sublayers.svg" width='400'/>
</div>

`LayerNorm` and `BatchNorm` are similar but different in where they apply their normalization. In Batch norm, the `median` and the `standard deviation` are applied across the incoming batch, whereas in LayerNorm, the same statistics are calculated across the dimensions of the input to the layer. 
- Hence BatchNorm is across the batch 
- LayerNorm is across each input. 

This is a good article explaining the differences:
https://www.pinecone.io/learn/batch-layer-normalization/

To address this, batch normalization introduces two parameters: a scaling factor gamma (γ) and an offset beta (β). These are learnable parameters, so if the fluctuation in input distribution is necessary for the neural network to learn a certain class better, then the network learns the optimal values of gamma and beta for each mini-batch.

$$ 
\begin{align}
\mu_l = \frac{1}{d}\sum_{i=1}^{d}x_i \text{}\text{ } (1)\\ \sigma_l^2 = \frac{1}{d}\sum_{i=1}^{d}(x_i - \mu_l)^2 \text{}\text{ } (2)\\ \hat{x_i} = \frac{x_i - \mu_l}{\sqrt{\sigma_l^2}} \text{}\text{ } (3)\\ or\text{ }\hat{x_i} = \frac{x_i - \mu_l}{\sqrt{\sigma_l^2 + \epsilon}} \text{}\text{ } (3) \\ Adding\text{ }\epsilon\text{ }helps\text{ }when\text{ }\sigma_l^2\text{ }is\text{ }small\\ y_i = \mathcal{LN}(x_i) = \gamma.x_i + \beta \text{}\text{ }(4)
\end{align}
 $$

<div>
    <img src="https://d33wubrfki0l68.cloudfront.net/5863322b42dcdf4b45ffef4de43f6ef0385db477/e6251/images/batch-normalization-example.png" width='400'/>
    <img src="https://d33wubrfki0l68.cloudfront.net/c8f1f7a886548f82234f8a3b06faeecfbb88c657/42d49/images/layer-normalization.png" width='400'/>
</div>
    


In [None]:
class LayerNorm(nn.Module):
    """
    LayerNorm module
    """
    
    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.eps = eps         # handle division by zero
        self.gamma = nn.Parameter(torch.ones(features))        # scaling factor that the the network learns
        self.beta = nn.Parameter(torch.zeros(features))        # offset factor that the the network learns
    
    def forward(self, x):
        std = x.std(dim=-1, keepdim=True)
        mean = x.mean(dim=-1, keepdim=True)
        return self.gamma * (x-mean)/(std + eps) + self.beta

That is, the output of each sub-layer is $\mathrm{LayerNorm}(x + \mathrm{Sublayer}(x))$, where $\mathrm{Sublayer}(x)$ is the function implemented by the sub-layer itself. We apply dropout (cite) to the output of each sub-layer, before it is added to the sub-layer input and normalized.

To facilitate these residual connections, all sub-layers in the model, as well as the embedding layers, produce outputs of dimension d_{\text{model}}=512d 
model =512.

In [None]:
class Sublayer(nn.Module):
    
    def __init__(self, size, dropout):
        super().__init__()
        self.layernorm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, sublayer_protagonist):
#       return self.layernorm(x + self.dropout(sublayer_protagonist(x)))
#       return x + self.dropout(sublayer_protagonist(x))
        return x + self.dropout(sublayer_protagonist(self.norm(x)))
        

Each layer has two sub-layers. The first is a multi-head self-attention mechanism, and the second is a simple, position-wise fully connected feed-forward network.

In [None]:
class EncoderLayer(nn.Module):
    """
    Collection of two sublayers that make up a single Encoder layer
    """
    
    def __init__(self, size, multihead_self_attention, feed_forward, dropout):
        super().__init__()
        self.attention = multihead_self_attention
        self.feed_forward = feed_forward 
        self.sublayers = clone(Sublayer(size, dropout), 2)
        self.size = size 
    
    def forward(self, x):
        # sublayer 1's output 
        x = self.sublayers[0](x, self.attention(...))
        return self.sublyers[1](x, self.feed_forward)

# Decoder
The decoder is also composed of a stack of N=6 identical layers.

In [None]:

class Decoder(nn.Module):
    " core decoder which is a stack of 6 individual decoders in sequence"
    
    def __init__(self, decoder_layer, N):
        super().__init__()
        self.decoder_stack = clone(decoder_layer, N)
        
    def forward(self, x):
        for decoder in decoder_stack:
            x = decoder(x)
        return x

In addition to the two sub-layers in each encoder layer, the decoder inserts a third sub-layer, which performs multi-head attention over the output of the encoder stack. Similar to the encoder, we employ residual connections around each of the sub-layers, followed by layer normalization.

In [None]:
class DecoderLayer(nn.Module):
    
    """
    This defines each layer of the decoder that is composed of 3 sublayers each
    """
    
    def __init__(
        self, 
        size, 
        masked_multi_head_attention, 
        multi_head_attention, 
        feed_forward, 
        dropout
    ):
        
        
        super().__init__()
        self.masked_attention = masked_multi_head_attention
        self.attention = multi_head_attention
        self.feed_forward = feed_forward
        self.sublayers = clone(Sublayer(size, dropout), 3)
    
    def forward(self, x):
        x = self.sublayers[0](x, self.masked_attention(...))
        x = self.sublayers[1](x, self.attention(...))
        x = self.sublayers[2](x, self.feed_forward(...))
        return x

We also modify the self-attention sub-layer in the decoder stack to prevent positions from attending to subsequent positions. This masking, combined with fact that the output embeddings are offset by one position, ensures that the predictions for position `i` can depend only on the known outputs at positions less than `i`.

In [None]:
def forward_mask(size):
    """
    Mask to prevent current word being affected by words after it. Only past words affect.
    
    If input is:
    [*,*,*,*]
    [*,*,*,*]
    [*,*,*,*]
    [*,*,*,*]
    
    return:
    [False, False, False, False]
    [ True, False, False, False]
    [ True,  True, False, False]
    [ True,  True,  True, False]
    """
    attn_shape = (1, size, size)
    mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(torch.uint8)
    return mask==0


In [None]:
# a visualization of a sample mask
def show_sample_mask(size):
    LS_data = pd.concat(
    [
            pd.DataFrame(
                {
                    "Forward Mask": forward_mask(size)[0][x, y].flatten(),
                    "Window": y,
                    "Masking": x,
                }
            )
            for y in range(size)
            for x in range(size)
        ]
    )
    
    return (
        alt.Chart(LS_data)
        .mark_rect()
        .properties(height=250, width=250)
        .encode(
        alt.X("Window:O"),
        alt.Y("Masking:O"),
        alt.Color("Forward Mask:Q", scale=alt.Scale(scheme="viridis"))
        )
        .interactive()
        
        )


show_sample_mask(20)

## Attention
An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.

We call our particular attention “Scaled Dot-Product Attention”. The input consists of queries and keys of dimension $d_k$, and values of dimension $d_v$. We compute the dot products of the query with all keys, divide each by $\sqrt{d_k}$, and apply a softmax function to obtain the weights on the values. 

$$
\mathrm Attention(Q,K,V) = \mathrm softmax(\frac{QK^T}{\sqrt{d_k}})V
$$

<div>
    <img src="img/attention.svg" width='450'/>
</div>

In [None]:
def attention(query, key, value, mask=None, dropout: nn.Module=None):
    d_k = query.shape[-1]
    x = torch.matmul(query,key.transpose(-2,-1))/math.sqrt(d_k)
    if mask is not None:
        x = x.masked_fill(mask==0, 1e-9)
    x = x.softmax(dim=-1)
    if dropout is not None:
        x = dropout(x)
    scaled_attention = torch.matmul(x,value)
    
    return scaled_attention

The two most commonly used attention functions are additive attention (cite), and dot-product (multiplicative) attention. Dot-product attention is identical to our algorithm, except for the scaling factor of $\frac{1}{\sqrt{d_k}}$. Additive attention computes the compatibility function using a feed-forward network with a single hidden layer. While the two are similar in theoretical complexity, dot-product attention is much faster and more space-efficient in practice, since it can be implemented using highly optimized matrix multiplication code.

While for small values of $d_k$ the two mechanisms perform similarly, additive attention outperforms dot product attention without scaling for larger values of $d_k$ (cite). We suspect that for large values of $d_k$, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients (To illustrate why the dot products get large, assume that the components of qq and kk are independent random variables with mean 00 and variance 11. Then their dot product, $q \cdot k = \sum_{i=1}^{d_k} q_ik_i$, has mean 0 and variance $d_k$ 
 .). To counteract this effect, we scale the dot products by $\frac{1}{\sqrt{d_k}}$


Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.

$$
\mathrm{MultiHead}(Q, K, V) = \mathrm{Concat}(\mathrm{head_1}, ..., \mathrm{head_h})W^O \\ \text{where}~\mathrm{head_i} = \mathrm{Attention}(QW^Q_i, KW^K_i, VW^V_i)
MultiHead(Q,K,V)=Concat(head 
$$

Where the projections are parameter matrices

$$
W_i^Q \in \mathbb{R}^{{d_{\text{model}}} \times d_k}
$$



In [None]:
class MultiHeadAttention(nn.Module):
    
    """
    The multiple attention heads that each attend to different subspaces of the vector dimensions and concatenate it at the end.
    
    - The dimensions of the query and key are equal. d_v = d_k = d_model 
    - If there are h attention heads, then each head attends to d_model/h dimensions (vector subspace)
    - Additonal each of these subspaces are weighted by a linear layer that learns the relative importance of each dimension.
    """
    
    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
        assert d_model%h == 0
        # assuming d_v always equal to d_k -----> ! is this ever not true 
        self.d_k = d_model // h       # this is the number of dimensions each head will attend to 
        self.h = h

        # we need one linear for Q, K, V each and one for the concatenated output of the attention
        self.linears = clone(nn.Linear(d_model, d_model), 4)    
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
    
    def forward(self, query, key, value, mask=None):
        
        if mask is not None:
            mask.unsqueeze(1)    # to account for batching
        
        num_batches = query.shape[0]
        
        #1) Do all the linear projections and split into num of attention heads. d_model => h x d_k
        
        ## this is essentially reshaping the output so that the different attention heads can attend to their respective vector subspaces
        query, key, value = [
            lin(x).view(num_batches, -1, self.h, self.d_k).transpose(1,2) 
            for lin, x in zip(self.linears, (query, key, value))
                            ]
        
        #2) apply attention on all the vectors in the batch 
        x, self.attn = attention(query, key, value, mask, self.dropout)
        
        #3) Concat the vectors, using a view and apply final linear layer 
        x = (
            x.transpose(1,2)
            .contiguous()
            .view(num_batches, -1, self.h * self.d_k)
        )
        
        del query
        del value
        del key
        
        return self.linears[-1](x)
        

In [None]:
a = 10
a//3, a%3

In [None]:
a = torch.arange(25).reshape(5,5)
a = a.squeeze(0)
a.transpose(-2,-1) == a.transpose(-1,-2)
a, a.transpose(), a.transpose(2,1)