# Train Baseline

smol explanation...

## Setup

### Environment

In [None]:
# If this is a notebook which is executed in colab [in_colab=True]:
#  1. Mount google drive and use the repository in there [mount_drive=True] (the repository must be in your google drive root folder).
#  2. Clone repository to remote machine [mount_drive=False].
in_colab = False
mount_drive = True

try:
    # Check if running in colab.
    in_colab = 'google.colab' in str(get_ipython())
except:
    pass

if in_colab:
    if mount_drive:
        # Mount google drive and navigate to it.
        from google.colab import drive
        drive.mount('/content/drive')
        %cd drive/MyDrive
    else:
        # Pull repository.
        !git clone https://github.com/HenningBuhl/low-resource-machine-translation

    # Workaround for problem with undefined symbols (https://github.com/scverse/scvi-tools/issues/1464).
    !pip install --quiet scvi-colab
    from scvi_colab import install
    install()

    # Navigate to the repository and install requirements.
    %cd low-resource-machine-translation
    !pip install -r requirements.txt

    # Navigate to notebook location.
    %cd experiments

In [None]:
# Add src module directory to system path for subsecuent imports.
import sys
sys.path.insert(0, '../src')

In [None]:
from util import is_notebook

# Settings and module reloading (only in Jupyter Notebooks).
if is_notebook():
    # Module reloading.
    %load_ext autoreload
    %autoreload 2

    # Plot settings.
    %matplotlib inline

### Imports

In [None]:
# From packages.
import pytorch_lightning as pl

# From repository.
from arguments import *
from benchmark import *
from calc import *
from constants import *
from data import *
from layers import *
from metric_logging import *
from plotting import *
from path_management import *
from tokenizer import *
from transformer import *
from util import *

### Arguments

In [None]:
# Define arguments with argparse.
import argparse
from distutils.util import strtobool
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

# Experiment.
parser.add_argument('--dev-run', default=False, type=strtobool, help='Executes a fast dev run instead of fully training.')
parser.add_argument('--fresh-run', default=False, type=strtobool, help='Ignores all cashed data on disk, reruns generation and overwrites everything.')
parser.add_argument('--seed', default=0, type=int, help='The random seed of the program.')
parser.add_argument('--src-lang', default='de', type=str, help='The source language.')
parser.add_argument('--tgt-lang', default='nl', type=str, help='The target language.')
parser.add_argument('--eval-before-train', default=False, type=strtobool, help='Evaluate the model on the validation data before training.')

# Metrics.
parser.add_argument('--track-bleu', default=True, type=strtobool, help='Whether to track the SacreBLEU score metric.')
parser.add_argument('--track-ter', default=False, type=strtobool, help='Whether to track the translation edit rate metric.')
parser.add_argument('--track-tp', default=False, type=strtobool, help='Whether to track the translation perplexity metric.')
parser.add_argument('--track-chrf', default=False, type=strtobool, help='Whether to track the CHRF score metric.')

# Data.
parser.add_argument('--shuffle-before-split', default=False, type=strtobool, help='Whether to shuffle the data before creating the train, validation and test sets.')
parser.add_argument('--num-val-examples', default=3000, type=int, help='The number of validation examples.') 
parser.add_argument('--num-test-examples', default=3000, type=int, help='The number of test examples.')

# Tokenization.
parser.add_argument('--src-vocab-size', default=16000, type=int, help='The vocabulary size of the source language tokenizer.')
parser.add_argument('--src-char-coverage', default=1.0, type=float, help='The character coverage (percentage) of the source language tokenizer.')
parser.add_argument('--tgt-vocab-size', default=16000, type=int, help='The vocabulary size of the target language tokenizer.')
parser.add_argument('--tgt-char-coverage', default=1.0, type=float, help='The character coverage (percentage) of the target language tokenizer.')

# Architecture.
parser.add_argument('--num-layers', default=6, type=int, help='The number of encoder and decoder layers.')
parser.add_argument('--d-model', default=512, type=int, help='The embedding size.')
parser.add_argument('--drop-out-rate', default=0.1, type=float, help='The dropout rate.')
parser.add_argument('--num-heads', default=8, type=int, help='The number of attention heads.')
parser.add_argument('--d-ff', default=2048, type=int, help='The feed forward dimension.')
parser.add_argument('--max-len', default=128, type=int, help='The maximum sequence length.')

# Optimizer.
parser.add_argument('--learning-rate', default=1e-4, type=float, help='The learning rate.')
parser.add_argument('--weight-decay', default=0, type=float, help='The weight decay.')
#parser.add_argument('--beta-1', default=0.9, type=float, help='')
#parser.add_argument('--beta-2', default=0.999, type=float, help='')
#parser.add_argument('--scheduling', default=0, type=float, help='')

# Training.
parser.add_argument('--batch-size', default=80, type=int, help='The batch size.')
parser.add_argument('--label-smoothing', default=0, type=float, help='The amount of smoothing when calculating the loss.')
parser.add_argument('--max-epochs', default=10, type=int, help='The maximum number of training epochs.')
parser.add_argument('--max-examples', default=-1, type=int, help='The maximum number of training examples.')
parser.add_argument('--shuffle-train-data', default=True, type=strtobool, help='Whether to shuffle the training data during training.')
parser.add_argument('--gpus', default=1, type=int, help='The number of GPUs.')
parser.add_argument('--num-workers', default=4, type=int, help='The number of pytorch workers.')
parser.add_argument('--ckpt-path', default=None, type=str, help='The model checkpoint form which to resume training.')

# Early Stopping + Model Checkpoint.
parser.add_argument('--enable-early-stopping', default=False, type=strtobool, help='Whether to enable early stopping.')
parser.add_argument('--enable-checkpointing', default=False, type=strtobool, help='Whether to enable checkpointing. The best and the last version of the model are saved.')
parser.add_argument('--monitor', default='val_loss', type=str, help='The metric to monitor.')
parser.add_argument('--min-delta', default=0, type=float, help='The minimum change the metric must achieve.')
parser.add_argument('--patience', default=3, type=int, help='Number of epochs that the monitored metric has time to improve.')
parser.add_argument('--mode', default='min', type=str, choices=['min', 'max'], help='How the monitored metric should improve.')

# Parse args.
if is_notebook():
    sys.argv = ['-f']  # Used to make argparse work in jupyter notebooks (all args must be optional).
    args, _ = parser.parse_known_args()  # -f can lead to unknown argument.
else:
    args = parser.parse_args()

# Print args.
print('Arguments:')
print(args)

In [None]:
# Auto-infer args.
auto_infer_args(args)

In [None]:
# Adjust arguments for test purposes.
if is_notebook() and True:  # Quickly turn on and off with 'and True/False'.
    #args.dev_run = True
    #args.fresh_run = True
    args.max_epochs = 2
    args.batch_size = 1
    args.max_examples = 2
    args.num_val_examples = 1
    args.num_test_examples = 1

    #args.enable_checkpointing = True
    #args.enable_early_stopping = True

    #args.eval_before_train = False

    #args.label_smoothing = 0.1

    args.track_ter = True
    args.track_tp = True
    args.track_chrf = True
    print('Adjusted args in notebook')

In [None]:
# Sanity check args.
sanity_check_args(args)

### Seed

In [None]:
# Set seed.
from pytorch_lightning import seed_everything
seed_everything(args.seed, workers=True)

### Paths

In [None]:
# Create directories and create file names.
pm = ExperimentPathManager(f'baseline-{args.src_lang}-{args.tgt_lang}', 'baseline')
pm.init()

In [None]:
# Save arguments.
save_dict(pm.args_file, args.__dict__)

## Data Preprocessing

In [None]:
# Create PreProcessor.
pp = PreProcessor(args.src_lang, args.tgt_lang)

### Splitting

In [None]:
# Split data into (train, val, test) sets.
pp.split_data(args.shuffle_before_split, args.num_val_examples, args.num_test_examples, args.fresh_run)

### Tokenizers

In [None]:
# Load tokenizers.
src_tokenizer = TokenizerBuilder(args.src_lang, args.tgt_lang).build(args.src_vocab_size, args.src_char_coverage, fresh_run=args.fresh_run)
tgt_tokenizer = TokenizerBuilder(args.tgt_lang, args.src_lang).build(args.tgt_vocab_size, args.tgt_char_coverage, fresh_run=args.fresh_run)

### Preparation

In [None]:
# Load dataloaders.
train_dataloader, val_dataloader, test_dataloader = pp.pre_process(src_tokenizer, tgt_tokenizer, args.batch_size, args.shuffle_train_data, args.max_examples, args.max_len, fresh_run=args.fresh_run)

## Experiment

### Model

In [None]:
# Create model.
model = Transformer(src_tokenizer,
                    tgt_tokenizer,
                    args.learning_rate,
                    args.weight_decay,
                    args.num_layers,
                    args.d_model,
                    args.drop_out_rate,
                    args.num_heads,
                    args.d_ff,
                    args.max_len,
                    args.label_smoothing,
                    args.track_bleu,
                    args.track_ter,
                    args.track_tp,
                    args.track_chrf,
                    )

In [None]:
# Save untrained model.
model.save(pm.baseline.untrained_model_file)

### Training

In [None]:
# Create callbacks.
callbacks = []

if args.enable_checkpointing:
    model_checkpoint = pl.callbacks.ModelCheckpoint(
        monitor=args.monitor,
        dirpath=pm.baseline.checkpoint_dir,
        filename='{epoch}-{val_loss:.2f}',
        save_top_k=1,
        save_last=True,
        every_n_epochs=1,
        verbose=True,
    )
    callbacks.append(model_checkpoint)

if args.enable_early_stopping:
    early_stopping_callback = pl.callbacks.EarlyStopping(
        monitor=args.monitor,
        min_delta=args.min_delta,
        patience=args.patience,
        mode=args.mode,
        verbose=True,
    )
    callbacks.append(early_stopping_callback)

In [None]:
# Create metric logger.
metric_logger = MetricLogger()

In [None]:
# Create trainer.
trainer = pl.Trainer(deterministic=True,
                  fast_dev_run=args.dev_run,
                  max_epochs=args.max_epochs,
                  logger=metric_logger,
                  log_every_n_steps=1,
                  enable_checkpointing=args.enable_checkpointing,
                  default_root_dir=pm.baseline.checkpoint_dir,
                  callbacks=callbacks,
                  gpus=args.gpus if str(device) == 'cuda' else 0)

In [None]:
# Evaluate before training.
if args.eval_before_train:
    trainer.validate(model, dataloaders=val_dataloader)

In [None]:
# Training.
trainer.fit(model,
            train_dataloaders=train_dataloader,
            val_dataloaders=val_dataloader,
            ckpt_path=args.ckpt_path)

In [None]:
# Save model.
if args.enable_checkpointing:
    model.load_from_checkpoint(model_checkpoint.best_model_path)

model.save(pm.baseline.model_file)

### Testing

In [None]:
# Testing.
test_metrics = trainer.test(model, dataloaders=test_dataloader)

## Exporting Results

In [None]:
# Save recorded metrics.
metric_logger.manual_save(pm.baseline.metrics_dir, pm.baseline.metrics_file)

In [None]:
# Save metric plots.
for metric in model.tracked_metrics:
    plot_metric(metric_logger.metrics, metric, save_path=pm.baseline.metrics_svg_template.format(metric))