In [None]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers import Dense, Input, Embedding, LayerNormalization, Dropout
from tensorflow.keras.models import Model
import numpy as np


In [None]:
def load_text(path: str) -> str:
    """Load raw text and do minimal preprocessing."""
    with open(path, "r", encoding="utf-8") as f:
        data = f.read().replace("\n", " ")
    return data

data = load_text("training_data.txt")


In [None]:
print(len(data))


In [None]:
def build_vocab(data: str):
    """
    Build character set + integer encodings.

    Convention (same spirit as your original):
    - Reserve 0 optionally for padding/unknown (you decide).
    - Map characters to 1..V-1.
    """
    characters = list(set(list(data)))
    vocab_size = len(characters) + 1

    char2idx = {}
    idx2char = {}
    for i, ch in enumerate(characters):
        char2idx[ch] = i + 1
        idx2char[i + 1] = ch

    return characters, vocab_size, char2idx, idx2char

characters, input_vocab_size, character_to_integer_encoding, integer_to_character_encoding = build_vocab(data)
print(len(characters))


In [None]:
def encode(string: str):
    """Convert string -> list of token ids."""
    global character_to_integer_encoding
    return [character_to_integer_encoding[ch] for ch in string]

def decode(lst):
    """Convert list of token ids -> string."""
    global integer_to_character_encoding
    return "".join(integer_to_character_encoding[i] for i in lst)


In [None]:
input_data = encode(data)
train_data = input_data[:int(0.9 * len(input_data))]
test_data  = input_data[int(0.9 * len(input_data)):]


In [None]:
batch_size = 32
block_size = 128
num_heads = 8
num_transformer_blocks = 4
embed_dim = 256
feed_forward_dim = 256
dropout_rate = 0.1


In [None]:
def causal_attention_mask(batch_size, n_dest, n_src):
    """
    Create a causal (lower-triangular) attention mask so position i cannot attend to j > i.
    Returns a boolean mask where True means "allowed to attend".
    """
    # Create a lower triangular matrix of ones
    # Shape: (n_dest, n_src)
    mask = tf.linalg.band_part(tf.ones((n_dest, n_src)), -1, 0)
    # Convert to boolean (True = attend, False = mask out)
    mask = tf.cast(mask, dtype=tf.bool)
    return mask


class TransformerBlock(layers.Layer):
    """
    One transformer block:
    - Causal self-attention
    - Residual + LayerNorm
    - Feed-forward MLP
    - Residual + LayerNorm
    """
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super().__init__()

        # Multi-head self-attention layer
        self.att = layers.MultiHeadAttention(
            num_heads=num_heads,
            key_dim=embed_dim // num_heads
        )

        # Feed-forward network: expand then project back
        self.ffn = tf.keras.Sequential([
            Dense(ff_dim, activation='gelu'),
            Dense(embed_dim)
        ])

        self.normalization_layer_1 = LayerNormalization(epsilon=1e-6)
        self.normalization_layer_2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)

    def call(self, inputs, training=False):
        """
        inputs: (batch, seq_len, embed_dim)
        """
        batch_size = tf.shape(inputs)[0]
        seq_len = tf.shape(inputs)[1]

        # Build causal mask for autoregressive attention
        causal_mask = causal_attention_mask(batch_size, seq_len, seq_len)

        # Self-attention with causal mask
        attn_output = self.att(
            query=inputs,
            value=inputs,
            key=inputs,
            attention_mask=causal_mask,
            training=training
        )
        attn_output = self.dropout1(attn_output, training=training)

        # Residual connection + layer normalization
        out1 = self.normalization_layer_1(inputs + attn_output)

        # Feed-forward network
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)

        # Residual connection + layer normalization
        out2 = self.normalization_layer_2(out1 + ffn_output)

        return out2


In [None]:
class TokenAndPositionEmbedding(layers.Layer):
    """Embeds tokens + positions and adds them."""
    def __init__(self, maxlen, vocab_size, embed_dim):
        super().__init__()
        self.token_embedding = Embedding(input_dim=vocab_size, output_dim=embed_dim)
        self.pos_embedding = Embedding(input_dim=maxlen, output_dim=embed_dim)
        self.maxlen = maxlen

    def call(self, x):
        """
        x: (batch, seq_len)
        Returns: (batch, seq_len, embed_dim)
        """
        seq_len = tf.shape(x)[1]
        
        # Create position indices [0, 1, 2, ..., seq_len-1]
        positions = tf.range(start=0, limit=seq_len, delta=1)
        
        # Get position embeddings: (seq_len, embed_dim)
        pos_emb = self.pos_embedding(positions)
        
        # Get token embeddings: (batch, seq_len, embed_dim)
        tok_emb = self.token_embedding(x)
        
        # Add token and position embeddings (pos_emb broadcasts over batch)
        return tok_emb + pos_emb


In [None]:
def get_transformer_model(
    maxlen,
    vocab_size,
    embed_dim,
    num_heads,
    feed_forward_dim,
    num_transformer_blocks=1,
    rate=0.1,
):
    """Functional API model builder for causal next-token prediction."""
    inputs = Input(shape=(maxlen,), dtype=tf.int32)
    x = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)(inputs)

    for _ in range(num_transformer_blocks):
        x = TransformerBlock(embed_dim, num_heads, feed_forward_dim, rate=rate)(x)

    outputs = Dense(vocab_size)(x)
    model = Model(inputs=inputs, outputs=[outputs])
    return model


In [None]:
model = get_transformer_model(block_size, input_vocab_size, embed_dim, num_heads, feed_forward_dim, num_transformer_blocks, rate=dropout_rate)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(
    optimizer="adam",
    loss=[loss_fn],
    metrics=["accuracy"],
)


In [None]:
def make_next_token_dataset(token_ids, block_size):
    """
    inputs[i]  = token_ids[i : i+block_size]
    targets[i] = token_ids[i+1 : i+block_size+1]
    """
    inputs = [token_ids[i:i+block_size] for i in range(0, len(token_ids) - block_size - 1)]
    targets = [token_ids[i+1:i+block_size+1] for i in range(0, len(token_ids) - block_size - 1)]
    return inputs, targets


def build_tf_dataset(inputs, targets, batch_size, shuffle_buffer=10000):
    """
    Build a tf.data.Dataset from input/target sequences.
    """
    # Convert to numpy arrays with int32 dtype
    X = np.array(inputs, dtype=np.int32)
    Y = np.array(targets, dtype=np.int32)
    
    # Create dataset from tensor slices
    dataset = tf.data.Dataset.from_tensor_slices((X, Y))
    
    # Shuffle the dataset
    dataset = dataset.shuffle(buffer_size=shuffle_buffer)
    
    # Batch the dataset, dropping incomplete final batch
    dataset = dataset.batch(batch_size, drop_remainder=True)
    
    # Prefetch for performance
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    
    return dataset


inputs, targets = make_next_token_dataset(train_data, block_size)

# Create the training dataset
dataset = build_tf_dataset(inputs, targets, batch_size=batch_size, shuffle_buffer=10000)
print(f"Dataset created with {len(inputs)} samples")


In [None]:
model.summary()


In [None]:
# Training the Transformer model
# Using a custom callback to display generated text during training

class TextGenerationCallback(tf.keras.callbacks.Callback):
    def __init__(self, seed_tokens, generate_fn, decode_fn, every_n_epochs=5):
        super().__init__()
        self.seed_tokens = seed_tokens
        self.generate_fn = generate_fn
        self.decode_fn = decode_fn
        self.every_n_epochs = every_n_epochs
    
    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.every_n_epochs == 0:
            print(f"\n--- Sample generation at epoch {epoch + 1} ---")
            generated = self.generate_fn(self.model, self.seed_tokens.copy(), num_generate=100)
            print(generated[:200])
            print("---\n")

# Train the model
print("Starting training...")
history = model.fit(
    dataset,
    epochs=15,
    verbose=1
)

print("\nTraining complete!")


In [None]:
def generate_text(model, seed_tokens, num_generate=200, temperature=0.8):
    """
    Generate text given seed_tokens using the trained model.
    
    Args:
        model: Trained transformer model
        seed_tokens: List of token ids to start generation (length should be block_size)
        num_generate: Number of new tokens to generate
        temperature: Sampling temperature (higher = more random, lower = more deterministic)
    
    Returns:
        Generated string
    """
    # Ensure seed_tokens has the right length
    if len(seed_tokens) < block_size:
        # Pad with zeros if too short
        seed_tokens = [0] * (block_size - len(seed_tokens)) + seed_tokens
    elif len(seed_tokens) > block_size:
        # Take the last block_size tokens
        seed_tokens = seed_tokens[-block_size:]
    
    generated_tokens = []
    current_tokens = list(seed_tokens)
    
    for _ in range(num_generate):
        # Prepare input tensor
        input_eval = tf.convert_to_tensor([current_tokens], dtype=tf.int32)
        
        # Get model predictions
        logits = model(input_eval, training=False)  # (1, block_size, vocab_size)
        
        # Get logits for the last position
        next_token_logits = logits[0, -1, :]  # (vocab_size,)
        
        # Apply temperature scaling
        scaled_logits = next_token_logits / temperature
        
        # Convert to probabilities using softmax
        probs = tf.nn.softmax(scaled_logits).numpy()
        
        # Sample from the probability distribution
        next_token = np.random.choice(len(probs), p=probs)
        
        # Append to generated tokens
        generated_tokens.append(next_token)
        
        # Update current_tokens: shift left and append new token
        current_tokens = current_tokens[1:] + [next_token]
    
    # Decode and return the generated text
    return decode(generated_tokens)


In [None]:
# Test text generation after training
print("="*60)
print("TEXT GENERATION DEMO")
print("="*60)

# Use the beginning of training data as seed
seed = train_data[:block_size]
seed_text = decode(seed)

print(f"\nSeed text (first {block_size} characters):")
print("-"*40)
print(seed_text)
print("-"*40)

# Generate text with different temperatures
for temp in [0.5, 0.8, 1.0]:
    print(f"\n\nGenerated text (temperature={temp}):")
    print("-"*40)
    generated = generate_text(model, seed.copy(), num_generate=300, temperature=temp)
    print(seed_text + generated)
    print("-"*40)

print("\n" + "="*60)
print("GENERATION COMPLETE")
print("="*60)


In [None]:
# Visualize training history
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot loss
axes[0].plot(history.history['loss'], label='Training Loss', color='blue')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss Over Epochs')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot accuracy
axes[1].plot(history.history['accuracy'], label='Training Accuracy', color='green')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training Accuracy Over Epochs')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Interactive text generation with custom prompts
def generate_from_prompt(prompt, num_chars=200, temperature=0.8):
    """Generate text starting from a custom prompt."""
    # Encode the prompt
    prompt_tokens = encode(prompt)
    
    # Pad or truncate to block_size
    if len(prompt_tokens) < block_size:
        prompt_tokens = [0] * (block_size - len(prompt_tokens)) + prompt_tokens
    else:
        prompt_tokens = prompt_tokens[-block_size:]
    
    # Generate
    generated = generate_text(model, prompt_tokens, num_generate=num_chars, temperature=temperature)
    
    return prompt + generated

# Example prompts in Shakespearean style
prompts = [
    "To be or not to be",
    "All the world's a stage",
    "Friends, Romans, countrymen",
]

print("="*60)
print("CUSTOM PROMPT GENERATION")
print("="*60)

for prompt in prompts:
    print(f"\nPrompt: '{prompt}'")
    print("-"*40)
    result = generate_from_prompt(prompt, num_chars=150, temperature=0.7)
    print(result)
    print("-"*40)