<a href="https://colab.research.google.com/github/SunbirdAI/salt/blob/main/notebooks/DSA_2024_Tutorial_on_speech_recognition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install -q jiwer
!pip install -q evaluate
!pip install -qU accelerate
!pip install -q transformers[torch]
!git clone https://github.com/sunbirdai/salt.git
!pip install -qr salt/requirements.txt

In [None]:
import salt.dataset
import salt.utils
import salt.metrics
import yaml
import transformers
from IPython import display
import torch
import evaluate

We start with some of Meta's models: [MMS](https://huggingface.co/docs/transformers/en/model_doc/mms) for speech recognition and generation, and [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb) for translation. These support several African languages, so make a useful starting point. However for some languages, training data was limited and so the models need some refinement. We'll look at both how to add on a new language which the model doesn't know about at all, and also how to improve performance for a language which is supported in the model but not very well.

First of all, let's run through an example: how to fine tune an English speech recognition model to work better with a specific accent, in this case Ugandan. If you want to try this for another language, select one below (`lug`=Luganda, `ach`=Acholi, `teo`=Ateso, `nyn`=Runyankole, `lgg`=Lugbara).

We'll start by loading some evaluation data.

In [None]:
language = 'eng' #@param ["eng", "lug", "ach", "nyn", "teo", "lgg"]

In [None]:
validation_dataset_config = f'''
huggingface_load:
  path: Sunbird/salt
  split: dev
  name: multispeaker-{language}
source:
  type: speech
  language: {language}
  preprocessing:
    - set_sample_rate:
        rate: 16_000
target:
  type: text
  language: {language}
  preprocessing:
    - lower_case
    - clean_and_remove_punctuation:
        allowed_punctuation: "'"
shuffle: True
'''

config = yaml.safe_load(validation_dataset_config)
ds_validation = salt.dataset.create(config)
salt.utils.show_dataset(ds_validation, N=5, audio_features=['source'])

Next, load a base model (Meta MMS) and see what output it gives for one of these audio samples.

In [None]:
pretrained_model = 'facebook/mms-1b-all'

processor = transformers.Wav2Vec2Processor.from_pretrained(pretrained_model)
processor.tokenizer.set_target_lang(language)
data_collator = salt.utils.DataCollatorCTCWithPadding(
    processor=processor, padding=True)
# Is there a GPU?
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = transformers.Wav2Vec2ForCTC.from_pretrained(pretrained_model).to(device)

model.load_adapter(language)

Get a single example from the test set

In [None]:
example = next(iter(ds_validation))

Hear the audio and see what the correct text should be

In [None]:
display.display(display.Audio(data=example['source'], rate=16000))
print('Correct text: ' + example['target'])

What does the model say?

In [None]:
inputs = processor(example['source'], sampling_rate=16_000, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs.to(device)).logits

ids = torch.argmax(outputs, dim=-1)[0]
transcription = processor.decode(ids)

print('Model prediction: ' + transcription)

Now we'll evaluate a base model using this data, to see how well it does on Ugandan English.

In [None]:
def prepare_dataset(batch):
    batch["input_values"] = processor(
        batch["source"], sampling_rate=16000
    ).input_values
    batch["labels"] = processor(text=batch["target"]).input_ids
    return batch

validation_data_tokenised = ds_validation.map(
    prepare_dataset,
    batch_size=4,
    batched=True,
)

compute_metrics = salt.metrics.multilingual_eval_fn(
      ds_validation, [evaluate.load('wer'), evaluate.load('cer')],
      processor.tokenizer, log_first_N_predictions=2,
      speech_processor=processor)

In [None]:
transformers.Trainer(
    model=model,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    eval_dataset=validation_data_tokenised,
    tokenizer=processor.feature_extractor,
).evaluate()

Fine tune the model with training data which is more representative. Note that we'll do some augmentation on the training data, adding some noise. This makes the task a little more difficult and adds some extra variation, which in practice makes the model more robust to audio samples where there is noise in the background. It's also possible to augment the speed and pitch, for example.

In [None]:
train_dataset_config = f'''
huggingface_load:
  path: Sunbird/salt
  split: train
  name: multispeaker-{language}
source:
  type: speech
  language: {language}
  preprocessing:
    - set_sample_rate:
        rate: 16_000
    - augment_audio_noise:
        max_relative_amplitude: 0.5
target:
  type: text
  language: {language}
  preprocessing:
    - lower_case
    - clean_and_remove_punctuation:
        allowed_punctuation: "'"
shuffle: True
'''

config = yaml.safe_load(train_dataset_config)
ds_train = salt.dataset.create(config)
salt.utils.show_dataset(ds_train, N=5, audio_features=['source'])

Start training. We'll just run a few training steps here to see the processs, but leaving it for longer usually results in a better model.

In [None]:
training_args = yaml.safe_load('''
    output_dir: stt
    per_device_train_batch_size: 8
    gradient_accumulation_steps: 4
    evaluation_strategy: steps
    max_steps: 100
    gradient_checkpointing: True
    gradient_checkpointing_kwargs:
      use_reentrant: True
    fp16: True
    save_steps: 100
    eval_steps: 20
    logging_steps: 100
    learning_rate: 3.0e-4
    warmup_steps: 100
    save_total_limit: 2
    push_to_hub: False
    load_best_model_at_end: True
    metric_for_best_model: loss
    greater_is_better: False
    weight_decay: 0.01
''')

train_data_tokenised = ds_train.map(
    prepare_dataset,
    batch_size=4,
    batched=True,
)

# We don't train all of the model, only the language-specific adapter layers.
model.freeze_base_model()
adapter_weights = model._get_adapters()
for param in adapter_weights.values():
    param.requires_grad = True

# Set up the trainer and get started.
trainer = transformers.Trainer(
    model=model,
    data_collator=data_collator,
    args=transformers.TrainingArguments(**training_args, report_to="none"),
    compute_metrics=compute_metrics,
    train_dataset=train_data_tokenised,
    eval_dataset=validation_data_tokenised,
    tokenizer=processor.feature_extractor,
)

trainer.train()

Let's take a look at what the performance on the validation set is now, at whatever point training was stopped.

In [None]:
metrics = trainer.evaluate()

# Multilingual speech recognition

We can actually train a model to be able to recognise more than one language. This helps e.g. for code switching, where someone mainly speaks in one language but uses some terms from a different language. Here's an example of how we can create Luganda + English training data.

In [None]:
multilingual_dataset_config = '''
huggingface_load:
  - path: Sunbird/salt
    split: train
    name: multispeaker-eng
  - path: Sunbird/salt
    split: train
    name: multispeaker-lug
source:
  type: speech
  language: [eng, lug]
  preprocessing:
    - set_sample_rate:
        rate: 16_000
target:
  type: text
  language: [eng, lug]
  preprocessing:
    - lower_case
    - clean_and_remove_punctuation:
        allowed_punctuation: "'"
shuffle: True
'''

config = yaml.safe_load(multilingual_dataset_config)
ds_multilingual = salt.dataset.create(config)
salt.utils.show_dataset(ds_multilingual, N=5, audio_features=['source'])

## Your turn!

# Speech recognition data collection

Select a language you'd like to work on, and get the ISO 639-3 code [here](https://iso639-3.sil.org/code_tables/639/data). For example, Swahili is `swh` or Luganda is `lug`. Then we can form some groups in the classroom so that everyone interested in a particular language can work together. The more people's voices that can be used to train a model, the better it will work.

In [None]:
language_code = "eng" #@param {type:"string"}

We need to create some phrases in your language of interest and then make some recordings of what they sound like when spoken.

For our simple example, make a copy of [this spreadsheet](https://docs.google.com/spreadsheets/d/1w6TbJsv1gTZmPI8kkZ0Wr_yKpRjBHqRNJsKYROx4zQQ/edit#gid=1347849995), which has English phrases. Translate some of the phrases to your language of interest. These are the sentences to be read out and recorded.

Select some rows from the translation spreadsheet, and paste them into [this tool](https://sunbirdai.github.io/dsa2024-speech-data-recording/).
Download the resulting files, and you will find that this can be used to create a HuggingFace dataset using [this notebook](https://colab.research.google.com/drive/1UuacvElXeS58GGw_-CfUXj2KuqJHqCbM#scrollTo=hzSVaYc3mYrJ) as an example.

In [None]:
your_huggingface_repo = '' # e.g. yourusername/datasetname

# Load the new data with this config
eval_dataset_config = f'''
huggingface_load:
  path: {your_huggingface_repo}
  split: train
source:
  type: speech
  language: {language_code}
  preprocessing:
    - set_sample_rate:
        rate: 16_000
target:
  type: text
  language: {language_code}
  preprocessing:
    - lower_case
    - clean_and_remove_punctuation:
        allowed_punctuation: "'"
shuffle: True
'''

config = yaml.safe_load(eval_dataset_config)
ds_eval = salt.dataset.create(config)
salt.utils.show_dataset(ds_eval, N=5, audio_features=['source'])

# Pointers on translation

We'll also give some tips here on how text-to-text translation models can be trained. The output of a speech recognition model can be fed into a translation model, creating some interesting application possibilities.

*   Make a copy of [this](https://docs.google.com/spreadsheets/d/1FNGvg_IkUNvRbK8_6XFcmI4J5DxFHYEmS7d4KF0R1YI/edit#gid=1347849995) spreadsheet. Notice that there are two tabs: train and test.
*   Add a new column `text_[languagecode]`, where `languagecode` is the ISO 639-3 code you found above.
*   Download the `train` and `test` tabs as a csv file.
*   These can be uploaded to HuggingFace.
*   [Reference code](https://github.com/SunbirdAI/salt/blob/main/notebooks/NLLB_training.ipynb) for training a translation model.

# Pointers on speech generation (text-to-speech)

Text-to-speech models are trained using very similar data as for speech recognition. However, the recordings need to be made under more controlled conditions. Ideally this is done in a studio, so that the sound quality is good, without background noise, and with sentences spoken by someone who is trained as a presenter or voice actor.

As above, the Meta MMS models do support text-to-speech for several African languages, though quality varies depending on the language and some need to be retrained for practical usage.

# Pan-African models 🌍

All of this data can be joined up, so that we train single models which can understand and translate between many African languages.

Add your HuggingFace repository IDs in the slack channel.