In [None]:
!pip install torchmetrics tokenizers datasets

In [None]:
%%writefile config.py
from pathlib import Path

def get_config():
    return {
        "batch_size": 64,
        "num_epochs": 20,
        "lr": 2e-4,
        "seq_len": 128,
        "d_model": 512, # Changed from 360 to 512 (Standard practice: 512/8 = 64 per head)
        "lang_src": "en",
        "lang_tgt": "hi", # CHANGED: Tamil (ta) -> Hindi (hi)
        "model_folder": "weights",
        "model_basename": "tmodel_", # CHANGED: Just the prefix, not the full path
        "preload": None, # Set to 'latest' to resume training if interrupted
        "tokenizer_file": "tokenizer_{0}.json",          
        "experiment_name": "runs/tmodel",
        "N": 6,
        "h": 8,
        "dropout": 0.1 # Reduced slightly (0.2 is okay, but 0.1 is standard for this size)
    }

def get_weights_file_path(config, epoch: str):
    model_folder = config['model_folder']
    model_basename = config['model_basename']
    model_filename = f"{model_basename}{epoch}.pt"
    return str(Path('.') / model_folder / model_filename)

def latest_weights_file_path(config):
    # CHANGED: Removed dependency on missing 'datasource' key
    model_folder = config['model_folder']
    model_basename = config['model_basename']
    
    # Check if folder exists
    if not Path(model_folder).exists():
        return None
        
    model_filename = f"{model_basename}*"
    weights_files = list(Path(model_folder).glob(model_filename))
    
    if len(weights_files) == 0:
        return None
    
    weights_files.sort()
    return str(weights_files[-1])

In [None]:
%%writefile dataset.py
import torch
import torch.nn as nn
from torch.utils.data import Dataset

class BilingualDataset(Dataset):

    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
        super().__init__()
        self.seq_len = seq_len
        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang

        # Store special token IDs directly as integers (more efficient)
        self.sos_token = torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_src.token_to_id('[PAD]')], dtype=torch.int64)

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, index):
        src_target_pair = self.ds[index]
        src_text = src_target_pair[self.src_lang]
        tgt_text = src_target_pair[self.tgt_lang]

        # Splitting sentences into tokens
        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids

        # Truncate tokens if too long
        # Encoder: needs space for SOS, EOS
        if len(enc_input_tokens) > self.seq_len - 2:
            enc_input_tokens = enc_input_tokens[:self.seq_len - 2]
        # Decoder: needs space for SOS
        if len(dec_input_tokens) > self.seq_len - 1:
            dec_input_tokens = dec_input_tokens[:self.seq_len - 1]

        enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1

        # Create padding tensors efficiently
        # We use .item() to get the int value, then create the tensor
        enc_padding = torch.full((enc_num_padding_tokens,), self.pad_token.item(), dtype=torch.int64)
        dec_padding = torch.full((dec_num_padding_tokens,), self.pad_token.item(), dtype=torch.int64)

        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(enc_input_tokens, dtype=torch.int64),
                self.eos_token,
                enc_padding
            ],
            dim=0,
        )

        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                dec_padding
            ],
            dim=0
        )

        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                self.eos_token,
                dec_padding
            ],
            dim=0,
        )

        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len

        return {
            "encoder_input": encoder_input,  # (seq_len)
            "decoder_input": decoder_input,  # (seq_len)
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len)
            "label": label, # (seq_len)
            "src_text": src_text,
            "tgt_text": tgt_text
        }

def causal_mask(size):
    # Defines the upper triangular matrix (including diag) as 0
    # Returns 1 for lower triangle (keep), 0 for upper (ignore)
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0

In [None]:
%%writefile model.py
import torch
import torch.nn as nn
import math

class InputEmbeddings(nn.Module):
    def __init__(self, vocab_size: int, d_model: int):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embeddings = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        # (batch, seq_len) -> (batch, seq_len, d_model)
        # Multiply by sqrt(d_model) to scale embeddings (as per paper)
        return self.embeddings(x) * math.sqrt(self.d_model)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_length: int, dropout: float):
        super().__init__()
        self.d_model = d_model
        self.seq_length = seq_length
        self.dropout = nn.Dropout(dropout)

        # Create a matrix of shape (seq_length, d_model)
        pe = torch.zeros(seq_length, d_model)
        
        # Create a vector of shape (seq_length, 1) for positions
        position = torch.arange(0, seq_length, dtype=torch.float32).unsqueeze(1)
        
        # Calculate the division term
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        # Apply sine to even indices and cosine to odd indices
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Add batch dimension: (1, seq_length, d_model)
        pe = pe.unsqueeze(0)

        # Register as a buffer (not a learnable parameter, but part of state_dict)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # Add positional encoding to embeddings (slicing to current seq length)
        x = x + self.pe[:, :x.shape[1], :]
        return self.dropout(x)

class LayerNormalisation(nn.Module):
    def __init__(self, eps: float = 10**-6):
        super().__init__()
        self.eps = eps
        # Learnable parameters
        self.alpha = nn.Parameter(torch.ones(1)) # Multiplicative
        self.bias = nn.Parameter(torch.zeros(1)) # Additive

    def forward(self, x):
        # Mean and Std calculated over the last dimension
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.alpha * ((x - mean) / (std + self.eps)) + self.bias

class FeedForwardBlock(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float):
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff) # W1 + b1
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model) # W2 + b2

    def forward(self, x):
        # (Batch, Seq_Len, d_model) -> (Batch, Seq_Len, d_ff) -> (Batch, Seq_Len, d_model)
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, d_model: int, h: int, dropout: float):
        super().__init__()
        self.d_model = d_model
        self.h = h
        
        assert d_model % h == 0, "d_model must be divisible by h"

        self.d_k = d_model // h
        
        self.w_q = nn.Linear(d_model, d_model) # Wq
        self.w_k = nn.Linear(d_model, d_model) # Wk
        self.w_v = nn.Linear(d_model, d_model) # Wv
        self.w_o = nn.Linear(d_model, d_model) # Wo
        
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]
        
        # (Batch, h, Seq_Len, d_k) @ (Batch, h, d_k, Seq_Len) -> (Batch, h, Seq_Len, Seq_Len)
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        
        if mask is not None:
            # Apply mask (0 means ignore, so we set to -infinity)
            attention_scores.masked_fill_(mask == 0, -1e9)
            
        attention_scores = attention_scores.softmax(dim=-1)
        
        if dropout is not None:
            attention_scores = dropout(attention_scores)
            
        # (Batch, h, Seq_Len, Seq_Len) @ (Batch, h, Seq_Len, d_k) -> (Batch, h, Seq_Len, d_k)
        return (attention_scores @ value), attention_scores

    def forward(self, q, k, v, mask):
        query = self.w_q(q) # (Batch, Seq_Len, d_model)
        key = self.w_k(k)   # (Batch, Seq_Len, d_model)
        value = self.w_v(v) # (Batch, Seq_Len, d_model)

        # Reshape for multi-head attention:
        # (Batch, Seq_Len, d_model) -> (Batch, Seq_Len, h, d_k) -> (Batch, h, Seq_Len, d_k)
        query = query.view(query.shape[0], -1, self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], -1, self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], -1, self.h, self.d_k).transpose(1, 2)

        # Calculate attention
        x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
        
        # Combine heads
        # (Batch, h, Seq_Len, d_k) -> (Batch, Seq_Len, h, d_k) -> (Batch, Seq_Len, d_model)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
        
        # Output projection
        return self.w_o(x)

class ResidualConnection(nn.Module):
    def __init__(self, dropout: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalisation()

    def forward(self, x, sublayer):
        # "Norm first" pre-LN architecture is often more stable, but original paper used Post-LN.
        # Your implementation is Post-LN (Norm after Sublayer + Add): 
        # x + dropout(sublayer(norm(x))) -> This is actually Pre-LN structure (Norm is applied to input of sublayer)
        # This is generally BETTER for training stability.
        return x + self.dropout(sublayer(self.norm(x)))

class EncoderBlock(nn.Module):
    def __init__(self, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])

    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x

class Encoder(nn.Module):
    def __init__(self, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalisation()

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class DecoderBlock(nn.Module):
    def __init__(self, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])

    def forward(self, x, encoder_output, src_mask, target_mask):
        # 1. Self Attention (with target mask to prevent look-ahead)
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, target_mask))
        # 2. Cross Attention (Query=Decoder, Key/Value=Encoder)
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        # 3. Feed Forward
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x

class Decoder(nn.Module):
    def __init__(self, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalisation()

    def forward(self, x, encoder_output, src_mask, target_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, target_mask)
        return self.norm(x)

class ProjectionLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        # (Batch, Seq_Len, d_model) -> (Batch, Seq_Len, vocab_size)
        # CRITICAL FIX: Returning RAW LOGITS.
        # Use nn.CrossEntropyLoss during training (which applies Softmax internally).
        return self.proj(x)

class Transformer(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, target_embed: InputEmbeddings, src_pos: PositionalEncoding, target_pos: PositionalEncoding, projection_layer: ProjectionLayer):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.target_embed = target_embed
        self.src_pos = src_pos
        self.target_pos = target_pos
        self.projection_layer = projection_layer

    def encode(self, src, src_mask):
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)

    def decode(self, encoder_output, src_mask, target, target_mask):
        target = self.target_embed(target)
        target = self.target_pos(target)
        return self.decoder(target, encoder_output, src_mask, target_mask)

    def project(self, x):
        return self.projection_layer(x)

def build_transformer(src_vocab_size: int, target_vocab_size: int, src_seq_len: int, target_seq_len: int, d_model: int = 512, N: int = 6, h: int = 8, dropout: float = 0.1, d_ff: int = 2048):
    
    # Create Embedding Layers
    src_embed = InputEmbeddings(src_vocab_size, d_model)
    target_embed = InputEmbeddings(target_vocab_size, d_model)

    # Create Positional Encoding Layers
    src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
    target_pos = PositionalEncoding(d_model, target_seq_len, dropout)

    # Create Encoder Blocks
    encoder_blocks = []
    for _ in range(N):
        encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout)
        encoder_blocks.append(encoder_block)

    # Create Decoder Blocks
    decoder_blocks = []
    for _ in range(N):
        decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
        decoder_blocks.append(decoder_block)

    # Assemble Encoder and Decoder
    encoder = Encoder(nn.ModuleList(encoder_blocks))
    decoder = Decoder(nn.ModuleList(decoder_blocks))

    # Create Projection Layer
    projection_layer = ProjectionLayer(d_model, target_vocab_size)

    # Create the complete Transformer
    transformer = Transformer(encoder, decoder, src_embed, target_embed, src_pos, target_pos, projection_layer)

    # Initialize parameters with Xavier Uniform
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return transformer

In [None]:
%%writefile train.py

import warnings
from pathlib import Path
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import torchmetrics

# Huggingface imports
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
# CHANGED: Added Punctuation and Sequence imports
from tokenizers.pre_tokenizers import Whitespace, Punctuation, Sequence

# Local imports
from model import build_transformer
from dataset import BilingualDataset, causal_mask
from config import get_config, get_weights_file_path, latest_weights_file_path

def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')
    
    # Precompute encoder output and reuse it for every token we get from the decoder
    encoder_output = model.encode(source, source_mask)
    # Initialise the decoder input with the sos token
    decoder_input = torch.empty(1,1).fill_(sos_idx).type_as(source).to(device)
    
    while True:
        if decoder_input.size(1) == max_len :
            break
        
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)
        
        # Calculate the decoder output
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
        
        # Get the next token
        prob = model.project(out[:,-1])
        
        # Select the token with the maximum probability
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat([decoder_input, torch.empty(1,1).type_as(source).fill_(next_word.item()).to(device)], dim=1)

        if next_word == eos_idx:
            break
            
    return decoder_input.squeeze(0)

def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_state, writer, num_examples = 2):
    model.eval()
    count = 0

    source_texts = []
    expected = []
    predicted = []

    try:
        import shutil
        console_width = shutil.get_terminal_size().columns
    except:
        console_width = 80

    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch["encoder_input"].to(device)
            encoder_mask = batch["encoder_mask"].to(device)

            assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"

            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)

            source_text = batch["src_text"][0]
            target_text = batch["tgt_text"][0]
            model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())

            source_texts.append(source_text)
            expected.append(target_text)
            predicted.append(model_out_text)
            
            print_msg('-'*console_width)
            print_msg(f"{'SOURCE: ':>12}{source_text}")
            print_msg(f"{'TARGET: ':>12}{target_text}")
            print_msg(f"{'PREDICTED: ':>12}{model_out_text}")

            if count == num_examples:
                print_msg('-'*console_width)
                break
    
    if writer:
        # TorchMetrics expects lists of strings
        metric = torchmetrics.CharErrorRate()
        cer = metric(predicted, expected)
        writer.add_scalar('validation cer', cer, global_state)
        writer.flush()

        metric = torchmetrics.WordErrorRate()
        wer = metric(predicted, expected)
        writer.add_scalar('validation wer', wer, global_state)
        writer.flush()

        metric = torchmetrics.BLEUScore()
        bleu = metric(predicted, expected)
        writer.add_scalar('validation BLEU', bleu, global_state)
        writer.flush()

def get_all_sentences(ds, lang):
    for item in ds:
        if 'translation' in item:
            yield item['translation'][lang]
        else:
            yield item[lang]

def get_or_build_tokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not tokenizer_path.exists():
        print(f"Tokenizer file {tokenizer_path} not found. Training a new one...")
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        
        # CHANGED: Updated to use Sequence(Whitespace, Punctuation)
        # This matches the structure of your uploaded JSON files.
        tokenizer.pre_tokenizer = Sequence([Whitespace(), Punctuation(behavior="Isolated")])
        
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

def get_ds(config):
    print("Loading Hindi Dataset...")
    ds_raw = load_dataset('cfilt/iitb-english-hindi', split='train[:100000]') 
    
    flattened_ds = []
    for item in ds_raw:
        if 'translation' in item:
            flattened_ds.append(item['translation'])
        else:
            flattened_ds.append(item)
            
    ds_raw = flattened_ds

    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])

    train_ds_size = int(0.9 * len(ds_raw))
    train_ds_raw = ds_raw[:train_ds_size]
    val_ds_raw = ds_raw[train_ds_size:]

    train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])

    max_len_src = 0
    max_len_tgt = 0

    for item in ds_raw:
        src_ids = tokenizer_src.encode(item[config['lang_src']]).ids
        tgt_ids = tokenizer_tgt.encode(item[config['lang_tgt']]).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))

    print(f'Max length of the source sentence : {max_len_src}')
    print(f'Max length of the target sentence : {max_len_tgt}')

    train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True, num_workers=2, pin_memory=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(vocab_src_len, vocab_tgt_len, config['seq_len'], config['seq_len'], config['d_model'])
    return model

def train_model(config):
    # Setup Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f'Using device {device}')
    
    # Create model folder
    Path(config['model_folder']).mkdir(parents=True, exist_ok=True)
    
    # Log file path
    log_file_path = Path(config['model_folder']) / "validation_logs.txt"

    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)

    writer = SummaryWriter(log_dir=config['experiment_name'])
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)

    initial_epoch = 0
    global_step = 0
    preload = config['preload']
    
    if preload == 'latest':
        model_filename = latest_weights_file_path(config)
    elif preload:
        model_filename = get_weights_file_path(config, preload)
    else:
        model_filename = None

    if model_filename:
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename, map_location=device)
        model.load_state_dict(state['model_state_dict'])
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']

    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_tgt.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

    for epoch in range(initial_epoch, config['num_epochs']):
        torch.cuda.empty_cache()
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f'Processing epoch {epoch:02d}')
        
        for batch in batch_iterator:
            encoder_input = batch['encoder_input'].to(device)
            decoder_input = batch['decoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            decoder_mask = batch['decoder_mask'].to(device)

            encoder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
            proj_output = model.project(decoder_output)

            label = batch['label'].to(device)

            loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
            
            batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

            writer.add_scalar('train loss', loss.item(), global_step)
            writer.flush()

            loss.backward()
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            global_step += 1

        # Custom print function for logging
        def print_msg(msg):
            batch_iterator.write(msg)
            with open(log_file_path, "a", encoding='utf-8') as f:
                f.write(msg + "\n")

        print_msg(f"\n--- Epoch {epoch:02d} Validation ---")
        
        # FIXED LINE: Changed 'global_state' to 'global_step'
        run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, print_msg, global_step, writer)

        model_filename = get_weights_file_path(config, f'{epoch:02d}')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)
        
if __name__ == '__main__':
    warnings.filterwarnings('ignore')
    config = get_config()
    train_model(config)

In [None]:
%%writefile translate.py
from pathlib import Path
import torch
import sys
from config import get_config, latest_weights_file_path
from model import build_transformer
from tokenizers import Tokenizer
from dataset import causal_mask

def translate(sentence: str):
    # 1. Load Config and Device
    config = get_config()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 2. Load Tokenizers
    # We load the existing files (we do NOT retrain them)
    tokenizer_src = Tokenizer.from_file(config['tokenizer_file'].format(config['lang_src']))
    tokenizer_tgt = Tokenizer.from_file(config['tokenizer_file'].format(config['lang_tgt']))

    # 3. Build Model
    model = build_transformer(
        tokenizer_src.get_vocab_size(), 
        tokenizer_tgt.get_vocab_size(), 
        config['seq_len'], 
        config['seq_len'], 
        d_model=config['d_model']
    ).to(device)

    # 4. Load Pre-trained Weights
    model_filename = latest_weights_file_path(config)
    if not model_filename:
        print("No weights found! Train the model first.")
        return
    
    print(f"Loading weights from: {model_filename}")
    state = torch.load(model_filename, map_location=device)
    model.load_state_dict(state['model_state_dict'])
    model.eval() # Switch to evaluation mode

    # 5. Prepare Input Text
    # Encode the sentence and add SOS/EOS
    sos_token = tokenizer_tgt.token_to_id('[SOS]')
    eos_token = tokenizer_tgt.token_to_id('[EOS]')
    
    encoder_input_tokens = tokenizer_src.encode(sentence).ids
    encoder_input = torch.tensor(
        [sos_token] + encoder_input_tokens + [eos_token], 
        dtype=torch.int64
    ).to(device)
    
    # Add batch dimension (1, seq_len)
    encoder_input = encoder_input.unsqueeze(0) 
    
    # Create Encoder Mask (1, 1, 1, seq_len)
    # Since we are doing inference on 1 sentence, we don't strictly need padding mask if we handle lengths right,
    # but strictly speaking, we should mask padding if we had it. Here we have no padding.
    encoder_mask = (encoder_input != tokenizer_src.token_to_id('[PAD]')).unsqueeze(0).unsqueeze(0).int().to(device)

    # 6. Run Encoder
    with torch.no_grad():
        encoder_output = model.encode(encoder_input, encoder_mask)

        # 7. Auto-regressive Decoder (Greedy Decode)
        # Start with just [SOS]
        decoder_input = torch.empty(1, 1).fill_(sos_token).type_as(encoder_input).to(device)

        while True:
            if decoder_input.size(1) == config['seq_len']:
                break

            # Create Mask for Decoder
            decoder_mask = causal_mask(decoder_input.size(1)).type_as(encoder_mask).to(device)

            # Calculate output
            out = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)

            # Get next token probability
            prob = model.project(out[:, -1])
            _, next_word = torch.max(prob, dim=1)

            # Append next word to decoder input
            decoder_input = torch.cat(
                [decoder_input, torch.empty(1, 1).type_as(encoder_input).fill_(next_word.item()).to(device)], 
                dim=1
            )

            if next_word == eos_token:
                break

    # 8. Decode the result back to text
    # Squeeze batch dim, remove SOS
    output_ids = decoder_input.squeeze(0).tolist()
    
    translated_text = tokenizer_tgt.decode(output_ids)
    print(f"\nEnglish: {sentence}")
    print(f"Hindi:   {translated_text}")

if __name__ == '__main__':
    # Usage: python translate.py "Hello world"
    if len(sys.argv) > 1:
        sentence = sys.argv[1]
        translate(sentence)
    else:
        print("Please provide a sentence. Example: python translate.py 'Hello world'")

In [None]:
# Copy tokenizers from Input (Read-only) to Working (Read-Write)
!cp /kaggle/input/toktok/tokenizer_en.json .
!cp /kaggle/input/toktok/tokenizer_hi.json .

In [None]:
!python train.py

In [None]:
!python translate.py "I am a student of this institute."