In [4]:
import torch

from utils import Config, get_latest_weights_file_path
from train import Transformer
from inference import translate_text
from tokenizer import load_tokenizer

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = Config('config.yaml')

# load tokenizers
src_tokenizer = load_tokenizer(config=config, lang=config.LANG_SRC)
tgt_tokenizer = load_tokenizer(config=config, lang=config.LANG_TGT)

# build transformer architecture
model = Transformer(
    embed_size=config.D_MODEL,
    src_vocab_size=src_tokenizer.get_vocab_size(),
    tgt_vocab_size=tgt_tokenizer.get_vocab_size(),
    max_len=config.MAX_LEN,
    dropout=0.1,
    heads=8,
    hidden_size=2048,
    N=6
).to(device)

# Load the pretrained weights
model_filename = get_latest_weights_file_path(config=config)
state = torch.load(f=model_filename, map_location=torch.device('cpu'))
model.load_state_dict(state_dict=state['model_state_dict'])

<All keys matched successfully>

In [7]:
text = "Скажи мне, сколько сейчас времени?"

translate_text(
    text=text, 
    model=model, 
    config=config, 
    device=device, 
    src_tokenizer=src_tokenizer, 
    tgt_tokenizer=tgt_tokenizer, 
)

'— Ma , ?'