In [2]:
%load_ext autoreload

In [3]:
import torch
from torch import nn
import torch.nn.functional as F

## Self attention
sequence $x$ is $b x t x k$. batch sequence length and dimension
Follwoing we add linear transformations in order to provide learnable parameters.
* $W_q$ transform $x_i$ to $q_i$ query vector. Compared to every other vector inorder to establish weights for its own output $y_i$
* $W_k$ transform $x_i$ to $k_i$ key vector. Used for $x_i$ being compared to the queries, to establish the weights for other $y_j$
* $W_v$ transform $x_i$ to $v_i$ vector to be weighted and that actually encodes the information.

<img src="images/learnable-weights.png"  width="500" height="600">

Hence, the self attention layer for each $y_i$ will be the following


<img src="images/learnable-structure.png"  width="500" height="600">

Small tricks:

* Divide by $\sqrt{k}$ to reduce input values of the softmax (as dimension increases it reduces by the euclidean lenght, read bibliography for more explanations)
* Multi-head attention: learn many query, keys values for each input (paralllel self attention) and concatenate at the end. Allows to focus and learn different queries for each input. Each $W_q^r$, $W_k^r$ and $W_v^r$ is an attention head.




In [4]:
class SelfAttention(nn.Module):
    def __init__(self, k, heads=10):
        super().__init__()
        #Compute Linear Transformations
        self.transform_queries = nn.Linear(k, k*heads, bias=False)
        self.transform_keys = nn.Linear(k, k*heads, bias=False)
        self.transform_values = nn.Linear(k, k*heads, bias=False)
        # Linear Transform that reduces dimensionality
        self.dimension_reduce = nn.Linear(k*heads, k, bias=False)
        self.k = k
        self.h = heads
        
    def forward(self,x): #input b x t x k
        b, t, k = x.size()
        h = self.h
        
        #Transform: b,t,k => b,t,k*h 
        queries = self.transform_queries(x)
        keys = self.transform_keys(x)
        values = self.transform_values(x)
        #Separate heads from dimension
        # b,t,k*h => b,t,h,k
        queries = queries.view(b,t,h,k)
        keys = keys.view(b,t,h,k)
        values = values.view(b,t,h,k)
        #Matrix multiplication for each batch and each head. Hence we merge heads and batch in order tu use torch.bmm
        # Transpose b,t,h,k => b,h,t,k
        # Merge dim b,h,t,k => b*h,t,k
        queries = queries.transpose(1,2).contiguous().view(b*h,t,k)
        keys = keys.transpose(1,2).contiguous().view(b*h,t,k)
        values = values.transpose(1,2).contiguous().view(b*h,t,k)
        # Scale
        queries = queries /(k**(0.25))
        keys = keys/(k**(0.25))
        
        #Use torch batch matrix mult, performs a matrix mult for each elemt of the batch.
        #Batch consist of batch sample and head
        # b*h,t,k x b*h,k,t => b*h,t,t 
        weights = torch.bmm(queries,keys.transpose(1,2))
        soft_weights = F.softmax(weights, dim=2)
        #Multiply weights b*h,t,t x b*h,t,k => b*h,t,k 
        #(each row weight contains weights for each row vectors, row weights linear combination of rowvectors) 
        output = torch.bmm(soft_weights, values).view(b, h, t, k)
        #Merge the h heads on the k dimension
        #Transpose b,h,t,k =>  b,t,h,k 
        #View(merge) b,t,h,k => b,t,h*k 
        output = output.transpose(1,2).contiguous().view(b,t,h*k)
        #Reduce dimension b,t,h*k => b,t,k
        return self.dimension_reduce(output)


In [5]:
# Fast check
self_attention = SelfAttention(k=10,heads=20)
sample_input = torch.rand(8,5,10)
output = self_attention(sample_input)
output, output.shape

(tensor([[[ 0.0296,  0.0317,  0.1882,  0.0721, -0.2167, -0.1055,  0.0324,
           -0.0840,  0.1289, -0.0255],
          [ 0.0317,  0.0318,  0.1885,  0.0736, -0.2177, -0.1046,  0.0331,
           -0.0845,  0.1283, -0.0263],
          [ 0.0310,  0.0303,  0.1882,  0.0735, -0.2160, -0.1048,  0.0335,
           -0.0832,  0.1291, -0.0247],
          [ 0.0298,  0.0334,  0.1883,  0.0719, -0.2170, -0.1078,  0.0327,
           -0.0856,  0.1299, -0.0260],
          [ 0.0306,  0.0320,  0.1886,  0.0722, -0.2164, -0.1058,  0.0328,
           -0.0838,  0.1290, -0.0247]],
 
         [[ 0.0150,  0.0536,  0.2762,  0.0174, -0.2163, -0.1906, -0.0490,
           -0.1417,  0.1454, -0.0242],
          [ 0.0178,  0.0540,  0.2771,  0.0179, -0.2177, -0.1874, -0.0473,
           -0.1413,  0.1419, -0.0235],
          [ 0.0167,  0.0547,  0.2763,  0.0187, -0.2182, -0.1894, -0.0504,
           -0.1425,  0.1438, -0.0251],
          [ 0.0171,  0.0545,  0.2762,  0.0186, -0.2178, -0.1880, -0.0506,
           -0.1417,

## Transformers

"Any architecture designed to process a connected set of units—such where the only interaction between units is through self-attention"
<img src="images/transformers.png"  width="500" height="600">

It is to say seq-to-seq all at once

Basic structure:
* Self attention
* Norm layer
* FeedForward Net
* Norm Layer

Additionaly residual blocks and keypoint is to combine feedforward with self attention

In [10]:
class TransformerModule(nn.Module):
    
    def __init__(self, k, heads, hidden_layer_mult=4):
        super().__init__()
        self.self_attention = SelfAttention(k, heads)
        self.layer_norm_1 = nn.LayerNorm(k)
        self.feed_forward = nn.Sequential(
            nn.Linear(k, hidden_layer_mult * k),
            nn.ReLU(),
            nn.Linear(hidden_layer_mult * k, k),
        )
        self.layer_norm_2 = nn.LayerNorm(k)
        
    def forward(self, x):
        attention = self.self_attention(x)
        norm = self.layer_norm_1(attention + x)
        ff_output = self.feed_forward(norm)
        output = self.layer_norm_2(ff_output+norm)
        return output

In [11]:
#Fast check
transformer_module = TransformerModule(k=10, heads=20)
sample_input = torch.rand(8,5,10)
output = transformer_module(sample_input)
output, output.shape

(tensor([[[ 0.8233, -1.3898, -1.3552,  0.6779,  0.1417, -1.2642,  1.5271,
            0.8112, -0.4786,  0.5067],
          [ 0.6343, -0.0228, -2.2422,  1.1804, -0.8153,  0.3628,  1.2845,
           -0.4620,  0.5467, -0.4664],
          [ 0.6151, -2.0556, -0.0430,  0.2308, -0.3274, -1.1467, -0.3801,
            1.4909,  0.4577,  1.1584],
          [-0.7969, -0.4526, -0.9660, -1.2747, -0.7350,  1.1137,  0.9998,
           -0.5436,  1.3457,  1.3098],
          [-1.6160,  1.1272, -0.5850,  0.0466, -1.1241,  0.3151,  1.6314,
           -0.9376,  0.9011,  0.2412]],
 
         [[ 1.3050, -2.1407, -0.7351,  0.0604,  1.0004,  0.2593,  0.2542,
           -0.2618,  1.1129, -0.8547],
          [-0.2232, -0.6451, -2.3826, -0.3343,  0.0539,  0.5688,  1.5108,
            0.7899,  0.7139, -0.0521],
          [ 0.5460, -1.1643, -1.9265, -0.5806,  0.4076,  0.2140,  0.9163,
           -0.7309,  1.0479,  1.2703],
          [ 0.7612, -0.3981, -1.6886, -0.4407,  0.5392, -1.0161,  1.1695,
            1.8046,