In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Embedding, LSTM, Dense, Input, Layer, Dropout, GlobalAveragePooling1D
)
from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing.sequence import pad_sequences

# Paramètres de configuration
max_features = 20000  # Nombre de mots uniques dans le vocabulaire
max_len = 200         # Longueur maximale des séquences
embedding_dim = 128   # Dimension des vecteurs d'embedding

# Charger et préparer le dataset IMDB
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
x_train = pad_sequences(x_train, maxlen=max_len)
x_test = pad_sequences(x_test, maxlen=max_len)

# Définir une couche d'attention personnalisée
class Attention(Layer):
    def __init__(self):
        super(Attention, self).__init__()

    def build(self, input_shape):
        self.W = self.add_weight(shape=(input_shape[-1], input_shape[-1]),
                                 initializer="random_normal",
                                 trainable=True)
        self.b = self.add_weight(shape=(input_shape[-1],),
                                 initializer="zeros",
                                 trainable=True)
        self.u = self.add_weight(shape=(input_shape[-1], 1),
                                 initializer="random_normal",
                                 trainable=True)
        super(Attention, self).build(input_shape)

    def call(self, inputs):
        # Calcul des scores d'attention
        score = tf.nn.tanh(tf.tensordot(inputs, self.W, axes=1) + self.b)
        attention_weights = tf.nn.softmax(tf.tensordot(score, self.u, axes=1), axis=1)
        # Calcul du contexte pondéré
        context_vector = tf.reduce_sum(inputs * attention_weights, axis=1)
        return context_vector, attention_weights

# Construire le modèle avec LSTM et attention
def build_model():
    inputs = Input(shape=(max_len,))
    x = Embedding(max_features, embedding_dim)(inputs)
    x = LSTM(128, return_sequences=True)(x)  # Sortie de séquences pour l'attention
    context_vector, attention_weights = Attention()(x)
    x = Dropout(0.5)(context_vector)
    outputs = Dense(1, activation="sigmoid")(x)

    model = Model(inputs=inputs, outputs=outputs)
    return model

# Initialiser et compiler le modèle
model = build_model()
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
model.summary()

# Entraîner le modèle
history = model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=5, batch_size=64)

# Évaluer le modèle
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")

# Visualiser les poids d'attention pour une phrase
def visualize_attention(input_sequence, word_index):
    reverse_word_index = {v: k for k, v in word_index.items()}
    words = [reverse_word_index.get(i, '?') for i in input_sequence]

    input_sequence = pad_sequences([input_sequence], maxlen=max_len)
    model_attention = Model(inputs=model.input, outputs=model.get_layer('attention').output)
    context_vector, attention_weights = model_attention.predict(input_sequence)
    attention_weights = attention_weights.squeeze()

    # Afficher les mots et leurs poids
    for word, weight in zip(words, attention_weights):
        print(f"{word}: {weight:.4f}")

# Exemple d'utilisation de la visualisation des poids d'attention
sample_index = 0  # Index d'un échantillon dans le dataset
visualize_attention(x_train[sample_index], imdb.get_word_index())


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz
[1m17464789/17464789[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


Epoch 1/5
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m273s[0m 693ms/step - accuracy: 0.7144 - loss: 0.5248 - val_accuracy: 0.8676 - val_loss: 0.3079
Epoch 2/5
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m335s[0m 728ms/step - accuracy: 0.9260 - loss: 0.2035 - val_accuracy: 0.8689 - val_loss: 0.3172
Epoch 3/5
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m310s[0m 697ms/step - accuracy: 0.9647 - loss: 0.1057 - val_accuracy: 0.8587 - val_loss: 0.3637
Epoch 4/5
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m322s[0m 698ms/step - accuracy: 0.9800 - loss: 0.0654 - val_accuracy: 0.8509 - val_loss: 0.4907
Epoch 5/5
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m322s[0m 697ms/step - accuracy: 0.9891 - loss: 0.0359 - val_accuracy: 0.8416 - val_loss: 0.4851
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m96s[0m 123ms/step - accuracy: 0.8435 - loss: 0.4837
Test Accuracy: 84.16%
Downloading data from https://storage.goo