In [1]:
from pathlib import Path
import torch
import torch.nn as nn
from Config import get_config, latest_weights_file_path
from Train import get_model, get_dataset, run_validation
from Translate import translate



In [2]:
from tokenizers import Tokenizer # Make sure Tokenizer is imported

# Define the device
device = "cpu" # Or your desired device like "cuda"
print("Using device:", device)
config = get_config()

tokenizer_src_path_str = config['tokenizer_file'].format(config['lang_src'])
tokenizer_tgt_path_str = config['tokenizer_file'].format(config['lang_tgt'])

tokenizer_src_path = Path(tokenizer_src_path_str)
tokenizer_tgt_path = Path(tokenizer_tgt_path_str)

print(f"Loading source tokenizer from: {tokenizer_src_path}")
tokenizer_src = Tokenizer.from_file(str(tokenizer_src_path))
print(f"Loading target tokenizer from: {tokenizer_tgt_path}")
tokenizer_tgt = Tokenizer.from_file(str(tokenizer_tgt_path))

# Verify vocabulary sizes BEFORE building the model
current_src_vocab_size = tokenizer_src.get_vocab_size()
current_tgt_vocab_size = tokenizer_tgt.get_vocab_size()
print(f"Loaded source tokenizer vocab size: {current_src_vocab_size}")
print(f"Loaded target tokenizer vocab size: {current_tgt_vocab_size}")

Using device: cpu
Loading source tokenizer from: tokenizer_en.json
Loading target tokenizer from: tokenizer_vi.json
Loaded source tokenizer vocab size: 30000
Loaded target tokenizer vocab size: 23029


In [None]:
state_dict = state['model_state_dict']

for key, tensor in state_dict.items():
    print(f"{key}: {tuple(tensor.shape)}")

In [None]:
model = get_model(config, current_src_vocab_size, current_tgt_vocab_size).to(device)
print("Model initialized with current tokenizer vocabulary sizes.")

# Load the pretrained weights
model_filename = latest_weights_file_path(config)

if model_filename:
    print(f"Attempting to load model weights from: {model_filename}")
    # Ensure map_location uses the torch.device object
    state = torch.load(model_filename, map_location=torch.device(device), weights_only=True)
    model.load_state_dict(state['model_state_dict'])
    print(f"Successfully loaded model weights from {model_filename}")
else:
    print(f"No model checkpoint found at path pattern: {config['datasource']}_{config['model_folder']}/{config['model_basename']}*")
    raise FileNotFoundError("Could not find model weights to load.")

print("Loading dataset...")
train_dataloader, val_dataloader, _, _ = get_dataset(config) # The tokenizers returned here should be the same instances if paths match
print("Dataset loaded.")

In [None]:
run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: print(msg), 0, None, num_examples=10)

In [None]:
t = translate("Why do I need to translate this?")

In [None]:
t = translate(34)