In [1]:
import torch
import pandas as pd
import os
from dotenv import load_dotenv
from pyannote.audio import Pipeline
import whisper # The official OpenAI Whisper library
from pydub import AudioSegment
import numpy as np
import torchaudio

# Revert to the documented import structure
from ctc_forced_aligner import (
    load_audio as ctc_load_audio, # Renamed to avoid conflict with torchaudio.load
    load_alignment_model,
    generate_emissions,
    preprocess_text,
    get_alignments,
    get_spans,
    postprocess_results,
)

load_dotenv()
hf_token = os.getenv("HF_TOKEN")

if hf_token is None:
    raise ValueError("Hugging Face token not found. Please set the HF_TOKEN environment variable.")

audio_file = "output1.wav"

if not os.path.exists(audio_file):
    raise FileNotFoundError(f"The audio file was not found at: {audio_file}")

# Determine device for Pyannote and Whisper/Alignment models
if torch.backends.mps.is_available():
    device_pyannote = "mps"
    compute_type_pyannote = "float16"

    device_whisper_align = "cpu" # Still force CPU for Whisper/Alignment on MPS
    compute_type_whisper_align = "float32"
elif torch.cuda.is_available():
    device_pyannote = "cuda"
    compute_type_pyannote = "float16"
    device_whisper_align = "cuda"
    compute_type_whisper_align = "float16"
else:
    device_pyannote = "cpu"
    compute_type_pyannote = "float32"
    device_whisper_align = "cpu"
    compute_type_whisper_align = "float32"

print(f"Using device for Pyannote: {device_pyannote}")
print(f"Using compute type for Pyannote: {compute_type_pyannote}")
print(f"Using device for Whisper/Alignment: {device_whisper_align}")
print(f"Using compute type for Whisper/Alignment: {compute_type_whisper_align}")

print("Starting speaker diarization...")
diarization_pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-3.1",
    use_auth_token=hf_token
)
diarization_pipeline.to(torch.device(device_pyannote))
diarization = diarization_pipeline(audio_file)
print("Speaker diarization complete.")

data = []
for segment, track, speaker in diarization.itertracks(yield_label=True):
    data.append({
        'start': segment.start,
        'end': segment.end,
        'speaker': speaker
    })
diarization_df = pd.DataFrame(data)

print("\n--- Diarization Result (DataFrame) ---")
print(diarization_df)
print(f"Number of diarization segments: {len(diarization_df)}")
print("--------------------------------------")

# 2. Transcription with OpenAI Whisper
print(f"Loading OpenAI Whisper model on device: {device_whisper_align}...")
model = whisper.load_model("medium", device=device_whisper_align)
print("OpenAI Whisper model loaded.")

print("Transcribing audio with OpenAI Whisper...")
transcription_raw_result = model.transcribe(audio_file, language="th", verbose=False)
transcription_result = {'segments': []}
if 'segments' in transcription_raw_result:
    for segment in transcription_raw_result['segments']:
        transcription_result['segments'].append({
            'text': segment['text'],
            'start': segment['start'],
            'end': segment['end']
        })
print("Transcription complete.")

# print("\n--- Transcription Result (raw Whisper output) ---")
# print(f"Number of transcription segments: {len(transcription_result['segments'])}")
# for i, seg in enumerate(transcription_result['segments']):
#     print(f"  Segment {i+1}: Start={seg['start']:.2f}, End={seg['end']:.2f}, Text='{seg['text'].strip()}'")
# print("--------------------------------------------")


# 3. Forced Alignment with ctc-forced-aligner
print("Initializing forced alignment service (using load_alignment_model)...")
try:
    alignment_model, alignment_tokenizer = load_alignment_model(
        device_whisper_align,
        dtype=torch.float16 if device_whisper_align == "cuda" else torch.float32,
    )
    print("Forced alignment model and tokenizer loaded.")
except Exception as e:
    print(f"Error loading alignment model: {e}")
    print("Please ensure the ctc-forced-aligner library is correctly installed from GitHub and dependencies are met.")
    sys.exit(1)

aligned_result_segments = []

print("Loading full audio for alignment...")
try:
    # Using torchaudio.load as it's generally more robust
    full_audio_waveform, sample_rate = torchaudio.load(audio_file)
    if sample_rate != 16000:
        print(f"Resampling audio from {sample_rate}Hz to 16000Hz...")
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000).to(device_whisper_align)
        full_audio_waveform = resampler(full_audio_waveform)
        sample_rate = 16000

    # Ensure it's mono and on the correct device for emissions generation
    full_audio_waveform = full_audio_waveform.mean(dim=0, keepdim=True).to(device_whisper_align)

    # ctc-forced-aligner's generate_emissions expects audio waveform
    print("Full audio loaded and processed.")
except Exception as e:
    print(f"Error loading/processing audio with torchaudio: {e}")
    sys.exit(1)


print("Performing forced alignment on transcription segments...")
for i, segment in enumerate(transcription_result['segments']):
    segment_text = segment['text']
    segment_start = segment['start']
    segment_end = segment['end']

    # print(f"  Aligning segment {i+1}: {segment_start:.2f}s - {segment_end:.2f}s, Text: '{segment_text.strip()}'")

    # Extract the audio chunk for the current segment
    start_sample_abs = int(segment_start * sample_rate)
    end_sample_abs = int(segment_end * sample_rate)

    start_sample_abs = max(0, start_sample_abs)
    end_sample_abs = min(full_audio_waveform.shape[-1], end_sample_abs)

    if start_sample_abs >= end_sample_abs:
        print(f"    Warning: Skipping empty or invalid audio chunk for segment {i+1}.")
        aligned_result_segments.append({
            'text': segment_text,
            'start': segment_start,
            'end': segment_end,
            'words': []
        })
        continue

    current_audio_chunk_tensor = full_audio_waveform[:, start_sample_abs:end_sample_abs]

    try:
        # Generate emissions for the current audio chunk
        emissions, stride = generate_emissions(
            alignment_model,
            current_audio_chunk_tensor,
            batch_size=1 # Process one chunk at a time
        )

        # Preprocess the text
        tokens_starred, text_starred = preprocess_text(
            segment_text,
            romanize=True, # Set to True if using models that require romanization (e.g., multilingual)
            language="th"
        )

        # Get alignments
        segments, scores, blank_token = get_alignments(
            emissions,
            tokens_starred,
            alignment_tokenizer,
        )

        # Get spans (word-level details)
        spans = get_spans(
            segments,
            text_starred,
            offset=0, # This offset is relative to the *chunk*, we'll add segment_start later
            blank_token=blank_token
        )

        # Postprocess results to get word-level timings
        aligned_words_raw = postprocess_results(
            spans,
            score_threshold=0.8 # Adjust threshold if needed
        )

        aligned_words = []
        for word_info in aligned_words_raw:
            # Add the original segment's start time to the word timings
            word_info['start'] += segment_start
            word_info['end'] += segment_start
            aligned_words.append(word_info)

        aligned_result_segments.append({
            'text': segment_text,
            'start': segment_start,
            'end': segment_end,
            'words': aligned_words
        })
    except Exception as e:
        print(f"    Warning: Could not align segment '{segment_text}' (start: {segment_start:.2f}s, end: {segment_end:.2f}s). Error: {e}")
        aligned_result_segments.append({
            'text': segment_text,
            'start': segment_start,
            'end': segment_end,
            'words': []
        })

aligned_result = {'segments': aligned_result_segments}
print("Alignment complete.")

# print("\n--- Aligned Result ---")
# print(f"Number of aligned segments: {len(aligned_result['segments'])}")
# for i, seg in enumerate(aligned_result['segments']):
#     text_display = "".join([word["word"] for word in seg["words"]]).strip() if "words" in seg and seg["words"] else seg["text"].strip()
#     print(f"  Segment {i+1}: Start={seg['start']:.2f}, End={seg['end']:.2f}, Text='{text_display}'")
#     if "words" in seg and seg["words"]:
#         for word in seg["words"]:
#             print(f"    - Word: '{word['word']}', Start={word['start']:.2f}, End={word['end']:.2f}")
# print("----------------------")


# 4. Assign Speaker Labels
print("Assigning speakers to transcription segments...")
result_with_speakers = aligned_result.copy()
result_with_speakers['segments'] = []

for segment in aligned_result['segments']:
    segment_start = segment['start']
    segment_end = segment['end']
    assigned_speaker = "Unknown"

    max_overlap = 0
    for idx, row in diarization_df.iterrows():
        diar_start = row['start']
        diar_end = row['end']

        overlap_start = max(segment_start, diar_start)
        overlap_end = min(segment_end, diar_end)

        overlap_duration = max(0, overlap_end - overlap_start)

        if overlap_duration > max_overlap:
            max_overlap = overlap_duration
            assigned_speaker = row['speaker']

    segment['speaker'] = assigned_speaker
    result_with_speakers['segments'].append(segment)
print("Speaker assignment complete.")


# 5. Print the final result
print("\n--- Speaker-Labeled Transcription ---")
for segment in result_with_speakers["segments"]:
    speaker = segment.get("speaker", "Unknown Speaker")
    start_time = segment["start"]
    end_time = segment["end"]
    text = segment["text"].strip()
    print(f"[{speaker}] {start_time:.2f}s - {end_time:.2f}s: {text}")
print("--------------------------------------------")

  from .autonotebook import tqdm as notebook_tqdm


Using device for Pyannote: mps
Using compute type for Pyannote: float16
Using device for Whisper/Alignment: cpu
Using compute type for Whisper/Alignment: float32
Starting speaker diarization...
Speaker diarization complete.

--- Diarization Result (DataFrame) ---
         start         end     speaker
0     0.030969    2.646594  SPEAKER_01
1     2.815344    4.131594  SPEAKER_00
2     4.435344    6.544719  SPEAKER_00
3     7.185969   11.674719  SPEAKER_00
4     7.320969    7.725969  SPEAKER_01
5    11.708469   14.070969  SPEAKER_01
6    13.514094   17.395344  SPEAKER_00
7    29.764719   30.372219  SPEAKER_01
8    30.810969   32.532219  SPEAKER_01
9    32.903469   51.330969  SPEAKER_01
10   48.917844   49.373469  SPEAKER_00
11   50.419719   51.246594  SPEAKER_00
12   51.330969   51.381594  SPEAKER_00
13   52.444719   63.430344  SPEAKER_01
14   60.612219   60.797844  SPEAKER_00
15   64.527219   65.404719  SPEAKER_01
16   65.826594   68.307219  SPEAKER_01
17   68.442219   69.707844  SPEAKE

 50%|█████     | 6000/12000 [02:42<02:42, 36.98frames/s]


Transcription complete.
Initializing forced alignment service (using load_alignment_model)...
Forced alignment model and tokenizer loaded.
Loading full audio for alignment...
Full audio loaded and processed.
Performing forced alignment on transcription segments...
Alignment complete.
Assigning speakers to transcription segments...
Speaker assignment complete.

--- Speaker-Labeled Transcription ---
[SPEAKER_01] 0.00s - 2.68s: ดิจิ่งโลกเราก็เคยผ่านการสุนพันศ์
[SPEAKER_00] 2.68s - 4.08s: หิน 10 กิโล
[SPEAKER_00] 4.34s - 6.68s: พุ่งมาด้วยความเร็ว 20 กิโลเมตร
[SPEAKER_00] 7.06s - 8.66s: ต่อวินาที
[SPEAKER_00] 8.66s - 9.88s: พัศพิรุณเนี่ย
[SPEAKER_00] 9.88s - 11.68s: ก่อมน้ำลงมาดับไฟให้นะ
[SPEAKER_01] 11.68s - 14.08s: - อ๋อ พัศพิรุณยังเห็นใจล้าว! ใช่!
[SPEAKER_00] 14.08s - 14.92s: เอ่อ!
[SPEAKER_00] 14.92s - 17.36s: แต่น้ำนั่นน่ะ คือ น้ำกดซันฟิวริก
[SPEAKER_00] 17.36s - 19.82s: อือางน้ำ
[SPEAKER_01] 107.36s - 110.36s: งานที่ดูเนี่ยผมก็จะดูพวกศาลคดี
[SPEAKER_01] 110.36s - 116.36s: แต่ถ้าฟัง p