In [None]:
from IPython import display
!pip install transformers
!pip install sacrebleu
!pip install sacremoses
!pip install datasets
!pip install wandb
!pip install sentencepiece
!pip install numpy requests nlpaug
!wget https://raw.githubusercontent.com/SunbirdAI/nmt_training/main/salt_v2/salt.py
!wget https://raw.githubusercontent.com/SunbirdAI/nmt_training/main/nmt_clean/augmentations.py
display.clear_output()

In [None]:
from augmentations import Augmentations
import datasets
from IPython import display
import nlpaug.augmenter.char as nac
import nlpaug.augmenter.word as naw
import numpy as np
import os
import pandas as pd
import random
import sentencepiece
import sacrebleu
import sacremoses
import salt
import tqdm
import transformers
import torch
import wandb


In [None]:
torch.cuda.is_available()

In [None]:
# Parameters for mul-en models
config = {
    'source_language': 'many',
    'target_language': 'eng',
    'metric_for_best_model': 'loss',
    'effective_train_batch_size': 5000,
    'max_input_length': 128,
    'validation_samples_per_language': 500,
    'eval_languages': ["ach", "lgg", "lug", "nyn", "teo", "luo"],
    'eval_pretrained_model': True,
    'learning_rate': 5e-5,
    'num_train_epochs': 10,
    'label_smoothing_factor': 0.1,
    'mt560_relative_sample_rate' : 0.2,
    'flores200_training_data': True,
    'mt560_training_data': True,
    'monolingual_training_data': False,
    'back_translation_training_data': True,
    'google_back_translation_data': True,
    'named_entities_training_data': False,
    'lafand_training_data': True,
    'tag_subsets': True,
    'early_stopping_patience': 4,
    'eval_steps_interval': 50,
    'data_dir': 'salt-translation-plus-external-datasets-15-3-23',
}

config['language_pair'] = (f'{config["source_language"]}-'
                           f'{config["target_language"]}')
config['wandb_project'] = f'sunbird-translate-{config["language_pair"]}'
config['model_checkpoint'] = f'/kaggle/input/nmt-marianmt-w-mafand/best/checkpoint-600'
#config['model_checkpoint'] = (
#    '/content/gdrive/MyDrive/Translation/saved_models/'
#    'marianmt-many-eng/checkpoint-1400')


# Find the biggest batch size that fits in GPU memory
APPROX_MODEL_MEMORY_SIZE_MB = 310
if torch.cuda.is_available():
  gpu_info = !nvidia-smi
  gpu_memory_mb = int(gpu_info[9].split()[10][:-3])
  per_device_max_batch_size = int(gpu_memory_mb / APPROX_MODEL_MEMORY_SIZE_MB)
  B = config['effective_train_batch_size'] 
  factors = np.array([x for x in range(1, B) if B % x == 0])
  config['train_batch_size'] = int(max(
      factors[factors < per_device_max_batch_size]))
  config['eval_batch_size'] = config['train_batch_size']
else:
  config['train_batch_size'] = 1
  config['eval_batch_size'] = 1
config['gradient_accumulation_steps'] = int(
    config['effective_train_batch_size'] / config['train_batch_size'])

# Trainer settings
config['train_settings'] = transformers.Seq2SeqTrainingArguments(
    output_dir= f'/kaggle/working/best',
    evaluation_strategy = 'steps',
    eval_steps = config['eval_steps_interval'],
    save_steps = config['eval_steps_interval'],
    gradient_accumulation_steps = config['gradient_accumulation_steps'],
    learning_rate = config['learning_rate'],
    per_device_train_batch_size = config['train_batch_size'],
    per_device_eval_batch_size = config['eval_batch_size'],
    weight_decay = 0.01,
    save_total_limit = 3,
    num_train_epochs = config['num_train_epochs'],
    predict_with_generate = True,
    fp16 = torch.cuda.is_available(),
    logging_dir = f'/kaggle/working/best',
    report_to = 'wandb',
    run_name = f'{config["source_language"]}-{config["target_language"]}',
    load_best_model_at_end=True,
    metric_for_best_model = config['metric_for_best_model'],
    label_smoothing_factor = config['label_smoothing_factor']
)

In [None]:
config['training_subset_ids'] = ['salt-train', 'ai4d']

if config['lafand_training_data']:
    config['training_subset_ids'].extend(['lafand-en-lug-combined', 'lafand-en-luo-combined'])

if config['flores200_training_data']:
    config['training_subset_ids'] .append('flores200')

if config['back_translation_training_data']:
    config['training_subset_ids'].extend( ['bt_ach_en_14_3_23', 'bt_lug_en_14_3_23'])

if config['back_translation_training_data']:
    config['training_subset_ids'].extend(['backtranslated-from-eng-google', 'backtranslated-from-lug-google' ])

config['training_subset_ids'] = config['training_subset_ids']*5

if config['mt560_training_data']:
    config['training_subset_ids'].extend([
        'mt560_ach', 'mt560_lug', 'mt560_nyn','mt560_luo'])



In [None]:
if not os.path.exists('salt-translation-plus-external-datasets-15-3-23'):
    !wget https://sunbird-translate.s3.us-east-2.amazonaws.com/salt-translation-plus-external-datasets-15-3-23.zip
    !unzip salt-translation-plus-external-datasets-15-3-23.zip
    display.clear_output()

In [None]:
from tqdm import tqdm
training_subsets = [
    salt.translation_dataset(
        path=f'{config["data_dir"]}/{id}.jsonl',
        source_language=config['source_language'],
        target_language=config['target_language'],
        allow_target_language_in_source=False,
        prefix_target_language_in_source=False,
        languages_to_include=config['eval_languages'],
        keep_unaugmented_src = False)
    for id in tqdm(config['training_subset_ids'])
]

In [None]:
# sample_probabilities = np.array([len(s) for s in training_subsets])
# sample_probabilities = sample_probabilities * np.array(
#     [config['mt560_relative_sample_rate'] if ('mt560' in id) or ('google' in id)  else 1.0
#      for id in config['training_subset_ids']])
# sample_probabilities = sample_probabilities / np.sum(sample_probabilities)

# train_data_raw = datasets.interleave_datasets(
#     training_subsets, sample_probabilities)
train_data_raw = datasets.concatenate_datasets(training_subsets)
train_data_raw = train_data_raw.shuffle()
train_data_raw = train_data_raw.flatten_indices()

In [None]:
validation_subsets = [
    salt.translation_dataset(
        path=f'{config["data_dir"]}/salt-dev.jsonl',
        source_language=language,
        target_language="eng",
        keep_unaugmented_src = False,
        allow_target_language_in_source=False,
        prefix_target_language_in_source=False
    )
    for language in config['eval_languages']
]
validation_data_raw = datasets.concatenate_datasets(validation_subsets)

In [None]:
def sentence_format(input):
    '''Ensure capital letter at the start and full stop at the end.'''
    try:
        input = input[0].capitalize() + input[1:]
        if input[-1] not in ['.', '!', '?']:
            input = input + '.'
    except:
        return ""
    return input

def preprocess(examples):
    normalizer = sacremoses.MosesPunctNormalizer()  
    inputs = []
    targets = []
    for input, target in zip(examples['source'], examples['target']):
        if not len(input):
          input = target
        inputs.append(sentence_format(normalizer.normalize(input)))
        targets.append(sentence_format(normalizer.normalize(target)))
    
    model_inputs = tokenizer(
        inputs, text_target=targets,
        max_length=config['max_input_length'], truncation=True)

    return model_inputs

def postprocess(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    return preds, labels

def compute_metrics(eval_preds, eval_languages, samples_per_language):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
        
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess(decoded_preds, decoded_labels)
    result = {}
    for i, lang in enumerate(eval_languages):
        
        result_subset = metric.compute(
            predictions=decoded_preds[
                i*samples_per_language:(i+1)*samples_per_language],
            references=decoded_labels[
                i*samples_per_language:(i+1)*samples_per_language])
        result[f"BLEU_{lang}"] = result_subset["score"]
        
    result["BLEU_mean"] = np.mean(
        [result[f"BLEU_{lang}"] for lang in eval_languages])
    
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [None]:
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(config['model_checkpoint'])
tokenizer = transformers.AutoTokenizer.from_pretrained(config['model_checkpoint'])
data_collator = transformers.DataCollatorForSeq2Seq(tokenizer, model = model) 
metric = datasets.load_metric('sacrebleu')

In [None]:
if config['target_language'] == 'many':
    replacements = {'nyn': 'kin',
                    'lgg': 'lin',
                    'ach': 'tso',
                    'teo': 'som',
                    'luo': 'sna'}
    for r in replacements:
        if (f'>>{r}<<' not in tokenizer.encoder and
            f'{replacements[r]}' in tokenizer.encoder):
            tokenizer.encoder[f">>{r}<<"] = tokenizer.encoder[f"{replacements[r]}"]
            del tokenizer.encoder[f"{replacements[r]}"]

    # Check that all the evaluation language codes are mapped to something.
    for r in config['eval_languages']:
        if f'>>{r}<<' not in tokenizer.encoder:
            raise ValueError(f'Language code {r} not found in the encoder.')

In [None]:
train_data  = train_data_raw.map(
    preprocess,
    remove_columns=['source', 'target', 'source_language', 'target_language'],
    batched=True)

validation_data  = validation_data_raw.map(
    preprocess,
    remove_columns=['source', 'target', 'source_language', 'target_language'],
    batched=True)

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_key")
import os
os.environ["WANDB_API_KEY"] = secret_value_0
wandb.init(project=config['wandb_project'], config=config, entity="azawahry")

transformers.logging.set_verbosity_error()

trainer = transformers.Seq2SeqTrainer(
    model,
    config['train_settings'],
    train_dataset = train_data,
    eval_dataset = validation_data,
    data_collator = data_collator,
    tokenizer = tokenizer,
    compute_metrics = lambda x: compute_metrics(
        x, config['eval_languages'][:-1], config['validation_samples_per_language']),
    callbacks = [
        transformers.EarlyStoppingCallback(
            early_stopping_patience = config['early_stopping_patience'])],
)

In [None]:
# if config['eval_pretrained_model']:
#     trainer.evaluate()

In [None]:
trainer.train()