In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Import necessary modules
from pathlib import Path

import pandas as pd

from IPython.core.display import HTML
from IPython.display import display

from dotenv import load_dotenv

# Import the transcription service modules
from tnh_scholar.audio_processing.transcription_service import (
    TranscriptionServiceFactory,
)

# Load environment variables from .env file
load_dotenv()


In [None]:
def test_transcription(audio_file_path, provider="whisper", options=None):
    """
    Test transcription with the specified provider.
    
    Args:
        audio_file_path: Path to the audio file
        provider: 'whisper' or 'assemblyai'
    
    Returns:
        Transcription result
    """
    if options is None:
        options = {}
    # Create the transcription service
    service = TranscriptionServiceFactory.create_service(provider=provider)

    # Print some info
    print(f"Testing transcription with {provider} service...")
    print(f"Audio file: {audio_file_path}")

    # Transcribe the audio file
    result = service.transcribe(audio_file_path, options)

    # Print the transcription text
    print("\nTranscription result:")
    print("-" * 80)
    print(result["text"])
    print("-" * 80)

    # Return the full result for further inspection
    return result

In [None]:
def test_format_generation(audio_file_path, provider="whisper", format_type="srt"):
    """
    Test generation of formatted transcription (SRT or VTT).
    
    Args:
        audio_file_path: Path to the audio file
        provider: 'whisper' or 'assemblyai'
        format_type: 'srt' or 'vtt'
    
    Returns:
        Formatted transcription as string
    """
    # Create the transcription service
    service = TranscriptionServiceFactory.create_service(provider=provider)
    
    # Print some info
    print(f"Testing {format_type.upper()} generation with {provider} service...")
    print(f"Audio file: {audio_file_path}")
    
    # Generate the formatted transcription
    formatted_result = service.transcribe_to_format(
        audio_file_path, 
        format_type=format_type
    )
    
    # Print a sample of the formatted output
    print(f"\nSample {format_type.upper()} output:")
    print("-" * 80)
    lines = formatted_result.splitlines()
    sample_lines = lines[:min(20, len(lines))]  # First 20 lines or all if fewer
    print("\n".join(sample_lines))
    print("...")
    print("-" * 80)
    
    # Save the formatted output to a file
    output_file = Path(f"transcription_output_{provider}_{format_type}.{format_type}")
    with open(output_file, "w", encoding="utf-8") as f:
        f.write(formatted_result)
    print(f"Full output saved to: {output_file}")
    
    return formatted_result

In [None]:
# Replace with your audio file path
working_dir = Path.home() / "Desktop" / "wouter_video"
audio_file = working_dir / "dt_trim_enh_short.wav"
if not audio_file.exists:
    raise FileNotFoundError("Not found.")

In [None]:
# Test Whisper transcription
options = {
    "language": "vi",
           }
whisper_result = test_transcription(audio_file, provider="whisper", options=options)



In [None]:
# Examine specific parts of the result
print("\nExamining detailed result:")
print(f"Language: {whisper_result['language']}")
print(f"Audio duration: {whisper_result.get('audio_duration', 'Not available')} milliseconds")
print(f"Word count: {len(whisper_result.get('words', []))}")

In [None]:
# Test AssemblyAI transcription
options = {
    "language_code": "vi", 
    "language_detection": False, 
    "speaker_labels": False, 
}
assemblyai_result = test_transcription(audio_file, provider="assemblyai", options=options)


In [None]:
# Examine specific parts of the result
print("\nExamining detailed result:")
print(f"Language: {assemblyai_result['language']}")
print(f"Audio duration: {assemblyai_result['audio_duration_ms']} milliseconds")
print(f"Word count: {len(assemblyai_result['words'])}")
print(f"Utterance count: {len(assemblyai_result['utterances'])}")

# Look at the first few utterances if available
if assemblyai_result['utterances']:
    print("\nSample utterances:")
    for i, utterance in enumerate(assemblyai_result['utterances'][:3]):
        print(f"Speaker {utterance['speaker']}: {utterance['text'][:100]}...")
        if i >= 2:
            break

In [None]:
def replace_html_header(html_string: str) -> str:
    """
    Replaces the entire <head>...</head> section of an HTML string
    with a new header containing custom CSS for diff display.

    Args:
        html_string (str): The original HTML string.

    Returns:
        str: The modified HTML string with the replaced <head> section.
    """
    new_head = """
<head>
    <meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
    <title></title>
    <style type="text/css">
        table.diff { font-family: Courier; border: medium; }
        .diff_header { background-color: #e0e0e0; }
        td.diff_header { text-align: right; }
        .diff_next { background-color: #c0c0c0; }
        .diff_add { background-color: lightgreen; color: black; }
        .diff_chg { background-color: lightyellow; color: black; }
        .diff_sub { background-color: lightcoral; color: black; }
    </style>
</head>
"""

    # Replace only the <head>...</head> section using non-greedy matching
    return re.sub(r"<head>.*?</head>", new_head, html_string, flags=re.DOTALL | re.IGNORECASE)

In [None]:
def compare_transcripts_high_contrast(whisper_result, assemblyai_result):
    # Clean transcripts
    whisper_text = re.sub(r'[^\w\s]', '', whisper_result['text']).lower()
    assemblyai_text = re.sub(r'[^\w\s]', '', assemblyai_result['text']).lower()
    
    # Split into words
    whisper_words = whisper_text.split()
    assemblyai_words = assemblyai_text.split()
    
    # Generate diff with custom colors
    d = difflib.HtmlDiff()
    
    # Get the standard diff table
    html_diff = d.make_file(whisper_words, assemblyai_words, "Whisper", "AssemblyAI")
    
    html_diff = replace_html_header(html_diff)

    # Calculate match statistics
    matcher = difflib.SequenceMatcher(None, whisper_words, assemblyai_words)
    matching = sum(b.size for b in matcher.get_matching_blocks())
    
    print(f"Whisper words: {len(whisper_words)}")
    print(f"AssemblyAI words: {len(assemblyai_words)}")
    print(f"Matching words: {matching}")
    print(f"Match percentage: {matching/len(whisper_words)*100:.2f}% (of Whisper)")
    print(f"Match percentage: {matching/len(assemblyai_words)*100:.2f}% (of AssemblyAI)")
    
    display(HTML(html_diff))
    return html_diff

# Run the comparison
html_result = compare_transcripts_high_contrast(whisper_result, assemblyai_result)

In [None]:
def get_word_duration(row, time_system='assemblyai'):
    """
    Calculate the duration of a word in milliseconds.
    
    Args:
        row: DataFrame row containing start and end times
        time_system: Which time system to use ('assemblyai' or 'whisper')
        
    Returns:
        Duration in milliseconds
    """
    prefix = f"{time_system}_"
    start_key = f"{prefix}start"
    end_key = f"{prefix}end"
    
    # Make sure we have valid times
    if start_key not in row or end_key not in row:
        return 0
        
    return row[end_key] - row[start_key]


def adjust_add_timestamps(df, add_row_index):
    """
    Adjust timestamps for an 'add' row using a simple algorithm:
    1. Adjust previous word end time to match its AssemblyAI duration
    2. Set add word start time to be the adjusted end time of previous word
    3. Set add word end time to be the start time of next word
    
    Args:
        df: DataFrame containing the comparison results
        add_row_index: Index of the 'add' row that needs timestamps
        
    Returns:
        Tuple of (start_time_ms, end_time_ms) in milliseconds
    """
    # Get the add row
    add_row = df.iloc[add_row_index]
    
    # Get rows before and after (if they exist)
    has_prev_row = add_row_index > 0
    has_next_row = add_row_index < len(df) - 1
    
    # Calculate new start time (based on previous word)
    if has_prev_row:
        prev_row_index = add_row_index - 1
        prev_row = df.iloc[prev_row_index]
        prev_duration = get_word_duration(prev_row, 'assemblyai')
        new_start = prev_row['whisper_start'] + prev_duration
        
        # Modify the previous word's end time in the DataFrame
        df.at[prev_row_index, 'whisper_end'] = new_start
    else:
        # First word in transcript - use next word start time minus duration
        next_row = df.iloc[add_row_index + 1]
        added_word_duration = get_word_duration(add_row, 'assemblyai')
        new_start = max(0, next_row['whisper_start'] - added_word_duration)
    
    # Calculate new end time (based on next word)
    if has_next_row:
        next_row = df.iloc[add_row_index + 1]
        new_end = next_row['whisper_start']
    else:
        # Last word in transcript - use start time plus duration
        added_word_duration = get_word_duration(add_row, 'assemblyai')
        new_end = new_start + added_word_duration
    
    return int(new_start), int(new_end)

In [None]:
def ms_to_min_sec(ms):
    """Convert milliseconds to 'MM:SS.sss' format."""
    if pd.isnull(ms):
        return ""
    
    # Convert to seconds
    total_seconds = ms / 1000
    
    # Calculate minutes and seconds
    minutes = int(total_seconds // 60)
    seconds = total_seconds % 60
    
    # Format as MM:SS.sss
    return f"{minutes:02d}:{seconds:.3f}"

def compare_transcripts_with_timings(whisper_result, assemblyai_result):
    # Create arrays of word data
    whisper_words = []
    for word_data in whisper_result['words']:
        word = re.sub(r'[^\w\s]', '', word_data['word']).lower()
        whisper_words.append({
            'word': word,
            'start_ms': word_data['start_ms'],
            'end_ms': word_data['end_ms']
        })
    
    assemblyai_words = []
    for word_data in assemblyai_result['words']:
        # Handle multi-word entries
        raw_word = word_data['word'].lower()
        words = re.sub(r'[^\w\s]', '', raw_word).split()
        
        if len(words) == 1:
            # Single word - use timing as is
            assemblyai_words.append({
                'word': words[0],
                'start_ms': word_data['start_ms'],
                'end_ms': word_data['end_ms']
            })
        else:
            # Multiple words - split the timing
            word_count = len(words)
            time_span = word_data['end_ms'] - word_data['start_ms']
            time_per_word = time_span / word_count
            
            for i, word in enumerate(words):
                start_time = word_data['start_ms'] + (i * time_per_word)
                end_time = start_time + time_per_word
                
                assemblyai_words.append({
                    'word': word,
                    'start_ms': int(start_time),
                    'end_ms': int(end_time)
                })
    
    # Extract word sequences for diff
    whisper_sequence = [word_data['word'] for word_data in whisper_words]
    assemblyai_sequence = [word_data['word'] for word_data in assemblyai_words]
    
    # Get diff information
    matcher = difflib.SequenceMatcher(None, whisper_sequence, assemblyai_sequence)
    
    # Build data for the table
    table_data = []
    
    for tag, i1, i2, j1, j2 in matcher.get_opcodes():
        if tag == 'equal':
            # Matching words - add with timing data
            for k in range(i2 - i1):
                w_idx = i1 + k
                a_idx = j1 + k
                
                w_word = whisper_words[w_idx]
                a_word = assemblyai_words[a_idx]
                
                offset = a_word['start_ms'] - w_word['start_ms']
                
                table_data.append({
                    'status': 'match',
                    'whisper_idx': w_idx,
                    'whisper_word': w_word['word'],
                    'whisper_start': w_word['start_ms'],
                    'whisper_end': w_word['end_ms'],
                    'assemblyai_idx': a_idx,
                    'assemblyai_word': a_word['word'],
                    'assemblyai_start': a_word['start_ms'],
                    'assemblyai_end': a_word['end_ms'],
                    'time_diff_ms': offset
                })
                
        elif tag == 'replace':
            # For replacements, create a single row showing both sides together
            whisper_words_list = []
            whisper_idx_list = []
            whisper_start_list = []
            whisper_end_list = []
            
            assemblyai_words_list = []
            assemblyai_idx_list = []
            assemblyai_start_list = []
            assemblyai_end_list = []
            
            # Gather Whisper words
            for i in range(i1, i2):
                w_word = whisper_words[i]
                whisper_words_list.append(w_word['word'])
                whisper_idx_list.append(i)
                whisper_start_list.append(w_word['start_ms'])
                whisper_end_list.append(w_word['end_ms'])
            
            # Gather AssemblyAI words
            for j in range(j1, j2):
                a_word = assemblyai_words[j]
                assemblyai_words_list.append(a_word['word'])
                assemblyai_idx_list.append(j)
                assemblyai_start_list.append(a_word['start_ms'])
                assemblyai_end_list.append(a_word['end_ms'])
            
            # Calculate time difference if possible
            time_diff = None
            if whisper_start_list and assemblyai_start_list:
                time_diff = assemblyai_start_list[0] - whisper_start_list[0]
            
            # Add a single 'change' row
            table_data.append({
                'status': 'change',
                'whisper_idx': whisper_idx_list[0] if whisper_idx_list else None,
                'whisper_word': " ".join(whisper_words_list),
                'whisper_start': whisper_start_list[0] if whisper_start_list else None,
                'whisper_end': whisper_end_list[-1] if whisper_end_list else None,
                'assemblyai_idx': assemblyai_idx_list[0] if assemblyai_idx_list else None,
                'assemblyai_word': " ".join(assemblyai_words_list),
                'assemblyai_start': assemblyai_start_list[0] if assemblyai_start_list else None,
                'assemblyai_end': assemblyai_end_list[-1] if assemblyai_end_list else None,
                'time_diff_ms': time_diff
            })
                
        elif tag == 'delete':
            # Words only in whisper - consolidate into one row if multiple consecutive words
            whisper_words_list = []
            whisper_idx_first = None
            whisper_start_first = None
            whisper_end_last = None
            
            for i in range(i1, i2):
                w_word = whisper_words[i]
                if whisper_idx_first is None:
                    whisper_idx_first = i
                    whisper_start_first = w_word['start_ms']
                whisper_words_list.append(w_word['word'])
                whisper_end_last = w_word['end_ms']
            
            table_data.append({
                'status': 'delete',
                'whisper_idx': whisper_idx_first,
                'whisper_word': " ".join(whisper_words_list),
                'whisper_start': whisper_start_first,
                'whisper_end': whisper_end_last,
                'assemblyai_idx': None,
                'assemblyai_word': '',
                'assemblyai_start': None,
                'assemblyai_end': None,
                'time_diff_ms': None
            })
                
        elif tag == 'insert':
            # Words only in assemblyai - consolidate into one row if multiple consecutive words
            assemblyai_words_list = []
            assemblyai_idx_first = None
            assemblyai_start_first = None
            assemblyai_end_last = None
            
            for j in range(j1, j2):
                a_word = assemblyai_words[j]
                if assemblyai_idx_first is None:
                    assemblyai_idx_first = j
                    assemblyai_start_first = a_word['start_ms']
                assemblyai_words_list.append(a_word['word'])
                assemblyai_end_last = a_word['end_ms']
            
            table_data.append({
                'status': 'add',
                'whisper_idx': None,
                'whisper_word': '',
                'whisper_start': None,
                'whisper_end': None,
                'assemblyai_idx': assemblyai_idx_first,
                'assemblyai_word': " ".join(assemblyai_words_list),
                'assemblyai_start': assemblyai_start_first,
                'assemblyai_end': assemblyai_end_last,
                'time_diff_ms': None
            })
    
    # Create DataFrame from the table data
    df = pd.DataFrame(table_data)
    
    # Post-processing: Interpolate timestamps for 'add' rows
    for index, row in df.iterrows():
        if row['status'] == 'add':
            # Get interpolated timestamps
            start_ms, end_ms = adjust_add_timestamps(df, index)
            
            # Update the DataFrame with interpolated values
            df.at[index, 'whisper_start'] = start_ms
            df.at[index, 'whisper_end'] = end_ms
            
            # Calculate time difference
            if pd.notnull(row['assemblyai_start']):
                time_diff_ms = start_ms - row['assemblyai_start']
                df.at[index, 'time_diff_ms'] = time_diff_ms
                
    # Calculate match statistics
    match_count = len(df[df['status'] == 'match'])
    whisper_count = len(whisper_sequence)
    assemblyai_count = len(assemblyai_sequence)
    
    print(f"Whisper words: {whisper_count}")
    print(f"AssemblyAI words: {assemblyai_count}")
    print(f"Matching words: {match_count}")
    print(f"Match percentage: {match_count/whisper_count*100:.2f}% (of Whisper)")
    print(f"Match percentage: {match_count/assemblyai_count*100:.2f}% (of AssemblyAI)")
    
    return df
    
def format_transcript_display(df):
    """
    Format the raw dataframe for display with appropriate time formats.
    
    Args:
        df: Raw dataframe from compare_transcripts_with_timings
        
    Returns:
        Formatted display dataframe
    """
    # Create a copy to avoid modifying the original
    display_df = df.copy()
    
    # Format timestamp columns to minutes:seconds.milliseconds format
    for col in ['whisper_start', 'whisper_end', 'assemblyai_start', 'assemblyai_end']:
        display_df[col] = df[col].apply(lambda x: ms_to_min_sec(x) if pd.notnull(x) else "")
    
    # Format index columns to integers
    for col in ['whisper_idx', 'assemblyai_idx']:
        display_df[col] = df[col].apply(lambda x: f"{int(x)}" if pd.notnull(x) else "")
    
    # Format time difference
    if 'time_diff_ms' in df.columns:
        display_df['time_diff'] = df['time_diff_ms'].apply(
            lambda x: f"{x/1000:.3f}s" if pd.notnull(x) else ""
        )
        # Remove the original column
        display_df = display_df.drop('time_diff_ms', axis=1)
    
    return display_df

def style_transcript_display(display_df):
    """
    Apply styling to the formatted display dataframe.
    
    Args:
        display_df: Formatted dataframe from format_transcript_display
        
    Returns:
        Styled dataframe ready for display
    """
    def style_row(row):
        # Using a color-blind friendly palette
        status_colors = {
            'match': 'background-color: #FFFFFF',  # White
            'add': 'background-color: #CCEBC5; color: black',  # Light green
            'delete': 'background-color: #FBB4AE; color: black',  # Light red/salmon
            'change': 'background-color: #B3CDE3; color: black',  # Light blue
        }
        
        color_style = status_colors.get(row['status'], 'background-color: white')
        return [color_style] * len(row)
    
    return display_df.style.apply(style_row, axis=1)

# Usage example:
df = compare_transcripts_with_timings(whisper_result, assemblyai_result)
display_df = format_transcript_display(df)
styled_df = style_transcript_display(display_df)

styled_df

In [None]:
transcript = assemblyai_result['raw_result']

In [None]:
srt_out = transcript.export_subtitles_srt(chars_per_caption=32)

In [None]:
print(srt_out)

In [None]:
from tnh_scholar.utils.file_utils import write_str_to_file


In [None]:
sent = transcript.get_sentences()

In [None]:
sent[-1]

In [None]:
out_path = Path("dharma_talk_br_phap_hoi.srt")
write_str_to_file(out_path, srt_out, overwrite=True)

In [None]:
# Test SRT generation with Whisper
whisper_srt = test_format_generation(audio_file, provider="whisper", format_type="srt")

# Test VTT generation with Whisper
whisper_vtt = test_format_generation(audio_file, provider="whisper", format_type="vtt")

In [None]:
# Test SRT generation with AssemblyAI
assemblyai_srt = test_format_generation(audio_file, provider="assemblyai", format_type="srt")

# Test VTT generation with AssemblyAI
assemblyai_vtt = test_format_generation(audio_file, provider="assemblyai", format_type="vtt")

In [None]:
def compare_transcriptions(whisper_result, assemblyai_result):
    """Compare transcriptions from both services"""
    print("Comparing transcription results:")
    print("-" * 80)
    
    # Compare text length
    whisper_len = len(whisper_result["text"])
    assemblyai_len = len(assemblyai_result["text"])
    
    print(f"Whisper text length: {whisper_len} characters")
    print(f"AssemblyAI text length: {assemblyai_len} characters")
    
    # Compare word count
    whisper_words = len(whisper_result["text"].split())
    assemblyai_words = len(assemblyai_result["text"].split())
    
    print(f"Whisper word count: {whisper_words} words")
    print(f"AssemblyAI word count: {assemblyai_words} words")
    
    # Compare first 200 characters
    print("\nWhisper first 200 chars:")
    print(whisper_result["text"][:200])
    
    print("\nAssemblyAI first 200 chars:")
    print(assemblyai_result["text"][:200])
    
    # Show unique capabilities
    print("\nUnique capabilities:")
    print(f"Whisper provides word-level timing: {len(whisper_result.get('words', [])) > 0}")
    print(f"AssemblyAI provides speaker diarization: {len(assemblyai_result.get('utterances', [])) > 0}")

# Compare the results
compare_transcriptions(whisper_result, assemblyai_result)