<a href="https://colab.research.google.com/github/DevyanshBhattacharya/Self_Attention_for_transformer/blob/main/Transformer_from_tensorflow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import numpy as np


In [None]:
for gpu in tf.config.experimental.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(gpu, True)

def positional_encoding(length: int, depth: int):
    depth = depth / 2

    positions = np.arange(length)[:, np.newaxis]     # (seq, 1)
    depths = np.arange(depth)[np.newaxis, :]/depth   # (1, depth)

    angle_rates = 1 / (10000**depths)         # (1, depth)
    angle_rads = positions * angle_rates      # (pos, depth)

    pos_encoding = np.concatenate([np.sin(angle_rads), np.cos(angle_rads)], axis=-1)

    return tf.cast(pos_encoding, dtype=tf.float32)

In [None]:
class PositionalEmbedding(tf.keras.layers.Layer):

    def __init__(self, vocab_size: int, d_model: int, embedding: tf.keras.layers.Embedding=None):

        super().__init__()
        self.d_model = d_model
        self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True) if embedding is None else embedding
        self.pos_encoding = positional_encoding(length=2048, depth=d_model)

    def compute_mask(self, *args, **kwargs):

        return self.embedding.compute_mask(*args, **kwargs)

    def call(self, x: tf.Tensor) -> tf.Tensor:

        x = self.embedding(x)
        length = tf.shape(x)[1]
        # This factor sets the relative scale of the embedding and positonal_encoding.
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x = x + self.pos_encoding[tf.newaxis, :length, :]
        return x

In [None]:
class EncoderLayer(tf.keras.layers.Layer):

    def __init__(self, d_model: int, num_heads: int, dff: int, dropout_rate: float=0.1):

        super().__init__()

        self.self_attention = GlobalSelfAttention(
            num_heads=num_heads,
            key_dim=d_model,
            dropout=dropout_rate
            )

        self.ffn = FeedForward(d_model, dff)

    def call(self, x: tf.Tensor) -> tf.Tensor:

        x = self.self_attention(x)
        x = self.ffn(x)
        return x

In [None]:
class BaseAttention(tf.keras.layers.Layer):

    def __init__(self, **kwargs: dict):

        super().__init__()
        self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
        self.layernorm = tf.keras.layers.LayerNormalization()
        self.add = tf.keras.layers.Add()

In [None]:
class CrossAttention(BaseAttention):

    def call(self, x: tf.Tensor, context: tf.Tensor) -> tf.Tensor:

        attn_output, attn_scores = self.mha(query=x, key=context, value=context, return_attention_scores=True)

        # Cache the attention scores for plotting later.
        self.last_attn_scores = attn_scores

        x = self.add([x, attn_output])
        x = self.layernorm(x)

        return x

In [None]:
encoder_vocab_size = 1000
decoder_vocab_size = 1100
d_model = 512

encoder_embedding_layer = PositionalEmbedding(1000, d_model)
decoder_embedding_layer = PositionalEmbedding(1000, d_model)

random_encoder_input = np.random.randint(0, encoder_vocab_size, size=(1, 100))
decoder_embedding_layer = PositionalEmbedding(1100, d_model)
encoder_embeddings = encoder_embedding_layer(random_encoder_input)
decoder_embeddings = decoder_embedding_layer(random_decoder_input)

print("encoder_embeddings shape", encoder_embeddings.shape)
print("decoder_embeddings shape", decoder_embeddings.shape)

cross_attention_layer = CrossAttention(num_heads=2, key_dim=512)
cross_attention_output = cross_attention_layer(decoder_embeddings, encoder_embeddings)

print("cross_attention_output shape", cross_attention_output.shape)

encoder_embeddings shape (1, 100, 512)
decoder_embeddings shape (1, 110, 512)
cross_attention_output shape (1, 110, 512)


In [None]:
encoder_embeddings.shape (1, 100, 512)
decoder_embeddings.shape (1, 110, 512)
cross_attention_output.shape (1, 110, 512)

TypeError: 'TensorShape' object is not callable

In [None]:
class GlobalSelfAttention(BaseAttention):

    def call(self, x: tf.Tensor) -> tf.Tensor:

        attn_output = self.mha(query=x, value=x, key=x)
        x = self.add([x, attn_output])
        x = self.layernorm(x)
        return x

In [None]:
encoder_vocab_size = 1000
d_model = 512

encoder_embedding_layer = PositionalEmbedding(1000, d_model)

random_encoder_input = np.random.randint(0, encoder_vocab_size, size=(1, 100))

encoder_embeddings = encoder_embedding_layer(random_encoder_input)

print("encoder_embeddings shape", encoder_embeddings.shape)

cross_attention_layer = GlobalSelfAttention(num_heads=2, key_dim=512)
cross_attention_output = cross_attention_layer(encoder_embeddings)

print("global_self_attention_output shape", cross_attention_output.shape)

encoder_embeddings shape (1, 100, 512)
global_self_attention_output shape (1, 100, 512)


In [None]:
encoder_embeddings shape (1, 100, 512)
global_self_attention_output shape (1, 100, 512)

In [None]:
class CausalSelfAttention(BaseAttention):

    def call(self, x: tf.Tensor) -> tf.Tensor:

        attn_output = self.mha(query=x, value=x, key=x, use_causal_mask = True)
        x = self.add([x, attn_output])
        x = self.layernorm(x)
        return x

In [None]:
decoder_vocab_size = 1100
d_model = 512

decoder_embedding_layer = PositionalEmbedding(1000, d_model)

random_decoder_input = np.random.randint(0, decoder_vocab_size, size=(1, 110))

decoder_embeddings = decoder_embedding_layer(random_decoder_input)

print("decoder_embeddings shape", decoder_embeddings.shape)

causal_self_attention_layer = CausalSelfAttention(num_heads=2, key_dim=512)
causal_self_attention_output = causal_self_attention_layer(decoder_embeddings)

print("causal_self_attention_output shape", causal_self_attention_output.shape)

out1 = causal_self_attention_layer(decoder_embedding_layer(random_decoder_input[:, :50])) # Only the first 50 tokens beffore applying the embedding layer
out2 = causal_self_attention_layer(decoder_embedding_layer(random_decoder_input)[:, :50]) # Only the first 50 tokens after applying the embedding layer

diff = tf.reduce_max(tf.abs(out1 - out2)).numpy()

print("Difference between the two outputs:", diff)

In [None]:
decoder_embeddings shape (1, 110, 512)
causal_self_attention_output shape (1, 110, 512)
Difference between the two outputs: 0.0

In [None]:
class FeedForward(tf.keras.layers.Layer):

    def __init__(self, d_model: int, dff: int, dropout_rate: float=0.1):

        super().__init__()
        self.seq = tf.keras.Sequential([
            tf.keras.layers.Dense(dff, activation='relu'),
            tf.keras.layers.Dense(d_model),
            tf.keras.layers.Dropout(dropout_rate)
        ])
        self.add = tf.keras.layers.Add()
        self.layer_norm = tf.keras.layers.LayerNormalization()

    def call(self, x: tf.Tensor) -> tf.Tensor:

        x = self.add([x, self.seq(x)])
        x = self.layer_norm(x)
        return x

In [None]:
encoder_vocab_size = 1000
d_model = 512

encoder_embedding_layer = PositionalEmbedding(1000, d_model)

random_encoder_input = np.random.randint(0, encoder_vocab_size, size=(1, 100))

encoder_embeddings = encoder_embedding_layer(random_encoder_input)

print("encoder_embeddings shape", encoder_embeddings.shape)

feed_forward_layer = FeedForward(d_model, dff=2048)
feed_forward_output = feed_forward_layer(encoder_embeddings)

print("feed_forward_output shape", feed_forward_output.shape)

encoder_embeddings shape (1, 100, 512)
feed_forward_output shape (1, 100, 512)


In [None]:
encoder_vocab_size = 1000
d_model = 512

encoder_embedding_layer = PositionalEmbedding(1000, d_model)

random_encoder_input = np.random.randint(0, encoder_vocab_size, size=(1, 100))

encoder_embeddings = encoder_embedding_layer(random_encoder_input)

print("encoder_embeddings shape", encoder_embeddings.shape)

encoder_layer = EncoderLayer(d_model, num_heads=2, dff=2048)

encoder_layer_output = encoder_layer(encoder_embeddings)

print("encoder_layer_output shape", encoder_layer_output.shape)

encoder_embeddings shape (1, 100, 512)
encoder_layer_output shape (1, 100, 512)


In [None]:
encoder_embeddings shape (1, 100, 512)
encoder_layer_output shape (1, 100, 512)

In [None]:
class Encoder(tf.keras.layers.Layer):

    def __init__(self, num_layers: int, d_model: int, num_heads: int, dff: int, vocab_size: int, dropout_rate: float=0.1):

        super().__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size, d_model=d_model)

        self.enc_layers = [
            EncoderLayer(d_model=d_model,
                        num_heads=num_heads,
                        dff=dff,
                        dropout_rate=dropout_rate)
            for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(dropout_rate)

    def call(self, x: tf.Tensor) -> tf.Tensor:

        x = self.pos_embedding(x)
        # here x has shape `(batch_size, seq_len, d_model)`

        # Add dropout.
        x = self.dropout(x)

        for i in range(self.num_layers):
            x = self.enc_layers[i](x)

        return x  # Shape `(batch_size, seq_len, d_model)`

In [None]:
encoder_vocab_size = 1000
d_model = 512

encoder = Encoder(num_layers=2, d_model=d_model, num_heads=2, dff=2048, vocab_size=encoder_vocab_size)

random_encoder_input = np.random.randint(0, encoder_vocab_size, size=(1, 100))

encoder_output = encoder(random_encoder_input)

print("random_encoder_input shape", random_encoder_input.shape)
print("encoder_output shape", encoder_output.shape)

random_encoder_input shape (1, 100)
encoder_output shape (1, 100, 512)


In [None]:
random_encoder_input shape (1, 100)
encoder_output shape (1, 100, 512)

In [None]:
class DecoderLayer(tf.keras.layers.Layer):

    def __init__(self, d_model: int, num_heads: int, dff: int, dropout_rate: float=0.1):

        super(DecoderLayer, self).__init__()

        self.causal_self_attention = CausalSelfAttention(
            num_heads=num_heads,
            key_dim=d_model,
            dropout=dropout_rate)

        self.cross_attention = CrossAttention(
            num_heads=num_heads,
            key_dim=d_model,
            dropout=dropout_rate)

        self.ffn = FeedForward(d_model, dff)

    def call(self, x: tf.Tensor, context: tf.Tensor) -> tf.Tensor:

        x = self.causal_self_attention(x=x)
        x = self.cross_attention(x=x, context=context)

        # Cache the last attention scores for plotting later
        self.last_attn_scores = self.cross_attention.last_attn_scores

        x = self.ffn(x)  # Shape `(batch_size, seq_len, d_model)`.
        return x

In [None]:
# Test DecoderLayer layer
decoder_vocab_size = 1000
d_model = 512
dff = 2048
num_heads = 8

decoder_layer = DecoderLayer(d_model, num_heads, dff)

random_decoderLayer_input = np.random.randint(0, decoder_vocab_size, size=(1, 110))

decoder_embeddings = encoder_embedding_layer(random_decoderLayer_input)

decoderLayer_output = decoder_layer(decoder_embeddings, encoder_output)

print("random_decoder_input shape", random_decoderLayer_input.shape)
print("decoder_embeddings shape", decoder_embeddings.shape)
print("decoder_output shape", decoderLayer_output.shape)

random_decoder_input shape (1, 110)
decoder_embeddings shape (1, 110, 512)
decoder_output shape (1, 110, 512)


In [None]:
random_decoder_input shape (1, 110)
decoder_embeddings shape (1, 110, 512)
decoder_output shape (1, 110, 512)

In [None]:
class Decoder(tf.keras.layers.Layer):

    def __init__(self, num_layers: int, d_model: int, num_heads: int, dff: int, vocab_size: int, dropout_rate: float=0.1):

        super(Decoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size, d_model=d_model)
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        self.dec_layers = [
            DecoderLayer(
                d_model=d_model,
                num_heads=num_heads,
                dff=dff,
                dropout_rate=dropout_rate) for _ in range(num_layers)]

        self.last_attn_scores = None

    def call(self, x: tf.Tensor, context: tf.Tensor) -> tf.Tensor:

        # `x` is token-IDs shape (batch, target_seq_len)
        x = self.pos_embedding(x)  # (batch_size, target_seq_len, d_model)

        x = self.dropout(x)

        for i in range(self.num_layers):
            x  = self.dec_layers[i](x, context)

        self.last_attn_scores = self.dec_layers[-1].last_attn_scores

        # The shape of x is (batch_size, target_seq_len, d_model).
        return x

In [None]:
# Test decoder layer
decoder_vocab_size = 1000
d_model = 512

decoder_layer = Decoder(num_layers=2, d_model=d_model, num_heads=2, dff=2048, vocab_size=decoder_vocab_size)

random_decoder_input = np.random.randint(0, decoder_vocab_size, size=(1, 100))

decoder_output = decoder_layer(random_decoder_input, encoder_output)

print("random_decoder_input shape", random_decoder_input.shape)
print("decoder_output shape", decoder_output.shape)

random_decoder_input shape (1, 100)
decoder_output shape (1, 100, 512)


In [None]:
random_decoder_input shape (1, 100)
decoder_output shape (1, 100, 512)

In [None]:
def Transformer(
    input_vocab_size: int,
    target_vocab_size: int,
    encoder_input_size: int = None,
    decoder_input_size: int = None,
    num_layers: int=6,
    d_model: int=512,
    num_heads: int=8,
    dff: int=2048,
    dropout_rate: float=0.1,
    ) -> tf.keras.Model:

    inputs = [
        tf.keras.layers.Input(shape=(encoder_input_size,), dtype=tf.int64),
        tf.keras.layers.Input(shape=(decoder_input_size,), dtype=tf.int64)
        ]

    encoder_input, decoder_input = inputs

    encoder = Encoder(num_layers=num_layers, d_model=d_model, num_heads=num_heads, dff=dff, vocab_size=input_vocab_size, dropout_rate=dropout_rate)(encoder_input)
    decoder = Decoder(num_layers=num_layers, d_model=d_model, num_heads=num_heads, dff=dff, vocab_size=target_vocab_size, dropout_rate=dropout_rate)(decoder_input, encoder)

    output = tf.keras.layers.Dense(target_vocab_size)(decoder)

    return tf.keras.Model(inputs=inputs, outputs=output)

In [None]:
encoder_input_size = 100
decoder_input_size = 110

encoder_vocab_size = 1000
decoder_vocab_size = 1000

model = Transformer(
    input_vocab_size=encoder_vocab_size,
    target_vocab_size=decoder_vocab_size,
    encoder_input_size=encoder_input_size,
    decoder_input_size=decoder_input_size,
    num_layers=2,
    d_model=512,
    num_heads=2,
    dff=512,
    dropout_rate=0.1)

model.summary()


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 100)]                0         []                            
                                                                                                  
 input_2 (InputLayer)        [(None, 110)]                0         []                            
                                                                                                  
 encoder_1 (Encoder)         (None, 100, 512)             5768192   ['input_1[0][0]']             
                                                                                                  
 decoder_1 (Decoder)         (None, 110, 512)             9971712   ['input_2[0][0]',             
                                                                     'encoder_1[0][0]']       