In [1]:
import tensorflow as tf
import gin.tf
import numpy as np
import functools
from dataclasses import dataclass
from typing import Optional

In [2]:
gin.enter_interactive_mode()

### SelfAttentionLayer

In [None]:
class SelfAttentionLayer(tf.keras.layers.Layer):

    def __init__(self, heads_count: int, attention_head_dimension: int, dropout_rate: float = 0.0):
        super(SelfAttentionLayer, self).__init__()
        self.attention_layer = tf.keras.layers.MultiHeadAttention(
            num_heads=heads_count,
            key_dim=attention_head_dimension,
            dropout=dropout_rate,
        )
    
    def call(self, inputs, mask=None, training=None):
        return self.attention_layer(query=inputs, value=inputs, attention_mask=mask, training=training)

In [None]:
layer = SelfAttentionLayer(heads_count=4, attention_head_dimension=256, dropout_rate=0.5)
embeddings = tf.random.uniform(shape=(64, 10, 256))
layer(embeddings).shape

TensorShape([64, 10, 256])

### PointwiseFeedforwardLayer

In [None]:
class PointwiseFeedforwardLayer(tf.keras.layers.Layer):

    def __init__(self, hidden_layer_dimension: int):
        super(PointwiseFeedforwardLayer, self).__init__()
        self.hidden_layer_dimension = hidden_layer_dimension
        self.dense_layer1 = None
        self.dense_layer2 = None

    def build(self, input_shape):
        self.dense_layer1 = tf.keras.layers.Dense(
            self.hidden_layer_dimension, activation="relu"
        )
        self.dense_layer2 = tf.keras.layers.Dense(units=input_shape[-1])

    def call(self, inputs, training=None):
        outputs = self.dense_layer1(inputs, training=training)
        return self.dense_layer2(outputs, training=training)

In [None]:
layer = PointwiseFeedforwardLayer(hidden_layer_dimension=512)
layer(tf.random.uniform(shape=(32, 10, 256))).shape

TensorShape([32, 10, 256])

### TransformerEncoderLayer

In [None]:
class TransformerEncoderLayer(tf.keras.layers.Layer):

    def __init__(
        self,
        attention_heads_count: int,
        attention_head_dimension: int,
        pointwise_hidden_layer_dimension: int,
        dropout_rate: float = 0.0,
    ):
        super(TransformerEncoderLayer, self).__init__()
        self.attention_layer = SelfAttentionLayer(
            heads_count=attention_heads_count,
            attention_head_dimension=attention_head_dimension,
            dropout_rate=dropout_rate,
        )
        self.pointwise_layer = PointwiseFeedforwardLayer(
            hidden_layer_dimension=pointwise_hidden_layer_dimension,
        )
        self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout_layer1 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout_layer2 = tf.keras.layers.Dropout(dropout_rate)

    def call(self, inputs, training=None):
        attention_outputs = self.attention_layer(inputs, training=training)
        attention_outputs = self.dropout_layer1(attention_outputs, training=training)
        attention_outputs = self.layer_norm1(inputs + attention_outputs)
        pointwise_outputs = self.pointwise_layer(attention_outputs)
        pointwise_outputs = self.dropout_layer2(pointwise_outputs, training=training)
        return self.layer_norm2(attention_outputs + pointwise_outputs)

In [None]:
layer = TransformerEncoderLayer(
    attention_heads_count=4,
    attention_head_dimension=512,
    pointwise_hidden_layer_dimension=512,
    dropout_rate=0.5,
)
layer(tf.random.uniform(shape=(32, 10, 256))).shape, layer.count_params()

(TensorShape([32, 10, 256]), 2367488)

### StackedTransformerEncoderLayers

In [None]:
class StackedTransformerEncodersLayer(tf.keras.layers.Layer):

    def __init__(
            self,
            layers_count: int = gin.REQUIRED,
            attention_heads_count: int = gin.REQUIRED,
            attention_head_dimension: int = gin.REQUIRED,
            pointwise_hidden_layer_dimension: int = gin.REQUIRED,
            dropout_rate: float = gin.REQUIRED,
    ):
        super(StackedTransformerEncodersLayer, self).__init__()
        encoder_layer_initializer = functools.partial(
            TransformerEncoderLayer,
            attention_heads_count=attention_heads_count,
            attention_head_dimension=attention_head_dimension,
            pointwise_hidden_layer_dimension=pointwise_hidden_layer_dimension,
            dropout_rate=dropout_rate,
        )
        self.sublayers = [encoder_layer_initializer() for _ in range(layers_count)]

    def call(self, inputs, training=None):
        outputs = inputs
        for sublayer in self.sublayers:
            outputs = sublayer(outputs, training=training)
        return outputs

In [None]:
layer = StackedTransformerEncodersLayer(
    layers_count=12,
    attention_heads_count=8,
    attention_head_dimension=512,
    pointwise_hidden_layer_dimension=2048,
    dropout_rate=0.5,
)
layer(tf.random.uniform(shape=(32, 10, 512))).shape, layer.count_params()

(TensorShape([32, 10, 512]), 126038016)

### Embeddings layers

In [None]:
class PositionEmbeddingsLayer(tf.keras.layers.Layer):

    def __init__(self, max_inputs_length: int, use_fourier_series: bool, trainable: bool):
        super(PositionEmbeddingsLayer, self).__init__()
        self.max_inputs_length = max_inputs_length
        self.use_fourier_series = use_fourier_series
        self.trainable = trainable
        self.position_embeddings = None

    def _get_fourier_angles(self, embeddings_dimension):
        input_positions = np.arange(self.max_inputs_length).reshape((-1, 1))
        embedding_positions = np.arange(embeddings_dimension).reshape((1, -1))
        relative_embeddings_positions = (2.0 * (embedding_positions // 2)) / embeddings_dimension
        return input_positions / np.power(10000, relative_embeddings_positions)

    def _get_fourier_positional_embeddings(self, embeddings_dimension):
        angles = self._get_fourier_angles(embeddings_dimension)
        positional_embeddings = np.zeros(angles.shape)
        positional_embeddings[:, 0::2] = np.sin(angles[:, 0::2])
        positional_embeddings[:, 1::2] = np.cos(angles[:, 1::2])
        return positional_embeddings

    def _get_initial_embeddings(self, embeddings_dimension):
        if self.use_fourier_series:
            return self._get_fourier_positional_embeddings(embeddings_dimension)
        return tf.random.truncated_normal(shape=(self.max_inputs_length, embeddings_dimension))

    def build(self, input_shape):
        initial_embeddings = self._get_initial_embeddings(
            embeddings_dimension=input_shape[-1]
        )
        self.position_embeddings = tf.Variable(
            initial_embeddings,
            name='position_embeddings',
            trainable=self.trainable,
        )

    def call(self, inputs, training=None):
        inputs_length = tf.shape(inputs)[-2]
        chosen_embeddings = self.position_embeddings[:inputs_length, :]
        return tf.broadcast_to(chosen_embeddings, inputs.shape)

In [None]:
layer = PositionEmbeddingsLayer(
    max_inputs_length=12,
    use_fourier_series=True,
    trainable=True,
)
layer(tf.random.uniform(shape=(64, 10, 256))).shape

TensorShape([64, 10, 256])