In [2]:
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.models import Model, Sequential
from tensorflow import keras
from tensorflow.keras import backend as K

In [7]:
class ScaledDotProductAttention(keras.layers.Layer):
    r"""The attention layer that takes three inputs representing queries, keys and values.
    \text{Attention}(Q, K, V) = \text{softmax}(\frac{Q K^T}{\sqrt{d_k}}) V
    See: https://arxiv.org/pdf/1706.03762.pdf
    """

    def __init__(self,
                 return_attention=False,
                 **kwargs):
        """Initialize the layer.
        :param return_attention: Whether to return attention weights.
        :param kwargs: Arguments for parent class.
        """
        super(ScaledDotProductAttention, self).__init__(**kwargs)
        self.supports_masking = True
        self.return_attention = return_attention
        

    def call(self, inputs, mask=None, **kwargs):
        if isinstance(inputs, list):
            query, key, value = inputs
        else:
            query = key = value = inputs
        if isinstance(mask, list):
            mask = mask[1]

        feature_dim = K.shape(query)[-1]
        e = K.batch_dot(query, key, axes=2) / K.sqrt(K.cast(feature_dim, dtype=K.floatx()))
        e = K.exp(e - K.max(e, axis=-1, keepdims=True))
        
        if mask is not None:
            e *= K.cast(K.expand_dims(mask, axis=-2), K.floatx())
            
        a = e / (K.sum(e, axis=-1, keepdims=True) + K.epsilon())
        v = K.batch_dot(a, value)
        
        if self.return_attention:
            return [v, a]
        return v

In [8]:
input_shape = (100,)
model_input = Input(shape=input_shape)
embedding_layer = Embedding(1000, 128)(model_input)
embedding_dropout_layer = SpatialDropout1D(0.3)(embedding_layer)
bilstm_layer = Bidirectional(LSTM(256, return_sequences=True))(embedding_dropout_layer)
scaled_attention_layer = ScaledDotProductAttention()(bilstm_layer)
max_pool_layer = GlobalMaxPooling1D()(scaled_attention_layer)
output = Dense(3, activation='softmax')(max_pool_layer)
full_model = Model(inputs=model_input, outputs=output)

In [9]:
full_model.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
embedding_2 (Embedding)      (None, 100, 128)          128000    
_________________________________________________________________
spatial_dropout1d_2 (Spatial (None, 100, 128)          0         
_________________________________________________________________
bidirectional_2 (Bidirection (None, 100, 512)          788480    
_________________________________________________________________
scaled_dot_product_attention (None, 100, 512)          0         
_________________________________________________________________
global_max_pooling1d_2 (Glob (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 3)                 1539