In [6]:
import torch

from utils import Config, get_last_checkpoint_path
from train import Transformer
from inference import translate_text
from tokenizer import load_tokenizer

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = Config('config.yaml')

# load tokenizers
src_tokenizer = load_tokenizer(config=config, lang=config.LANG_SRC)
tgt_tokenizer = load_tokenizer(config=config, lang=config.LANG_TGT)

# build transformer architecture
model = Transformer(
    src_vocab_size = src_tokenizer.get_vocab_size(),
    tgt_vocab_size = tgt_tokenizer.get_vocab_size(),
    embed_size     = config.EMBED_SIZE,
    hidden_size    = config.HIDDEN_SIZE,
    max_len        = config.MAX_LEN,
    dropout        = config.DROPOUT,
    heads          = config.HEADS,
    N              = config.LAYERS,
).to(device)

# Load the pretrained weights
model_filename = get_last_checkpoint_path(config=config)
state = torch.load(f=model_filename, map_location=torch.device('cpu'))
model.load_state_dict(state_dict=state['model_state_dict'])

<All keys matched successfully>

In [9]:
text = "Сколько сейчас времени?"

translate_text(
    text=text, 
    model=model, 
    config=config, 
    device=device, 
    src_tokenizer=src_tokenizer, 
    tgt_tokenizer=tgt_tokenizer, 
)

'Quante volte è tempo ora ?'