In [None]:
!pip install tokenizers datasets

In [4]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

from pathlib import Path
from tqdm import tqdm
import math

In [5]:
def get_config():
    return {
        'batch_size': 32,  # Increased for better GPU utilization
        'epochs': 4,      # Reduced for faster training
        'lr': 1e-4,
        'seq_len': 128,
        'd_model': 384,
        'lang_src': 'en',
        'lang_tgt': 'fr',  # Changed to French for better dataset availability
        'model_folder': 'weights',
        'model_basename': 'tmodel_',
        'preload': None,
        'tokenizer_file': 'tokenizer_{0}.json',
        'experiment_name': 'runs/tmodel'
    }

def get_weight_file_path(config, epoch):
    model_folder = config['model_folder']
    model_basename = config['model_basename']
    model_filename = f"{model_basename}{epoch}.pt"
    return str(Path('.')/model_folder/model_filename)

config = get_config()
print("Configuration:", config)

Configuration: {'batch_size': 32, 'epochs': 4, 'lr': 0.0001, 'seq_len': 128, 'd_model': 384, 'lang_src': 'en', 'lang_tgt': 'fr', 'model_folder': 'weights', 'model_basename': 'tmodel_', 'preload': None, 'tokenizer_file': 'tokenizer_{0}.json', 'experiment_name': 'runs/tmodel'}


In [6]:
class InputEmbeddings(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)

In [7]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int, dropout: float):
        super().__init__()
        self.seq_len = seq_len
        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(seq_len, d_model)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.shape[1], :].requires_grad_(False)
        return self.dropout(x)

In [8]:
class LayerNormalization(nn.Module):
    def __init__(self, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(1))
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.alpha * (x-mean) / (std + self.eps) + self.bias

In [9]:
class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(self.dropout(torch.relu(self.linear1(x))))

In [10]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, h: int, dropout: float):
        super().__init__()
        self.d_model = d_model
        self.h = h
        assert d_model % h == 0, "d_model is not divisible by h"
        self.d_k = d_model // h
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.wo = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(q, k, v, mask, dropout):
        d_k = q.shape[-1]
        attention_scores = (q @ k.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            fill = torch.finfo(attention_scores.dtype).min
            attention_scores.masked_fill_(~mask, fill)
        attention_scores = attention_scores.softmax(dim=-1)
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        return (attention_scores @ v), attention_scores

    def forward(self, q, k, v, mask):
        query = self.wq(q)
        key = self.wk(k)
        value = self.wv(v)

        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)

        x, self.attention_scores = MultiHeadAttention.attention(query, key, value, mask, self.dropout)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
        return self.wo(x)

In [11]:
class ResidualConnection(nn.Module):
    def __init__(self, dropout: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization()

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [12]:
class EncoderBlock(nn.Module):
    def __init__(self, self_attention: MultiHeadAttention, self_feed_forward: FeedForward, dropout: float):
        super().__init__()
        self.attention_block = self_attention
        self.feed_forward_block = self_feed_forward
        self.residuals = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])

    def forward(self, x, src_mask):
        x = self.residuals[0](x, lambda x: self.attention_block(x, x, x, src_mask))
        x = self.residuals[1](x, self.feed_forward_block)
        return x

In [13]:
class Encoder(nn.Module):
    def __init__(self, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [14]:
class DecoderBlock(nn.Module):
    def __init__(self, self_attention: MultiHeadAttention, cross_attention: MultiHeadAttention, self_feed_forward: FeedForward, dropout: float):
        super().__init__()
        self.attention_block = self_attention
        self.cross_attention = cross_attention
        self.feed_forward_block = self_feed_forward
        self.residuals = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])

    def forward(self, x, encoder_output, tgt_mask, src_mask):
        x = self.residuals[0](x, lambda x: self.attention_block(x, x, x, tgt_mask))
        x = self.residuals[1](x, lambda x: self.cross_attention(x, encoder_output, encoder_output, src_mask))
        x = self.residuals[2](x, self.feed_forward_block)
        return x

In [15]:
class Decoder(nn.Module):
    def __init__(self, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()

    def forward(self, x, encoder_output, tgt_mask, src_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, tgt_mask, src_mask)
        return self.norm(x)

In [16]:
class ProjectionLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        return torch.log_softmax(self.proj(x), dim=-1)

In [17]:
class Transformer(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: PositionalEncoding, projection_layer: ProjectionLayer, tgt_pos: PositionalEncoding):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection = projection_layer

    def encode(self, src, src_mask):
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)

    def decode(self, tgt, encoder_output, src_mask, tgt_mask):
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_output, tgt_mask, src_mask)
    
    def project(self, x):
        return self.projection(x)

In [18]:
def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int = 512, N: int = 6, h: int = 4, dropout: float = 0.1, d_ff: int = 1024):
    src_embed = InputEmbeddings(d_model, src_vocab_size)
    tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)
    src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
    tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)

    encoder_blocks = []
    for _ in range(N):
        encoder_attention = MultiHeadAttention(d_model, h, dropout)
        encoder_feedforward = FeedForward(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(encoder_attention, encoder_feedforward, dropout)
        encoder_blocks.append(encoder_block)

    decoder_blocks = []
    for _ in range(N):
        decoder_attention = MultiHeadAttention(d_model, h, dropout)
        decoder_cross_attention = MultiHeadAttention(d_model, h, dropout)
        decoder_feedforward = FeedForward(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(decoder_attention, decoder_cross_attention, decoder_feedforward, dropout)
        decoder_blocks.append(decoder_block)

    encoder = Encoder(nn.ModuleList(encoder_blocks))
    decoder = Decoder(nn.ModuleList(decoder_blocks))
    project_layer = ProjectionLayer(d_model, tgt_vocab_size)

    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, project_layer, tgt_pos)

    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    
    return transformer

In [19]:
def causal_mask(size, device=None):
    # returns a (1, size, size) boolean mask where (i,j) is True iff j <= i
    m = torch.triu(torch.ones((size, size), dtype=torch.bool), diagonal=1)
    mask = ~m
    mask = mask.unsqueeze(0)
    return mask.to(device) if device is not None else mask

# -----------------------------------------------------------------------------
# 3) DATASET
# -----------------------------------------------------------------------------
class BilingualDataset(Dataset):
    def __init__(self, raw_split, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
        self.raw = raw_split
        self.tok_src = tokenizer_src
        self.tok_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.seq_len = seq_len
        self.sos_id = tokenizer_src.token_to_id('[SOS]')
        self.eos_id = tokenizer_src.token_to_id('[EOS]')
        self.pad_id = tokenizer_src.token_to_id('[PAD]')

    def __len__(self):
        return len(self.raw)

    def __getitem__(self, idx):
        item = self.raw[idx]
        src_txt = item['translation'][self.src_lang]
        tgt_txt = item['translation'][self.tgt_lang]

        src_ids = self.tok_src.encode(src_txt).ids
        tgt_ids = self.tok_tgt.encode(tgt_txt).ids

        # account for SOS/EOS
        if len(src_ids) + 2 > self.seq_len or len(tgt_ids) + 1 > self.seq_len:
            raise ValueError("Sentence too long")

        # build encoder_input: [SOS] + src_ids + [EOS] + PAD...
        enc = [self.sos_id] + src_ids + [self.eos_id]
        enc += [self.pad_id] * (self.seq_len - len(enc))
        enc = torch.tensor(enc, dtype=torch.long)

        # build decoder_input: [SOS] + tgt_ids + PAD...
        dec = [self.sos_id] + tgt_ids
        dec += [self.pad_id] * (self.seq_len - len(dec))
        dec = torch.tensor(dec, dtype=torch.long)

        # build label: tgt_ids + [EOS] + PAD...
        lbl = tgt_ids + [self.eos_id]
        lbl += [self.pad_id] * (self.seq_len - len(lbl))
        lbl = torch.tensor(lbl, dtype=torch.long)

        # masks: boolean
        enc_mask = (enc != self.pad_id).unsqueeze(0).unsqueeze(0)  # (1,1,seq_len)
        dec_mask = (dec != self.pad_id).unsqueeze(0).unsqueeze(0)  # (1,1,seq_len)
        dec_mask = dec_mask & causal_mask(self.seq_len)          # causal

        return {
            'encoder_input': enc,
            'decoder_input': dec,
            'label':         lbl,
            'encoder_mask':  enc_mask.bool(),
            'decoder_mask':  dec_mask.bool(),
            'src_text':      src_txt,
            'tgt_text':      tgt_txt
        }

In [20]:
def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]

def get_or_build_tokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
        tokenizer = Tokenizer(WordLevel(unk_token='[UNK]'))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

def get_ds(config):
    print("Loading dataset...")
    ds_raw = load_dataset('opus100', f"{config['lang_src']}-{config['lang_tgt']}", split='train')
    
    print("Building tokenizers...")
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])

    print("Filtering dataset...")
    def filter_long_sentences(example):
        src_ids = tokenizer_src.encode(example['translation'][config['lang_src']]).ids
        tgt_ids = tokenizer_tgt.encode(example['translation'][config['lang_tgt']]).ids
        return len(src_ids) + 2 <= config['seq_len'] and len(tgt_ids) + 1 <= config['seq_len']

    ds_raw = ds_raw.filter(filter_long_sentences)
    print(f"Dataset size after filtering: {len(ds_raw)}")

    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size

    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])
    train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])

    train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

In [21]:
def greedy(model, src, src_mask, tokenizer_tgt, max_len, device):
    # Unwrap DataParallel
    real = model.module if hasattr(model, 'module') else model

    sos = tokenizer_tgt.token_to_id('[SOS]')
    eos = tokenizer_tgt.token_to_id('[EOS]')

    # Encode once
    enc_out = real.transformer.encode(src, src_mask)  # (1, T_src, d_model)

    # Start with SOS
    ys = torch.full((1, 1), sos, dtype=torch.long, device=device)

    for _ in range(max_len - 1):
        seq_l = ys.size(1)
        dec_mask = causal_mask(seq_l, device)  # (1, seq_l, seq_l)
        logits = real( src, ys, src_mask, dec_mask )  # (1, seq_l, V)
        nxt = logits[:, -1].argmax(-1, keepdim=True)  # (1,1)
        ys = torch.cat([ys, nxt], dim=1)
        if nxt.item() == eos:
            break

    return ys  # (1, L)

def run_validation(model, val_loader, tokenizer_tgt, max_len, device, print_fn, num_examples=2):
    model.eval()
    seen = 0
    with torch.no_grad():
        for batch in val_loader:
            # Expect batch size 1 for clarity
            src      = batch['encoder_input'].to(device).long()
            src_mask = batch['encoder_mask'].to(device)
            assert src.size(0) == 1, "Please use batch_size=1 for validation"

            gen = greedy(model, src, src_mask, tokenizer_tgt, max_len, device)
            src_txt  = tokenizer_tgt.decode(src.squeeze(0).tolist())
            ref_txt  = batch['tgt_text'][0]      # assuming your val dataset provides this
            pred_txt = tokenizer_tgt.decode(gen.squeeze(0).tolist())

            print_fn('-'*80)
            print_fn(f"SRC:       {src_txt}")
            print_fn(f"REFERENCE: {ref_txt}")
            print_fn(f"PREDICTED: {pred_txt}")

            seen += 1
            if seen >= num_examples:
                break
    print_fn("Validation complete.")

In [27]:
class Seq2SeqModel(nn.Module):
    def __init__(self, transformer):
        super().__init__()
        self.transformer = transformer

    def forward(self, src, tgt, src_mask, tgt_mask):
        # Encode → Decode → Project
        enc_out = self.transformer.encode(src, src_mask)
        dec_out = self.transformer.decode(tgt, enc_out, src_mask, tgt_mask)
        return self.transformer.project(dec_out)

In [None]:
from torch.cuda.amp import autocast, GradScaler
       
def train_model(config):
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    Path(config['model_folder']).mkdir(parents=True, exist_ok=True)

    # Data
    print("Loading data...")
    train_dl, val_dl, tok_src, tok_tgt = get_ds(config)

    # Model build
    print("Building model...")
    core = build_transformer(
        tok_src.get_vocab_size(),
        tok_tgt.get_vocab_size(),
        config['seq_len'], config['seq_len'],
        config['d_model']
    )
    model = Seq2SeqModel(core)

    # Multi‐GPU
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs!")
        model = nn.DataParallel(model)

    model = model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)
    loss_fn = nn.CrossEntropyLoss(
        ignore_index=tok_src.token_to_id("[PAD]"),
        label_smoothing=0.1
    ).cuda()

    scaler     = GradScaler()
    accum_steps = config.get('accum_steps', 1)

    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    print("Starting training...")

    for epoch in range(1, config['epochs']+1):
        model.train()
        total_loss = 0.0
        loop = tqdm(train_dl, desc=f"Epoch {epoch}/{config['epochs']}",
                    mininterval=1.0, miniters=100)

        optimizer.zero_grad()
        for batch_idx, batch in enumerate(loop):
            # Prepare (keep inputs on CPU)
            src       = batch['encoder_input'].long()
            tgt       = batch['decoder_input'].long()
            src_mask  = batch['encoder_mask']
            tgt_mask  = batch['decoder_mask']
            labels    = batch['label'].long()

            # Mixed‐precision forward
            with autocast():
                logits = model(src, tgt, src_mask, tgt_mask)  # (B, T, V)
                loss   = loss_fn(
                    logits.view(-1, tok_tgt.get_vocab_size()),
                    labels.view(-1).cuda()
                ) / accum_steps

            # Backward + step (scaled)
            scaler.scale(loss).backward()
            if (batch_idx + 1) % accum_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            total_loss += loss.item() * accum_steps
            loop.set_postfix({
                "loss":     f"{loss.item()*accum_steps:.3f}",
                "avg_loss": f"{(total_loss/(batch_idx+1)):.3f}"
            })

        avg_loss = total_loss / len(train_dl)
        print(f"Epoch {epoch} done — avg loss: {avg_loss:.4f}")

        # Validation (only need one GPU to print)
        print("Running validation...")
        run_validation(model, val_dl, tok_tgt, config['seq_len'], device, print, num_examples=3)

        # Checkpoint (unwrap real model if DataParallel)
        real = model.module if hasattr(model, 'module') else model
        ckpt = {
            'epoch': epoch,
            'model_state_dict': real.transformer.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'avg_loss': avg_loss
        }
        fn = get_weight_file_path(config, epoch)
        torch.save(ckpt, fn)
        print(f"Saved checkpoint: {fn}")

    print("Training complete!")
    return model



if __name__ == '__main__':
    model = train_model(config)

In [25]:
CHECKPOINT = "/kaggle/usr/lib/transformer_from_scratch/weights/tmodel_4.pt"
SRC_TOKENIZER = "/kaggle/usr/lib/transformer_from_scratch/tokenizer_en.json"
TGT_TOKENIZER = "/kaggle/usr/lib/transformer_from_scratch/tokenizer_fr.json"
SEQ_LEN = 128
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Load tokenizers ---
tokenizer_src = Tokenizer.from_file(SRC_TOKENIZER)
tokenizer_tgt = Tokenizer.from_file(TGT_TOKENIZER)

# --- Build model & load weights ---
model = build_transformer(
    src_vocab_size=tokenizer_src.get_vocab_size(),
    tgt_vocab_size=tokenizer_tgt.get_vocab_size(),
    src_seq_len=SEQ_LEN, tgt_seq_len=SEQ_LEN,
    d_model=384, N=6, h=4, dropout=0.1, d_ff=1024
).to(DEVICE)
ckpt = torch.load(CHECKPOINT, map_location=DEVICE)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

Transformer(
  (encoder): Encoder(
    (layers): ModuleList(
      (0-5): 6 x EncoderBlock(
        (attention_block): MultiHeadAttention(
          (wq): Linear(in_features=384, out_features=384, bias=True)
          (wk): Linear(in_features=384, out_features=384, bias=True)
          (wv): Linear(in_features=384, out_features=384, bias=True)
          (wo): Linear(in_features=384, out_features=384, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward_block): FeedForward(
          (dropout): Dropout(p=0.1, inplace=False)
          (linear1): Linear(in_features=384, out_features=1024, bias=True)
          (linear2): Linear(in_features=1024, out_features=384, bias=True)
        )
        (residuals): ModuleList(
          (0-1): 2 x ResidualConnection(
            (dropout): Dropout(p=0.1, inplace=False)
            (norm): LayerNormalization()
          )
        )
      )
    )
    (norm): LayerNormalization()
  )
  (decoder): Decoder(
    (la

In [29]:
model = Seq2SeqModel(model)

In [33]:
def translate(text: str) -> str:
    # 1) Tokenize and add special tokens
    ids = tokenizer_src.encode(text).ids
    ids = [tokenizer_src.token_to_id("[SOS]")] + ids + [tokenizer_src.token_to_id("[EOS]")]
    # 2) Pad or truncate to SEQ_LEN
    if len(ids) < SEQ_LEN:
        ids += [tokenizer_src.token_to_id("[PAD]")] * (SEQ_LEN - len(ids))
    else:
        ids = ids[:SEQ_LEN]
    src = torch.tensor([ids], dtype=torch.long, device=DEVICE)

    # 3) Build encoder mask
    src_mask = (src != tokenizer_src.token_to_id("[PAD]")).unsqueeze(1).unsqueeze(2)

    # 4) Greedy decode
    with torch.no_grad():
        pred_ids = greedy(model, src, src_mask, tokenizer_tgt, SEQ_LEN, DEVICE)

    # 5) Strip at EOS and decode
    out = pred_ids.squeeze(0).tolist()
    if tokenizer_tgt.token_to_id("[EOS]") in out:
        out = out[: out.index(tokenizer_tgt.token_to_id("[EOS]")) ]
    return tokenizer_tgt.decode(out).strip()

# ─── Try it out ───────────────────────────────────────────────────────────────
while True:
    inp = input("Enter your english text")
    if inp.lower() == 'exit':
        break
    fr = translate(inp)
    print("EN →", inp)
    print("FR →", fr, "\n")

Enter your english text France is a beautiful place to visit. I currently reside in India


EN → France is a beautiful place to visit. I currently reside in India
FR → La France est un endroit magnifique pour visiter . Je vis en Inde 



Enter your english text How are you today


EN → How are you today
FR → Comment tu es aujourd ' hui 



Enter your english text Merci


EN → Merci
FR → Merci . 



Enter your english text Bonjur


EN → Bonjur
FR →  



Enter your english text How many people died yesterday


EN → How many people died yesterday
FR → Combien de gens sont morts hier 



Enter your english text My name is Aniketh Reddy


EN → My name is Aniketh Reddy
FR → Mon nom est 



Enter your english text exit
