# Alignment Logic

In [49]:
from typing import List, Tuple, Dict
import numpy as np
from transformers import AutoModel, AutoTokenizer
from dataclasses import dataclass
from Levenshtein import distance as levenshtein_distance
from phonetics import metaphone
import pysrt
from pathlib import Path
import json
import pandas as pd
from copy import deepcopy

In [50]:
@dataclass
class WordTimestamp:
    word: str
    start: float
    end: float

@dataclass
class TimeAlignedSentence:
    sentence: str
    start_time: float
    end_time: float
    words: List[WordTimestamp]


@dataclass
class AlignmentResult:
    ground_truth_sentence: list[str]
    predicted_sentence: TimeAlignedSentence | None
    similarity_score: float


In [51]:
import sys
sys.path.append('..')

from scripts.utils import (
    read_srt_file,
    load_segments_from_json,
    read_ground_truth_file,
    srt_to_text,
    seconds_to_srt_time,
    fix_json_file,
    generate_srt_file_output,
    combine_repeated_words,
    WordTimestamp,
    TimeAlignedSentence,
    AlignmentResult
)

In [52]:
class EnhancedTranscriptAligner:
    def __init__(self):
        # self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
        # self.model = AutoModel.from_pretrained('bert-base-uncased')
        self.similarity_threshold = 0.6

    def preprocess_gt_text(self, gt_text: str) -> List[str]:
        return " ".join([word.strip() for word in gt_text.strip().split()])

    def get_word_embeddings(self, words: List[str]) -> np.ndarray:
        inputs = self.tokenizer(words, return_tensors='pt', padding=True, truncation=True)
        outputs = self.model(**inputs)
        # Use the mean of the last hidden state as the word embeddings
        embeddings = outputs.last_hidden_state.mean(dim=1).detach().numpy()
        return embeddings

    def cosine_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
        return np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))

    def phonetic_similarity(self, word1: str, word2: str) -> float:
        return 1.0 if metaphone(word1) == metaphone(word2) else 0.0

    def combined_similarity(self, gt_word: str, pred_word: str) -> float:
        # cosine_sim = self.cosine_similarity(gt_embedding, pred_embedding)
        lev_dist = levenshtein_distance(gt_word, pred_word)
        phonetic_sim = self.phonetic_similarity(gt_word, pred_word)

        # Normalize Levenshtein distance to a similarity score
        max_len = max(len(gt_word), len(pred_word))
        lev_sim = 1 - (lev_dist / max_len) if max_len > 0 else 0

        # Weighted combination of similarities
        return 0.5 * phonetic_sim + 0.5 * lev_sim

    def dtw_align_with_embeddings(
        self,
        ground_truth: str,
        predicted_sentences: List[TimeAlignedSentence]
    ) -> Tuple[List[AlignmentResult], set]:
        """
        Align ground truth text with time-aligned predicted sentences using sliding window approach.
        """
        # Preprocess ground truth
        processed_gt_text = self.preprocess_gt_text(ground_truth)
        gt_words = processed_gt_text.split()
        # gt_word_embeddings = self.get_word_embeddings(gt_words)
        
        # Initialize results
        best_alignment = []
        matched_word_indices = set()
        updated_pred_sentences = []
        
        # Process each predicted sentence
        for j, pred_sentence in enumerate(predicted_sentences):
            pred_words = pred_sentence.sentence.lower().split()
            if not pred_words:
                continue
                
            # pred_word_embeddings = self.get_word_embeddings(pred_words)
            max_similarity = -np.inf
            best_segment_alignment = None
            window_size = len(pred_words)
            
            # Slide window over ground truth text
            for k in range(len(gt_words) - window_size + 1):
                # Skip if segment is already matched
                if any(idx in matched_word_indices for idx in range(k, k + window_size)):
                    continue
                    
                # Calculate similarities for current and previous positions
                similarities = []
                
                # Current position
                curr_similarity = sum(
                    self.combined_similarity(gt_words[k + m], pred_words[m])
                    for m in range(window_size)
                ) / window_size
                similarities.append((k, curr_similarity))
                
                # Previous position (if not at start)
                if k > 0:
                    prev_similarity = sum(
                        self.combined_similarity(gt_words[k - 1 + m], pred_words[m])
                        for m in range(window_size)
                    ) / window_size
                    similarities.append((k - 1, prev_similarity))
                
                # Find best similarity for this window position
                best_k, similarity = max(similarities, key=lambda x: x[1])
                
                if similarity > max_similarity:
                    max_similarity = similarity
                    segment_indices = range(best_k, best_k + window_size)
                    best_segment_alignment = (j, segment_indices, similarity)
            
            # Process best alignment if found and above threshold
            if best_segment_alignment and best_segment_alignment[2] >= self.similarity_threshold:
                j, aligned_indices, score = best_segment_alignment
                matched_word_indices.update(aligned_indices)
                
                # Create alignment result
                aligned_words = [gt_words[idx] for idx in aligned_indices]
                result = AlignmentResult(
                    ground_truth_sentence=' '.join(aligned_words),
                    predicted_sentence=predicted_sentences[j],
                    similarity_score=score
                )
                best_alignment.append(result)
                
                # Create updated predicted sentence with aligned words
                updated_pred = TimeAlignedSentence(
                    sentence=predicted_sentences[j].sentence,
                    start_time=predicted_sentences[j].start_time,
                    end_time=predicted_sentences[j].end_time,
                    words=[
                        WordTimestamp(
                            word=gt_word,
                            start=predicted_sentences[j].words[m].start,
                            end=predicted_sentences[j].words[m].end
                        )
                        for m, gt_word in enumerate(aligned_words)
                    ]
                )
                updated_pred_sentences.append(updated_pred)
        
        return best_alignment, matched_word_indices, updated_pred_sentences

In [53]:
data_path = Path("../data/Punjabi")
audio_path = data_path / "Audio"
transcript_path = data_path / "Text"
srt_path = data_path / "Ground Truth SRT"
results_path = data_path / "Results"

In [None]:
benchmark_df = pd.read_csv(data_path / "benchmark_list.csv", index_col=0).reset_index(drop=True)
benchmark_df.head()

In [None]:
benchmark_paths = benchmark_df["Story Name"].unique().tolist()
benchmark_paths

In [None]:
# Create lists to store decoded audio and transcripts
base_lists = {"audio_path": [], "transcript_path": [], "srt_path": [], "transcript": [], "sampling_rate": [], "array": []}
dataset_dict = {"train": deepcopy(base_lists), "val": deepcopy(base_lists)}
dataset_dict

In [None]:
for dir_path in audio_path.glob("*"):
    for file_path in dir_path.glob("*.wav"):
        print(file_path.name)
        if any(benchmark_path in file_path.name for benchmark_path in benchmark_paths):
            dataset_type = "val"
        else:
            dataset_type = "train"
        print(dataset_type)

        # Append audio path and transcript path to dataset
        file_audio_path = str(file_path)
        file_transcript_path = transcript_path / dir_path.name.replace("Videos", "Text") / str(file_path.name.replace(".wav", ".txt"))
        file_srt_path = srt_path / dir_path.name.replace("Videos", "SRT") / f"{file_path.name.replace('.wav', '.srt')}"

        dataset_dict[dataset_type]["audio_path"].append(file_audio_path)
        dataset_dict[dataset_type]["transcript_path"].append(file_transcript_path)
        dataset_dict[dataset_type]["srt_path"].append(file_srt_path)
        # # Load and decode audio
        # decoded_audio, sampling_rate = sf.read(file_audio_path)
        # dataset_dict[dataset_type]["array"].append(decoded_audio)
        # dataset_dict[dataset_type]["sampling_rate"].append(sampling_rate)

        # # Read transcript and append to dataset
        # with open(file_transcript_path, 'r', encoding='utf-8') as file:
        #     transcript = file.read()
        #     dataset_dict[dataset_type]["transcript"].append(transcript)

In [None]:
inference_data = {}
segments_path = results_path / "Segments"

for file_transcript_path, file_srt_path in zip(dataset_dict["val"]["transcript_path"], dataset_dict["val"]["srt_path"]):
    base_name = Path(file_transcript_path).name.replace(".txt", "")
    inference_data[base_name] = {}
    inference_data[base_name]["srt_path"] = file_srt_path
    json_path = segments_path / Path(file_transcript_path).name.replace(".txt", ".json")

    print(f"Loading input predicted json file {json_path} ...")
    pred_srt_sentences = load_segments_from_json(json_path)
    inference_data[base_name]["pred_srt_sentences"] = pred_srt_sentences


    print(f"Loading ground truth text file {file_transcript_path} ...")
    ground_truth_text = read_ground_truth_file(file_transcript_path)
    inference_data[base_name]["ground_truth_text"] = ground_truth_text

In [None]:
print("Loading the alignment logic ...\n")
# Align the predicted sentences with the ground truth text
aligner = EnhancedTranscriptAligner()

In [115]:
# Post processing the results

def post_process_results(results: List[AlignmentResult], ground_truth_text: str, matched_word_indices: set, updated_pred_sentences: List[TimeAlignedSentence]) -> Tuple[List[AlignmentResult], List[TimeAlignedSentence]]:
    """
    Post-process the alignment results to handle unaligned and overlapping words.

    Args:
        results (List[AlignmentResult]): List of alignment results.
        ground_truth_text (str): Ground truth text.
        matched_word_indices (set): Set of matched word indices.

    Returns:
        List[AlignmentResult]: List of post-processed alignment results.
    """
    gt_words_list = ground_truth_text.split()
    gt_words_unaligned_indices = set(range(len(gt_words_list))) - matched_word_indices

    for i in range(1, len(results)):
        prev_result = results[i - 1]
        curr_result = results[i]

        if prev_result.predicted_sentence and curr_result.predicted_sentence:
            prev_end_time = prev_result.predicted_sentence.end_time
            curr_start_time = curr_result.predicted_sentence.start_time

            if prev_end_time == curr_start_time:
                prev_words = set(prev_result.ground_truth_sentence)
                curr_words = set(curr_result.ground_truth_sentence)

                # Form a string which is supposed to match with ground truth
                match_string = ' '.join(prev_result.ground_truth_sentence) + ' ' + ' '.join(curr_result.ground_truth_sentence)
                if match_string in ground_truth_text:
                    continue
                else:
                    # Identify unaligned words in the between the segments
                    unaligned_words_list = [gt_words_list[idx] for idx in gt_words_unaligned_indices]
                    max_n = min(5, len(unaligned_words_list))  # Handle case where unaligned words are less than 5
                    for n in range(1, max_n + 1):  # Form unaligned strings of 1 to max_n consecutive words
                        for j in range(len(unaligned_words_list) - n + 1):
                            unaligned_string = ' '.join(unaligned_words_list[j:j + n])
                            match_string = ' '.join(prev_result.ground_truth_sentence) + ' ' + unaligned_string + ' ' + ' '.join(curr_result.ground_truth_sentence)
                            if match_string in ground_truth_text:
                                # Add them to the previous segment
                                prev_result.ground_truth_sentence += ' ' + unaligned_string
                                for word in unaligned_words_list[j:j + n]:
                                    gt_words_unaligned_indices.remove(gt_words_list.index(word))
                                break
                        else:
                            continue
                        break

                    # Identify overlapping words
                    overlapping_words = prev_words & curr_words
                    # Remove overlapping words from the previous segment
                    if overlapping_words:
                        prev_result.ground_truth_sentence = ' '.join(
                            word for word in prev_result.ground_truth_sentence.split() if word not in overlapping_words
                        )

    
    used_indices = set()
    for i in range(1, len(updated_pred_sentences)):
        prev_result = updated_pred_sentences[i - 1]
        curr_result = updated_pred_sentences[i]

        if prev_result.words and curr_result.words:
            prev_words = prev_result.words
            curr_words = curr_result.words


        # Identify unaligned words between the segments
        prev_words_list = [w.word for w in prev_words]
        curr_words_list = [w.word for w in curr_words]

        # Find the indices of the previous and current words in the ground truth
        prev_index = -1
        curr_index = -1

        for j in range(len(gt_words_list) - len(prev_words_list) + 1):
            if j not in used_indices and gt_words_list[j:j + len(prev_words_list)] == prev_words_list:
                prev_index = j + len(prev_words_list) - 1
                used_indices.update(range(j, j + len(prev_words_list)))
                break

        for j in range(len(gt_words_list) - len(curr_words_list) + 1):
            if j not in used_indices and gt_words_list[j:j + len(curr_words_list)] == curr_words_list:
                curr_index = j
                # used_indices.update(range(j, j + len(curr_words_list)))
                break
       
        # If both indices are valid, handle the alignment cases
        if prev_index != -1 and curr_index != -1:
            if prev_index == curr_index - 1:
                # Ideal continuation, do nothing
                pass
            elif prev_index >= curr_index:
                # Remove overlapping words from prev_words using indices
                overlap_count = prev_index - curr_index + 1
                prev_words = prev_words[:-overlap_count]
            elif prev_index < curr_index - 1:
                # There are unaligned words in between, add these words to curr_words
                unaligned_words = gt_words_list[prev_index + 1:curr_index]
                for word in reversed(unaligned_words):
                    curr_words.insert(0, WordTimestamp(word=word, start=curr_words[0].start, end=curr_words[0].end))

    return results, updated_pred_sentences

In [None]:
inference_results = {}

for base_name in inference_data.keys():
    ground_truth_text = inference_data[base_name]["ground_truth_text"]
    pred_srt_sentences = inference_data[base_name]["pred_srt_sentences"]

    print(f"Aligning for {base_name} ...")
    results, matched_word_indices, updated_pred_sentences = aligner.dtw_align_with_embeddings(ground_truth_text, pred_srt_sentences)
    inference_results[base_name] = {}
    inference_results[base_name]["results"] = results
    inference_results[base_name]["matched_word_indices"] = matched_word_indices
    inference_results[base_name]["updated_pred_sentences"] = updated_pred_sentences

    print("Post processing the results ...\n")
    # Post processing
    post_processed_results, post_processed_pred_sentences = post_process_results(results, ground_truth_text, matched_word_indices, updated_pred_sentences)
    inference_results[base_name]["post_processed_results"] = post_processed_results
    inference_results[base_name]["post_processed_pred_sentences"] = post_processed_pred_sentences


In [None]:
alignment_word_level_path = results_path / "Alignment" / "Word_level"
alignment_word_level_path.mkdir(parents=True, exist_ok=True)

alignment_sentence_level_path = results_path / "Alignment" / "Sentence_level"
alignment_sentence_level_path.mkdir(parents=True, exist_ok=True)

for base_name in inference_results:
    print(f"Alignment Results for {base_name} ...")
    print("-" * 60)
    print("   Ground Truth  |  Predicted  |  Time  | Similarity   ")
    print("-" * 60)

    for result, updated_pred in zip(inference_results[base_name]["post_processed_results"], inference_results[base_name]["post_processed_pred_sentences"]):
        if result.predicted_sentence:
            print(f"{result.ground_truth_sentence:<50} | "
                f"{result.predicted_sentence.sentence:<50} | "
                f"{result.predicted_sentence.start_time:.2f}-{result.predicted_sentence.end_time:.2f} | "
                f"{result.similarity_score:.3f}")
            # print("Word Level Timestamps for this predicted segment:")
            # for word in updated_pred.words:
            #     print(f"   {word.word:<15} | Start: {word.start:.2f} | End: {word.end:.2f}")
            print("-" * 60)
        else:
            print(f"{result.ground_truth_sentence:<50} | {'':50} | {'':8} | {result.similarity_score:.3f}")

    generate_srt_file_output(
        # inference_results[base_name]["post_processed_results"],
        # inference_results[base_name]["post_processed_pred_sentences"],
        inference_results[base_name]["results"],
        inference_results[base_name]["updated_pred_sentences"],
        alignment_word_level_path / f"{base_name}.srt",
        alignment_sentence_level_path / f"{base_name}.srt"
    )
    print(f"Aligned GT transcript saved as {base_name}_Word_level.srt")

In [118]:

def combine_repeated_words(srt_file_path, output_file_path):
    # Open the SRT file
    subs = pysrt.open(srt_file_path)
    combined_subs = pysrt.SubRipFile()

    previous_word = None
    start_time = None
    end_time = None

    for sub in subs:
        # Clean the text to avoid whitespace issues
        current_word = sub.text.strip()

        if previous_word is None:
            # Initialize the first word block
            previous_word = current_word
            start_time = sub.start
            end_time = sub.end
        elif current_word == previous_word:
            # If the word is repeated, update the end time
            end_time = sub.end
        else:
            # Add the combined segment to the output
            combined_subs.append(pysrt.SubRipItem(index=len(combined_subs) + 1, start=start_time, end=end_time, text=previous_word))
            # Start a new block for the new word
            previous_word = current_word
            start_time = sub.start
            end_time = sub.end

    # Add the last remaining block
    combined_subs.append(pysrt.SubRipItem(index=len(combined_subs) + 1, start=start_time, end=end_time, text=previous_word))

    combined_subs.save(output_file_path)
    print(f"Combined SRT saved at {output_file_path}")


In [119]:
# for base_name in inference_results:
#     # Example usage
#     srt_file_path = alignment_word_level_path / f"{base_name}.srt"  # Replace with your input SRT file path
#     output_file_path = alignment_word_level_path / f"{base_name}_final.srt"  # Replace with your desired output SRT file path
#     combine_repeated_words(srt_file_path, output_file_path)

In [120]:
# prompt: there is a word level srt file and a sentence level srt file, compare the sentence with the combined words and match them.
# And now calculate the deviation between timestamps of combined words (first word start ts and last word end ts) and sentence level timestamps.
# Give a metric with subtitle error rate from 0 to 100

import pysrt


def SER_srt_files(word_srt_path, sentence_srt_path):
    """
    Aligns word-level timestamps from one SRT file with sentences in another,
    identifying potential timestamp errors.

    Args:
        word_srt_path (str): Path to the SRT file with word-level timestamps.
        sentence_srt_path (str): Path to the SRT file with sentences.
    """
    try:
        word_subs = pysrt.open(word_srt_path)
        sentence_subs = pysrt.open(sentence_srt_path)
    except Exception as e:
        print(f"Error opening SRT files: {e}")
        return

    aligned_results = []
    error = []
    min_ts = float('inf')
    max_ts = float('-inf')

    matched_word_indices = set()
    L = 0

    deviation = 0
    word_index = 0
    for s,sentence_sub in enumerate(sentence_subs):
        # Regex to match any character that is not a word character (alphanumeric + underscore) or whitespace
        # pattern = r'[^\w\s]'

        # Substitute matched characters with empty string
        # sentence_text = re.sub(pattern, '', sentence_sub.text)

        sentence_text = " ".join([sub.strip() for sub in sentence_sub.text.split()])

        # sentence_text = sentence_sub.text
        L = len(sentence_text.split())
        word_index += L
        # print("-" * 60)
        # print(f"Sentence Index: {s}")
        # print(f"Sentence Text: {sentence_text}")
        # print(f"Sentence Length : {L}")
        # print(f"Word Index: {word_index}\n")

        # Find corresponding word-level timestamps by joining adjacent words
        start_time = float('inf')
        end_time = float('-inf')

        new_matched_word_indices = set()
        # Join adjacent word sub texts and find matches
        # joined_text = ''
        joined_sub = []
        tolerate_matching = 0

        for idx, word_sub in enumerate(word_subs):
            if matched_word_indices:
                max_index = max(matched_word_indices)
                matched_word_indices.update(set(range(max_index + 1)))
            if idx in matched_word_indices:
                continue  # Skip already matched word_subs

            word = word_sub.text
            joined_sub.append(word_sub)
            joined_text = " ".join([sub.text for sub in joined_sub])
            
            # print(f"Joined text: {joined_text}")
            # print(f"Joined Length : {len(joined_sub)}\n")
            new_matched_word_indices.add(idx)
            # Check if the joined text matches the sentence text
            # Remove punctuations from joined text
            # cleaned_joined_text = re.sub(pattern, '', joined_text)

            if len(joined_sub) == L:
                if sentence_text == joined_text:
                    print(f"Matched: {joined_text}\n")
                    start_time = min(start_time, joined_sub[0].start.ordinal / 1000)
                    end_time = max(end_time, joined_sub[-1].end.ordinal / 1000)
                    matched_word_indices.update(new_matched_word_indices)

                    # new_matched_word_indices = set()
                    # joined_text = ''
                    joined_sub = []
                    break
                tolerate_matching += 1
                if tolerate_matching > 5:
                    matched_word_indices.update(new_matched_word_indices)
                    break
                
                # Removing the first word in joined_text, as the length has exceeded
                joined_sub = joined_sub[1:]
                # print(f"Joined text: {joined_text}")
                # print(f"Joined Length : {len(joined_sub)}\n")
                new_matched_word_indices.remove(idx - L + 1)


        min_ts = min(sentence_sub.start.ordinal, min_ts)
        max_ts = max(sentence_sub.end.ordinal, max_ts)

        # print(f"Sentence Start Time: {sentence_sub.start.ordinal / 1000}")
        # print(f"Sentence End Time: {sentence_sub.end.ordinal / 1000}")
        # print(f"Start Time: {start_time}")
        # print(f"End Time: {end_time}")
        # print("-" * 60)

        # Calculate deviation
        if start_time != float('inf'):
            deviation = abs(start_time - sentence_sub.start.ordinal / 1000)
            deviation = max(0, deviation-0.1)
            if deviation > 5:
                continue
            error.append(deviation)
            # print(deviation)

        # break

    total_length = (max_ts - min_ts) / 1000 if max_ts > min_ts else 1  # Avoid division by zero

    # Calculate Subtitle Error Rate (SER)
    ser = (sum(error) / total_length) * 100 if total_length > 0 else 0
    print(f"Subtitle Error Rate (SER): {ser:.3f}% \n")


In [None]:
for base_name in inference_results:
    # if not base_name == "Abdul_Kalam,_Missile_Man_Punjabi":
    #     continue
    print(f"Calculating SER for {base_name} ...")
    pred_aligned_srt_path = alignment_word_level_path / f"{base_name}.srt"
    gt_aligned_srt_path = inference_data[base_name]["srt_path"]
    SER_srt_files(pred_aligned_srt_path, gt_aligned_srt_path)
    break