# Implementing A Simple Transformer with Attention

## courtesy

https://arxiv.org/abs/1706.03762

Attention Is All You Need

Different GPT blogs

# Install libraries

In [1]:
!pip install torch==2.0.1 torchtext==0.15.2  torchvision

Collecting torch==2.0.1
  Downloading torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl.metadata (24 kB)
Collecting torchtext==0.15.2
  Downloading torchtext-0.15.2-cp310-cp310-manylinux1_x86_64.whl.metadata (7.4 kB)
Collecting nvidia-cuda-nvrtc-cu11==11.7.99 (from torch==2.0.1)
  Downloading nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu11==11.7.99 (from torch==2.0.1)
  Downloading nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cuda-cupti-cu11==11.7.101 (from torch==2.0.1)
  Downloading nvidia_cuda_cupti_cu11-11.7.101-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu11==8.5.0.96 (from torch==2.0.1)
  Downloading nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu11==11.10.3.66 (from torch==2.0.1)
  Downloading nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Co

# Step 1: Import Libraries & Prepare Dataset


We break down sentences into words (tokenization) and convert these words into numerical values (numericalization) so the computer can process them. To ensure that all sentences have the same length, we add padding where necessary.


Imagine you have different sentences, but some are long and some are short. The computer likes things to be the same size, so we add extra blank spaces (padding) to make all sentences the same length.

In [7]:
# Step 1: Import Libraries & Prepare Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define dataset with diverse sentences
text = [
    "deep learning enables AI to process data efficiently",
    "attention is the key mechanism behind transformers",
    "transformers improve machine learning and NLP models",
    "models generalize better with diverse training datasets",
    "neural networks learn from both structured and unstructured data",
    "machine learning optimizes healthcare and financial decisions",
    "artificial intelligence advancements rely on large datasets",
    "language models predict sequences for a variety of applications",
    "transfer learning helps models reuse knowledge across domains",
    "self-attention allows transformers to process entire sequences in parallel",
]


text.extend([
    "transformers use multi-head self-attention to capture dependencies",
    "machine learning requires extensive training data",
    "deep neural networks power modern AI models",
    "natural language processing is a key area of artificial intelligence",
    "self-supervised learning reduces the need for labeled data",
    "vision transformers are effective for image recognition tasks",
    "sequence-to-sequence models generate text translations",
    "speech recognition systems use recurrent networks",
    "deep learning applications span across multiple industries",
    "transfer learning helps models adapt to new tasks efficiently",
    "reinforcement learning is used in robotics and game AI",
    "contrastive learning improves representations in self-supervised learning",
    "large language models like GPT understand and generate text",
    "computer vision detects objects, faces, and scenes in images",
    "attention mechanisms help models focus on relevant information",
    "transformers surpass recurrent networks in processing sequences",
    "neural networks learn hierarchical features from raw data",
    "speech synthesis models generate human-like voices",
    "convolutional neural networks excel in image processing",
    "language models predict the next word in a sentence",
    "unsupervised learning finds hidden patterns in unlabeled data",
    "generative models create realistic images and videos",
    "transformers process entire sequences in parallel",
    "BERT is a transformer model pre-trained on large text corpora",
    "autonomous vehicles rely on deep learning for perception",
    "recurrent neural networks handle sequential data efficiently",
    "zero-shot learning enables models to generalize to unseen tasks",
    "few-shot learning trains models with very few examples",
    "multi-modal models integrate text, images, and audio",
    "transformers achieve state-of-the-art results in NLP",
    "self-attention captures long-range dependencies in sequences",
    "meta-learning allows models to learn how to learn",
    "language understanding is crucial for conversational AI",
    "convolutional networks extract spatial features in vision tasks",
    "transformers are trained using large-scale parallel computing",
    "speech-to-text models convert spoken words into written text",
    "recurrent networks struggle with long-term dependencies",
    "contrastive learning is widely used in representation learning",
    "neural networks optimize parameters using backpropagation",
    "few-shot learning mimics human ability to learn from few examples",
    "transformers enable contextualized embeddings in language models",
    "deep learning accelerates drug discovery and medical research",
    "attention is all you need is a seminal paper on transformers",
    "transformer-based architectures improve translation accuracy",
    "transformers eliminate the need for recurrence in NLP models",
    "speech models use hidden Markov models and neural networks",
    "transformers outperform traditional sequence models in NLP",
    "pre-training on large datasets enhances model generalization",
    "tokenization is a fundamental step in NLP preprocessing",
    "transformers generalize well across different NLP tasks",
    "AI systems leverage deep learning for better decision-making",
])


# Tokenization
tokenizer = get_tokenizer("basic_english")
tokens = [tokenizer(sentence) for sentence in text]

# Debugging: Print tokenized text
print("Tokenized sentences:", tokens)


Tokenized sentences: [['deep', 'learning', 'enables', 'ai', 'to', 'process', 'data', 'efficiently'], ['attention', 'is', 'the', 'key', 'mechanism', 'behind', 'transformers'], ['transformers', 'improve', 'machine', 'learning', 'and', 'nlp', 'models'], ['models', 'generalize', 'better', 'with', 'diverse', 'training', 'datasets'], ['neural', 'networks', 'learn', 'from', 'both', 'structured', 'and', 'unstructured', 'data'], ['machine', 'learning', 'optimizes', 'healthcare', 'and', 'financial', 'decisions'], ['artificial', 'intelligence', 'advancements', 'rely', 'on', 'large', 'datasets'], ['language', 'models', 'predict', 'sequences', 'for', 'a', 'variety', 'of', 'applications'], ['transfer', 'learning', 'helps', 'models', 'reuse', 'knowledge', 'across', 'domains'], ['self-attention', 'allows', 'transformers', 'to', 'process', 'entire', 'sequences', 'in', 'parallel'], ['transformers', 'use', 'multi-head', 'self-attention', 'to', 'capture', 'dependencies'], ['machine', 'learning', 'requires

# Step 1.1 Vocabulary Building


A vocabulary is like a dictionary where each word is assigned a unique number. Special tokens like <pad> (for padding) and <unk> (for unknown words) help handle cases where a word isn't recognized.


Think of it like a secret codebook where every word gets its own number. If the computer doesn’t know a word, it uses a special code for "unknown" words.

In [8]:
# Build Vocabulary
vocab = build_vocab_from_iterator(tokens, specials=["<pad>", "<unk>"], min_freq=1)
vocab.set_default_index(vocab["<unk>"])

# Debugging: Print vocab size
print(f"Vocabulary Size: {len(vocab)}")

# Convert text to numerical format
numerical_text = [[vocab[token] for token in sentence] for sentence in tokens]

# Debugging: Print numericalized sentences
print("Numericalized sentences:", numerical_text)

Vocabulary Size: 225
Numericalized sentences: [[14, 3, 53, 18, 9, 39, 10, 31], [28, 8, 26, 64, 146, 88, 5], [5, 62, 36, 3, 7, 16, 2], [2, 33, 50, 45, 106, 72, 29], [12, 6, 19, 32, 90, 195, 7, 211, 10], [36, 3, 160, 127, 7, 118, 102], [49, 63, 81, 69, 17, 24, 29], [15, 2, 67, 20, 11, 13, 213, 66, 47], [73, 3, 59, 2, 181, 138, 27, 107], [41, 46, 5, 9, 39, 54, 20, 4, 38], [5, 43, 152, 41, 9, 91, 30], [36, 3, 178, 115, 72, 10], [14, 12, 6, 166, 151, 18, 2], [155, 15, 40, 8, 13, 64, 84, 66, 49, 63], [70, 3, 173, 26, 37, 11, 139, 10], [44, 5, 48, 109, 11, 61, 68, 21], [187, 2, 34, 22, 206], [42, 68, 71, 43, 25, 6], [14, 3, 47, 189, 27, 154, 134], [73, 3, 59, 2, 80, 9, 156, 21, 31], [174, 3, 8, 74, 4, 182, 7, 122, 18], [51, 3, 133, 177, 4, 70, 3], [24, 15, 2, 142, 125, 207, 7, 34, 22], [93, 44, 103, 158, 23, 117, 23, 7, 183, 4, 35], [28, 147, 128, 2, 120, 17, 175, 135], [5, 197, 25, 6, 4, 40, 20], [12, 6, 19, 129, 56, 32, 170, 10], [42, 198, 2, 34, 132, 217], [52, 12, 6, 114, 4, 61, 40], [15,

# Step 1.2  Padding

Since sentences have different lengths, we use padding tokens (<pad>) to make them uniform in size. This helps in batch processing, making computations more efficient.


It’s like writing a sentence on a blank page and adding extra spaces at the end so that all pages have the same number of words. This way, everything lines up properly.

In [9]:

# Pad sequences
max_seq_length = max(len(seq) for seq in numerical_text)
padded_text = [F.pad(torch.tensor(seq, dtype=torch.long), (0, max_seq_length - len(seq)), value=vocab["<pad>"]) for seq in numerical_text]

# Convert to TensorDataset
padded_text = torch.stack(padded_text)
dataset = TensorDataset(padded_text)

# Create DataLoader
batch_size = 4
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Step 2: Define Multi-Head Attention

Multi-head attention allows a model to focus on different parts of a sentence at the same time. Instead of just looking at one word at a time, it splits the information into smaller pieces (heads), each attending to different relationships within the text. This helps capture meaning better.

For example, in the sentence "The cat sat on the mat", different heads might focus on:

The relationship between cat and sat,
The connection between mat and on,
The subject (cat) of the sentence.
Each head processes information differently, and the results are combined for better understanding.


Imagine you’re watching a movie with your friends. One friend pays attention to the music, another to the colors, another to the action scenes. When you talk about it later, you all bring different details, making the story richer. That’s what multi-head attention does—it looks at different parts of the sentence at the same time.

### Detailed Explanation
Multi-head attention is a mechanism that allows a model to focus on different parts of a sentence simultaneously. Instead of analyzing one word at a time, it breaks the information into multiple smaller parts (called "heads"), with each head attending to different relationships between words. This helps capture the full meaning of the sentence more effectively.

For example, consider the sentence:
📝 "The cat sat on the mat."

A single-head attention model might focus only on a single aspect of the sentence, such as:

The connection between cat and sat (i.e., who is performing the action?).
However, a multi-head attention model splits its focus into multiple heads, where:

One head looks at the relationship between cat and sat (who is performing the action?).
Another head looks at the relationship between mat and on (where did the action take place?).
A third head looks at the (determiner) and how it relates to cat (subject).
Each head processes different relationships separately, and their results are combined to create a stronger understanding of the sentence.

### How Multi-Head Attention Works (Step by Step)
Each word in a sentence is transformed into a numerical representation called an embedding. Multi-head attention processes these embeddings using the following steps:

#### 1.Linear Transformations:

Each word embedding is passed through three linear layers (fully connected layers) to generate:
Q (Query): What this word is looking for in other words.
K (Key): How relevant this word is to other words.
V (Value): The actual information stored in the word.
These transformations allow each word to have different perspectives when interacting with other words.



##### Understanding Query (Q), Key (K), and Value (V)

In the sentence **"The cat sat on the mat"**, we use **Query (Q)**, **Key (K)**, and **Value (V)** to determine which words should focus on each other.

###### 1. Query (Q) → The word we are focusing on.
   - **Example**: `"cat"`

###### 2. Key (K) → The words we compare with.
   - **Example**: `["The", "cat", "sat", "on", "the", "mat"]`

###### 3. Value (V) → The actual word representations used in the output.
   - **Example**: `["The", "cat", "sat", "on", "the", "mat"]`


Step-by-Step Breakdown
Let's say we want to find what "cat" should pay attention to.

We compute attention scores by comparing Q ("cat") with each K (all words).

If "sat" has a high score with "cat", then "sat" is important in relation to "cat".

The word representations from V (Values) are then weighted by these attention scores to form a meaningful output.

##### Example of Attention Scores for "cat"

| Key Word (K) | Score (Similarity with Q "cat") |
|-------------|--------------------------------|
| "The"       | Low (not very related)       |
| "cat"       | High (itself)                |
| "sat"       | High (action related to cat) |
| "on"        | Medium (less relevant)       |
| "the"       | Low (not important here)     |
| "mat"       | Medium (cat is on the mat)   |

Final Output: "cat" will focus most on "sat" and "mat" because they are relevant.


###### Explanation For a 12-Year-Old
Think of a classroom:

You (Query Q) are "cat", and you are curious about something.
Your classmates (Keys K) are all the words in the sentence.
You ask, "Who is most important to me?"
Each classmate (word) gives their answer.
The teacher (Attention Mechanism) decides which answers (Values V) matter the most based on how similar they are to what you asked.
##### Example:

"cat" will listen most to "sat" because cats sit.
"cat" will also listen to "mat", since cats often sit on mats.
But "The" and "on" aren't very important to "cat", so it pays less attention to them.
##### 🚀 Conclusion: Attention helps words focus on the most meaningful parts of the sentence!

#### 2.Splitting into Multiple Heads:

The transformed Q, K, and V vectors are split into multiple smaller parts (heads). Each head independently performs attention calculations.

#### 3.Scaled Dot-Product Attention:

For each head, the model computes attention scores using:
$$ Attention \ Score = \frac{Q \times K^T}{\sqrt{d_k}} $$



This means:
The query (Q) of one word is compared against the key (K) of every other word to determine how much attention it should give.
The dot product measures similarity—higher values mean stronger relationships.
The square root of d_k (size of each head) helps stabilize gradients during training.
#### 4.Masking (Optional):

If using masked attention (common in language models like GPT), we block information from certain positions to prevent the model from "cheating."
This is important for autoregressive generation (predicting words one by one) to ensure the model does not see future words.

#### 5.Softmax & Weighted Sum:

The attention scores are normalized using a softmax function, turning them into probabilities.
These probabilities determine how much of each word’s value (V) contributes to the final output.

#### 6.Concatenation & Final Transformation:

Once each head has computed its result, all heads are concatenated back together and passed through another linear layer to merge their insights.


### Simple Explanation For a 12-Year-Old
Think of multi-head attention like a group of detectives solving a mystery together.

Imagine you're watching a movie with your friends, and each friend focuses on something different:

One friend listens to the music 🎵.
Another pays attention to the action scenes 🎬.
Another watches the characters' emotions 😢.
Another looks at the background scenery 🌆.
After the movie, when you all discuss it, each friend shares a different part of the story, making your understanding richer and more complete.

Similarly, multi-head attention lets a model look at different parts of a sentence at the same time instead of just one word at a time. It understands relationships between words better, just like your friends understand different aspects of the movie!

Why is Multi-Head Attention So Powerful?
Parallel Processing: Unlike older models (like RNNs), which process words one by one, transformers look at the entire sentence at once.
Focus on Different Relationships: Each head pays attention to different parts of the sentence, making the understanding more detailed.
Better Context Awareness: Since multiple words can be attended to simultaneously, the model understands long-range dependencies better.
This concept is one of the key reasons transformers (like GPT, BERT, and T5) are so powerful in language processing! 🚀

In [10]:
# Step 2: Define Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads

        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.out_linear = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)

        # Debugging print
        #print(f"Input to attention - Q: {q.shape}, K: {k.shape}, V: {v.shape}")

        q = self.q_linear(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        k = self.k_linear(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        v = self.v_linear(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Debugging print
        #print(f"After Linear Layers - Q: {q.shape}, K: {k.shape}, V: {v.shape}")

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn = torch.softmax(scores, dim=-1)

        # Debugging print
        #print(f"Attention Scores Shape: {attn.shape}, V Shape: {v.shape}")

        output = torch.matmul(attn, v)

        # Debugging print
        #print(f"Output before reshape: {output.shape}")

        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)

        #print(f"Output after reshape: {output.shape}")

        return self.out_linear(output)


# Step 3. Transformer Model

A transformer is a deep learning model that understands text by using layers of multi-head attention and feedforward networks. It doesn’t process words one by one (like RNNs) but instead looks at the whole sentence at once, making it very efficient.

How It Works:

Embedding Layer: Converts words into numerical vectors that capture their meanings.
Positional Encoding: Since transformers don’t read words in order like humans, they add position information to keep track of word order.
Multi-Head Attention: Helps the model look at different words in a sentence at once.
Feedforward Layers: Refine the information and prepare it for the next layer.
Output Layer: Predicts the next word or the meaning of the sentence.
Transformers are the backbone of powerful AI models like GPT, BERT, and T5.


Think of a transformer like a really smart detective. Instead of reading a book word by word, the detective looks at the entire page, finds important clues, and connects them instantly. It doesn’t just follow the order of words—it figures out how everything fits together to understand the meaning quickly.

In [11]:
# Transformer Model
class MiniTransformer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers, vocab_size, max_len):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_len, d_model)
        self.layers = nn.ModuleList([MultiHeadAttention(d_model, num_heads) for _ in range(num_layers)])
        self.fc_out1 = nn.Linear(d_model, d_model)
        self.fc_out2 = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        pos = torch.arange(0, x.size(1)).unsqueeze(0).to(x.device)
        x = self.embedding(x) + self.position_embedding(pos)

        # Debugging print
        #print(f"Embedding Shape: {x.shape}")

        for layer in self.layers:
            x = layer(x, x, x)

        x = self.fc_out1(x)
        return self.fc_out2(x)

# Model Initialization
model = MiniTransformer(
    d_model=256, num_heads=8, d_ff=512, num_layers=2,
    vocab_size=len(vocab), max_len=max_seq_length
).to(device)

# Debugging: Check embedding layer size
print(f"Embedding Layer Shape: {model.embedding.weight.shape}")


Embedding Layer Shape: torch.Size([225, 256])


# Step 4. Train the Model


CrossEntropy Loss measures how different the model’s prediction is from the actual answer. Ignoring the <pad> token ensures that padding doesn’t affect training since it's not real data.

It’s like grading a test. If a student leaves blank spaces (padding) on the answer sheet, the teacher ignores them while checking answers.


AdamW is an optimization algorithm that updates model weights efficiently while preventing overfitting. It’s a version of Adam that includes weight decay, which helps stabilize training.


Imagine you’re learning to ride a bike. If you keep adjusting too much, you’ll wobble. If you don’t adjust at all, you’ll fall. AdamW carefully adjusts the learning process so the model doesn’t "wobble" too much.

In [16]:
# Training
def train_model(model, train_loader, epochs=5):
    criterion = nn.CrossEntropyLoss(ignore_index=vocab["<pad>"])
    #criterion = nn.CrossEntropyLoss(ignore_index=vocab["<pad>"], label_smoothing=0.1)

    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_loader:
            batch_tensor = batch[0].to(device)
            input_tensor = batch_tensor[:, :-1]
            target_tensor = batch_tensor[:, 1:]

            optimizer.zero_grad()
            output = model(input_tensor)
            loss = criterion(output.reshape(-1, output.size(-1)), target_tensor.reshape(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

# Train the Model
train_model(model, train_loader, epochs=200)

Epoch 1, Loss: 0.8043
Epoch 2, Loss: 5.1186
Epoch 3, Loss: 9.9844
Epoch 4, Loss: 12.6110
Epoch 5, Loss: 6.6701
Epoch 6, Loss: 6.1210
Epoch 7, Loss: 4.2512
Epoch 8, Loss: 4.8079
Epoch 9, Loss: 4.4147
Epoch 10, Loss: 2.8314
Epoch 11, Loss: 3.6111
Epoch 12, Loss: 5.9703
Epoch 13, Loss: 6.3211
Epoch 14, Loss: 5.1674
Epoch 15, Loss: 3.6164
Epoch 16, Loss: 4.8582
Epoch 17, Loss: 3.8594
Epoch 18, Loss: 3.0023
Epoch 19, Loss: 2.5497
Epoch 20, Loss: 2.2677
Epoch 21, Loss: 1.7869
Epoch 22, Loss: 3.2531
Epoch 23, Loss: 4.1885
Epoch 24, Loss: 7.8143
Epoch 25, Loss: 10.1721
Epoch 26, Loss: 13.6266
Epoch 27, Loss: 14.6170
Epoch 28, Loss: 12.0695
Epoch 29, Loss: 11.2052
Epoch 30, Loss: 7.7528
Epoch 31, Loss: 6.3641
Epoch 32, Loss: 5.5414
Epoch 33, Loss: 5.4329
Epoch 34, Loss: 5.9932
Epoch 35, Loss: 4.2969
Epoch 36, Loss: 3.2927
Epoch 37, Loss: 2.5046
Epoch 38, Loss: 1.5655
Epoch 39, Loss: 2.6855
Epoch 40, Loss: 1.7061
Epoch 41, Loss: 1.3156
Epoch 42, Loss: 1.1453
Epoch 43, Loss: 0.9160
Epoch 44, Loss

# Step 5. Prediction

Top-K sampling limits the number of possible word choices to the top K most likely ones, preventing low-quality predictions. Temperature controls randomness—higher values make the model more creative, lower values make it more predictable.


Imagine picking ice cream flavors. With Top-K, instead of choosing from all flavors, you only pick from the top 5 best ones. Temperature is like how adventurous you are—if it's low, you always choose vanilla; if it's high, you might try a crazy new flavor.

Beam search expands multiple possible sequences at each step, keeping the most likely ones. Instead of picking the best word at each step, it looks at different possibilities and selects the best overall sentence.


Imagine you’re solving a maze, and you can explore multiple paths at once. Instead of choosing the first way that looks good, you try a few paths and pick the best one.



In [17]:



import torch.nn.functional as F

def beam_search_predict(sentence, max_words=10, temperature=1.2, top_k=5):
    model.eval()

    # Convert sentence to token indices
    tokens = [vocab[token] if token in vocab else vocab["<unk>"] for token in tokenizer(sentence)]
    x = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)

    #print(f"🔍 Input Tokens: {tokens}")  # Debugging
    #print(f"🔍 Initial Tensor Shape: {x.shape}")  # Debugging

    generated_tokens = set()  # To track repeated tokens

    with torch.no_grad():
        for _ in range(max_words):
            output = model(x)  # Get model predictions
            logits = output[:, -1, :]  # Take last word's logits

            # **Apply temperature scaling**
            logits = logits / temperature

            # **Apply Top-k Sampling**
            probs = F.softmax(logits, dim=-1)
            top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
            sampled_index = torch.multinomial(top_k_probs, 1).item()
            output_token = top_k_indices[0, sampled_index].item()

            # Debugging: Show predicted token
            #print(f"🔍 Predicted Token Index: {output_token}, Token: {vocab.lookup_token(output_token)}")

            # Prevent infinite loop: Stop if same token is repeated too much
            if output_token in generated_tokens:
                #print(f"⚠️ Stopping early: Model is repeating token '{vocab.lookup_token(output_token)}'")
                break
            generated_tokens.add(output_token)

            # Append the new token
            x = torch.cat([x, torch.tensor([[output_token]], dtype=torch.long).to(device)], dim=1)

    #print(f"🔍 Final Predicted Sequence Indices: {x.tolist()}")  # Debugging

    return " ".join([vocab.lookup_token(i) for i in x.squeeze().tolist()])

# Run Beam Search Prediction
# print(beam_search_predict("hello", temperature=1.2, top_k=5))
# print(beam_search_predict("attention", temperature=1.2, top_k=5))
# print(beam_search_predict("learning", temperature=1.2, top_k=5))


# print(beam_search_predict("hello", temperature=0.8, top_k=5))
# print(beam_search_predict("transformers", temperature=0.8, top_k=5))
# print(beam_search_predict("vision", temperature=0.8, top_k=5))
# print(beam_search_predict("attention", temperature=0.8, top_k=5))
# print(beam_search_predict("learning", temperature=0.8, top_k=5))


print(beam_search_predict("hello", temperature=0.8, top_k=10))
print(beam_search_predict("transformers", temperature=0.8, top_k=10))
print(beam_search_predict("vision", temperature=0.8, top_k=10))
print(beam_search_predict("attention", temperature=0.8, top_k=10))
print(beam_search_predict("learning", temperature=0.8, top_k=10))

<unk> process is
transformers for the
vision captures is
attention captures sequences
learning learning


# Enhance prediction

In [18]:
def beam_search_predict(sentence, max_words=8, temperature=0.8, top_k=5, top_p=0.85, penalty=1.5, beam_width=3, length_penalty=1.0):
    model.eval()

    # Convert sentence to token indices
    tokens = [vocab[token] if token in vocab else vocab["<unk>"] for token in tokenizer(sentence)]
    x = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)

    generated_tokens = set()  # Track repeated tokens
    sequences = [(x, 0)]  # Store (sequence, score)

    with torch.no_grad():
        for _ in range(max_words):
            all_candidates = []
            for seq, score in sequences:
                output = model(seq)

                # Apply temperature scaling
                logits = output[:, -1, :] / temperature
                probs = F.softmax(logits, dim=-1)

                # Apply top-k and top-p filtering
                sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

                # Apply top-p filtering (keep the most probable tokens that sum to top_p)
                top_p_mask = cumulative_probs <= top_p
                top_p_mask[:, 1:] = top_p_mask[:, :-1].clone()
                top_p_mask[:, 0] = True  # Always keep at least one token

                # Get valid token indices
                valid_indices = sorted_indices[top_p_mask]
                valid_probs = sorted_probs[top_p_mask]

                # Sample from filtered distribution
                sampled_index = valid_indices[torch.multinomial(valid_probs, 1)].item()

                # Apply repetition penalty
                if sampled_index in generated_tokens:
                    probs[:, sampled_index] /= penalty

                # Append the new token
                new_seq = torch.cat([seq, torch.tensor([[sampled_index]], dtype=torch.long).to(device)], dim=1)
                new_score = score - torch.log(probs[:, sampled_index]) / (len(new_seq) ** length_penalty)  # Normalize score

                all_candidates.append((new_seq, new_score))

            # Keep best `beam_width` sequences
            sequences = sorted(all_candidates, key=lambda x: x[1])[:beam_width]

            # Stop early if all beams generate `<unk>` or repeated tokens
            if all(seq[0][:, -1].item() == vocab["<unk>"] or seq[0][:, -1].item() in generated_tokens for seq in sequences):
                break

            # Track generated tokens
            generated_tokens.update([seq[0][:, -1].item() for seq in sequences])

    # Choose the best sequence
    best_sequence = sequences[0][0].squeeze().tolist()
    return " ".join([vocab.lookup_token(i) for i in best_sequence])

# 🔥 Test the improved function
print(beam_search_predict("hello", temperature=0.8, top_k=5))
print(beam_search_predict("transformers", temperature=0.8, top_k=5))
print(beam_search_predict("vision", temperature=0.8, top_k=5))
print(beam_search_predict("attention", temperature=0.8, top_k=5))
print(beam_search_predict("learning", temperature=0.8, top_k=5))


<unk> process is is
transformers for for
vision captures transformers is from on on
attention captures sequences sequences
learning learning learning


In [19]:
def beam_search_predict(sentence, max_words=6, temperature=0.8, top_k=5, top_p=0.85, penalty=2.0, beam_width=5, diversity_penalty=1.2):
    model.eval()

    # Convert sentence to token indices
    tokens = [vocab[token] if token in vocab else vocab["<unk>"] for token in tokenizer(sentence)]
    x = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)

    generated_tokens = set()  # Track repeated tokens
    sequences = [(x, 0)]  # Store (sequence, score)

    with torch.no_grad():
        for _ in range(max_words):
            all_candidates = []
            for seq, score in sequences:
                output = model(seq)

                # Apply temperature scaling
                logits = output[:, -1, :] / temperature
                probs = F.softmax(logits, dim=-1)

                # Apply top-k and top-p filtering
                sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

                # Apply top-p filtering
                top_p_mask = cumulative_probs <= top_p
                top_p_mask[:, 1:] = top_p_mask[:, :-1].clone()
                top_p_mask[:, 0] = True  # Always keep at least one token

                # Get valid token indices
                valid_indices = sorted_indices[top_p_mask]
                valid_probs = sorted_probs[top_p_mask]

                # Sample from filtered distribution
                sampled_index = valid_indices[torch.multinomial(valid_probs, 1)].item()

                # Apply repetition penalty
                if sampled_index in generated_tokens:
                    probs[:, sampled_index] /= penalty

                # Apply diversity penalty (promotes different outputs)
                probs /= (diversity_penalty ** len(seq))

                # Append the new token
                new_seq = torch.cat([seq, torch.tensor([[sampled_index]], dtype=torch.long).to(device)], dim=1)
                new_score = score - torch.log(probs[:, sampled_index])  # Normalize score

                all_candidates.append((new_seq, new_score))

            # Keep best `beam_width` sequences
            sequences = sorted(all_candidates, key=lambda x: x[1])[:beam_width]

            # Stop early if all beams generate `<unk>` or repeated tokens
            if all(seq[0][:, -1].item() == vocab["<unk>"] or seq[0][:, -1].item() in generated_tokens for seq in sequences):
                break

            # Track generated tokens
            generated_tokens.update([seq[0][:, -1].item() for seq in sequences])

    # Choose the best sequence
    best_sequence = sequences[0][0].squeeze().tolist()
    return " ".join([vocab.lookup_token(i) for i in best_sequence])

# 🔥 Test the improved function
print(beam_search_predict("hello", temperature=0.8, top_k=5))
print(beam_search_predict("transformers", temperature=0.8, top_k=5))
print(beam_search_predict("vision", temperature=0.8, top_k=5))
print(beam_search_predict("attention", temperature=0.8, top_k=5))
print(beam_search_predict("learning", temperature=0.8, top_k=5))


<unk> process is is
transformers for the for
vision captures captures
attention captures sequences sequences
learning learning learning
