In [3]:
# Import TensorFlow and NumPy libraries
import tensorflow as tf
import numpy as np


In [4]:
def get_positional_encoding(seq_len, d_model):
    """
    Generates positional encoding using sinusoidal patterns to indicate token position.

    Args:
        seq_len (int): Length of the sequence
        d_model (int): Embedding dimension size

    Returns:
        tf.Tensor: A tensor of shape (seq_len, d_model) with positional encodings
    """
    # Compute position and divide terms for the encoding
    position = np.arange(seq_len)[:, np.newaxis]
    div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))

    # Apply sine and cosine functions to even and odd indices respectively
    pos_encoding = np.zeros((seq_len, d_model))
    pos_encoding[:, 0::2] = np.sin(position * div_term)
    pos_encoding[:, 1::2] = np.cos(position * div_term)

    return tf.constant(pos_encoding, dtype=tf.float32)


In [6]:
class TransformerXL:
    """
    A basic Transformer-XL model for long-context sequence modeling.
    """
    def __init__(self, d_model, num_heads, num_layers, memory_len, max_len):
        """
        Initializes model hyperparameters and weights.

        Args:
            d_model (int): Dimension of the model embeddings
            num_heads (int): Number of attention heads
            num_layers (int): Number of Transformer layers
            memory_len (int): Length of memory to retain past context
            max_len (int): Maximum length of input sequences
        """
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.memory_len = memory_len
        self.max_len = max_len

        # Positional encoding for each token position
        self.positional_encoding = get_positional_encoding(max_len, d_model)

        # Initialize weight matrices for query, key, value, and output transformations
        self.wq = tf.Variable(tf.random.normal([d_model, d_model]), trainable=True)
        self.wk = tf.Variable(tf.random.normal([d_model, d_model]), trainable=True)
        self.wv = tf.Variable(tf.random.normal([d_model, d_model]), trainable=True)
        self.wo = tf.Variable(tf.random.normal([d_model, d_model]), trainable=True)

        # Feed-forward network weights for processing outputs
        self.ffn_w1 = tf.Variable(tf.random.normal([d_model, d_model * 4]), trainable=True)
        self.ffn_w2 = tf.Variable(tf.random.normal([d_model * 4, d_model]), trainable=True)

    def split_heads(self, x):
        """
        Splits the embedding tensor into multiple attention heads.

        Args:
            x (tf.Tensor): Input tensor of shape (batch_size, seq_len, d_model)

        Returns:
            tf.Tensor: Reshaped tensor with multiple heads for parallel attention
        """
        batch_size, seq_len, depth = x.shape
        x = tf.reshape(x, (batch_size, seq_len, self.num_heads, depth // self.num_heads))
        return tf.transpose(x, perm=[0, 2, 1, 3])  # [batch_size, num_heads, seq_len, depth_per_head]

    def attention(self, q, k, v):
        """
        Calculates scaled dot-product attention scores and applies them.

        Args:
            q, k, v (tf.Tensor): Query, key, and value tensors for attention

        Returns:
            tf.Tensor: Output after applying attention scores
        """
        dk = tf.cast(tf.shape(k)[-1], tf.float32)
        scaled_logits = tf.matmul(q, k, transpose_b=True) / tf.math.sqrt(dk)
        attention_weights = tf.nn.softmax(scaled_logits, axis=-1)
        return tf.matmul(attention_weights, v)

    def forward(self, x, memory):
        """
        Forward pass of the Transformer-XL model, with memory handling.

        Args:
            x (tf.Tensor): Input tensor of shape (batch_size, seq_len, d_model)
            memory (tf.Tensor): Memory tensor of past context to extend sequence length

        Returns:
            tf.Tensor: Model output after self-attention and memory update
        """
        # Add positional encoding to input
        seq_len = tf.shape(x)[1]
        x = x + self.positional_encoding[:seq_len, :]

        for layer in range(self.num_layers):
            # Combine memory from previous segments with current sequence
            x_combined = tf.concat([memory, x], axis=1)

            # Compute queries, keys, and values
            q = tf.matmul(x, self.wq)
            k = tf.matmul(x_combined, self.wk)
            v = tf.matmul(x_combined, self.wv)

            # Split for multi-head attention
            q, k, v = self.split_heads(q), self.split_heads(k), self.split_heads(v)

            # Apply attention mechanism
            attention_output = self.attention(q, k, v)
            attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
            concat_attention = tf.reshape(attention_output, (x.shape[0], -1, self.d_model))
            attention_output = tf.matmul(concat_attention, self.wo)

            # Add residual connection and layer normalization
            layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
            x = layernorm1(x + attention_output)

            # Feed-forward layer with activation and layer normalization
            ffn_output = tf.nn.relu(tf.matmul(x, self.ffn_w1))
            ffn_output = tf.matmul(ffn_output, self.ffn_w2)
            layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
            x = layernorm2(x + ffn_output)

            # Update memory with current output, trimming to memory_len
            memory = tf.concat([memory, x], axis=1)[:, -self.memory_len:]

        return x, memory  # Return output and updated memory


In [7]:
# Define model parameters
d_model = 512          # Dimension of token embeddings
num_heads = 8          # Number of attention heads
num_layers = 4         # Number of transformer layers
memory_len = 128       # Length of memory for context
max_len = 512          # Maximum sequence length

# Create an instance of the Transformer-XL model
model = TransformerXL(d_model, num_heads, num_layers, memory_len, max_len)

# Example input data and initial memory
input_data = tf.random.normal([1, 100, d_model])  # (batch_size, seq_len, d_model)
memory = tf.zeros([1, 0, d_model])  # Initial empty memory

# Perform a forward pass
output, updated_memory = model.forward(input_data, memory)


In [8]:
# Print the output and updated memory shapes
print("Model Output Shape:", output.shape)          # Expected shape: (batch_size, seq_len, d_model)
print("Updated Memory Shape:", updated_memory.shape)  # Expected shape: (batch_size, memory_len, d_model)


Model Output Shape: (1, 100, 512)
Updated Memory Shape: (1, 128, 512)
