<a href="https://colab.research.google.com/github/anroyus/score/blob/master/Attention_Is_All_You_Need_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
# Core PyTorch & TorchText (versions compatible with each other)
!pip install torch==2.0.1 torchtext==0.15.2 -U



In [5]:
pip freeze torch

absl-py==1.4.0
accelerate==1.8.1
aiofiles==24.1.0
aiohappyeyeballs==2.6.1
aiohttp==3.11.15
aiosignal==1.4.0
alabaster==1.0.0
albucore==0.0.24
albumentations==2.0.8
ale-py==0.11.1
altair==5.5.0
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
anyio==4.9.0
argon2-cffi==25.1.0
argon2-cffi-bindings==21.2.0
array_record==0.7.2
arviz==0.21.0
astropy==7.1.0
astropy-iers-data==0.2025.7.7.0.39.39
astunparse==1.6.3
atpublic==5.1
attrs==25.3.0
audioread==3.0.1
autograd==1.8.0
babel==2.17.0
backcall==0.2.0
backports.tarfile==1.2.0
beautifulsoup4==4.13.4
betterproto==2.0.0b6
bigframes==2.8.0
bigquery-magics==0.10.1
bleach==6.2.0
blinker==1.9.0
blis==1.3.0
blobfile==3.0.0
blosc2==3.5.1
bokeh==3.7.3
Bottleneck==1.4.2
bqplot==0.12.45
branca==0.8.1
build==1.2.2.post1
CacheControl==0.14.3
cachetools==5.5.2
catalogue==2.0.10
certifi==2025.7.9
cffi==1.17.1
chardet==5.2.0
charset-normalizer==3.4.2
chex==0.1.89
clarabel==0.11.1
click==8.2.1
cloudpathlib==0.21.1
cloudpickle==3.1.1
cmake==3.31.6
cmdstanpy

In [10]:
"""
Full Transformer Data Pipeline

1	Load and tokenize Hindi-English
2	Build vocabularies
3	Encode tokens into tensors
4	Create dataset + __getitem__
5	collate_fn for dynamic padding
6	Positional Encoding
7	Define Transformer model
8	Training loop
"""

'\nFull Transformer Data Pipeline\n\n1\tLoad and tokenize Hindi-English\n2\tBuild vocabularies\n3\tEncode tokens into tensors\n4\tCreate dataset + __getitem__\n5\tcollate_fn for dynamic padding\n6\tPositional Encoding\n7\tDefine Transformer model\n8\tTraining loop\n'

In [11]:
# Hugging Face datasets for loading IITB Hindi-English corpus
!pip install -q datasets

In [12]:
# Fix for Hugging Face + fsspec caching issue (critical!)..I ran into various local caching issues earlier due to higher fsspec versions..2025.X
!pip install fsspec==2023.9.2



In [2]:
#  Step 1: Load Dataset
# Load IITB English to Hindi data sets... we can always build our own later.
from datasets import load_dataset

raw_dataset = load_dataset("cfilt/iitb-english-hindi", split="train")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [6]:
# Test code: Check sample
print(raw_dataset[0])
print(raw_dataset.features)

{'translation': {'en': 'Give your application an accessibility workout', 'hi': 'अपने अनुप्रयोग को पहुंचनीयता व्यायाम का लाभ दें'}}
{'translation': {'en': Value(dtype='string', id=None), 'hi': Value(dtype='string', id=None)}}


In [7]:
import math
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

from datasets import load_dataset
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
import string

def hindi_tokenizer(text):
    return [word.strip("।") for word in text.strip().split()]

HI_TOKENIZER = hindi_tokenizer

In [9]:
# Test Hindi Tokenizer
print(hindi_tokenizer("भारत एक सुंदर देश है।"))  # ['भारत', 'एक', 'सुंदर', 'देश', 'है']


['भारत', 'एक', 'सुंदर', 'देश', 'है']


In [11]:
# test validation
for i in range(9999):
  text = raw_dataset[i]["translation"]["hi"]
  if 'भारत' in text:
      print(f"{i}: {text}")


In [14]:
# Step 2: Define Tokenizers
# Using torchtext's get_tokenizer. For Hindi, we’ll use a whitespace tokenizer for now:

from torchtext.data.utils import get_tokenizer

# Tokenizers
EN_TOKENIZER = get_tokenizer("basic_english")
# HI_TOKENIZER = lambda x: x.strip().split()


In [15]:
# Step 3: Yield Tokens for Vocabulary Building
# We need iterators that tokenize sentences and yield tokens for vocab building:
def yield_tokens(data_iter, tokenizer, lang):
    for example in data_iter:
        yield tokenizer(example["translation"][lang])



In [17]:
# ## Step 4:Build Vocabularies
# We’ll create vocabularies with special tokens like <unk>, <pad>, <bos>, <eos>:
from torchtext.vocab import build_vocab_from_iterator

specials = ['<unk>', '<pad>', '<bos>', '<eos>']

HI_VOCAB = build_vocab_from_iterator(yield_tokens(raw_dataset, HI_TOKENIZER, "hi"),
                                    specials=specials, max_tokens=25000)
HI_VOCAB.set_default_index(HI_VOCAB['<unk>'])


EN_VOCAB = build_vocab_from_iterator(yield_tokens(raw_dataset, EN_TOKENIZER, "en"),
                                     specials=specials, max_tokens=25000)
EN_VOCAB.set_default_index(EN_VOCAB['<unk>'])



In [18]:
# test
print(HI_VOCAB.get_itos()[:1000])

['<unk>', '<pad>', '<bos>', '<eos>', 'के', 'और', 'में', 'है', 'की', 'से', 'को', 'का', 'कि', 'पर', 'लिए', 'एक', 'हैं', 'तो', 'नहीं', 'जो', 'भी', 'ने', 'यह', 'हो', 'किया', 'ही', 'कर', 'इस', 'या', 'वह', 'करने', 'है,', 'अपने', 'न', 'तुम', 'वे', 'तथा', 'कुछ', 'किसी', 'गया', 'था', 'कोई', 'हम', 'उनके', 'साथ', 'रूप', 'द्वारा', 'है.', 'करते', 'लोगों', 'जब', 'लोग', 'दिया', 'भारत', 'तक', 'ये', 'जाता', 'फिर', 'उसके', 'थे', 'क्या', 'अल्लाह', 'ख़ुदा', 'अपनी', 'उस', 'जा', 'करना', 'रहे', 'उन्हें', 'करता', 'मैं', 'उन', 'होता', 'उसे', 'रहा', 'हुए', 'करें', 'आप', 'सकता', 'बहुत', 'समय', 'उन्होंने', 'कहा', 'अधिक', 'वाले', 'बात', 'दो', 'हमने', 'तरह', 'पहले', 'उनकी', 'तुम्हारे', 'यदि', 'पास', 'बाद', 'हुआ', 'थी', 'सभी', 'ऐसे', 'हमारे', 'ओर', 'होने', 'हैं,', 'गए', 'दिन', 'गई', 'अन्य', 'इसके', 'काम', 'सकते', 'प्रकार', 'उसकी', 'प्राप्त', 'दी', 'किए', 'होती', 'इन', 'एवं', 'ऐसा', 'कार्य', 'जाने', 'राष्ट्रपति', 'क्षेत्र', 'अगर', 'होगा', 'भारतीय', '(_', 'मुझे', 'सरकार', 'चाहिए', 'उसने', 'कारण', 'अब', 'देश', 'विकास',

In [19]:
# test
print('भारत' in HI_VOCAB.get_itos())

True


In [20]:
# test
for idx, token in enumerate(HI_VOCAB.get_itos()):
    if token == 'भारत':
        print(f"'भारत' found at index {idx}")

'भारत' found at index 53


In [21]:
# save vocab
import pickle

with open("hi_vocab.pkl", "wb") as f:
    pickle.dump(HI_VOCAB, f)

with open("en_vocab.pkl", "wb") as f:
    pickle.dump(EN_VOCAB, f)


In [26]:
# Not relevant for my current hindi-to-eng translation
# Install necessary packages (torchtext and spacy)
!pip install spacy --quiet
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm

import spacy
spacy_de = spacy.load("de_core_news_sm")
spacy_en = spacy.load("en_core_web_sm")

Collecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m61.5 MB/s[0m eta [36m0:00:00[0m
[?25h[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.
Collecting de-core-news-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.8.0/de_core_news_sm-3.8.0-py3-none-any.whl (14.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.6/14.6 MB[0m [31m68.0 MB/s[0m eta [36m0:00:00[0m
[?25h[38;5;2m✔ Download and insta

In [27]:
# this is not linked to hindi-engligh. This is for dutch

def tokenize_de(text):
    return [tok.text.lower() for tok in spacy_de.tokenizer(text)]

t1= tokenize_de("Cow is a sacred animal in hindu tradition")
t1_d = spacy_en("Cow is a sacred animal in hindu tradition")
t2 = tokenize_de("Muslims eat Cow meat")
t3 = tokenize_de("Cow milk is popular")

In [28]:
'''
token.text → the word itself
token.pos_ → part of speech (noun, verb, etc.)
token.dep_ → syntactic relation (subject, object, etc.)
token.lemma_ → base form (e.g., "running" → "run")
doc.ents → named entities (like "Amazon", "Las Vegas")
doc.sents → sentences (if multiple)
'''

for token in t1_d:
    print(f"{token.text:15} | POS: {token.pos_:10} | DEP: {token.dep_:10} | Lemma: {token.lemma_}")


Cow             | POS: NOUN       | DEP: nsubj      | Lemma: cow
is              | POS: AUX        | DEP: ROOT       | Lemma: be
a               | POS: DET        | DEP: det        | Lemma: a
sacred          | POS: ADJ        | DEP: amod       | Lemma: sacred
animal          | POS: NOUN       | DEP: attr       | Lemma: animal
in              | POS: ADP        | DEP: prep       | Lemma: in
hindu           | POS: NOUN       | DEP: compound   | Lemma: hindu
tradition       | POS: NOUN       | DEP: pobj       | Lemma: tradition


In [22]:
# testing the vocab
print(f"Hindi vocab size: {len(HI_VOCAB)}")
print(f"English vocab size: {len(EN_VOCAB)}")

print("Sample Hindi tokens:", HI_VOCAB.get_itos()[:10])
print("Sample English tokens:", EN_VOCAB.get_itos()[:10])

Hindi vocab size: 25000
English vocab size: 25000
Sample Hindi tokens: ['<unk>', '<pad>', '<bos>', '<eos>', 'के', 'और', 'में', 'है', 'की', 'से']
Sample English tokens: ['<unk>', '<pad>', '<bos>', '<eos>', 'the', '.', ',', 'of', 'and', 'to']


In [23]:
# testing vocab
test_sent = "भारत एक सुंदर देश है"
tokens = HI_TOKENIZER(test_sent)
print("Tokens:", tokens)
indices = [HI_VOCAB[token] for token in tokens]
print("Indices:", indices)

Tokens: ['भारत', 'एक', 'सुंदर', 'देश', 'है']
Indices: [53, 15, 1773, 133, 7]


In [24]:
# testing
print('भारत' in HI_VOCAB.get_itos())  # Should be True now
print('सुंदर' in HI_VOCAB.get_itos())
print('देश' in HI_VOCAB.get_itos())
print(HI_VOCAB.get_itos()[:50])

True
True
True
['<unk>', '<pad>', '<bos>', '<eos>', 'के', 'और', 'में', 'है', 'की', 'से', 'को', 'का', 'कि', 'पर', 'लिए', 'एक', 'हैं', 'तो', 'नहीं', 'जो', 'भी', 'ने', 'यह', 'हो', 'किया', 'ही', 'कर', 'इस', 'या', 'वह', 'करने', 'है,', 'अपने', 'न', 'तुम', 'वे', 'तथा', 'कुछ', 'किसी', 'गया', 'था', 'कोई', 'हम', 'उनके', 'साथ', 'रूप', 'द्वारा', 'है.', 'करते', 'लोगों']


In [25]:
## Step 5: Encode with <bos> and <eos>, and Convert to Tensor
# We’ll define a function to:

# Tokenize the sentence

# Add <bos> and <eos>

# Convert tokens to vocab indices

# Return a PyTorch tensor

# python
# Copy
# Edit

BOS_IDX = EN_VOCAB['<bos>']
EOS_IDX = EN_VOCAB['<eos>']
PAD_IDX = EN_VOCAB['<pad>']

def encode(text, tokenizer, vocab):
    #return [vocab['<bos>']] + vocab(tokenizer(text)) + [vocab['<eos>']]
    return [vocab['<bos>']] + [vocab[token] for token in tokenizer(text)] + [vocab['<eos>']]




In [26]:
## Step 6: Build PyTorch Dataset Wrapper
# We’ll define a TranslationDataset that will return tensor pairs: (src_tensor, tgt_tensor).
class TranslationDataset(torch.utils.data.Dataset):
    def __init__(self, data, src_lang, tgt_lang, src_tokenizer, tgt_tokenizer, src_vocab, tgt_vocab):
        self.data = data
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        src_text = item['translation'][self.src_lang]
        tgt_text = item['translation'][self.tgt_lang]

        src_tensor = torch.tensor(encode(src_text, self.src_tokenizer, self.src_vocab), dtype=torch.long)
        tgt_tensor = torch.tensor(encode(tgt_text, self.tgt_tokenizer, self.tgt_vocab), dtype=torch.long)

        return src_tensor, tgt_tensor


# Step 7: Creating the dataset class
Custom TranslationDataset to return input/output pairs
Transformers require:
 - An input sequence (e.g., Hindi sentence)
 - A target sequence (e.g., English translation)
 - Both converted into tensors
 - and special tokens added: <bos> at the start, <eos> at the end


So we will tokenize each sentence
Encode each token as an index from the vocab
Wrap it up in a PyTorch-friendly format


In [27]:
from torch.utils.data import Dataset
import torch

class TranslationDataset(Dataset):
    def __init__(self, data, src_tokenizer, tgt_tokenizer, src_vocab, tgt_vocab):
        self.data = data
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab

    def __len__(self):
        return len(self.data)

    def encode(self, text, tokenizer, vocab):
        tokens = tokenizer(text.lower().strip())
        return [vocab['<bos>']] + [vocab[token] for token in tokens] + [vocab['<eos>']]

    def __getitem__(self, idx):
        example = self.data[idx]
        src_text = example['translation']['hi']
        tgt_text = example['translation']['en']

        src_tensor = torch.tensor(self.encode(src_text, self.src_tokenizer, self.src_vocab), dtype=torch.long)
        tgt_tensor = torch.tensor(self.encode(tgt_text, self.tgt_tokenizer, self.tgt_vocab), dtype=torch.long)

        return src_tensor, tgt_tensor


In [28]:
# Test code for Step 5
dataset = TranslationDataset(
    data=raw_dataset,
    src_tokenizer=HI_TOKENIZER,
    tgt_tokenizer=EN_TOKENIZER,
    src_vocab=HI_VOCAB,
    tgt_vocab=EN_VOCAB
)

src_tensor, tgt_tensor = dataset[0]
print(src_tensor)  # [<bos>, id1, id2, ..., <eos>]
print(tgt_tensor)  # [<bos>, id1, id2, ..., <eos>]


tensor([    2,    32,  1223,    10, 22171,  5950,    11,   368,   538,     3])
tensor([   2,  173,   50,  424,   45, 6507,    0,    3])


In [29]:
# Step 6: Collate function and batching (handling variable-length sequences with padding)
"""
Problem Statement:
Sentences have different lengths.

We need batches of uniform shape (e.g., [batch_size, max_len]).

Solution:
The collate function:

Pads each sentence in the batch to the length of the longest sentence.

Returns padded tensors + lengths (optional).

Also creates masks (for padding).
A visual interpretation would be that, we are shaping all sequences in the batch into a same size rectangle.


e.g. Unequal length token sequences below
["<bos>", 12, 18, 92, "<eos>"]. --> length : 5
["<bos>", 45, 63, "<eos>"]  --> length : 4
["<bos>", 89, 12, 77, 23, "<eos>"]  --> length : 6

After Padding:

[
 [ <bos>, 12, 18, 92, <eos>, <pad> ],  --> length : 6
 [ <bos>, 45, 63, <eos>, <pad>, <pad> ],  --> length : 6
 [ <bos>, 89, 12, 77, 23, <eos> ].  --> length : 6
]

DURING TRAINING :
Notes: We should use dynamic padding within each batch
So if a batch has lengths: [6, 12, 10], we pad to 12 for that batch only
This is memory-efficient and doesn't require hardcoding a global max length


DURING INFERENCE :
At Inference Time or in Model Definition:
This is where max sequence length matters.

Transformer models must know:
 - How many positions to encode (in positional encodings)
 - What shape to expect during attention computation
 - So here’s what we typically do:
        Choose a Safe Maximum Sequence Length
Strategy	Max Length:
Observe our training set.
Take the 95th percentile length — e.g., 60 tokens
Add a buffer	E.g., use max_len = 128 even if your training set only goes up to 60
Very long sequences	Use 512 or 1024 (like BERT), but that increases memory cost

MAX_SEQ_LEN = 128  # Set by you

positional_encoding = PositionalEncoding(
    d_model=EMBED_DIM,
    max_len=MAX_SEQ_LEN
)

Future-Proofing Tips:
Train with dynamic padding, but define positional encodings with a large enough max_len (e.g., 256 or 512)

Monitor real-world inference traffic — log token lengths

For production: set MAX_LEN = 2x the typical input length



"""

# Take a batch of (src_tensor, tgt_tensor) pairs
# Pad them dynamically
# Return padded source and target tensors
from torch.nn.utils.rnn import pad_sequence

SRC_PAD_IDX = HI_VOCAB['<pad>']
TGT_PAD_IDX = EN_VOCAB['<pad>']

def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)

    # Pad source and target sequences dynamically
    src_batch = pad_sequence(src_batch, padding_value=SRC_PAD_IDX, batch_first=True)
    tgt_batch = pad_sequence(tgt_batch, padding_value=TGT_PAD_IDX, batch_first=True)

    return src_batch, tgt_batch


In [30]:
# Usage with Data Loader
from torch.utils.data import DataLoader

train_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn
)

# Try a batch
src_batch, tgt_batch = next(iter(train_loader))
print(src_batch.shape)  # [batch_size, max_src_len]
print(tgt_batch.shape)  # [batch_size, max_tgt_len]


torch.Size([32, 61])
torch.Size([32, 59])


In [31]:
#Step 7: Positional encoding
#Inject order information into embeddings
# Inject position awareness into the model
# Use either sinusoidal encoding (from the paper ) or learned encoding
# Paper Section 3.5
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super(PositionalEncoding, self).__init__()

        # Create a long enough positional encoding matrix: [max_len, d_model]
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        pe[:, 0::2] = torch.sin(position * div_term)  # Even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # Odd indices

        pe = pe.unsqueeze(0)  # Shape: [1, max_len, d_model] for broadcasting
        self.register_buffer("pe", pe)  # Not a learnable param

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x shape: [batch_size, seq_len, d_model]
        x = x + self.pe[:, :x.size(1), :]
        return x



In [32]:
# example use
# Let's say you have a batch of embeddings (e.g., from token IDs → embedding layer):
"""
embed = nn.Embedding(len(HI_VOCAB), d_model)
pos_enc = PositionalEncoding(d_model=d_model)

x = embed(src_batch)           # shape: [batch_size, seq_len, d_model]
x = pos_enc(x)                 # positionally encoded embeddings

"""
# register_buffer ensures the position matrix is stored with the model (and moves to CUDA if needed) but doesn’t get updated during backprop.
# max_len=5000 is typical; you can increase this if your sentences go longer.
# This is purely additive to embeddings.

'\nembed = nn.Embedding(len(HI_VOCAB), d_model)\npos_enc = PositionalEncoding(d_model=d_model)\n\nx = embed(src_batch)           # shape: [batch_size, seq_len, d_model]\nx = pos_enc(x)                 # positionally encoded embeddings\n\n'

In [40]:
#Step 8: Transformer model architecture
#Encoder, decoder, multi-head attention, masking

| Component                                    | Description                             |
| -------------------------------------------- | --------------------------------------- |
| 1. `Embedding` + `PositionalEncoding`        | For both source and target              |
| 2. `Transformer` from `torch.nn.Transformer` | Main encoder-decoder logic              |
| 3. Final `Linear` layer                      | To project decoder output to vocab size |


In [33]:
"""
src_tok_emb and tgt_tok_emb: Token embeddings
positional_encoding: Injects order into the embeddings
transformer: Applies multi-head attention & encoder-decoder logic
generator: Final linear layer → logits over target vocab


batch_first=True makes tensor shapes [batch, seq, feature], which matches your padded batches.

We still need to define:
src_mask and tgt_mask (for causal masking in decoder)
Padding masks (to ignore <pad> tokens)
"""

import torch.nn as nn

class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1,
                 max_len: int = 5000):
        super(Seq2SeqTransformer, self).__init__()

        # src_tok_emb and tgt_tok_emb: Token embeddings
        self.src_tok_emb = nn.Embedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = nn.Embedding(tgt_vocab_size, emb_size)

        # positional_encoding: Injects order into the embeddings
        self.positional_encoding = PositionalEncoding(emb_size, max_len)

        # transformer: Applies multi-head attention & encoder-decoder logic
        self.transformer = nn.Transformer(d_model=emb_size,
                                          nhead=nhead,
                                          num_encoder_layers=num_encoder_layers,
                                          num_decoder_layers=num_decoder_layers,
                                          dim_feedforward=dim_feedforward,
                                          dropout=dropout,
                                          batch_first=True)

        # generator: Final linear layer → logits over target vocab
        self.generator = nn.Linear(emb_size, tgt_vocab_size)

    def forward(self, src, tgt, src_mask, tgt_mask,
                src_padding_mask, tgt_padding_mask, memory_key_padding_mask):

        # src, tgt shape: [batch_size, seq_len]
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(tgt))

        output = self.transformer(src_emb, tgt_emb,
                                  src_mask=src_mask,
                                  tgt_mask=tgt_mask,
                                  src_key_padding_mask=src_padding_mask,
                                  tgt_key_padding_mask=tgt_padding_mask,
                                  memory_key_padding_mask=memory_key_padding_mask)

        return self.generator(output)


In [34]:
#Step 9: Loss function and masks
#Padding mask and look-ahead mask in decoder

# Causal Mask (Target Mask)
# For autoregressive decoding — at step t, we can only see tokens <= t.
# Look Ahead Mask : This is applied to the target (tgt) sequence inside the decoder self-attention.

def generate_square_subsequent_mask(sz):
    return torch.triu(torch.ones(sz, sz, dtype=torch.bool), diagonal=1)


# Padding mask
"""
This creates a mask that’s:

- 1 for real tokens
- 0 for <pad> tokens

Can be broadcasted in attention computations

We need this for:

 - encoder self-attention
 - decoder cross-attention
 - sometimes decoder self-attention (combined with look-ahead)

"""


'\nThis creates a mask that’s:\n\n- 1 for real tokens\n- 0 for <pad> tokens\n\nCan be broadcasted in attention computations\n\nWe need this for:\n\n - encoder self-attention\n - decoder cross-attention\n - sometimes decoder self-attention (combined with look-ahead)\n\n'

In [35]:
# # Test Code: Example Use of Masking During Forward Pass

SRC_VOCAB_SIZE = len(HI_VOCAB)
TGT_VOCAB_SIZE = len(EN_VOCAB)
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
NUM_ENCODER_LAYERS = 6
NUM_DECODER_LAYERS = 6

model = Seq2SeqTransformer(
    NUM_ENCODER_LAYERS,
    NUM_DECODER_LAYERS,
    EMB_SIZE,
    NHEAD,
    SRC_VOCAB_SIZE,
    TGT_VOCAB_SIZE,
    FFN_HID_DIM
).to(device)



# # Example vocab: Syntheic Data
# # <pad>: 0, <bos>: 1, <eos>: 2, I: 3, love: 4, cats: 5, dogs: 6

# pad_idx = 0

# # src: "I love dogs <pad>"
# # tgt: "<bos> I love cats"

# src = torch.tensor([[3, 4, 6, 0]], device=device)  # e.g., "I love dogs <pad>"
# tgt = torch.tensor([[1, 3, 4, 5]], device=device)  # e.g., "<bos> I love cats"

# pad_idx = HI_VOCAB['<pad>']

# # Padding masks (shape: [batch_size, seq_len])
# src_padding_mask = (src == pad_idx)
# tgt_padding_mask = (tgt == pad_idx)

# # Look-ahead mask (shape: [tgt_seq_len, tgt_seq_len])
# tgt_mask = generate_square_subsequent_mask(tgt.size(1)).to(device)

# # No src mask typically, so pass None or a zeros mask
# src_mask = None

# # memory_key_padding_mask = same as src_padding_mask
# output = model(
#     src,
#     tgt,
#     src_mask=src_mask,
#     tgt_mask=tgt_mask,
#     src_padding_mask=src_padding_mask,
#     tgt_padding_mask=tgt_padding_mask,
#     memory_key_padding_mask=src_padding_mask
# )


In [36]:
model.eval()

Seq2SeqTransformer(
  (src_tok_emb): Embedding(25000, 512)
  (tgt_tok_emb): Embedding(25000, 512)
  (positional_encoding): PositionalEncoding()
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=512, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): Transf

In [37]:
# # Test mask
# generate_square_subsequent_mask(4)

# # tensor([
# # [0, -inf, -inf, -inf],
# # [0,   0, -inf, -inf],
# # [0,   0,   0, -inf],
# # [0,   0,   0,   0]])

In [46]:
#Step 10: Training loop
#Batching, optimizer, loss, learning rate schedule
"""
nn.CrossEntropyLoss(ignore_index=PAD_IDX)
Optimizer (Adam)
Learning rate scheduler (optional)
Forward + loss computation
Backward + optimizer step
Logging every N batches
"""

'\nnn.CrossEntropyLoss(ignore_index=PAD_IDX)\nOptimizer (Adam)\nLearning rate scheduler (optional)\nForward + loss computation\nBackward + optimizer step\nLogging every N batches\n'

In [38]:
# Loss function (ignore padding index)
criterion = nn.CrossEntropyLoss(ignore_index=TGT_PAD_IDX)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Mask generators
def create_masks(src, tgt_input):
    src_mask = None
    tgt_mask = generate_square_subsequent_mask(tgt_input.size(1)).to(device)

    seq_len = tgt_input.size(1)
    tgt_mask = generate_square_subsequent_mask(seq_len).to(device)

    src_padding_mask = (src == SRC_PAD_IDX)
    tgt_padding_mask = (tgt_input == TGT_PAD_IDX)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask


In [39]:
def train_epoch(model, data_loader, optimizer, criterion):
    model.train()
    total_loss = 0

    for src, tgt in data_loader:
        src = src.to(device)
        tgt = tgt.to(device)

        # Prepare inputs/outputs
        tgt_input = tgt[:, :-1]    # input to decoder
        tgt_out = tgt[:, 1:]       # prediction target

        # Generate masks
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_masks(src, tgt_input)

        # Forward
        logits = model(
            src, tgt_input,
            src_mask=src_mask,
            tgt_mask=tgt_mask,
            src_padding_mask=src_padding_mask,
            tgt_padding_mask=tgt_padding_mask,
            memory_key_padding_mask=src_padding_mask
        )

        # Flatten outputs and targets
        logits = logits.reshape(-1, logits.size(-1))
        tgt_out = tgt_out.reshape(-1)

        loss = criterion(logits, tgt_out)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(data_loader)


In [40]:
# Save model and optimizer state dicts
def save_checkpoint(model, optimizer, epoch, path="checkpoint.pth"):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, path)
    print(f"Checkpoint saved at epoch {epoch} to {path}")

# Load model and optimizer state dicts
def load_checkpoint(model, optimizer, path="checkpoint.pth"):
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Checkpoint loaded. Resuming from epoch {start_epoch}")
    return start_epoch


In [None]:
EPOCHS = 10

for epoch in range(1, EPOCHS + 1):
    train_loss = train_epoch(model, train_loader, optimizer, criterion)
    print(f"Epoch {epoch}, Train Loss: {train_loss:.4f}")

# Save checkpoint once after training completes
torch.save({
    'epoch': EPOCHS,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict()
}, "last_checkpoint.pth")

print("Training complete. Checkpoint saved.")


In [None]:
import torch
from google.colab import files

# After training finishes
checkpoint_path = "last_checkpoint.pth"
torch.save({
    'epoch': EPOCHS,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict()
}, checkpoint_path)

# Download to local machine
files.download(checkpoint_path)


In [None]:
# Greedy decoding:
# Starts with <bos>
# Predicts one token at a time
# Feeds each predicted token back into the decoder until it generates <eos> or hits max length


def greedy_decode(model, src, src_mask, max_len, start_symbol, src_padding_mask):
    model.eval()

    src = src.to(device)
    src_mask = src_mask
    src_padding_mask = src_padding_mask

    memory = model.transformer.encoder(
        model.positional_encoding(model.src_tok_emb(src)),
        src_key_padding_mask=src_padding_mask
    )

    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)

    for i in range(max_len - 1):
        tgt_mask = generate_square_subsequent_mask(ys.size(1)).to(device)

        out = model.transformer.decoder(
            model.positional_encoding(model.tgt_tok_emb(ys)),
            memory,
            tgt_mask=tgt_mask,
            memory_key_padding_mask=src_padding_mask
        )

        out = model.generator(out[:, -1])
        next_token = torch.argmax(out, dim=-1).item()

        ys = torch.cat([ys, torch.ones(1, 1).fill_(next_token).type(torch.long).to(device)], dim=1)

        if next_token == EOS_IDX:
            break

    return ys


In [None]:
import torch
import torch.nn.functional as F

def beam_search_decode(model, src, src_mask, max_len, start_symbol, src_padding_mask, beam_width=3):
    model.eval()

    # Encode source
    # src shape: [batch_size, src_len]
    memory = model.transformer.encoder(
        model.positional_encoding(model.src_tok_emb(src)),
        src_key_padding_mask=src_padding_mask
    ) # memory shape: [batch_size, src_len, d_model]


    # Initialize beams: list of (sequence_tensor, score)
    # sequence_tensor shape: [seq_len] for each beam initially
    beams = [(torch.tensor([start_symbol], dtype=torch.long, device=device), 0)]

    for _ in range(max_len):
        candidates = []
        for seq, score in beams:
            if seq[-1].item() == EOS_IDX:  # If EOS, keep the beam
                candidates.append((seq, score))
                continue

            # Prepare target input for the decoder
            # seq shape: [current_seq_len]
            # tgt_input shape: [batch_size=1, current_seq_len, d_model]
            tgt_input = model.positional_encoding(model.tgt_tok_emb(seq.unsqueeze(0))) # Add batch dimension


            # Generate causal mask for the decoder self-attention
            # tgt_mask shape: [current_seq_len, current_seq_len]
            tgt_mask = generate_square_subsequent_mask(seq.size(0)).to(device)

            # memory_key_padding_mask shape: [batch_size=1, src_len]
            # This is the same as src_padding_mask

            out = model.transformer.decoder(
                tgt_input, # shape: [1, current_seq_len, d_model]
                memory,    # shape: [1, src_len, d_model]
                tgt_mask=tgt_mask, # shape: [current_seq_len, current_seq_len]
                memory_key_padding_mask=src_padding_mask # shape: [1, src_len]
            ) # out shape: [1, current_seq_len, d_model]


            # Get logits for the last generated token
            logits = model.generator(out[:, -1])  # shape: [1, vocab_size]
            log_probs = F.log_softmax(logits, dim=-1)

            topk_log_probs, topk_indices = log_probs.topk(beam_width)

            for i in range(beam_width):
                next_token = topk_indices[0, i]
                next_score = score + topk_log_probs[0, i].item()
                next_seq = torch.cat([seq, next_token.unsqueeze(0)])
                candidates.append((next_seq, next_score))

        # Keep top k beams based on scores
        beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_width]

    # Return the best sequence (highest score)
    return beams[0][0]

In [None]:
def translate_beam_search(model, src_sentence, beam_width=3, max_len=50):
    model.eval()

    # Tokenize and encode source sentence
    tokens = HI_TOKENIZER(src_sentence)
    src_indices = [HI_VOCAB['<bos>']] + [HI_VOCAB[token] for token in tokens] + [HI_VOCAB['<eos>']]
    src_tensor = torch.tensor(src_indices, dtype=torch.long, device=device).unsqueeze(0) # Add batch dimension

    src_padding_mask = (src_tensor == SRC_PAD_IDX)

    src_mask = None  # or create if your model needs

    with torch.no_grad():
        decoded_ids = beam_search_decode(
            model,
            src_tensor,
            src_mask,
            max_len,
            BOS_IDX,
            src_padding_mask,
            beam_width
        ).flatten()

    # Remove special tokens from output and convert to text
    tokens = [EN_VOCAB.lookup_token(idx) for idx in decoded_ids if idx not in {BOS_IDX, EOS_IDX, PAD_IDX}]

    return " ".join(tokens)

In [None]:
def translate(model, src_sentence, beam_width=3, max_len=50):
    model.eval()

    tokens = HI_TOKENIZER(src_sentence)
    src_indices = [HI_VOCAB['<bos>']] + [HI_VOCAB[token] for token in tokens] + [HI_VOCAB['<eos>']]

    src_tensor = torch.tensor(src_indices, dtype=torch.long, device=device).unsqueeze(0)  # batch first: (1, seq_len)
    #src_padding_mask = (src_tensor == SRC_PAD_IDX)  # boolean tensor
    src_padding_mask = (src_tensor == SRC_PAD_IDX)  # Make sure it's 2D: [batch_size, src_len]
    if src_padding_mask.dim() == 1:
        src_padding_mask = src_padding_mask.unsqueeze(0)

    src_mask = None
    # print(f"src_indices = {src_indices}")
    # print(f"src_tensor.shape = {src_tensor.shape}")
    # print(f"src_padding_mask.shape = {src_padding_mask.shape}")
    with torch.no_grad():
        decoded_ids = beam_search_decode(
            model,
            src_tensor,
            src_mask,
            max_len,
            BOS_IDX,
            src_padding_mask,
            beam_width
        ).flatten()

    tokens = [EN_VOCAB.lookup_token(idx) for idx in decoded_ids if idx not in {BOS_IDX, EOS_IDX, PAD_IDX}]
    return " ".join(tokens)

In [None]:
#print(model.eval)
print(translate(model, "भारत एक सुंदर देश है"))
# Output might be random initially unless trained
