In [None]:
import torch
import torch.nn as nn
from torch.utils.data import dataloader, dataset, TensorDataset
import torch.nn.functional as F
import math

#My custom Positional Encodings
class Positional_Encoding(nn.Module):  # Fixed typo: "Postional" → "Positional"
    def __init__(self, max_len, d_model):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                            (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class Masked_Attention(nn.Module):
    """Improved masked self-attention with efficient causal masking"""
    def __init__(self, num_heads, embedings, dropout=0.1):
        super().__init__()
        assert embedings % num_heads == 0, "embedings must be divisible by num_heads"

        self.num_heads = num_heads
        self.embedings = embedings
        self.heads = embedings // num_heads
        self.dropout = nn.Dropout(dropout)

        # QKV projections (keeping your original names)
        self.fc1 = nn.Linear(embedings, embedings)
        self.fc2 = nn.Linear(embedings, embedings)
        self.fc3 = nn.Linear(embedings, embedings)
        self.outlayer = nn.Linear(embedings, embedings)

        # Register causal mask as buffer (moved to device automatically)
        self.register_buffer('causal_mask', None)

    def forward(self, x, mask=None):
        batch_size, seq_len, embed = x.size()

        # Project to Q, K, V (keeping your original variable names)
        Q = self.fc1(x)
        K = self.fc2(x)
        V = self.fc3(x)

        # Reshape for multi-head attention
        Q = Q.view(batch_size, seq_len, self.num_heads, self.heads).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.heads).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.heads).transpose(1, 2)

        # Scaled dot-product attention
        first_term = torch.matmul(Q, K.transpose(-2, -1))
        second_term = math.sqrt(self.heads)
        scores = first_term / second_term

        # Apply padding mask if provided
        if mask is not None:
            padding_mask = mask.unsqueeze(1).unsqueeze(2)  # [B, 1, 1, S]
            scores = scores.masked_fill(padding_mask, float('-inf'))

        # Apply causal mask (prevents attending to future tokens)
        # FIXED: Create mask once and reuse, don't recreate every forward pass
        if self.causal_mask is None or self.causal_mask.size(0) != seq_len:
            # Create upper triangular matrix (1s above diagonal)
            self.causal_mask = torch.triu(
                torch.ones(seq_len, seq_len, device=x.device),
                diagonal=1
            ).bool()

        scores = scores.masked_fill(self.causal_mask, float('-inf'))

        # Softmax and dropout
        val = F.softmax(scores, dim=-1)
        val = self.dropout(val)

        # Apply attention to values
        output = torch.matmul(val, V)

        # Reshape back to [B, S, E]
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed)

        # Final linear projection
        out = self.outlayer(output)
        return out


class Multihead(nn.Module):
    """Improved multi-head attention for encoder and cross-attention"""
    def __init__(self, num_heads, embeding, dropout=0.1):
        super().__init__()
        assert embeding % num_heads == 0, "embeding must be divisible by num_heads"

        self.num_heads = num_heads
        self.embeding = embeding
        self.head = embeding // num_heads
        self.dropout = nn.Dropout(dropout)

        # QKV projections (keeping your original names)
        self.fc1 = nn.Linear(embeding, embeding)
        self.fc2 = nn.Linear(embeding, embeding)
        self.fc3 = nn.Linear(embeding, embeding)
        self.out_layer = nn.Linear(embeding, embeding)

    def forward(self, x, enc_output=None, mask=None):
        batch_size, seq_len, embed = x.size()

        # Query from decoder (keeping your original variable names)
        Q = self.fc1(x)

        # Key and Value from encoder (if cross-attention) or from x (self-attention)
        if enc_output is not None:
            K = self.fc2(enc_output)
            V = self.fc3(enc_output)
            seq_len_kv = enc_output.size(1)
        else:
            K = self.fc2(x)
            V = self.fc3(x)
            seq_len_kv = seq_len

        # Reshape for multi-head attention (split heads)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head).transpose(1, 2)
        K = K.view(batch_size, seq_len_kv, self.num_heads, self.head).transpose(1, 2)
        V = V.view(batch_size, seq_len_kv, self.num_heads, self.head).transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head)

        # Apply padding mask if provided
        if mask is not None:
            padding_mask = mask.unsqueeze(1).unsqueeze(2)  # [B, 1, 1, S]
            scores = scores.masked_fill(padding_mask, float('-inf'))

        # Softmax and dropout
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        # Apply attention to values
        out = torch.matmul(attn, V)

        # Reshape back to [B, S, E] (combine heads)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, embed)

        # Final linear projection
        return self.out_layer(out)

# My Custom Postional Feed Forward Network
class Position_Feedforward(nn.Module):
    def __init__(self, input_ch, output_ch, dropout = 0.1):
        super().__init__()
        self.fc1 = nn.Linear(input_ch, output_ch)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(output_ch, input_ch)
        self.dropout = nn.Dropout(dropout)


    def forward(self, x):
         x= self.fc1(x)
         x = self.relu(x)
         x = self.dropout(x)
         x = self.fc2(x)

         return x

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import dataset, TensorDataset, dataloader
import torch.nn.functional as F
import math

#Desigining Multi-layered Encoder
class Multi_Encoder(nn.Module):
    def __init__(self, d_model, heads, hidden_lay, dropout = 0.1):
        super().__init__()
        #Encoder Stuff
        self.E_Multi = Multihead(heads, d_model, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.E_FeedFor = Position_Feedforward(d_model, hidden_lay)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, src_mask = None):
        # Doing Encoder Stuff
        x1 = self.E_Multi(x, mask=src_mask)
        x = self.norm1(x + self.dropout(x1))

        x2 = self.E_FeedFor(x)
        enc_output = self.norm2(x + self.dropout(x2))
        return enc_output

class Multi_Decoder(nn.Module):
    def __init__(self, d_model, heads, hidden_lay, dropout = 0.1):
        super().__init__()
        #Decoder stuff
        self.D_Mask = Masked_Attention(heads, d_model, dropout)
        self.norm3 = nn.LayerNorm(d_model)
        self.Cross_att = Multihead(heads, d_model, dropout)
        self.norm4 = nn.LayerNorm(d_model)
        self.D_Feedfor = Position_Feedforward(d_model, hidden_lay)
        self.norm5 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, tgt_mask=None, src_mask=None):
        x3 = self.D_Mask(x, mask=tgt_mask)
        x = self.norm3(x+ self.dropout(x3))


        x4 = self.Cross_att(x, enc_output, mask=src_mask)
        x = self.norm4(x + self.dropout(x4))

        x5 = self.D_Feedfor(x)
        x = self.norm5(x + self.dropout(x5))
        return x



class FullTransformer_Custom(nn.Module):
    def __init__(self, vocab_size, heads, d_model, hidden_lay,seq_len, num_layers,dropout = 0.1):
        super().__init__()
        self.src_embedding = nn.Embedding(vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model
        self.pos = Positional_Encoding(seq_len, d_model)
        self.pos1 = Positional_Encoding(seq_len,d_model)

        # Create 6 encoder and 6 decoder layers
        self.encoder_layers = nn.ModuleList([
            Multi_Encoder(d_model, heads, hidden_lay, dropout)
            for _ in range(num_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            Multi_Decoder(d_model, heads, hidden_lay, dropout)
            for _ in range(num_layers)
        ])

        self.out_layer = nn.Linear(d_model, vocab_size)

    def encoder(self, src, src_mask=None):
        # Doing Encoder Stuff
        x = self.src_embedding(src) * math.sqrt(self.d_model)
        x = self.pos(x)

        for layer in self.encoder_layers:
            x = layer(x, src_mask=src_mask)
        return x

    def decoder(self, tar, enc_output, tgt_mask=None, src_mask=None):
        # Doing the Decoder Stuff
        x = self.tgt_embedding(tar) * math.sqrt(self.d_model)
        x = self.pos1(x)
        for layer in self.decoder_layers:
            x = layer(x, enc_output, tgt_mask=tgt_mask, src_mask=src_mask)
        out = self.out_layer(x)
        return out


    def forward(self, src, tar, src_mask=None, tgt_mask=None):
        encoder_output = self.encoder(src, src_mask=src_mask)
        final = self.decoder(tar, encoder_output, tgt_mask=tgt_mask, src_mask=src_mask)
        return final





In [None]:
import os
os.kill(os.getpid(), 9)

In [None]:
'''
Willing to write the Data Loading part which gonna take care of all the
preprocessing steps that has to be taken before feding the data to the Transformer
'''
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
from torchtext.datasets import Multi30k
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, dataset, TensorDataset
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import math
from itertools import islice

# Loading the Data
train_iter, valid_iter, test_iter = Multi30k(split=('train', 'valid', 'test'), language_pair=('en', 'de'))

#Initializing the Tokeniziers
tokenizer_en = get_tokenizer("spacy", language="en_core_web_sm")
tokenizer_de = get_tokenizer("spacy", language="de_core_news_sm")

#Method to tokenizer on each word of the sentence
def yeild_tokens(datasets, tokenizer, index = 0):
    for data in datasets:
        yield tokenizer(data[index])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Using default build_vocab_from_iterator to find all the unique words and assigning them an integer
#For english sentences
eng_tokens = build_vocab_from_iterator(yeild_tokens(train_iter, tokenizer_en), specials=["<unk>", "<pad>", "<bos>", "<eos>"], max_tokens=8000)
eng_tokens.set_default_index(eng_tokens["<unk>"])

train_iter, valid_iter, test_iter = Multi30k(split=('train', 'valid', 'test'), language_pair=('en', 'de'))
#For German Sentences so keeping the index as 1
ger_tokens = build_vocab_from_iterator(yeild_tokens(train_iter, tokenizer_de, index=1), specials=["<unk>", "<pad>", "<bos>", "<eos>"], max_tokens=8000)
ger_tokens.set_default_index(ger_tokens["<unk>"])
#print(len(eng_tokens))
#print(len(ger_tokens))

#Method to apply these vocab to data
def apply_vocab(src_text, tgr_text):
    src_tokens = [eng_tokens["<bos>"]] + [eng_tokens[t] for t in tokenizer_en(src_text)] + [eng_tokens["<eos>"]]
    tgt_tokens = [ger_tokens["<bos>"]] + [ger_tokens[t] for t in tokenizer_de(tgr_text)] + [ger_tokens["<eos>"]]

    src_tensor = torch.tensor(src_tokens)
    tgr_tensor = torch.tensor(tgt_tokens)

    return src_tensor, tgr_tensor

#Method to apply padding
def padding(src, tar):
    pad_src = pad_sequence(src, batch_first= True, padding_value=eng_tokens["<pad>"])
    pad_tar = pad_sequence(tar, batch_first=True, padding_value=ger_tokens["<pad>"])

    src_op = (pad_src == eng_tokens["<pad>"])
    tar_op = (pad_tar == ger_tokens["<pad>"])

    return pad_src, pad_tar, src_op, tar_op

#Applying custom apply_vocab and padding function to the train data using a function called collat
def collate_fn(batch):
    src_list , tgr_list = [], []

    for src, tgr in batch:
        src_tensor , tgr_tensor = apply_vocab(src, tgr)
        src_list.append(src_tensor)
        tgr_list.append(tgr_tensor)

    return padding(src_list, tgr_list)

#Now we has to write a Custom Dataloader that uses this collate function
train_iter, valid_iter, test_iter = Multi30k(split=('train', 'valid', 'test'), language_pair=('en', 'de'))
train_loader = DataLoader(train_iter, batch_size = 32, collate_fn=collate_fn)
val_loader = DataLoader(valid_iter, batch_size = 32, collate_fn=collate_fn)
test_loader = DataLoader(test_iter, batch_size = 32, collate_fn=collate_fn)

print("\nTesting DataLoader...")
src, tgt, src_mask, tgt_mask = next(iter(train_loader))
print(f"Source shape: {src.shape}")
print(f"Target shape: {tgt.shape}")
print(f"Source mask shape: {src_mask.shape}")
print(f"Target mask shape: {tgt_mask.shape}")

Using device: cuda

Testing DataLoader...
Source shape: torch.Size([32, 24])
Target shape: torch.Size([32, 21])
Source mask shape: torch.Size([32, 24])
Target mask shape: torch.Size([32, 21])


In [None]:
'''
Training the Custom Transformer model with Mixied precesion and Gradient Clipping
'''
from torch.optim.lr_scheduler import ReduceLROnPlateau

#Adding Validation Step after each loop
vocab_size = max(len(eng_tokens), len(ger_tokens))
model = FullTransformer_Custom(vocab_size, heads = 8, d_model=256, hidden_lay=512, seq_len=100, num_layers=3, dropout=0.2)
model = model.to(device)
criterion = nn.CrossEntropyLoss(
    ignore_index=ger_tokens["<pad>"],
    label_smoothing=0.1
)
optimizer = optim.AdamW(  # Changed from Adam
    model.parameters(),
    lr=0.0001,           # Reduced from 0.0003
    betas=(0.9, 0.98),
    eps=1e-9,
    weight_decay=0.01
)
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=3,          # Reduced from 5
    min_lr=1e-6
)
epochs = 30
scaler = torch.cuda.amp.GradScaler()
best_val_loss = float('inf')

for i in range(epochs):
    print("Running Epoch Number", i)
    model.train()
    total_loss = 0
    batch_count = 0
    for src, tgr, src_mask, tgr_mask in train_loader:
        src = src.to(device)
        tgr = tgr.to(device)
        src_mask = src_mask.to(device)
        tgr_mask = tgr_mask.to(device)

        tgr_input = tgr[:, :-1]
        tgr_mask_input = tgr_mask[:, :-1]
        tgr_output = tgr[:, 1:]
        optimizer.zero_grad()

        #Adding Mixed Precesion to speed up the training
        with torch.cuda.amp.autocast():
             output = model(src, tgr_input, src_mask=src_mask, tgt_mask=tgr_mask_input)
             loss = criterion(output.reshape(-1, vocab_size), tgr_output.reshape(-1))
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        batch_count += 1
    avg_loss = total_loss / batch_count
    print(f"Epoch {i+1}/{epochs}, Loss: {avg_loss:.4f}, Perplexity: {math.exp(avg_loss):.2f}")

    #Adding Validation after each Epoch to generalize the model better
    model.eval()
    val_loss = 0
    val_count = 0
    with torch.no_grad():
         for src, tgr, src_mask, tgr_mask in val_loader:
             src = src.to(device)
             tgr = tgr.to(device)
             src_mask = src_mask.to(device)
             tgr_mask = tgr_mask.to(device)
             tgr_input = tgr[:, :-1]
             tgr_output = tgr[:, 1:]
             tgr_mask_input = tgr_mask[:, :-1]

             #Adding Mixed Precesion to speed up the training
             with torch.cuda.amp.autocast():
                  output = model(src, tgr_input, src_mask=src_mask, tgt_mask=tgr_mask_input)
                  loss = criterion(output.reshape(-1, vocab_size), tgr_output.reshape(-1))

             val_loss += loss.item()
             val_count += 1
    avg_val_loss = val_loss / val_count
    print(f"Epoch {i+1}/{epochs}, Val Loss: {avg_val_loss:.4f}, Perplexity: {math.exp(avg_val_loss):.2f}\n")
    scheduler.step(avg_val_loss)
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save({
            'epoch': i + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': avg_loss,
            'val_loss': avg_val_loss,
        }, 'best_transformer_model.pth')
        print(f"✓ Best model saved! Val Loss: {avg_val_loss:.4f}\n")


#Saving the Model
# Save the model after training completes
torch.save({
    'epoch': epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'train_loss': avg_loss,
    'val_loss': avg_val_loss,
}, 'final_transformer_model.pth')
print(f"Training complete!")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Final model saved as 'final_transformer_model.pth'")
print(f"Best model saved as 'best_transformer_model.pth'")
from google.colab import files

print("Downloading models to your computer...")
files.download('best_transformer_model.pth')
files.download('final_transformer_model.pth')
print("Download complete!")

Running Epoch Number 0




Epoch 1/30, Loss: 5.1561, Perplexity: 173.49
Epoch 1/30, Val Loss: 4.4798, Perplexity: 88.22

✓ Best model saved! Val Loss: 4.4798

Running Epoch Number 1
Epoch 2/30, Loss: 4.3313, Perplexity: 76.04
Epoch 2/30, Val Loss: 4.0865, Perplexity: 59.53

✓ Best model saved! Val Loss: 4.0865

Running Epoch Number 2
Epoch 3/30, Loss: 4.0311, Perplexity: 56.32
Epoch 3/30, Val Loss: 3.8636, Perplexity: 47.64

✓ Best model saved! Val Loss: 3.8636

Running Epoch Number 3
Epoch 4/30, Loss: 3.8284, Perplexity: 45.99
Epoch 4/30, Val Loss: 3.7103, Perplexity: 40.86

✓ Best model saved! Val Loss: 3.7103

Running Epoch Number 4
Epoch 5/30, Loss: 3.6751, Perplexity: 39.45
Epoch 5/30, Val Loss: 3.5921, Perplexity: 36.31

✓ Best model saved! Val Loss: 3.5921

Running Epoch Number 5
Epoch 6/30, Loss: 3.5524, Perplexity: 34.90
Epoch 6/30, Val Loss: 3.5122, Perplexity: 33.52

✓ Best model saved! Val Loss: 3.5122

Running Epoch Number 6
Epoch 7/30, Loss: 3.4498, Perplexity: 31.49
Epoch 7/30, Val Loss: 3.4263, P

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Download complete!


In [None]:
#TESTING THE TRANSFORMER ON TEST DATA

import urllib.request
import gzip
from pathlib import Path
from torch.utils.data import Dataset, DataLoader

# Download and prepare test data
def download_multi30k_test():
    """Download Multi30k test data, bypassing corrupted cache"""
    data_dir = Path('/tmp/multi30k_manual')
    data_dir.mkdir(exist_ok=True)

    base_url = 'https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/'
    files = {'en': 'test_2016_flickr.en.gz', 'de': 'test_2016_flickr.de.gz'}

    data = {}
    for lang, filename in files.items():
        txt_path = data_dir / f'test_{lang}.txt'

        if not txt_path.exists():
            print(f"Downloading {lang} test data...")
            urllib.request.urlretrieve(base_url + filename, data_dir / filename)
            with gzip.open(data_dir / filename, 'rb') as f_in, open(txt_path, 'wb') as f_out:
                f_out.write(f_in.read())
            (data_dir / filename).unlink()

        with open(txt_path, 'r', encoding='utf-8') as f:
            data[lang] = [line.strip() for line in f]

    return data['en'], data['de']

class SimpleDataset(Dataset):
    def __init__(self, src, tgt):
        self.data = list(zip(src, tgt))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Download test data
print("Downloading test data...")
test_en, test_de = download_multi30k_test()
print(f"Test set size: {len(test_en)} examples\n")

test_loader = DataLoader(
    SimpleDataset(test_en, test_de),
    batch_size=8,
    collate_fn=collate_fn,
    shuffle=False
)

# Recreate model with the EXACT architecture that was saved
print("Recreating model to match checkpoint...")
vocab_size = max(len(eng_tokens), len(ger_tokens))
model = FullTransformer_Custom(
    vocab_size,
    heads=8,
    d_model=256,
    hidden_lay=512,  # Your checkpoint was trained with 1024, NOT 2048
    seq_len=100,
    num_layers=3
)
model = model.to(device)

# Load the best model weights
print("Loading best model weights...")
checkpoint = torch.load('best_transformer_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
print(f"✓ Model loaded from epoch {checkpoint['epoch']}")
print(f"✓ Best Val Loss: {checkpoint['val_loss']:.4f}\n")

# Evaluate on test set
print("Evaluating on test set...")
model.eval()
test_loss = 0
batch_count = 0

with torch.no_grad():
    for src, tgt, src_mask, tgt_mask in test_loader:
        src = src.to(device)
        tgt = tgt.to(device)
        src_mask = src_mask.to(device)
        tgt_mask = tgt_mask.to(device)

        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]
        tgt_mask_input = tgt_mask[:, :-1]

        with torch.cuda.amp.autocast():
            output = model(src, tgt_input, src_mask=src_mask, tgt_mask=tgt_mask_input)
            loss = criterion(output.reshape(-1, vocab_size), tgt_output.reshape(-1))

        test_loss += loss.item()
        batch_count += 1

avg_test_loss = test_loss / batch_count
test_perplexity = math.exp(avg_test_loss)

print("\n" + "="*50)
print("TEST RESULTS")
print("="*50)
print(f"Test Loss:       {avg_test_loss:.4f}")
print(f"Test Perplexity: {test_perplexity:.2f}")
print(f"Batches:         {batch_count}")
print("="*50)

Downloading test data...
Test set size: 1000 examples

Recreating model to match checkpoint...
Loading best model weights...
✓ Model loaded from epoch 29
✓ Best Val Loss: 2.9589

Evaluating on test set...

TEST RESULTS
Test Loss:       2.9718
Test Perplexity: 19.53
Batches:         125


In [None]:
def translate_sentence_greedy(sentence, model, eng_tokens, ger_tokens, device, max_len=100):
    """Simple greedy decoding with CORRECT mask format"""
    model.eval()

    # Get vocab mappings
    eng_stoi = eng_tokens.get_stoi()
    ger_stoi = ger_tokens.get_stoi()
    ger_itos = ger_tokens.get_itos()

    # Tokenize source
    tokens = tokenizer_en(sentence.lower())
    src_indices = [eng_stoi["<bos>"]] + [eng_stoi.get(t, eng_stoi["<unk>"]) for t in tokens] + [eng_stoi["<eos>"]]

    # Pad source
    src_len = len(src_indices)
    if len(src_indices) < max_len:
        src_indices += [eng_stoi["<pad>"]] * (max_len - len(src_indices))
    else:
        src_indices = src_indices[:max_len]

    src = torch.tensor([src_indices]).to(device)

    # FIXED: Source mask - True for PAD tokens (to mask them out)
    src_mask = (src == eng_stoi["<pad>"])  # Shape: [batch, seq_len]

    # Start with <bos>
    tgt_indices = [ger_stoi["<bos>"]]

    with torch.no_grad():
        for _ in range(max_len - 1):
            seq_len = len(tgt_indices)

            # Pad current target
            tgt_padded = tgt_indices + [ger_stoi["<pad>"]] * (max_len - seq_len)
            tgt = torch.tensor([tgt_padded]).to(device)

            # FIXED: Target mask - True for PAD tokens
            tgt_mask = (tgt == ger_stoi["<pad>"])  # Shape: [batch, seq_len]

            with torch.cuda.amp.autocast():
                output = model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)

            # Get next token from the LAST GENERATED position
            next_token = output[0, seq_len - 1].argmax().item()

            if next_token == ger_stoi["<eos>"]:
                break

            tgt_indices.append(next_token)

    # Convert to words
    translated_tokens = []
    for idx in tgt_indices[1:]:  # Skip <bos>
        if idx == ger_stoi["<eos>"]:
            break
        token = ger_itos[idx]
        if token not in ["<bos>", "<eos>", "<pad>", "<unk>"]:
            translated_tokens.append(token)

    return ' '.join(translated_tokens) if translated_tokens else "<empty>"


def translate_sentence_beam(sentence, model, eng_tokens, ger_tokens, device, max_len=100, beam_width=5):
    """Translate using beam search with CORRECT mask format"""
    model.eval()

    # Tokenize source
    tokens = tokenizer_en(sentence.lower())

    # Get vocab mappings
    eng_stoi = eng_tokens.get_stoi()
    ger_stoi = ger_tokens.get_stoi()

    src_indices = [eng_stoi["<bos>"]] + [eng_stoi.get(t, eng_stoi["<unk>"]) for t in tokens] + [eng_stoi["<eos>"]]

    # Pad source to max_len
    if len(src_indices) < max_len:
        src_indices += [eng_stoi["<pad>"]] * (max_len - len(src_indices))
    else:
        src_indices = src_indices[:max_len]

    src = torch.tensor([src_indices]).to(device)

    # FIXED: Source mask - True for PAD tokens
    src_mask = (src == eng_stoi["<pad>"])

    # Initialize beam
    beams = [([ger_stoi["<bos>"]], 0.0)]

    with torch.no_grad():
        for step in range(max_len - 1):
            all_candidates = []

            for seq, score in beams:
                if seq[-1] == ger_stoi["<eos>"]:
                    all_candidates.append((seq, score))
                    continue

                seq_len = len(seq)

                # Pad target to max_len
                tgt_padded = seq + [ger_stoi["<pad>"]] * (max_len - seq_len)
                tgt = torch.tensor([tgt_padded]).to(device)

                # FIXED: Target mask - True for PAD tokens
                tgt_mask = (tgt == ger_stoi["<pad>"])

                with torch.cuda.amp.autocast():
                    output = model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)

                # Get logits for last valid position
                logits = output[0, seq_len - 1]
                log_probs = torch.log_softmax(logits, dim=-1)
                top_probs, top_indices = torch.topk(log_probs, beam_width)

                for prob, idx in zip(top_probs, top_indices):
                    new_seq = seq + [idx.item()]
                    new_score = score + prob.item()
                    all_candidates.append((new_seq, new_score))

            # Keep top beam_width sequences
            beams = sorted(all_candidates, key=lambda x: x[1] / len(x[0]), reverse=True)[:beam_width]

            if all(seq[-1] == ger_stoi["<eos>"] for seq, _ in beams):
                break

    # Get best sequence
    best_seq = beams[0][0]
    ger_itos = ger_tokens.get_itos()

    # Convert to words
    translated_tokens = []
    for idx in best_seq[1:]:
        if idx == ger_stoi["<eos>"]:
            break
        token = ger_itos[idx]
        if token not in ["<bos>", "<eos>", "<pad>", "<unk>"]:
            translated_tokens.append(token)

    return ' '.join(translated_tokens) if translated_tokens else "<empty>"


# Test
test_sentences = [
    "a dog is running in the park",
    "the cat is sleeping on the bed",
    "two people are walking together",
    "children are playing with a ball"
]

print("=" * 60)
print("GREEDY DECODING")
print("=" * 60)

for sentence in test_sentences:
    translation = translate_sentence_greedy(sentence, model, eng_tokens, ger_tokens, device)
    print(f"\nEnglish:  {sentence}")
    print(f"German:   {translation}")

print("\n" + "=" * 60)
print("BEAM SEARCH (width=5)")
print("=" * 60)

for sentence in test_sentences:
    translation = translate_sentence_beam(sentence, model, eng_tokens, ger_tokens, device, beam_width=5)
    print(f"\nEnglish:  {sentence}")
    print(f"German:   {translation}")

GREEDY DECODING

English:  a dog is running in the park
German:   Ein Hund läuft im Park .

English:  the cat is sleeping on the bed
German:   Eine Katze schläft auf dem Bett .

English:  two people are walking together
German:   Zwei Personen gehen zusammen spazieren .

English:  children are playing with a ball
German:   Kinder spielen mit einem Ball .

BEAM SEARCH (width=5)

English:  a dog is running in the park
German:   Ein Hund läuft in einem Park .

English:  the cat is sleeping on the bed
German:   Eine Katze schläft auf dem Bett .

English:  two people are walking together
German:   Zwei Personen gehen zusammen spazieren .

English:  children are playing with a ball
German:   Kinder spielen mit einem Ball .
