# Evaluate Models on Benchmarks

In [None]:
# Static experiment settings.
experiment = 'benchmark'

## Setup

### Environment

In [None]:
# If this is a notebook which is executed in colab [in_colab=True]:
#  1. Mount google drive and use the repository in there [mount_drive=True] (the repository must be in your google drive root folder).
#  2. Clone repository to remote machine [mount_drive=False].
in_colab = False
mount_drive = True

try:
    # Check if running in colab.
    in_colab = 'google.colab' in str(get_ipython())
except:
    pass

if in_colab:
    if mount_drive:
        # Mount google drive and navigate to it.
        from google.colab import drive
        drive.mount('/content/drive')
        %cd drive/MyDrive
    else:
        # Pull repository.
        !git clone https://github.com/HenningBuhl/low-resource-machine-translation

    # Workaround for problem with undefined symbols (https://github.com/scverse/scvi-tools/issues/1464).
    !pip install --quiet scvi-colab
    from scvi_colab import install
    install()

    # Navigate to the repository and install requirements.
    %cd low-resource-machine-translation
    !pip install -r requirements.txt

    # Navigate to notebook location.
    %cd experiments

In [None]:
# Add src module directory to system path for subsecuent imports.
import sys
sys.path.insert(0, '../src')

In [None]:
from util import is_notebook

# Settings and module reloading (only in Jupyter Notebooks).
if is_notebook():
    # Module reloading.
    %load_ext autoreload
    %autoreload 2

    # Plot settings.
    %matplotlib inline

### Imports

In [None]:
# From packages.
import pytorch_lightning as pl
import argparse
from distutils.util import strtobool

# From repository.
from arg_management import *
from constants import *
from data import *
from layers import *
from metric_logging import *
from plotting import *
from path_management import *
from tokenizer import *
from transformer import *
from util import *

### Arguments

In [None]:
# Define arguments with argparse.
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

# Experiment.
parser.add_argument('--seed', default=0, type=int, help='The random seed of the program.')
parser.add_argument('--inferece-methods', default='greedy', type=str, nargs="*", choices=['greedy', 'beam-search', 'top-k', 'top-p'], help='The inference methods used.')
parser.add_argument('--beam-size', default=8, type=int, nargs="*", help='The number of different beam sizes to be used.')
parser.add_argument('--top-k', default=15, type=int, nargs="*", help='The differnt top-Ks being used.')
parser.add_argument('--top-p', default=0.7, type=int, nargs="*", help='The differnt top-ps being used.')

# Metrics.
arg_manager.add_metrics_args(parser)

# Parse args.
if is_notebook():
    sys.argv = ['-f']  # Used to make argparse work in jupyter notebooks (all args must be optional).
    args, _ = parser.parse_known_args()  # -f can lead to unknown argument.
else:
    args = parser.parse_args()

# Print args.
print('Arguments:')
print(args)

In [None]:
# Auto-infer args.
auto_infer_args(args, experiment)

In [None]:
# Adjust arguments for test purposes.
if is_notebook() and True:  # Quickly turn on and off with 'and True/False'.
    #args.dev_run = True
    #args.fresh_run = True

    print('Adjusted args in notebook')

In [None]:
# Sanity check args.
sanity_check_args(args)

### Seed

In [None]:
# Set seed.
from pytorch_lightning import seed_everything
seed_everything(args.seed, workers=True)

### Paths

In [None]:
# Read and create directories and files.
bm = BenchmarkManager()
bm.init()

In [None]:
# Save arguments.
save_dict(bm.args_file, args.__dict__)

## Benchmark

In [None]:
# Create run dir.
run_dir = os.path.join(CONST_RUNS_DIR, f'benchmark-{get_time_as_string()}')
create_dir(run_dir)

# Which metrics to record.
track_metrics = []  # TODO from args.

In [None]:
# Iterate over benchmarks.
for benchmark_name in get_dirs(CONST_BENCHMARKS_DIR):
    print(f'Benchmark: {benchmark_name}')

    # Create benchmark result dir.
    benchmark_dir = os.path.join(run_dir, benchmark_name)
    create_dir(benchmark_dir)

    # Load benchmark data.
    # TODO (with BenchmarkDataPreProcessor [automatically detect /de,/en or /de-en])

    # Iterate over models.
    for model_name in get_dirs(CONST_MODELS_DIR):
        print(f'Model: {model_name}')

        # Create model model dir.
        model_dir = os.path.join(run_dir, model_name)
        create_dir(model_dir)

        # Load model args.
        args = load_dict(os.path.join(CONST_MODELS_DIR, model_name, 'args.json'))

        # Load tokenizers and model(s).
        experiment_type = args['experiment']
        if experiment_type == 'cascaded':
            # Load tokenizers.
            # TODO

            # Load model(s).
            # TODO

            # Create function that translates input text and returns it (experiment_type agnostic for further code below).
            # TODO
            pass
        else:
            pass

        # Iterate over inference methods.
        for method_name in args.inference_methods:
            print(f'Method: {method_name}')

            # Iterate over inference method params.
            values = ... args... # None for greedy?
            for value in values:
                #print(f'Value: {value}')
                ...

                # Perform inference.
                # TODO

                # Calculate metrics.
                metrics = {}
                # TODO

                # Save metrics to separate file each.
                # TODO
                # for metric in track_metrics
                #     TODO calculate...
                #     os.path.join(model_dir, method_name, f'{'' if value is None else (value+"/")}{metric}.json')
                #     value = metrics[metric]

In [None]:
# Perform benchmark...
for bc in benchmark_configs:
    print(f'Performing "{bc.name}" benchmark on {len(model_configs)} models.')
    
    # Create directories.
    benchmark_dir = os.path.join(run_dir, bc.name)
    dirs = [benchmark_dir]
    for dir in dirs:
        create_dir(dir)
    
    # Download and unpack data.
    bc.collate_fn(data_dir)
    
    # ... on every model.
    for mc in model_configs:
        print(f'\tBenchmarking {mc.name} ({mc.langs}) on {bc.name}.')
        
        # Create directories.
        results_dir = os.path.join(benchmark_dir, mc.name)
        dirs = [results_dir]
        for dir in dirs:
            create_dir(dir)
        
        method_kwargs = {
            'greedy': {},
            'beam': {'beam_size': hparams.beam_size},
            'top_k': {'top_k': hparams.top_k},
            'top_p': {'top_p': hparams.top_p},
        }
        test_results = {k: 0 for k in method_kwargs.keys()}
        for method, kwargs in method_kwargs.items():
            print(f'\t\tBenchmarking {mc.name} ({mc.langs}) on {bc.name} with inference method {method}.')

            # Perform evaluation base on mc.type.
            if mc.type == 'single':
                src_lang, tgt_lang = mc.langs

                # Load tokenizers.
                src_tokenizer = load_tokenizer(src_lang, tgt_lang)
                tgt_tokenizer = load_tokenizer(tgt_lang, src_lang)
                print('\t\tLoaded tokenizers.')

                # Load data.
                dataset = bc.pp_fn(data_dir, src_lang, tgt_lang, src_tokenizer, tgt_tokenizer)
                print('\t\tLoaded data.')

                # Create dataloader.
                test_dataloader = DataLoader(dataset, batch_size=1, num_workers=hparams.num_workers)
                print('\t\tCreated data loader.')

                # Create model.
                model = Transformer(src_tokenizer,
                        tgt_tokenizer,
                        score_metric=score_metric)
                print('\t\tCreated model.')

                # Load model.
                model.load_state_dict(torch.load(mc.paths))
                model.to(device)
                print('\t\tLoaded model.')

                # Testing.
                for batch_idx, batch in enumerate(test_dataloader):
                    src_input, tgt_input, tgt_output = batch 
                    
                    # Convert preprocessed input back to text.
                    src_text = src_tokenizer.Decode(src_input.tolist())[0]
                    label_text = tgt_tokenizer.Decode(tgt_input.tolist())[0]

                    # Pass through model.
                    tgt_text = model.translate(src_text, method='sampling' if 'top' in method else method, kwargs=kwargs)

                    # Calculate metrics.
                    score = score_metric.compute(predictions=[tgt_text], references=[[label_text]])['score']
                    
                    # Accumulate metrics.
                    test_results[method] += score
                    print(f'\t\t{score}')
                test_results[method] = [test_results[method] / len(dataset)]

            elif mc.type == 'cascaded':
                src_lang, pvt_lang, tgt_lang = mc.langs

                # Load tokenizers.
                src_tokenizer = load_tokenizer(src_lang, tgt_lang)
                pvt_tokenizer = load_tokenizer(pvt_lang, tgt_lang)
                tgt_tokenizer = load_tokenizer(tgt_lang, src_lang)
                print('\t\tLoaded tokenizers.')

                # Load data.
                dataset = bc.pp_fn(data_dir, src_lang, tgt_lang, src_tokenizer, tgt_tokenizer)
                print('\t\tLoaded data.')

                # Create dataloader.
                test_dataloader = DataLoader(dataset, batch_size=1, num_workers=hparams.num_workers)
                print('\t\tCreated data loader.')

                # Create model.
                src_pvt_model = Transformer(src_tokenizer,
                                            pvt_tokenizer,
                                            score_metric=score_metric)
                pvt_tgt_model = Transformer(pvt_tokenizer,
                                            tgt_tokenizer,
                                            score_metric=score_metric)
                print('\t\tCreated models.')

                # Load model.
                src_pvt_model.load_state_dict(torch.load(mc.paths[0]))
                pvt_tgt_model.load_state_dict(torch.load(mc.paths[1]))
                src_pvt_model.to(device)
                pvt_tgt_model.to(device)
                print('\t\tLoaded models.')

                # Testing.
                for batch_idx, batch in enumerate(test_dataloader):
                    # Cascaded inference.
                    score, src_text, pvt_text, tgt_text, label_text = cascaded_inference(batch,
                                                                                         src_tokenizer, tgt_tokenizer,
                                                                                         src_pvt_model, pvt_tgt_model,
                                                                                         score_metric,
                                                                                         method='sampling' if 'top' in method else method,
                                                                                         kwargs=kwargs)
                    # Accumulate metrics.
                    test_results[method] += score
                    print(f'\t\t{score}')
                test_results[method] = [test_results[method] / len(dataset)]

            else:
                raise ValueError(f'Unknown model_config.type: {mc.type}')

        # Save recorded metrics.
        save_dict(results_dir, test_results, 'metrics')