# Audio Diarization and Transcription Pipeline

This notebook provides a simple workflow to:
1. Perform speaker diarization using PyAnnote
2. Extract speaker-specific audio
3. Transcribe speaker audio with audio-transcribe
4. Convert JSONL to SRT with timeline mapping

Note: Requires tnh-scholar package to be installed.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Import required libraries
import json
import subprocess
import sys
import tempfile
from io import BytesIO
from pathlib import Path
from typing import List

from IPython.display import Audio, display

from tnh_scholar.audio_processing.diarization import (
    diarize,
    resume_diarization,
)
from tnh_scholar.audio_processing.diarization.audio import AudioHandler
from tnh_scholar.audio_processing.diarization.config import (
    ChunkConfig,
    DiarizationConfig,
    LanguageConfig,
    SpeakerConfig,
)
from tnh_scholar.audio_processing.diarization.models import AugDiarizedSegment
from tnh_scholar.audio_processing.diarization.pyannote_adapter import PyannoteAdapter
from tnh_scholar.audio_processing.diarization.strategies import LanguageProbe, WhisperLanguageDetector
from tnh_scholar.audio_processing.diarization.strategies.speaker_blocker import group_speaker_blocks
from tnh_scholar.audio_processing.diarization.strategies.time_gap import TimeGapChunker
from tnh_scholar.audio_processing.diarization.timeline_mapper import TimelineMapper
from tnh_scholar.audio_processing.diarization.viewer import close_segment_viewer, launch_segment_viewer
from tnh_scholar.audio_processing.timed_object.timed_text import Granularity, TimedText
from tnh_scholar.audio_processing.transcription import patch_whisper_options
from tnh_scholar.audio_processing.transcription.srt_processor import (
    SRTConfig,
    SRTProcessor,
)
from tnh_scholar.audio_processing.transcription.text_segment_builder import TextSegmentBuilder
from tnh_scholar.audio_processing.transcription.transcription_service import (
    TranscriptionResult,
    TranscriptionServiceFactory,
)
from tnh_scholar.audio_processing.utils import (
    get_audio_from_file,
    get_segment_audio,
    play_diarization_segment,
)
from tnh_scholar.utils.file_utils import (
    write_str_to_file,
)


In [None]:
import traceback
import warnings


# Handle warnings with traceback
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
    log = file if hasattr(file, 'write') else sys.stderr
    traceback.print_stack(file=log)
    log.write(warnings.formatwarning(message, category, filename, lineno, line))

warnings.showwarning = warn_with_traceback

In [None]:
import logging

from tnh_scholar.logging_config import setup_logging

setup_logging(log_level=logging.DEBUG)

In [None]:
# Configuration - Update these values
# Path to the directory containing audio files

BASE_DIR = Path.home() / "Desktop/tnh-scholar/audio_transcriptions"

# Audio file to process (run this notebook once per file)
AUDIO_FILE_STR = "farm_convo_spencer.flac"

DIARIZATION_FILE_STR = AUDIO_FILE_STR

SPEAKER_COUNT = None # Must be 1, 2 or None. If speakers > 2 use None for best result.

GENERATE_NEW_DIARIZATION = False

DIARIZE_SINGLE_SPEAKER = False

SRT_INCLUDE_SPEAKER = True

LANGUAGE = 'en'

TARGET_CHUNK_TIME = 2 * 60  # seconds

MIN_CHUNK_TIME = 10 # seconds

TRANSCRIBER = "whisper"

completed = True


In [None]:
metadata = "" # read_str_from_file(BASE_DIR / "sr_bamboo_metadata.txt")

In [None]:
diarize_config = DiarizationConfig(
    chunk = ChunkConfig(
        target_duration=TARGET_CHUNK_TIME * 1000,
        min_duration= MIN_CHUNK_TIME * 1000, 
    ),
    speaker = SpeakerConfig(
        single_speaker=DIARIZE_SINGLE_SPEAKER,
    ),
    language = LanguageConfig(),
)

In [None]:
# Set up paths
audio_file_path = BASE_DIR / AUDIO_FILE_STR
diarize_audio_file_path = BASE_DIR / DIARIZATION_FILE_STR

file_ext_str = audio_file_path.suffix

if not audio_file_path.exists():
    raise FileNotFoundError(f"No file found: {audio_file_path}")

diarization_results_path = diarize_audio_file_path.parent / "raw_diarization_results.json"

In [None]:
def load_diarization_result(file_path):
    """Load diarization result from JSON file or sample data."""
    if not file_path:
        raise ValueError("File_path must be provided.")

    with open(file_path, 'r') as f:
        data = json.load(f)

    return data

In [None]:
# Run PyAnnote diarization
if GENERATE_NEW_DIARIZATION:
    completed = False
    print(f"Starting diarization for {diarize_audio_file_path}...")
    result = diarize(diarize_audio_file_path, num_speakers=SPEAKER_COUNT, output_path=diarization_results_path)

    # If the job is still running, you'll get a job ID
    if isinstance(result, str):
        job_id = result
        print(f"Diarization job started with ID: {job_id}")
        print("Wait for completion and then run the next cell with this job ID")
    else:
        completed = True
        print("Diarization process finished on initial run.")

In [None]:
# Only run this if you got a job ID in the previous cell
# Replace with your actual job ID from the previous step
# job_id = "your-job-id-here"  # e.g., "994c79b7-5f32-4715-aa34-33f00e216369"

# Check status

if not completed:
    status = check_job_status(job_id)
    print(f"Current status: {status.get('status', 'unknown')}")

    # Resume if needed
    if status.get('status') != 'succeeded':
        print("Resuming diarization...")
        result = resume_diarization(audio_file_path, job_id)
        print("Diarization completed")
    else:
        print("Diarization already completed")

In [None]:
print(diarize_config)

In [None]:
base_audio = get_audio_from_file(audio_file_path)

In [None]:
transcription_options_aai = {"language_code": LANGUAGE, "language_detection": False}

In [None]:
ts_service = TranscriptionServiceFactory.create_service(provider=TRANSCRIBER)

In [None]:
transcript = ts_service.transcribe(audio_file_path, transcription_options_aai)

In [None]:
full_seg = transcript.utterance_timing

In [None]:

assert full_seg is not None
new_seg = TimedText(segments=full_seg.segments, granularity=Granularity.SEGMENT)

In [None]:
assert full_seg is not None
full_out = new_seg.export_text()

In [None]:
print(full_out)

In [None]:
path_out = BASE_DIR / audio_file_path.with_suffix(".txt")
write_str_to_file(path_out, full_out, overwrite=True)

In [None]:
srt_config = SRTConfig(include_speaker=SRT_INCLUDE_SPEAKER) 
srt_processor = SRTProcessor(srt_config)

In [None]:
srt_out = srt_processor.generate(full_seg)

In [None]:
print(srt_out)

In [None]:
# Load and process the diarization results

print(f"Loading diarization results from {diarization_results_path}")
chunker = TimeGapChunker(config=diarize_config)
segment_adapter = PyannoteAdapter(config=diarize_config)
result = load_diarization_result(file_path=diarization_results_path)
data = result['output']
segments = segment_adapter.to_segments(data)
chunk_list = chunker.extract(segments)

for chunk in chunk_list:
    print(f"  chunk: {chunk}")

In [None]:
diarize_raw = data['diarization']

In [None]:
diarize_raw[0]['start']

In [None]:
data

In [None]:
chunk_list

In [None]:
len(segments)

In [None]:
segments[0]

In [None]:
segments[110]

In [None]:
long_list = [seg for seg in segments if seg.duration_sec > 4.0]
long_list_info = [
    (i, 
    seg.duration_sec, seg.start.to_seconds(), 
    seg.end.to_seconds(), seg.speaker
    ) 
    for i, seg in enumerate(long_list) 
]
long_list_info

In [None]:
speaker_blocks = group_speaker_blocks(segments, config=diarize_config)

In [None]:
len(speaker_blocks)

In [None]:
pid = launch_segment_viewer(speaker_blocks[:250], audio_file_path)

In [None]:
close_segment_viewer(pid)

In [None]:
[(block.speaker, block.duration) for block in speaker_blocks]

In [None]:
len(long_list)

In [None]:
test_idx = 167
seg = segments[test_idx]
print(seg)
play_diarization_segment(seg, base_audio)

In [None]:
detector = WhisperLanguageDetector()

probe = LanguageProbe(
    config=diarize_config, 
    detector=detector,
)

In [None]:
seg_audio = get_segment_audio(seg, base_audio)

In [None]:
seg_audio

In [None]:

aug_seg = AugDiarizedSegment.from_segment(segments[test_idx], audio=seg_audio)

In [None]:
aug_seg

In [None]:
probe.segment_language(aug_segment=aug_seg)

In [None]:
import concurrent.futures
import time

from openai import RateLimitError

segments_to_probe = long_list

def probe_segment_safe(probe, aug_segment):
    try:
        return probe.segment_language(aug_segment=aug_segment)
    except RateLimitError:
        print("Rate limit hit, sleeping and retrying...")
        time.sleep(10)  # Wait and retry
        try:
            return probe.segment_language(aug_segment=aug_segment)
        except Exception as e:
            print(f"Failed again: {e}")
            return None
    except Exception as e:
        print(f"Error: {e}")
        return None

# Example: probe all segments in long_list (or chunk_list, or your own list)
max_workers = 1000  # Adjust based on your rate limit
results = []
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
    futures = [executor.submit(
        probe_segment_safe, 
        probe, 
        AugDiarizedSegment.from_segment(seg, audio=get_segment_audio(seg, base_audio))
        )
        for seg in segments_to_probe]  # Adjust range as needed
    for future in concurrent.futures.as_completed(futures):
        results.append(future.result())

print("Language probe results:", results)

In [None]:
[chunk.accumulated_time for chunk in chunk_list]

In [None]:
[len(chunk.segments) for chunk in chunk_list]

In [None]:
# Extract speaker audio segments
print("Extracting speaker audio segments to local ByteIO objects")
audio_handler = AudioHandler()
total_chunks = len(chunk_list) 

for i, chunk in enumerate(chunk_list, start=1):
    print(f"Building chunk {i} of {total_chunks}")
    audio_handler.build_audio_chunk(chunk, audio_file=audio_file_path)

In [None]:
len(chunk_list)

In [None]:
chunk_list[0]

In [None]:
chunk_list[0].total_duration_sec

In [None]:
audio_list = [chunk.audio for chunk in chunk_list]

In [None]:
audio_list

In [None]:
aud_chunk = audio_list[0]

In [None]:
aud_chunk

In [None]:
play_audio_mp4(aud_chunk.data)

In [None]:
ts_service = TranscriptionServiceFactory.create_service(provider=TRANSCRIBER)

transcription_options_whisper = {
    "language": LANGUAGE, "timestamp_granularities": ["word"], "prompt": metadata
    }
transcription_options_whisper = patch_whisper_options(
    transcription_options_whisper, file_extension=file_ext_str
    )
transcription_options_aai = {"language_code": LANGUAGE, "language_detection": False}

In [None]:
transcription_options = transcription_options_whisper \
    if TRANSCRIBER == 'whisper' \
        else transcription_options_aai

In [None]:
chunks_to_process = chunk_list

In [None]:
transcripts: List[TranscriptionResult]= []
for i, chunk in enumerate(chunks_to_process, start=1):
    print(f"processing chunk: {i}")
    audio = chunk.audio
    if not audio:
        raise ValueError("No audio data for chunk.")
    audio_obj = audio.data
    print(f"Running transcript generation with {TRANSCRIBER} service...")
    print(f"Audio file: {audio_obj}")
    transcript = ts_service.transcribe(audio_obj, transcription_options)
    print(transcript)
    transcripts.append(transcript)
    
print("Transcription loop complete.")

In [None]:
transcript.raw_result

In [None]:
len(transcripts)

In [None]:
mapper = TimelineMapper()

In [None]:
timings = []
for chunk, transcript in zip(chunk_list, transcripts):
    tt = transcript.word_timing
    if tt is not None:
        new_timing = mapper.remap(tt, chunk)
        timings.append(new_timing)
    else:
        raise ValueError("No timed text for words.")

In [None]:
len(timings)

In [None]:
timings[0]

In [None]:
complete_timing = TimedText.merge(timings)

In [None]:
complete_timing

In [None]:
segment_builder = TextSegmentBuilder(max_duration_ms=4*1000, target_characters=42, ignore_speaker=True)

In [None]:
full_seg = segment_builder.create_segments(complete_timing)

In [None]:
full_seg

In [None]:
srt_config = SRTConfig(include_speaker=SRT_INCLUDE_SPEAKER) 
srt_processor = SRTProcessor(srt_config)

In [None]:
srt_out = srt_processor.generate(full_seg)

In [None]:
print(srt_out)

In [None]:
play_audio_mp4(aud_chunk.data)

In [None]:
display(Audio(str(audio_file_path)))

In [None]:
test_str = "srt_out"

In [None]:
new_ext = ".srt"
new_stem = f"{audio_file_path.stem}_{test_str}"
srt_path = audio_file_path.with_name(new_stem + new_ext)


write_str_to_file(srt_path, srt_out, overwrite=True)

# END OF PROCESS PIPE

In [None]:
# Post-processing: Translate all final SRT files to English
print("\n===== Translating SRT files to English =====")

        
# Find all _final.srt files in each speaker directory
for srt_file in BASE_DIR.glob("*.srt"):
    print(f"file: {srt_file}")
          
    en_srt_file = srt_file.with_name(f"{srt_file.stem}_en.srt")
    
    # Skip if English version already exists
    if en_srt_file.exists():
        print(f"English SRT already exists: {en_srt_file}")
        continue
        
    # Run srt-translate
    cmd = f"srt-translate '{srt_file}' -o '{en_srt_file}' -t en"
    print(f"Running: {cmd}")
    
    try:
        subprocess.run(cmd, shell=True, check=True)
        print(f"Successfully translated: {srt_file} -> {en_srt_file}")
    except subprocess.CalledProcessError as e:
        print(f"Error translating {srt_file}: {e}")

print("===== Translation complete =====")

In [None]:
# --- Settings ---
srt_folder = BASE_DIR  # <-- Change this to your actual folder
srt_processor = SRTProcessor()

# --- Processing Loop ---
for srt_file in srt_folder.glob("*.srt"):
    # Read original SRT content
    srt_content = srt_file.read_text(encoding="utf-8")

    # Parse to TimedText
    timed_text = srt_processor.parse(srt_content)

    # Re-generate SRT without speaker labels
    cleaned_srt = srt_processor.generate(timed_text, include_speaker=False)

    # Rename original file to *_sp.srt
    speaker_file = srt_file.with_stem(f"{srt_file.stem}_sp")
    srt_file.rename(speaker_file)

    # Save cleaned SRT under original filename
    srt_file.write_text(cleaned_srt, encoding="utf-8")

print("Cleaning and renaming completed.")

In [None]:
# # Process each speaker's audio
# for speaker, blocks in mapped_blocks.items():
#     speaker_audio_path = export_dir / f"{speaker}.mp3"
#     speaker_output_dir = export_dir / "audio_transcriptions" / speaker
#     audio_transcribe_output_dir = export_dir / "audio_transcriptions"
#     ensure_directory_exists(speaker_output_dir)
    
#     print(f"\nProcessing {speaker}...")
    
#     # Run audio-transcribe on the speaker's audio file
#     cmd = f"audio-transcribe -f {speaker_audio_path} --output_dir {audio_transcribe_output_dir} --split --transcribe"
#     print(f"Running: {cmd}")
#     subprocess.run(cmd, shell=True, check=True)
    

In [None]:
test_list = []

In [None]:
test_list[-1]