In [2]:
from keras import backend as K
from keras.layers import Layer
from keras.layers import Input, GRU 
from keras.layers.wrappers import Bidirectional

class AttentionGRU(Layer):
    
    def __init__(self, 
                 units, 
                 output_dim, 
                 name = 'AttentionGRU', 
                 activation = 'tanh',
                 return_probabilities = False,
                 kernel_initializer = 'glorot_uniform', 
                 recurrent_initializer = 'orthogonal', 
                 bias_initializer = 'zeros', 
                 kernel_regularizer = None,
                 recurrent_regularizer = None,
                 activity_regularizer = None,
                 bias_regularizer = None,
                 kernel_constraint = None,
                 recurrent_constraint = None,
                 bias_constraint = None,
                 **kwargs):
        
        #define parameters
        #called when layer gets instantiated 
        
        self.units = units
        self.output_dim = output_dim
        self.activation = activations.get(activation)
        
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.recurrent_initializer = initializers.get(recurrent_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        
        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizers)
        
        self.kernel_constraint = constraints.get(kernel_constraint)
        self.recurrent_constraint = constraints.get(recurrent_constraint)
        self.bias_constraint = constraints.get(bias_constraint)
        
        self.name = name
        self.return_sequences = True     #redundant.. but okay!
        
        super(AttentionGRU, self).__init__(**kwargs)
        
    def build(self, input_shape):
        
        #is called when we run Model.compile(...)
        
        # define weights here
        
        #self.batch_size, self.timesteps, self.input_dim = input_shape
        
        #if self.stateful:
        #   super(AttentionLayer).reset_states()
        
        #self.states = [None, None]    # y, s
        
        """""
            Learnable Weights(?) for the "pre-softmax" attention vector 
        """""
        
        
        self.Mu_T = self.add_weight(shape = (self.units,), 
                                   name = 'Mu_T',
                                   initializer = self.kernel_initializer,
                                   regularizer = self.kernel_regularizer, 
                                   constraint = self.kernel_constraint)
        
        self.W_h = self.add_weight(shape = (self.input_dim, self.units),
                                  name = 'W_h',
                                  initializer = self.kernel_initializer,
                                  regularizer = self.kernel_regularizer,
                                  constraint = self.kernel_constraint)
        
        self.W_s = self.add_weight(shape = (self.units, self.units),
                                  name = 'W_s',
                                  initializer = self.kernel_initializer,
                                  regularizer = self.kernel_regularizer,
                                  constraint = self.kernel_constraint)
        
        #bias to be added or not
        self.b_a = self.add_weight(shape = (self.units),
                                  name = 'b_a',
                                  initializer = self.bias_initializer,
                                  regularizer = self.bias_regularizer,
                                  constraint = self.bias_constraint)
        
        """""
            Learnable weights for the update (z) gate
        """""
        
        self.W_z = self.add_weight(shape = (self.output_dim, self.units),
                                  name = 'W_z',
                                  initializer = self.recurrent_initializer,
                                  regularizer = self.recurrent_regularizer,
                                  constraint = self.recurrent_constraint)
        
        self.U_z = self.add_weight(shape = (self.units, self.units),
                                  name = 'U_z',
                                  initializer = self.recurrent_initializer,
                                  regularizer = self.recurrent_regularizer,
                                  constraint = self.recurrent_constraint)
        
        self.b_z = self.add_weight(shape = (self.units,),
                                  name = 'b_z',
                                  initializer = self.bias_initializer,
                                  regularizer = self.bias_regularizer,
                                  constraint = self.bias_constraint)
        
        """""
            Learnable Weights for the reset (r) gate
        """""
        
        self.W_r = self.add_weight(shape = (self.output_dim, self.units),
                                  name = 'W_r',
                                  initializer = self.recurrent_initializer,
                                  regularizer = self.recurrent_regularizer,
                                  constraint = self.recurrent_constraint)
        
        self.U_r = self.add_weight(shape = (self.units, self.units),
                                  name = 'U_r',
                                  initializer = self.recurrent_initializer,
                                  regularizer = self.recurrent_regularizer,
                                  constraint = self.recurrent_constraint)
        
        self.b_r = self.add_weight(shape = (self.units,),
                                  name = 'b_r',
                                  initializer = self.bias_initializer,
                                  regularizer = self.bias_regularizer,
                                  constraint = self.bias_constraint)
        
        """""
            Learnable Weights for the intermediate(c) gate
        """""
        
        self.W_c = self.add_weight(shape = (self.output_dim, self.units),
                                  name = 'W_c',
                                  initializer = self.recurrent_initializer,
                                  regularizer = self.recurrent_regularizer,
                                  constraint = self.recurrent_constraint)
        
        self.U_c = self.add_weight(shape = (self.units, self.units),
                                  name = 'U_c',
                                  initializer = self.recurrent_initializer,
                                  regularizer = self.recurrent_regularizer,
                                  constraint = self.recurrent_constraint)
        
        self.b_c = self.add_weight(shape = (self.units,),
                                  name = 'b_c',
                                  initializer = self.bias_initializer,
                                  regularizer = self.bias_regularizer,
                                  constraint = self.bias_constraint)
             
        """""
            Learnable Weights for the output(o) gate
        """""
        
        self.V = self.add_weight(shape = (),
                                  name = 'V',
                                  initializer = self.recurrent_initializer,
                                  regularizer = self.recurrent_regularizer,
                                  constraint = self.recurrent_constraint)
        
        self.V_ = self.add_weight(shape = (),
                                  name = 'V_',
                                  initializer = self.recurrent_initializer,
                                  regularizer = self.recurrent_regularizer,
                                  constraint = self.recurrent_constraint)
        
        self.b_y = self.add_weight(shape = (),
                                  name = 'b_y',
                                  initializer = self.bias_initializer,
                                  regularizer = self.bias_regularizer,
                                  constraint = self.bias_constraint)
        """"
        *****Documentation example*****
        
        self.kernel = self.add_weight(name = 'kernel', 
                                      shape = (input_shape[1], self.output_dim),
                                      initializer = 'uniform',
                                      trainable = True)
        
        """"
        
        self.input_spec = [InputSpec(shape = (self.batch_size, self.timesteps, self.input_dim))]
        
        super(AttentionGRU, self).build(input_shape)
        # OR self.built = True
        
    def call(self, x):
        #layer logic resides here
        
        #store the entire encoded sequence with the cell.., so that it can be accessed later 
        self.x_seq = x
        
        # constant computation
        # does not depend on any previous steps
        # shape would certainly be the MAXLEN
        
        self.Wh_hi = _time_distributed_dense(self.x_seq, self.W_h, b = self.b_a, input_dim = ?, timesteps = ?, output_dim = ?)
        
        return super(AttentionGRU, self).call(x)

    def get_initial_state(self, inputs):
        
        
        
        return [y0, s0]
        
        
    def step(self, x, states):
        
        # most important part of the code
        # executes the cell logic
        # step is applied to every element of the input sequence
        
        # need to initialize yt_prev, st_prev
        
        yt_prev, st_prev = states
        
        # update gate eq.
        z_t = activations.sigmoid(K.dot(yt_prev, self.W_z) + K.dot(yt_prev, self.U_z) + self.b_z)
        
        # reset gate eq.
        r_t = activitions.sigmoid(K.dot(yt_prev, self.W_r) + K.dot(yt_prev, self.U_z) + self.b_r)
        
        # memory content eq.
        c_t = activations.tanh(K.dot(yt_prev, self.W_c) + K.dot((r_t * st_prev), self.U_c) + self.b_c)
        
        # new hidden state 
        st = (1 - z_t)*st_prev + z_t*c_t
        
        st = K.repeat(st, self.timesteps)
        
        Ws_st = K.dot(st, self.W_s)
        
        # eq(3)
        e_t = K.dot(activations.tanh(Ws_st + self.Wh_hi), K.expand_dims(self.Mu_T))
        
        # eq(4)
        a_t = activations.softmax(et)
        
        # eq(5)
        C = K.squeeze(K.batch_dot(at, self.x_seq, axes = 1), axis = 1)
        
        # look into the [;] operation
        yt = activations.softmax(K.dot(activations.tanh(K.dot(K.concat(st,C),V))),V_)
    
    def compute_output_shape(self, input_shape):
        #specify shape transformation logic here
        
        if self.return_probabilities:
            return (None, self.timesteps, self.timesteps)
        else
            return (None, self.timesteps, self.output_dim)
    
    def get_config():
        # allows us to load the model using just a saved file.. once the training is done
        # currently.. ignore it
        # if possible.. implement later
        
        config = {
            'output_dim' : self.output_dim,
            'units' : self.units,
            'return_probabilities' : self.return_probabilities
        }
        
        base_config = super(AttentionGRU, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

# check if it compiles

if __name__ == '__main__':
    
    #what's the shape of the input
    i = Input(shape = (100,200), dtype = 'float32')
    
    enc = Bidirectional(LSTM(64, return_sequences=True), merge_mode = 'concat')(i)
    dec = AttentionDecoder(32, 4)(enc)
    
    model = Model(inputs=i, outputs=dec)
    model.summary()

Using TensorFlow backend.
