# Comprendre les mécanismes d'attention

In [1]:
import numpy as np
import math
import tensorflow as tf
from tensorflow.keras.layers import Dense, Embedding


In [2]:
# Vocabulaire

source_vocab = {"<pad>": 0, "J'": 1, "aime": 2, "l'IA": 3}
target_vocab = {"<pad>": 0, "I": 1, "love": 2, "AI": 3}

source_sentence = ["J'", "aime", "l'IA"]
target_sentence = ["I", "love", "AI"]


In [3]:
# Embedding

embedding_dim = 256

source_embedding_layer = Embedding(input_dim=len(source_vocab), output_dim=embedding_dim)
target_embedding_layer = Embedding(input_dim=len(target_vocab), output_dim=embedding_dim)

encoder_input = tf.constant([[source_vocab[w] for w in source_sentence]])
decoder_input = tf.constant([[target_vocab[w] for w in target_sentence]])

encoder_embedded = source_embedding_layer(encoder_input)
decoder_embedded = target_embedding_layer(decoder_input)
print(encoder_embedded)

tf.Tensor(
[[[ 1.22458339e-02  9.03306156e-03 -1.20334700e-03 -3.68554965e-02
   -4.20827642e-02 -7.48722628e-03 -6.57258183e-03  1.84887983e-02
   -2.48702522e-02  2.68221237e-02 -1.95911173e-02  4.16897610e-03
   -2.83405185e-02  2.94088982e-02  4.62104939e-02 -1.19501725e-02
    3.61781977e-02  1.82985924e-02 -6.60772249e-03 -2.00513732e-02
   -2.02218890e-02  2.20030546e-03 -2.28498112e-02  2.75463946e-02
    2.36936249e-02  4.49346192e-02 -4.07230034e-02 -7.70865753e-03
   -4.50745597e-02 -4.78939898e-02 -5.65898418e-03  1.20176189e-02
   -3.54299061e-02  4.49584164e-02  2.07644589e-02  4.00773324e-02
   -3.16075236e-03 -3.41283828e-02 -4.63761389e-04  4.65988629e-02
   -2.98468471e-02 -4.20441478e-03 -4.36969846e-03  3.31059359e-02
    2.30305083e-02 -2.87600514e-02 -1.62031054e-02  1.32355951e-02
   -2.02228669e-02  1.42252184e-02  2.63634808e-02  4.65340056e-02
    3.51757668e-02  9.85424593e-03 -2.80498024e-02 -3.39217559e-02
    2.21589953e-03  4.51902486e-02 -4.31303270e-02 

In [4]:
# Transformation en Queries, Keys et Values

Q = Dense(256, name="query")(decoder_embedded)
K = Dense(256, name="key")(encoder_embedded)
V = Dense(256, name="value")(encoder_embedded)


print("Dimensions de Q:", Q.shape)
print("\nDimensions de K:", K.shape)
print("\nDimensions de V:", V.shape)



Dimensions de Q: (1, 3, 256)

Dimensions de K: (1, 3, 256)

Dimensions de V: (1, 3, 256)


In [5]:

## Calcul de l'Attention
QK = tf.matmul(Q, K, transpose_b=True)
QK_normalized = QK / math.sqrt(256)
softmax = tf.nn.softmax(QK_normalized)
attention_output = tf.matmul(softmax, V)

print("Dimension de l'attention:", attention_output.shape)


Dimension de l'attention: (1, 3, 256)


# Multihead Attention

In [6]:
# Découpage en multi-Têtes
num_heads = 8
depth = 256 // num_heads

def split_heads(x):
    x = tf.reshape(x, (1, -1, num_heads, depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])

Q = split_heads(Q)
print("Dimensions de multi-head Q:", Q.shape)  # (1, 8, 3, 32)

K = split_heads(K)
print("Dimensions de multi-head K:", K.shape)  # (1, 8, 3, 32)

V = split_heads(V)
print("Dimensions de multi-head V:", V.shape)  # (1, 8, 3, 32)




Dimensions de multi-head Q: (1, 8, 3, 32)
Dimensions de multi-head K: (1, 8, 3, 32)
Dimensions de multi-head V: (1, 8, 3, 32)


In [7]:
# Calcul de l'attention multi-tete
QK = tf.matmul(Q, K, transpose_b=True)
print("Dimensions of QK:", QK.shape)  # (1, 8, 3, 3)

QK_normalized = QK / math.sqrt(depth)
softmax = tf.nn.softmax(QK_normalized)
attention_heads = tf.matmul(softmax, V)
print("Dimensions de attention_heads:", attention_heads.shape)  # (1, 8, 3, 32)



Dimensions of QK: (1, 8, 3, 3)
Dimensions de attention_heads: (1, 8, 3, 32)


In [8]:
#Fusion des têtes
attention_heads = tf.transpose(attention_heads, perm=[0, 2, 1, 3])
concat_attention = tf.reshape(attention_heads, (1, -1, 256))
print("Dimensions de concat_attention:", concat_attention.shape)  # (1, 3, 256)


Dimensions de concat_attention: (1, 3, 256)


In [9]:
# transformation linéaire finale
multihead_attention_output = Dense(256, name="output")(concat_attention)
print("Dimensions de multihead_attention_output:", multihead_attention_output.shape)  # (1, 3, 256)


Dimensions de multihead_attention_output: (1, 3, 256)
