<a href="https://colab.research.google.com/github/greyhound101/Multihead_attention/blob/master/local_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
class Attention(Layer):
    
    def __init__(self, context='many-to-many', alignment_type='global', window_width=None,
                 score_function='general', model_api='functional', **kwargs):
        if context not in ['many-to-many', 'many-to-one']:
            raise ValueError("Argument for param @context is not recognized")
        if alignment_type not in ['global', 'local-m', 'local-p', 'local-p*']:
            raise ValueError("Argument for param @alignment_type is not recognized")
        if alignment_type == 'global' and window_width is not None:
            raise ValueError("Can't use windowed approach with global attention")
        if context == 'many-to-many' and alignment_type == 'local-p*':
            raise ValueError("Can't use local-p* approach in many-to-many scenarios")
        if score_function not in ['dot', 'general', 'location', 'concat', 'scaled_dot']:
            raise ValueError("Argument for param @score_function is not recognized")
        if model_api not in ['sequential', 'functional']:
            raise ValueError("Argument for param @model_api is not recognized")
        super(Attention, self).__init__(**kwargs)
        self.context = context
        self.alignment_type = alignment_type
        self.window_width = window_width  # D
        self.score_function = score_function
        self.model_api = model_api

    def get_config(self):
        base_config = super(Attention, self).get_config()
        base_config['alignment_type'] = self.alignment_type
        base_config['window_width'] = self.window_width
        base_config['score_function'] = self.score_function
        base_config['model_api'] = self.model_api
        return base_config

    def build(self, input_shape):
        # Declare attributes for easy access to dimension values
        if self.context == 'many-to-many':
            self.input_sequence_length, self.hidden_dim = input_shape[0][1], input_shape[0][2]
            self.target_sequence_length = input_shape[1][1]
        elif self.context == 'many-to-one':
            self.input_sequence_length, self.hidden_dim = input_shape[0][1], input_shape[0][2]

        # Build weight matrices for different alignment types and score functions
        if 'local-p' in self.alignment_type:
            self.W_p = Dense(units=self.hidden_dim, use_bias=False)
            self.W_p.build(input_shape=(None, None, self.hidden_dim))                               # (B, 1, H)
            self._trainable_weights += self.W_p.trainable_weights

            self.v_p = Dense(units=1, use_bias=False)
            self.v_p.build(input_shape=(None, None, self.hidden_dim))                               # (B, 1, H)
            self._trainable_weights += self.v_p.trainable_weights

        if 'dot' not in self.score_function:  # weight matrix not utilized for 'dot' function
            self.W_a = Dense(units=self.hidden_dim, use_bias=False)
            self.W_a.build(input_shape=(None, None, self.hidden_dim))                               # (B, S*, H)
            self._trainable_weights += self.W_a.trainable_weights

        if self.score_function == 'concat':  # define additional weight matrices
            self.U_a = Dense(units=self.hidden_dim, use_bias=False)
            self.U_a.build(input_shape=(None, None, self.hidden_dim))                               # (B, 1, H)
            self._trainable_weights += self.U_a.trainable_weights

            self.v_a = Dense(units=1, use_bias=False)
            self.v_a.build(input_shape=(None, None, self.hidden_dim))                               # (B, S*, H)
            self._trainable_weights += self.v_a.trainable_weights

        super(Attention, self).build(input_shape)

    def call(self, inputs):
        # Pass decoder output (prev. timestep) alongside encoder output for all scenarios
        if not isinstance(inputs, list):
            raise ValueError("Pass a list=[encoder_out (Tensor), decoder_out (Tensor)," +
                             "current_timestep (int)] for all scenarios")

        target_hidden_state = inputs[1]                                                         # (B, H)
        current_timestep = inputs[2]
        source_hidden_states = inputs[0]                                                        # (B, S, H)

        target_hidden_state = tf.expand_dims(input=target_hidden_state, axis=1)                     # (B, 1, H)

        self.window_width = 8 if self.window_width is None else self.window_width
        aligned_position = self.W_p(target_hidden_state)                                    # (B, 1, H)
        aligned_position = Activation('tanh')(aligned_position)                             # (B, 1, H)
        aligned_position = self.v_p(aligned_position)                                       # (B, 1, 1)
        aligned_position = Activation('sigmoid')(aligned_position)                          # (B, 1, 1)
        aligned_position = aligned_position * self.input_sequence_length                    # (B, 1, 1)

        attention_score = Dot(axes=[2, 2])([source_hidden_states, target_hidden_state])         # (B, S*, 1)
        if self.score_function == 'scaled_dot':
                attention_score *= 1 / np.sqrt(float(source_hidden_states.shape[2]))                # (B, S*, 1)

        attention_weights = Activation('softmax')(attention_score)                                  # (B, S*, 1)

        gaussian_estimation = lambda s: tf.exp(-tf.square(s - aligned_position) /
                                                   (2 * tf.square(self.window_width / 2)))
        gaussian_factor = gaussian_estimation(0)
        for i in range(1, self.input_sequence_length):
                gaussian_factor = Concatenate(axis=1)([gaussian_factor, gaussian_estimation(i)])    # (B, S*, 1)
        attention_weights = attention_weights * gaussian_factor                                 # (B, S*, 1)

        context_vector = source_hidden_states * attention_weights                                   # (B, S*, H)

        if self.model_api == 'functional':
            return context_vector, attention_weights
        elif self.model_api == 'sequential':
            return context_vector