# Exercice 6

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def softmax(x):
    return np.exp(x) / np.sum(np.exp(x), axis=-1, keepdims=True)

def layer_norm(x):
    return (x - np.mean(x, axis=-1, keepdims=True)) / np.std(x, axis=-1, keepdims=True)



sentence = "Barack Obama was the 44th president of the United States".split()
vocab_size = len(sentence)

word_to_index = {word: idx for idx, word in enumerate(sentence)}
embeddings = np.eye(vocab_size)

inputs = np.array([embeddings[word_to_index[word]] for word in sentence])

np.random.seed(42)
Wq = np.random.rand(vocab_size, vocab_size)
Wk = np.random.rand(vocab_size, vocab_size)
Wv = np.random.rand(vocab_size, vocab_size)


Q = np.dot(inputs, Wq)
K = np.dot(inputs, Wk)
V = np.dot(inputs, Wv)

scores = np.dot(Q, K.T) / np.sqrt(K.shape[1])
attention_weights = softmax(scores)
output = np.dot(attention_weights, V)


plt.figure(figsize=(15, 10))

plt.subplot(2, 3, 1)
sns.heatmap(inputs, annot=True, cmap='viridis', xticklabels=sentence, yticklabels=sentence)
plt.title('Inputs pour entités Nommées')

plt.subplot(2, 3, 2)
sns.heatmap(attention_weights, annot=True, cmap='viridis', xticklabels=sentence, yticklabels=sentence)
plt.title('Attention Weights pour entités Nommées')


print("\nAnalyse:")

for i, mot in enumerate(sentence):
    attention_sans_lui_meme= attention_weights[i].copy()
    #ou le mot est on met 0 pour ne pas avoir bank = bank
    attention_sans_lui_meme[i]= 0  
    
    index_max = np.argmax(attention_sans_lui_meme)
    if attention_sans_lui_meme[index_max] > 0.1:
        print(f"   '{mot}' ->  '{sentence[index_max]}' (score: {attention_sans_lui_meme[index_max]:.3f})")
        



print("\n Attention traduction sequence sequence")
encoder_sentence = "i am the president".split()
decoder_sentence = "je suis le president ".split()

encoder_vocab_size = len(encoder_sentence)
decoder_vocab_size = len(decoder_sentence)


encoder_word_to_index = {word: idx for idx, word in enumerate(encoder_sentence)}
decoder_word_to_index = {word: idx for idx, word in enumerate(decoder_sentence)}
encoder_embeddings = np.eye(encoder_vocab_size)
decoder_embeddings = np.eye(decoder_vocab_size)

encoder_inputs = np.array([encoder_embeddings[encoder_word_to_index[word]] for word in encoder_sentence])
decoder_inputs = np.array([decoder_embeddings[decoder_word_to_index[word]] for word in decoder_sentence])

np.random.seed(42)
Wq_enc = np.random.rand(encoder_vocab_size, encoder_vocab_size)
Wk_enc = np.random.rand(encoder_vocab_size, encoder_vocab_size)
Wv_enc = np.random.rand(encoder_vocab_size, encoder_vocab_size)

Wq_dec = np.random.rand(decoder_vocab_size, decoder_vocab_size)
Wk_dec = np.random.rand(decoder_vocab_size, decoder_vocab_size)
Wv_dec = np.random.rand(decoder_vocab_size, decoder_vocab_size)

#encodeur
Q_enc = np.dot(encoder_inputs, Wq_enc)
K_enc = np.dot(encoder_inputs, Wk_enc)
V_enc = np.dot(encoder_inputs, Wv_enc)
scores_enc = np.dot(Q_enc, K_enc.T) / np.sqrt(K_enc.shape[1])
weights_enc = softmax(scores_enc)
output_enc = np.dot(weights_enc, V_enc)

#décodeur
Q_dec = np.dot(decoder_inputs, Wq_dec)
K_dec = np.dot(decoder_inputs, Wk_dec)
V_dec = np.dot(decoder_inputs, Wv_dec)
scores_dec = np.dot(Q_dec, K_dec.T) / np.sqrt(K_dec.shape[1])
weights_dec = softmax(scores_dec)
output_dec = np.dot(weights_dec, V_dec)


#Visualisation traduction
plt.subplot(2, 3, 3)
sns.heatmap(weights_enc, annot=True, cmap='viridis', xticklabels=encoder_sentence, yticklabels=encoder_sentence)
plt.title('Encoder Attention Traduction')

plt.subplot(2, 3, 4)
sns.heatmap(weights_dec, annot=True, cmap='viridis', xticklabels=decoder_sentence, yticklabels=decoder_sentence)
plt.title('Decoder attention traduction')




print("\n attention pour classification de texte ")

#sentences
sentences = [
    "i love the all the citizen",
    "the new law is even more unfair"
]


classification_sentence = sentences[0].split()
vocab_size_classification = len(classification_sentence)


word_to_index_classification = {word: idx for idx, word in enumerate(classification_sentence)}
embeddings_classification = np.eye(vocab_size_classification)
inputs_classification = np.array([embeddings_classification[word_to_index_classification[word]] for word in classification_sentence])


np.random.seed(42)
Wq_classification = np.random.rand(vocab_size_classification, vocab_size_classification)
Wk_classification = np.random.rand(vocab_size_classification, vocab_size_classification)
Wv_classification = np.random.rand(vocab_size_classification, vocab_size_classification)

Q_classification = np.dot(inputs_classification, Wq_classification)
K_classification = np.dot(inputs_classification, Wk_classification)
V_classification = np.dot(inputs_classification, Wv_classification)

scores_classification = np.dot(Q_classification, K_classification.T) / np.sqrt(K_classification.shape[1])
weights_classification = softmax(scores_classification)
output_classification = np.dot(weights_classification, V_classification)



attention_classifcaition = np.mean(output_classification, axis=0)


scores_relation = np.dot(Q_dec, K_enc.T) / np.sqrt(K_enc.shape[1])
weights_relation = softmax(scores_relation)
output_relation = np.dot(weights_relation, V_enc)

plt.subplot(2, 3, 5)
sns.heatmap(weights_relation, annot=True, cmap='viridis',
            xticklabels=encoder_sentence, yticklabels=decoder_sentence)
plt.title('Relation entre input et output sequence')


plt.subplot(2, 3, 6)
sns.heatmap(weights_classification, annot=True, cmap='viridis', 
            xticklabels=classification_sentence, yticklabels=classification_sentence)
plt.title("Classification attetion ")



plt.tight_layout()
plt.show()



print("\nAnalyse classification:")
for i, mot in enumerate(classification_sentence):
    attention_sans_lui_meme= weights_classification[i].copy()
    #ou le mot est on met 0 pour ne pas avoir bank = bank
    attention_sans_lui_meme[i]= 0  
    
    index_max = np.argmax(attention_sans_lui_meme)
    if attention_sans_lui_meme[index_max] > 0.1:
        print(f"   '{mot}' ->  '{classification_sentence[index_max]}' (score: {attention_sans_lui_meme[index_max]:.3f})")
        


print(f"\nFeatures pour classification: {attention_classifcaition}")


            
Le mécanisme d'attention relie 'Barack Obama' à 'United States',
et 'president' a 'states' donc on voit bien le lien sémantiques 

Donc la self-attention gère bien les entités


On peut alors comprendre son importance 