In [1]:
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Embedding, LSTM, Dense, Attention, Bidirectional, Dropout
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# Загрузка данных
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=10000)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz


In [3]:
# Предобработка данных
max_len = 200
x_train = pad_sequences(x_train, maxlen=max_len, truncating='post')
x_test = pad_sequences(x_test, maxlen=max_len, truncating='post')

In [4]:
# Модель RNN с attention
embedding_dim = 100
hidden_dim = 256
output_dim = 1
dropout_rate = 0.5

In [5]:
inputs = Input(shape=(max_len,))
embedding = Embedding(input_dim=10000, output_dim=embedding_dim, input_length=max_len)(inputs)
lstm = Bidirectional(LSTM(hidden_dim, return_sequences=True))(embedding)
attention = Attention()([lstm, lstm])
context = tf.reduce_sum(attention * lstm, axis=1)
dropout = Dropout(dropout_rate)(context)
output = Dense(output_dim, activation='sigmoid')(dropout)
model = Model(inputs=inputs, outputs=output)

In [6]:
# Компиляция модели
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

In [7]:
# Добавленные строки для сохранения весов внимания
attention_weights = []

In [8]:
# Обучение модели
for epoch in range(3):
    history = model.fit(x_train, y_train, epochs=1, batch_size=64, validation_data=(x_test, y_test))

    # Получение весов внимания
    attention_layer = model.layers[2]  # Получаем attention слой
    weights = attention_layer.get_weights()[0]
    mean_attention_weight = np.mean(weights)
    attention_weights.append(mean_attention_weight)



In [9]:
# Оценка точности модели
loss, accuracy = model.evaluate(x_test, y_test)
print(f'Test Loss: {loss:.3f}, Test Accuracy: {accuracy * 100:.2f}%')

Test Loss: 0.418, Test Accuracy: 84.59%


In [10]:
attention_weights

[0.0004714129, 0.00037414717, 6.578885e-05]

In [11]:
accuracy

0.8458799719810486