In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math
import numpy as np

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# For reproducibility
torch.manual_seed(42)

Using device: cuda


<torch._C.Generator at 0x79df428a7d70>

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    """
    Implements Multi-Head Attention from scratch.
    Configuration:
    - d_model = 128
    - num_heads = 4
    - batch_size = 64
    - seq_length = 32
    """

    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()

        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Each head will get d_k = 128 / 4 = 32 dimensions

        # Linear projections for queries, keys, and values
        self.W_q = nn.Linear(d_model, d_model)  # (128 -> 128)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        # Final output projection
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """
        Q, K, V shapes: (batch_size, num_heads, seq_len, d_k) = (64, 4, 32, 32)
        """
        # Compute attention scores: Q x K^T
        # Resulting shape: (batch_size, num_heads, seq_len, seq_len) = (64, 4, 32, 32)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Apply attention mask if provided
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

        # Softmax over last dimension (keys)
        attn_probs = F.softmax(attn_scores, dim=-1)

        # Multiply attention weights with values
        # Output shape: (batch_size, num_heads, seq_len, d_k) = (64, 4, 32, 32)
        output = torch.matmul(attn_probs, V)

        return output

    def split_heads(self, x):
        """
        Split d_model into num_heads.
        Input shape:  (batch_size, seq_len, d_model) = (64, 32, 128)
        Output shape: (batch_size, num_heads, seq_len, d_k) = (64, 4, 32, 32)
        """
        batch_size, seq_length, d_model = x.size()
        # Reshape and transpose: (64, 32, 128) → (64, 32, 4, 32) → (64, 4, 32, 32)
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        """
        Combine heads back into original shape.
        Input shape:  (batch_size, num_heads, seq_len, d_k) = (64, 4, 32, 32)
        Output shape: (batch_size, seq_len, d_model) = (64, 32, 128)
        """
        batch_size, num_heads, seq_length, d_k = x.size()
        # Transpose and reshape: (64, 4, 32, 32) → (64, 32, 4, 32) → (64, 32, 128)
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None):
        """
        Input Q, K, V shape: (batch_size, seq_len, d_model) = (64, 32, 128)
        Output shape:       (batch_size, seq_len, d_model) = (64, 32, 128)
        """
        # Linear projections: all outputs shape (64, 32, 128)
        Q = self.W_q(Q)
        K = self.W_k(K)
        V = self.W_v(V)

        # Split into heads: (64, 32, 128) → (64, 4, 32, 32)
        Q = self.split_heads(Q)
        K = self.split_heads(K)
        V = self.split_heads(V)

        # Scaled Dot-Product Attention: (64, 4, 32, 32)
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)

        # Combine heads: (64, 4, 32, 32) → (64, 32, 128)
        combined_output = self.combine_heads(attn_output)

        # Final linear projection: (64, 32, 128)
        output = self.W_o(combined_output)

        return output

In [3]:
class PositionWiseFeedForward(nn.Module):
    """
    Implements the Position-wise Feed-Forward network.
    Each position in the sequence is passed through the same MLP independently.
    """

    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()

        # First linear layer: expands dimensionality (128 → 512)
        self.fc1 = nn.Linear(d_model, d_ff)

        # Second linear layer: projects back to original dimension (512 → 128)
        self.fc2 = nn.Linear(d_ff, d_model)

        # Non-linearity in between
        self.relu = nn.ReLU()

    def forward(self, x):
        """
        Input shape : (batch_size, seq_length, d_model) = (64, 32, 128)
        Output shape: (batch_size, seq_length, d_model) = (64, 32, 128)
        """
        # Apply FFN: position-wise
        return self.fc2(self.relu(self.fc1(x)))


class TransformerBlock(nn.Module):
    """
    A single Transformer encoder block with:
    - Multi-Head Attention (MHA)
    - Feed-Forward Network (FFN)
    - Residual connections + Layer Normalization
    - Dropout (regularization)
    """

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(TransformerBlock, self).__init__()

        # Multi-Head Self-Attention module
        self.attention = MultiHeadAttention(d_model, num_heads)

        # Position-wise Feed-Forward Network
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)

        # Layer normalization for each sub-layer (after residual addition)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Input shape : (batch_size, seq_length, d_model) = (64, 32, 128)
        Output shape: (batch_size, seq_length, d_model) = (64, 32, 128)
        """

        # --- First Sub-layer: Multi-Head Self-Attention ---

        # Self-attention: Q=K=V=x
        attn_output = self.attention(x, x, x, mask)  # → (64, 32, 128)

        # Add & Norm: residual connection + layer normalization
        x = self.norm1(x + self.dropout(attn_output))  # → (64, 32, 128)

        # --- Second Sub-layer: Feed-Forward Network ---

        # Position-wise FFN
        ff_output = self.feed_forward(x)  # → (64, 32, 128)

        # Add & Norm: residual connection + layer normalization
        x = self.norm2(x + self.dropout(ff_output))  # → (64, 32, 128)

        return x

In [4]:
class SimpleTransformerLM(nn.Module):
    """
    A simple Transformer-based Language Model built using:
    - Token embeddings + positional embeddings
    - Stacked Transformer blocks (multi-head attention + FFN)
    - Final linear layer to predict the next token (character)
    """

    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, max_seq_length, dropout=0.1):
        super(SimpleTransformerLM, self).__init__()

        # 🔹 Embedding layer to convert token indices → dense vectors of dim d_model
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        #   Input shape: (batch_size, seq_length)
        #   Output shape: (batch_size, seq_length, d_model)

        # 🔹 Positional embedding: learnable embeddings for each position in the sequence
        self.positional_embedding = nn.Embedding(max_seq_length, d_model)
        #   Input: position indices [0, 1, 2, ..., max_seq_length - 1]
        #   Output: (batch_size, seq_length, d_model)

        # 🔹 Stack of Transformer blocks (each with attention + FFN)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        #   This will hold `num_layers` blocks — each applied in sequence in forward()

        # 🔹 Final output layer: projects each position's final hidden state → vocab logits
        self.fc_out = nn.Linear(d_model, vocab_size)
        #   Output shape: (batch_size, seq_length, vocab_size)

        # 🔹 Dropout layer for regularization
        self.dropout = nn.Dropout(dropout)

        # 🔹 Store sequence length limit for reference
        self.max_seq_length = max_seq_length

    def forward(self, x, mask=None):
        """
        x:       (batch_size, seq_length) – input token indices
        mask:    (optional) attention mask, used in decoder (not used here)
        Returns: (batch_size, seq_length, vocab_size) – prediction logits
        """

        batch_size, seq_length = x.size()

        # 🔹 Create a tensor of positions [0, 1, 2, ..., seq_length - 1]
        positions = torch.arange(0, seq_length).expand(batch_size, seq_length).to(x.device)
        #   Shape: (batch_size, seq_length)

        # 🔹 Embed tokens and positions, then sum and apply dropout
        x = self.token_embedding(x) + self.positional_embedding(positions)
        x = self.dropout(x)
        #   Shape: (batch_size, seq_length, d_model)

        # 🔹 Pass through each Transformer block sequentially
        for block in self.transformer_blocks:
            x = block(x, mask)
        #   Shape stays the same throughout: (batch_size, seq_length, d_model)

        # 🔹 Final projection: predict vocabulary logits at each position
        logits = self.fc_out(x)
        #   Output shape: (batch_size, seq_length, vocab_size)

        return logits

In [5]:
# ================================================
# 🔹 Define a Simple Character-Level Corpus
# ================================================

text = """
Transformers have revolutionized the field of natural language processing.
The core idea behind the Transformer is self-attention, a mechanism that allows the model to weigh the importance of different words in a sequence.
This contrasts with previous models like RNNs, which processed words sequentially.
The parallel nature of Transformers allows for much faster training on large datasets.
"""
# ➤ This multi-line string will be used to train a character-level language model.
# ➤ Each character (including punctuation and whitespace) is treated as a token.

# ================================================
# 🔹 Create Vocabulary from Unique Characters
# ================================================

chars = sorted(list(set(text)))  # Extract all unique characters and sort them
vocab_size = len(chars)          # Number of unique characters

print(f"Vocabulary size: {vocab_size}")
print(f"Vocabulary: {''.join(chars)}")

# Example output:
# Vocabulary size: 65
# Vocabulary:
# .ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz

# ================================================
# 🔹 Create Character ↔ Integer Mappings
# ================================================

stoi = {ch: i for i, ch in enumerate(chars)}  # String to integer
itos = {i: ch for i, ch in enumerate(chars)}  # Integer to string

# Lambda functions for encoding and decoding sequences
encode = lambda s: [stoi[c] for c in s]       # String → List of ints
decode = lambda l: ''.join([itos[i] for i in l])  # List of ints → String

# ================================================
# 🔹 Encode the Full Text into a Sequence of Integers
# ================================================

data = torch.tensor(encode(text), dtype=torch.long)
# Shape: (total_characters,) → 1D tensor of integers representing the text

# ================================================
# 🔹 Prepare Input and Target Sequences
# ================================================

block_size = 32  # Context window size (sequence length)

X, Y = [], []  # Lists to store input-output pairs

for i in range(len(data) - block_size):
    # ➤ Input sequence: current block of 32 characters
    X.append(data[i     : i + block_size])

    # ➤ Target sequence: next 32 characters (shifted by one position)
    Y.append(data[i + 1 : i + block_size + 1])
    # ➤ This sets up a next-character prediction task

# Convert lists of tensors to a 2D tensor:
# X shape: (num_sequences, block_size) = (num_examples, 32)
# Y shape: same
X = torch.stack(X)  # e.g., (435, 32)
Y = torch.stack(Y)

# ================================================
# 🔹 Print Final Shapes
# ================================================

print(f"Shape of input data (X): {X.shape}")
print(f"Shape of target data (Y): {Y.shape}")
# ➤ Both should be [num_sequences, block_size]
# ➤ Example: Shape of input data (X): torch.Size([435, 32])

Vocabulary size: 32
Vocabulary: 
 ,-.NRTabcdefghiklmnopqrstuvwyz
Shape of input data (X): torch.Size([362, 32])
Shape of target data (Y): torch.Size([362, 32])


In [7]:
# =====================================================
# 🔹 Hyperparameters for Transformer Language Model
# =====================================================

D_MODEL = 128           # Hidden dimension of token embeddings and model output
NUM_LAYERS = 4          # Number of stacked Transformer blocks (depth of the model)
NUM_HEADS = 4           # Number of attention heads in Multi-Head Attention
D_FF = 512              # Dimension of Feed-Forward Network inside Transformer block
MAX_SEQ_LENGTH = block_size  # Maximum context window (sequence length), e.g., 32
DROPOUT = 0.1           # Dropout probability for regularization
LEARNING_RATE = 3e-4    # Learning rate for the optimizer
EPOCHS = 5000           # Total number of training iterations
BATCH_SIZE = 64         # Number of examples per mini-batch

# =====================================================
# 🔹 Model, Optimizer, and Loss Function Setup
# =====================================================

# Instantiate the model with the defined architecture and move it to device (CPU/GPU)
model = SimpleTransformerLM(
    vocab_size,         # Output vocabulary size (number of tokens to predict)
    D_MODEL,            # Embedding and hidden dimensions
    NUM_LAYERS,         # Number of Transformer blocks
    NUM_HEADS,          # Number of attention heads
    D_FF,               # FFN dimensionality
    MAX_SEQ_LENGTH,     # Context size
    DROPOUT             # Dropout rate
).to(device)

# Optimizer: AdamW is typically used for Transformers
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

# Loss function: CrossEntropyLoss for multi-class classification per character
# Note: logits are raw scores; targets are integer indices
criterion = nn.CrossEntropyLoss()

# =====================================================
# 🔹 Training Loop
# =====================================================

print("\n--- Starting Training ---")

for epoch in range(EPOCHS):

    # -----------------------------------------
    # 🔸 Step 1: Sample a random mini-batch
    # -----------------------------------------
    ix = torch.randint(0, X.shape[0], (BATCH_SIZE,))      # Sample random indices
    x_batch = X[ix].to(device)                            # Input batch: shape (64, 32)
    y_batch = Y[ix].to(device)                            # Target batch: shape (64, 32)

    # -----------------------------------------
    # 🔸 Step 2: Forward pass through the model
    # -----------------------------------------
    logits = model(x_batch)       # Shape: (64, 32, vocab_size)

    # -----------------------------------------
    # 🔸 Step 3: Reshape for loss computation
    # -----------------------------------------
    B, T, C = logits.shape        # B=batch, T=sequence length, C=classes
    logits = logits.view(B * T, C)      # Reshape to (64*32, vocab_size)
    targets = y_batch.view(B * T)       # Reshape to (64*32,) for CE loss

    # -----------------------------------------
    # 🔸 Step 4: Compute loss
    # -----------------------------------------
    loss = criterion(logits, targets)

    # -----------------------------------------
    # 🔸 Step 5: Backward pass and optimization
    # -----------------------------------------
    optimizer.zero_grad(set_to_none=True)  # Clear gradients (more efficient with set_to_none)
    loss.backward()                        # Backpropagation
    optimizer.step()                       # Update model weights

    # -----------------------------------------
    # 🔸 Step 6: Print loss periodically
    # -----------------------------------------
    if epoch % 500 == 0:
        print(f"Epoch {epoch}/{EPOCHS}, Loss: {loss.item():.4f}")

print("--- Training Complete ---")


--- Starting Training ---
Epoch 0/5000, Loss: 3.5758
Epoch 500/5000, Loss: 0.0119
Epoch 1000/5000, Loss: 0.0022
Epoch 1500/5000, Loss: 0.0013
Epoch 2000/5000, Loss: 0.0004
Epoch 2500/5000, Loss: 0.0003
Epoch 3000/5000, Loss: 0.0035
Epoch 3500/5000, Loss: 0.0004
Epoch 4000/5000, Loss: 0.0002
Epoch 4500/5000, Loss: 0.0003
--- Training Complete ---


In [8]:
def generate(model, start_string, max_new_tokens, method, **kwargs):
    """
    Generates text from the model using various decoding strategies.

    Args:
        model: The trained Transformer model.
        start_string: The initial string to start generation from.
        max_new_tokens: The maximum number of tokens to generate.
        method: One of 'greedy', 'sample', 'top_k', 'top_p'.
        **kwargs: Arguments for the decoding methods (temp, top_k, top_p).
    """
    model.eval()

    # Get configuration from kwargs
    temperature = kwargs.get('temp', 1.0)
    top_k = kwargs.get('top_k', None)
    top_p = kwargs.get('top_p', None)

    # Encode the starting string
    idx = torch.tensor(encode(start_string), dtype=torch.long, device=device).unsqueeze(0)

    for _ in range(max_new_tokens):
        # Crop context if it exceeds max_seq_length
        idx_cond = idx if idx.size(1) <= MAX_SEQ_LENGTH else idx[:, -MAX_SEQ_LENGTH:]

        with torch.no_grad():
            logits = model(idx_cond)

        # Focus only on the last time step
        logits = logits[:, -1, :]

        # --- Apply Decoding Strategy ---
        if method == 'greedy':
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)
        else: # For all sampling methods
            # Apply temperature scaling
            logits = logits / temperature

            # Apply Top-k sampling
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')

            # Apply Top-p (Nucleus) sampling
            if top_p is not None:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

                # Remove tokens with cumulative probability above the threshold
                sorted_indices_to_remove = cumulative_probs > top_p
                # Shift the indices to the right to keep the first token above the threshold
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits[:, indices_to_remove] = -float('Inf')

            # Get probabilities and sample
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)

        # Append the new token to the sequence
        idx = torch.cat((idx, idx_next), dim=1)

    return decode(idx[0].tolist())

# --- Run Demonstrations ---
print("\n--- Starting Inference Demonstrations ---\n")
start_prompt = "The Transformer is"
max_tokens_to_gen = 100

# 1. Greedy Decoding
print("--- 1. Greedy Decoding ---")
print("Description: Always picks the token with the highest probability. Deterministic but can be repetitive.")
generated_text = generate(model, start_prompt, max_tokens_to_gen, method='greedy')
print(f"Output:\n{generated_text}\n")


# 2. Sampling with Temperature
print("--- 2. Sampling with Temperature ---")
print("Description: Samples from the probability distribution. Temperature controls randomness.")

print("a) Low Temperature (T=0.5): More predictable, closer to greedy.")
torch.manual_seed(42) # for reproducibility
generated_text = generate(model, start_prompt, max_tokens_to_gen, method='sample', temp=0.5)
print(f"Output:\n{generated_text}\n")

print("b) High Temperature (T=1.5): More random, creative, but can be nonsensical.")
torch.manual_seed(42) # for reproducibility
generated_text = generate(model, start_prompt, max_tokens_to_gen, method='sample', temp=1.5)
print(f"Output:\n{generated_text}\n")


# 3. Top-k Sampling
print("--- 3. Top-k Sampling ---")
print("Description: Restricts sampling to the 'k' most likely tokens. Avoids very unlikely tokens.")
torch.manual_seed(42) # for reproducibility
generated_text = generate(model, start_prompt, max_tokens_to_gen, method='sample', temp=1.0, top_k=10)
print(f"Output (k=10):\n{generated_text}\n")


# 4. Top-p (Nucleus) Sampling
print("--- 4. Top-p (Nucleus) Sampling ---")
print("Description: Samples from the smallest set of tokens whose cumulative probability exceeds 'p'. More dynamic than top-k.")
torch.manual_seed(42) # for reproducibility
generated_text = generate(model, start_prompt, max_tokens_to_gen, method='sample', temp=1.0, top_p=0.90)
print(f"Output (p=0.90):\n{generated_text}\n")


--- Starting Inference Demonstrations ---

--- 1. Greedy Decoding ---
Description: Always picks the token with the highest probability. Deterministic but can be repetitive.
Output:
The Transformer is Thanarmely.
The Tr allows wor bs for  fanceisteis ais lis tisetis ts like RNNs, whicousta witid wof

--- 2. Sampling with Temperature ---
Description: Samples from the probability distribution. Temperature controls randomness.
a) Low Temperature (T=0.5): More predictable, closer to greedy.
Output:
The Transformer is Thase Troweh sfas the forelfel natureisf Transformers allows for much faster training on large data

b) High Temperature (T=1.5): More random, creative, but can be nonsensical.
Output:
The Transformer isequcernse.
This tis ite.
Tally.
The Transformer Tras former ms s ts fisetis sis RNs tiksts likestRNN

--- 3. Top-k Sampling ---
Description: Restricts sampling to the 'k' most likely tokens. Avoids very unlikely tokens.
Output (k=10):
The Transformer is Thasehey.
TrmTr Trmer is