In [None]:
# Add src module directory to system path for subsecuent imports.
import os
import sys
sys.path.insert(0, '../src')

In [None]:
from util import is_notebook

# Settings (only in Jupyter Notebooks).
if is_notebook():
    # Module reloading.
    %load_ext autoreload
    # aimport?
    %autoreload 2
    # Plot settings.
    %matplotlib inline

In [None]:
# Imports.
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from datasets import load_metric

from constants import *
from util import *
from transformer import Transformer, cascaded_inference
from tokenizer import load_tokenizer
from data import download_data, load_data
from plotting import plot_metric
from metric_logging import MetricLogger

In [None]:
# Set seed.
from pytorch_lightning import seed_everything
seed_everything(0, workers=True)

In [None]:
# Experiment paramereters.
hparams = dotdict({
    'src_lang': 'de',
    'pvt_lang': 'nl',
    'tgt_lang': 'en',
    'src_pvt_model_path': 'models/baseline-de-nl.pt',
    'pvt_tgt_model_path': 'models/baseline-nl-en.pt',
    'batch_size': 80,
    'max_epochs': 10,
    'max_examples': -1,
    'gpus': 1,
    'num_workers': 4,
    'ckpt_path': None,
})

print('Experiment paramereters:')
print(hparams)

In [None]:
# Constant directories.
data_dir = os.path.join('./', 'data')
tokenizers_dir = os.path.join('./', 'tokenizers')
runs_dir = os.path.join('./', 'runs')

# Experiment directories.
run_dir = os.path.join(runs_dir, f'cascaded-{hparams.src_lang}-{hparams.pvt_lang}-{hparams.tgt_lang}-{get_time_as_string()}')
results_dir = os.path.join(run_dir, 'results')

dirs = [data_dir, tokenizers_dir, runs_dir, run_dir, results_dir]
for dir in dirs:
    create_dir(dir)

print('Created directories.')

In [None]:
# Load Metrics.
score_metric = load_metric('sacrebleu')

print('Loaded metrics.')

In [None]:
# Download data.
download_data(hparams.src_lang, hparams.tgt_lang)

In [None]:
# Load tokenizers.
src_tokenizer = load_tokenizer(hparams.src_lang, hparams.tgt_lang)
pvt_tokenizer = load_tokenizer(hparams.src_lang, hparams.tgt_lang)
tgt_tokenizer = load_tokenizer(hparams.tgt_lang, hparams.src_lang)

print('Loaded tokenizers.')

In [None]:
# Load data.
train_dataset, val_dataset, test_dataset = load_data(hparams.src_lang,
                                                     hparams.tgt_lang,
                                                     src_tokenizer,
                                                     tgt_tokenizer,
                                                     hparams.max_examples)

print(f'Preprocessed data ({hparams.src_lang}-{hparams.tgt_lang})')
print(f'\tTraining data:   {len(train_dataset)}')
print(f'\tValidation data: {len(val_dataset)}')
print(f'\tTest data:       {len(test_dataset)}')

In [None]:
# Create data loaders.
train_dataloader = DataLoader(train_dataset, batch_size=hparams.batch_size, num_workers=hparams.num_workers)
val_dataloader = DataLoader(val_dataset, batch_size=hparams.batch_size, num_workers=hparams.num_workers)
test_dataloader = DataLoader(test_dataset, batch_size=1, num_workers=hparams.num_workers)

print('Created data loaders.')

In [None]:
# Create models.
src_pvt_model = Transformer(src_tokenizer,
                            pvt_tokenizer,
                            score_metric=score_metric)

pvt_tgt_model = Transformer(pvt_tokenizer,
                            tgt_tokenizer,
                            score_metric=score_metric)

print('Created models.')

In [None]:
# Load models.
src_pvt_model.load_state_dict(torch.load(hparams.src_pvt_model_path))
pvt_tgt_model.load_state_dict(torch.load(hparams.pvt_tgt_model_path))

src_pvt_model.to(device)
pvt_tgt_model.to(device)

print('Loaded models.')

In [None]:
# Testing.
test_results = {'test_score_epoch': 0}

for batch_idx, batch in enumerate(test_dataloader):
    # Cascaded inference.
    score, src_text, pvt_text, tgt_text, label_text = cascaded_inference(batch,
                                                                         src_tokenizer, tgt_tokenizer,
                                                                         src_pvt_model, pvt_tgt_model,
                                                                         score_metric)
    
    # Accumulate metrics.
    test_results['test_score_epoch'] += score
    print(f'Score: {score}')
test_results['test_score_epoch'] = [test_results['test_score_epoch'] / len(test_dataset)]

In [None]:
# Save hyper parameters.
save_dict(run_dir, hparams, 'hparams')

In [None]:
# Save test results.
save_dict(results_dir, test_results, 'metrics')