In [2]:
%load_ext autoreload

In [3]:
import torch
from torch import nn
import torch.nn.functional as F

## Self attention
sequence $x$ is $b x t x k$. batch sequence length and dimension
Follwoing we add linear transformations in order to provide learnable parameters.
* $W_q$ transform $x_i$ to $q_i$ query vector. Compared to every other vector inorder to establish weights for its own output $y_i$
* $W_k$ transform $x_i$ to $k_i$ key vector. Used for $x_i$ being compared to the queries, to establish the weights for other $y_j$
* $W_v$ transform $x_i$ to $v_i$ vector to be weighted and that actually encodes the information.

<img src="images/learnable-weights.png"  width="500" height="600">

Hence, the self attention layer for each $y_i$ will be the following


<img src="images/learnable-structure.png"  width="500" height="600">

Small tricks:

* Divide by $\sqrt{k}$ to reduce input values of the softmax (as dimension increases it reduces by the euclidean lenght, read bibliography for more explanations)
* Multi-head attention: learn many query, keys values for each input (paralllel self attention) and concatenate at the end. Allows to focus and learn different queries for each input. Each $W_q^r$, $W_k^r$ and $W_v^r$ is an attention head.




In [10]:
class SelfAttention(nn.Module):
    def __init__(self, k, heads=10):
        super().__init__()
        #Compute Linear Transformations
        self.transform_queries = nn.Linear(k, k*heads, bias=False)
        self.transform_keys = nn.Linear(k, k*heads, bias=False)
        self.transform_values = nn.Linear(k, k*heads, bias=False)
        # Linear Transform that reduces dimensionality
        self.dimension_reduce = nn.Linear(k*heads, k, bias=False)
        self.k = k
        self.h = heads
        
    def forward(self,x): #input b x t x k
        b, t, k = x.size()
        h = self.h
        
        #Transform: b,t,k => b,t,k*h 
        queries = self.transform_queries(x)
        keys = self.transform_keys(x)
        values = self.transform_values(x)
        #Separate heads from dimension
        # b,t,k*h => b,t,h,k
        queries = queries.view(b,t,h,k)
        keys = keys.view(b,t,h,k)
        values = values.view(b,t,h,k)
        #Matrix multiplication for each batch and each head. Hence we merge heads and batch in order tu use torch.bmm
        # Transpose b,t,h,k => b,h,t,k
        # Merge dim b,h,t,k => b*h,t,k
        queries = queries.transpose(1,2).contiguous().view(b*h,t,k)
        keys = keys.transpose(1,2).contiguous().view(b*h,t,k)
        values = values.transpose(1,2).contiguous().view(b*h,t,k)
        # Scale
        queries = queries /(k**(0.25))
        keys = keys/(k**(0.25))
        
        #Use torch batch matrix mult, performs a matrix mult for each elemt of the batch.
        #Batch consist of batch sample and head
        # b*h,t,k x b*h,k,t => b*h,t,t 
        weights = torch.bmm(queries,keys.transpose(1,2))
        soft_weights = F.softmax(weights, dim=2)
        #Multiply weights b*h,t,t x b*h,t,k => b*h,t,k 
        #(each row weight contains weights for each row vectors, row weights linear combination of rowvectors) 
        output = torch.bmm(soft_weights, values).view(b, h, t, k)
        #Merge the h heads on the k dimension
        #Transpose b,h,t,k =>  b,t,h,k 
        #View(merge) b,t,h,k => b,t,h*k 
        output = output.transpose(1,2).contiguous().view(b,t,h*k)
        #Reduce dimension b,t,h*k => b,t,k
        return self.dimension_reduce(output)


In [12]:
# Fast check
self_attention = SelfAttention(k=10,heads=20)
sample_input = torch.rand(8,5,10)
output = self_attention(sample_input)
output, output.shape

(tensor([[[-0.0158, -0.1386, -0.0345, -0.2060, -0.1662, -0.0297, -0.0587,
            0.0035,  0.0146,  0.0096],
          [-0.0151, -0.1365, -0.0336, -0.2070, -0.1653, -0.0281, -0.0588,
            0.0049,  0.0124,  0.0078],
          [-0.0152, -0.1378, -0.0355, -0.2058, -0.1654, -0.0293, -0.0593,
            0.0035,  0.0136,  0.0091],
          [-0.0155, -0.1352, -0.0342, -0.2078, -0.1673, -0.0283, -0.0595,
            0.0043,  0.0115,  0.0077],
          [-0.0163, -0.1348, -0.0337, -0.2076, -0.1673, -0.0286, -0.0591,
            0.0064,  0.0126,  0.0060]],
 
         [[-0.0696, -0.1451, -0.0073, -0.2360, -0.0398, -0.0289,  0.0026,
           -0.0469,  0.0683, -0.0650],
          [-0.0713, -0.1461, -0.0069, -0.2358, -0.0411, -0.0280,  0.0022,
           -0.0500,  0.0693, -0.0634],
          [-0.0695, -0.1478, -0.0072, -0.2371, -0.0425, -0.0255,  0.0025,
           -0.0532,  0.0684, -0.0637],
          [-0.0699, -0.1484, -0.0040, -0.2378, -0.0415, -0.0244,  0.0030,
           -0.0544,

## Transformers

"Any architecture designed to process a connected set of units—such where the only interaction between units is through self-attention"
<img src="images/transformers.png"  width="500" height="600">

It is to say seq-to-seq all at once

Basic structure:
* Self attention
* Norm layer
* FeedForward Net
* Norm Layer

Additionaly residual blocks and keypoint is to combine feedforward with self attention