## 1. Import Modules and Data
It contians following steps:
1. Use tokenizers from `spacy` to tokenize texts from train test_dataset. 
2. Build the vocabulary, i.e. the tokens for the index dictionary. A list of special tokens (e.g. `<eos>`, `<pad>`) is prepended to the entire table.
3. Prepare test_dataset and dataloader.

In [1]:
from data import load_data
from modules import Transformer, make_pad_mask, make_causal_mask
import torch
import torch.nn as nn
import random
import config
import os

os.makedirs(config.checkpoint_dir, exist_ok=True)

src_lang = "en"
tgt_lang = "de"

src_vocab, tgt_vocab, train_dataloader, valid_dataloader, test_dataloader = (
    load_data(src_lang, tgt_lang)
)

dataset = test_dataloader.dataset

torch.manual_seed(3407)
config.device

device(type='cuda', index=0)

## 2. Load Trained Model

In [16]:
model = Transformer(
    src_pad_idx=src_vocab["<pad>"],
    tgt_pad_idx=tgt_vocab["<pad>"],
    src_vocab_size=len(src_vocab),
    tgt_vocab_size=len(tgt_vocab),
    d_model=config.d_model,
    n_head=config.n_head,
    max_len=config.max_len,
    ffn_hidden=config.ffn_hidden,
    n_layer=config.n_layer,
    dropout=config.dropout,
    device=config.device,
)
state_dict = torch.load(os.path.join(config.checkpoint_dir, "en_de.pth"))
model.load_state_dict(state_dict)

<All keys matched successfully>

## 3. Inference


In [22]:
def greedy_search(model, src_sentence, max_len=50):
    model.eval()
    src_tokens = (
        [src_vocab["<sos>"]]
        + [src_vocab.get(word, src_vocab["<unk>"]) for word in src_sentence.split()]
        + [src_vocab["<eos>"]]
    )
    src_tensor = torch.LongTensor(src_tokens).unsqueeze(0).to(config.device)
    memory = model.encode(src_tensor)
    memory_mask = make_pad_mask(memory, tgt_vocab["<pad>"])
    special_index = [
        tgt_vocab["<sos>"],
        tgt_vocab["<pad>"],
        tgt_vocab["<unk>"],
        tgt_vocab["<eos>"],
    ]

    tgt_tokens = [tgt_vocab["<sos>"]]
    for _ in range(max_len):
        tgt_tensor = torch.LongTensor(tgt_tokens).unsqueeze(0).to(config.device)
        output = model.decode(tgt_tensor, memory, memory_mask)
        next_token = output.argmax(2)[:, -1].item()
        tgt_tokens.append(next_token)
        if next_token == tgt_vocab["<eos>"]:
            break

    return " ".join(
        [
            list(tgt_vocab.keys())[list(tgt_vocab.values()).index(token)]
            for token in tgt_tokens
            if token not in special_index
        ]
    )


greedy_search(model, "A girl in karate uniform breaking a stick with a front kick.")

'Ein Mann einem einem einem .'