# Multi-head-attention 模型

Attention Is All You Need

https://arxiv.org/abs/1706.03762

In [None]:
import sys
sys.path.append("..")
sys.dont_write_bytecode = True

import tensorflow as tf

from deep_recommenders.layers.nlp.multi_head_attention import MultiHeadAttention

## 数据准备

imdb数据集的预处理和分割

In [None]:
vocab_size = 5000
max_len = 256

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(maxlen=max_len, num_words=vocab_size)
x_train = tf.keras.preprocessing.sequence.pad_sequences(x_train, maxlen=max_len)
x_test = tf.keras.preprocessing.sequence.pad_sequences(x_test, maxlen=max_len)
x_train_masks = tf.equal(x_train, 0)
x_test_masks = tf.equal(x_test, 0)
y_train = tf.keras.utils.to_categorical(y_train)
y_test = tf.keras.utils.to_categorical(y_test)

## 模型构建

inputs, masks => Embedding => MultiHeadAttention => GlobalAveragePooling => Dense(10) => Relu => Dense(2)


In [None]:
model_dim = 32
batch_size = 256
epochs = 10

inputs = tf.keras.Input(shape=(max_len,), name="inputs")
masks = tf.keras.Input(shape=(max_len,), name='masks')
embeddings = tf.keras.layers.Embedding(vocab_size, model_dim)(inputs)
x = MultiHeadAttention(2, 16)([embeddings, embeddings, embeddings, masks])
x = tf.keras.layers.GlobalAveragePooling1D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
x = tf.keras.layers.Dense(10, activation='relu')(x)
outputs = tf.keras.layers.Dense(2, activation='softmax')(x)

model = tf.keras.Model(inputs=[inputs, masks], outputs=outputs)
model.compile(optimizer=tf.keras.optimizers.Adam(beta_1=0.9, beta_2=0.98, epsilon=1e-9), 
    loss='categorical_crossentropy', metrics=['accuracy'])

## 模型训练

使用早停策略防止过拟合。

In [None]:
es = tf.keras.callbacks.EarlyStopping(patience=3)
model.fit([x_train, x_train_masks], y_train, 
    batch_size=batch_size, epochs=epochs, validation_split=0.2, callbacks=[es])

## 模型评估

In [None]:
test_metrics = model.evaluate([x_test, x_test_masks], y_test, batch_size=batch_size, verbose=0)
print("loss on Test: %.4f" % test_metrics[0])
print("accu on Test: %.4f" % test_metrics[1])

# loss on Test: 0.3308
# accu on Test: 0.8777