In [1]:
# Add src module directory to system path for subsecuent imports.
import os
import sys
sys.path.insert(0, '../src')

In [2]:
from util import is_notebook

# Settings (only in Jupyter Notebooks).
if is_notebook():
    # Module reloading.
    %load_ext autoreload
    # aimport?
    %autoreload 2
    # Plot settings.
    %matplotlib inline

In [3]:
# Imports.
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from datasets import load_metric

from constants import *
from util import *
from transformer import Transformer, cascaded_inference
from tokenizer import load_tokenizer
from data import download_data, load_data
from plotting import plot_metric
from metric_logging import MetricLogger

In [4]:
# Set seed.
from pytorch_lightning import seed_everything
seed_everything(0, workers=True)

Global seed set to 0


0

In [5]:
# Experiment paramereters.
hparams = dotdict({
    'src_lang': 'de',
    'pvt_lang': 'nl',
    'tgt_lang': 'en',
    'src_pvt_model_path': 'models/baseline-de-nl.pt',
    'pvt_tgt_model_path': 'models/baseline-nl-en.pt',
    'batch_size': 80,
    'max_epochs': 10,
    'max_examples': -1,
    'gpus': 1,
    'num_workers': 4,
    'ckpt_path': None,
})

print('Experiment paramereters:')
print(hparams)

Experiment paramereters:
{'src_lang': 'de', 'pvt_lang': 'nl', 'tgt_lang': 'en', 'src_pvt_model_path': 'models/baseline-de-nl.pt', 'pvt_tgt_model_path': 'models/baseline-nl-en.pt', 'batch_size': 80, 'max_epochs': 10, 'max_examples': -1, 'gpus': 1, 'num_workers': 4, 'ckpt_path': None}


In [6]:
# Constant directories.
data_dir = os.path.join('./', 'data')
tokenizers_dir = os.path.join('./', 'tokenizers')
runs_dir = os.path.join('./', 'runs')

# Experiment directories.
run_dir = os.path.join(runs_dir, f'cascaded-{hparams.src_lang}-{hparams.pvt_lang}-{hparams.tgt_lang}-{get_time_as_string()}')
results_dir = os.path.join(run_dir, 'results')

dirs = [data_dir, tokenizers_dir, runs_dir, run_dir, results_dir]
for dir in dirs:
    create_dir(dir)

print('Created directories.')

Dir "./data" already exists.
Dir "./tokenizers" already exists.
Dir "./runs" already exists.
Dir "./runs\cascaded-de-nl-en-2022.08.27-16.55.31" does not exist, creating it.
Dir "./runs\cascaded-de-nl-en-2022.08.27-16.55.31\results" does not exist, creating it.
Created directories.


In [7]:
# Load Metrics.
score_metric = load_metric('sacrebleu')

print('Loaded metrics.')

Loaded metrics.


In [8]:
# Download data.
download_data(hparams.src_lang, hparams.tgt_lang)

File "data\de-en.zip" already exists. Skipping download.
Directory data\de-en already exists. Skipping unzipping.


In [9]:
# Load tokenizers.
src_tokenizer = load_tokenizer(hparams.src_lang, hparams.tgt_lang)
pvt_tokenizer = load_tokenizer(hparams.src_lang, hparams.tgt_lang)
tgt_tokenizer = load_tokenizer(hparams.tgt_lang, hparams.src_lang)

print('Loaded tokenizers.')

Tokenizer exists. Skipping training.
Tokenizer exists. Skipping training.
Tokenizer exists. Skipping training.
Loaded tokenizers.


In [10]:
# Load data.
train_dataset, val_dataset, test_dataset = load_data(hparams.src_lang,
                                                     hparams.tgt_lang,
                                                     src_tokenizer,
                                                     tgt_tokenizer,
                                                     hparams.max_examples)

print(f'Preprocessed data ({hparams.src_lang}-{hparams.tgt_lang})')
print(f'\tTraining data:   {len(train_dataset)}')
print(f'\tValidation data: {len(val_dataset)}')
print(f'\tTest data:       {len(test_dataset)}')

Preprocessed data exists, loading from disk...
Splitting de-en data...
Data (de-en) split.
Preprocessed data (de-en)
	Training data:   1416094
	Validation data: 78671
	Test data:       78671


In [11]:
# Create data loaders.
train_dataloader = DataLoader(train_dataset, batch_size=hparams.batch_size, num_workers=hparams.num_workers)
val_dataloader = DataLoader(val_dataset, batch_size=hparams.batch_size, num_workers=hparams.num_workers)
test_dataloader = DataLoader(test_dataset, batch_size=1, num_workers=hparams.num_workers)

print('Created data loaders.')

Created data loaders.


In [12]:
# Create models.
src_pvt_model = Transformer(src_tokenizer,
                            pvt_tokenizer,
                            score_metric=score_metric)

pvt_tgt_model = Transformer(pvt_tokenizer,
                            tgt_tokenizer,
                            score_metric=score_metric)

print('Created models.')

Created models.


In [13]:
# Load models.
src_pvt_model.load_state_dict(torch.load(hparams.src_pvt_model_path))
pvt_tgt_model.load_state_dict(torch.load(hparams.pvt_tgt_model_path))

src_pvt_model.to(device)
pvt_tgt_model.to(device)

print('Loaded models.')

Loaded models.


In [None]:
# Testing.
test_results = {'test_score_epoch': 0}

for batch_idx, batch in enumerate(test_dataloader):
    # Cascaded inference.
    score, src_text, pvt_text, tgt_text, label_text = cascaded_inference(batch,
                                                                         src_tokenizer, tgt_tokenizer,
                                                                         src_pvt_model, pvt_tgt_model,
                                                                         score_metric)
    
    # Accumulate metrics.
    test_results['test_score_epoch'] += score
    print(f'Score: {score}')
test_results['test_score_epoch'] = [test_results['test_score_epoch'] / len(test_dataset)]

Score: 4.196114906296549
Score: 8.276064952530392
Score: 6.477300151702164
Score: 25.29920735938594
Score: 14.923729480049115
Score: 6.567274736060395
Score: 14.0332996877368
Score: 17.64093567849862
Score: 5.0912128230977505
Score: 33.010083098515025
Score: 14.6798691397542
Score: 19.692104496063735
Score: 13.059620291793733
Score: 6.27465531099474
Score: 27.846127764465503
Score: 28.339296176052862
Score: 6.68225620936445
Score: 20.482706926412007
Score: 17.395797375642243
Score: 5.606668411195422
Score: 13.534889927489722
Score: 21.620380142073646
Score: 30.426693999798907
Score: 5.2956899456325806


In [None]:
# Save hyper parameters.
save_dict(run_dir, hparams, 'hparams')

In [None]:
# Save test results.
save_dict(results_dir, test_results, 'metrics')