In [1]:
import os

import pandas as pd

import librosa
import librosa.display

import numpy as np
import copy

import IPython.display as ipd

import matplotlib.pyplot as plt

import random

from collections import Counter
from torch.utils.data import DataLoader

from sklearn.model_selection import train_test_split

import torch
import torchaudio

from dataclasses import dataclass
from typing import Any, Dict, List, Union
from datasets import DatasetDict
from datasets import concatenate_datasets
from datasets import Dataset as DS

from transformers import (
    WhisperFeatureExtractor,
    WhisperTokenizer,
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    TrainerCallback,
    TrainingArguments,
    TrainerState,
    TrainerControl,
    EarlyStoppingCallback,
    pipeline
)
from torch.optim import AdamW

from torchmetrics.text import WordErrorRate, CharErrorRate

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# !pip freeze > requirements.txt

# Dataset

In [4]:
BASE_DIR = '/kaggle/input/final-splits'
train_data_dir = f"{BASE_DIR}/final_splits/train/"
val_data_dir = f"{BASE_DIR}/final_splits/valid/"
data_path = "/kaggle/input/final-splits/final_splits/train/train.xlsx"

data = pd.read_excel(data_path)
data["transcriptions"] = data["transcriptions"].str.strip()
data["file_name"] = data["file_name"].str.strip()
data = data.drop(columns=["External_ID", "district","Split","annotator"])
data

Unnamed: 0,file_name,transcriptions
0,train_barishal (1).wav,"আসসালামু আলাইকুম, আমার নাম হাসিবুর রহমান শুব, ..."
1,train_barishal (10).wav,<> ভালোই লাগছে। অ্যাসাইনমেন্ট করতে বসছি একটা। ...
2,train_barishal (100).wav,"আমি তো বলছি ভাই ও আমারে কইবে, মানে তুই কও ক্যা..."
3,train_barishal (103).wav,মোর পছন্দের শখ হইলো গান হোনা। গান হোনতে ব্যাক ...
4,train_barishal (104).wav,"মুই সকল ধরনের গান হুনি, খালি বলিউডের হিন্দি গা..."
...,...,...
13337,train_tangail (995).wav,টাঙ্গাইল শহর থেইকা প্রাহ ছয় কিলোমিটার দক্ষিণে ...
13338,train_tangail (996).wav,মসজিদটি মূলত বর্গাকৃতির একটি গম্বুজ বিশিষ্ট। এ...
13339,train_tangail (997).wav,লাল ইট দ্বারা নির্মিত এই মসজিদটি আকারে বেশ ছোট...
13340,train_tangail (998).wav,সুলতানী ও মুঘোল এই দুই আমলের স্থাপত্য রিতীর সু...


In [5]:
val_data_path = "/kaggle/input/final-splits/final_splits/valid/valid.xlsx"

val_data = pd.read_excel(val_data_path)
val_data["transcriptions"] = val_data["transcriptions"].str.strip()
val_data["file_name"] = val_data["file_name"].str.strip()
val_data = val_data.drop(columns=["External_ID", "district","Split","annotator"])
val_data

Unnamed: 0,file_name,transcriptions
0,valid_barishal (1).wav,"সঞ্জয়দা, কাম কইরা আইবে না এইরম এইরম তার, স্বাম..."
1,valid_barishal (10).wav,"কিছু কিছু ছাত্র আছে মাস্টার গোলাইয়া ওই, খাইলেও..."
2,valid_barishal (100).wav,"জমার পরে পানিডা হালাইয়া দিয়া আবার পানি দিমু, প..."
3,valid_barishal (101).wav,এইরপর এই বান্ড বরমু। বান্ডে রাখমু। বাডা-বোডা ভ...
4,valid_barishal (102).wav,কি খেলা যাইয়া পছন্দ ছিলো? ছোডো বেলায় তো এই ই খ...
...,...,...
1661,valid_tangail (95).wav,আসসালামু আলাইকুম। আমি মোহাম্মদ সামিউল ইসলাম সৈ...
1662,valid_tangail (96).wav,আঞ্চলিক ভাষায় কিছুক্ষণ কথা বলবো প্রেসেন্ট ইস্য...
1663,valid_tangail (97).wav,আলহামদুলিল্লা! ভালো। তারপর? দিনকাল কেমন যাইতাছ...
1664,valid_tangail (98).wav,"আগে তোর ঈদের প্ল্যান বল। ঈদের প্ল্যান, আসলাম, ..."


In [6]:
# remove punctuations
punctuations = [
    "/::\)","/::","(-_-)","(*_*)","(>_<)",":)",";)",":P","xD","-_-","*_*","(>_<)","...",".",",",";",":","!","?","'","অ�", "অাবার", "।"
    "\"","-","_","/","\\","|","{","}","[","]","(",")","<",">","@","#","$","%","^","&","*","~","`","+","=","0","1","2","3","4","5","6","7","8","9","৳","০",
    "১","২","৩","৪","৫","৬","৭","৮","৯","\n","\t","\r","\f","\v","\u00C0-\u017F","\u2000-\u206F","\u25A0-\u25FF","\u2600-\u26FF","\u2B00-\u2BFF","\u3000-\u303F",
    "\uFB00-\uFB4F","\uFE00-\uFE0F","\uFE30-\uFE4F","\u1F600-\u1F64F","\u1F300-\u1F5FF","\u1F680-\u1F6FF","\u1F1E0-\u1F1FF","\u2600-\u26FF","\u2700-\u27BF",
    "\u1F300-\u1F5FF","\u1F900-\u1F9FF","\u1F600-\u1F64F","\u1F680-\u1F6FF","\u1F1E0-\u1F1FF","\u1F600-\u1F64F",
]
def remove_punctuations(text):
    for punctuation in punctuations:
        text = text.replace(punctuation, "")
    return text

import re
def remove_emoji(text):
    emoji_pattern = re.compile(
        "["u"\U0001F600-\U0001F64F"  # emoticons
        u"\U0001F300-\U0001F5FF"  # symbols & pictographs
        u"\U0001F680-\U0001F6FF"  # transport & map symbols
        u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
        u"\U00002702-\U000027B0"
        u"\U000024C2-\U0001F251"
        "]+",
        flags=re.UNICODE,
    )
    return emoji_pattern.sub(r"", text)
def remove_extra_space(text):
    text = re.sub(r"[a-zA-Z]+", "", text)
    text = re.sub(r"\s+", " ", text)
    return text
    
def remove_extra(text):
    text = re.sub(r"\s+", " ", text)
    return text

# Apply the function to the 'transcript' column
data['transcriptions'] = data['transcriptions'].apply(remove_punctuations)
data['transcriptions'] = data['transcriptions'].apply(remove_emoji)
data['transcriptions'] = data['transcriptions'].apply(remove_extra_space)
data['transcriptions'] = data['transcriptions'].apply(remove_extra)


# Apply the function to the 'transcript' column
val_data['transcriptions'] = val_data['transcriptions'].apply(remove_punctuations)
val_data['transcriptions'] = val_data['transcriptions'].apply(remove_emoji)
val_data['transcriptions'] = val_data['transcriptions'].apply(remove_extra_space)
val_data['transcriptions'] = val_data['transcriptions'].apply(remove_extra)

# Model

In [7]:
#Setting up base model
MODEL_NAME = "openai/whisper-medium"
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME, device_map="auto")
model_id = "whisper-medium"

TASK = "transcribe"
feature_extractor = WhisperFeatureExtractor.from_pretrained(MODEL_NAME)
tokenizer = WhisperTokenizer.from_pretrained(MODEL_NAME, language='bn', task=TASK)
processor = WhisperProcessor.from_pretrained(MODEL_NAME, language='bn', task=TASK)
ids = tokenizer.encode("")
tokenizer.decode(ids)

'<|startoftranscript|><|bn|><|transcribe|><|notimestamps|><|endoftext|>'

In [8]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        
        torch.cuda.empty_cache()

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [9]:
def prepare_dataset(example, split):
    if split == "train":
        audio_path = train_data_dir + example["file_name"]
    elif split == "val":
        audio_path = val_data_dir + example["file_name"]
    else:
        raise ValueError("Invalid split specified. Expected 'train' or 'val'.")
    
    # Load the audio using librosa
    audio, sr = librosa.load(audio_path, sr=16_000)
    
    # Extract input features and labels
    example["input_features"] = feature_extractor(audio, sampling_rate=sr).input_features[0]
    example["labels"] = tokenizer(f"{example['transcriptions']}", max_length=448, padding=True, truncation=True).input_ids
    
    return example


def filter_inputs(input_audio):
    """filter inputs with zero input length"""
    return 0 < len(input_audio)


def filter_labels(input_labels):
    """filter empty label sequences"""
    return 0 < len(input_labels)   


cer = CharErrorRate()
wer = WordErrorRate()



def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    label_ids[label_ids == -100] = tokenizer.pad_token_id

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer_res = wer(pred_str, label_str)
    cer_res = cer(pred_str, label_str)

    # print("WER:",wer_res,"| CER:", cer_res) # to show up during running logs
    # print("Pred:",pred_str[0])
    # print("Label:",label_str[0])
    
    return {"wer": wer_res, "cer": cer_res}

In [10]:
train_set = data
val_set = val_data

print(f"Train size: {len(train_set)}")
print(f"Val size: {len(val_set)}")

Train size: 13342
Val size: 1666


In [11]:
def evaluate_model(model, ds_eval, tokenizer, data_collator, batch_size=1, device="cuda"):

    # Set up the data loader
    test_loader = DataLoader(
        ds_eval,
        batch_size=batch_size,
        collate_fn=data_collator,
    )

    # Place the model in evaluation mode
    model.eval()

    # Initialize accumulators
    predictions = []
    references = []

    # Perform inference
    for batch in test_loader:
        # Move inputs to the specified device
        input_features = batch["input_features"].to(device)
        
        # Generate predictions
        with torch.no_grad():
            pred_ids = model.generate(input_features)
        
        # Decode predictions and references
        preds = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
        refs = tokenizer.batch_decode(batch["labels"], skip_special_tokens=True)
        
        predictions.extend(preds)
        references.extend(refs)

    # Calculate Word Error Rate (WER)
    test_wer = wer(references, predictions)
    
    # Return results
    return test_wer, predictions, references


# Main Loop

In [None]:
set_seed()


train_split = DS.from_pandas(train_set)
val_split = DS.from_pandas(val_set)

# Map the dataset with split-specific processing
ds_splits = DatasetDict({
    'train': train_split,
    'val': val_split,
})

print("Dataset Preparation Starts")

# Apply the prepare_dataset function to each split, passing the split argument
ds_splits["train"] = ds_splits["train"].map(
    lambda example: prepare_dataset(example, split="train"), 
    remove_columns=ds_splits["train"].column_names
)

ds_splits["val"] = ds_splits["val"].map(
    lambda example: prepare_dataset(example, split="val"), 
    remove_columns=ds_splits["val"].column_names
)


 

training_args = Seq2SeqTrainingArguments(
    output_dir=model_id,
    per_device_train_batch_size=12,
    per_device_eval_batch_size=12,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    fp16=True,
    learning_rate=1e-3,
    weight_decay=1e-2,
    warmup_steps=2,
    num_train_epochs=1,
    eval_strategy="epoch", # or "epochs"
    save_strategy="epoch",
    predict_with_generate=True,
    # generation_max_length=448,
    #     save_steps=2976,
    #     eval_steps=32,
    #     logging_steps=1000,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
    report_to="none",
    remove_unused_columns=False
)

model.generation_config.max_length = 448
# model.generation_config.language = "bn"
# model.generation_config.task = "transcribe"

# model.generation_config.forced_decoder_ids = None
# model.config.suppress_tokens = [] 

#Start training
optimizer = AdamW(model.parameters(), lr=training_args.learning_rate)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=ds_splits["train"],
    eval_dataset=ds_splits["val"],
    data_collator=data_collator,
    tokenizer=processor.feature_extractor,
    compute_metrics=compute_metrics,
    optimizers=(optimizer, None),
#     callbacks=[EarlyStoppingCallback(2, 1.0)]
)


trainer.train()


trainer.save_model(training_args.output_dir)
processor.save_pretrained(training_args.output_dir)


# Evaluate the model
val_wer, predictions, references = evaluate_model(
    model=model,
    ds_eval=ds_splits["val"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    batch_size=1,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# print(f"Loss at iteration {i}: {trainer.state.log_history[-1]['loss']}")
print(f"Val WER after training: {val_wer}")


print("Process complete.")