In [1]:
from keras import layers, models

### Custom Attention Layers

In [None]:
class BaseAttention(layers.Layer):
    def __init__(self, name, **kwargs):
        super().__init__(name=name)
        self.mha = layers.MultiHeadAttention(**kwargs)
        self.layernorm = layers.LayerNormalization()
        self.add = layers.Add()


class CrossAttention(BaseAttention):
    def call(self, x, context):
        attn_vector, attn_scores = self.mha(query=x, key=context, value=context, return_attention_scores=True)
        self.last_attn_scores = attn_scores
        x = self.add([x, attn_vector])
        return self.layernorm(x)


class SelfAttention(BaseAttention):
    def call(self, x):
        attn_vector = self.mha(query=x, value=x, key=x)
        x = self.add([x, attn_vector])
        return self.layernorm(x)


class MaskedSelfAttention(BaseAttention):
    def call(self, x):
        attn_output = self.mha(query=x, value=x, key=x, use_causal_mask=True)
        x = self.add([x, attn_output])
        return self.layernorm(x)

In [None]:
class FeedForward(layers.Layer):
    def __init__(self, d_model, dff, dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.seq = models.Sequential([
            layers.Dense(dff, activation='relu'),
            layers.Dense(d_model),
            layers.Dropout(dropout_rate)
        ])
        self.add = layers.Add()
        self.layer_norm = layers.LayerNormalization()

    def call(self, x):
        x = self.add([x, self.seq(x)])
        return self.layer_norm(x)