# Finetuning Whisper for Punjabi

In [None]:
# Check GPU availability
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [2]:
# All the imports
from pathlib import Path
import pandas as pd
import pickle
from copy import deepcopy
from typing import List, Tuple, Dict

from datasets import DatasetDict, Dataset
import soundfile as sf
import pysrt
import json
import jiwer

from typing import Dict, Any, Optional
import logging
import re


In [3]:
data_path = Path("../data/Punjabi")
audio_path = data_path / "Audio"
transcript_path = data_path / "Text"
srt_path = data_path / "Ground Truth SRT"

In [None]:
benchmark_df = pd.read_csv(data_path / "benchmark_list.csv", index_col=0).reset_index(drop=True)
benchmark_df.head()

In [5]:
benchmark_names = benchmark_df["Story Name"].unique().tolist()
benchmark_categories = benchmark_df["Category"].unique().tolist()


In [None]:
# Create lists to store decoded audio and transcripts
base_lists = {"audio_path": [], "transcript_path": [], "srt_path": []}
dataset = {"train": deepcopy(base_lists), "val": deepcopy(base_lists)}
dataset

In [None]:
for dir_path in audio_path.glob("*"):
    for file_path in dir_path.glob("*.wav"):
        print(file_path.name)
        if any(benchmark_path in file_path.name for benchmark_path in benchmark_names):
            dataset_type = "val"
        else:
            dataset_type = "train"
        print(dataset_type)

        # Append audio path and transcript path to dataset
        file_audio_path = str(file_path)
        file_transcript_path = transcript_path / dir_path.name.replace("Videos", "Text") / str(file_path.name.replace(".wav", ".txt"))
        file_srt_path = srt_path / dir_path.name.replace("Videos", "SRT") / str(file_path.name.replace(".wav", ".srt"))
        dataset[dataset_type]["audio_path"].append(file_audio_path)
        dataset[dataset_type]["transcript_path"].append(file_transcript_path)
        dataset[dataset_type]["srt_path"].append(file_srt_path)  

In [17]:
def get_srt_start_end(start, end):
    """Convert SRT timestamp objects to seconds.
    
    Args:
        start: pysrt.SubRipTime object for start time
        end: pysrt.SubRipTime object for end time
        
    Returns:
        tuple: (start_seconds, end_seconds)
    """
    def time_to_seconds(t):
        return t.hours * 3600 + t.minutes * 60 + t.seconds + t.milliseconds / 1000
        
    return time_to_seconds(start), time_to_seconds(end)

# def combine_subtitles(srt_data, max_chunk_duration=15.0):
#     """Combine consecutive subtitle chunks into logical sentences.
    
#     Combines subtitles while:
#     1. Respecting maximum duration limits
#     2. Splitting on sentence boundaries (punctuation marks)
#     3. Cleaning up formatting and unwanted symbols
    
#     Args:
#         srt_data: List of pysrt.SubRipItem objects
#         max_chunk_duration: Maximum duration in seconds for combined chunks
    
#     Returns:
#         List of combined pysrt.SubRipItem objects
#     """
#     if not srt_data or max_chunk_duration is None:
#         return srt_data
        
#     # Constants
#     PUNCTUATION_MARKS = {'।', '!', '?', '॥', '...', '...!', '♪'}
#     QUOTE_REPLACEMENTS = {
#         '“': '"',
#         '”': '"', 
#         '‘': "'",
#         '’': "'",
#         '…': '...',
#         '-': ''
#     }
    
#     def clean_subtitle_text(text):
#         """Clean and normalize subtitle text."""
#         # Remove text in square brackets
#         text = re.sub(r'\[.*?\]', '', text)
        
#         # Replace quotes and other symbols
#         for old, new in QUOTE_REPLACEMENTS.items():
#             text = text.replace(old, new)
            
#         # Remove leading/trailing ♪ and whitespace
#         text = text.strip('♪').strip()
        
#         # Normalize whitespace
#         return ' '.join(text.split())
    
#     def should_split_chunk(text):
#         """Check if chunk should be split based on punctuation."""
#         return any(text.endswith(mark) for mark in PUNCTUATION_MARKS)
    
#     combined_srt_data = []
#     current_chunk = None
    
#     for subtitle in srt_data:
#         subtitle_text = clean_subtitle_text(subtitle.text)
        
#         if not subtitle_text:  # Skip empty subtitles
#             continue
            
#         if current_chunk is None:
#             current_chunk = subtitle
#             current_chunk.text = subtitle_text
#             continue
            
#         # Check if combining would exceed duration limit
#         start_sec, end_sec = get_srt_start_end(current_chunk.start, subtitle.end)
#         if end_sec - start_sec <= max_chunk_duration:
#             # Combine chunks
#             current_chunk.end = subtitle.end
#             current_chunk.text = f"{current_chunk.text} {subtitle_text}"
            
#             if should_split_chunk(current_chunk.text):
#                 combined_srt_data.append(current_chunk)
#                 current_chunk = None
#         else:
#             # Duration would be too long, split here
#             combined_srt_data.append(current_chunk)
#             current_chunk = subtitle
#             current_chunk.text = subtitle_text
            
#             if should_split_chunk(current_chunk.text):
#                 combined_srt_data.append(current_chunk)
#                 current_chunk = None
    
#     # Add final chunk if exists
#     if current_chunk is not None:
#         combined_srt_data.append(current_chunk)
        
#     return combined_srt_data

In [18]:
def get_srt_time_str(time_obj, adjust=False):
    """Convert SRT timestamp object to string format with 0.02s resolution.
    
    Args:
        time_obj: pysrt.SubRipTime object
        
    Returns:
        str: Time in seconds with 0.02s resolution (e.g. "10.50")
    """
    seconds = (time_obj.hours * 3600 + 
              time_obj.minutes * 60 + 
              time_obj.seconds + 
              time_obj.milliseconds / 1000)
    # Round to nearest 0.02 seconds
    rounded_seconds = round(seconds * 50) / 50
    if adjust:
        rounded_seconds -= 0.02
    return f"{rounded_seconds:.2f}"

def combine_subtitles(srt_data, max_chunk_duration=15.0):
    """Combine consecutive subtitle chunks with timestamps.
    
    Combines subtitles while:
    1. Respecting maximum duration limits
    2. Adding timestamps in format <|seconds|>
    3. Cleaning up formatting and unwanted symbols
    
    Args:
        srt_data: List of pysrt.SubRipItem objects
        max_chunk_duration: Maximum duration in seconds for combined chunks
    
    Returns:
        List of combined pysrt.SubRipItem objects with timestamps in text
    """
    if not srt_data or max_chunk_duration is None:
        return srt_data
        
    # Constants
    QUOTE_REPLACEMENTS = {
        '“': '"',
        '”': '"', 
        '‘': "'",
        '’': "'",
        '…': '...',
        '-': ''
    }
    
    def clean_subtitle_text(text):
        """Clean and normalize subtitle text."""
        # Remove text in square brackets
        text = re.sub(r'\[.*?\]', '', text)
        
        # Replace quotes and other symbols
        for old, new in QUOTE_REPLACEMENTS.items():
            text = text.replace(old, new)
            
        # Remove leading/trailing ♪ and whitespace
        text = text.strip('♪').strip()
        
        # Normalize whitespace
        return ' '.join(text.split())
    
    combined_srt_data = []
    current_chunk = None
    chunk_start_time = 0.0
    
    for subtitle in srt_data:
        subtitle_text = clean_subtitle_text(subtitle.text)
        
        if not subtitle_text:  # Skip empty subtitles
            continue
            
        # Calculate relative timestamps
        if current_chunk is None:
            chunk_start_time = float(get_srt_time_str(subtitle.start))
            relative_start = 0.0
        else:
            relative_start = float(get_srt_time_str(subtitle.start)) - chunk_start_time
            
        relative_end = float(get_srt_time_str(subtitle.end, adjust=True)) - chunk_start_time
        
        # Add timestamps to text
        timestamped_text = (f"<|{relative_start:.2f}|> "
                          f"{subtitle_text} "
                          f"<|{relative_end:.2f}|>")
        
        if current_chunk is None:
            current_chunk = subtitle
            current_chunk.text = timestamped_text
            continue
            
        # Check if combining would exceed duration limit
        if relative_end <= max_chunk_duration:
            # Combine chunks
            current_chunk.end = subtitle.end
            current_chunk.text = f"{current_chunk.text} {timestamped_text}"
        else:
            # Duration would be too long, split here
            combined_srt_data.append(current_chunk)
            current_chunk = subtitle
            # Reset timestamps for new chunk
            chunk_start_time = float(get_srt_time_str(subtitle.start))
            current_chunk.text = f"<|0.00|> {subtitle_text} <|{float(get_srt_time_str(subtitle.end, adjust=True)) - chunk_start_time:.2f}|>"
    
    # Add final chunk if exists
    if current_chunk is not None:
        combined_srt_data.append(current_chunk)
        
    return combined_srt_data

In [19]:
# def get_srt_start_end(start, end):
#     """
#     Get the start and end time in seconds from a pysrt.SubRipItem object
#     """
#     start_sec = (start.hours * 3600 + 
#                 start.minutes * 60 + 
#                 start.seconds + 
#                 start.milliseconds / 1000)
#     end_sec = (end.hours * 3600 + 
#                 end.minutes * 60 + 
#                 end.seconds + 
#                 end.milliseconds / 1000)
#     return start_sec, end_sec



# def combine_subtitles(srt_data, max_chunk_duration=15.0):
#     """
#     Combine consecutive subtitle chunks such that their total duration doesn't exceed max_duration,
#     and split whenever a logical sentence ends with specific punctuation.
    
#     Args:
#         srt_data: List of pysrt.SubRipItem objects
#         max_chunk_duration: Maximum duration in seconds for combined chunks
    
#     Returns:
#         List of combined pysrt.SubRipItem objects
#     """
#     if max_chunk_duration is None:
#         return srt_data
#     combined_srt_data = []
#     current_chunk = None
#     punctuation_marks = ['।', '!', '?', '॥', '...', '...!', '♪']

#     for subtitle in srt_data:
#         # Remove text within square brackets and unwanted punctuation
#         subtitle_text = subtitle.text
#         subtitle_text = re.sub(r'\[.*?\]', '', subtitle_text)  # Remove text in square brackets
        
#         subtitle_text = subtitle_text.replace('“', '"').replace('”', '"').replace('‘', "'").replace('’', "'")  # Replace quotes
#         subtitle_text = subtitle_text.replace('…', '...').replace('-', '')  # Remove ellipsis and replace with ... and remove '-
#         subtitle_text = subtitle_text.strip()  # Remove leading/trailing whitespace
#         subtitle_text = " ".join(subtitle_text.split())

#         # Remove leading and trailing ♪ symbols and treat them as punctuation
#         if subtitle_text.strip().startswith('♪'):
#             subtitle_text = subtitle_text.strip()[1:].strip()

#         # print(subtitle_text)

#         if current_chunk is None:
#             current_chunk = subtitle
#             current_chunk.text = subtitle_text
#         else:
#             # Calculate duration if we combine this subtitle
#             start_sec, end_sec = get_srt_start_end(current_chunk.start, subtitle.end)
#             potential_duration = end_sec - start_sec
            
#             if potential_duration <= max_chunk_duration:
#                 # Combine with current chunk
#                 current_chunk.end = subtitle.end
#                 # Handle the music symbol
#                 current_chunk.text = current_chunk.text.strip().replace('♪', '').strip()
#                 current_chunk.text = f"{current_chunk.text} {subtitle_text}"
                
#                 # Check if the current chunk ends with a punctuation mark
#                 if any(current_chunk.text.endswith(mark) for mark in punctuation_marks):

#                     # Handle the music symbol
#                     current_chunk.text = current_chunk.text.strip().replace('♪', '').strip()
#                     # Add current chunk to results and start new one
#                     combined_srt_data.append(current_chunk)
#                     current_chunk = None  # Reset current_chunk to start a new one
#             else:
#                 # If the potential duration exceeds max_chunk_duration, add the current chunk
#                 # to results and start a new chunk with the current subtitle
#                 # Handle the music symbol
#                 current_chunk.text = current_chunk.text.strip().replace('♪', '').strip()
#                 combined_srt_data.append(current_chunk)
#                 current_chunk = subtitle
#                 current_chunk.text = subtitle_text
                
#                 # Check if the new current_chunk ends with a punctuation mark
#                 if any(current_chunk.text.endswith(mark) for mark in punctuation_marks):

#                     # Handle the music symbol
#                     current_chunk.text = current_chunk.text.strip().replace('♪', '').strip()
#                     combined_srt_data.append(current_chunk)
#                     current_chunk = None  # Reset current_chunk to start a new one

#     # Add the last chunk if it exists
#     if current_chunk is not None:
#         combined_srt_data.append(current_chunk)
    
#     return combined_srt_data

In [None]:
srt_file_path = dataset["val"]["srt_path"][1]
print(srt_file_path)

srt_data = pysrt.open(srt_file_path)
combined_srt_data = combine_subtitles(srt_data)

for i, sub in enumerate(combined_srt_data):
    sub_duration = sub.end - sub.start
    print(f"chunk :{i} start: {sub.start} end: {sub.end} duration: {sub_duration}")
    print(sub.text, "\n")

In [21]:
def chunk_dataset(
    dataset: Dict[str, Dict[str, list]], 
    finetuning_data_path: Path,
    min_chunk_duration: float = 1,  # Minimum chunk duration in seconds
    max_chunk_duration: float = 15.0,  # Maximum chunk duration in seconds
    combined_chunk_duration: Optional[float] = None,  # Maximum chunk duration in seconds
) -> None:
    """
    Chunk each sentence from SRT files into separate chunks of audio and text for finetuning.
    
    Args:
        dataset: Dictionary containing train/val splits with audio and SRT paths
        finetuning_data_path: Base path to store chunked data
        min_chunk_duration: Minimum allowed duration for audio chunks in seconds
        max_chunk_duration: Maximum allowed duration for audio chunks in seconds
        combined_chunk_duration: Maximum duration for combined chunks in seconds
    """
    # Setup logging
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)

    for split in dataset:
        logger.info(f"Processing {split} split...")
        # Create output directories if they don't exist
        audio_out_path = finetuning_data_path / split / "Audio"
        text_out_path = finetuning_data_path / split / "Text"
        audio_out_path.mkdir(parents=True, exist_ok=True)
        text_out_path.mkdir(parents=True, exist_ok=True)
        
        # Verify required keys exist
        if "srt_path" not in dataset[split] or "audio_path" not in dataset[split]:
            logger.error(f"Missing required paths in {split} split")
            continue

        for srt_file_path, audio_file_path in zip(dataset[split]["srt_path"], dataset[split]["audio_path"]):
            try:
                # Load SRT file
                srt_data = pysrt.open(srt_file_path)

                # Combine subtitles into chunks of combined_chunk_duration
                combined_srt_data = combine_subtitles(srt_data, max_chunk_duration=combined_chunk_duration)
                
                # Load audio file
                audio_data, sampling_rate = sf.read(audio_file_path)
                
                logger.info(f"Processing file: {Path(audio_file_path).name}")
                
                for i, sentence in enumerate(combined_srt_data):
                    try:
                        # Convert timedelta to seconds
                        start_sec, end_sec = get_srt_start_end(sentence.start, sentence.end)
                        
                        # Validate chunk duration
                        duration = end_sec - start_sec
                        if duration < min_chunk_duration:
                            logger.warning(f"Skipping chunk {i}: duration too short ({duration:.2f}s)")
                            continue
                        if duration > max_chunk_duration:
                            logger.warning(f"Skipping chunk {i}: duration too long ({duration:.2f}s)")
                            continue
                            
                        # Convert time to samples
                        start_sample = int(start_sec * sampling_rate)
                        end_sample = int(end_sec * sampling_rate)
                        
                        # Validate sample indices
                        if end_sample > len(audio_data):
                            logger.warning(f"Chunk {i} exceeds audio duration, truncating")
                            end_sample = len(audio_data)
                        
                        # Extract audio chunk
                        audio_chunk = audio_data[start_sample:end_sample]
                        
                        # Generate output paths
                        chunk_name = f"{Path(audio_file_path).stem}_chunk{i:04d}"
                        audio_out_file = audio_out_path / f"{chunk_name}.wav"
                        text_out_file = text_out_path / f"{chunk_name}.txt"
                        
                        # Save audio chunk and transcript
                        sf.write(audio_out_file, audio_chunk, sampling_rate)
                        text_out_file.write_text(" ".join(sentence.text.strip().split()), encoding='utf-8')
                        
                    except Exception as e:
                        logger.error(f"Error processing chunk {i} of {audio_file_path.name}: {str(e)}")
                        continue
                        
            except Exception as e:
                logger.error(f"Error processing file {audio_file_path}: {str(e)}")
                continue
    logger.info("Dataset chunking completed")


In [None]:
# Example usage
finetuning_data_path = Path("/spiral_hdd_2/workspace/naren/data/timestamped_chunks")
chunk_dataset(
    dataset=dataset,
    finetuning_data_path=finetuning_data_path,
    min_chunk_duration=0.25,  # Minimum quarter second
    max_chunk_duration=30.0,  # Maximum 30 seconds
    combined_chunk_duration=30.0,  # Maximum 10 seconds
)

In [None]:
# Create finetuning dataset and create lists to store decoded audio and transcripts
base_lists = {"audio_path": [], "transcript_path": []}
finetuning_dataset = {"train": deepcopy(base_lists), "val": deepcopy(base_lists)}
finetuning_dataset

In [24]:
import numpy as np
train_cutoff = 100
val_cutoff = 10

train_cutoff = np.inf
val_cutoff = np.inf

for split in finetuning_dataset:
    for i, audio_path in enumerate((finetuning_data_path / split / "Audio").glob("*.wav")):
        if i == train_cutoff and split == "train":
            break
        if i == val_cutoff and split == "val":
            break
        transcript_path = (finetuning_data_path / split / "Text") / str(audio_path.name.replace(".wav", ".txt"))
        finetuning_dataset[split]["audio_path"].append(str(audio_path))
        finetuning_dataset[split]["transcript_path"].append(str(transcript_path))

In [None]:
print(finetuning_dataset["train"]["audio_path"])
print(finetuning_dataset["train"]["transcript_path"])

In [None]:
len(finetuning_dataset["val"]["audio_path"])

In [None]:
len(finetuning_dataset["train"]["audio_path"])


In [None]:
print(finetuning_dataset["val"]["audio_path"][0])
print(finetuning_dataset["val"]["transcript_path"][0])
print(finetuning_dataset["val"]["audio_path"][1])
print(finetuning_dataset["val"]["transcript_path"][1])

In [None]:
with open(finetuning_dataset["val"]["transcript_path"][0], "r") as f:
    print(f.read())
with open(finetuning_dataset["val"]["transcript_path"][1], "r") as f:
    print(f.read())

## Create Whisper Feature Extractor, Tokenizer and Processor

In [30]:
model_name = "openai/whisper-large-v2"
language = "Punjabi"
task = "transcribe"

In [None]:
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name)

In [32]:
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained(model_name, language=language, task=task)

In [33]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained(model_name, language=language, task=task)

In [34]:
def prepare_dataset(batch):
    # Load audio file
    audio_array, sampling_rate = sf.read(str(batch["audio_path"]))
    
    # Read transcript
    with open(batch["transcript_path"], 'r', encoding='utf-8') as file:
        transcript = file.read().strip()
    
    # Compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio_array, sampling_rate=sampling_rate).input_features[0]
    
    # Add array and sampling rate to batch
    batch["array"] = audio_array
    batch["sampling_rate"] = sampling_rate
    
    # Encode target text to label ids
    max_input_length = 448  # Maximum sequence length allowed by the model
    batch["labels"] = tokenizer(transcript, padding=True, truncation=True, max_length=max_input_length).input_ids
    
    return batch

In [None]:
whisper_dataset = DatasetDict()
whisper_dataset["train"] = Dataset.from_dict(finetuning_dataset["train"])
whisper_dataset["val"] = Dataset.from_dict(finetuning_dataset["val"])
len(whisper_dataset["train"])

In [None]:
whisper_dataset["train"] = whisper_dataset["train"].map(prepare_dataset, num_proc=len(whisper_dataset["train"]))
whisper_dataset["val"] = whisper_dataset["val"].map(prepare_dataset, num_proc=len(whisper_dataset["val"]))

In [None]:
len(whisper_dataset["train"])

In [None]:
whisper_dataset_path = finetuning_data_path / "whisper_dataset_1200_15.pkl"
whisper_dataset_path

In [None]:
with open(whisper_dataset_path, "wb") as f:
    pickle.dump(whisper_dataset, f)

# Load the whisper dataset for finetuning


In [None]:
with open(whisper_dataset_path, "rb") as f:
    whisper_dataset = pickle.load(f)


In [None]:
with open(whisper_dataset["train"][0]["transcript_path"], "r") as f:
    print(tokenizer(f.read()))

In [None]:
whisper_dataset

## Load Whisper Model and trainer

In [None]:
from dataclasses import dataclass
from typing import Any, Dict, List, Union

import torch
from transformers import WhisperForConditionalGeneration

# model = WhisperForConditionalGeneration.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu")
model = WhisperForConditionalGeneration.from_pretrained(model_name)

In [28]:
# Load Whisper model and processor
model.generation_config.language = language
model.generation_config.task = task
model.generation_config.forced_decoder_ids = None

In [29]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt", max_length=self.processor.tokenizer.model_max_length)

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [30]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

In [None]:
import evaluate

metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    # # Strip whitespace and clean up strings
    # pred_str = [pred.strip() for pred in pred_str if pred.strip()]
    # label_str = [label.strip() for label in label_str if label.strip()]

    print("Predictions: \n", pred_str)
    print("Labels: \n", label_str)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [None]:
metric.compute(predictions=["Hi, how are iyou?", "Fine, thank youu", "Fine, thank youu"], references=["Hi, how are you?", "Fine, thank you", "Fine, thank you"])

In [33]:
# enter a new checkpoints path
checkpoints_path = "/spiral_hdd_2/workspace/naren/data/models/whisper-large-v2-pa-4"

In [34]:
import torch
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir=checkpoints_path,  # change to a repo name of your choice
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-4,
    warmup_steps=10,
    max_steps=250,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=16,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=25,
    eval_steps=5,
    logging_steps=5,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    save_total_limit=2,
    metric_for_best_model="loss",
    greater_is_better=False,
    push_to_hub=False,
)

In [35]:
# from transformers import EarlyStoppingCallback

# # Add EarlyStoppingCallback
# early_stopping_callback = EarlyStoppingCallback(
#     early_stopping_patience=3  # Number of evaluations with no improvement before stopping
# )

In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=whisper_dataset["train"],
    eval_dataset=whisper_dataset["val"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
    # callbacks=[early_stopping_callback]
)

In [None]:
trainer.train()

In [38]:
# Save the best model explicitly to a new directory
trainer.save_model(output_dir=checkpoints_path)

## Conversion to Whisper pt model

In [2]:
from collections import OrderedDict

def convert_hf_to_openai_checkpoint(hf_model, model_size="large-v2"):
    """
    Convert a HF Whisper model to an OpenAI Whisper-style checkpoint
    with correct key naming for 'attn_ln', 'mlp_ln', 'cross_attn', etc.
    """

    hf_sd = hf_model.state_dict()
    openai_sd = OrderedDict()

    def rename_key(key: str) -> str:
        new_key = key

        # 1) Remove "model." prefix
        if new_key.startswith("model."):
            new_key = new_key.replace("model.", "")

        # 2) "encoder.layers" -> "encoder.blocks"
        new_key = new_key.replace("encoder.layers", "encoder.blocks")
        #    "decoder.layers" -> "decoder.blocks"
        new_key = new_key.replace("decoder.layers", "decoder.blocks")

        # 3) Token/pos embeddings
        new_key = new_key.replace("embed_tokens.weight", "token_embedding.weight")
        new_key = new_key.replace("embed_positions.weight", "positional_embedding")

        # 4) Self-attention names
        #    HF:  "self_attn.{q_proj|k_proj|v_proj|out_proj}"
        #    OAI: "attn.{query|key|value|out}"
        new_key = new_key.replace("self_attn.q_proj.weight", "attn.query.weight")
        new_key = new_key.replace("self_attn.q_proj.bias",   "attn.query.bias")
        new_key = new_key.replace("self_attn.k_proj.weight", "attn.key.weight")
        new_key = new_key.replace("self_attn.k_proj.bias",   "attn.key.bias")
        new_key = new_key.replace("self_attn.v_proj.weight", "attn.value.weight")
        new_key = new_key.replace("self_attn.v_proj.bias",   "attn.value.bias")
        new_key = new_key.replace("self_attn.out_proj.weight", "attn.out.weight")
        new_key = new_key.replace("self_attn.out_proj.bias",   "attn.out.bias")

        # 5) Cross-attention (decoder only):
        #    HF:  "encoder_attn.{q_proj|k_proj|v_proj|out_proj}", "encoder_attn_layer_norm"
        #    OAI: "cross_attn.{query|key|value|out}", "cross_attn_ln"
        new_key = new_key.replace("encoder_attn.q_proj.weight", "cross_attn.query.weight")
        new_key = new_key.replace("encoder_attn.q_proj.bias",   "cross_attn.query.bias")
        new_key = new_key.replace("encoder_attn.k_proj.weight", "cross_attn.key.weight")
        new_key = new_key.replace("encoder_attn.k_proj.bias",   "cross_attn.key.bias")
        new_key = new_key.replace("encoder_attn.v_proj.weight", "cross_attn.value.weight")
        new_key = new_key.replace("encoder_attn.v_proj.bias",   "cross_attn.value.bias")
        new_key = new_key.replace("encoder_attn.out_proj.weight", "cross_attn.out.weight")
        new_key = new_key.replace("encoder_attn.out_proj.bias",   "cross_attn.out.bias")
        new_key = new_key.replace("encoder_attn_layer_norm.weight", "cross_attn_ln.weight")
        new_key = new_key.replace("encoder_attn_layer_norm.bias",   "cross_attn_ln.bias")

        # 6) MLP (feed-forward)
        new_key = new_key.replace("fc1.weight", "mlp.0.weight")
        new_key = new_key.replace("fc1.bias",   "mlp.0.bias")
        new_key = new_key.replace("fc2.weight", "mlp.2.weight")
        new_key = new_key.replace("fc2.bias",   "mlp.2.bias")

        # 7) Old layer norms -> attn_ln, mlp_ln
        #    HF: "self_attn_layer_norm" -> "ln_attn"
        #    OAI: "attn_ln"
        #    HF: "final_layer_norm" -> "ln_mlp"
        #    OAI: "mlp_ln"
        new_key = new_key.replace("self_attn_layer_norm.weight", "attn_ln.weight")
        new_key = new_key.replace("self_attn_layer_norm.bias",   "attn_ln.bias")
        new_key = new_key.replace("final_layer_norm.weight",     "mlp_ln.weight")
        new_key = new_key.replace("final_layer_norm.bias",       "mlp_ln.bias")

        # 8) If you also used "ln_attn"/"ln_mlp" previously, rename them to OAI's "attn_ln", "mlp_ln"
        new_key = new_key.replace("ln_attn", "attn_ln")
        new_key = new_key.replace("ln_mlp",  "mlp_ln")

        # 9) The top-level LN for the encoder/decoder is "encoder.ln_post" / "decoder.ln_post"
        #    That might still be okay. If official code expects "encoder.ln_post" we can keep it.
        #    Just ensure there's no mismatch, e.g. "layer_norm" -> "ln_post"
        new_key = new_key.replace("encoder.layer_norm.weight", "encoder.ln_post.weight")
        new_key = new_key.replace("encoder.layer_norm.bias",   "encoder.ln_post.bias")
        new_key = new_key.replace("decoder.layer_norm.weight", "decoder.ln_post.weight")
        new_key = new_key.replace("decoder.layer_norm.bias",   "decoder.ln_post.bias")

        new_key = new_key.replace("decoder.ln_post.", "decoder.ln.")

        return new_key

    for old_key, val in hf_sd.items():
        new_key = rename_key(old_key)
        openai_sd[new_key] = val

    # Remove or rename 'proj_out.weight' if your code doesn't expect it at all:
    if "proj_out.weight" in openai_sd:
        # Option A: just drop it:
        openai_sd.pop("proj_out.weight")

        # Option B: rename it to something your code expects, e.g. "decoder.proj.weight"
        #   openai_sd["decoder.proj.weight"] = openai_sd.pop("proj_out.weight")


    # For "tiny", provide the official dims from whisper/model.py
    if model_size == "tiny":
      dims = {
          "n_mels": 80,
          "n_audio_ctx": 1500,
          "n_audio_state": 384,
          "n_audio_head": 6,
          "n_audio_layer": 4,   # 4 encoder layers
          "n_vocab": hf_model.config.vocab_size,
          "n_text_ctx": 448,
          "n_text_state": 384,
          "n_text_head": 6,
          "n_text_layer": 4,    # 4 decoder layers
      }
    elif model_size == "large-v2":
      dims = {
          "n_mels": 80,
          "n_audio_ctx": 1500,
          "n_audio_state": 1280,
          "n_audio_head": 20,
          "n_audio_layer": 32,
          "n_vocab": hf_model.config.vocab_size,
          "n_text_ctx": 448,
          "n_text_state": 1280,
          "n_text_head": 20,
          "n_text_layer": 32,
      }

    return {
        "model_state_dict": openai_sd,
        "dims": dims,
        "name": model_size,
        # If it's the .en variant, set this to False; otherwise True
        "multilingual": True
    }


In [1]:
checkpoint_path = "/spiral_hdd_2/workspace/naren/models/whisper-large-v2-pa-5"
best_checkpoint_path = "/spiral_hdd_2/workspace/naren/models/whisper-large-v2-pa-5/checkpoint-best"


In [None]:
model = WhisperForConditionalGeneration.from_pretrained(best_checkpoint_path)

In [None]:
# Convert
print("Converting HF model to OpenAI-like format...")
whisper_model = convert_hf_to_openai_checkpoint(model, model_size="large-v2")

# Save the converted checkpoint
converted_ckpt_path = f"{checkpoint_path}/best_checkpoint.pt"
torch.save(whisper_model, converted_ckpt_path)
print(f"Saved converted checkpoint to {converted_ckpt_path}")

## Transcription

In [11]:
test_audio_path = data_path / "Audio" / "AniBook Videos" / "Abdul_Kalam,_Missile_Man_Punjabi.wav"

test_audio, sampling_rate = sf.read(test_audio_path)
# test_input_features = processor(test_audio, sampling_rate=sampling_rate, return_tensors="pt").input_features.to(model.device)

In [None]:
# read the srt file and convert to text and store it
import pysrt

test_srt_path = data_path / "Ground Truth SRT" / "AniBook SRT" / "Abdul_Kalam,_Missile_Man_Punjabi.srt"

srt_file = pysrt.open(test_srt_path)
for subtitle in srt_file:
    print(f"Start: {subtitle.start}, End: {subtitle.end}, Text: {subtitle.text}")

# Whisper Load model and run inference

In [None]:
import whisper

converted_ckpt_path = "/spiral_hdd_2/workspace/naren/models/whisper-large-v2-pa-5/best_checkpoint.pt"
whisper_model = whisper.load_model(converted_ckpt_path)
# whisper_model = whisper.load_model("large-v2")

In [None]:
import numpy as np

test_audio = test_audio.astype(np.float32)
total_duration = len(test_audio) / sampling_rate
print(f"Total duration of the audio file: {total_duration} seconds")

In [47]:
# Create clip timestamps at 10 second intervals, with each timestamp repeated
# to create pairs for start/end times up to 290 seconds
# clip_timestamps = []
# for i in range(0, 290, 10):
#     clip_timestamps.extend([i, i+10])
# print(clip_timestamps)


In [None]:
import numpy as np

# Transcribe using whisper model
initial_prompt = "ਸਤ ਸ੍ਰੀ ਅਕਾਲ, ਇਹ ਇੱਕ ਪੰਜਾਬੀ ਆਡੀਓ ਫਾਇਲ ਹੈ, ਇਸ ਦਾ ਸਮੱਗਰੀ ਹੇਠ ਲਿਖੀ ਹੈ।"
# clip_timestamps = "0,10,10,20,30"
# total duration of the audio file

# result = whisper_model.transcribe(test_audio, language="pa", word_timestamps=True, initial_prompt=initial_prompt, clip_timestamps=clip_timestamps)
result = whisper_model.transcribe(test_audio, language="pa", word_timestamps=True, initial_prompt=initial_prompt)
print(result["text"])

In [None]:
pred_transcript = result["text"]
result["text"]

In [None]:

for segment in result["segments"]:
    print(f"Start: {segment['start']}, End: {segment['end']}, Text: {segment['text']}")

In [13]:
GT_text_path = "/home/arjun/naren/alignment/data/Punjabi/Text/AniBook Text/Abdul_Kalam,_Missile_Man_Punjabi.txt"
with open(GT_text_path, "r") as f:
    gt_text = f.read()

gt_transcript = " ".join(gt_text.split("\n"))

In [None]:
import jiwer
measures = jiwer.compute_measures(gt_transcript, pred_transcript)
print(f"WER: {measures['wer']}")
print(f"Insertions: {measures['insertions']}")
print(f"Deletions: {measures['deletions']}")
print(f"Substitutions: {measures['substitutions']}")
print(f"Ground Truth Total Words: {len(gt_transcript.split())}")
print(f"Predicted Total Words: {len(pred_transcript.split())}")

In [None]:
from pydub import AudioSegment
import numpy as np
import whisper

# Load your Whisper model
whisper_model = whisper.load_model("large-v2")

# Load your audio file using PyDub
audio = AudioSegment.from_file(test_audio_path)

# Define your desired chunk duration in milliseconds (e.g., 1 minute)
chunk_duration_ms = 10000  # 10 seconds
chunk_duration_ms_overlap = 13000
chunks = [audio[i:i + chunk_duration_ms_overlap] for i in range(0, len(audio), chunk_duration_ms)]

# Convert each chunk to numpy array
def audio_to_numpy(audio_segment):
    samples = np.array(audio_segment.get_array_of_samples())
    return samples.astype(np.float32) / (2 ** 15)  # Normalize to [-1, 1]

chunk_arrays = [audio_to_numpy(chunk) for chunk in chunks]

# Transcribe each chunk
outputs = []
initial_prompt = "ਸਤ ਸ੍ਰੀ ਅਕਾਲ, ਇਹ ਇੱਕ ਪੰਜਾਬੀ ਆਡੀਓ ਫਾਇਲ ਹੈ, ਇਸ ਦਾ ਸਮੱਗਰੀ ਹੇਠ ਲਿਖੀ ਹੈ۔"
for i, chunk in enumerate(chunk_arrays):
    result = whisper_model.transcribe(chunk, language="pa", word_timestamps=True, initial_prompt=initial_prompt)
    outputs.append(result)
    for segment in result["segments"]:
        print(f"Start: {segment['start']}, End: {segment['end']}, Text: {segment['text']}")
    print("-"*100)

In [None]:
for output in outputs:
    for segment in output["segments"]:
        print(f"Start: {segment['start']}, End: {segment['end']}, Text: {segment['text']}")
    print("-"*100)

In [None]:
with open(data_path / "Text" / "AniBook Text" / "Abdul_Kalam,_Missile_Man_Punjabi.txt", "r") as f:
    gt_text = f.read()
    print(gt_text)

In [None]:
# pred_transcript = result["text"]
gt_transcript = " ".join(gt_text.split())

In [None]:
gt_transcript

In [16]:
pred_transcript = "ਮਿਜ਼ਾਇਲ ਮੈਂ ਅਦਣੀ ਮਿਜ਼ਾਇਲ ਨੂੰ ਲੰਜ ਕਰਨ ਵਿੱਚ ਹੋ ਰਹੀਆਂ ਵਾਰਵਾਰ ਅਸਫ਼ਲਦਾਵਾਂ ਅਤੇ ਦੇਰੀ ਨੇ ਪਰੈਸ ਵਿੱਚ ਕਾਫ਼ੀ ਗੱਲਾਂ ਕਰਵਾ ਦਿੱਤੀਆਂ ਸਨ ਰਾਕੇਟ ਵਿਗਿਆਨ ਵਿੱਚ ਅਜੇ ਹੀਆਂ ਦੇਰੀਆਂ ਆਮ ਹੈਂ ਪਰ ਦੇਸ਼ ਸਮਝਣ ਦੇ ਮੋੜ ਵਿੱਚ ਨਹੀਂ ਸੀ ਸਾਡੀਆਂ ਮੁਸ਼ਕਿਲਾਂ ਸਿਰਫ਼ ਸਾਡੀਆਂ ਨੇ ਇਹਨਾਂ ਦੇਰਿਆਂ ਦਾ ਆਪਣੇ ਅਨੂਸਾਰ ਮਤਲਬ ਕੱਢੇ ਲਿਆ ਸੀ ਹਤੇ ਕੁੱਝ ਅਨੋਖੇ ਵਿਕਲਬ ਸੁਝਾਏ। ਅਮੂਲ ਨੇ ਇਹ ਵੀ ਇੱਕ ਘਾਟੂਣ ਬਣਾਇਆ ਸੀ ਜਿਸ ਵਿੱਚ ਸੁਝਾਓ ਦਿੱਤਾ ਸੀ ਕਿ ਅਗਨੀ ਨੂੰ ਇਹਦਨ ਵੱਜੋਂ ਉਨਾਂ ਦਾ ਮੱਖਣ ਵਰਤਣਾ ਚਾਹੀਦਾ ਹੈ ਪਰਿਸ ਦੇਰੀ ਘਾਰਨ ਗੁੱਸੇ ਵਿੱਚ ਸੀ ਅਤੇ ਮੋਰੀ ਤਰ੍ਹਾਂ ਨਰਾਸ਼ ਹੋ ਗਏ ਸੀ। ਮੈਨੂੰ ਕਹੀ ਮਾਕੇ ਯਾਦ ਆਏ ਜਦੋਂ ਮੇਰੀ ਟੀਮ ਦੇ ਲੀਡਰਾਂ ਨੇ ਸਾਡਾ ਮਨੋਬਲੋ ਅਧਾਇਆ ਸੀ ਅਦੇ ਮੈਂ ਵੀ ਅਜਿਹਾ ਕਰਨ ਬਾਰੇ ਸੋਚਿਆ ਮੈਂ ਇੱਕ ਮੀਟਿੰਗ ਲਾਈ ਅਤੇ ਦੋ ਹਜ਼ਾਰ ਮੈਂਬਰਾਂ ਦੀ ਆਪਣੀ ਟੀਮ ਨੂੰ ਸਮੋਦਿਤ ਕੀਤਾ ਸਾਨੂੰ ਇੱਕ ਵਧੀਆ ਮਾਕਾ ਦਿੱਤਾ ਗਨਾਲ ਉਹਨੀਆਂ ਹੀ ਵੱਡੀਆਂ ਚਾਣਾਉਤੀਆਂ ਵੀ ਆਉਂਦੀਆਂ ਹੈਂ, ਅਸੀਂ ਹਾਰ ਨਹੀਂ ਮੰਨ ਸਕਦੇ, ਸਾਡਾ ਦੇਸ਼ ਸਾਡੇ ਵੱਲੋਂ ਕਿਸੀ ਵੀ ਚੀਜ਼ ਦੀ ਕਮੀ ਦਾ ਹੱਕਦਾਰ ਨਹੀਂ ਹੈ। ਮੈਂ ਆਪਣੇ ਲੋਕਾਂ ਨੂੰ ਇਹ ਕਹਿ ਕੇ ਆਪਣੀ ਗੱਲ ਖ਼ਤਮ ਕੀਤੀ, ਮੈਂ ਤੁਹਾਨੂੰ ਵਾਦਾ ਕਰਦਾ ਹਾਂ, ਕਿ ਜਾਡੇ ਅਗਲੇ ਟਰਾਇਲ ਵਿੇਰੀ ਟੀਮ ਨੂੰ ਆਪਣੀ ਉਰਜਾ ਮੁੜ ਮਿਲ ਗਈ ਸੀ, ਆਪਣੇ ਜੋਸ਼ ਨੂੰ ਮੁੜ ਤਾਜਾ ਕਰਦਿਆਂ ਹੁਣਾਂ ਨੇ ਬਹੁਤ ਧਿਆਨ ਲੱਗਾ ਕਿ ਅਤੇ ਇੱਛਾ ਸ਼ਕਤੀ ਨਾਲ ਕੰਮ ਕੀਤਾ। ਆਖਰ ਘਾਰ ਬਾਈ ਮਾਈ ਉੱਨੀ ਸੋਂਰਾਨਵੇਂ ਨੂੰ ਲੰਜ ਦਹਿ ਕਰ ਦਿੱਦਾ ਗਿਆ ਸੀ। ਆਰਮੀ ਸਟਾਫ਼ ਤੇ ਮੁੱਖੀ ਅਤੇ ਰੱਖਿਆ ਮੰਦਰੀ ਵਰਗੇ ਵੱਛਲੀ ਰਾਤ ਸਾਡੇ ਵਿਚੋਂ ਕੁੱਝ ਲੋਕ ਸਮੁੰਦਰੀ ਕੰਢੇ ਉੱਤੇ ਸਹਿਰ ਕਰਨ ਗਏ ਹੋਏ ਸਨ ਜੇ ਉਹ ਪੂਰੇ ਜਨ ਨਾਲ ਜਗਮਗਾ ਰਿਹਾ ਸੀ ਲਹਿਰਾਂ ਗਰਜ਼ ਕਿਉਠੀਆਂ ਅਤੇ ਚੱਟਾਂ ਨਾਲ ਉੱਤੇ ਆ ਕੇ ਡਿੱਗੀਆਂ ਕਿ ਅਸੀਂ ਕੱਲ੍ਹ ਅਗਣੀ ਲਾਂਜ ਕਰਨ ਵਿੱਚ ਸਫ਼ਲ ਹੋਵਾਂਗੇ ਇਹ ਸਵਾਲ ਹਰ ਕਿਸੇ ਦੇ ਦਤਿਆਰ ਨਹੀਂ ਸੀ। ਆਖਰ ਘਾਰ ਰੱਖਿਆ ਮਨਤਰੀ ਨੇ ਚੁੱਪੀ ਤੋੜੀ ਅਤੇ ਮੈਨੂੰ ਪੁੱਛਿਆ। ਕਲਾਮ ਕੱਲ੍ਹ ਅਸੀਂ ਅਗਣੀ ਦੀ ਸਫ਼ਲ ਦਾ ਕਿਵੇਂ ਮਨਾਈਏ ਕਿ ਤੁਹਾਡੇ ਦਿਲਦੀ ਕੋਈ ਇੱਛਾ ਹੈ ਇਹ ਇੱਕ ਸਰਲ ਸਵਾਲ�ਂ ਸੀ? ਫਿਰ ਮੈਂ ਇਹਨੂੰ ਜਵਾਬ ਮਿਲਿਆ। ਸਾਨੂੰ ਸਾਡੇ ਖੁੱਝ ਕੇ ਇੰਦਰ ਵਿੱਚ ਲਗਾਉਣ ਲਈ ਇੱਕ ਇਲੱਕ ਕਿਪੂਟੇ ਚਾਹੀਦੇ ਹਨ। ਮੈਂ ਕਿਹਾ ਮਨਤਰੀ ਜੀ ਦਾ ਚਿਹਰਾ ਦੋਸਤਾ ਨਾ ਚਮਕ ਨਾਲ ਚਮਕ ਉਠਿਆ ਤੁਸੀਂ ਅਗਣੀ ਲਈ ਤਰਤੀ ਮਾਂ ਦਾ ਅਸ਼ੀਰਤੇ ਅਗਨੀ ਲਈ ਤਰਤੀਮਾਂ ਦਾ ਅਸ਼੍ਰਵਾਰ ਲੈ ਰਹੇ ਹੋ। ਅਸੀਂ ਕੱਲ੍ਹ ਜ਼ਰੂਲ ਸਫ਼ਲ ਹੋਵਾਂਗੇ, ਹੁਣਾਂ ਨੇ ਨਿੱਕੀ ਪਵਿਕਵਾਣੀ ਕੀਤੀ। ਅਗਲੇ ਦਿਨ ਸਵੇਰ ਇਹ ਸੱਤ ਦੱਸ ਦੇ ਅਗਨੀ ਨੇ ਉਡਾਨ ਪਰੀ, ਇਹ ਬਿਲਕੁਲ ਸਹੀ ਲਾਂਜ ਸੀ। ਇਹ ਰਾਤ ਦੀ ਡਰਾਉਣੇ ਨੀਨ ਤੋਂ ਬਾਹਦ, ਉੱਜਵਾਲ ਅਤੇ ਤਾਜੀ ਸਵੇਰ ਵਿੱਖਵੱਖ ਕਾਰਜ ਕਿੰਨਰਾਂ ਤੇ ਪੰਜ ਸਾਲਾਂ ਦੇ ਨਿਰਨ ਤਰ ਕੰਮ ਦੇ ਨਤੀਜੇ ਦਾ ਫਲ ਆਖਰ ਕਾਰ ਮਿਲ ਗਿਆ ਸੀ। ਸਿਰਫ਼ ਛੇਸੋ ਸਕਿੰਟਾਂ ਦੀ ਸ਼ਾਨਦਾਰ ਰੁਡਾਨ ਨੇ ਸਾਡੀ ਸਾਰੀ ਥਕਾਵੱਟ ਦੂਰ ਕਰ ਦਿੱਤੀ ਸੀ। ਕਈ ਸਾਲਾਂ ਦੀ ਮਿਹਨਾਂ ਦਾ ਇੰਨਾਂ ਸ਼ਾਨਦਾਰ ਫਲ ਮਿਲਿਆ ਸੀ। ਇਹ ਮੇਰੀ ਜ਼਼ਾਨਗੀ ਦੇ ਸਭ ਤੋਂ ਮਾਹਾਨ ਪਲਾ ਵਿੱਚੋਂ ਇੱਕ ਸੀ ਅਤੇ ਰਹੇਗਾ।"

In [None]:
metric.compute(predictions=[pred_transcript], references=[gt_transcript])

In [None]:
import jiwer

# compute the number of insertions, deletions, substitutions
measures = jiwer.compute_measures(gt_transcript, pred_transcript)
print(f"WER: {measures['wer']}")
print(f"Insertions: {measures['insertions']}")
print(f"Deletions: {measures['deletions']}")
print(f"Substitutions: {measures['substitutions']}")
print(f"Ground Truth Total Words: {len(gt_transcript.split())}")
print(f"Predicted Total Words: {len(pred_transcript.split())}")

In [8]:
def preprocess_text(text: str) -> List[str]:
    text = " ".join([word.strip() for word in text.strip().split()])
    # remove all punctuations including brackets
    text = re.sub(r'[.,।?!♪”‘॥\…\-\(\)\[\]\{\}]', '', text)
    # text = re.sub(r'[^\w\s]|♪', '', text)
    return text

def remove_text_in_brackets(text: str) -> str:
    """Remove all text between square brackets, including the brackets.
    
    This function will remove any text that is enclosed within square brackets 
    (e.g., [example], [text to remove], [1, 2, 3], etc.) along with the brackets themselves.
    It will also handle nested brackets and any text that may appear before or after the brackets.
    """
    # Use a stack to handle nested brackets
    stack = []
    result = []
    i = 0

    while i < len(text):
        if text[i] == '[':
            stack.append(i)  # Push the index of the opening bracket onto the stack
        elif text[i] == ']' and stack:
            stack.pop()  # Pop the last opening bracket index
            if not stack:  # If the stack is empty, it means we found a complete bracket pair
                # Remove the text between the last opening and this closing bracket
                result.append('')  # Append an empty string to remove the content
        else:
            if not stack:  # Only add characters to result if we are not inside brackets
                result.append(text[i])
        i += 1

    return ''.join(result).strip()

In [9]:
txt = " Text"
srt = " SRT"
vid = " Videos"

In [10]:
import re
from typing import List

def preprocess_text(text: str) -> List[str]:
    text = " ".join([word.strip() for word in text.strip().split()])
    # remove punctuations
    # text = re.sub(r'[.,।?!-]', '', text)
    text = re.sub(r'[^\w\s]', '', text)
    return text

In [None]:
gt_transcripts = {}
filenames = []
segment_paths = data_path / "Results" / "Segments_26_01"
for dir_path in transcript_path.glob("*"):
    for file_path in dir_path.glob("*.txt"):
        if any(benchmark_name in file_path.name for benchmark_name in benchmark_names):
            print(file_path.name)
            with open(file_path, "r") as f:
                gt_transcript = f.read()
            gt_transcript = preprocess_text(gt_transcript)
            gt_transcripts[file_path.name] = gt_transcript
            filenames.append(file_path.name.replace(".txt", ""))

filenames = sorted(filenames)

pred_transcripts = {}

print("-"*100)
for file_path in segment_paths.glob("*.json"):
    if any(benchmark_name in file_path.name for benchmark_name in benchmark_names):
        print(file_path.name)
        with open(file_path, "r") as f:
            result = json.load(f)
        pred_transcripts[file_path.name] = preprocess_text(result["text"])

print("-"*100)
jiwer_df = pd.DataFrame()
jiwer_df["Base Name"] = filenames
jiwer_df["WER"] = float("nan")
jiwer_df["Insertions"] = 0
jiwer_df["Deletions"] = 0
jiwer_df["Substitutions"] = 0
jiwer_df["Predicted Total Words"] = 0

for base_name in sorted(filenames):
    gt_transcript = gt_transcripts[base_name + ".txt"]
    gt_transcript = remove_text_in_brackets(gt_transcript)
    gt_transcript = preprocess_text(gt_transcript)
    pred_transcript = pred_transcripts[base_name + ".json"]
    pred_transcript = preprocess_text(pred_transcript)
    measures = jiwer.compute_measures(gt_transcript, pred_transcript)
    print(f"Base Name: {base_name}")
    print(f"WER: {measures['wer']}")
    print(f"Insertions: {measures['insertions']}")
    print(f"Deletions: {measures['deletions']}")
    print(f"Substitutions: {measures['substitutions']}")
    print(f"Ground Truth Total Words: {len(gt_transcript.split())}")
    print(f"Predicted Total Words: {len(pred_transcript.split())}")
    print("-"*100)
    # round the wer value to 2 decimal places
    jiwer_df.loc[jiwer_df["Base Name"] == base_name, "WER"] = round(measures["wer"]*100, 2)
    jiwer_df.loc[jiwer_df["Base Name"] == base_name, "Insertions"] = measures["insertions"]
    jiwer_df.loc[jiwer_df["Base Name"] == base_name, "Deletions"] = measures["deletions"]
    jiwer_df.loc[jiwer_df["Base Name"] == base_name, "Substitutions"] = measures["substitutions"]
    jiwer_df.loc[jiwer_df["Base Name"] == base_name, "Predicted Total Words"] = len(pred_transcript.split())
    jiwer_df.loc[jiwer_df["Base Name"] == base_name, "GT Total Words"] = len(gt_transcript.split())

jiwer_df.to_csv(data_path / "Results" / "Jiwer_results.csv", index=False)




In [None]:
jiwer_df.head(13)