In [None]:
!git clone https://github.com/sunbirdai/salt.git
!pip install -q -r salt/requirements.txt
!pip install -q sentencepiece sacremoses accelerate peft

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import transformers
import datasets
import evaluate
from tqdm.notebook import tqdm
import salt.dataset
import salt.utils
import salt.metrics
import salt.constants
import yaml
import peft
import mlflow
from IPython import display
import getpass
from functools import partial

In [None]:
MLFLOW_TRACKING_USERNAME = getpass.getpass('Enter the MLFLOW_TRACKING_USERNAME: ')
os.environ['MLFLOW_TRACKING_USERNAME'] = MLFLOW_TRACKING_USERNAME
MLFLOW_TRACKING_PASSWORD = getpass.getpass('Enter the MLFLOW_TRACKING_PASSWORD: ')
os.environ['MLFLOW_TRACKING_PASSWORD'] = MLFLOW_TRACKING_PASSWORD
os.environ["MLFLOW_TRACKING_URI"] = "https://mlflow-sunbird-ce0ecfc14244.herokuapp.com" 
os.environ["MLFLOW_EXPERIMENT_NAME"] = "translation"

In [None]:
# define the artifacts directory for output files
output_folder = "nllb-3.3b-salt-lr2e-4"

effective_train_batch_size = 3000
train_batch_size = 10
eval_batch_size = train_batch_size

gradient_accumulation_steps = int(effective_train_batch_size / train_batch_size)

ALL_LANGUAGES = list(salt.constants.SALT_LANGUAGE_NAMES.keys())
ALL_LANGUAGES.sort()
ALL_LANGUAGES.remove('ibo')

# Everything in one yaml string, so that it can all be logged to MLFlow
yaml_config = f'''
training_args:
  output_dir: "{output_folder}"
  eval_strategy: steps
  eval_steps: 100
  save_steps: 100
  warmup_steps: 10
  gradient_accumulation_steps: {gradient_accumulation_steps}
  learning_rate: 2.0e-4  # Include decimal point to parse as float
  optim: adafactor
  per_device_train_batch_size: {train_batch_size}
  per_device_eval_batch_size: {eval_batch_size}
  weight_decay: 0
  save_total_limit: 3
  num_train_epochs: 3
  predict_with_generate: True
  # fp16: True
  bf16: True
  logging_dir: "{output_folder}"
  load_best_model_at_end: True
  metric_for_best_model: loss
  seed: 42
  push_to_hub: True

max_input_length: 256
eval_pretrained_model: False
early_stopping_patience: 6
data_dir: .
model_checkpoint: facebook/nllb-200-3.3B

datasets:
  train:
    huggingface_load:
      - path: jq/salt_with_eng_target
        name: translations
        split: train
      - path: sunbird/salt
        name: text-hard
        split: train
      - path: Sunbird/external-translation-data
        name: flores200
      - path: Sunbird/external-translation-data
        name: google_smol_ach
      - path: Sunbird/external-translation-data
        name: google_smol_alz
      - path: Sunbird/external-translation-data
        name: google_smol_cgg
      - path: Sunbird/external-translation-data
        name: google_smol_lug
      - path: Sunbird/external-translation-data
        name: lafand-eng-lug
      - path: Sunbird/external-translation-data
        name: lafand-eng-luo
      - path: Sunbird/external-translation-data
        name: makerere-ai4d
      - path: Sunbird/external-translation-data
        name: mozilla_110
      - path: Sunbird/external-translation-data
        name: tico19
      - path: Sunbird/external-translation-data
        name: bibles 
        split: train[:-200] # Save some for eval
      - path: Sunbird/external-translation-data
        name: makerere-ea-languages
        
      - path: Sunbird/external-translation-data
        name: backtranslated_ach
      - path: Sunbird/external-translation-data
        name: backtranslated_lug
      - path: Sunbird/external-translation-data
        name: backtranslated_eng_to_lug_only
      
      - path: Sunbird/external-translation-data
        name: mt560_ach # optional: _unidirectional
      - path: Sunbird/external-translation-data
        name: mt560_alz_unidirectional
      - path: Sunbird/external-translation-data  
        name: mt560_koo_unidirectional
      - path: Sunbird/external-translation-data        
        name: mt560_lug # optional: _unidirectional
      - path: Sunbird/external-translation-data        
        name: mt560_nyn # optional: _unidirectional
      - path: Sunbird/external-translation-data        
        name: mt560_swa_unidirectional
        split: train[:50000]
      - path: Sunbird/external-translation-data        
        name: mt560_ttj_unidirectional

    source:
      type: text
      language: {ALL_LANGUAGES}
      preprocessing:
        - random_case:
            apply_to_both: False
            p_all_lower_case: 0.05
            p_all_upper_case: 0.005
        - augment_characters:
            p: 0.1
        - prefix_dataset_tag:
            tags:
              mt560: '<mt560>'
              backtranslate: '<bt>'
              bible: '<bible>'
    target:
      type: text
      language: {ALL_LANGUAGES}
      preprocessing:
        - clean_text
        - ensure_text_ends_with_punctuation
    shuffle: True
    src_or_tgt_languages_must_contain: eng  # Limit to xx->eng, eng->xx
    allow_same_src_and_tgt_language: False
    # keep_metadata_features: True
        
  validation:
    huggingface_load:
      path: sunbird/salt
      name: text-all
      split: dev  # optionally use a slice, e.g. dev[:10] for a quick test
    source:
      type: text
      language: [ach,lgg,lug,nyn,teo,eng,xog,ttj,swa]
      preprocessing:
        - clean_text
    target:
      type: text
      language: [ach,lgg,lug,nyn,teo,eng,xog,ttj,swa]
      preprocessing:
        - clean_text
        - ensure_text_ends_with_punctuation
    src_or_tgt_languages_must_contain: eng  # Limit to xx->eng, eng->xx
    allow_same_src_and_tgt_language: False
'''

config = yaml.safe_load(yaml_config)

training_settings = transformers.Seq2SeqTrainingArguments(
    **config["training_args"])

In [None]:
model = salt.utils.TrainableM2MForConditionalGeneration.from_pretrained(
    config['model_checkpoint'])
tokenizer = transformers.NllbTokenizer.from_pretrained(
    config['model_checkpoint'],
    src_lang='eng_Latn',
    tgt_lang='eng_Latn')

So far, LoRA didn't help with this model, but it might prevent overfitting with some more experimentation.

In [None]:
use_peft = False
if use_peft:
    from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
    peft_config = LoraConfig(
        target_modules=["q_proj", "k_proj", "v_proj"],
        task_type=TaskType.SEQ_2_SEQ_LM,
        inference_mode=False,
        r=16,
        lora_alpha=32,
        lora_dropout=0.0
    )
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()
    
    # Alternatively, count and display trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    
    print(f"Trainable Parameters: {trainable_params}")
    print(f"Total Parameters: {total_params}")
    print(f"Percentage of Trainable Parameters: {100 * trainable_params / total_params:.2f}%")

In [None]:
label_pad_token_id = -100
data_collator = transformers.DataCollatorForSeq2Seq(
    tokenizer,
    model = model,
    label_pad_token_id=label_pad_token_id,
)

In [None]:
def preprocess(examples):
    model_inputs = tokenizer(
        examples['source'],
        text_target=examples['target'],
        max_length=config['max_input_length'],
        truncation=True)

    # For NLLB models, set the language code for the sources and targets
    model_inputs['forced_bos_token_id'] = []
    for i in range(len(examples['source'])):
      source_language = examples['source.language'][i]
      target_language = examples['target.language'][i]
      model_inputs['input_ids'][i][0] = salt.constants.SALT_LANGUAGE_TOKENS_NLLB_TRANSLATION[
          source_language]
      target_language_token = salt.constants.SALT_LANGUAGE_TOKENS_NLLB_TRANSLATION[target_language]
      model_inputs['labels'][i][0] = target_language_token
      model_inputs['forced_bos_token_id'].append(target_language_token)

    return model_inputs

train_dataset = salt.dataset.create(config['datasets']['train'])
eval_dataset = salt.dataset.create(config['datasets']['validation'])

# Process all the train_dataset data up front, instead of lazy loading.
def gen_from_iterable_dataset(iterable_ds):
    yield from iterable_ds
train_dataset_preloaded = datasets.Dataset.from_generator(
    partial(gen_from_iterable_dataset, train_dataset), features=train_dataset.features)
# Shuffling is then improved.
train_dataset = train_dataset_preloaded.shuffle()

salt.utils.show_dataset(train_dataset, N=20)

train_dataset = train_dataset.map(
    preprocess,
    batched=True,
    remove_columns=['source', 'source.language', 'target', 'target.language'],
    num_proc=32)
eval_dataset = eval_dataset.map(
    preprocess,
    batched=True)

compute_metrics = salt.metrics.multilingual_eval_fn(
      eval_dataset, [evaluate.load('sacrebleu')],
      tokenizer, log_first_N_predictions=0)

In [None]:
transformers.generation.utils.ForcedBOSTokenLogitsProcessor = salt.utils.ForcedVariableBOSTokenLogitsProcessor

trainer = transformers.Seq2SeqTrainer(
  model,
  training_settings,
  train_dataset = train_dataset,
  eval_dataset = eval_dataset,
  data_collator = data_collator,
  tokenizer = tokenizer,
  compute_metrics = compute_metrics,
  callbacks = [
      transformers.EarlyStoppingCallback(
          early_stopping_patience = (config
           ['early_stopping_patience']))],
)

trainer.train()

In [None]:
trainer.evaluate()

In [None]:
trainer.push_to_hub()

In case we need to count the lengths of the input/output sequences in tokens, to check the minimum sequence length.

In [None]:
count_sequence_lengths = False
if count_sequence_lengths:
    from functools import partial
    def gen_from_iterable_dataset(iterable_ds):
        yield from iterable_ds
    
    train_ds = datasets.Dataset.from_generator(
        partial(gen_from_iterable_dataset, train_dataset), features=train_dataset.features)
    
    input_lengths = []
    output_lengths = []
    for i in tqdm(range(len(train_ds))):
        input_lengths.append(len(train_ds[i]['input_ids']))
        output_lengths.append(len(train_ds[i]['labels']))
    
    import numpy as np
    np.max(input_lengths)
    import matplotlib.pyplot as plt
    plt.hist(input_lengths, bins=50)

Try out the model on a few test examples

In [None]:
def translate(text, source_language, target_language):
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  inputs = tokenizer(text, return_tensors="pt").to(device)
  inputs['input_ids'][0][0] = SALT_LANGUAGE_TOKENS_NLLB_TRANSLATION[source_language]
  num_beams = 1
  translated_tokens = model.to(device).generate(
      **inputs,
      forced_bos_token_id=[salt.constants.SALT_LANGUAGE_TOKENS_NLLB_TRANSLATION[target_language]]*num_beams,
      max_length=100,
      #num_beams=num_beams,
  )
  result = tokenizer.batch_decode(
      translated_tokens, skip_special_tokens=True)[0]
  return result

In [None]:
#translate('Eyai papa kang lore', 'kdj', 'eng') # My father is at home
#translate('Erai ikoku ngin nginikang.', 'kdj', 'eng') # The child is mine
#translate('Ke konya ku sente moku awil ku cam', 'alz', 'eng') # Alur: Could you lend me some money for lunch, please?
#translate('Do lu pondra ya?', 'keo', 'eng') # Kakwa: Where are you coming from?
translate('Do a jakinda na nyo?', 'keo', 'eng') # Kakwa: what have you brought for me?

In [None]:
queries = [
    'Genda mu nnyumba olabe oba abaana bakyatunula.',
    'Yakoma mu S.4 era agamba talina buzibu nabuyigirize bwa musajja.',
    'Aba Multiplex baagudde mu kawunyemu misana ttuku.',
    'Bantu mannyo ga mpisi, gasseka kungulu, nga munda mulimu bussi.',
    "Ng'akubira, ssentebe Ssegawa alagidde Ndaula okuwaayo sitampu nga tennamuleetera bizibu.",
    'Wano webawaaba?',
    "Oyita ewala, n'otuuka emirembe.",
    'Waliwo omukadde e Banda mu Wakiso atalina wakusula nga akasiisira keyazimba enkuba bwetonnya konna kajjula',
]

for query in queries:
  print(query)
  print(translate(query, "lug", "eng"))
  print('\n')