In [1]:
# Add src module directory to system path for subsecuent imports.
import os
import sys
sys.path.insert(0, '../src')

In [2]:
from util import is_notebook

# Settings (only in Jupyter Notebooks).
if is_notebook():
    # Module reloading.
    %load_ext autoreload
    # aimport?
    %autoreload 2
    # Plot settings.
    %matplotlib inline

In [3]:
# Imports.
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from datasets import load_metric

from constants import *
from util import *
from transformer import Transformer
from tokenizer import load_tokenizer
from data import download_data, load_data
from plotting import plot_metric
from metric_logging import MetricLogger

In [4]:
# Set seed.
from pytorch_lightning import seed_everything
seed_everything(0, workers=True)

Global seed set to 0


0

In [5]:
# Set torch device to GPU is available, otherwise cpu.
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Torch uses {device} device.')

Torch uses cpu device.


In [6]:
# Experiment paramereters.
hparams = dotdict({
    'src_lang': 'de',
    'tgt_lang': 'nl',
    'model_path': './models/baseline-de-nl.pt',
})

print('Experiment paramereters:')
print(hparams)

Experiment paramereters:
{'src_lang': 'de', 'tgt_lang': 'nl', 'model_path': './models/baseline-de-nl.pt'}


In [7]:
# Load Metrics.
score_metric = load_metric('sacrebleu')

print('Loaded metrics.')

Loaded metrics.


In [8]:
# Load tokenizers.
src_tokenizer = load_tokenizer(hparams.src_lang, hparams.tgt_lang)
tgt_tokenizer = load_tokenizer(hparams.tgt_lang, hparams.src_lang)

print('Loaded tokenizers.')

Tokenizer exists. Skipping training.
Tokenizer exists. Skipping training.
Loaded tokenizers.


In [9]:
# Create model.
model = Transformer(src_tokenizer,
                    tgt_tokenizer,
                    score_metric=score_metric)

print('Created model.')

Created model.


In [10]:
# Load model.
model.load_state_dict(torch.load(hparams.model_path))

print('Loaded model')

Loaded model


In [41]:
# Perform translation.
model = model.to(device)
src_text = 'Ich will das nicht mehr.'
tgt_text = ' '

translation = model.translate(src_text, method='greedy')
score = score_metric.compute(predictions=[translation], references=[[tgt_text]])

print(f'"{src_text}" --> "{translation}"')
print(f'"{translation}" =?= "{tgt_text}"')
print(f'Sacre Bleu Score: {score["score"]}')

"Ich will das nicht mehr." --> "Ik wil het niet meer...."
"Ik wil het niet meer...." =?= " "
Sacre Bleu Score: 0.0
