In [2]:
import torch

from utils import Config, get_latest_weights_file_path
from train import Transformer
from inference import translate_text
from tokenizer import load_tokenizer

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = Config('config.yaml')

# load tokenizers
tokenizer_src = load_tokenizer(config=config, lang=config.LANG_SRC)
tokenizer_tgt = load_tokenizer(config=config, lang=config.LANG_TGT)

# build transformer architecture
model = Transformer(
    embed_size=config.D_MODEL,
    src_vocab_size=tokenizer_src.get_vocab_size(),
    tgt_vocab_size=tokenizer_tgt.get_vocab_size(),
    max_len=config.MAX_LEN,
    dropout=0.1,
    heads=8,
    hidden_size=2048,
    N=6
).to(device)

# Load the pretrained weights
model_filename = get_latest_weights_file_path(config=config)
state = torch.load(f=model_filename, map_location=torch.device('cpu'))
model.load_state_dict(state_dict=state['model_state_dict'])

RuntimeError: Error(s) in loading state_dict for Transformer:
	Missing key(s) in state_dict: "encoder.input_embedding.embedding.weight", "encoder.pos_encoding.pe", "encoder.encoder_block.self_attention_block.w_q.weight", "encoder.encoder_block.self_attention_block.w_k.weight", "encoder.encoder_block.self_attention_block.w_v.weight", "encoder.encoder_block.self_attention_block.w_o.weight", "encoder.encoder_block.norm.alpha", "encoder.encoder_block.norm.bias", "encoder.encoder_block.feed_forward_block.fc1.weight", "encoder.encoder_block.feed_forward_block.fc1.bias", "encoder.encoder_block.feed_forward_block.fc2.weight", "encoder.encoder_block.feed_forward_block.fc2.bias", "encoder.layers.0.norm.alpha", "encoder.layers.0.norm.bias", "encoder.layers.1.norm.alpha", "encoder.layers.1.norm.bias", "encoder.layers.2.norm.alpha", "encoder.layers.2.norm.bias", "encoder.layers.3.norm.alpha", "encoder.layers.3.norm.bias", "encoder.layers.4.norm.alpha", "encoder.layers.4.norm.bias", "encoder.layers.5.norm.alpha", "encoder.layers.5.norm.bias", "decoder.input_embedding.embedding.weight", "decoder.pos_encoding.pe", "decoder.decoder_block.attention_block.w_q.weight", "decoder.decoder_block.attention_block.w_k.weight", "decoder.decoder_block.attention_block.w_v.weight", "decoder.decoder_block.attention_block.w_o.weight", "decoder.decoder_block.norm.alpha", "decoder.decoder_block.norm.bias", "decoder.decoder_block.feed_forward_block.fc1.weight", "decoder.decoder_block.feed_forward_block.fc1.bias", "decoder.decoder_block.feed_forward_block.fc2.weight", "decoder.decoder_block.feed_forward_block.fc2.bias", "decoder.layers.0.attention_block.w_q.weight", "decoder.layers.0.attention_block.w_k.weight", "decoder.layers.0.attention_block.w_v.weight", "decoder.layers.0.attention_block.w_o.weight", "decoder.layers.0.norm.alpha", "decoder.layers.0.norm.bias", "decoder.layers.1.attention_block.w_q.weight", "decoder.layers.1.attention_block.w_k.weight", "decoder.layers.1.attention_block.w_v.weight", "decoder.layers.1.attention_block.w_o.weight", "decoder.layers.1.norm.alpha", "decoder.layers.1.norm.bias", "decoder.layers.2.attention_block.w_q.weight", "decoder.layers.2.attention_block.w_k.weight", "decoder.layers.2.attention_block.w_v.weight", "decoder.layers.2.attention_block.w_o.weight", "decoder.layers.2.norm.alpha", "decoder.layers.2.norm.bias", "decoder.layers.3.attention_block.w_q.weight", "decoder.layers.3.attention_block.w_k.weight", "decoder.layers.3.attention_block.w_v.weight", "decoder.layers.3.attention_block.w_o.weight", "decoder.layers.3.norm.alpha", "decoder.layers.3.norm.bias", "decoder.layers.4.attention_block.w_q.weight", "decoder.layers.4.attention_block.w_k.weight", "decoder.layers.4.attention_block.w_v.weight", "decoder.layers.4.attention_block.w_o.weight", "decoder.layers.4.norm.alpha", "decoder.layers.4.norm.bias", "decoder.layers.5.attention_block.w_q.weight", "decoder.layers.5.attention_block.w_k.weight", "decoder.layers.5.attention_block.w_v.weight", "decoder.layers.5.attention_block.w_o.weight", "decoder.layers.5.norm.alpha", "decoder.layers.5.norm.bias", "decoder.proj.fc.weight", "decoder.proj.fc.bias". 
	Unexpected key(s) in state_dict: "src_embed.embedding.weight", "tgt_embed.embedding.weight", "src_pos.pe", "tgt_pos.pe", "projection_layer.proj.weight", "projection_layer.proj.bias", "encoder.layers.0.residual_connection1.norm.alpha", "encoder.layers.0.residual_connection1.norm.bias", "encoder.layers.0.residual_connection2.norm.alpha", "encoder.layers.0.residual_connection2.norm.bias", "encoder.layers.1.residual_connection1.norm.alpha", "encoder.layers.1.residual_connection1.norm.bias", "encoder.layers.1.residual_connection2.norm.alpha", "encoder.layers.1.residual_connection2.norm.bias", "encoder.layers.2.residual_connection1.norm.alpha", "encoder.layers.2.residual_connection1.norm.bias", "encoder.layers.2.residual_connection2.norm.alpha", "encoder.layers.2.residual_connection2.norm.bias", "encoder.layers.3.residual_connection1.norm.alpha", "encoder.layers.3.residual_connection1.norm.bias", "encoder.layers.3.residual_connection2.norm.alpha", "encoder.layers.3.residual_connection2.norm.bias", "encoder.layers.4.residual_connection1.norm.alpha", "encoder.layers.4.residual_connection1.norm.bias", "encoder.layers.4.residual_connection2.norm.alpha", "encoder.layers.4.residual_connection2.norm.bias", "encoder.layers.5.residual_connection1.norm.alpha", "encoder.layers.5.residual_connection1.norm.bias", "encoder.layers.5.residual_connection2.norm.alpha", "encoder.layers.5.residual_connection2.norm.bias", "decoder.layers.0.self_attention_block.w_q.weight", "decoder.layers.0.self_attention_block.w_k.weight", "decoder.layers.0.self_attention_block.w_v.weight", "decoder.layers.0.self_attention_block.w_o.weight", "decoder.layers.0.cross_attention_block.w_q.weight", "decoder.layers.0.cross_attention_block.w_k.weight", "decoder.layers.0.cross_attention_block.w_v.weight", "decoder.layers.0.cross_attention_block.w_o.weight", "decoder.layers.0.residual_connection1.norm.alpha", "decoder.layers.0.residual_connection1.norm.bias", "decoder.layers.0.residual_connection2.norm.alpha", "decoder.layers.0.residual_connection2.norm.bias", "decoder.layers.0.residual_connection3.norm.alpha", "decoder.layers.0.residual_connection3.norm.bias", "decoder.layers.1.self_attention_block.w_q.weight", "decoder.layers.1.self_attention_block.w_k.weight", "decoder.layers.1.self_attention_block.w_v.weight", "decoder.layers.1.self_attention_block.w_o.weight", "decoder.layers.1.cross_attention_block.w_q.weight", "decoder.layers.1.cross_attention_block.w_k.weight", "decoder.layers.1.cross_attention_block.w_v.weight", "decoder.layers.1.cross_attention_block.w_o.weight", "decoder.layers.1.residual_connection1.norm.alpha", "decoder.layers.1.residual_connection1.norm.bias", "decoder.layers.1.residual_connection2.norm.alpha", "decoder.layers.1.residual_connection2.norm.bias", "decoder.layers.1.residual_connection3.norm.alpha", "decoder.layers.1.residual_connection3.norm.bias", "decoder.layers.2.self_attention_block.w_q.weight", "decoder.layers.2.self_attention_block.w_k.weight", "decoder.layers.2.self_attention_block.w_v.weight", "decoder.layers.2.self_attention_block.w_o.weight", "decoder.layers.2.cross_attention_block.w_q.weight", "decoder.layers.2.cross_attention_block.w_k.weight", "decoder.layers.2.cross_attention_block.w_v.weight", "decoder.layers.2.cross_attention_block.w_o.weight", "decoder.layers.2.residual_connection1.norm.alpha", "decoder.layers.2.residual_connection1.norm.bias", "decoder.layers.2.residual_connection2.norm.alpha", "decoder.layers.2.residual_connection2.norm.bias", "decoder.layers.2.residual_connection3.norm.alpha", "decoder.layers.2.residual_connection3.norm.bias", "decoder.layers.3.self_attention_block.w_q.weight", "decoder.layers.3.self_attention_block.w_k.weight", "decoder.layers.3.self_attention_block.w_v.weight", "decoder.layers.3.self_attention_block.w_o.weight", "decoder.layers.3.cross_attention_block.w_q.weight", "decoder.layers.3.cross_attention_block.w_k.weight", "decoder.layers.3.cross_attention_block.w_v.weight", "decoder.layers.3.cross_attention_block.w_o.weight", "decoder.layers.3.residual_connection1.norm.alpha", "decoder.layers.3.residual_connection1.norm.bias", "decoder.layers.3.residual_connection2.norm.alpha", "decoder.layers.3.residual_connection2.norm.bias", "decoder.layers.3.residual_connection3.norm.alpha", "decoder.layers.3.residual_connection3.norm.bias", "decoder.layers.4.self_attention_block.w_q.weight", "decoder.layers.4.self_attention_block.w_k.weight", "decoder.layers.4.self_attention_block.w_v.weight", "decoder.layers.4.self_attention_block.w_o.weight", "decoder.layers.4.cross_attention_block.w_q.weight", "decoder.layers.4.cross_attention_block.w_k.weight", "decoder.layers.4.cross_attention_block.w_v.weight", "decoder.layers.4.cross_attention_block.w_o.weight", "decoder.layers.4.residual_connection1.norm.alpha", "decoder.layers.4.residual_connection1.norm.bias", "decoder.layers.4.residual_connection2.norm.alpha", "decoder.layers.4.residual_connection2.norm.bias", "decoder.layers.4.residual_connection3.norm.alpha", "decoder.layers.4.residual_connection3.norm.bias", "decoder.layers.5.self_attention_block.w_q.weight", "decoder.layers.5.self_attention_block.w_k.weight", "decoder.layers.5.self_attention_block.w_v.weight", "decoder.layers.5.self_attention_block.w_o.weight", "decoder.layers.5.cross_attention_block.w_q.weight", "decoder.layers.5.cross_attention_block.w_k.weight", "decoder.layers.5.cross_attention_block.w_v.weight", "decoder.layers.5.cross_attention_block.w_o.weight", "decoder.layers.5.residual_connection1.norm.alpha", "decoder.layers.5.residual_connection1.norm.bias", "decoder.layers.5.residual_connection2.norm.alpha", "decoder.layers.5.residual_connection2.norm.bias", "decoder.layers.5.residual_connection3.norm.alpha", "decoder.layers.5.residual_connection3.norm.bias". 

In [81]:
text = "Скажи мне, сколько сейчас времени?"

translate_text(
    text=text, 
    model=model, 
    config=config, 
    device=device, 
    tokenizer_src=tokenizer_src, 
    tokenizer_tgt=tokenizer_tgt, 
)

', come sono felice di tempo ?'