# Imports for the entire model and training loop

In [None]:
import torch
import torch.nn as nn
import math
from torch.utils.data import Dataset, DataLoader, random_split

from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
import sys
from pathlib import Path
import warnings
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

warnings.filterwarnings("ignore")

# Model Code

The following is the code for the treansformer itself, including the position encoding, input embedding, Encoder and Decoder Block, LayerNormalization layer, Feed Forward Network, Residual Connections for both Encoders and Decoder, output layer and the Transformer

In [2]:
class InputEmbedding(nn.Module):
    def __init__(self, d_model: int, vocab_size: int, ):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        return self.embedding(x)*math.sqrt(self.d_model)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int, dropout: float):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        # (seq_len, d_model)
        pe = torch.zeros(self.seq_len, d_model)
        # (seq_len, 1)
        position = torch.arange(0, self.seq_len, dtype=torch.float).unsqueeze(1)
        # 2i*log(10000)/d_model, div_term = (d_model,)
        div_term = torch.exp(torch.arange(0, d_model, 2).float()* 
                            (-math.log(10000)/self.d_model))
        # apply trig, shape (seq_len, d_model)
        pe[:,0::2] = torch.sin(position*div_term)
        pe[:,1::2] = torch.cos(position*div_term)
        # [1, seq_len, d_model]
        pe = pe.unsqueeze(0)

        self.register_buffer('pe', pe)

    @torch.no_grad()
    def forward(self, x):
        x = x + self.pe[:, :x.shape[1], :] # for different broadcasting purposes
        return self.dropout(x)
    
class LayerNorm(nn.Module):
    def __init__(self, features:int, eps: float = 10**-6):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features))
    
    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.gamma * (x-mean)/(std + self.eps) + self.beta

class FeedForwardNetwork(nn.Module):
    def __init__(self, d_model: int, d_ff:int, dropout: float):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.fnn1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.fnn2 = nn.Linear(d_ff, d_model)
    
    def forward(self, x):
        return self.fnn2(self.dropout(torch.relu(self.fnn1(x))))
    
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, h:int, dropout:float):
        super().__init__()
        self.d_model = d_model
        self.h = h
        self.dropout = dropout
        assert d_model%h == 0,  "d_model not divisible by h"

        self.d_k = d_model//h
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        
        self.Wo = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    
    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]

        # (batch, seq_len, d_k) -> (batch, seq_len, seq_len)
        attention_scores = ((query @ key.transpose(-2,-1))/math.sqrt(d_k))
        if mask is not None:
            attention_scores.masked_fill_(mask==0, -1e9)
        attention_scores = attention_scores.softmax(dim=-1) # batch, h, seq_len, seq_len
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        
        return (attention_scores @ value), attention_scores


    def forward(self, q, k, v, mask):
        # batch, seq len, d_model
        query = self.Wq(q)
        value = self.Wv(v)
        key = self.Wk(k)
        # batch, seq_len, h, d_k -> batch, h, seq_len, d_k
        query = query.view(query.shape[0], -1, self.h, self.d_k).transpose(1,2)
        key = key.view(key.shape[0], -1, self.h, self.d_k).transpose(1,2)
        # key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).permute([0,2,1,3])
        value = value.view(value.shape[0], -1, self.h, self.d_k).transpose(1,2)

        # (batch, h, seq_len, d_k) 
        x, self.attention_scores = MultiHeadAttention.attention(query, key, value, mask, self.dropout)

        x = x.transpose(1,2)
        x = x.contiguous().view(x.shape[0], -1, self.h*self.d_k)

        return self.Wo(x)

class ResidualConnections(nn.Module):
    def __init__(self, features:int, dropout: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNorm(features)
    
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))
    
class EncoderBlock(nn.Module):
    def __init__(self, features, self_attention_block: MultiHeadAttention, feed_forward_block: FeedForwardNetwork,
                 dropout: float, ):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_foward_network = feed_forward_block
        self.dropout = dropout
        self.residual_connections = nn.ModuleList([ResidualConnections(features, dropout) for _ in range(2)])

    def forward(self, x, src_mask):
        # Residual_connections stores residual blocks, 
        # hence following means ResidualConnections.forward(x, sublayer)
        x   = self.residual_connections[0](x, 
            lambda x: self.self_attention_block(x, x, x, src_mask))
        # residual_connections__Call__() expects only one input function, hence we do this
        x = self.residual_connections[1](x, self.feed_foward_network)
        return x
    
class Encoder(nn.Module):
    def __init__(self, features, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNorm(features)
    
    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class DecoderBlock(nn.Module):
    def __init__(self, features:int, self_attention_block: MultiHeadAttention, cross_attention_block: MultiHeadAttention, 
                feed_forward_block: FeedForwardNetwork, dropout:float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_network = feed_forward_block
        self.dropout = dropout
        
        self.residual_connections = nn.ModuleList([ResidualConnections(features, dropout) for _ in range(3)])
    
    def forward(self, x, e_x, target_mask, src_mask):
        x = self.residual_connections[0](x, 
            lambda x: self.self_attention_block(x, x, x, target_mask))
        
        x = self.residual_connections[1](x, 
            lambda x: self.cross_attention_block(x, e_x, e_x, src_mask))
        
        x = self.residual_connections[2](x, self.feed_forward_network)
        return x
    
class Decoder(nn.Module):
    def __init__(self, features: int, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNorm(features)
    
    def forward(self, x, encoder_output, src_mask, target_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, target_mask, src_mask)
        return self.norm(x)

    
class ProjectionLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        return torch.log_softmax(self.proj(x), dim =-1)

class Transformer(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, 
                src_embed: InputEmbedding, target_embed: InputEmbedding, 
                src_pos: PositionalEncoding, target_pos: PositionalEncoding, 
                proj_layer:ProjectionLayer):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.target_embed= target_embed
        self.src_pos = src_pos
        self.target_pos= target_pos
        self.proj_layer = proj_layer
    
    def encode(self, src, src_mask):
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)
    
    def decode(self, encoder_output, src_mask, target, target_mask):
        target = self.target_embed(target)
        target = self.target_pos(target)
        return self.decoder(target, encoder_output, src_mask, target_mask)
    
    def proj(self, x):
        return self.proj_layer(x)

## Defining a function to build a transformer

In [3]:
def build_transformer(src_vocab_size: int, target_vocab_size: int, src_seq_len:int, 
                    target_seq_len: int, d_model: int = 512, N_blocks: int = 6, heads: int = 8,
                    dropout: float = 0.1, d_ff: int = 2048) -> Transformer:
    
    src_embed = InputEmbedding(d_model, src_vocab_size)
    target_embed = InputEmbedding(d_model, target_vocab_size)

    src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
    target_pos = PositionalEncoding(d_model, target_seq_len, dropout)

    encoder_blocks = []
    # creating the encoder block parameters
    for _ in range(N_blocks):
        encoder_self_attention_block = MultiHeadAttention(d_model, heads, dropout)
        encoder_ffn = FeedForwardNetwork(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(d_model, encoder_self_attention_block, encoder_ffn, dropout)
        encoder_blocks.append(encoder_block)
    
    decoder_blocks = []
    # creating the decoder block parameters
    for _ in range(N_blocks):
        decoder_self_attention_block = MultiHeadAttention(d_model, heads, dropout)
        decoder_cross_attention_block = MultiHeadAttention(d_model, heads, dropout)
        decoder_ffn = FeedForwardNetwork(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(d_model, decoder_self_attention_block, decoder_cross_attention_block,
                                    decoder_ffn, dropout)
        decoder_blocks.append(decoder_block)

    # Creating the encoder and the decoder
    encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
    decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))

    # Projection layer
    projection_layer = ProjectionLayer(d_model, target_vocab_size)

    # transformer
    transformer = Transformer(encoder, decoder, src_embed, target_embed, src_pos, target_pos,
                            projection_layer)
    
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return transformer  


## Config File for the training and dataset loading

In [4]:
def get_config():
    return{
        "batch_size": 8,
        "num_epoches": 20,
        "lr": 1e-4,
        "seq_len": 512,
        "d_model": 512,
        "lang_src": 'en',
        "lang_target": 'fr',
        "model_folder": "weights",
        "model_filename": "tmodel_",
        'preload': None,
        "tokenizer_file": "tokenizer_{0}.json",
        "experiment_name": "runs/tmodel",
    }
def get_weights_file_path(config, epoch: str):
    model_folder = config['model_folder']
    model_basename = config['model_filename']
    model_filename = f"{model_basename}{epoch}.pt"
    return str(Path('.') / model_folder / model_filename)

## Defining the dataset modification loop

In [5]:
class BilungialDataset(Dataset):
    def __init__(self, dataset, tokenizer_src, tokenizer_target, 
                src_lang, target_lang, seq_len):
        super().__init__()

        self.dataset = dataset
        self.tokenizer_src = tokenizer_src
        self.tokenizer_target = tokenizer_target
        self.src_lang = src_lang
        self.target_lang = target_lang
        self.seq_len = seq_len

        self.sos_token = torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_src.token_to_id('[PAD]')], dtype=torch.int64)

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        src_target_pair = self.dataset[index]
        src_text = src_target_pair['translation'][self.src_lang]
        target_text = src_target_pair['translation'][self.target_lang]
        
        enc_input_token = self.tokenizer_src.encode(src_text).ids
        dec_input_token = self.tokenizer_target.encode(target_text).ids

        enc_num_padding_tokens = self.seq_len - len(enc_input_token) - 2
        dec_num_padding_tokens = self.seq_len - len(dec_input_token) - 1

        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError("Sentence is too long")
        
        # Adding sos and eos to the source text
        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(enc_input_token, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token]*enc_num_padding_tokens, dtype=torch.int64)
            ], dim=0
        )
        # No eos for decoder input, we only add sos
        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input_token, dtype=torch.int64),
                torch.tensor([self.pad_token]*dec_num_padding_tokens, dtype=torch.int64)
            ], dim = 0
        )
        # Add sos to decoder output since this is what we expect the output to be
        label = torch.cat(
            [
                torch.tensor(dec_input_token, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token]*dec_num_padding_tokens, dtype=torch.int64)
            ], dim=0
        )

        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len
        
        return {
            'encoder_input': encoder_input,
            'decoder_input': decoder_input,
            'encoder_mask': (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(),
            'decoder_mask': (decoder_input != self.pad_token).unsqueeze(0).int() & casual_mask(
                decoder_input.size(0)),
            'label': label,
            "src_target": src_text,
            "target_target": target_text
        }
def casual_mask(size):
    mask = torch.triu(torch.ones(1,size,size), diagonal=1).type(torch.int)
    return mask==0

## Training Loop

In [None]:
def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_target, max_len, device):
    sos_index = tokenizer_target.token_to_id('[SOS]')
    eos_index = tokenizer_target.token_to_id('[EOS]')

    encoder_output = model.encode(source, source_mask)

    decoder_input = torch.empty(1,1).fill_(sos_index).type_as(source).to(device)
    while True:
        if decoder_input.size(1) == max_len:
            break

        decoder_mask = casual_mask(decoder_input.size(1)).type_as(source).to(device)

        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

        prob = model.proj(out[:,-1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat([decoder_input, torch.empty(1,1).type_as(source).fill_(next_word.item()).to(device)], dim=1)

        if next_word == eos_index:
            break

    return decoder_input.squeeze(0)

def run_validation(model, validation_ds, tokenizer_src, tokenizer_target,
                   max_len, device, print_msg, global_state, writer, num_examples=2):
    model.eval()
    count = 0

    # source_texts = []
    # expected = []
    # predicted = []

    console_width = 80

    with torch.no_grad():
        for batch in validation_ds:
            count+=1
            encoder_input = batch['encoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            
            assert encoder_input.size(0) == 1, "Batch must be 1 for validation"
            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_target, max_len, device)

            source_text = batch['src_target'][0]
            target_text = batch['target_target'][0]
            model_out_text = tokenizer_target.decode(model_out.detach().cpu().numpy())\
            
            # source_texts.append(source_text)
            # expected.append(target_text)
            # predicted.append(model_out_text)

            print_msg('_'*console_width)

            print_msg(f"source: {source_text}")
            print_msg(f"target: {target_text}")
            print_msg(f"model: {model_out_text}")

            if count == num_examples:
                break

    

def get_all_sentences(dataset, lang):
    for item in dataset:
        yield item['translation'][lang]
    

def get_or_build_tokenizer(config, dataset, lang):
    # Same as {'tokenizer_file': f"tokenizer_file_{lang}.json"}
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
        # Maps word it hasnt seen to unknown
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]")) 
        # Defining split as whitespace
        tokenizer.pre_tokenizer = Whitespace()
        # defining a trainer to train with special tokens
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        # train the tokenizer
        tokenizer.train_from_iterator(get_all_sentences(dataset, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else: 
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

def get_dataset(config):
    ds_raw = load_dataset('opus_books', f'{config["lang_src"]}-{config["lang_target"]}', 
                            split="train")
    
    # build tokenizers
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_target = get_or_build_tokenizer(config, ds_raw, config['lang_target'])

    train_ds_size = int(0.9*len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

    train_ds = BilungialDataset(train_ds_raw, tokenizer_src, tokenizer_target, 
                                config['lang_src'], config['lang_target'], config['seq_len'])
    val_ds = BilungialDataset(val_ds_raw, tokenizer_src, tokenizer_target, 
                                config['lang_src'], config['lang_target'], config['seq_len'])
    
    # max_len_src = 0
    # max_len_target = 0

    # for item in ds_raw:
    #     src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
    #     target_ids = tokenizer_target.encode(item['translation'][config['lang_target']]).ids
    #     max_len_src = max(max_len_src, len(src_ids))
    #     max_len_target = max(max_len_target, len(target_ids))
    
    # print(f"Max length of source sequence: {max_len_src}")
    # print(f"Max length of target sequence: {max_len_target}")

    train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_target

def get_model(config, vocab_src_len, vocab_target_len):
    model = build_transformer(vocab_src_len, vocab_target_len, 
                            config['seq_len'], config['seq_len'], config['d_model'])
    return model

def train_model(config):
    # Define the device 
    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    print(f'Using device {device}')

    Path(config['model_folder']).mkdir(parents=True, exist_ok=True)

    train_dataloader, val_dataloader, tokenizer_src, tokenizer_target = get_dataset(config)
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_target.get_vocab_size()).to(device)
    # Tensorboard
    writer = SummaryWriter(config['experiment_name'])

    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps = 1e-9)

    initial_epoch = 0
    global_step = 0
    if config['preload']:
        model_filename = get_weights_file_path(config, config['preload'])
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename)
        initial_epoch = state['epoch']+1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']

    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

    for epoch in range(initial_epoch, config['num_epoches']):
        batch_iterator = tqdm(train_dataloader, desc=f'Processing epoch {epoch:02d}')
        for batch in batch_iterator:
            model.train()
            
            encoder_input = batch['encoder_input'].to(device)
            decoder_input = batch['decoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            decoder_mask = batch['decoder_mask'].to(device)

            encoder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
            proj_output = model.proj(decoder_output)

            label = batch['label'].to(device)

            loss = loss_fn(proj_output.view(-1, tokenizer_target.get_vocab_size()), label.view(-1))
            batch_iterator.set_postfix({f"loss": f"{loss.item():6.3f}"})

            writer.add_scalar('train_loss', loss.item(), global_step)
            writer.flush()


            loss.backward()

            optimizer.step()
            optimizer.zero_grad()
            if global_step%100 == 0:
                run_validation(model, val_dataloader, tokenizer_src, tokenizer_target, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)

            global_step+=1
   
        model_filename = get_weights_file_path(config, f"{epoch:02d}")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)


warnings.filterwarnings("ignore")
config = get_config()
train_model(config)

# Inference


In [6]:
def casual_mask(size):
    mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int)
    return mask == 0

def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_target, max_len, device):
    sos_idx = tokenizer_target.token_to_id('[SOS]')
    eos_idx = tokenizer_target.token_to_id('[EOS]')

    encoder_output = model.encode(source, source_mask)
    
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)

    while True:
        if decoder_input.size(1) == max_len:
            break

        decoder_mask = casual_mask(decoder_input.size(1)).type_as(source).to(device)

        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

        prob = model.proj(out[:, -1])
        
        _, next_word = torch.max(prob, dim=1)
        
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
        )

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)

def get_translation(sentence: str):
    config = get_config()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    tokenizer_src_path = Path(config['tokenizer_file'].format(config['lang_src']))
    tokenizer_tgt_path = Path(config['tokenizer_file'].format(config['lang_target']))
    
    if not tokenizer_src_path.exists() or not tokenizer_tgt_path.exists():
        print(f"Error: Tokenizer files not found at {tokenizer_src_path}")
        return

    tokenizer_src = Tokenizer.from_file(str(tokenizer_src_path))
    tokenizer_target = Tokenizer.from_file(str(tokenizer_tgt_path))

    model = build_transformer(
        tokenizer_src.get_vocab_size(), 
        tokenizer_target.get_vocab_size(), 
        config['seq_len'], 
        config['seq_len'], 
        config['d_model']
    ).to(device)

    model_filename = get_weights_file_path(config, "latest") 
    print(f"Loading weights from: {model_filename}")
    
    if not Path(model_filename).exists():
        print(f"Error: Weights file not found at {model_filename}")
        return

    state = torch.load(model_filename, map_location=device)
    model.load_state_dict(state['model_state_dict'])
    model.eval()

    print(f"Translating: '{sentence}'")
    with torch.no_grad():
        source = tokenizer_src.encode(sentence)
        
        source = torch.cat([
            torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64), 
            torch.tensor(source.ids, dtype=torch.int64),
            torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64),
            torch.tensor([tokenizer_src.token_to_id('[PAD]')] * (config['seq_len'] - len(source.ids) - 2), dtype=torch.int64)
        ], dim=0).to(device)

        source = source.unsqueeze(0)

        source_mask = (source != tokenizer_src.token_to_id('[PAD]')).unsqueeze(0).unsqueeze(0).int().to(device)
        
        model_out = greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_target, config['seq_len'], device)
        
        model_out_text = tokenizer_target.decode(model_out.detach().cpu().numpy())
        
        return model_out_text

if __name__ == '__main__':
    if len(sys.argv) > 1:
        sentence = sys.argv[1]
        result = get_translation(sentence)
        print(f"\nPREDICTED: {result}")
    else:
        print("Enter a sentence to translate (or '_' to exit):")
        while True:
            text = input("> ")
            if text.lower() == '_': break
            try:
                result = get_translation(text)
                print(f"PREDICTED: {result}")
                print("-" * 30)
            except Exception as e:
                print(f"Error: {e}")

Using device: cuda
Loading weights from: weights\tmodel_latest.pt
Error: Weights file not found at weights\tmodel_latest.pt

PREDICTED: None
