In [None]:
import os                       
import torch                    
from pydub import AudioSegment, silence
from pyannote.audio import Pipeline  
import torchaudio
from dotenv import load_dotenv
import csv
from pathlib import Path
from pydub.silence import detect_silence
from tenacity import retry, stop_after_attempt, wait_exponential
from tqdm import tqdm
import google.generativeai as genai
import pandas as pd
import time
from pyannote.audio import Inference
import numpy as np
import whisper

In [None]:
MIN_LEN = 5000
MAX_LEN = 8000
MIN_SILENCE_LEN = 600  
SILENCE_THRESH = -35 
max_extra_percent=0.15
PADDING_MS = 200 

In [None]:
         load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-3.1",
    use_auth_token=HF_TOKEN
)

In [None]:
api_key = os.getenv("GEMINI_API_KEY")

if not api_key:
    raise EnvironmentError(
        "GEMINI_API_KEY not found. Please set it as an environment variable."
    )
genai.configure(api_key=api_key)
model = genai.GenerativeModel("models/gemini-2.5-flash")
print(" AudioTranscriber initialized successfully.")

In [None]:
whisper_model = whisper.load_model("medium") 

In [None]:
device = torch.device("cpu")
embedding_model = Inference(
    "pyannote/embedding", 
    device= device,
    use_auth_token=HF_TOKEN)
speaker_embeddings_db = {}

In [None]:
# def split_audio_fixed_length(filepath, chunk_minutes=2, out_dir="Final_Output"):
#     podcast_name = Path(filepath).stem 
#     podcast_dir = os.path.join(out_dir, podcast_name)
#     os.makedirs(podcast_dir, exist_ok=True)
#     audio = AudioSegment.from_wav(filepath)
#     chunk_length = chunk_minutes * 60 * 1000
#     chunk_files = []
#     start_ms = 0
#     chunk_idx = 1
#     while start_ms < len(audio):
#         end_ms = min(start_ms + chunk_length, len(audio))
#         silences = detect_silence(
#             audio[start_ms:end_ms],
#             min_silence_len=MIN_SILENCE_LEN,
#             silence_thresh=SILENCE_THRESH
#         )
#         if silences:
#             last_silence_end = silences[-1][1]
#             end_ms = start_ms + last_silence_end
#         chunk = audio[start_ms:end_ms]
#         chunk_name = f"chunk{chunk_idx:02d}.wav"
#         chunk_path = os.path.join(podcast_dir, chunk_name)
#         chunk.export(chunk_path, format="wav")
#         chunk_files.append(chunk_path)
#         print(f" Saved chunk: {chunk_path}")
#         start_ms = end_ms
#         chunk_idx += 1
#     return chunk_files

In [None]:
def split_segment(segment_audio, min_len=MIN_LEN, max_len=MAX_LEN):
    segments = []
    audio_len = len(segment_audio)
    start = 0
    while start < audio_len:
        search_end = min(start + max_len, audio_len)
        segment = segment_audio[start:search_end]
        silences = silence.detect_silence(
            segment,
            min_silence_len=MIN_SILENCE_LEN,
            silence_thresh=SILENCE_THRESH
        )
        if silences:
            cut = start + silences[-1][1]
        else:
            extra_allowed = int(max_len * 0.2)
            search_end_extra = min(start + max_len + extra_allowed, audio_len)
            segment_extra = segment_audio[start:search_end_extra]
            silences_extra = silence.detect_silence(
                segment_extra,
                min_silence_len=MIN_SILENCE_LEN,
                silence_thresh=SILENCE_THRESH
            )
            if silences_extra:
                cut = start + silences_extra[0][1]
            else:
                cut = search_end_extra
        if cut - start < min_len:
            cut = min(start + min_len, audio_len)
        segment = segment_audio[start:cut]
        segments.append(segment)
        start = cut
    for i, s in enumerate(segments):
    return segments

In [None]:
# @retry(stop=stop_after_attempt(3),
#        wait=wait_exponential(multiplier=1, min=4, max=10))
# def transcription(audio_file_path, txt_output_dir):
#     """
#     Transcribes a WAV file, saves transcription to .txt,
#     and returns the transcription text.
#     """
#     audio_path = Path(audio_file_path)
#     txt_output_dir = Path(txt_output_dir)
#     txt_output_dir.mkdir(parents=True, exist_ok=True)
#     txt_file_path = txt_output_dir / f"{audio_path.stem}.txt"
#     if txt_file_path.exists():
#         return txt_file_path.read_text(encoding="utf-8").strip()
#     try:
#         prompt = """
#         Please transcribe this Arabic audio accurately.The audio contains Egyptian Arabic dialect speech.

#         Instructions:
#         1. Transcribe exactly what is spoken in Arabic script.
#         2. Do not add any diacritics (tashkeel).
#         3. Include natural speech patterns and colloquialisms.
#         4. If any English text is spoken, map its characters to Arabic(e.g., 'K' -> 'كي').
#         5. Return only the transcribed text, with no additional explanations.
#         """
#         audio_part = {
#             "mime_type": "audio/wav",
#             "data": audio_path.read_bytes()
#         }
#         response = model.generate_content([prompt, audio_part])
#         text = response.text.strip()
#         txt_file_path.write_text(text, encoding="utf-8")
#         return text
#     except Exception as e:
#         print(f"Transcription failed for {audio_path.name}: {e}")
#         return "[transcription_failed]"

In [None]:
def transcription(audio_file_path):
    audio_path = Path(audio_file_path)
    txt_file_path = audio_path.with_suffix(".txt")
    if txt_file_path.exists():
        return txt_file_path
    try:
        result = whisper_model.transcribe(str(audio_path), language=None)
        text = result['text'].strip()
        txt_file_path.write_text(text, encoding="utf-8")
        return txt_file_path
    except Exception as e:
        print(f"Transcription failed for {audio_path.name}: {e}")
        return txt_file_path

In [None]:
def get_embedding(segment_audio):
    samples = np.array(segment_audio.get_array_of_samples(), dtype=np.float32)

    if segment_audio.channels > 1:
        samples = samples.reshape(-1, segment_audio.channels).mean(axis=1)

    waveform = torch.tensor(samples).unsqueeze(0)

    audio_input = {
        "waveform": waveform,
        "sample_rate": segment_audio.frame_rate
    }

    emb = embedding_model(audio_input)
    embedding_tensor = torch.tensor(emb.data)
    return embedding_tensor

In [None]:
def assign_speaker(embedding, speaker_embeddings, threshold=0.65):
    emb_mean = embedding.mean(dim=0)
    sims = {
        spk: torch.nn.functional.cosine_similarity(
            emb_mean.unsqueeze(0),
            ref.unsqueeze(0)
        ).item()
        for spk, ref in speaker_embeddings.items()
    }
    return max(sims, key=sims.get)

In [None]:
embeddings_file = "speaker_embeddings.pt"
if os.path.exists(embeddings_file):
    speaker_embeddings_db = torch.load(embeddings_file)
    print(" Loaded existing speaker embeddings.")
else:
    speaker_embeddings_db = {}
    print("No existing embeddings found, will create new ones.")

In [None]:
audio_folder = "Audio_folder"
output_folder = "Dataset"
Final_Data = pd.DataFrame(
    columns=["Speaker", "Audio_file", "Transcription"]
)
os.makedirs(output_folder, exist_ok=True)
for file_name in os.listdir(audio_folder):
    if not file_name.endswith(".wav"):
        continue
    print(f"\nProcessing file: {file_name}") 
    podcast_name = Path(file_name).stem
    file_path = os.path.join(audio_folder, file_name)
    podcast_dir = os.path.join(output_folder, podcast_name)
    os.makedirs(podcast_dir, exist_ok=True)
    audio = AudioSegment.from_wav(file_path)
    diarization = pipeline(file_path, num_speakers=2)
    print(f"Diarization done | Number of segments: {len(list(diarization.itertracks(yield_label=True)))}")
    if not speaker_embeddings_db:  
    print("Creating speaker embeddings for first time...")
    speaker_audio = {"host": [], "guest": []}
    for turn, _, speaker in diarization.itertracks(yield_label=True):
        start_ms = int(turn.start*1000)
        end_ms = int(turn.end*1000)
        segment_audio = audio[start_ms:end_ms]
        sub_segments = split_segment(segment_audio)
        if speaker == "SPEAKER_00":
            speaker_audio["host"].extend(sub_segments)
        else:
            speaker_audio["guest"].extend(sub_segments)
    for spk, segments in speaker_audio.items():
        embeddings = []
        for seg in segments:
            emb = get_embedding(seg)
            embeddings.append(emb.mean(dim=0))
        speaker_embeddings_db[spk] = torch.stack(embeddings).mean(dim=0)
    torch.save(speaker_embeddings_db, embeddings_file)
    print(" Embeddings for first podcast initialized and saved")
    speaker_counter = {"host":0, "guest":0}
    pending_segment = {"host": None, "guest": None}
    speaker_turns = {"host": [], "guest": []}
    for turn, _, speaker in diarization.itertracks(yield_label=True):
    if speaker == "SPEAKER_00":
        speaker_turns["host"].append(turn)
    else:
        speaker_turns["guest"].append(turn)
    for spk, turns in speaker_turns.items():
    for turn in speaker_turns[spk]:
        start_ms = int(turn.start*1000)
        end_ms = int(turn.end*1000)
        segment_audio = audio[start_ms:end_ms]
        emb = get_embedding(segment_audio)
        speaker_name = assign_speaker(emb, speaker_embeddings_db)
        if pending_segment[speaker_name] is not None:
            segment_audio = pending_segment[speaker_name] + segment_audio
        pending_segment[speaker_name] = None
        sub_segments = split_segment(segment_audio)
        sub_len = len(sub_segments)
        for i, sub in enumerate(sub_segments):
            if len(sub_segments) == 1 and len(sub) < MIN_LEN :
                pending_segment[speaker_name] =sub
                continue
            if i == len(sub_segments) - 1 and len(sub) < MIN_LEN:
                pending_segment[speaker_name] =sub
                continue
            speaker_counter[speaker_name] += 1
            audio_file_name = f"{podcast_name}_{speaker_name}_{speaker_counter[speaker_name]:02d}.wav"
            audio_file_path = os.path.join(podcast_dir, audio_file_name)
            sub.export(audio_file_path, format="wav")
            print(f"Segment {speaker_counter[speaker_name]:02d} | Speaker: {speaker_name} | Duration: {len(sub)/1000:.2f}s")
            transcription_file_path = transcription(audio_file_path)
            print(f"Transcription done | File: {audio_file_name} | Transcription path: {transcription_file_path}")  
            Final_Data = pd.concat(
                [Final_Data, pd.DataFrame([{
                    "Speaker": speaker_name,
                    "Audio_file": audio_file_path,
                    "Transcription": transcription_file_path
                }])],
                ignore_index=True
            )
print(f" Finished processing file: {file_name}")

In [None]:
csv_file = os.path.join(output_folder, "Dataset.csv")
if os.path.exists(csv_file):
    old_data = pd.read_csv(csv_file)
    Final_Data = pd.concat([old_data, Final_Data], ignore_index=True)

Final_Data.to_csv(csv_file, index=False, encoding="utf-8")
print("Process Done and data saved.")