# Train Direct Pivoting

## Setup

### Environment

In [None]:
# If this is a notebook which is executed in colab [in_colab=True]:
#  1. Mount google drive and use the repository in there [mount_drive=True] (the repository must be in your google drive root folder).
#  2. Clone repository to remote machine [mount_drive=False].
in_colab = False
mount_drive = True

try:
    # Check if running in colab.
    in_colab = 'google.colab' in str(get_ipython())
except:
    pass

if in_colab:
    if mount_drive:
        # Mount google drive and navigate to it.
        from google.colab import drive
        drive.mount('/content/drive')
        %cd drive/MyDrive
    else:
        # Pull repository.
        !git clone https://github.com/HenningBuhl/low-resource-machine-translation

    # Workaround for problem with undefined symbols (https://github.com/scverse/scvi-tools/issues/1464).
    !pip install --quiet scvi-colab
    from scvi_colab import install
    install()

    # Navigate to the repository and install requirements.
    %cd low-resource-machine-translation
    !pip install -r requirements.txt

    # Navigate to notebook location.
    %cd experiments

In [None]:
# Add src module directory to system path for subsecuent imports.
import sys
sys.path.insert(0, '../src')

In [None]:
from util import is_notebook

# Settings and module reloading (only in Jupyter Notebooks).
if is_notebook():
    # Module reloading.
    %load_ext autoreload
    %autoreload 2

    # Plot settings.
    %matplotlib inline

### Imports

In [None]:
# From packages.
import pytorch_lightning as pl

# From repository.
from arguments import *
from benchmark import *
from calc import *
from constants import *
from data import *
from layers import *
from metric_logging import *
from plotting import *
from path_management import *
from tokenizer import *
from transformer import *
from util import *

### Arguments

In [None]:
# Define arguments with argparse.
import argparse
from distutils.util import strtobool
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

# Experiment.
parser.add_argument('--dev-run', default=False, type=strtobool, help='Executes a fast dev run instead of fully training.')
parser.add_argument('--fresh-run', default=False, type=strtobool, help='Ignores all cashed data on disk, reruns generation and overwrites everything.')
parser.add_argument('--seed', default=0, type=int, help='The random seed of the program.')
parser.add_argument('--src-lang', default='de', type=str, help='The source language.')
parser.add_argument('--tgt-lang', default='nl', type=str, help='The target language.')
parser.add_argument('--eval-before-train', default=False, type=strtobool, help='Evaluate the model on the validation data before training.')

# Metrics.
parser.add_argument('--track-bleu', default=True, type=strtobool, help='Whether to track the SacreBLEU score metric.')
parser.add_argument('--track-ter', default=False, type=strtobool, help='Whether to track the translation edit rate metric.')
parser.add_argument('--track-tp', default=False, type=strtobool, help='Whether to track the translation perplexity metric.')
parser.add_argument('--track-chrf', default=False, type=strtobool, help='Whether to track the CHRF score metric.')

# Data.
parser.add_argument('--shuffle-before-split', default=False, type=strtobool, help='Whether to shuffle the data before creating the train, validation and test sets.')
parser.add_argument('--num-val-examples', default=3000, type=int, help='The number of validation examples.') 
parser.add_argument('--num-test-examples', default=3000, type=int, help='The number of test examples.')

# Tokenization.
parser.add_argument('--src-vocab-size', default=16000, type=int, help='The vocabulary size of the source language tokenizer.')
parser.add_argument('--src-char-coverage', default=1.0, type=float, help='The character coverage (percentage) of the source language tokenizer.')
parser.add_argument('--tgt-vocab-size', default=16000, type=int, help='The vocabulary size of the target language tokenizer.')
parser.add_argument('--tgt-char-coverage', default=1.0, type=float, help='The character coverage (percentage) of the target language tokenizer.')

# Architecture.
parser.add_argument('--num-layers', default=6, type=int, help='The number of encoder and decoder layers.')
parser.add_argument('--d-model', default=512, type=int, help='The embedding size.')
parser.add_argument('--dropout-rate', default=0.1, type=float, help='The dropout rate.')
parser.add_argument('--num-heads', default=8, type=int, help='The number of attention heads.')
parser.add_argument('--d-ff', default=2048, type=int, help='The feed forward dimension.')
parser.add_argument('--max-len', default=128, type=int, help='The maximum sequence length.')

# Optimizer.
parser.add_argument('--learning-rate', default=1e-4, type=float, help='The learning rate.')
parser.add_argument('--weight-decay', default=0, type=float, help='The weight decay.')
parser.add_argument('--beta-1', default=0.9, type=float, help='Beta_1 parameter of Adam.')
parser.add_argument('--beta-2', default=0.98, type=float, help='Beta_2 parameter of Adam.')

# Scheduler.
parser.add_argument('--enable-scheduling', default=False, type=strtobool, help='Whether to enable scheduling.')
parser.add_argument('--warm-up-steps', default=4000, type=int, help='The number of warm up steps.')

# Training.
parser.add_argument('--batch-size', default=80, type=int, help='The batch size.')
parser.add_argument('--label-smoothing', default=0, type=float, help='The amount of smoothing when calculating the loss.')
parser.add_argument('--max-epochs', default=10, type=int, help='The maximum number of training epochs.')
parser.add_argument('--max-examples', default=-1, type=int, help='The maximum number of training examples.')
parser.add_argument('--shuffle-train-data', default=True, type=strtobool, help='Whether to shuffle the training data during training.')
parser.add_argument('--gpus', default=1, type=int, help='The number of GPUs.')
parser.add_argument('--num-workers', default=4, type=int, help='The number of pytorch workers.')
parser.add_argument('--ckpt-path', default=None, type=str, help='The model checkpoint form which to resume training.')

# Early Stopping + Model Checkpoint.
parser.add_argument('--enable-early-stopping', default=False, type=strtobool, help='Whether to enable early stopping.')
parser.add_argument('--enable-checkpointing', default=False, type=strtobool, help='Whether to enable checkpointing. The best and the last version of the model are saved.')
parser.add_argument('--monitor', default='val_loss', type=str, help='The metric to monitor.')
parser.add_argument('--min-delta', default=0, type=float, help='The minimum change the metric must achieve.')
parser.add_argument('--patience', default=3, type=int, help='Number of epochs that the monitored metric has time to improve.')
parser.add_argument('--mode', default='min', type=str, choices=['min', 'max'], help='How the monitored metric should improve.')

# Parse args.
if is_notebook():
    sys.argv = ['-f']  # Used to make argparse work in jupyter notebooks (all args must be optional).
    args, _ = parser.parse_known_args()  # -f can lead to unknown argument.
else:
    args = parser.parse_args()

# Print args.
print('Arguments:')
print(args)

In [None]:
# Auto-infer args.
auto_infer_args(args)

In [None]:
# Adjust arguments for test purposes.
if is_notebook() and True:  # Quickly turn on and off with 'and True/False'.
    #args.dev_run = True
    #args.fresh_run = True

    args.max_epochs = 5
    args.batch_size = 1
    args.max_examples = 2
    args.num_val_examples = 1
    args.num_test_examples = 1

    print('Adjusted args in notebook')

In [None]:
# Sanity check args.
sanity_check_args(args)

In [None]:
# Experiment paramereters.
hparams = dotdict({
    'src_lang': 'de',
    'pvt_lang': 'nl',
    'tgt_lang': 'en',
    'src_pvt_model_path': 'models/baseline-de-nl.pt',
    'pvt_tgt_model_path': 'models/baseline-nl-en.pt',
    'batch_size': 80,
    'max_epochs': 10,
    'max_examples': 10_000,
    'gpus': 1,
    'num_workers': 4,
    'ckpt_path': None,
})

print('Experiment paramereters:')
print(hparams)

### Seed

In [None]:
# Set seed.
from pytorch_lightning import seed_everything
seed_everything(args.seed, workers=True)

### Paths

In [None]:
# Create directories and create file names.
pm = ExperimentPathManager(f'direct-pivoting-{args.src_lang}-{args.tgt_lang}-{args.tgt_lang}', 'direct-pivoting')
pm.init()

In [None]:
# Save arguments.
save_dict(pm.args_file, args.__dict__)

In [None]:
# Constant directories.
data_dir = os.path.join('./', 'data')
tokenizers_dir = os.path.join('./', 'tokenizers')
runs_dir = os.path.join('./', 'runs')

# Experiment directories.
run_dir = os.path.join(runs_dir, f'direct-pivoting-{hparams.src_lang}-{hparams.pvt_lang}-{hparams.tgt_lang}-{get_time_as_string()}')
model_checkpoints_dir = os.path.join(run_dir, 'checkpoints')
results_dir = os.path.join(run_dir, 'results')
pre_training_eval_results_dir = os.path.join(run_dir, 'pre-training-eval-results')

dirs = [data_dir, tokenizers_dir, runs_dir, run_dir, model_checkpoints_dir, results_dir, pre_training_eval_results_dir]
for dir in dirs:
    create_dir(dir)

print('Created directories.')

In [None]:
# Load Metrics.
score_metric = load_metric('sacrebleu')

print('Loaded metrics.')

In [None]:
# Download data.
download_data(hparams.src_lang, hparams.tgt_lang)

In [None]:
# Load tokenizers.
src_tokenizer = load_tokenizer(hparams.src_lang, hparams.tgt_lang)
pvt_tokenizer = load_tokenizer(hparams.pvt_lang, hparams.tgt_lang)
tgt_tokenizer = load_tokenizer(hparams.tgt_lang, hparams.src_lang)

print('Loaded tokenizers.')

In [None]:
# Load data.
train_dataset, val_dataset, test_dataset = load_data(hparams.src_lang,
                                                     hparams.tgt_lang,
                                                     src_tokenizer,
                                                     tgt_tokenizer,
                                                     hparams.max_examples)

print(f'Preprocessed data ({hparams.src_lang}-{hparams.tgt_lang})')
print(f'\tTraining data:   {len(train_dataset)}')
print(f'\tValidation data: {len(val_dataset)}')
print(f'\tTest data:       {len(test_dataset)}')

In [None]:
# Create data loaders.
train_dataloader = DataLoader(train_dataset, batch_size=hparams.batch_size, num_workers=hparams.num_workers)
val_dataloader = DataLoader(val_dataset, batch_size=hparams.batch_size, num_workers=hparams.num_workers)
test_dataloader = DataLoader(test_dataset, batch_size=hparams.batch_size, num_workers=hparams.num_workers)

print('Created data loaders.')

In [None]:
# Create models.
src_pvt_model = Transformer(src_tokenizer,
                            pvt_tokenizer,
                            score_metric=score_metric)

pvt_tgt_model = Transformer(pvt_tokenizer,
                            tgt_tokenizer,
                            score_metric=score_metric)

print('Created models.')

In [None]:
# Load models.
src_pvt_model.load_state_dict(torch.load(hparams.src_pvt_model_path))
pvt_tgt_model.load_state_dict(torch.load(hparams.pvt_tgt_model_path))

src_pvt_model.to(device)
pvt_tgt_model.to(device)

print('Loaded models.')

In [None]:
# Create direct pivoting model.
model = pvt_tgt_model
model.src_tokenizer = src_tokenizer
model.src_vocab_size = src_pvt_model.src_vocab_size
model.src_embedding = src_pvt_model.src_embedding
model.encoder = src_pvt_model.encoder

In [None]:
# Add aditional regularization to combat over-fitting on limited data.
model.set_dropout_rate(0.3)
model.weight_decay = 0.0

In [None]:
# Create trainer.
metric_logger = MetricLogger()
checkpoint_callback = pl.callbacks.ModelCheckpoint(
          dirpath=model_checkpoints_dir,
          verbose=True,
          save_last=True,
      )

trainer = Trainer(deterministic=True,
                  fast_dev_run=False,
                  max_epochs=hparams.max_epochs,
                  logger=metric_logger,
                  log_every_n_steps=1,
                  enable_checkpointing=True,
                  default_root_dir=model_checkpoints_dir,
                  callbacks=[checkpoint_callback],
                  gpus=hparams.gpus if str(device) == 'cuda' else 0)

print('Created trainer.')

In [None]:
# Save untrained model.
model_path = os.path.join(pre_training_eval_results_dir, 'model.pt')
torch.save(model.state_dict(), model_path)

In [None]:
# Evaluate performance.
test_metrics = trainer.test(model, dataloaders=test_dataloader)
print(test_metrics)
metric_logger.manual_save(pre_training_eval_results_dir)
metric_logger.reset()

In [None]:
# Training.
trainer.fit(model,
            train_dataloaders=train_dataloader,
            val_dataloaders=val_dataloader)

In [None]:
# Save model.
model_path = os.path.join(run_dir, 'model.pt')
torch.save(model.state_dict(), model_path)

In [None]:
# Testing.
test_metrics = trainer.test(model, dataloaders=test_dataloader)
print(test_metrics)

In [None]:
# Plot loss metrics.
save_path = os.path.join(results_dir, 'loss.svg')
plot_metric(metric_logger.metrics, 'loss', 'Loss', save_path=save_path)

In [None]:
# Plot score metric.
save_path = os.path.join(results_dir, 'score.svg')
plot_metric(metric_logger.metrics, 'score', 'Score', save_path=save_path)

In [None]:
# Save hyper parameters.
save_dict(run_dir, hparams, 'hparams')

In [None]:
# Save recorded metrics.
metric_logger.manual_save(results_dir)