In [6]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import keras
from keras.models import Model
from keras import backend as K
from keras.engine.topology import Layer
from keras.layers import Input, Dense, Conv2D, MaxPool2D, Flatten, Multiply, Permute, Embedding, LSTM, Bidirectional
from keras.preprocessing import sequence
from keras.datasets import imdb

In [2]:
max_features = 20000
maxlen = 80  # cut texts after this number of words (among top max_features most common words)
batch_size = 32

print('Loading data...')
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
print(len(x_train), 'train sequences')
print(len(x_test), 'test sequences')

print('Pad sequences (samples x time)')
x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
print('x_train shape:', x_train.shape)
print('x_test shape:', x_test.shape)

Loading data...
25000 train sequences
25000 test sequences
Pad sequences (samples x time)
x_train shape: (25000, 80)
x_test shape: (25000, 80)


In [149]:
class MultiHeadAttention(Layer):
    def __init__(self, alignment_vector_size, return_sequence=False, **kwargs):
        self.alignment_vector_size = alignment_vector_size
        self.return_sequence = return_sequence
        super(MultiHeadAttention, self).__init__(**kwargs)

    def build(self, input_shape):
        self.input_shape_ = input_shape

        self.alignment_vector = self.add_weight(name="alignment_vector", shape=(self.alignment_vector_size,), initializer='uniform', trainable=True)
        self.kernel = self.add_weight(name="kernel", shape=(self.alignment_vector_size, input_shape[2]), initializer='uniform', trainable=True)
        self.bias = self.add_weight(name="bias", shape=(self.alignment_vector_size,), initializer='uniform', trainable=True)
        
        super(MultiHeadAttention, self).build(input_shape)

    def call(self, hidden_state_sequence): 
        hidden_state_sequence.set_shape(self.input_shape_)
        
        u = K.tanh( K.squeeze(K.dot(hidden_state_sequence, K.expand_dims(self.kernel)), axis=-1) + self.bias )
        
        aligned_u = K.squeeze(K.dot(u, K.expand_dims(self.alignment_vector)), axis=-1)
        attention_weights = K.softmax( aligned_u, axis=1 )
        
        if self.return_sequence:
            context_vector = K.expand_dims(attention_weights,axis=2) * hidden_state_sequence
        else:
            context_vector = K.sum( K.expand_dims(attention_weights,axis=2) * hidden_state_sequence, axis=1 )

        return context_vector

    def compute_output_shape(self, input_shape):
        if self.return_sequence:
            return input_shape
        else:
            return (input_shape[0], input_shape[-1])

In [140]:
K.get_session().close()
cfg = K.tf.ConfigProto()
cfg.gpu_options.allow_growth = True
K.set_session(K.tf.Session(config=cfg))

In [150]:
inputs = Input(shape=(maxlen,))
x = Embedding(max_features, 128)(inputs)

x = Bidirectional(LSTM(128, dropout=0.2, recurrent_dropout=0.2, return_sequences=True))(x)
x = MultiHeadAttention(10, return_sequence=True)(x)
x = LSTM(128, dropout=0.2, recurrent_dropout=0.2)(x)
predictions = Dense(1, activation='sigmoid')(x)


model = Model(inputs=inputs, outputs=predictions)
model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy'])

print(model.summary())

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_71 (InputLayer)        (None, 80)                0         
_________________________________________________________________
embedding_71 (Embedding)     (None, 80, 128)           2560000   
_________________________________________________________________
bidirectional_67 (Bidirectio (None, 80, 256)           263168    
_________________________________________________________________
multi_head_attention_66 (Mul (None, 80, 256)           2580      
_________________________________________________________________
lstm_74 (LSTM)               (None, 128)               197120    
_________________________________________________________________
dense_18 (Dense)             (None, 1)                 129       
Total params: 3,022,997
Trainable params: 3,022,997
Non-trainable params: 0
_________________________________________________________________


In [None]:
model.fit(x_train, y_train, batch_size=batch_size, epochs=15, validation_data=(x_test, y_test))

Train on 25000 samples, validate on 25000 samples
Epoch 1/15
Epoch 2/15
Epoch 3/15