## Training notebook for the SALT dataset using a pretrained mBART50 model. 

### Training Steps:
* Loading libraries
* Loading dataset files
* Preprocessing, tokenizing and adding source language tokens 
* Loading model 
* Training
* Saving Results

In [None]:
from IPython import display
!pip install transformers
!pip install sacrebleu
!pip install sacremoses
!pip install datasets
!pip install wandb
!pip install sentencepiece
display.clear_output()

In [None]:
import datasets
from IPython import display
import numpy as np
import os
import pandas as pd
import random
import sentencepiece
import sacrebleu
import sacremoses
import tqdm
import transformers
import torch
import wandb

from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, Seq2SeqTrainer

In [None]:
os.mkdir("/kaggle/temp")
os.chdir("/kaggle/temp")

Using a config dictionary to store hyperparameters and arguments is useful, as it can be passed on to wandb along with the performance. Allows for ease of replication.


In [None]:
# Parameters for mul-en models
config = {
    'source_languages': ["ach", "lgg", "lug", "nyn", "teo"],
    'target_languages': ['en'],
    'metric_for_best_model': 'loss',
    'train_batch_size': 1,
    'gradient_accumulation_steps': 2400,
    'max_input_length': 128,
    'max_target_length': 128,
    'validation_samples_per_language': 500,
    'validation_train_merge': True,
    'eval_batch_size': 1,
    'eval_languages': ["ach", "lgg", "lug", "nyn", "teo"],
    'eval_pretrained_model': False,
    'learning_rate': 1e-4,
    'num_train_epochs': 2,
    'label_smoothing_factor': 0.1,
    'flores101_training_data': True,
    'mt560_training_data': True,
    'back_translation_training_data': False,
    'front_translation_training_data': False, #not implemented
    'named_entities_training_data': False,
    'recycle_language_tokens': True
}

config['language_pair'] = f'salt-en'
config['wandb_project'] = f'salt-mbart'
config['wandb_entity'] = f'sunbird'

config['model_checkpoint'] = f'facebook/mbart-large-50'

# What training data to use
config['training_extra_data_dir'] = f'v7-dataset/v7.0/supervised/mul-en/'

# Evaluate roughly every 10 minutes
eval_steps_interval = 350 * 60 * 7 / (config['gradient_accumulation_steps']
                                      * config['train_batch_size'])

eval_steps_interval = 4 * max(1, int(eval_steps_interval / 10))

print(f'Evaluating every {eval_steps_interval} training steps.')

config['train_settings'] = transformers.Seq2SeqTrainingArguments(
    f'output-{config["language_pair"]}',
    evaluation_strategy = 'steps',
    eval_steps = eval_steps_interval,
    save_steps = eval_steps_interval,
    gradient_accumulation_steps = config['gradient_accumulation_steps'],
    learning_rate = config['learning_rate'],
    per_device_train_batch_size = config['train_batch_size'],
    per_device_eval_batch_size = config['eval_batch_size'],
    weight_decay = 0.01,
    save_total_limit = 3,
    num_train_epochs = config['num_train_epochs'],
    predict_with_generate = True,
    fp16 = torch.cuda.is_available(),
    logging_dir = f'output-{config["language_pair"]}',
    report_to = 'wandb',
    run_name = f'{config["language_pair"]}',
    load_best_model_at_end=True,
    metric_for_best_model = config['metric_for_best_model'],
    label_smoothing_factor = config['label_smoothing_factor'],
)

Loading the data is done by specifying the source and destination languages for each pair, a bit redundant as redundant languages are loaded multiple times.
It is possible to implement this as a generator or some other type of lazy loading thing. However the dataset being small enough it is not important.

In [None]:
eval_steps_interval# lang_token, path
## Be careful to keep the order the same for source and target dataset pairs
config['training_subset_paths'] = [
        {
            "source":{"language":"ach",
                   "path":"v7-dataset/v7.0/supervised/en-ach/train.ach"},
            "target":{"language":"en",
                   "path":"v7-dataset/v7.0/supervised/en-lug/train.en"} 
        },
        {
            "source":{"language":"lgg",
                   "path":"v7-dataset/v7.0/supervised/en-lgg/train.lgg"},
            "target":{"language":"en",
                   "path":"v7-dataset/v7.0/supervised/en-lug/train.en"} 
        },
        {
            "source":{"language":"lug",
                   "path":"v7-dataset/v7.0/supervised/en-lug/train.lug"},
            "target":{"language":"en",
                   "path":"v7-dataset/v7.0/supervised/en-lug/train.en"} 
        },
        {
            "source":{"language":"nyn",
                   "path":"v7-dataset/v7.0/supervised/en-nyn/train.nyn"},
            "target":{"language":"en",
                   "path":"v7-dataset/v7.0/supervised/en-lug/train.en"} 
        },
        
        {
            "source":{"language":"teo",
                   "path":"v7-dataset/v7.0/supervised/en-teo/train.teo"},
            "target":{"language":"en",
                   "path":"v7-dataset/v7.0/supervised/en-lug/train.en"} 
        }
    
   ]


    
config['validation_subset_paths'] = [
        
        {
            "source":{"language":"ach",
                   "path":"v7-dataset/v7.0/supervised/en-ach/val.ach"},
            "target":{"language":"en",
                   "path":"v7-dataset/v7.0/supervised/en-lug/val.en"} 
        },
        {
            "source":{"language":"lgg",
                   "path":"v7-dataset/v7.0/supervised/en-lgg/val.lgg"},
            "target":{"language":"en",
                   "path":"v7-dataset/v7.0/supervised/en-lug/val.en"} 
        },
        {
            "source":{"language":"lug",
                   "path":"v7-dataset/v7.0/supervised/en-lug/val.lug"},
            "target":{"language":"en",
                   "path":"v7-dataset/v7.0/supervised/en-lug/val.en"} 
        },
        {
            "source":{"language":"nyn",
                   "path":"v7-dataset/v7.0/supervised/en-nyn/val.nyn"},
            "target":{"language":"en",
                   "path":"v7-dataset/v7.0/supervised/en-lug/val.en"} 
        },
        {
            "source":{"language":"teo",
                   "path":"v7-dataset/v7.0/supervised/en-teo/val.teo"},
            "target":{"language":"en",
                   "path":"v7-dataset/v7.0/supervised/en-lug/val.en"} 
        }
    
   ]

#why not luo?
if config['flores101_training_data']:
    flores_dict = {
        "source":{
            "language":"lug",
            "path":"v7-dataset/v7.0/supervised/mul-en/train_flores_lug.src"
        },
        "target":{
            "language":"en",
            "path":"v7-dataset/v7.0/supervised/mul-en/train_flores_lug.tgt"
        }
    }
    config['training_subset_paths'].append(flores_dict)

# if config['back_translation_training_data']:
#     raise NotImplementedError("Have not split bt data by language yet")
#     config['training_subset_ids'].append('back_translated')

# Over-sample the non-religious training text
#config['training_subset_ids'] = config['training_subset_ids'] * 5
# Will oversample from interleave datasets

if config['mt560_training_data']:
    mt560_list = [
        {
            "source":{
            "language":"ach",
            "path":"v7-dataset/v7.0/supervised/mul-en/train_mt560_ach.src"
            },
            "target":{
                "language":"en",
                "path":"v7-dataset/v7.0/supervised/mul-en/train_mt560_ach.tgt"
            }  
        },
        {
            "source":{
            "language":"lug",
            "path":"v7-dataset/v7.0/supervised/mul-en/train_mt560_lug.src"
            },
            "target":{
                "language":"en",
                "path":"v7-dataset/v7.0/supervised/mul-en/train_mt560_lug.tgt"
            }
        },
        {
            "source":{
            "language":"nyn",
            "path":"v7-dataset/v7.0/supervised/mul-en/train_mt560_nyn.src"
            },
            "target":{
                "language":"en",
                "path":"v7-dataset/v7.0/supervised/mul-en/train_mt560_nyn.tgt"
            }
        }        

    ]
    config['training_subset_paths'].extend(mt560_list)

# if config['named_entities_training_data']:
#     rasie NotImplementedError("NER pairs are aggregate not separate")
#     config['training_subset_ids'].append('named_entities')

In [None]:
if not os.path.exists('v7-dataset'):
    !wget https://sunbird-translate.s3.us-east-2.amazonaws.com/v7-dataset.zip
    !unzip v7-dataset.zip
    display.clear_output()

This is where the model is loaded. Ideally any changes/arguments here should be passed from the config dict.

In [None]:
model = MBartForConditionalGeneration.from_pretrained(config["model_checkpoint"])
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")
data_collator = transformers.DataCollatorForSeq2Seq(tokenizer, model = model) 
metric = datasets.load_metric('sacrebleu')

Since we're using a pretrained model with its own tokenizer/bpe encoding we need to either create new tokens and expand the existing model embeddings,
as specified in this link https://www.depends-on-the-definition.com/how-to-add-new-tokens-to-huggingface-transformers/ (Feel free to add this if you need it). Or since the mBART50 is trained on 50 languages, 49 of which we don't need. Then we could simply proceed to use the tokens of unsused languages for our own languages.

From a code cleanliness prespective it is better to change the tokenizer's values for the different languages so that when we use tokenizer.encode("teo") we get a numerical value directly, but replacing all incidences of "teo" with "ar_AR" or similar would also work.

In [None]:
if config["recycle_language_tokens"]:
    token_conversion_dict = {
        "teo": 'ar_AR' ,
        "ach": 'cs_CZ',
        "lug": 'de_DE',
        "lgg": 'es_XX',
        "nyn": 'et_EE',
        "en": 'en_XX'
        
     }
else:
    raise NotImplementedError("Code to add tokens and resize embedding layer not added")
    # If you want to add it refer to https://www.depends-on-the-definition.com/how-to-add-new-tokens-to-huggingface-transformers/

Since validation takes long to run, we limit validation data based on the validation_cutoff argument, that we provide a value for in config["validation_samples_per_language"]. such that for each file in the config["validation_subset_paths"] we only load a subset of the validation data, whether to discard the rest or add to the training set is controlled by the config["validation_train_merge"] value (True for use extra validation data in training, False for discard)

In [None]:
def _file_to_list(path):
    with open(path) as file:
        lines = file.readlines()
        lines = [line.rstrip() for line in lines]
        return lines

def dataset_from_folders(pair_dicts_list, language_token_dict, validation_cutoff = 0,mode = "cutoff_maximum"):

    list_of_pairs = []
    for language_pair_dict in pair_dicts_list:
        src_language = language_pair_dict["source"]["language"]
        src_scentences = _file_to_list(language_pair_dict["source"]["path"])
        tgt_language = language_pair_dict["target"]["language"]
        tgt_scentences = _file_to_list(language_pair_dict["target"]["path"])
                
        if validation_cutoff:
            if mode == "cutoff_maximum":
                src_scentences = src_scentences[:validation_cutoff]
                tgt_scentences = tgt_scentences[:validation_cutoff]
            elif mode == "cutoff_minimum":
                src_scentences = src_scentences[validation_cutoff:]
                tgt_scentences = tgt_scentences[validation_cutoff:]

        # pairs = {'translation': [{src_language: s,
        #                     tgt_language: t}
        #                      for s, t in zip(src_scentences, tgt_scentences)]}
        
        src_scentences = [language_token_dict[src_language] + " " + src for src in src_scentences]
        tgt_scentences = [language_token_dict[tgt_language] + " " + tgt for tgt in tgt_scentences]
        

        pairs = {'translation': [{"src": s,
                            "tgt": t}
                             for s, t in zip(src_scentences, tgt_scentences)]}

        list_of_pairs.append(datasets.Dataset.from_dict(pairs))

    return list_of_pairs

def dataset_from_src_tgt_files(data_dir, dataset_id, validation_cutoff = 0, mode = "train"):
    """
        validation_cutoff: use first n lines as validation
    """

    path = os.path.join(data_dir, dataset_id)
    source, target = [_file_to_list(path + '.src'),
                      _file_to_list(path + '.tgt')]
    if mode == "cutoff_maximum":
        source = source[:validation_cutoff]
        target = target[:validation_cutoff]
    elif mode == "cutoff_minimum":
        source = source[validation_cutoff:]
        target = target[validation_cutoff:]

    #pairs = {'translation': [{config['source_language']: s,
    #                          config['target_language']: t}
    #                         for s, t in zip(source, target)]}
    
    pairs = {'translation': [{config['source_language']: s,
                             config['target_language']: t}
                            for s, t in zip(source, target)]}
    
    return datasets.Dataset.from_dict(pairs)

The MT560 dataset is larger than the others with a focus on religious data. The luganda-mt560 is especially massive, we undersample the mt560 scentence pairs using the sample_probabilities array, which tells the datasets.interleave_datasets how much to sample from each dataset to make the final combined dataset.

In [None]:
training_subsets = dataset_from_folders(config['training_subset_paths'],
                                        token_conversion_dict,
                                        validation_cutoff = 0)
if config["validation_train_merge"]: 
    extra_training_data = dataset_from_folders(config['validation_subset_paths'], 
                                               token_conversion_dict,
                                               validation_cutoff = config['validation_samples_per_language'],
                                               mode = "cutoff_minimum")
    training_subsets.extend(extra_training_data)

training_subsets = [s.shuffle() for s in training_subsets]


sample_probabilities = np.array([len(s) for s in training_subsets])
sample_probabilities[6] = sample_probabilities[6]//10 #downsample mt560 ach by a factor of 10
sample_probabilities[7] = sample_probabilities[7]//20 #downsample mt560 lug by a factor of 20
sample_probabilities[8] = sample_probabilities[8]//10 #downsample mt560 nyn by a factor of 10

sample_probabilities = sample_probabilities / np.sum(sample_probabilities)

train_data_raw = datasets.interleave_datasets( 
    training_subsets, sample_probabilities)

In [None]:
validation_subsets = dataset_from_folders(config['validation_subset_paths'], 
                                                token_conversion_dict,
                                               validation_cutoff = config['validation_samples_per_language'],
                                               mode = "cutoff_maximum")
    

validation_data_raw = datasets.concatenate_datasets(validation_subsets)

In [None]:
def sentence_format(input):
    '''Ensure capital letter at the start and full stop at the end.'''
    input = input[0].capitalize() + input[1:]
    if input[-1] not in ['.', '!', '?']:
        input = input + '.'
    return input

def preprocess(examples):
    normalizer = sacremoses.MosesPunctNormalizer()
    inputs = [ex["src"] for ex in examples['translation']]
    targets = [ex["tgt"] for ex in examples['translation']]

    inputs = [sentence_format(normalizer.normalize(text))
              for text in inputs]
    targets = [sentence_format(normalizer.normalize(text))
               for text in targets]
    
    model_inputs = tokenizer(
        inputs, max_length=config['max_input_length'], truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            targets, max_length=config['max_target_length'], truncation=True)

    model_inputs["labels"] = labels["input_ids"]

    return model_inputs

def postprocess(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    return preds, labels

def compute_metrics(eval_preds, eval_languages, samples_per_language):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
        
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess(decoded_preds, decoded_labels)
    
    result = {}
    for i, lang in enumerate(eval_languages):
        result_subset = metric.compute(
            predictions=decoded_preds[i*samples_per_language:(i+1)*samples_per_language],
            references=decoded_labels[i*samples_per_language:(i+1)*samples_per_language])
        result[f"BLEU_{lang}"] = result_subset["score"]
        
    result["BLEU_mean"] = np.mean([result[f"BLEU_{lang}"] for lang in eval_languages])
    
    result = {k: round(v, 4) for k, v in result.items()}
    return result

This is where the tokenizer sets the target language token, for the case of generating multiple languages from this same model, we need a way to be able to inject a different language token for each desired output language.

In [None]:
tokenizer.tgt_lang = 'en_XX'

In [None]:


train_data  = train_data_raw.map(
    preprocess, remove_columns=["translation"], batched=True)

validation_data  = validation_data_raw.map(
    preprocess, remove_columns=["translation"], batched=True)


In [None]:
model.config.use_cache = False

Cross your fingers, say your prayers, train !

In [None]:
os.environ["WANDB_API_KEY"] = "ENTER YOUR WANDB API KEY HERE"
wandb.init(project=config['wandb_project'],entity=config["wandb_entity"], config=config)

trainer = transformers.Seq2SeqTrainer(
    model,
    config['train_settings'],
    train_dataset = train_data,
    eval_dataset = validation_data,
    data_collator = data_collator,
    tokenizer = tokenizer,
    compute_metrics = lambda x: compute_metrics(
        x, config['eval_languages'], config['validation_samples_per_language']),
    callbacks = [transformers.EarlyStoppingCallback(early_stopping_patience = 5)],
)

In [None]:
if config['eval_pretrained_model']:
    trainer.evaluate()

In [None]:
trainer.train()

In [None]:
trainer.save_model("/kaggle/working/best_mBART_salt")