# Evaluate Models on Benchmarks

In [None]:
# Static experiment settings.
experiment = 'benchmark'

## Setup

### Environment

In [None]:
# If this is a notebook which is executed in colab [in_colab=True]:
#  1. Mount google drive and use the repository in there [mount_drive=True] (the repository must be in your google drive root folder).
#  2. Clone repository to remote machine [mount_drive=False].
in_colab = False
mount_drive = True

try:
    # Check if running in colab.
    in_colab = 'google.colab' in str(get_ipython())
except:
    pass

if in_colab:
    if mount_drive:
        # Mount google drive and navigate to it.
        from google.colab import drive
        drive.mount('/content/drive')
        %cd drive/MyDrive
    else:
        # Pull repository.
        !git clone https://github.com/HenningBuhl/low-resource-machine-translation

    # Workaround for problem with undefined symbols (https://github.com/scverse/scvi-tools/issues/1464).
    !pip install --quiet scvi-colab
    from scvi_colab import install
    install()

    # Navigate to the repository and install requirements.
    %cd low-resource-machine-translation
    !pip install -r requirements.txt

    # Navigate to notebook location.
    %cd experiments

In [None]:
# Add src module directory to system path for subsecuent imports.
import sys
sys.path.insert(0, '../src')

In [None]:
from util import is_notebook

# Settings and module reloading (only in Jupyter Notebooks).
if is_notebook():
    # Module reloading.
    %load_ext autoreload
    %autoreload 2

    # Plot settings.
    %matplotlib inline

### Imports

In [None]:
# From packages.
import pytorch_lightning as pl
import argparse
from distutils.util import strtobool

# From repository.
from arg_management import *
from benchmark import *
from constants import *
from data import *
from layers import *
from metric_logging import *
from plotting import *
from path_management import *
from tokenizer import *
from transformer import *
from util import *

### Arguments

In [None]:
# Define arguments with argparse.
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

# Experiment.
parser.add_argument('--inferece-methods', default='greedy', type=str, nargs="*", choices=['greedy', 'beam-search', 'top-k', 'top-p'], help='The inference methods used.')
parser.add_argument('--beam-sizes', default=8, type=int, nargs="*", help='The number of different beam sizes to be used.')
parser.add_argument('--top-ks', default=15, type=int, nargs="*", help='The differnt top-Ks being used.')
parser.add_argument('--top-ps', default=0.7, type=int, nargs="*", help='The differnt top-ps being used.')

# Run.
arg_manager.add_run_args(parser)
# Metrics.
arg_manager.add_metrics_args(parser)
# Data.
arg_manager.add_data_args(parser)
# Tokenization.
arg_manager.add_tokenization_args(parser)
# Architecture.
arg_manager.add_architecture_args(parser)
# Optimizer.
arg_manager.add_optimizer_args(parser)
# Scheduler.
arg_manager.add_scheduler_args(parser)
# Training.
arg_manager.add_training_args(parser)
# Early Stopping + Model Checkpoint.
arg_manager.add_early_stopping_and_checkpoiting_args(parser)

# Parse args.
if is_notebook():
    sys.argv = ['-f']  # Used to make argparse work in jupyter notebooks (all args must be optional).
    args, _ = parser.parse_known_args()  # -f can lead to unknown argument.
else:
    args = parser.parse_args()

# Print args.
print('Arguments:')
print(args)

In [None]:
# Auto-infer args.
auto_infer_args(args, experiment)

In [None]:
# Adjust arguments for test purposes.
if is_notebook() and True:  # Quickly turn on and off with 'and True/False'.
    #args.dev_run = True
    #args.fresh_run = True

    args.max_epochs = 5
    args.batch_size = 1
    args.max_examples = 2
    args.num_val_examples = 1
    args.num_test_examples = 1

    print('Adjusted args in notebook')

In [None]:
# Sanity check args.
sanity_check_args(args)

### Seed

In [None]:
# Set seed.
from pytorch_lightning import seed_everything
seed_everything(args.seed, workers=True)

### Paths

In [None]:
# Create directories and create file names.
pm = ExperimentPathManager(experiment, experiment=experiment)
pm.init()

In [None]:
# Save arguments.
save_dict(pm.args_file, args.__dict__)

In [None]:
from util import is_notebook

# Settings (only in Jupyter Notebooks).
if is_notebook():
    # Module reloading.
    %load_ext autoreload
    # aimport?
    %autoreload 2
    # Plot settings.
    %matplotlib inline

In [None]:
# Imports.
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from datasets import load_metric

from constants import *
from util import *
from transformer import Transformer, cascaded_inference
from tokenizer import load_tokenizer
from data import download_data, load_data
from plotting import plot_metric
from metric_logging import MetricLogger
from benchmark import *

In [None]:
# Set seed.
from pytorch_lightning import seed_everything
seed_everything(0, workers=True)

In [None]:
# Experiment variables.

# List of models to be evaluated.
model_configs = [
    # Baseline and cascaded.
    ModelConfig('single', ('de', 'en'), './models/baseline-de-en.pt'),
    ModelConfig('single', ('de', 'nl'), './models/baseline-de-nl.pt'),
    ModelConfig('single', ('nl', 'en'), './models/baseline-nl-en.pt'),
    ModelConfig('cascaded', ('de', 'nl', 'en'), ['./models/baseline-de-nl.pt', './models/baseline-nl-en.pt'], 'cascaded-de-nl-en'),
    # Baseline (limited de-en).
    ModelConfig('single', ('de', 'en'), './models/baseline-de-en-10000-examples.pt'),
    ModelConfig('single', ('de', 'en'), './models/baseline-de-en-20000-examples.pt'),
    ModelConfig('single', ('de', 'en'), './models/baseline-de-en-50000-examples.pt'),
    ModelConfig('single', ('de', 'en'), './models/baseline-de-en-100000-examples.pt'),
    # Direct pivoting.
    ModelConfig('single', ('de', 'en'), './models/direct-pivoting-de-nl-en.pt'),
    ModelConfig('single', ('de', 'en'), './models/direct-pivoting-de-nl-en-10000-examples.pt'),
    ModelConfig('single', ('de', 'en'), './models/direct-pivoting-de-nl-en-20000-examples.pt'),
    ModelConfig('single', ('de', 'en'), './models/direct-pivoting-de-nl-en-50000-examples.pt'),
    ModelConfig('single', ('de', 'en'), './models/direct-pivoting-de-nl-en-100000-examples.pt'),
    # Step-wise pivoting.
    ModelConfig('single', ('de', 'en'), './models/step-wise-pivoting-de-nl-en.pt'),
    ModelConfig('single', ('de', 'en'), './models/step-wise-pivoting-de-nl-en-10000-examples.pt'),
    ModelConfig('single', ('de', 'en'), './models/step-wise-pivoting-de-nl-en-20000-examples.pt'),
    ModelConfig('single', ('de', 'en'), './models/step-wise-pivoting-de-nl-en-50000-examples.pt'),
    ModelConfig('single', ('de', 'en'), './models/step-wise-pivoting-de-nl-en-100000-examples.pt'),
    ModelConfig('single', ('nl', 'en'), './models/step-wise-pivoting-de-nl-en-step-2.pt'),
    # Reverse step-wise pivoting.
    ModelConfig('single', ('de', 'en'), './models/reverse-step-wise-pivoting-de-nl-en.pt'),
    ModelConfig('single', ('de', 'en'), './models/reverse-step-wise-pivoting-de-nl-en-10000-examples.pt'),
    ModelConfig('single', ('de', 'en'), './models/reverse-step-wise-pivoting-de-nl-en-20000-examples.pt'),
    ModelConfig('single', ('de', 'en'), './models/reverse-step-wise-pivoting-de-nl-en-50000-examples.pt'),
    ModelConfig('single', ('de', 'en'), './models/reverse-step-wise-pivoting-de-nl-en-100000-examples.pt'),
    ModelConfig('single', ('de', 'nl'), './models/reverse-step-wise-pivoting-de-nl-en-step-2.pt'),
]

# List of benchmarks to be applied to models.
benchmark_configs = [
    BenchmarkConfig('flores', flores_collate_fn, flores_pp_fn),
    BenchmarkConfig('tatoeba', tatoeba_collate_fn, tatoeba_pp_fn),
]

In [None]:
# Constant directories.
data_dir = os.path.join('./', 'data')
tokenizers_dir = os.path.join('./', 'tokenizers')
runs_dir = os.path.join('./', 'runs')

# Experiment directories.
run_dir = os.path.join(runs_dir, f'benchmark-{get_time_as_string()}')

dirs = [data_dir, tokenizers_dir, runs_dir, run_dir]
for dir in dirs:
    create_dir(dir)

print('Created directories.')

In [None]:
# Load Metrics.
score_metric = load_metric('sacrebleu')

print('Loaded metrics.')

In [None]:
# Experiment paramereters.
hparams = dotdict({
    'beam_size': 8,
    'top_k': 15,
    'top_p': 0.6,
    'num_workers': 4,
})

print('Experiment paramereters:')
print(hparams)

In [None]:
# Perform benchmark...
for bc in benchmark_configs:
    print(f'Performing "{bc.name}" benchmark on {len(model_configs)} models.')
    
    # Create directories.
    benchmark_dir = os.path.join(run_dir, bc.name)
    dirs = [benchmark_dir]
    for dir in dirs:
        create_dir(dir)
    
    # Download and unpack data.
    bc.collate_fn(data_dir)
    
    # ... on every model.
    for mc in model_configs:
        print(f'\tBenchmarking {mc.name} ({mc.langs}) on {bc.name}.')
        
        # Create directories.
        results_dir = os.path.join(benchmark_dir, mc.name)
        dirs = [results_dir]
        for dir in dirs:
            create_dir(dir)
        
        method_kwargs = {
            'greedy': {},
            'beam': {'beam_size': hparams.beam_size},
            'top_k': {'top_k': hparams.top_k},
            'top_p': {'top_p': hparams.top_p},
        }
        test_results = {k: 0 for k in method_kwargs.keys()}
        for method, kwargs in method_kwargs.items():
            print(f'\t\tBenchmarking {mc.name} ({mc.langs}) on {bc.name} with inference method {method}.')

            # Perform evaluation base on mc.type.
            if mc.type == 'single':
                src_lang, tgt_lang = mc.langs

                # Load tokenizers.
                src_tokenizer = load_tokenizer(src_lang, tgt_lang)
                tgt_tokenizer = load_tokenizer(tgt_lang, src_lang)
                print('\t\tLoaded tokenizers.')

                # Load data.
                dataset = bc.pp_fn(data_dir, src_lang, tgt_lang, src_tokenizer, tgt_tokenizer)
                print('\t\tLoaded data.')

                # Create dataloader.
                test_dataloader = DataLoader(dataset, batch_size=1, num_workers=hparams.num_workers)
                print('\t\tCreated data loader.')

                # Create model.
                model = Transformer(src_tokenizer,
                        tgt_tokenizer,
                        score_metric=score_metric)
                print('\t\tCreated model.')

                # Load model.
                model.load_state_dict(torch.load(mc.paths))
                model.to(device)
                print('\t\tLoaded model.')

                # Testing.
                for batch_idx, batch in enumerate(test_dataloader):
                    src_input, tgt_input, tgt_output = batch 
                    
                    # Convert preprocessed input back to text.
                    src_text = src_tokenizer.Decode(src_input.tolist())[0]
                    label_text = tgt_tokenizer.Decode(tgt_input.tolist())[0]

                    # Pass through model.
                    tgt_text = model.translate(src_text, method='sampling' if 'top' in method else method, kwargs=kwargs)

                    # Calculate metrics.
                    score = score_metric.compute(predictions=[tgt_text], references=[[label_text]])['score']
                    
                    # Accumulate metrics.
                    test_results[method] += score
                    print(f'\t\t{score}')
                test_results[method] = [test_results[method] / len(dataset)]

            elif mc.type == 'cascaded':
                src_lang, pvt_lang, tgt_lang = mc.langs

                # Load tokenizers.
                src_tokenizer = load_tokenizer(src_lang, tgt_lang)
                pvt_tokenizer = load_tokenizer(pvt_lang, tgt_lang)
                tgt_tokenizer = load_tokenizer(tgt_lang, src_lang)
                print('\t\tLoaded tokenizers.')

                # Load data.
                dataset = bc.pp_fn(data_dir, src_lang, tgt_lang, src_tokenizer, tgt_tokenizer)
                print('\t\tLoaded data.')

                # Create dataloader.
                test_dataloader = DataLoader(dataset, batch_size=1, num_workers=hparams.num_workers)
                print('\t\tCreated data loader.')

                # Create model.
                src_pvt_model = Transformer(src_tokenizer,
                                            pvt_tokenizer,
                                            score_metric=score_metric)
                pvt_tgt_model = Transformer(pvt_tokenizer,
                                            tgt_tokenizer,
                                            score_metric=score_metric)
                print('\t\tCreated models.')

                # Load model.
                src_pvt_model.load_state_dict(torch.load(mc.paths[0]))
                pvt_tgt_model.load_state_dict(torch.load(mc.paths[1]))
                src_pvt_model.to(device)
                pvt_tgt_model.to(device)
                print('\t\tLoaded models.')

                # Testing.
                for batch_idx, batch in enumerate(test_dataloader):
                    # Cascaded inference.
                    score, src_text, pvt_text, tgt_text, label_text = cascaded_inference(batch,
                                                                                         src_tokenizer, tgt_tokenizer,
                                                                                         src_pvt_model, pvt_tgt_model,
                                                                                         score_metric,
                                                                                         method='sampling' if 'top' in method else method,
                                                                                         kwargs=kwargs)
                    # Accumulate metrics.
                    test_results[method] += score
                    print(f'\t\t{score}')
                test_results[method] = [test_results[method] / len(dataset)]

            else:
                raise ValueError(f'Unknown model_config.type: {mc.type}')

        # Save recorded metrics.
        save_dict(results_dir, test_results, 'metrics')