# Train Baseline

In [None]:
# Static experiment settings.
experiment = 'baseline'

print(f'Running experiment: {experiment}')

## Setup

### Environment

In [None]:
# If this is a notebook which is executed in colab [in_colab=True]:
#  1. Mount google drive and use the repository in there [mount_drive=True] (the repository must be in your google drive root folder).
#  2. Clone repository to remote machine [mount_drive=False].
in_colab = False
mount_drive = True

try:
    # Check if running in colab.
    in_colab = 'google.colab' in str(get_ipython())
except:
    pass

if in_colab:
    if mount_drive:
        # Mount google drive and navigate to it.
        from google.colab import drive
        drive.mount('/content/drive')
        %cd drive/MyDrive
    else:
        # Pull repository.
        !git clone https://github.com/HenningBuhl/low-resource-machine-translation

    # Workaround for problem with undefined symbols (https://github.com/scverse/scvi-tools/issues/1464).
    !pip install --quiet scvi-colab
    from scvi_colab import install
    install()

    # Navigate to the repository and install requirements.
    %cd low-resource-machine-translation
    !pip install -r requirements.txt

    # Navigate to notebook location.
    %cd experiments

In [None]:
# Add src module directory to system path for subsecuent imports.
import sys
sys.path.insert(0, '../src')

In [None]:
from util import is_notebook

# Settings and module reloading (only in Jupyter Notebooks).
if is_notebook():
    # Module reloading.
    %load_ext autoreload
    %autoreload 2

    # Plot settings.
    %matplotlib inline

### Imports

In [None]:
# From packages.
import pytorch_lightning as pl
import argparse
from distutils.util import strtobool

# From repository.
from arg_management import *
from constants import *
from data import *
from layers import *
from metric_logging import *
from plotting import *
from path_management import *
from tokenizer import *
from transformer import *
from util import *

### Arguments

In [None]:
# Define arguments with argparse.
arg_manager = ArgManager()
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

# Experiment.
parser.add_argument('--src-lang', default='de', type=str, help='The source language.')
parser.add_argument('--tgt-lang', default='nl', type=str, help='The target language.')

# Run.
arg_manager.add_run_args(parser)
# Metrics.
arg_manager.add_metrics_args(parser)
# Data.
arg_manager.add_data_args(parser)
# Tokenization.
arg_manager.add_tokenization_args(parser)
# Architecture.
arg_manager.add_architecture_args(parser)
# Optimizer.
arg_manager.add_optimizer_args(parser)
# Scheduler.
arg_manager.add_scheduler_args(parser)
# Training.
arg_manager.add_training_args(parser)
# Early Stopping + Model Checkpoint.
arg_manager.add_early_stopping_and_checkpoiting_args(parser)

# Parse args.
if is_notebook():
    sys.argv = ['-f']  # Used to make argparse work in jupyter notebooks (all args must be optional).
    args, _ = parser.parse_known_args()  # -f can lead to unknown argument.
else:
    args = parser.parse_args()

# Print args.
print('Arguments:')
print(args)

In [None]:
# Auto-infer args.
arg_manager.auto_infer_args(args, experiment)

In [None]:
# Adjust arguments for test purposes.
if is_notebook() and True:  # Quickly turn on and off with 'and True/False'.
    #args.dev_run = True
    #args.fresh_run = True
    
    print('Adjusted args in notebook')

In [None]:
# Sanity check args.
arg_manager.sanity_check_args(args)

### Seed

In [None]:
# Set seed.
from pytorch_lightning import seed_everything
seed_everything(args.seed, workers=True)

### Paths

In [None]:
# Create directories and create file names.
em = ExperimentManager(f'{experiment}-{args.src_lang}-{args.tgt_lang}', experiment)
em.init()

In [None]:
# Save arguments.
save_dict(em.args_file, args.__dict__)

## Data Preprocessing

In [None]:
# Create ParallelDataPreProcessor.
pp = ParallelDataPreProcessor(args.src_lang, args.tgt_lang)

### Splitting

In [None]:
# Split data into (train, val, test) sets.
pp.split_data(args.shuffle_before_split, args.num_val_examples, args.num_test_examples, args.fresh_run)

### Tokenizers

In [None]:
# Load tokenizers.
src_tokenizer = TokenizerBuilder(args.src_lang, args.tgt_lang).build(
    args.src_vocab_size, args.src_char_coverage, fresh_run=args.fresh_run)
tgt_tokenizer = TokenizerBuilder(args.tgt_lang, args.src_lang).build(
    args.tgt_vocab_size, args.tgt_char_coverage, fresh_run=args.fresh_run)

### Preparation

In [None]:
# Load dataloaders.
train_dataloader, val_dataloader, test_dataloader = pp.pre_process(src_tokenizer, tgt_tokenizer, args.batch_size, args.shuffle_train_data, args.max_examples, args.max_len, fresh_run=args.fresh_run)

## Experiment

### Create Model

In [None]:
# Create model.
model = Transformer(
    src_tokenizer,
    tgt_tokenizer,
    args.learning_rate,
    args.weight_decay,
    args.beta_1,
    args.beta_2,
    args.enable_scheduling,
    args.warm_up_steps,
    args.num_layers,
    args.d_model,
    args.dropout,
    args.num_heads,
    args.d_ff,
    args.max_len,
    args.label_smoothing,
    args.track_bleu,
    args.track_ter,
    args.track_tp,
    args.track_chrf,
)

In [None]:
# Save untrained model.
model.save(em.baseline.untrained_model_file)

### Training Setup

In [None]:
# Create callbacks and loggers.
callbacks = []

if args.enable_checkpointing:
    model_checkpoint = pl.callbacks.ModelCheckpoint(
        monitor=args.monitor,
        dirpath=em.baseline.checkpoint_dir,
        filename='{epoch}-{step}-{val_loss:.2f}',
        save_top_k=1,
        save_last=True,
        every_n_epochs=1,
        verbose=True,
    )
    callbacks.append(model_checkpoint)

if args.enable_early_stopping:
    early_stopping_callback = pl.callbacks.EarlyStopping(
        monitor=args.monitor,
        min_delta=args.min_delta,
        patience=args.patience,
        mode=args.mode,
        verbose=True,
    )
    callbacks.append(early_stopping_callback)

if args.enable_scheduling:
    lr_monitor = pl.callbacks.LearningRateMonitor(
        logging_interval='step',
        log_momentum=True
    )
    callbacks.append(lr_monitor)

# Create metric logger.
metric_logger = MetricLogger()

In [None]:
# Create trainer.
trainer = pl.Trainer(
    deterministic=True,
    fast_dev_run=args.dev_run,
    max_epochs=args.max_epochs,
    logger=metric_logger,
    log_every_n_steps=1,
    enable_checkpointing=args.enable_checkpointing,
    default_root_dir=em.baseline.checkpoint_dir,
    callbacks=callbacks,
    gpus=args.gpus if str(device) == 'cuda' else 0
)

In [None]:
# Evaluate before training.
if args.eval_before_train:
    trainer.validate(model, dataloaders=val_dataloader)

### Train Model

In [None]:
# Training.
trainer.fit(model,
            train_dataloaders=train_dataloader,
            val_dataloaders=val_dataloader,
            ckpt_path=args.ckpt_path)

### Save Model

In [None]:
# Save model.
if args.enable_checkpointing:
    model.load_from_checkpoint(model_checkpoint.best_model_path)

model.save(em.baseline.model_file)

### Test Model

In [None]:
# Testing.
test_metrics = trainer.test(model, dataloaders=test_dataloader)

## Export Results

In [None]:
# Save recorded metrics.
metric_logger.manual_save(em.baseline.metrics_dir, em.baseline.metrics_file)

In [None]:
# Save metric plots.
for metric in model.tracked_metrics:
    plot_metric(metric_logger.metrics, metric,
                save_path=em.baseline.metric_svg_template.format(metric))