In [9]:
from train import get_model, greedy_decode
from dataset import get_source_tokenizer, get_target_tokenizer
from config import get_config, get_weights_file_path
from pathlib import Path
import torch
import warnings

In [None]:
def load_model_for_inference(config):
    """
    Loads the Seq2SeqTransformer model with the latest weights for inference.

    Args:
        config (dict): Configuration dictionary containing model parameters.

    Returns:
        Seq2SeqTransformer: The model loaded with the latest weights.
        Tokenizer: Source language tokenizer.
        Tokenizer: Target language tokenizer.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    source_tokenizer = get_source_tokenizer(config)
    target_tokenizer = get_target_tokenizer(config)

    model = get_model(
        config, source_tokenizer.get_vocab_size(), target_tokenizer.get_vocab_size()
    ).to(device)

    model_filename = get_weights_file_path(config, "latest")
    if Path(model_filename).exists():
        print(f"Loading model weights from {model_filename}")
        state = torch.load(model_filename, map_location=device)
        model.load_state_dict(state["model_state_dict"])
    else:
        raise FileNotFoundError(f"No model weights found at {model_filename}")

    return model, source_tokenizer, target_tokenizer

In [None]:
def infer(model, source_sentence, source_tokenizer, target_tokenizer, config):
    """
    Performs inference using the Seq2SeqTransformer model.

    Args:
        model (Seq2SeqTransformer): The trained Seq2SeqTransformer model.
        source_sentence (str): The input sentence in the source language.
        source_tokenizer (Tokenizer): Tokenizer for the source language.
        target_tokenizer (Tokenizer): Tokenizer for the target language.
        config (dict): Configuration dictionary containing model parameters.

    Returns:
        str: The translated sentence in the target language.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()

    source_tokens = source_tokenizer.encode(source_sentence).ids
    source_tensor = torch.tensor(source_tokens).unsqueeze(0).to(device)
    source_mask = (source_tensor != source_tokenizer.token_to_id("<PAD>")).unsqueeze(1).unsqueeze(2).to(device)

    with torch.no_grad():
        output_tokens = greedy_decode(
            model, source_tensor, source_mask, target_tokenizer, config["max_len"], device
        )

    translated_sentence = target_tokenizer.decode(output_tokens.cpu().numpy())
    return translated_sentence


In [None]:
warnings.filterwarnings("ignore")
config = get_config()
# Example inference
model, source_tokenizer, target_tokenizer = load_model_for_inference(config)
source_sentence = "This is a test sentence."
translated_sentence = infer(model, source_sentence, source_tokenizer, target_tokenizer, config)
print(f"Source: {source_sentence}")
print(f"Translated: {translated_sentence}")