# Finetuning Whisper-large-V2 on Colab using PEFT-Lora + BNB INT8 training

In this Colab, we present a step-by-step guide on how to fine-tune Whisper for any multilingual ASR dataset using Hugging Face 🤗 Transformers and 🤗 PEFT. Using 🤗 PEFT and `bitsandbytes`, you can train the `whisper-large-v2` seamlessly on a colab with T4 GPU (16 GB VRAM). In this notebook, with most parts from [fine_tune_whisper.ipynb](https://colab.research.google.com/github/sanchit-gandhi/notebooks/blob/main/fine_tune_whisper.ipynb#scrollTo=BRdrdFIeU78w) is adapted to train using PEFT LoRA+BNB INT8.

For more details on model, datasets and metrics, refer blog [Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers](https://huggingface.co/blog/fine-tune-whisper)



## Inital Setup

In [None]:
!add-apt-repository -y ppa:jonathonf/ffmpeg-4
!apt update
!apt install -y ffmpeg

Repository: 'deb https://ppa.launchpadcontent.net/jonathonf/ffmpeg-4/ubuntu/ jammy main'
Description:
Backport of FFmpeg 4 and associated libraries. Now includes AOM/AV1 support!

FDK AAC is not compatible with GPL and FFmpeg can't be redistributed with it included. Please don't ask for it to be added to this public PPA.

---

PPA supporters:

BigBlueButton (https://bigbluebutton.org)

---

Donate to FFMPEG: https://ffmpeg.org/donations.html
Donate to Debian: https://www.debian.org/donations
Donate to this PPA: https://ko-fi.com/jonathonf
More info: https://launchpad.net/~jonathonf/+archive/ubuntu/ffmpeg-4
Adding repository.
Adding deb entry to /etc/apt/sources.list.d/jonathonf-ubuntu-ffmpeg-4-jammy.list
Adding disabled deb-src entry to /etc/apt/sources.list.d/jonathonf-ubuntu-ffmpeg-4-jammy.list
Adding key to /etc/apt/trusted.gpg.d/jonathonf-ubuntu-ffmpeg-4.gpg with fingerprint 4AB0F789CBA31744CC7DA76A8CF63AD3F06FC659
Get:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ In

In [None]:
!pip install datasets>=2.6.1
!pip install git+https://github.com/huggingface/transformers
!pip install librosa
!pip install evaluate>=0.30
!pip install jiwer
!pip install gradio
!pip install -q bitsandbytes datasets accelerate loralib
!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git@main

Collecting git+https://github.com/huggingface/transformers
  Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-yfb501vx
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers /tmp/pip-req-build-yfb501vx
  Resolved https://github.com/huggingface/transformers to commit 976189a6df796a2ff442dd81b022626c840d8c27
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: transformers
  Building wheel for transformers (pyproject.toml) ... [?25l[?25hdone
  Created wheel for transformers: filename=transformers-4.37.0.dev0-py3-none-any.whl size=8342965 sha256=9590537959909ff61a9df7e57aa600dca184b931adca50ef9e402fb0e5a52fdb
  Stored in directory: /tmp/pip-ephem-wheel-cache-76rwu0sq/wheels/c0/14/d6/6c9a5582d2ac191ec0a483be151a4495fe1eb2a6706ca49f1b
Successfully built transformers

Linking the notebook to the Hub is straightforward - it simply requires entering your Hub authentication token when prompted. Find your Hub authentication token [here](https://huggingface.co/settings/tokens):

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [1]:
# Select CUDA device index
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
model_name_or_path = "openai/whisper-large-v2"
language = "Chinese"
language_abbr = "zh-TW"
task = "transcribe"
dataset_name = "mozilla-foundation/common_voice_11_0"

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# IMPORTS
import os
import re
import subprocess
import torch
import evaluate
import json
import pandas as pd
import glob

# from datasets import load_dataset, DatasetDict
# from transformers import WhisperFeatureExtractor
# from transformers import WhisperTokenizer
# from transformers import WhisperProcessor
# from transformers import WhisperForConditionalGeneration
# from peft import prepare_model_for_training
# from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model, PeftConfig
# from transformers import Seq2SeqTrainingArguments
# from transformers import Seq2SeqTrainer

from datasets import Audio
from dataclasses import dataclass
from typing import Any, Dict, List, Union


# Convert mp4 into WAV with sample rate 16000

In [3]:
mp4_folder = '/content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/data'
wav_folder = '/content/drive/MyDrive/WAV_data'
cut_wav_folder = '/content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/WAV_cut'
srt_folder = '/content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/SRT'
train_test_split_folder = "/content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/WAV_train_test_split/"
train_test_split_data_folder = "/content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/WAV_train_test_split/data"
testing_folder = os.path.join(train_test_split_data_folder, 'test')
training_folder = os.path.join(train_test_split_data_folder, 'train')


# assert(os.path.exists(mp4_folder))
# assert(os.path.exists(wav_folder))
assert(os.path.exists(cut_wav_folder))
assert(os.path.exists(srt_folder))
assert(os.path.exists(train_test_split_folder))
assert(os.path.exists(train_test_split_data_folder))
assert(os.path.exists(testing_folder))
assert(os.path.exists(training_folder))

In [None]:
for mp4_file in os.listdir(mp4_folder):
  if mp4_file.endswith('.mp4'):
    mp4_path = os.path.join(mp4_folder, mp4_file)

    wav_file = os.path.splitext(mp4_file)[0] + '.wav'
    wav_path = os.path.join(wav_folder, wav_file)

    command = f"ffmpeg -i '{mp4_path}' '{wav_path}'"
    os.system(command)

    print(f"Converted {mp4_file} to {wav_path}")


Converted CLC015-3-字幕版.mp4 to /content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/WAV_data/CLC015-3-字幕版.wav
Converted CLC015-2-字幕版.mp4 to /content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/WAV_data/CLC015-2-字幕版.wav
Converted CLC015-1-字幕版.mp4 to /content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/WAV_data/CLC015-1-字幕版.wav
Converted CLC014-2-字幕版.mp4 to /content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/WAV_data/CLC014-2-字幕版.wav
Converted CLC014-1-字幕版.mp4 to /content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/WAV_data/CLC014-1-字幕版.wav
Converted 第13課-4字幕版.mp4 to /content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/WAV_data/第13課-4字幕版.wav
Converted 第13課-1-字幕版.mp4 to /content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/WAV_data/第13課-1-字幕版.wav
Converted 第13課-3-字幕版.mp4 to /content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/WAV_da

# Cut and Convert Using ffmpeg

In [4]:
def replace_spaces(text: str):
  return text.replace(" ", "")

def remove_punctuation(text: str):
  return re.sub(r'[^\w\s]','',text)

def read_srt(file_path):
  with open(file_path, 'r', encoding='utf-8-sig') as file:
    lines = file.readlines()

  subtitles = []
  current_subtitle = None

  for line in lines:
    if line is None: continue
    line = line.strip()

    if line.isdigit():
      if current_subtitle is not None:
          subtitles.append(current_subtitle)
      current_subtitle = {"index": int(line), "text": ""}
    elif "-->" in line:
      start, end = line.split("-->")
      current_subtitle["start"] = start.strip()
      current_subtitle["end"] = end.strip()
    elif line is not None and line != "":
      if current_subtitle is not None:
          line = replace_spaces(line)
          line = remove_punctuation(line)
          current_subtitle["text"] += line + " "

  if current_subtitle is not None:
      subtitles.append(current_subtitle)

  return subtitles


In [5]:
# Function to convert timestamp to seconds
def timestamp_to_seconds(timestamp):
    timestamp = timestamp.replace(",", ".")
    h, m, s = map(float, timestamp.split(':'))
    return h * 3600 + m * 60 + s

In [6]:
import json

# Opening JSON file
with open('/content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/srt_subtitles_json.json', 'r') as openfile:

	# Reading from json file
	srt_subtitles = json.load(openfile)

print(type(srt_subtitles))
print(srt_subtitles.keys())
print(srt_subtitles[next(iter(srt_subtitles.keys()))].keys())
print(srt_subtitles[next(iter(srt_subtitles.keys()))])


<class 'dict'>
dict_keys(['第10課-2', '第10課-4', '第13課-4', '第11課-4', '第2課-2', '第10課-1', 'CLC014-1', '第12课-4', '第11課-1', '第3課-3', 'CLC015-2', '第13課-1', '第7課-3', '第4課-2', '第2課-1', '第5課-2', '第6課-1', '第8課-1', '第12课-3', '第4課-1', '第2課-3', 'CLC014-2', '第9課-2', 'CLC015-1', '第7課-1', '第8課-4', '第6課-3', '第9課-3', '第8課-2', 'CLC014-3', '第5課-1', '第11課-3', '第3課-1', '第13課-3', '第11課-2', '第3課-2', '第13課-2', '第6課-2', '第12课-1', '第12课-2', '第9課-1', 'CLC015-3', '第7課-2', '第9課-4', '第5課-3', '第8課-3', '第10課-3'])
dict_keys(['srt_filepath', 'subtitles', 'audio_filepath', 'audio_files_subtitle_text'])
{'srt_filepath': '/content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/SRT/第10課-2-final.srt', 'subtitles': [{'index': 1, 'text': '大家看那個章天亮教授那個笑談風雲裏面 ', 'start': '00:00:03,545', 'end': '00:00:06,423'}, {'index': 2, 'text': '他前面講了這個三場大風造成改變了中國的歷史 ', 'start': '00:00:06,423', 'end': '00:00:10,510'}, {'index': 3, 'text': '這都是真實的事情 ', 'start': '00:00:10,510', 'end': '00:00:11,970'}, {'index': 4, 'text': '那我再補充一

In [None]:
finished

['第10課-2',
 '第10課-4',
 '第13課-4',
 '第11課-4',
 '第2課-2',
 '第10課-1',
 'CLC014-1',
 '第12课-4',
 '第11課-1',
 '第3課-3',
 'CLC015-2',
 '第13課-1',
 '第7課-3',
 '第4課-2',
 '第2課-1',
 '第5課-2',
 '第6課-1',
 '第8課-1',
 '第12课-3',
 '第4課-1',
 '第2課-3',
 'CLC014-2',
 '第9課-2',
 'CLC015-1',
 '第7課-1',
 '第8課-4',
 '第6課-3',
 '第9課-3',
 '第8課-2',
 'CLC014-3',
 '第5課-1',
 '第11課-3',
 '第3課-1',
 '第13課-3',
 '第11課-2',
 '第3課-2',
 '第13課-2',
 '第6課-2',
 '第12课-1',
 '第12课-2',
 '第9課-1',
 'CLC015-3']

In [None]:
files_left = ['第6課-1']

In [None]:
all_wav_files = set(os.listdir(wav_folder))
all_srt_files = set(os.listdir(srt_folder))
all_srt_files = [i for i in all_srt_files if i.endswith('.srt')]
all_wav_files = [i for i in all_wav_files if i.endswith('.wav')]

for srt_file in all_srt_files:
  if not srt_file.endswith('.srt'):
    continue

  srt_path = os.path.join(srt_folder, srt_file)

  lecture_name = "-".join( srt_file.split("-")[:2])

  if '.srt' in lecture_name:
    lecture_name = lecture_name.split('.srt')[0]

  lecture_name = lecture_name.strip()

  # if lecture_name in finished:
  #   continue
  if lecture_name not in files_left:
    continue
  elif srt_subtitles.get(lecture_name) is not None:
    print(f"{srt_file} already has SRT stored?")
    print(f"overwriting {srt_file}")

  subtitles = read_srt(srt_path)
  print(subtitles[0])
  print(subtitles[1])


  srt_subtitles[lecture_name] = {}
  srt_subtitles[lecture_name]['srt_filepath'] = srt_path
  srt_subtitles[lecture_name]['subtitles'] = subtitles
  srt_subtitles[lecture_name]['audio_filepath'] =  []
  srt_subtitles[lecture_name]['audio_files_subtitle_text'] =  []

  # We are Loop through all mp4 files in the folder
  for wav_file_indx, wav_file in enumerate(all_wav_files):

    if not wav_file.endswith('.wav'):
      continue

    wav_lecture_name = "-".join( wav_file.split("-")[:2] )
    wav_lecture_name = wav_lecture_name.strip()

    if wav_lecture_name != lecture_name:
      if wav_file_indx  >= len(all_wav_files) - 1:
        print("FAILURE," + lecture_name +  " could not be found " + wav_lecture_name)
      continue


    wav_path = os.path.join(wav_folder, wav_file)



    for subtitle_indx, subtitle in enumerate(subtitles):
      start_seconds = timestamp_to_seconds(subtitle["start"])
      end_seconds = timestamp_to_seconds(subtitle["end"])

      start_seconds -= 0.5
      end_seconds += 0.5
      start_seconds = max(start_seconds, 0)
      start_seconds = max(start_seconds, 0)


      difference = end_seconds - start_seconds
      if not isinstance(difference,(int, float)) or  not isinstance(start_seconds,(int, float)):
        print(start_seconds, type(start_seconds), difference, type(difference), wav_path)
        continue



      cut_wav_file = os.path.splitext(wav_file)[0] + f'.{subtitle_indx}._cut.wav'
      cut_wav_path = os.path.join(cut_wav_folder, cut_wav_file)



      cut_command = f"ffmpeg -i '{wav_path}' -ss {start_seconds} -t {difference} '{cut_wav_path}'"
      subprocess.run(cut_command, shell=True)

      srt_subtitles[lecture_name]['audio_filepath'].append(cut_wav_path)
      srt_subtitles[lecture_name]['audio_files_subtitle_text'].append(subtitle['text'])

      # print(f"Cut {wav_path} to {cut_wav_file} for subtitle {subtitle['index']}")

    print("SUCCESS: ", wav_lecture_name)
    break


In [None]:

srt_subtitles_json_obj = json.dumps(srt_subtitles, indent=4)

with open("/content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/srt_subtitles_json.json", "w") as outfile:
	outfile.write(srt_subtitles_json_obj)


## Load Dataset

In [None]:
test_subtitles_text = []
test_subtitles_audio = []
train_subtitles_text = []
train_subtitles_audio = []

for lecture_name, lecture_dict in srt_subtitles.items():
  if "CLC014" in lecture_name or "CLC015" in lecture_name:
    for indx, subtitle_text in enumerate(lecture_dict['audio_files_subtitle_text']):
      test_subtitles_text.append(subtitle_text)
      test_subtitles_audio.append(lecture_dict.get('audio_filepath')[indx])
    continue

  for indx, subtitle_text in enumerate(lecture_dict['audio_files_subtitle_text']):
    train_subtitles_text.append(subtitle_text)
    train_subtitles_audio.append(lecture_dict.get('audio_filepath')[indx])

test_subtitles_text = [i.strip() for i in test_subtitles_text]
test_subtitles_audio = [i.strip() for i in test_subtitles_audio]
test_subtitles_audio = [i.split('/')[-1] for i in test_subtitles_audio]
test_subtitles_audio = [os.path.join('data/test',i) for i in test_subtitles_audio]

train_subtitles_text = [i.strip() for i in train_subtitles_text]
train_subtitles_audio = [i.strip() for i in train_subtitles_audio]
train_subtitles_audio = [i.split('/')[-1] for i in train_subtitles_audio]
train_subtitles_audio = [os.path.join('data/train',i) for i in train_subtitles_audio]
assert(len(test_subtitles_text) == len(test_subtitles_audio))
assert(len(train_subtitles_text) == len(train_subtitles_audio))
print(len(train_subtitles_audio), len(test_subtitles_audio))
print(len(glob.glob(training_folder + '/*')), len(glob.glob(testing_folder + '/*')))

18399 2961
18395 2960


In [None]:
print(len(os.listdir(training_folder)), len(os.listdir(testing_folder)))

18395 2960


In [None]:
all_subtitles_text = train_subtitles_text + test_subtitles_text
all_subtitles_audio = train_subtitles_audio + test_subtitles_audio
assert(len(all_subtitles_text) == len(all_subtitles_audio))

In [None]:
'/content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/WAV_train_test_split/data/test/CLC014-2-字幕版.56._cut.wav'.split('WAV_train_test_split/')[-1]

'data/test/CLC014-2-字幕版.56._cut.wav'

In [None]:
all_training_files = glob.glob(training_folder + '/*')
all_testing_files = glob.glob(testing_folder + '/*')
all_training_files = [i.split('WAV_train_test_split/')[-1] for i in all_training_files]
all_testing_files = [i.split('WAV_train_test_split/')[-1] for i in all_testing_files]
all_training_files = list(set(all_training_files))
all_testing_files = list(set(all_testing_files))

for indx, i_file in enumerate(all_subtitles_audio):
  if i_file not in all_training_files and i_file not in all_testing_files:
    print(i_file)

  if i_file == "" or i_file is None or all_subtitles_text[indx] == "" or all_subtitles_text is None:
    print(f"Removed {i_file}, {all_subtitles_text[indx]}")
    all_subtitles_text.pop(indx)
    all_subtitles_audio.pop(indx)


data/train/第10課-1-字幕版.129._cut.wav
Removed data/train/第10課-1-字幕版.129._cut.wav, 
data/train/第4課-1-片頭片尾字幕版.115._cut.wav
Removed data/train/第4課-1-片頭片尾字幕版.115._cut.wav, 
data/train/第9課-3-字幕版.113._cut.wav
Removed data/train/第9課-3-字幕版.113._cut.wav, 
data/train/第9課-4-字幕版.193._cut.wav
Removed data/train/第9課-4-字幕版.193._cut.wav, 
data/test/CLC014-2-字幕版.300._cut.wav
Removed data/test/CLC014-2-字幕版.300._cut.wav, 


In [None]:
!rm -rf '/content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/WAV_train_test_split/data/test/CLC014-2-字幕版.300._cut.wav'
!rm -rf '/content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/WAV_train_test_split/data/train/第9課-3-字幕版.113._cut.wav'
!rm -rf '/content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/WAV_train_test_split/data/train/第9課-4-字幕版.193._cut.wav'
!rm -rf '/content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/WAV_train_test_split/data/train/第4課-1-片頭片尾字幕版.115._cut.wav'
!rm -rf '/content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/WAV_train_test_split/data/train/第10課-1-字幕版.129._cut.wav'

In [None]:
metadata_df = pd.DataFrame()
metadata_df['file_name'] = all_subtitles_audio
metadata_df['sentence'] = all_subtitles_text

metadata_df.to_csv(f"{train_test_split_folder}/metadata.csv", index = False)
metadata_df

Unnamed: 0,file_name,sentence
0,data/train/第10課-2-字幕版.0._cut.wav,大家看那個章天亮教授那個笑談風雲裏面
1,data/train/第10課-2-字幕版.1._cut.wav,他前面講了這個三場大風造成改變了中國的歷史
2,data/train/第10課-2-字幕版.2._cut.wav,這都是真實的事情
3,data/train/第10課-2-字幕版.3._cut.wav,那我再補充一個
4,data/train/第10課-2-字幕版.4._cut.wav,當時就是明太祖跟陳友諒大戰的時候
...,...,...
21350,data/test/CLC015-3-字幕版.565._cut.wav,就是寫史書
21351,data/test/CLC015-3-字幕版.566._cut.wav,學生史書啊
21352,data/test/CLC015-3-字幕版.567._cut.wav,對你用白話文把我給你的這些材料寫成通順的文章
21353,data/test/CLC015-3-字幕版.568._cut.wav,就是介紹這一朝講這一朝歷史


In [None]:
full_dataset

DatasetDict({
    train: Dataset({
        features: ['audio', 'transcription'],
        num_rows: 0
    })
    test: Dataset({
        features: ['audio', 'transcription'],
        num_rows: 0
    })
})

In [None]:
from datasets import load_dataset, DatasetDict

entire_dataset = load_dataset("audiofolder", data_dir=train_test_split_folder)

In [None]:
print(entire_dataset)

DatasetDict({
    train: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 18395
    })
    test: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 2960
    })
})


In [None]:
type(entire_dataset)

datasets.dataset_dict.DatasetDict

In [None]:
entire_dataset.keys()

dict_keys(['train', 'test'])

In [None]:
type(entire_dataset['train'])

In [None]:
print(entire_dataset["train"]['train'][0])

In [None]:
test_dataset

## Prepare Feature Extractor, Tokenizer and Data

In [7]:
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name_or_path)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [8]:
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained(model_name_or_path, language=language, task=task)

In [9]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained(model_name_or_path, language=language, task=task)

### Prepare Data

Since
our input audio is sampled at 48kHz, we need to _downsample_ it to
16kHz prior to passing it to the Whisper feature extractor, 16kHz being the sampling rate expected by the Whisper model.

We'll set the audio inputs to the correct sampling rate using dataset's
[`cast_column`](https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=cast_column#datasets.DatasetDict.cast_column)
method. This operation does not change the audio in-place,
but rather signals to `datasets` to resample audio samples _on the fly_ the
first time that they are loaded:

In [12]:
from datasets import Audio

entire_dataset = entire_dataset.cast_column("audio", Audio(sampling_rate=16000))

Re-loading the first audio sample in the Common Voice dataset will resample
it to the desired sampling rate:

In [None]:
import pickle

with open('entire_dataset.pkl', 'wb') as f:
    pickle.dump(entire_dataset, f)

In [None]:
print(entire_dataset["train"][0])

{'audio': {'path': '/content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/WAV_train_test_split/data/train/第10課-1-字幕版.0._cut.wav', 'array': array([ 0.        ,  0.        ,  0.        , ..., -0.11165189,
       -0.07839113, -0.02468587]), 'sampling_rate': 16000}, 'sentence': '這節課我們結束了第一個單元'}


Now we can write a function to prepare our data ready for the model:
1. We load and resample the audio data by calling `batch["audio"]`. As explained above, 🤗 Datasets performs any necessary resampling operations on the fly.
2. We use the feature extractor to compute the log-Mel spectrogram input features from our 1-dimensional audio array.
3. We encode the transcriptions to label ids through the use of the tokenizer.

In [13]:
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

We can apply the data preparation function to all of our training examples using dataset's `.map` method. The argument `num_proc` specifies how many CPU cores to use. Setting `num_proc` > 1 will enable multiprocessing. If the `.map` method hangs with multiprocessing, set `num_proc=1` and process the dataset sequentially.

In [None]:
entire_dataset['train'][33]

{'audio': {'path': '/content/drive/Shareddrives/FTCM/CIS335 ML&AI/Group Bonus Project/Part 2/WAV_train_test_split/data/train/第10課-1-字幕版.128._cut.wav',
  'array': array([-0.00620656, -0.00869003, -0.00709007, ..., -0.08543612,
         -0.05781476, -0.02823488]),
  'sampling_rate': 16000},
 'sentence': '然後不是一個'}

In [None]:
!lscpu

Architecture:            x86_64
  CPU op-mode(s):        32-bit, 64-bit
  Address sizes:         46 bits physical, 48 bits virtual
  Byte Order:            Little Endian
CPU(s):                  8
  On-line CPU(s) list:   0-7
Vendor ID:               GenuineIntel
  Model name:            Intel(R) Xeon(R) CPU @ 2.00GHz
    CPU family:          6
    Model:               85
    Thread(s) per core:  2
    Core(s) per socket:  4
    Socket(s):           1
    Stepping:            3
    BogoMIPS:            4000.28
    Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clf
                         lush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_
                         good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fm
                         a cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hyp
                         ervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd i

In [14]:
entire_dataset = entire_dataset.map(prepare_dataset,
                                    num_proc=8)

In [None]:
common_voice["train"]

In [None]:
entire_dataset["train"]

## Training and Evaluation

### Define a Data Collator

In [15]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union


@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

Let's initialise the data collator we've just defined:

In [16]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

### Evaluation Metrics

We'll use the word error rate (WER) metric, the 'de-facto' metric for assessing
ASR systems. For more information, refer to the WER [docs](https://huggingface.co/metrics/wer). We'll load the WER metric from 🤗 Evaluate:

In [17]:
import evaluate

metric = evaluate.load("wer")

We then simply have to define a function that takes our model
predictions and returns the WER metric. This function, called
`compute_metrics`, first replaces `-100` with the `pad_token_id`
in the `label_ids` (undoing the step we applied in the
data collator to ignore padded tokens correctly in the loss).
It then decodes the predicted and label ids to strings. Finally,
it computes the WER between the predictions and reference labels:

In [18]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

### Load a Pre-Trained Checkpoint

Now let's load the pre-trained Whisper `small` checkpoint. Again, this
is trivial through use of 🤗 Transformers!

In [19]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, load_in_8bit=True, device_map="auto")

# model.hf_device_map - this should be {" ": 0}

Override generation arguments - no tokens are forced as decoder outputs (see [`forced_decoder_ids`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)), no tokens are suppressed during generation (see [`suppress_tokens`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.suppress_tokens)):

In [None]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

### Post-processing on the model

Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons.

In [21]:
from peft import prepare_model_for_int8_training

model = prepare_model_for_int8_training(model)
# model = prepare_model_for_training(model, output_embedding_layer_name="proj_out")



### Apply LoRA

Here comes the magic with `peft`! Let's load a `PeftModel` and specify that we are going to use low-rank adapters (LoRA) using `get_peft_model` utility function from `peft`.

In [22]:
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model

config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")

model = get_peft_model(model, config)
model.print_trainable_parameters()

trainable params: 15,728,640 || all params: 1,559,033,600 || trainable%: 1.0088711365810203


We are ONLY using **1%** of the total trainable parameters, thereby performing **Parameter-Efficient Fine-Tuning**

### Define the Training Configuration

In the final step, we define all the parameters related to training. For more detail on the training arguments, refer to the Seq2SeqTrainingArguments [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments).

In [35]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="temp",  # change to a repo name of your choice
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-3,
    warmup_steps=50,
    num_train_epochs=3,
    evaluation_strategy="epoch",
    fp16=True,
    per_device_eval_batch_size=4,
    generation_max_length=128,
    logging_steps=25,
    remove_unused_columns=False,  # required as the PeftModel forward doesn't have the signature of the wrapped model's forward
    label_names=["labels"],  # same reason as above
)

**Few Important Notes:**
1. `remove_unused_columns=False` and `label_names=["labels"]` are required as the PeftModel's forward doesn't have the signature of the base model's forward.

2. INT8 training required autocasting. `predict_with_generate` can't be passed to Trainer because it internally calls transformer's `generate` without autocasting leading to errors.

3. Because of point 2, `compute_metrics` shouldn't be passed to `Seq2SeqTrainer` as seen below. (commented out)

In [36]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=entire_dataset["train"],
    eval_dataset=entire_dataset["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)
model.config.use_cache = True  # silence the warnings. Please re-enable for inference!

Before training

In [None]:
trainer.evaluate()

In [28]:
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!

In [38]:
import torch
from numba import cuda
device = cuda.get_current_device()
device.reset()
# torch.cuda.empty_cache()

In [None]:
trainer.train()

In [None]:
model_name_or_path = "openai/whisper-large-v2"
peft_model_id = "smangrul/" + f"{model_name_or_path}-{model.peft_config.peft_type}-colab".replace("/", "-")
model.push_to_hub(peft_model_id)
print(peft_model_id)

# Evaluation and Inference

**Important points to note while inferencing**:
1. As `predict_with_generate` can't be used, we will write the eval loop with `torch.cuda.amp.autocast()` as shown below.
2. As the base model is frozen, PEFT model sometimes fails ot recognise the language while decoding.Hence, we force the starting tokens to mention the language we are transcribing. This is done via `forced_decoder_ids = processor.get_decoder_prompt_ids(language="Marathi", task="transcribe")` and passing that too the `model.generate` call.
3. Please note that [AutoEvaluate Leaderboard](https://huggingface.co/spaces/autoevaluate/leaderboards?dataset=mozilla-foundation%2Fcommon_voice_11_0&only_verified=0&task=automatic-speech-recognition&config=mr&split=test&metric=wer) for `mr` language on `common_voice_11_0` has a bug wherein openai's `BasicTextNormalizer` normalizer is used while evaluation leading to degerated output text, an example is shown below:
```
without normalizer: 'स्विच्चान नरुवित्तीची पद्दत मोठ्या प्रमाणात आमलात आणल्या बसोन या दुपन्याने अनेक राथ प्रवेश केला आहे.'
with normalizer: 'स व च च न नर व त त च पद दत म ठ य प रम ण त आमल त आणल य बस न य द पन य न अन क र थ प रव श क ल आह'
```
Post fixing this bug, we report the 2 metrics for the top model of the leaderboard and the PEFT model:
1. `wer`: `wer` without using the `BasicTextNormalizer` as it doesn't cater to most indic languages. This is want we consider as true performance metric.
2. `normalized_wer`: `wer` using the `BasicTextNormalizer` to be comparable to the leaderboard metrics.
Below are the results:

| Model          | DrishtiSharma/whisper-large-v2-marathi | smangrul/openai-whisper-large-v2-LORA-colab |
|----------------|----------------------------------------|---------------------------------------------|
| wer            | 35.6457                                | 36.1356                                     |
| normalized_wer | 13.6440                                | 14.0165                                     |

We see that PEFT model's performance is comparable to the fully fine-tuned model on the top of the leaderboard. At the same time, we are able to train the large model in Colab notebook with limited GPU memory and the added advantage of resulting checkpoint being jsut `63` MB.



In [None]:
from peft import PeftModel, PeftConfig
from transformers import WhisperForConditionalGeneration, Seq2SeqTrainer

peft_model_id = "smangrul/openai-whisper-large-v2-LORA-colab"
peft_config = PeftConfig.from_pretrained(peft_model_id)
model = WhisperForConditionalGeneration.from_pretrained(
    peft_config.base_model_name_or_path, load_in_8bit=True, device_map="auto"
)
model = PeftModel.from_pretrained(model, peft_model_id)

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import gc

eval_dataloader = DataLoader(common_voice["test"], batch_size=8, collate_fn=data_collator)

model.eval()
for step, batch in enumerate(tqdm(eval_dataloader)):
    with torch.cuda.amp.autocast():
        with torch.no_grad():
            generated_tokens = (
                model.generate(
                    input_features=batch["input_features"].to("cuda"),
                    decoder_input_ids=batch["labels"][:, :4].to("cuda"),
                    max_new_tokens=255,
                )
                .cpu()
                .numpy()
            )
            labels = batch["labels"].cpu().numpy()
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
            decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
            metric.add_batch(
                predictions=decoded_preds,
                references=decoded_labels,
            )
    del generated_tokens, labels, batch
    gc.collect()
wer = 100 * metric.compute()
print(f"{wer=}")

## Using AutomaticSpeechRecognitionPipeline

**Few important notes:**
1. `pipe()` should be in the autocast context manager `with torch.cuda.amp.autocast():`
2. `forced_decoder_ids` specifying the `language` being transcribed should be provided in `generate_kwargs` dict.
3. You will get warning along the below lines which is **safe to ignore**.
```
The model 'PeftModel' is not supported for . Supported models are ['SpeechEncoderDecoderModel', 'Speech2TextForConditionalGeneration', 'SpeechT5ForSpeechToText', 'WhisperForConditionalGeneration', 'Data2VecAudioForCTC', 'HubertForCTC', 'MCTCTForCTC', 'SEWForCTC', 'SEWDForCTC', 'UniSpeechForCTC', 'UniSpeechSatForCTC', 'Wav2Vec2ForCTC', 'Wav2Vec2ConformerForCTC', 'WavLMForCTC'].

```

In [None]:
import torch
import gradio as gr
from transformers import (
    AutomaticSpeechRecognitionPipeline,
    WhisperForConditionalGeneration,
    WhisperTokenizer,
    WhisperProcessor,
)
from peft import PeftModel, PeftConfig


peft_model_id = "smangrul/openai-whisper-large-v2-LORA-colab"
language = "Chinese"
task = "transcribe"
peft_config = PeftConfig.from_pretrained(peft_model_id)
model = WhisperForConditionalGeneration.from_pretrained(
    peft_config.base_model_name_or_path, load_in_8bit=True, device_map="auto"
)

model = PeftModel.from_pretrained(model, peft_model_id)
tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
feature_extractor = processor.feature_extractor
forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)
pipe = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)


def transcribe(audio):
    with torch.cuda.amp.autocast():
        text = pipe(audio, generate_kwargs={"forced_decoder_ids": forced_decoder_ids}, max_new_tokens=255)["text"]
    return text


iface = gr.Interface(
    fn=transcribe,
    inputs=gr.Audio(source="microphone", type="filepath"),
    outputs="text",
    title="PEFT LoRA + INT8 Whisper Large V2 Marathi",
    description="Realtime demo for Marathi speech recognition using `PEFT-LoRA+INT8` fine-tuned Whisper Large V2 model.",
)

iface.launch(share=True)