# Transformer

Dans ce TP, nous allons implémenter un **transformer** en PyThorch "from scratch". Ce TP est basé sur la vidéo suivante:<br>
https://www.youtube.com/watch?v=U0s0f995w14

<img src="files/figures/transformer.jpg" width="350px" class="center"/>

## Librairies

In [11]:
import torch
import torch.nn as nn

## Model

In [12]:
class SelfAttention(nn.Module):
    
    def __init__(self, embed_dim, nb_heads):
        
        super(SelfAttention, self).__init__()
        
        self.embed_dim = embed_dim # 512 or 768
        self.nb_heads = nb_heads   # 8 attention heads
        self.head_dim = embed_dim // nb_heads
        
        try:
            self.head_dim * nb_heads == embed_dim
        except:
            print("Embed_dim needs to be divisible by nb_heads")
        
        self.fc_queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(self.nb_heads*self.head_dim, self.embed_dim)
        
        
        def forward(self, queries, keys, values, mask):
            
            # For one input sequence:
            # queries (Q), keys (K) and values (V) are 2D matrices
            # Q : [nb_queries x embed_dim]
            # K : [nb_keys x embed_dim]
            # V : [nb_values x embed_dim]
            # For a batch of N input sequences:
            # queries (Q), keys (K) and values (V) are then 3D matrices
            # Q : [N, nb_queries x embed_dim]
            # K : [N, nb_keys x embed_dim]
            # V : [N, nb_values x embed_dim]
            
            batch_size = queries.shape[0]
            nb_queries = queries.shape[1]
            nb_keys = keys.shape[1]
            nb_values = values.shape[1]
            
            # split keys, queries and values into nb_heads pieces
            # the 3rd dim embed_dim split into nb_heads x head_dim = embed_dim
            # queries (Q), keys (K) and values (V) are then 4D matrices
            # Q : [N, nb_queries x nb_heads x head_dim]
            # K : [N, nb_keys x nb_heads x head_dim]
            # V : [N, nb_values x nb_heads x head_dim]
            queries = queries.reshape(batch_size, nb_queries, self.nb_heads, self.head_dim)
            keys = keys.reshape(batch_size, nb_keys, self.nb_heads, self.head_dim)
            values = values.reshape(batch_size, nb_values, self.nb_heads, self.head_dim)
            
            # Attention scores are given by QK^T
            # we accomodate the product according to the batch size
            # n = batch_size
            # q = nb_queries, k = nb_keys, v = nb_values
            # h = nb_heads
            # d = head_dim
            attn_scores = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
            # attn_scores : [N, nb_heads, nb_queries, nb_keys]
            
            if mask is not None:
                attn_scores = attn_scores.maked_fill(mask==0, float("-1e20"))
            
            # Attention weights are given by softmax(QK^T / sqrt(d_k))
            attn_weights = torch.softmax(attn_scores / (self.embed_dim)**0.5, dim=3)
            # attn_weights : [N, nb_heads, nb_queries, nb_keys]
            
            # Outputs are given by softmax(QK^T / sqrt(d_k))V
            # l = k = v = nb_keys = nb_values
            outputs = torch.einsum("nhql,nlhd->nqhd", [attn_weights, values])
            
            # concatenation of the heads
            outputs = outputs.reshape(batch_size, nb_queries, self.nb_heads * self.head_dim)
            
            # last fc ålayer
            outputs = self.fc_out(outputs)
            
            return outputs