In [None]:
import torch
import torch.nn as nn
from config import get_config, latest_weights_file_path
from train import get_model, get_ds, run_validation, beam_search_decode

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device)
config = get_config()
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)

if torch.__version__.startswith("2."):
    print("Compiling the model before loading weights...")
    model = torch.compile(model)

# Load the pre-trained weights
model_filename = latest_weights_file_path(config)
state = torch.load(model_filename, map_location=torch.device(device))
model.load_state_dict(state['model_state_dict'])

In [None]:
run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, 
               lambda msg: print(msg), 0, None, num_examples=30)

In [None]:
def translate(sentence: str):
    model.eval()
    with torch.no_grad():
        # Tokenize source
        source = tokenizer_src.encode(sentence)
        source = torch.cat([
            torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int32),
            torch.tensor(source.ids, dtype=torch.int32),
            torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int32),
        ]).to(device)

        # Add padding
        pad_token_id = tokenizer_src.token_to_id('[PAD]')
        num_padding_tokens = config['seq_len'] - len(source)
        if num_padding_tokens < 0:
            raise ValueError("The sentence is too long")
        source = torch.cat([
            source,
            torch.full((num_padding_tokens,), pad_token_id, dtype=torch.int32, device=device)
        ]).unsqueeze(0)
        
        # Create mask
        source_mask = (source != pad_token_id).unsqueeze(1).unsqueeze(1).int().to(device)

        # Translate using beam search
        model_out = beam_search_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, config['seq_len'], device)
        
        # Decode output
        return tokenizer_tgt.decode(model_out.detach().cpu().numpy())

In [None]:
my_sentence = "お元気ですか？" 
translation = translate(my_sentence)

print(f"Source: {my_sentence}")
print(f"Translation: {translation}")