### Implémentation de la couche d'Attention d'un Transformer 

- Base : https://github.com/charlesollion/dlexperiments/tree/master/5-Transformers-Intro
- *À faire* : Implémenter l'attention de Bahdanau et de Luong (<x, Ay>)

In [5]:
import torch
import torch.nn as nn

### I. __Q__, __K__ et __V__

In [6]:
dim = 4

query_layer = nn.Linear(dim, dim)
key_layer = nn.Linear(dim, dim)
value_layer = nn.Linear(dim ,dim)

In [7]:
# size = (Batch_size, sequence_length, input_dimension)
X = torch.normal(0, 1, size=(1, 3, 4))

query = query_layer(X)
key = key_layer(X)
value = value_layer(X)

query.size(), key.size(), value.size()

(torch.Size([1, 3, 4]), torch.Size([1, 3, 4]), torch.Size([1, 3, 4]))

$$SelfAttention(Q_i, \mathbf{K}, \mathbf{V}) = \sum_j softmax_j(\frac{Q_i \cdot \mathbf{K}^T}{\sqrt{d_k}}) V_j $$

### Attention d'un Transformer

In [8]:
import math

attention_scores = torch.matmul(query, key.mT)
attention_scores /= math.sqrt(dim)

attention_probs = nn.Softmax(dim=1)(attention_scores)
attention_probs, attention_probs.size()

(tensor([[[0.2109, 0.3170, 0.2986],
          [0.5714, 0.3812, 0.4213],
          [0.2177, 0.3017, 0.2800]]], grad_fn=<SoftmaxBackward0>),
 torch.Size([1, 3, 3]))

In [47]:
torch.matmul(attention_probs, value) # Somme pondérée avec V

tensor([[[ 0.2359, -0.1721,  0.0009,  0.4688],
         [ 0.2439, -0.2217,  0.0210,  0.4918],
         [ 0.2290, -0.1905,  0.0565,  0.4676]]], grad_fn=<UnsafeViewBackward0>)

#### *C'est tout !*

## II. Attention, Normalisation & Prop. avant :
(Bloc Encodeur d'un Transformer quoi)

In [12]:
class SimpleEncoder(nn.Module):

    def __init__(self, hidden_dim) -> None:
        super().__init__()
        self.dim = hidden_dim
        self.query_layer = nn.Linear(self.dim, self.dim)
        self.value_layer = nn.Linear(self.dim, self.dim)
        self.key_layer = nn.Linear(self.dim, self.dim)

        self.output_layer = nn.Linear(self.dim, self.dim)
        self.dropout = nn.Dropout(0.2)
        self.LayerNorm = nn.LayerNorm(self.dim, eps=1e-4)

    def self_attention(self, X):
        query = self.query_layer(X)
        value = self.value_layer(X)
        key = self.key_layer(X)

        attention_scores = torch.matmul(query, key.mT)
        attention_scores = attention_scores /  math.sqrt(self.dim)

        attention_probs = nn.Softmax(dim=1)(attention_scores)

        return torch.matmul(attention_probs, value)

    def forward(self, X):
        Z = self.self_attention(X)
        Z = self.LayerNorm(X + Z)
        Z = self.output_layer(Z)
        Z = self.dropout(Z)
        return Z
    

In [13]:
T = SimpleEncoder(hidden_dim=4)

In [14]:
T(X)

tensor([[[-0.2061, -0.0000, -0.6968,  0.0000],
         [-0.4542, -0.0000, -0.5351, -0.0959],
         [ 0.2486, -1.1297,  0.0000, -0.4160]]], grad_fn=<MulBackward0>)