In [2]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import math

## The Embedding Layer

In [4]:
class Embeddings(tf.keras.layers.Layer):
    def __init__(self, vocab_size, d_model):
        """
        Args:
          vocab_size:     size of vocabulary
          d_model:        dimension of embeddings
        """
        super(Embeddings, self).__init__()

        # embedding look-up table
        self.lut = tf.keras.layers.Embedding(vocab_size, d_model)

        # dimension of embeddings
        self.d_model = d_model

    def call(self, x):
        """
        Args:
          x:              input Tensor (batch_size, seq_length)
        Returns:
                          embedding vector
        """
        # embeddings by constant sqrt(d_model)
        return self.lut(x) * tf.math.sqrt(tf.cast(self.d_model, tf.float32))


## Positional Encodong

In [5]:
class PositionalEncoding(layers.Layer):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model
        self.max_seq_length = max_seq_length

        pe_init = np.zeros((max_seq_length, d_model), dtype=np.float32)
        position = np.arange(0, max_seq_length, dtype=np.float32)[:, np.newaxis]
        div_term = np.power(10_000, (-np.arange(0, d_model, 2, dtype=np.float32) / d_model))

        pe_init[:, 0::2] = np.sin(position * div_term)
        pe_init[:, 1::2] = np.cos(position * div_term)
        pe_init = pe_init[np.newaxis, :]

        self.pe = tf.cast(pe_init, dtype=tf.float32)

    def call(self, x):
        seq_length = tf.shape(x)[1]
        return x + self.pe[:, :seq_length, :]

## Multi-Head Attention

In [21]:
class MultiHeadAttention(layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = layers.Dense(d_model)
        self.W_k = layers.Dense(d_model)
        self.W_v = layers.Dense(d_model)
        self.W_o = layers.Dense(d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = tf.matmul(Q, K, transpose_b=True) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = tf.where(mask, -1e9, attn_scores)

        attn_probs = tf.nn.softmax(attn_scores, axis=-1)
        output = tf.matmul(attn_probs, V)
        return output

    def split_heads(self, x):
        batch_size, seq_len, d_model = x.shape
        return tf.transpose(tf.reshape(x, (batch_size, seq_len, self.num_heads, self.d_k)), perm=[0, 2, 1, 3])

    def combine_heads(self, x):
        batch_size, _, seq_len, d_k = x.shape
        return tf.reshape(tf.transpose(x, perm=[0, 2, 1, 3]), (batch_size, seq_len, self.d_model))

    def call(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output

## Position-Wise Feed-Forward Network

In [22]:

class PositionWiseFeedForward(layers.Layer):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = layers.Dense(d_ff, activation='relu')
        self.fc2 = layers.Dense(d_model)

    def call(self, x):
        return self.fc2(self.fc1(x))

## Encoder Layer

In [23]:
class EncoderLayer(layers.Layer):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = layers.LayerNormalization()
        self.norm2 = layers.LayerNormalization()
        self.dropout = layers.Dropout(dropout)

    def call(self, x, mask=None):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x
     

## Decoder Layer

In [24]:

class DecoderLayer(layers.Layer):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = layers.LayerNormalization()
        self.norm2 = layers.LayerNormalization()
        self.norm3 = layers.LayerNormalization()
        self.dropout = layers.Dropout(dropout)

    def call(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

## Transformer Model

In [25]:
class Transformer(tf.keras.Model):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout=0.1):
        super(Transformer, self).__init__()
        self.encoder_embedding = layers.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = layers.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
        self.encoder_layers = [
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ]
        self.decoder_layers = [
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ]
        self.fc = layers.Dense(tgt_vocab_size)
        self.dropout = layers.Dropout(dropout)

    def generate_mask(self, src, tgt):
        src_mask = tf.expand_dims(tf.expand_dims(src != 0, axis=1), axis=2)
        tgt_mask = tf.expand_dims(tf.expand_dims(tgt != 0, axis=1), axis=3)
        seq_length = tgt.shape[1]
        nopeak_mask = tf.cast(tf.experimental.numpy.tril(tf.ones((1, seq_length, seq_length)), k=0), dtype=tf.bool)
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def call(self, inputs):
        src, tgt = inputs
        src_mask, tgt_mask = self.generate_mask(src, tgt)

        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

In [26]:
src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

# Generate random sample data
src_data = tf.random.uniform((64, max_seq_length), minval=1, maxval=src_vocab_size+1, dtype=tf.int32) # (batch_size, seq_length)
tgt_data = tf.random.uniform((64, max_seq_length), minval=1, maxval=tgt_vocab_size+1, dtype=tf.int32) # (batch_size, seq_length)
     

In [27]:
def masked_loss(label, pred):
  mask = label != 0
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')
  loss = loss_object(label, pred)

  mask = tf.cast(mask, dtype=loss.dtype)
  loss *= mask

  loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)
  return loss


def masked_accuracy(label, pred):
  pred = tf.argmax(pred, axis=2)
  label = tf.cast(label, pred.dtype)
  match = label == pred

  mask = label != 0

  match = match & mask

  match = tf.cast(match, dtype=tf.float32)
  mask = tf.cast(mask, dtype=tf.float32)
  return tf.reduce_sum(match)/tf.reduce_sum(mask)

In [28]:
transformer.compile(
    optimizer='adam',
    loss=masked_loss,
    metrics=[masked_accuracy]
)
    

In [None]:

transformer.fit(x=(src_data, tgt_data[:, :-1]), y=tgt_data[:, 1:], epochs=10)