In [None]:
# Add src module directory to system path for subsecuent imports.
import os
import sys
sys.path.insert(0, '../src')

In [None]:
from util import is_notebook

# Settings (only in Jupyter Notebooks).
if is_notebook():
    # Module reloading.
    %load_ext autoreload
    # aimport?
    %autoreload 2
    # Plot settings.
    %matplotlib inline

In [None]:
# Imports.
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from datasets import load_metric

from constants import *
from util import *
from transformer import Transformer
from tokenizer import load_tokenizer
from data import download_data, load_data
from plotting import plot_metric
from metric_logging import MetricLogger

In [None]:
# Set seed.
from pytorch_lightning import seed_everything
seed_everything(0, workers=True)

In [None]:
# Experiment paramereters.
hparams = dotdict({
    'src_lang': 'de',
    'pvt_lang': 'nl',
    'tgt_lang': 'en',
    'src_pvt_model_path': 'models/baseline-de-nl.pt',
    'pvt_tgt_model_path': 'models/baseline-nl-en.pt',
    'batch_size': 80,
    'max_epochs': 10,
    'max_examples': -1,
    'max_examples_fine_tune': 10_000,
    'gpus': 1,
    'num_workers': 4,
    'ckpt_path': None,
    'step_two_model': 'models/reverse-step-wise-pivoting-de-nl-en-step-2.pt',
})

print('Experiment paramereters:')
print(hparams)

In [None]:
# Constant directories.
data_dir = os.path.join('./', 'data')
tokenizers_dir = os.path.join('./', 'tokenizers')
runs_dir = os.path.join('./', 'runs')

# Experiment directories.
run_dir = os.path.join(runs_dir, f'reverse-step-wise-pivoting-{hparams.src_lang}-{hparams.pvt_lang}-{hparams.tgt_lang}-{get_time_as_string()}')
model_checkpoints_dir = os.path.join(run_dir, 'checkpoints')
results_dir = os.path.join(run_dir, 'results')
pre_training_eval_results_dir = os.path.join(run_dir, 'pre-training-eval-results')
step_results_dir = os.path.join(run_dir, 'step-results')
step_model_checkpoints_dir = os.path.join(run_dir, 'step_checkpoints')

dirs = [data_dir, runs_dir, run_dir, model_checkpoints_dir, results_dir, pre_training_eval_results_dir, step_results_dir, step_model_checkpoints_dir]
for dir in dirs:
    create_dir(dir)

print('Created directories.')

In [None]:
# Load Metrics.
score_metric = load_metric('sacrebleu')

print('Loaded metrics.')

In [None]:
# Download data.
download_data(hparams.src_lang, hparams.tgt_lang)

In [None]:
# Load tokenizers.
src_tokenizer = load_tokenizer(hparams.src_lang, hparams.tgt_lang)
pvt_tokenizer = load_tokenizer(hparams.pvt_lang, hparams.tgt_lang)
tgt_tokenizer = load_tokenizer(hparams.tgt_lang, hparams.src_lang)

print('Loaded tokenizers.')

In [None]:
#################
#
#
# !!!STEP 2!!! (step one is already done)
# Freeze encoder and train src-pvt.
#
#
#################

In [None]:
# Load data.
train_dataset, val_dataset, test_dataset = load_data(hparams.src_lang,
                                                     hparams.pvt_lang,
                                                     src_tokenizer,
                                                     pvt_tokenizer,
                                                     hparams.max_examples)

print(f'Preprocessed data ({hparams.src_lang}-{hparams.pvt_lang})')
print(f'\tTraining data:   {len(train_dataset)}')
print(f'\tValidation data: {len(val_dataset)}')
print(f'\tTest data:       {len(test_dataset)}')

In [None]:
# Create data loaders.
train_dataloader = DataLoader(train_dataset, batch_size=hparams.batch_size, num_workers=hparams.num_workers)
val_dataloader = DataLoader(val_dataset, batch_size=hparams.batch_size, num_workers=hparams.num_workers)
test_dataloader = DataLoader(test_dataset, batch_size=hparams.batch_size, num_workers=hparams.num_workers)

print('Created data loaders.')

In [None]:
# Create models.
model = Transformer(src_tokenizer,
                    pvt_tokenizer,
                    score_metric=score_metric)

pvt_tgt_model = Transformer(pvt_tokenizer,
                            tgt_tokenizer,
                            score_metric=score_metric)

print('Created models.')

In [None]:
# Load model.
pvt_tgt_model.load_state_dict(torch.load(hparams.pvt_tgt_model_path))

In [None]:
# Use pre-trained decoder from step 1 and freeze its weights.
model.decoder = pvt_tgt_model.decoder
for param in model.decoder.parameters():
    param.requires_grad = False

In [None]:
# Create trainer.
metric_logger = MetricLogger()
checkpoint_callback = pl.callbacks.ModelCheckpoint(
          dirpath=step_model_checkpoints_dir,
          verbose=True,
          save_last=True
      )

trainer = Trainer(deterministic=True,
                  fast_dev_run=False,
                  max_epochs=hparams.max_epochs,
                  logger=metric_logger,
                  log_every_n_steps=1,
                  enable_checkpointing=True,
                  default_root_dir=model_checkpoints_dir,
                  callbacks=[checkpoint_callback],
                  gpus=hparams.gpus if str(device) == 'cuda' else 0)

print('Created trainer.')

In [None]:
# Save untrained model.
model_path = os.path.join(pre_training_eval_results_dir, 'model.pt')
torch.save(model.state_dict(), model_path)

In [None]:
# Evaluate performance.
test_metrics = trainer.test(model, dataloaders=test_dataloader)
print(test_metrics)
metric_logger.manual_save(pre_training_eval_results_dir)
metric_logger.reset()

In [None]:
# Training step 2.
if hparams.step_two_model is None:
    trainer.fit(model,
            train_dataloaders=train_dataloader,
            val_dataloaders=val_dataloader)
else:
    print('Loading step two model...')
    model.load_state_dict(torch.load(hparams.step_two_model))

In [None]:
# Save model.
step_model_path = os.path.join(run_dir, 'step-model.pt')
torch.save(model.state_dict(), step_model_path)

In [None]:
# Testing.
test_metrics = trainer.test(model, dataloaders=test_dataloader)
print(test_metrics)

In [None]:
# Plot loss metrics.
save_path = os.path.join(step_results_dir, 'loss.svg')
plot_metric(metric_logger.metrics, 'loss', 'Loss', save_path=save_path)

In [None]:
# Plot score metric.
save_path = os.path.join(step_results_dir, 'score.svg')
plot_metric(metric_logger.metrics, 'score', 'Score', save_path=save_path)

In [None]:
# Save recorded metrics.
metric_logger.manual_save(step_results_dir)
metric_logger.reset()

In [None]:
#################
#
#
# !!!FINE-TUNE STEP!!!
# Fine tune on src-tgt
#
#
#################

In [None]:
# Load data.
train_dataset, val_dataset, test_dataset = load_data(hparams.src_lang,
                                                     hparams.tgt_lang,
                                                     src_tokenizer,
                                                     tgt_tokenizer,
                                                     hparams.max_examples_fine_tune)

print(f'Preprocessed data ({hparams.src_lang}-{hparams.tgt_lang})')
print(f'\tTraining data:   {len(train_dataset)}')
print(f'\tValidation data: {len(val_dataset)}')
print(f'\tTest data:       {len(test_dataset)}')

In [None]:
# Create data loaders.
train_dataloader = DataLoader(train_dataset, batch_size=hparams.batch_size, num_workers=hparams.num_workers)
val_dataloader = DataLoader(val_dataset, batch_size=hparams.batch_size, num_workers=hparams.num_workers)
test_dataloader = DataLoader(test_dataset, batch_size=hparams.batch_size, num_workers=hparams.num_workers)

print('Created data loaders.')

In [None]:
# Unfreeze the encoder and fix the tokenizer and embedding.
for param in model.decoder.parameters():
    param.requires_grad = True
model.tgt_tokenizer = tgt_tokenizer
model.tgt_vocab_size = pvt_tgt_model.tgt_vocab_size
model.tgt_embedding = pvt_tgt_model.tgt_embedding
model.output_linear = pvt_tgt_model.output_linear

In [None]:
# Add aditional regularization to combat over-fitting on limited data.
model.set_dropout_rate(0.3)
model.weight_decay = 0.0

In [None]:
# Create trainer.
metric_logger = MetricLogger()
checkpoint_callback = pl.callbacks.ModelCheckpoint(
          dirpath=model_checkpoints_dir,
          verbose=True,
          save_last=True
      )

trainer = Trainer(deterministic=True,
                  fast_dev_run=False,
                  max_epochs=hparams.max_epochs,
                  logger=metric_logger,
                  log_every_n_steps=1,
                  enable_checkpointing=True,
                  default_root_dir=model_checkpoints_dir,
                  callbacks=[checkpoint_callback],
                  gpus=hparams.gpus if str(device) == 'cuda' else 0)

print('Created trainer.')

In [None]:
# Training step 2.
trainer.fit(model,
            train_dataloaders=train_dataloader,
            val_dataloaders=val_dataloader)

In [None]:
# Save model.
model_path = os.path.join(run_dir, 'model.pt')
torch.save(model.state_dict(), model_path)

In [None]:
# Testing.
test_metrics = trainer.test(model, dataloaders=test_dataloader)
print(test_metrics)

In [None]:
# Plot loss metrics.
save_path = os.path.join(results_dir, 'loss.svg')
plot_metric(metric_logger.metrics, 'loss', 'Loss', save_path=save_path)

In [None]:
# Plot score metric.
save_path = os.path.join(results_dir, 'score.svg')
plot_metric(metric_logger.metrics, 'score', 'Score', save_path=save_path)

In [None]:
# Save hyper parameters.
save_dict(run_dir, hparams, 'hparams')

In [None]:
# Save recorded metrics.
metric_logger.manual_save(results_dir)