In [None]:
from IPython import display
!pip install transformers
!pip install sacrebleu
!pip install sacremoses
!pip install datasets
!pip install wandb
!pip install sentencepiece
display.clear_output()


In [None]:
import datasets
from IPython import display
import numpy as np
import os
import pandas as pd
import random
import sentencepiece
import sacrebleu
import sacremoses
from tqdm import tqdm
import transformers
import torch
import wandb
import glob

In [None]:
torch.cuda.is_available()

In [None]:
# Parameters for mul-en models
config = {
    'source_language': 'mul',
    'target_language': 'en',
    'metric_for_best_model': 'loss',
    'train_batch_size': 20,
    'gradient_accumulation_steps': 150,
    'max_input_length': 128,
    'max_target_length': 128,
    'validation_samples_per_language': 500,
    'eval_batch_size': 16,
    'eval_languages': ["ach", "lgg", "lug", "nyn", "teo"],
    'eval_pretrained_model': True,
    'learning_rate': 0.0001,
    'num_train_epochs': 10,
    'label_smoothing_factor': 0.1,
    'flores101_training_data': True,
    'mt560_training_data': True,
    'back_translation_training_data': True,
    'back_translation_model_checkpoint': '/content/gdrive/Shareddrives/Sunbird AI/Projects/African Language Technology/Models/en-mul-ethereal-valley',
    'back_translation_token': True,
    'forward_translation': True,
    'forward_translation_percentage':0.1,
    #'forward_translation_token': True,
    'out_of_domain_token': True,
    'named_entities_training_data': False,
}

config['language_pair'] = f'{config["source_language"]}-{config["target_language"]}'
config['wandb_project'] = f'salt-fwd-bck-ood'
config['wandb_entity'] = f'sunbird'

#config['model_checkpoint'] = f'Helsinki-NLP/opus-mt-{config["language_pair"]}'
config['model_checkpoint'] = '/kaggle/input/nmt-backtranslation-ood-forward-translation/output-mul-en/checkpoint-800'

# What training data to use
config['data_dir'] = f'v7-dataset/v7.0/supervised/{config["language_pair"]}/'

# Evaluate roughly every 10 minutes
eval_steps_interval = 350 * 60 * 7 / (config['gradient_accumulation_steps']
                                      * config['train_batch_size'])

eval_steps_interval = 10 * max(1, int(eval_steps_interval / 10))

print(f'Evaluating every {eval_steps_interval} training steps.')

config['train_settings'] = transformers.Seq2SeqTrainingArguments(
    f'output-{config["language_pair"]}',
    evaluation_strategy = 'steps',
    eval_steps = eval_steps_interval,
    save_steps = eval_steps_interval,
    gradient_accumulation_steps = config['gradient_accumulation_steps'],
    learning_rate = config['learning_rate'],
    per_device_train_batch_size = config['train_batch_size'],
    per_device_eval_batch_size = config['eval_batch_size'],
    weight_decay = 0.01,
    save_total_limit = 3,
    num_train_epochs = config['num_train_epochs'],
    predict_with_generate = True,
    fp16 = torch.cuda.is_available(),
    logging_dir = f'output-{config["language_pair"]}',
    report_to = 'wandb',
    run_name = f'{config["source_language"]}-{config["target_language"]}',
    load_best_model_at_end=True,
    metric_for_best_model = config['metric_for_best_model'],
    label_smoothing_factor = config['label_smoothing_factor']
)

In [None]:
config['training_subset_ids'] = [
    'train', 'train_ai4d',
]

if config['forward_translation']:
    config['training_subset_ids'].extend(
    ["forward_" + training_subset for training_subset in config['training_subset_ids']]
    )


config['validation_subset_ids'] = [
    'val_ach', 'val_lgg', 'val_lug', 'val_nyn', 'val_teo'
]


if config['flores101_training_data']:
    config['training_subset_ids'] .append('train_flores_lug')

if config['back_translation_training_data']:
    config['training_subset_ids'].append('back_translated')

# Over-sample the non-religious training text
config['training_subset_ids'] = config['training_subset_ids'] * 5

if config['mt560_training_data']:
    config['training_subset_ids'].extend([
        'train_mt560_lug', 'train_mt560_ach', 'train_mt560_nyn',
    ])
    


if config['named_entities_training_data']:
    config['training_subset_ids'].append('named_entities')

In [None]:
import os
os.chdir("/kaggle/working/")
if not os.path.exists('v7-dataset'):
    !wget https://sunbird-translate.s3.us-east-2.amazonaws.com/v7-dataset.zip
    !unzip v7-dataset.zip
    display.clear_output()
!cp /kaggle/input/salt-extra-mul-en/* /kaggle/working/v7-dataset/v7.0/supervised/mul-en/

In [None]:
def _file_to_list(path):
    with open(path) as file:
        lines = file.readlines()
        lines = [line.rstrip() for line in lines]
        return lines
    
def dataset_from_src_tgt_files(data_dir, dataset_id, read_first_n = 0):
    path = os.path.join(data_dir, dataset_id)
    
    if config["forward_translation"] and dataset_id.startswith("forward_"):
        source, target = [_file_to_list(path.replace("forward_", "") + '.src'),
                      _file_to_list(path.replace("forward_", "") + '.src')]
        source = [">>fwd<< " + scentence for scentence in source ]
        
        if read_first_n:
            source = source[:read_first_n]
            target = target[:read_first_n]
        
        if config["forward_translation_percentage"] < 1.0:
            cutoff_idx = int(len(source)*config["forward_translation_percentage"])
            c = list(zip(source, target))
            random.shuffle(c)
            source, target = zip(*c)
            source = source[:cutoff_idx]
            target = target[:cutoff_idx]

        
    else:
        source, target = [_file_to_list(path + '.src'),
                      _file_to_list(path + '.tgt')]
        
        if read_first_n:
            source = source[:read_first_n]
            target = target[:read_first_n]
        
        if config['out_of_domain_token'] and "mt560" in dataset_id :
            source = [">>ood<< " + scentence for scentence in source]
        
    
        
    if dataset_id == "back_translated" and config["back_translation_token"]:
        source = [">>bck<< " + scentence for scentence in source ] #Do we add a space after token?
        
    
    pairs = {'translation': [{config['source_language']: s,
                              config['target_language']: t}
                             for s, t in zip(source, target)]}
    
    return datasets.Dataset.from_dict(pairs)

In [None]:
training_subsets = [dataset_from_src_tgt_files(config['data_dir'], id)
                    for id in config['training_subset_ids']]
training_subsets = [s.shuffle() for s in training_subsets]

sample_probabilities = np.array([len(s) for s in training_subsets])
sample_probabilities = sample_probabilities / np.sum(sample_probabilities)

train_data_raw = datasets.interleave_datasets(
    training_subsets, sample_probabilities)

In [None]:
validation_subsets = [dataset_from_src_tgt_files(
                                            config['data_dir'], 
                                            id, 
                                            read_first_n = config['validation_samples_per_language']
                                            )
                                        for id in config['validation_subset_ids']]


validation_data_raw = datasets.concatenate_datasets(validation_subsets)

In [None]:
def sentence_format(input):
    '''Ensure capital letter at the start and full stop at the end.'''
    input = input[0].capitalize() + input[1:]
    if input[-1] not in ['.', '!', '?']:
        input = input + '.'
    return input

def preprocess(examples):
    normalizer = sacremoses.MosesPunctNormalizer()
    
    inputs = [ex[config['source_language']] for ex in examples['translation']]
    targets = [ex[config['target_language']] for ex in examples['translation']]

    inputs = [sentence_format(normalizer.normalize(text))
              for text in inputs]
    targets = [sentence_format(normalizer.normalize(text))
               for text in targets]
    
    model_inputs = tokenizer(
        inputs, max_length=config['max_input_length'], truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            targets, max_length=config['max_target_length'], truncation=True)

    model_inputs["labels"] = labels["input_ids"]

    return model_inputs

def postprocess(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    return preds, labels

def compute_metrics(eval_preds, eval_languages, samples_per_language):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
        
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess(decoded_preds, decoded_labels)
    
    result = {}
    for i, lang in enumerate(eval_languages):
        result_subset = metric.compute(
            predictions=decoded_preds[i*samples_per_language:(i+1)*samples_per_language],
            references=decoded_labels[i*samples_per_language:(i+1)*samples_per_language])
        result[f"BLEU_{lang}"] = result_subset["score"]
        
    result["BLEU_mean"] = np.mean([result[f"BLEU_{lang}"] for lang in eval_languages])
    
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [None]:
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(config['model_checkpoint'])
tokenizer = transformers.AutoTokenizer.from_pretrained(config['model_checkpoint'])
data_collator = transformers.DataCollatorForSeq2Seq(tokenizer, model = model) 
metric = datasets.load_metric('sacrebleu')

In [None]:
if config['back_translation_token']: 
    #replacing र
    tokenizer.encoder[">>bck<<"] = tokenizer.encoder["र"] #or plug 735
    
if config['forward_translation']: 
    tokenizer.encoder[">>fwd<<"] = tokenizer.encoder["را"] #or plug 1293
    
    
if config['out_of_domain_token']: 
    tokenizer.encoder[">>ood<<"] = tokenizer.encoder["ش"] #or plug 471
    
#display.clear_output()


In [None]:
if config['target_language'] == 'mul':
    replacements = {'bck': 'kin',
                    'lgg': 'lin',
                    'ach': 'tso',
                    'teo': 'som',
                    'luo': 'sna',

                   }
    for r in replacements:
        if (f'>>{r}<<' not in tokenizer.encoder and
            f'>>{replacements[r]}<<' in tokenizer.encoder):
            tokenizer.encoder[f">>{r}<<"] = tokenizer.encoder[f">>{replacements[r]}<<"]
            del tokenizer.encoder[f">>{replacements[r]}<<"]

    # Check that all the evaluation language codes are mapped to something.
    for r in config['eval_languages']:
        if f'>>{r}<<' not in tokenizer.encoder:
            raise ValueError(f'Language code {r} not found in the encoder.')

In [None]:
train_data  = train_data_raw.map(
    preprocess, remove_columns=["translation"], batched=True)

validation_data  = validation_data_raw.map(
    preprocess, remove_columns=["translation"], batched=True)

In [None]:
#optimizer = transformers.AdamW( model.parameters(), lr = config["learning_rate"])
#total_steps_warmup = (config['num_train_epochs'] * len(train_data))//2
#scheduler = transformers.get_linear_schedule_with_warmup( optimizer, num_warmup_steps=total_steps_warmup, num_training_steps=total_steps_warmup ) 
#optimizers = (optimizer, scheduler)

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_key")
import os
os.environ["WANDB_API_KEY"] = secret_value_0
wandb.init(project=config['wandb_project'],entity=config["wandb_entity"], config=config)

trainer = transformers.Seq2SeqTrainer(
    model,
    config['train_settings'],
    train_dataset = train_data,
    eval_dataset = validation_data,
    data_collator = data_collator,
    tokenizer = tokenizer,
    #optimizers= optimizers,
    compute_metrics = lambda x: compute_metrics(
        x, config['eval_languages'], config['validation_samples_per_language']),
    callbacks = [transformers.EarlyStoppingCallback(early_stopping_patience = 5)],
)

In [None]:
trainer.train()

In [None]:
if config['eval_pretrained_model']:
    trainer.evaluate()

In [None]:
if config['eval_pretrained_model']:
    trainer.evaluate(num_beams=4)

In [None]:
wandb.finish()