# Exercice 4

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


def softmax(x):
    return np.exp(x) / np.sum(np.exp(x), axis=-1, keepdims=True)


def dropout(x, taille=0.1):
    if taille == 0:
        return x
    mask = np.random.binomial(1, 1-taille, size=x.shape) / (1-taille)
    return x * mask


def multi_head_attention(Q, K, V, num_heads, taille_drop=0.0):
    d_model = Q.shape[-1] 
    d_k = d_model // num_heads 
    
    Q_heads = [np.dot(Q, np.random.rand(d_model, d_k)) for _ in range(num_heads)]
    K_heads = [np.dot(K, np.random.rand(d_model, d_k)) for _ in range(num_heads)]
    V_heads = [np.dot(V, np.random.rand(d_model, d_k)) for _ in range(num_heads)]
    
    liste_attention_outputs = []
    liste_attention_weights = []
    
    for i in range(num_heads):
        scores = np.dot(Q_heads[i], K_heads[i].T) / np.sqrt(d_k)
        attention_weights = softmax(scores)
        attention_weights = dropout(attention_weights, taille_drop)
        attention_output = np.dot(attention_weights, V_heads[i])
        liste_attention_outputs.append(attention_output)
        liste_attention_weights.append(attention_weights)
    
    concatenated_output = np.concatenate(liste_attention_outputs, axis=-1)
    output = np.dot(concatenated_output, np.random.rand(concatenated_output.shape[-1], d_model))
    return output, liste_attention_weights


def single_head_attention(Q, K, V, taille_drop=0.0):
    scores = np.dot(Q, K.T) / np.sqrt(K.shape[1])
    attention_weights = softmax(scores)
    attention_weights = dropout(attention_weights, taille_drop)
    output = np.dot(attention_weights, V)
    return output, attention_weights



sentence = "I am very happy to work at Paris 8 university".split()
vocab_size = len(sentence)

word_to_index = {word: idx for idx, word in enumerate(sentence)}
embeddings = np.eye(vocab_size)
inputs = np.array([embeddings[word_to_index[word]] for word in sentence])

np.random.seed(0)


print("Changement de dimension")
dimension_embed = [4, 8, 16]
plt.figure(figsize=(18, 12))

for indice_dimension, dimension in enumerate(dimension_embed):
    print("\n dimension: ",dimension)
    
    prj = np.random.rand(vocab_size, dimension)
    inputs_p = np.dot(inputs, prj)
    

    Wq = np.random.rand(dimension, dimension)
    Wk = np.random.rand(dimension, dimension)
    Wv = np.random.rand(dimension, dimension)
    
    Q = np.dot(inputs_p, Wq)
    K = np.dot(inputs_p, Wk)
    V = np.dot(inputs_p, Wv)
    
    scores = np.dot(Q, K.T) / np.sqrt(K.shape[1])
    attention_weights = softmax(scores)
    output = np.dot(attention_weights, V)

    plt.subplot(3, 3, indice_dimension*3 + 1)
    sns.heatmap(Q, annot=True, fmt='.2f', cmap='viridis', xticklabels=range(dimension), yticklabels=sentence)
    plt.title(f'Query ({dimension})')
    
    plt.subplot(3, 3, indice_dimension*3 + 2)
    sns.heatmap(attention_weights, annot=True, fmt='.2f', cmap='viridis', xticklabels=sentence, yticklabels=sentence)
    plt.title(f'Attention Weights ({dimension})')
    
    plt.subplot(3, 3, indice_dimension*3 + 3)
    sns.heatmap(output, annot=True, fmt='.2f', cmap='viridis', xticklabels=range(dimension), yticklabels=sentence)
    plt.title(f'Output ({dimension})')

plt.tight_layout()
plt.show()


print("\nMulti-Head Attention ===")


dimension = 8
prj = np.random.rand(vocab_size, dimension)
inputs_p = np.dot(inputs, prj)

Wq = np.random.rand(dimension, dimension)
Wk = np.random.rand(dimension, dimension)
Wv = np.random.rand(dimension, dimension)

Q = np.dot(inputs_p, Wq)
K = np.dot(inputs_p, Wk)
V = np.dot(inputs_p, Wv)

liste_num_head = [1, 2, 4]
plt.figure(figsize=(20, 10))

for indice_tete, num_heads in enumerate(liste_num_head):
    print("\n nombre tete: ",num_heads)
    
    if num_heads == 1:
        output, attention_weights = single_head_attention(Q, K, V)
    else:
        output, liste_attention_weights = multi_head_attention(Q, K, V, num_heads)
    
    plt.subplot(3, 4, indice_tete*4 + 1)
    sns.heatmap(Q, annot=True, fmt='.2f', cmap='viridis', xticklabels=range(dimension), yticklabels=sentence)
    plt.title(f'Query (Q) ({num_heads})')
    
    plt.subplot(3, 4, indice_tete*4 + 2)
    if num_heads == 1:
        sns.heatmap(attention_weights, annot=True, fmt='.2f', cmap='viridis', xticklabels=sentence, yticklabels=sentence)
    else:
        sns.heatmap(liste_attention_weights[0], annot=True, fmt='.2f', cmap='viridis', xticklabels=sentence, yticklabels=sentence)
    plt.title(f'Attention Weights ({num_heads})')
    
    plt.subplot(3, 4, indice_tete*4 + 3)
    sns.heatmap(output, annot=True, fmt='.2f', cmap='viridis', xticklabels=range(dimension), yticklabels=sentence)
    plt.title(f'Output ({num_heads})')
    


plt.tight_layout()
plt.show()
















print("\nDropout")

ensemble_drops = [0.0, 0.3]
plt.figure(figsize=(15, 8))

for indice_drop, taille_drop in enumerate(ensemble_drops):
    print("\nDropout : ",taille_drop)
    
    output, attention_weights = single_head_attention(Q, K, V, taille_drop)
    
    plt.subplot(2, 2, indice_drop*2 + 1)
    sns.heatmap(attention_weights, annot=True, fmt='.2f', cmap='viridis', xticklabels=sentence, yticklabels=sentence)
    plt.title(f'Attention Weights (Dropout={taille_drop})')
    
    plt.subplot(2, 2, indice_drop*2 + 2)
    sns.heatmap(output, annot=True, fmt='.2f', cmap='viridis', 
                xticklabels=range(dimension), yticklabels=sentence)
    plt.title(f'Output (Dropout={taille_drop})')


plt.tight_layout()
plt.show()

