In [1]:
%load_ext autoreload

In [3]:
import torch
from torch import nn
import torch.nn.functional as F

## Self attention
sequence $x$ is $b x t x k$. batch sequence length and dimension
Follwoing we add linear transformations in order to provide learnable parameters.
* $W_q$ transform $x_i$ to $q_i$ query vector. Compared to every other vector inorder to establish weights for its own output $y_i$
* $W_k$ transform $x_i$ to $k_i$ key vector. Used for $x_i$ being compared to the queries, to establish the weights for other $y_j$
* $W_v$ transform $x_i$ to $v_i$ vector to be weighted and that actually encodes the information.

<img src="images/learnable-weights.png"  width="500" height="600">

Hence, the self attention layer for each $y_i$ will be the following


<img src="images/learnable-structure.png"  width="500" height="600">

Small tricks:

* Divide by $\sqrt{k}$ to reduce input values of the softmax (as dimension increases it reduces by the euclidean lenght, read bibliography for more explanations)
* Multi-head attention: learn many query, keys values for each input (paralllel self attention) and concatenate at the end. Allows to focus and learn different queries for each input. Each $W_q^r$, $W_k^r$ and $W_v^r$ is an attention head.




In [None]:
class SelfAttention(nn.Module):
    def __init__(self, k, heads=10):
        super().__init__()
        #Compute Linear Transformations
        self.transform_queries = nn.Linear(k, k*heads, bias=False)
        self.transform_keys = nn.Linear(k, k*heads, bias=False)
        self.transform_values = nn.Linear(k, k*heads, bias=False)
        # Linear Transform that reduces dimensionality
        self.dimension_reduce = nn.Linear(k*heads, k, bias=False)
        self.k = k
        self.h = heads
        
    def forward(self,x): #input b x t x k
        b, t, k = x.size()
        h = self.h
        
        #Transform: b,t,k => b,t,k*h 
        queries = self.transform_queries(x)
        keys = self.transform_keys(x)
        values = self.transform_values(x)
        #Separate heads from dimension
        # b,t,k*h => b,t,h,k
        queries = queries.view(b,t,h,k)
        keys = keys.view(b,t,h,k)
        values = values.view(b,t,h,k)
        #Matrix multiplication for each batch and each head. Hence we merge heads and batch in order tu use torch.bmm
        # Transpose b,t,h,k => b,h,t,k
        # Merge dim b,h,t,k => b*h,t,k
        queries = queries.transpose(1,2).contiguous().view(b*h,t,k)
        keys = keys.transpose(1,2).contiguous().view(b*h,t,k)
        values = values.transpose(1,2).contiguous().view(b*h,t,k)
        
        queries = queries /(k**(0.25))
        keys = keys/(k**(0.25))
        
        weights = torch.bmm(queries,keys)
        soft_weights = F.softmax(weights, dim=2)
