<a href="https://colab.research.google.com/github/SunbirdAI/salt/blob/main/notebooks/NMT_training_for_Kenyan_Sign_Language_gloss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Machine translation example: Kenyan Sign Language Gloss

This example shows how to train a machine translation model for KSL, English and Swahili, using Maseno university datasets and HuggingFace extended with Sunbird AI [SALT](https://github.com/SunbirdAI/salt/).

In this example, we combine English-KSL and English-Swahili training data, to obtain a model which can translate in any direction, including Swahili to KSL.

The model used here is [NLLB 600M](https://huggingface.co/facebook/nllb-200-distilled-600M), which can be trained on a free Colab instance. If more GPU memory is available, then a larger model (e.g. NLLB 1.3B) is likely to give better performance.

In [None]:
#@title Install packages
!pip install -qU transformers
!pip install -qU datasets
!pip install -q accelerate
!pip install -q sentencepiece
!pip install -q sacremoses
!pip install -q wandb

# Sunbird African Language Technology (SALT) utilities
!git clone https://github.com/sunbirdai/salt.git
!pip install -qr salt/requirements.txt

In [None]:
#@title Imports

import os
import numpy as np
import pandas as pd
import torch
import transformers
import huggingface_hub
import datasets
import evaluate
import tqdm
import salt.dataset
import salt.utils
import salt.metrics
import yaml
import wandb
from IPython import display

In [None]:
huggingface_hub.notebook_login()

In [None]:
# Where should output files be stored locally
drive_folder = "./artifacts"

if not os.path.exists(drive_folder):
  %mkdir $drive_folder

# Large batch sizes generally give good results for translation
effective_train_batch_size = 480
train_batch_size = 6
eval_batch_size = train_batch_size

gradient_accumulation_steps = int(effective_train_batch_size / train_batch_size)

# Everything in one yaml string, so that it can all be logged.
yaml_config = '''
training_args:
  output_dir: "{drive_folder}"
  eval_strategy: steps
  eval_steps: 100
  save_steps: 100
  gradient_accumulation_steps: {gradient_accumulation_steps}
  learning_rate: 3.0e-4  # Include decimal point to parse as float
  # optim: adafactor
  per_device_train_batch_size: {train_batch_size}
  per_device_eval_batch_size: {eval_batch_size}
  weight_decay: 0.01
  save_total_limit: 3
  max_steps: 500
  predict_with_generate: True
  fp16: True
  logging_dir: "{drive_folder}"
  load_best_model_at_end: True
  metric_for_best_model: loss
  seed: 123
  push_to_hub: False

max_input_length: 128
eval_pretrained_model: False
early_stopping_patience: 4
data_dir: .

# Use a 600M parameter model here, which is easier to train on a free Colab
# instance. Bigger models work better, however: results will be improved
# if able to train on nllb-200-1.3B instead.
model_checkpoint: facebook/nllb-200-distilled-600M

datasets:
  train:
    huggingface_load:
      # We will load two datasets here: English/KSL Gloss, and also SALT
      # Swahili/English, so that we can try out multi-way translation.

      - path: EzekielMW/Eng_KSLGloss
        split: train[:-500]
      - path: sunbird/salt
        name: text-all
        split: train
    source:
      # This is a text translation only, no audio.
      type: text
      # The source text can be any of English, KSL or Swahili.
      language: [eng,ksl,swa]
      preprocessing:
        # The models are case sensitive, so if the training text is all
        # capitals, then it will only learn to translate capital letters and
        # won't understand lower case. Make everything lower case for now.
        - lower_case
        # We can also augment the spelling of the input text, which makes the
        # model more robust to spelling errors.
        - augment_characters
    target:
      type: text
      # The target text with any of English, KSL or Swahili.
      language: [eng,ksl,swa]
      # The models are case sensitive: make everything lower case for now.
      preprocessing:
        - lower_case

    shuffle: True
    allow_same_src_and_tgt_language: False

  validation:
    huggingface_load:
      # Use the last 500 of the KSL examples for validation.
      - path: EzekielMW/Eng_KSLGloss
        split: train[-500:]
      # Add some Swahili validation text.
      - path: sunbird/salt
        name: text-all
        split: dev
    source:
      type: text
      language: [swa,ksl,eng]
      preprocessing:
        - lower_case
    target:
      type: text
      language: [swa,ksl,eng]
      preprocessing:
        - lower_case
    allow_same_src_and_tgt_language: False
'''

yaml_config = yaml_config.format(
    drive_folder=drive_folder,
    train_batch_size=train_batch_size,
    eval_batch_size=eval_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
)

config = yaml.safe_load(yaml_config)

training_settings = transformers.Seq2SeqTrainingArguments(
    **config["training_args"])

In [None]:
# The default HuggingFace NLLB models can only be trained with one target
# language. Use a SALT wrapper which makes it trainable for multilingual
# translation.
model = salt.utils.TrainableM2MForConditionalGeneration.from_pretrained(
    config['model_checkpoint'])
tokenizer = transformers.NllbTokenizer.from_pretrained(
    config['model_checkpoint'],
    src_lang='eng_Latn',
    tgt_lang='eng_Latn')

In [None]:
# The pre-trained model that we use has support for some African languages, but
# we need to adapt the tokenizer to languages that it wasn't trained with,
# such as KSL. Here we reuse the token from a different language.
LANGUAGE_CODES = ["eng", "swa", "ksl"]

code_mapping = {
    # Exact/close mapping
    'eng': 'eng_Latn',
    'swa': 'swh_Latn',
    # Random mapping
    'ksl': 'ace_Latn',
}

offset = tokenizer.sp_model_size + tokenizer.fairseq_offset

for code in LANGUAGE_CODES:
    i = tokenizer.convert_tokens_to_ids(code_mapping[code])
    tokenizer._added_tokens_encoder[code] = i

In [None]:
label_pad_token_id = -100
data_collator = transformers.DataCollatorForSeq2Seq(
    tokenizer,
    model = model,
    label_pad_token_id=label_pad_token_id,
)

In [None]:
def preprocess(examples):
    model_inputs = tokenizer(
        examples['source'],
        text_target=examples['target'],
        max_length=config['max_input_length'],
        truncation=True)

    # For NLLB models, set the language code for the sources and targets
    model_inputs['forced_bos_token_id'] = []
    for i in range(len(examples['source'])):
      source_language = examples['source.language'][i]
      target_language = examples['target.language'][i]
      model_inputs['input_ids'][i][0] = tokenizer.convert_tokens_to_ids(
          source_language)
      target_language_token = tokenizer.convert_tokens_to_ids(target_language)
      model_inputs['labels'][i][0] = target_language_token
      model_inputs['forced_bos_token_id'].append(target_language_token)

    return model_inputs


train_dataset = salt.dataset.create(config['datasets']['train'])
eval_dataset = salt.dataset.create(config['datasets']['validation'])

# Take a look at some of the data rows after shuffling and preprocessing
salt.utils.show_dataset(train_dataset, N=10)

train_dataset = train_dataset.map(
    preprocess,
    batched=True,
    remove_columns=['source', 'source.language', 'target', 'target.language'])
eval_dataset = eval_dataset.map(
    preprocess,
    batched=True)

# Use a SALT function which computed the evaluation score separately for
# different languages.
compute_metrics = salt.metrics.multilingual_eval_fn(
      eval_dataset, [evaluate.load('sacrebleu')],
      tokenizer, log_first_N_predictions=10)


In [None]:
import wandb
wandb.init(project='translate-ksl-eng-swa', config=config)

In [None]:
transformers.generation.utils.ForcedBOSTokenLogitsProcessor = salt.utils.ForcedVariableBOSTokenLogitsProcessor

trainer = transformers.Seq2SeqTrainer(
  model,
  training_settings,
  train_dataset = train_dataset,
  eval_dataset = eval_dataset,
  data_collator = data_collator,
  tokenizer = tokenizer,
  compute_metrics = compute_metrics,
  callbacks = [
      transformers.EarlyStoppingCallback(
          early_stopping_patience = (config
           ['early_stopping_patience']))],
)

trainer.train()

In [None]:
# model.push_to_hub('your_repo_id')
# tokenizer.push_to_hub('your_repo_id')

In [None]:
transformers.generation.utils.ForcedBOSTokenLogitsProcessor = transformers.ForcedBOSTokenLogitsProcessor

def translate(text, source_language, target_language):
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  # We trained the model above with only lower-case text, so make sure the
  # inputs here are lower-case too.
  inputs = tokenizer(text.lower(), return_tensors="pt").to(device)
  inputs['input_ids'][0][0] = tokenizer.convert_tokens_to_ids(source_language)
  translated_tokens = model.to(device).generate(
      **inputs,
      forced_bos_token_id=tokenizer.convert_tokens_to_ids(target_language),
      max_length=100,
      num_beams=5,
  )

  result = tokenizer.batch_decode(
      translated_tokens, skip_special_tokens=True)[0]

  # Change KSL glosses to upper case
  if target_language == 'ksl':
    result = result.upper()

  return result

In [None]:
translate('where is the nearest hospital', 'eng', 'ksl')

In [None]:
translate('ME SCHOOL GO', 'ksl', 'eng')

In [None]:
translate('nataka kununua tikiti', 'swa', 'ksl')

In [None]:
translate('ME SCHOOL GO', 'ksl', 'swa')