In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import re
import subprocess
import sys
from io import BytesIO
from pathlib import Path

from IPython.display import Audio
from pydub import AudioSegment


In [None]:
# Import the transcription service modules
from tnh_scholar.audio_processing.transcription_service import (
    DiarizationChunker,
    TranscriptionFormatConverter,
    TranscriptionServiceFactory,
)
from tnh_scholar.cli_tools.audio_transcribe.diarize import (
    check_job_status,
    diarize,
    resume_diarization,
)


In [None]:
working_dir = Path.home() / "Desktop/transcription_wouter"

In [None]:
audio_file_base_path = working_dir \
    / "qa_sr_abbess.mp3"  
if not audio_file_base_path.exists():
    raise FileNotFoundError("Audio file not found.")

In [None]:
def gen_srt(audio_file_obj, provider="whisper", language=None, local_convert=False):
    """
    generate srt
    """
    format_type = "srt"
    # Create the transcription service
    service = TranscriptionServiceFactory.create_service(provider=provider)

    # Print some info
    print(f"Running {format_type.upper()} generation with {provider} service...")
    print(f"Audio file: {audio_file_obj}")

    transcription_options = {"language": language} if language else None
    
    # Generate the formatted transcription
    # use the local format converter if specified
    if local_convert:
        converter = TranscriptionFormatConverter()
        transcript = service.transcribe(audio_file_obj, options=transcription_options)
        return converter.convert(transcript)
        
    return service.transcribe_to_format(
        audio_file_obj, 
        format_type=format_type,
        transcription_options=transcription_options
    )

In [None]:
def process_audio_chunks(
    audio_path, chunks, audio_format=None, language=None, local_convert=False
    ):
    """
    Process audio file by chunks and generate SRTs with adjusted timestamps.
    
    Args:
        audio_path: Path to the audio file
        chunks: List of Chunk objects with timing information
        
    Returns:
        Combined SRT string with properly adjusted timestamps
    """
    
    if audio_format is None:
        audio_format = audio_path.suffix[1:]
        print(f"Using audio format: {audio_format}")
        
    # Load the full audio file
    print(f"Loading audio file: {audio_path}")
    full_audio = AudioSegment.from_file(audio_path)
    
    # Process each chunk
    all_srts = []
    
    for i, chunk in enumerate(chunks):
        chunk_duration = chunk.end_time - chunk.start_time
        print(f"Processing chunk {i+1}/{len(chunks)}: {chunk.start_time}ms "
              f"to {chunk.end_time}ms")
        print(f"chunk duration: {chunk_duration}")
        
        # Get subset of audio
        chunk_audio = full_audio[chunk.start_time:chunk.end_time]
        
        # Convert to file-like object
        chunk_file = BytesIO()
        chunk_audio.export(chunk_file, format=audio_format)
        chunk_file.seek(0)  # Reset file pointer to beginning
        
        # Add a filename for whisper to recognize
        chunk_file.name = f"chunk_{i}.{audio_format}"  

        # Generate SRT for this chunk
        chunk_srt = gen_srt(chunk_file, language=language, local_convert=local_convert)
        
        # Adjust timestamps in the SRT based on chunk start time
        adjusted_srt = adjust_srt_timestamps(chunk_srt, chunk.start_time)
        
        all_srts.append(adjusted_srt)
        
    # Combine all SRTs, renumbering entries
    combined_srt = combine_srts(all_srts)
    
    return all_srts, combined_srt

def adjust_srt_timestamps(srt_content, offset_ms):
    """Adjust SRT timestamps by adding the offset (in ms)"""
    
    def add_offset_to_timestamp(timestamp_str, offset_ms):
        """Add millisecond offset to an SRT timestamp string (HH:MM:SS,mmm)"""
        h, m, rest = timestamp_str.split(':')
        s, ms = rest.split(',')
        
        # Convert to total milliseconds
        total_ms = int(h) * 3600000 + int(m) * 60000 + int(s) * 1000 + int(ms) + offset_ms
        
        # Convert back to SRT format
        new_h = total_ms // 3600000
        total_ms %= 3600000
        new_m = total_ms // 60000
        total_ms %= 60000
        new_s = total_ms // 1000
        new_ms = total_ms % 1000
        
        return f"{new_h:02d}:{new_m:02d}:{new_s:02d},{new_ms:03d}"
    
    # Pattern for SRT timestamp lines
    pattern = r'(\d{2}:\d{2}:\d{2},\d{3}) --> (\d{2}:\d{2}:\d{2},\d{3})'
    
    def replace_timestamps(match):
        start_time = match.group(1)
        end_time = match.group(2)
        new_start = add_offset_to_timestamp(start_time, offset_ms)
        new_end = add_offset_to_timestamp(end_time, offset_ms)
        return f"{new_start} --> {new_end}"
    
    # Replace all timestamp pairs in the SRT
    return re.sub(pattern, replace_timestamps, srt_content)

def combine_srts(srt_list):
    """Combine multiple SRT strings, renumbering the entries sequentially"""
    result = []
    entry_num = 1
    
    for srt in srt_list:
        # Split into entries (blocks separated by blank lines)
        entries = srt.strip().split("\n\n")
        
        for entry in entries:
            if not entry.strip():
                continue
                
            # Split the entry into lines
            lines = entry.split("\n")
                
            # Replace the index number (first line) with sequential number
            lines[0] = str(entry_num)
            entry_num += 1
            
            # Add updated entry to result
            result.append("\n".join(lines))
    
    # Join all entries with blank lines in between
    return "\n\n".join(result)

In [None]:
audio_file_path = Path("/Users/phapman/Desktop/transcription_wouter/qa_sr_abbess_wh_sh.mp3")

In [None]:
audio_file_path.suffix[1:]

In [None]:
result = diarize(audio_file_path)

In [None]:
result = resume_diarization(audio_file_path, 'd4d35761-ac95-4ddd-b468-5a7471855219')

In [None]:
result

In [None]:
chunker = DiarizationChunker(target_duration=60 * 1000, single_speaker=True, min_chunk_duration=60 * 1000)

In [None]:
segs = chunker.to_segments(result)

In [None]:
segs

In [None]:
chunks = chunker.extract_chunks(segs)

In [None]:
len(chunks)

In [None]:
chunks[1]

In [None]:
[(ch.start_time, ch.end_time, ch.duration_sec, ch.segments[0].speaker) for ch in chunks]

In [None]:
chunks[1].segments[-1]

In [None]:
chunks[2].segments[0]

In [None]:
all_srts, combined = process_audio_chunks(audio_file_path, chunks, language="vi", local_convert=True)

In [None]:
print(combined)

In [None]:
print(all_srts[0])

In [None]:
print(combined)

In [None]:
from tnh_scholar.utils.file_utils import write_str_to_file

out_srt = working_dir / "Dharma Talk Br. Phap Hoi (for transcription) 2-bit.srt"
write_str_to_file(out_srt, combined, overwrite=True)