<a href="https://colab.research.google.com/github/google/project-euphonia-app/blob/main/training_colabs/Project_Euphonia_Finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-Tuning ASR opensource models for Project Euphonia App

This notebook demonstrates how to fine-tune the Whisper model for automatic speech recognition (ASR) using data from Google Firebase storage, specifically tailored for Project Euphonia.

This notebook is meant to run in Google Colab. To run as Jupyter notebook, access to Google Cloud / Firebase storage needs to be changed (you can manually download the recordings via `gsutil`).

**Note:** This setup provides a minimal training configuration, primarily for end-to-end demonstration. Further hyperparameter tuning is recommended for improved results.

## Prerequisites

- Google Cloud project with Firebase storage containing audio data and transcriptions.
- Access to a Google Colab environment with a GPU (L4 sufficient for smaller model sizes, A100 needed for large models).
- Recorded audio files provided by Euphonia App.

## Allow to access Google Cloud storage

This cell authenticates your Google Colab session with your Google Cloud account, allowing access to Google Cloud Storage.

In [None]:
from google.colab import auth
auth.authenticate_user()

## Imports

This cell installs or upgrades the necessary Python libraries, including datasets, transformers, evaluate, and others required for audio processing and model training.

In [None]:
!pip install --upgrade --quiet pip
!pip install --upgrade --quiet datasets[audio] transformers accelerate evaluate jiwer tensorboard

This cell imports the required Python modules.

In [None]:
import ipywidgets as widgets

import os
import csv
import shutil
import soundfile as sf
import numpy as np

from dataclasses import dataclass
from typing import Any, Dict, List, Union

from datasets import load_dataset
from evaluate import load as metrics_loader
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
from transformers import WhisperForConditionalGeneration
from transformers import WhisperProcessor
from transformers.models.whisper.english_normalizer import BasicTextNormalizer

import torch

## Function definition

This section defines the Word Error Rate (WER) calculation function, which is used to evaluate the model's performance.

In [None]:
wer_metric = metrics_loader("wer")
transcript_normalizer = BasicTextNormalizer()

def get_wer(references, predictions, normalize=True, verbose=True):
  rs = references
  ps = predictions
  if normalize:
    ps = [transcript_normalizer(x) for x in predictions]
    rs = [transcript_normalizer(x) for x in references]
  if verbose:
    for r, p in zip(rs, ps):
      print(r)
      print(p)
      print()

  return wer_metric.compute(references=rs, predictions=ps)

This function counts the number of trainable parameters in the model.

In [None]:
def count_trainable_parameters(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    return params

## Prepare data

This cell downloads the audio data and transcriptions from your Firebase storage bucket to the Colab environment. Replace <YOUR FIREBASE PROJECT NAME> with your actual project ID. This is the same ID as you chose when you created your Firebase project.

In [None]:
AUDIO_DATA_DIR = '/content/asr_data'
FIREBASE_PROJECT = '<YOUR FIREBASE PROJECT NAME>'
!mkdir -p {AUDIO_DATA_DIR}

# download all for now from firebase storage
!gsutil -m cp -r gs://{FIREBASE_PROJECT}.appspot.com/data/ {AUDIO_DATA_DIR}

## Prepare Dataset

This cell prepares the dataset by splitting the data into training, testing, and development sets, creating a metadata CSV file, and loading the data into a [Hugging Face Dataset](https://huggingface.co/blog/audio-datasets) object.

In [None]:
#@title Create dataset
AUDIO_FOLDER_DIR = '/content/audio_folder'
!mkdir {AUDIO_FOLDER_DIR}

DEV_METADATA_FILE = os.path.join(AUDIO_FOLDER_DIR, 'dev', 'metadata.csv')
TEST_METADATA_FILE = os.path.join(AUDIO_FOLDER_DIR, 'test', 'metadata.csv')
TRAIN_METADATA_FILE = os.path.join(AUDIO_FOLDER_DIR, 'train', 'metadata.csv')

csv_file_map = {
    'dev': DEV_METADATA_FILE,
    'test': TEST_METADATA_FILE,
    'train': TRAIN_METADATA_FILE
}

# percentage train/test/dev
#@markdown train and test together cannot exceed 90% of the data (at least 10% used as dev set)
TRAIN_PORTION = 0.8 #@param{type: 'number'}
TEST_PORTION = 0.1 #@param{type: 'number'}

verbose = False #@param{type: 'boolean'}

assert (TRAIN_PORTION + TEST_PORTION <= 0.9)


# get sizes
num_audios = len(os.listdir(os.path.join(AUDIO_DATA_DIR, 'data')))
train_offset = int(num_audios * TRAIN_PORTION)
test_offset = train_offset + int(num_audios * TEST_PORTION)

# copy audios and create metadata
!mkdir -p {AUDIO_FOLDER_DIR}/train
!mkdir -p {AUDIO_FOLDER_DIR}/dev
!mkdir -p {AUDIO_FOLDER_DIR}/test


for i in csv_file_map.keys():
  f = open(csv_file_map[i], 'w', newline='')
  spamwriter = csv.writer(f)
  spamwriter.writerow(['file_name', 'transcription'])
  f.close()

for i in range(0, num_audios):
  current_split = ''
  if i < train_offset:
    current_split = 'train'
  elif i < test_offset:
    current_split = 'test'
  else:
    current_split = 'dev'

  f = open(csv_file_map[current_split], 'a', newline='')
  spamwriter = csv.writer(f)

  orig_audio_file = os.path.join(AUDIO_DATA_DIR, 'data', str(i), 'recording.wav')
  transcript_file = os.path.join(AUDIO_DATA_DIR, 'data', str(i), 'phrase.txt')
  target_audio_file = os.path.join(AUDIO_FOLDER_DIR, current_split, 'recording_' + str(i) + '.wav')
  relative_target_audio_file = os.path.join('recording_' + str(i) + '.wav')
  transcript = open(transcript_file, 'r').read()

  if verbose:
    print(orig_audio_file + '-->' + target_audio_file + '\t' + relative_target_audio_file + '\t' + transcript)

  shutil.copyfile(orig_audio_file, target_audio_file)
  spamwriter.writerow([relative_target_audio_file, transcript])
  f.close()

# create huggingface dataset
my_audio_dataset = load_dataset("audiofolder", data_dir=AUDIO_FOLDER_DIR, streaming=False)
print(my_audio_dataset)

In [None]:
# we can now inspect examples in the dataset
my_audio_dataset['train'][0]

## Model Training
This cell uses an interactive dropdown widget to select the language for the ASR model.

In [None]:
#@title Select ASR model language
#@markdown Run cell to get list of languages supported by model. Then select the language.
#@markdown Don't re-run cell afterwards, as this will reset language selection.
from transformers.models.whisper import tokenization_whisper

languages_list = list(tokenization_whisper.TO_LANGUAGE_CODE.values())
language_picker = widgets.Dropdown(options=languages_list, value='en')
print('Language:')
language_picker

In [None]:
WHISPER_MODEL_TYPE = "openai/whisper-small" #@param['openai/whisper-tiny.en', 'openai/whisper-tiny', 'openai/whisper-base', 'openai/whisper-small', 'openai/whisper-medium', 'openai/whisper-large', 'openai/whisper-large-v3', 'openai/whisper-large-v3-turbo']{}
LANGUAGE = language_picker.value
TASK = "transcribe"

print('Using Language: ', LANGUAGE)
print('Using model:', WHISPER_MODEL_TYPE)
processor = WhisperProcessor.from_pretrained(WHISPER_MODEL_TYPE, language=LANGUAGE, task=TASK)
base_model = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_TYPE)

## Extract features on dataset

In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device is: ', device)

# for more efficient dataset processing
torch.set_num_threads(1)
torch.get_num_threads()
num_proc = os.cpu_count()
print('# processors:', num_proc)


In [None]:
%%time
def prepare_dataset(batch):
    audio = batch["audio"]
    batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    batch["labels"] = processor.tokenizer(batch["transcription"]).input_ids
    batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]
    return batch


my_audio_dataset = my_audio_dataset.map(prepare_dataset,
                                        writer_batch_size=1,
                                        num_proc=num_proc,
                                        )

## Configure Training

In [None]:
#@title Training Hyper Parameters
OUTPUT_DIR = '/content/whisper_tuning' #@param
LOG_DIR = os.path.join(OUTPUT_DIR, 'logs')

LEARNING_RATE = 1e-5 #@param
BATCH_SIZE = 8 #@param
MAX_EPOCHS = 10 #@param
WARMUP_STEPS = 10 #@param
# set this as short as possible for your data
MAX_GEN_LEN = 32 #@param
# if save steps is 0, only last and best model will be written
SAVE_STEPS = 0 #@param

# see
# https://huggingface.co/docs/transformers/v4.46.2/en/main_classes/trainer#transformers.TrainingArguments
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    logging_dir=OUTPUT_DIR + '/logs',
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    fp16=True,
    num_train_epochs=MAX_EPOCHS,
    #
    lr_scheduler_type='constant_with_warmup',
    learning_rate=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    #
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=MAX_GEN_LEN,
    eval_steps=5,
    metric_for_best_model="wer",
    greater_is_better=False,
    #
    save_steps=SAVE_STEPS,
    logging_steps=1,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    #
    push_to_hub=False,
    remove_unused_columns=False,
    eval_on_start=True,
)

In [None]:
#@title Ensure we set the language for training

base_model.generation_config.language = LANGUAGE
base_model.generation_config.task = TASK
base_model.generation_config.forced_decoder_ids = None
base_model.config.forced_decoder_ids = None

# to use gradient checkpointing
base_model.config.use_cache = False

print('language set to:', base_model.generation_config.language)

In [None]:
#@title define which parameters to update
#@markdown For personalizartion, we typically only want to update the encoder and projection layer.
#@markdown Updating the decoder layer may lead to overfitting.
UPDATE_ENCODER = True #@param{type: 'boolean'}
UPDATE_DECODER = False #@param{type: 'boolean'}
UPDATE_PROJ = True #@param{type: 'boolean'}
base_model.model.encoder.requires_grad_(UPDATE_ENCODER)
base_model.model.decoder.requires_grad_(UPDATE_DECODER)
base_model.proj_out.requires_grad_(UPDATE_PROJ)


print('encoder params to update/total:', count_trainable_parameters(base_model.model.encoder), base_model.model.encoder.num_parameters())
print('decoder parans to update/total:', count_trainable_parameters(base_model.model.decoder), base_model.model.decoder.num_parameters())

print('overall # trainable parameters:', count_trainable_parameters(base_model))
print('.   overall # model parameters:', base_model.model.num_parameters())

In [None]:
#@title Define Trainer
import evaluate
metric = evaluate.load("wer")
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=base_model.config.decoder_start_token_id,
)



trainer = Seq2SeqTrainer(
    args=training_args,
    model=base_model,
    train_dataset=my_audio_dataset["train"],
    eval_dataset=my_audio_dataset["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)


## Run Training

In [None]:
# start tensorboard
%load_ext tensorboard
%tensorboard --logdir {LOG_DIR}

In [None]:
%%time
trainer.train()

In [None]:
print('evaluating best model after fine-tuning, lanuage:', LANGUAGE)
trainer.evaluate(my_audio_dataset["validation"], language=LANGUAGE)

In [None]:
#@title Safe trained model

#@markdown It is recommended to save the final model in Google Drive as it is much faster to download it from there than from Colab (especially true for large models).

save_in_drive = False #@param {type: 'boolean'}

model_output_dir_name = 'finetuned_whisper_model' #@param {type: 'string'}


if save_in_drive:
  from google.colab import drive
  drive.mount('/content/drive')
  output_dir = os.path.join('/content/drive/MyDrive/', model_output_dir_name)
else:
  output_dir = os.path.join('/content/', model_output_dir_name)

!mkdir -p {output_dir}

print('Saving model in:', output_dir)

# save model and processor, so we can later load as pretrained
save_model_dir = os.path.join(output_dir, 'saved_model')
trainer.model.save_pretrainedl(save_model_dir, safe_serialization=False)

# save processor also
save_processor_dir = os.path.join(output_dir, 'saved_processor')
processor.save_pretrained(save_processor_dir, safe_serialization=False)

# Test adapted model

In [None]:
def transcribe_from_dataset(dataset_sample, whisper_model, max_new_tokens=128):
  input_features = processor.feature_extractor(
    dataset_sample["array"],
    sampling_rate=dataset_sample["sampling_rate"],
    return_tensors="pt").input_features

  predicted_ids = whisper_model.generate(
      input_features, max_new_tokens=max_new_tokens,
      language=LANGUAGE, task=TASK, forced_decoder_ids=None)
  transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
  return transcription[0].strip()

In [None]:
#@title Get WER on default and tuned model for comparison

default_model = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_TYPE, local_files_only=True)
finetuned_model = WhisperForConditionalGeneration.from_pretrained(save_pretrained_model_dir, local_files_only=True)

num_test_samples = 10 #@param{type: 'number'}
normalize_for_wer_calc = True #@param{type: 'boolean'}

num_test_samples = min(num_test_samples, len(my_audio_dataset['test']))
print('number of test examples to process:', num_test_samples)

predictions = []
finetuned_predictions = []
references  = []

for idx in range(num_test_samples):
  print('inference on example:', idx)
  sample = my_audio_dataset['test'][idx]["audio"]
  predictions.append(transcribe_from_dataset(sample, default_model))
  finetuned_predictions.append(transcribe_from_dataset(sample, finetuned_model))
  references.append(my_audio_dataset['test'][idx]['transcription'])

default_wer = get_wer(references=references, predictions=predictions, normalize=normalize_for_wer_calc, verbose=False)
finetuned_wer = get_wer(references=references, predictions=finetuned_predictions, normalize=normalize_for_wer_calc, verbose=False)

print(f'DEFAULT WER: {default_wer}')
print(f'FINETUNED WER: {finetuned_wer}')


In [None]:
#@title Run inference on individual example of test set
%%time
test_set_idx = 5 #@param
sample = my_audio_dataset['test'][test_set_idx]["audio"]
transcript = my_audio_dataset['test'][test_set_idx]["transcription"]
# use default_model or finetuned_model
model = finetuned_model
pred = transcribe_from_dataset(sample, model, max_new_tokens=32)
print('Ground truth: ', transcript)
print('  Prediction: ', pred)