## Imports and setup

In [None]:
from IPython import display
!pip install transformers
!pip install sacrebleu
!pip install sacremoses
!pip install datasets
!pip install wandb
!pip install sentencepiece
display.clear_output()

In [None]:
import datasets
from IPython import display
import numpy as np
import os
import pandas as pd
import random
import sentencepiece
import sacrebleu
import sacremoses
import tqdm
import transformers
import torch
import wandb

In [None]:
torch.cuda.is_available()

## Configuration

Alternatives for pre-training when translating to English: `Helsinki-NLP/opus-mt-lg-en`, `Helsinki-NLP/opus-mt-mul-en`.

Note 1: when training on V100 GPUs, there is more memory and `train_batch_size` can be increased (to 64?). If this is done then `gradient_accumulation_steps` should then be decreased accordingly, so that there is the same effective batch size.

Note 2: there is little difference in BLEU score when using a test set of 500 vs 1000 sentences per language. For rapid parameter tuning, we can therefore use `config['validation_samples_per_language'] = 500`, and then set it to 1000 for the best model config to report numbers in the paper.

In [None]:
# Parameters for mul-en models
config = {
    'source_language': 'mul',
    'target_language': 'en',
    'metric_for_best_model': 'loss',
    'train_batch_size': 20,
    'gradient_accumulation_steps': 150,
    'max_input_length': 128,
    'max_target_length': 128,
    'validation_samples_per_language': 500,
    'eval_batch_size': 16,
    'eval_languages': ["ach", "lgg", "lug", "nyn", "teo"],
    'eval_pretrained_model': False,
    'learning_rate': 1e-4,
    'num_train_epochs': 10,
    'label_smoothing_factor': 0.1,
    'flores101_training_data': True,
    'mt560_training_data': True,
    'back_translation_training_data': True,
    'named_entities_training_data': True,
}

config['language_pair'] = f'{config["source_language"]}-{config["target_language"]}'
config['wandb_project'] = f'sunbird-translate-{config["language_pair"]}'
config['model_checkpoint'] = f'Helsinki-NLP/opus-mt-{config["language_pair"]}'

# What training data to use
config['data_dir'] = f'v7-dataset/v7.0/supervised/{config["language_pair"]}/'

# Evaluate roughly every 10 minutes
eval_steps_interval = 350 * 60 * 7 / (config['gradient_accumulation_steps']
                                      * config['train_batch_size'])

eval_steps_interval = 10 * max(1, int(eval_steps_interval / 10))

print(f'Evaluating every {eval_steps_interval} training steps.')

config['train_settings'] = transformers.Seq2SeqTrainingArguments(
    f'output-{config["language_pair"]}',
    evaluation_strategy = 'steps',
    eval_steps = eval_steps_interval,
    save_steps = eval_steps_interval,
    gradient_accumulation_steps = config['gradient_accumulation_steps'],
    learning_rate = config['learning_rate'],
    per_device_train_batch_size = config['train_batch_size'],
    per_device_eval_batch_size = config['eval_batch_size'],
    weight_decay = 0.01,
    save_total_limit = 3,
    num_train_epochs = config['num_train_epochs'],
    predict_with_generate = True,
    fp16 = torch.cuda.is_available(),
    logging_dir = f'output-{config["language_pair"]}',
    report_to = 'wandb',
    run_name = f'{config["source_language"]}-{config["target_language"]}',
    load_best_model_at_end=True,
    metric_for_best_model = config['metric_for_best_model'],
    label_smoothing_factor = config['label_smoothing_factor']
)

MT560 is much bigger than the other training sets, so oversample the rest (by 5x) to balance it out.

In [None]:
config['training_subset_ids'] = [
    'train', 'train_ai4d',
    'val_ach', 'val_lgg', 'val_lug', 'val_nyn', 'val_teo',
]

if config['flores101_training_data']:
    config['training_subset_ids'] .append('train_flores_lug')

if config['back_translation_training_data']:
    config['training_subset_ids'].append('back_translated')

# Over-sample the non-religious training text
config['training_subset_ids'] = config['training_subset_ids'] * 5

if config['mt560_training_data']:
    config['training_subset_ids'].extend([
        'train_mt560_lug', 'train_mt560_ach', 'train_mt560_nyn',
    ])

if config['named_entities_training_data']:
    config['training_subset_ids'].append('named_entities')


# Set up datasets

Download the raw text data.

In [None]:
if not os.path.exists('v7-dataset'):
    !wget https://sunbird-translate.s3.us-east-2.amazonaws.com/v7-dataset.zip
    !unzip v7-dataset.zip
    display.clear_output()

Create a training set by interleaving separate training subsets.

Notes:
* This includes MT560 which has many examples (484,925), but which is biased towards religious text so we sample from it sparsely.
* We just use a 2-way train/test split for this experiment, so include the validation sentences in with the training set.
* LGG, ACH and TEO are oversampled a little by duplicating the validation sets, as a simple way to correct for there being more LUG and NYN training data.

In [None]:
def _file_to_list(path):
    with open(path) as file:
        lines = file.readlines()
        lines = [line.rstrip() for line in lines]
        return lines
    
def dataset_from_src_tgt_files(data_dir, dataset_id, read_first_n = 0):
    path = os.path.join(data_dir, dataset_id)
    source, target = [_file_to_list(path + '.src'),
                      _file_to_list(path + '.tgt')]
    if read_first_n:
        source = source[:read_first_n]
        target = target[:read_first_n]
    pairs = {'translation': [{config['source_language']: s,
                              config['target_language']: t}
                             for s, t in zip(source, target)]}
    return datasets.Dataset.from_dict(pairs)

In [None]:
training_subsets = [dataset_from_src_tgt_files(config['data_dir'], id)
                    for id in config['training_subset_ids']]
training_subsets = [s.shuffle() for s in training_subsets]

sample_probabilities = np.array([len(s) for s in training_subsets])
sample_probabilities = sample_probabilities / np.sum(sample_probabilities)

train_data_raw = datasets.interleave_datasets(
    training_subsets, sample_probabilities)

Make the separate validation sets

In [None]:
validation_subsets = [dataset_from_src_tgt_files(
    config['data_dir'], f'test_{lang}', read_first_n = config['validation_samples_per_language'])
    for lang in config['eval_languages']]
validation_data_raw = datasets.concatenate_datasets(validation_subsets)

## Helper functions

Note that whatever pre-processing we do here (punctuation normalisation and ensuring sentence case), we should also do at test-time when running the model on real queries.

In [None]:
def sentence_format(input):
    '''Ensure capital letter at the start and full stop at the end.'''
    input = input[0].capitalize() + input[1:]
    if input[-1] not in ['.', '!', '?']:
        input = input + '.'
    return input

def preprocess(examples):
    normalizer = sacremoses.MosesPunctNormalizer()
    
    inputs = [ex[config['source_language']] for ex in examples['translation']]
    targets = [ex[config['target_language']] for ex in examples['translation']]

    inputs = [sentence_format(normalizer.normalize(text))
              for text in inputs]
    targets = [sentence_format(normalizer.normalize(text))
               for text in targets]
    
    model_inputs = tokenizer(
        inputs, max_length=config['max_input_length'], truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            targets, max_length=config['max_target_length'], truncation=True)

    model_inputs["labels"] = labels["input_ids"]

    return model_inputs

def postprocess(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    return preds, labels

def compute_metrics(eval_preds, eval_languages, samples_per_language):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
        
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess(decoded_preds, decoded_labels)
    
    result = {}
    for i, lang in enumerate(eval_languages):
        result_subset = metric.compute(
            predictions=decoded_preds[i*samples_per_language:(i+1)*samples_per_language],
            references=decoded_labels[i*samples_per_language:(i+1)*samples_per_language])
        result[f"BLEU_{lang}"] = result_subset["score"]
        
    result["BLEU_mean"] = np.mean([result[f"BLEU_{lang}"] for lang in eval_languages])
    
    result = {k: round(v, 4) for k, v in result.items()}
    return result

# Training

Instantiate the model and tokenizer.

In [None]:
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(config['model_checkpoint'])
tokenizer = transformers.AutoTokenizer.from_pretrained(config['model_checkpoint'])
data_collator = transformers.DataCollatorForSeq2Seq(tokenizer, model = model) 
metric = datasets.load_metric('sacrebleu')

For multiple language outputs, we need to make sure the language codes have some mapping in the encoder. We can re-use the token indices of some other language codes in the pre-trained model that we don't need.

In `Helsinki-NLP/opus-mt-en-mul`, only Luganda (`lug`) is already supported.

In [None]:
if config['target_language'] == 'mul':
    replacements = {'nyn': 'kin',
                    'lgg': 'lin',
                    'ach': 'tso',
                    'teo': 'som',
                    'luo': 'sna'}
    for r in replacements:
        if (f'>>{r}<<' not in tokenizer.encoder and
            f'>>{replacements[r]}<<' in tokenizer.encoder):
            tokenizer.encoder[f">>{r}<<"] = tokenizer.encoder[f">>{replacements[r]}<<"]
            del tokenizer.encoder[f">>{replacements[r]}<<"]

    # Check that all the evaluation language codes are mapped to something.
    for r in config['eval_languages']:
        if f'>>{r}<<' not in tokenizer.encoder:
            raise ValueError(f'Language code {r} not found in the encoder.')

Pre-process the raw text datasets.

In [None]:
train_data  = train_data_raw.map(
    preprocess, remove_columns=["translation"], batched=True)

validation_data  = validation_data_raw.map(
    preprocess, remove_columns=["translation"], batched=True)

Launch the training.

In [None]:
wandb.init(project=config['wandb_project'], config=config)

trainer = transformers.Seq2SeqTrainer(
    model,
    config['train_settings'],
    train_dataset = train_data,
    eval_dataset = validation_data,
    data_collator = data_collator,
    tokenizer = tokenizer,
    compute_metrics = lambda x: compute_metrics(
        x, config['eval_languages'], config['validation_samples_per_language']),
    callbacks = [transformers.EarlyStoppingCallback(early_stopping_patience = 5)],
)

In [None]:
if config['eval_pretrained_model']:
    trainer.evaluate()

In [None]:
trainer.train()