# Generative AI for real-time chat transcription

## Requirements



### pyannote installation and requirements

In [59]:
# !pip install pyannote.audio
# !pip install torch
# !pip install torchaudio

In [60]:
import torch,torchaudio
from pyannote.audio import Pipeline
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import tqdm as notebook_tqdm

### whisper installation and requirements

In [61]:
# !pip install -U openai-whisper
# !pip install git+https://github.com/openai/whisper.git
# !sudo apt install ffmpeg
# !pip install setuptools-rust

In [62]:
import whisper

### gemini installations and requirements

In [63]:
# !pip install -q google-generativeai

In [64]:
import google.generativeai as genai
import time # For potential rate limiting

### Embedding installation and requirements

In [65]:
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding

### Re‐identificaiton installation and requirements

In [66]:
from sklearn.metrics.pairwise import cosine_similarity

### flask + ngrok installation and requiremets

In [67]:
# !pip install flask pyngrok

### other requirements

In [None]:
import os
import subprocess
import tempfile
import numpy as np

# Set the Hugging Face token
HF_TOKEN = ""

GEMINI_API_KEY = ""

### Instantiate all the models

In [None]:
# Diarisation-->instantiate the pipeline and loading the transformer
pipeline = Pipeline.from_pretrained(
  "pyannote/speaker-diarization-3.1",  #The model used
  use_auth_token=HF_TOKEN) #Hugging face token

# Transcription--> loading the whisper turbo model
whisper_model = whisper.load_model("turbo")

# Gemini--> refinement model
# gemini_model_name = "gemini-2.5-flash-preview-04-17"
gemini_model_name = "gemini-2.0-flash"
genai.configure(api_key=GEMINI_API_KEY) # Pass your variable here
gemini_model = genai.GenerativeModel(gemini_model_name)

# embedding
speaker_embedding_model = PretrainedSpeakerEmbedding(
    "pyannote/embedding",  #The model used
    use_auth_token=HF_TOKEN,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)

In [70]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()

### Warning handling

In [12]:
import warnings
warnings.filterwarnings(
    "ignore",
    message="The MPEG_LAYER_III subtype is unknown",
    category=UserWarning,
    module="torchaudio._backend.soundfile_backend"
)

## ── AUDIO PREPROCESSING ─────────────────────────────────

In [13]:
def preprocess_audio_and_save(input_audio_path: str,
                              output_dir: str = "/content/", # Or use tempfile.gettempdir() for truly temp files
                              output_filename: str = "final_processed_for_models.wav",
                              target_sample_rate: int = 16000) -> str:
    """
    Preprocesses an audio file and saves the final version.
    1. Converts to WAV format if not already (using ffmpeg for non-WAV inputs).
    2. Loads the audio using torchaudio.
    3. Converts to mono.
    4. Resamples to the target_sample_rate.
    5. Saves the fully processed audio to a specified path.
    6. Returns the path to the saved processed file.
    7. Cleans up any intermediate temporary WAV file created during initial conversion.

    Args:
        input_audio_path (str): Path to the input audio file.
        output_dir (str): Directory to save the final processed WAV file.
        output_filename (str): Filename for the final processed WAV file.
        target_sample_rate (int, optional): The desired sample rate. Defaults to 16000.

    Returns:
        str: Path to the saved final processed WAV file.
    """

    final_processed_path = os.path.join(output_dir, output_filename)

    # Internal helper for initial conversion to WAV if needed
    def _convert_to_wav_if_needed(input_file_path: str) -> tuple[str, bool]:
        working_audio_path = input_file_path
        intermediate_temp_created = False
        file_name, file_extension = os.path.splitext(input_file_path)
        file_extension = file_extension.lower()

        if file_extension == '.wav':
            return working_audio_path, intermediate_temp_created

        supported_non_wav_extensions = ['.mp3', '.ogg', '.flac', '.aac', '.m4a', '.wma']
        if file_extension not in supported_non_wav_extensions:
            raise ValueError(f"Unsupported audio file format: {file_extension} for '{input_file_path}'")

        intermediate_temp_wav_path = os.path.join(tempfile.gettempdir(), f"{os.path.basename(file_name)}_intermediate_{os.urandom(4).hex()}.wav")

        # print(f"Input '{input_file_path}' is not WAV. Converting to intermediate WAV: '{intermediate_temp_wav_path}'...")
        command = ['ffmpeg', '-i', input_file_path, '-y', intermediate_temp_wav_path]
        try:
            subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            working_audio_path = intermediate_temp_wav_path
            intermediate_temp_created = True
        except subprocess.CalledProcessError as e:
            print(f"Error during FFmpeg conversion for '{input_file_path}': {e.stderr.decode()}")
            raise
        return working_audio_path, intermediate_temp_created

    intermediate_wav_path, was_intermediate_temp_created = _convert_to_wav_if_needed(input_audio_path)

    try:
        waveform, current_sample_rate = torchaudio.load(intermediate_wav_path)

        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)

        if current_sample_rate != target_sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=current_sample_rate, new_freq=target_sample_rate)
            waveform = resampler(waveform)
        # final_sample_rate is target_sample_rate

        os.makedirs(output_dir, exist_ok=True)
        torchaudio.save(final_processed_path, waveform.cpu(), target_sample_rate)
        # print(f"Final processed audio saved to: {final_processed_path}")

    finally: # Ensure cleanup of intermediate file
        if was_intermediate_temp_created and os.path.exists(intermediate_wav_path):
            try:
                os.remove(intermediate_wav_path)
                # print(f"Successfully removed intermediate temp WAV: {intermediate_wav_path}")
            except OSError as e:
                print(f"Error removing intermediate temp WAV '{intermediate_wav_path}': {e}")

    return final_processed_path

## ── DIARIZATION ──────────────────────────────────────────

In [14]:
def diarization_func(audio_file,pipeline):
  pipeline.to(torch.device("cuda")) # This line moves the pipeline  to the GPU for computation.
  # run the diarization pipeline on the loaded audio data.
  diarization = pipeline(audio_file)
  return diarization

## ── TRANSCRIPTION ───────────────────────────────────────

In [15]:
def transcribe_audio(audio_file,whisper_model):
  transcription = whisper_model.transcribe(audio_file)
  return transcription

## ── COMBINE DIARIZATION + TRANSCRIPTION ─────────────────

In [16]:
#for more clarification
# This function takes the output of pyannote's diarization process and converts it into a more structured, easy-to-use format.

def parse_pyannote_diarization(diarization):
    # Initialize an empty list to store the parsed segments.
    parsed_segments = []

    # Iterate through the diarization results using `itertracks` with `yield_label=True`.
    # This method yields tuples containing:
    # - `turn`: A segment object with `start` and `end` attributes representing the time interval.
    # - `_`: (Unused) Additional metadata (if any).
    # - `speaker`: The label assigned to the speaker for this segment.
    for turn, _, speaker in diarization.itertracks(yield_label=True):
        # Append a dictionary to the `parsed_segments` list, containing:
        # - `start`: The start time of the segment (in seconds).
        # - `end`: The end time of the segment (in seconds).
        # - `speaker`: The speaker label for this segment.
        parsed_segments.append({
            'start': turn.start,
            'end': turn.end,
            'speaker': speaker
        })

    # Return the list of parsed segments in a structured format.
    return parsed_segments

In [17]:
# Function to calculate the overlap score between a transcription segment and a speaker diarization segment.
def segment_score(transcript_segment, speaker_segment, threshold=0.5):

    # Extract the start and end times of the transcription segment.
    t_start, t_end = transcript_segment["start"], transcript_segment["end"]

    # Extract the start and end times of the speaker diarization segment.
    s_start, s_end = speaker_segment['start'], speaker_segment['end']

    # Calculate overlap duration and ratio
    overlap = max(0, min(t_end, s_end) - max(t_start, s_start))
    overlap_ratio = overlap / (t_end - t_start) if (t_end - t_start) > 0 else 0

    return overlap_ratio if overlap_ratio >= threshold else 0.0  #apply threshold

In [18]:
def combine_diarization_transcription(diarization, transcription):
  # Initialize an empty list to store results
  speaker_texts = []

  parsed_diarization = parse_pyannote_diarization(diarization)

  # Track the previous speaker to avoid redundancy
  prev_speaker = None
  buffered_text = ""

  # Iterate through each segment in the transcription output.
  for t_segment in transcription["segments"]:
      max_score = 0  # Initialize the maximum overlap score to 0.
      best_s_segment = None  # Initialize the best matching speaker segment to None.

      # Compare the current transcription segment with all speaker diarization segments.
      for s_segment in parsed_diarization:
          # Calculate the overlap score between the transcription segment and the speaker segment.
          score = segment_score(t_segment, s_segment)

          # If the current score is higher than the previous maximum, update the maximum score and store the best segment.
          if score > max_score:
              max_score = score
              best_s_segment = s_segment

      current_speaker = best_s_segment['speaker'] if best_s_segment else None

      # Merge consecutive segments from the same speaker
      if current_speaker == prev_speaker:
          buffered_text += " " + t_segment["text"]
      else:
          if prev_speaker is not None:  # Flush the buffer for the previous speaker
              speaker_texts.append({prev_speaker : buffered_text.strip()})
          buffered_text = t_segment["text"]
          prev_speaker = current_speaker

      # Handle segments with no speaker match
      if best_s_segment is None:
          speaker_texts.append({f"No speaker found for: ": t_segment['text']})

  # Print any remaining buffered text
  if buffered_text:
      speaker_texts.append({prev_speaker : buffered_text.strip()})

  return speaker_texts

## ── Transcription Enhancement ─────────────────────────────────

In [19]:
#Define the Refinement Function using Gemini API
def refine_text_with_gemini_api(text_segment, speaker_label, model):
    if not model or not text_segment or not text_segment.strip():
        return text_segment

    # Constructing a detailed prompt for Gemini.
    # Few-shot examples are very effective with Gemini too.
    # You MUST create good, representative examples in Arabic/Darija/French/English mixes.

    prompt = f"""You are an expert AI assistant specializing in refining raw speech transcriptions from Moroccan conversations.
These conversations frequently mix Moroccan Darija, Modern Standard Arabic, French, and English.
Your primary directive is to enhance the clarity, readability, and grammatical correctness of the transcription WHILE STRICTLY PRESERVING THE ORIGINAL LANGUAGE(S) USED by the speaker.

**CRITICAL INSTRUCTIONS:**
1.  **NO TRANSLATION:** Absolutely do NOT translate any part of the text into English or any other language if it wasn't originally in that language. If the input is in Arabic/Darija, the output MUST be in Arabic/Darija. If it's a mix (e.g., Darija with French words), the output MUST preserve that exact mix.
2.  **LANGUAGE PRESERVATION:** Maintain the original linguistic blend. Do not replace Darija words with MSA or vice-versa unless it's a clear ASR error of a common word.
3.  **CORRECT ASR ERRORS:** Fix obvious errors from the Automatic Speech Recognition within the original language.
4.  **PUNCTUATION & CAPITALIZATION:** Improve punctuation and capitalization for better readability, following conventions appropriate for the language(s) being used (e.g., Arabic punctuation for Arabic text, French conventions for French text).
5.  **NATURAL FLOW:** Ensure the language sounds natural as if a human wrote it down from the speech.
6.  **NO ADDITIONS/OPINIONS:** Do not add any information or opinions not present in the original text.
7.  **MINIMAL CHANGES IF GOOD:** If a segment is already high quality, return it as is or with only very minor, essential touch-ups.
8.  **OUTPUT ONLY THE REFINED TEXT:** Do not include any of your own commentary, apologies, or explanations in the response. Just the refined segment.

**EXAMPLES OF DESIRED BEHAVIOR (Illustrative - provide your own high-quality examples):**

*   **Example 1 (Darija):**
    *   Speaker: SPEAKER_A
    *   Raw Transcription Segment: "السلام عليكم لباس اش خبارك كلشي مزيان"
    *   Refined Transcription Segment: "السلام عليكم، لباس؟ اش خبارك؟ كلشي مزيان."

*   **Example 2 (Darija/French Mix):**
    *   Speaker: SPEAKER_B
    *   Raw Transcription Segment: "bonjour khouya cv ana ghaya daba on va commencer le travail"
    *   Refined Transcription Segment: "Bonjour khouya, ça va? أنا غاية دابا. On va commencer le travail."

*   **Example 3 (MSA - Modern Standard Arabic):**
    *   Speaker: SPEAKER_C
    *   Raw Transcription Segment: "نود ان نناقش هذا الموضوع الهام في جلستنا اليوم"
    *   Refined Transcription Segment: "نود أن نناقش هذا الموضوع الهام في جلستنا اليوم."

*   **Example 4 (Input is already good):**
    *   Speaker: SPEAKER_D
    *   Raw Transcription Segment: "C'est une très bonne idée, merci."
    *   Refined Transcription Segment: "C'est une très bonne idée, merci."


**TASK:**
Now, apply these instructions to the following segment:

Speaker: {speaker_label}
Raw Transcription Segment: "{text_segment}"

Refined Transcription Segment:
"""

    try:
        # Safety settings can be adjusted if needed, but defaults are usually fine.
        # generation_config = genai.types.GenerationConfig(temperature=0.3) # Lower temp for more factual
        response = model.generate_content(
            prompt,
            # generation_config=generation_config
            safety_settings=[ # Adjust if you face blocking issues, but be mindful of safety
                {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
                {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
                {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
                {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
            ]
        )

        # Accessing the text:
        # For gemini-pro and older models: response.text
        # For gemini-1.5-flash/pro (multi-candidate, though usually one for non-streaming): response.candidates[0].content.parts[0].text
        # The .text attribute on the response object itself is usually a shortcut.
        if hasattr(response, 'text'):
            refined_text = response.text
        elif response.candidates and response.candidates[0].content and response.candidates[0].content.parts:
            refined_text = "".join(part.text for part in response.candidates[0].content.parts)
        else:
            print("    Warning: Could not extract text from Gemini response in the expected way.")
            print(f"    Response object: {response}")
            return text_segment # Fallback

        return refined_text.strip()
    except Exception as e:
        print(f"    Error during Gemini API call for '{text_segment[:50]}...': {e}")
        # Check if the error response from Gemini API has more details
        if hasattr(e, 'response') and e.response:
            print(f"    Gemini API Response Error: {e.response}")
        elif hasattr(e, 'message') and e.message: # For some genai errors
             print(f"    GenAI Exception message: {e.message}")
        return text_segment # Fallback

In [20]:
# Main function to process the list of speaker texts
def process_refinement_with_gemini(speaker_texts: list[dict],
                                   gemini_model_instance: genai.GenerativeModel) -> list[dict]:
    """
    Processes a list of speaker text segments using the Gemini API for refinement.

    Args:
        speaker_texts: A list of dictionaries, where each dictionary has one
                       speaker label as key and the raw text segment as value.
                       Example: [{'SPEAKER_00': 'Raw text 1'}, {'SPEAKER_01': 'Raw text 2'}]
        gemini_model_instance: An initialized instance of google.generativeai.GenerativeModel.

    Returns:
        A list of dictionaries in the same format as input, but with refined text segments.
        Returns an empty list if gemini_model_instance is None or speaker_texts is empty.
    """
    if not gemini_model_instance:
        # print("Error: Gemini model instance is not provided or not initialized.")
        return speaker_texts # Return original if no model
    if not speaker_texts:
        return []

    refined_dialogue = []
    total_segments = len(speaker_texts)
    for i, entry in enumerate(speaker_texts):
        if not entry: # Skip if entry is empty
            continue

        # Assuming each dict in speaker_texts has exactly one key-value pair
        speaker_label, original_text = list(entry.items())[0]

        # Call the internal helper to refine the single segment
        refined_text = refine_text_with_gemini_api(original_text,
                                                     speaker_label,
                                                     gemini_model_instance)

        refined_dialogue.append({speaker_label: refined_text})

        # Optional: if you need progress indication outside the function,
        # you could use a callback or yield results. For now, it processes silently.
        # print(f"Processed segment {i+1}/{total_segments}") # For debugging if needed

    return refined_dialogue

## ── SPEAKER EMBEDDING (FOR RE‐ID) ────────────────────────

In [21]:
def extract_speaker_embeddings(diarization, audio_path: str, min_duration_sec: float = 0.5) -> dict:
    """
    Extracts speaker embeddings for segments longer than min_duration_sec.
    Skips too-short segments that would break the model.
    Returns a dictionary mapping 'segment_id' -> embedding_vector (1D numpy array).
    """
    raw_waveform, sr = torchaudio.load(audio_path)
    if raw_waveform.shape[0] > 1:
        raw_waveform = raw_waveform.mean(dim=0, keepdim=True)

    embeddings = {}
    for turn, _, speaker_label in diarization.itertracks(yield_label=True):
        duration = turn.end - turn.start
        if duration < min_duration_sec:
            continue  # Skip very short segments

        start_sample = int(turn.start * sr)
        end_sample   = int(turn.end * sr)
        chunk = raw_waveform[:, start_sample:end_sample]

        if chunk.numel() == 0:
            continue

        try:
            emb_np = speaker_embedding_model(chunk.to(speaker_embedding_model.device))
            # If the output is a torch.Tensor, squeeze and convert to numpy
            if isinstance(emb_np, torch.Tensor):
                emb_np = emb_np.squeeze(0).cpu().numpy()
            # If already a numpy array, squeeze out extra dimensions
            elif isinstance(emb_np, np.ndarray):
                if emb_np.ndim > 1:
                    emb_np = np.squeeze(emb_np)
            segment_id = f"{speaker_label}_{turn.start:.3f}_{turn.end:.3f}"
            embeddings[segment_id] = emb_np
        except Exception as e:
            print(f"Skipping segment due to error: {e}")

    return embeddings

### ── RE‐IDENTIFICATION IMPLEMENTATION ─────────────────────────────

In [22]:
def reidentify_speakers(embeddings: dict, existing_db: dict, threshold: float = 0.5) -> dict:
    """
    existing_db: { speaker_id: [list of numpy arrays], ... }
    embeddings: { segment_id: embedding_vec, ... }

    For each new segment embedding, measure cosine similarity against the centroid
    of each known speaker in existing_db. If best similarity ≥ threshold, assign to that speaker_id.
    Otherwise, assign a new speaker_id ("SPEAKER_{N}") and add to existing_db.

    Returns a mapping: { segment_id: assigned_speaker_id } and updates existing_db in-place.
    """
    assignment = {}
    # Precompute centroids for known speakers
    centroids = {}
    for spk_id, vecs in existing_db.items():
        try:
            if len(vecs) == 0:
                continue
            centroids[spk_id] = np.mean(np.stack(vecs), axis=0)
        except Exception as e:
            print(f"Skipping speaker {spk_id} due to error in centroid computation: {e}")

    # Determine the next speaker index for new speakers
    if existing_db:
        used_indices = [
            int(k.split("_")[1])
            for k in existing_db if k.startswith("SPEAKER_") and k.split("_")[1].isdigit()
        ]
        next_idx = max(used_indices, default=-1) + 1
    else:
        next_idx = 0

    for seg_id, emb in embeddings.items():
        # Ensure emb is a numpy array and squeeze extra dimensions
        if isinstance(emb, np.ndarray) and emb.ndim > 1:
            emb = np.squeeze(emb)
        if not isinstance(emb, np.ndarray) or emb.ndim != 1:
            print(f"Skipping {seg_id}: invalid embedding shape {emb.shape if isinstance(emb, np.ndarray) else type(emb)}")
            continue

        best_spk   = None
        best_score = -1.0

        for spk_id, centroid in centroids.items():
            try:
                # Also ensure centroid is a 1D array
                if isinstance(centroid, np.ndarray) and centroid.ndim > 1:
                    centroid = np.squeeze(centroid)
                score = cosine_similarity(
                    centroid.reshape(1, -1), emb.reshape(1, -1)
                )[0, 0]
                if score > best_score:
                    best_score = score
                    best_spk   = spk_id
            except Exception as e:
                print(f"Error computing similarity for {spk_id}: {e}")

        if best_score >= threshold:
            # Assign to existing speaker
            assignment[seg_id] = best_spk
            existing_db[best_spk].append(emb)
            centroids[best_spk] = np.mean(np.stack(existing_db[best_spk]), axis=0)
        else:
            # New speaker
            new_spk = f"SPEAKER_{next_idx:02d}"
            next_idx += 1
            assignment[seg_id] = new_spk
            existing_db[new_spk] = [emb]
            centroids[new_spk] = emb

    return assignment


In [23]:
def apply_speaker_reid_mapping(combined: list[dict], speaker_map: dict) -> list[dict]:
    """
    Replace diarization speaker labels in `combined` with reidentified speaker IDs.
    Safely handles entries with None speaker labels.
    """
    reid_combined = []
    for entry in combined:
        if not entry:
            continue

        original_speaker, text = list(entry.items())[0]

        # Fallback for unknown speaker labels
        if original_speaker is None or not isinstance(original_speaker, str):
            reid_combined.append({"UNKNOWN": text})
            continue

        # Find segment IDs that start with the diarized speaker label
        matching_ids = [
            k for k in speaker_map
            if isinstance(k, str) and k.startswith(original_speaker)
        ]

        if matching_ids:
            reidentified = [speaker_map[k] for k in matching_ids]
            assigned_speaker = max(set(reidentified), key=reidentified.count)
        else:
            assigned_speaker = original_speaker  # fallback

        reid_combined.append({assigned_speaker: text})

    return reid_combined


### ── TOP‐LEVEL PIPELINES ───────────────────────────────────

In [24]:
def process_entire_audio(input_audio_path: str, existing_db: dict, temp_dir: str = "/content/") -> dict:
    # Step 1: Preprocess and diarize
    processed_wav = preprocess_audio_and_save(input_audio_path, output_dir=temp_dir)
    diar = diarization_func(processed_wav, pipeline)

    # Step 2: Transcribe and combine
    transcription = transcribe_audio(processed_wav, whisper_model)
    combined = combine_diarization_transcription(diar, transcription)

    # Step 3: Refinement with Gemini
    refined = process_refinement_with_gemini(combined, gemini_model)

    # Step 4: Extract embeddings and reidentify speakers
    embeddings = extract_speaker_embeddings(diar, processed_wav)
    speaker_map = reidentify_speakers(embeddings, existing_db, threshold=0.7)

    # Step 5: Apply reidentification to refined text
    reid_combined = apply_speaker_reid_mapping(refined, speaker_map)

    # 8) Return
    return {
        "processed_wav": processed_wav,
        "diarization": diar,
        "embeddings": embeddings,
        "speaker_mapping": speaker_map,
        "transcription_raw": transcription,
        "combined": combined,
        "refined": refined
    }

In [25]:
def process_entire_audio_minimal(input_audio_path: str, existing_db: dict, temp_dir: str = "/content/") -> list[dict]:
    # Step 1: Preprocess and diarize
    processed_wav = preprocess_audio_and_save(input_audio_path, output_dir=temp_dir)
    diar = diarization_func(processed_wav, pipeline)

    # Step 2: Transcribe and combine
    transcription = transcribe_audio(processed_wav, whisper_model)
    combined = combine_diarization_transcription(diar, transcription)

    # Step 3: Refinement with Gemini
    refined = process_refinement_with_gemini(combined, gemini_model)

    # Step 4: Extract embeddings and reidentify speakers
    embeddings = extract_speaker_embeddings(diar, processed_wav)
    speaker_map = reidentify_speakers(embeddings, existing_db, threshold=0.7)

    # Step 5: Apply reidentification to refined text
    reid_combined = apply_speaker_reid_mapping(refined, speaker_map)


    return reid_combined

In [26]:
def process_entire_audio_without_refinement(input_audio_path: str, existing_db: dict, temp_dir: str = "/content/") -> dict:
    # Step 1: Preprocess and diarize
    processed_wav = preprocess_audio_and_save(input_audio_path, output_dir=temp_dir)
    diar = diarization_func(processed_wav, pipeline)

    # Step 2: Transcribe and combine
    transcription = transcribe_audio(processed_wav, whisper_model)
    combined = combine_diarization_transcription(diar, transcription)

    # Step 3: Extract embeddings and reidentify speakers
    embeddings = extract_speaker_embeddings(diar, processed_wav)
    speaker_map = reidentify_speakers(embeddings, existing_db, threshold=0.7)

    # Step 4: Apply reidentification to refined text
    reid_combined = apply_speaker_reid_mapping(combined, speaker_map)

    # 8) Return
    return {
        "processed_wav": processed_wav,
        "diarization": diar,
        "embeddings": embeddings,
        "speaker_mapping": speaker_map,
        "transcription_raw": transcription,
        "combined": combined,
        "reid_combined": reid_combined
    }

In [27]:
def process_entire_audio_minimal_without_refinement(input_audio_path: str, existing_db: dict, temp_dir: str = "/content/") -> list[dict]:
    # Step 1: Preprocess and diarize
    processed_wav = preprocess_audio_and_save(input_audio_path, output_dir=temp_dir)
    diar = diarization_func(processed_wav, pipeline)

    # Step 2: Transcribe and combine
    transcription = transcribe_audio(processed_wav, whisper_model)
    combined = combine_diarization_transcription(diar, transcription)

    # Step 3: Extract embeddings and reidentify speakers
    embeddings = extract_speaker_embeddings(diar, processed_wav)
    speaker_map = reidentify_speakers(embeddings, existing_db, threshold=0.7)

    # Step 4: Apply reidentification to refined text
    reid_combined = apply_speaker_reid_mapping(combined, speaker_map)


    return reid_combined

## ──Flask Server Code───────────────────────────────────

In [28]:
# from flask import Flask, request, jsonify
# import os
# import pickle

# app = Flask(__name__)
# os.makedirs("/content/uploads", exist_ok=True)

# # ── (A) Load or initialize speaker embedding database ─────────
# DB_PATH = "/content/speaker_db.pkl"
# if os.path.exists(DB_PATH):
#     with open(DB_PATH, "rb") as f:
#         existing_db = pickle.load(f)
# else:
#     existing_db = {}  # { speaker_id: [embedding_vecs...] }

# # ── (B) Root endpoint ─────────
# @app.route("/", methods=["GET"])
# def index():
#     return "✅ L’API is active. Use POST <strong>/process_audio</strong> for Full‐file processing or <strong>/process_audio_chunk</strong> for Real‐time processing."

# # ── (C) Full‐file processing endpoint ─────────────────────────
# @app.route("/process_audio", methods=["POST"])
# def process_audio_endpoint():
#     """
#     Expects multipart/form‐data with a file field named 'file'.
#     Returns JSON with:
#       - speaker_mapping: { segment_id: speaker_id, … }
#       - combined: [ {'SPEAKER_00': text}, … ]
#       - refined:  [ {'SPEAKER_00': text}, … ]
#     """
#     if "file" not in request.files:
#         return jsonify({"error": "No file part"}), 400

#     file = request.files["file"]
#     if file.filename == "":
#         return jsonify({"error": "No selected file"}), 400

#     # 1) Save uploaded file
#     save_path = os.path.join("/content/uploads", file.filename)
#     file.save(save_path)

#     # 2) Run pipeline (this updates existing_db in memory)
#     result = process_entire_audio(save_path, existing_db, temp_dir="/content/uploads")

#     # 3) Persist updated DB
#     with open(DB_PATH, "wb") as f:
#         pickle.dump(existing_db, f)

#     # 4) Build minimal JSON response
#     payload = {
#         "speaker_mapping": result["speaker_mapping"],
#         "combined": result["combined"],
#         "refined": result["refined"]
#     }
#     return jsonify(payload), 200

# # ── (D) Chunked / Real‐time processing endpoint ─────────────
# @app.route('/process_audio_chunk', methods=['POST'])
# def process_audio_chunk_api():
#     """
#     Expects JSON payload { "audio_data_base64": "<base64‐encoded WAV chunk>" }
#     Returns { status: "success", refined_transcription_chunk: [ {speaker: text}, … ] }
#     """
#     global existing_db

#     if not request.json or 'audio_data_base64' not in request.json:
#         return jsonify({"error": "Missing audio_data_base64 in JSON payload"}), 400

#     audio_data_base64 = request.json['audio_data_base64']

#     try:
#         import base64, tempfile
#         # Decode base64 into a temporary file
#         audio_bytes = base64.b64decode(audio_data_base64)
#         temp_input = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
#         temp_input.write(audio_bytes)
#         temp_input_path = temp_input.name
#         temp_input.close()

#         # # 1) Preprocess chunk
#         # processed_path = preprocess_audio_and_save(
#         #     temp_input_path,
#         #     output_dir=tempfile.gettempdir(),
#         #     output_filename=f"processed_{os.path.basename(temp_input_path)}"
#         # )

#         # # 2) Diarize
#         # diar = diarization_func(processed_path, ai_transcriber.pipeline)

#         # # 3) Transcribe
#         # transcription = transcribe_audio(processed_path, ai_transcriber.whisper_model)

#         # # 4) Combine
#         # chunk_speaker_texts = combine_diarization_transcription(diar, transcription)

#         # # 5) (Optional) Re‐identify speakers on this chunk
#         # embeddings = extract_speaker_embeddings(diar, processed_path)
#         # speaker_map = reidentify_speakers(embeddings, existing_db, threshold=0.6)
#         # globalized = apply_speaker_reid_mapping(chunk_speaker_texts, speaker_map)

#         # # 6) Refine
#         # refined_chunk = process_refinement_with_gemini(globalized, ai_transcriber.gemini_model)
#         refined_chunk = process_entire_audio_minimal(temp_input_path, existing_db, temp_dir="/content/uploads")

#         # 7) Update DB persistently
#         with open(DB_PATH, "wb") as f:
#             pickle.dump(existing_db, f)

#         # 8) Clean up temp files
#         os.remove(temp_input_path)

#         return jsonify({
#             "status": "success",
#             "refined_transcription_chunk": refined_chunk
#         }), 200

#     except Exception as e:
#         import traceback
#         print("Error processing audio chunk:")
#         print(traceback.format_exc())
#         # Ensure temp cleanup
#         if 'temp_input_path' in locals() and os.path.exists(temp_input_path):
#             os.remove(temp_input_path)
#         return jsonify({"error": str(e), "trace": traceback.format_exc()}), 500

# # ── (E) Full‐file processing endpoint without refinement ─────────────
# @app.route("/process_audio_no_refine", methods=["POST"])
# def process_audio_no_refine():
#     """
#     Expects multipart/form‐data with a file field named 'file'.
#     Returns JSON with:
#       - speaker_mapping: { segment_id: speaker_id, … }
#       - combined:      [ {'SPEAKER_00': text}, … ]
#       - reid_combined: [ {'SPEAKER_XX': text}, … ]
#     (Everything but Gemini refinement.)
#     """
#     global existing_db

#     if "file" not in request.files:
#         return jsonify({"error": "No file part"}), 400

#     file = request.files["file"]
#     if file.filename == "":
#         return jsonify({"error": "No selected file"}), 400

#     # 1) Save uploaded file
#     save_path = os.path.join("/content/uploads", file.filename)
#     file.save(save_path)

#     try:
#         # 2) Call the “without_refinement” pipeline
#         result = process_entire_audio_without_refinement(save_path, existing_db, temp_dir="/content/uploads")

#         # 3) Persist the updated speaker‐DB
#         with open(DB_PATH, "wb") as f:
#             pickle.dump(existing_db, f)

#         # 4) Build minimal response
#         payload = {
#             "speaker_mapping": result["speaker_mapping"],
#             "combined": result["combined"],
#             "reid_combined": result["reid_combined"]
#         }
#         return jsonify(payload), 200

#     except Exception as e:
#         import traceback
#         print("Error in /process_audio_no_refine:", traceback.format_exc())
#         return jsonify({"error": str(e), "trace": traceback.format_exc()}), 500


# # ── (F) Chunked / Real‐time processing endpoint without refinement ─────────────
# @app.route("/process_audio_minimal_no_refine", methods=["POST"])
# def process_audio_minimal_no_refine():
#     """
#     Expects multipart/form‐data with a file field named 'file'.
#     Returns JSON with:
#       - reid_combined: [ {'SPEAKER_XX': text}, … ]
#     (Minimal pipeline without Gemini refinement.)
#     """
#     global existing_db

#     if "file" not in request.files:
#         return jsonify({"error": "No file part"}), 400

#     file = request.files["file"]
#     if file.filename == "":
#         return jsonify({"error": "No selected file"}), 400

#     # 1) Save uploaded file
#     save_path = os.path.join("/content/uploads", file.filename)
#     file.save(save_path)

#     try:
#         # 2) Call the minimal “without_refinement” pipeline
#         reid_combined = process_entire_audio_minimal_without_refinement(save_path, existing_db, temp_dir="/content/uploads")

#         # 3) Persist the updated speaker‐DB
#         with open(DB_PATH, "wb") as f:
#             pickle.dump(existing_db, f)

#         # 4) Build minimal response
#         payload = {
#             "reid_combined": reid_combined
#         }
#         return jsonify(payload), 200

#     except Exception as e:
#         import traceback
#         print("Error in /process_audio_minimal_no_refine:", traceback.format_exc())
#         return jsonify({"error": str(e), "trace": traceback.format_exc()}), 500


# # ── (G) Health check endpoint ────────────────────────────────
# @app.route("/health", methods=["GET"])
# def health_check():
#     return jsonify({"status": "ok"}), 200


## ──Ngrok deployment───────────────────────────────────

In [29]:
# from threading import Thread
# from pyngrok import ngrok
# import time


# def run_flask():
#     app.run(host="0.0.0.0", port=5000, use_reloader=False)

# # Start Flask in a background thread
# flask_thread = Thread(target=run_flask, daemon=True)
# flask_thread.start()

# # Give Flask a moment to spin up
# time.sleep(1)

# # Ngrok token
# ngrok.set_auth_token("2xsBeIHIpm3z0NRBZ75pJPHCf9h_43PRsGtjjktLXYqDKtDUM")

# # Start ngrok tunnel on port 5000
# public_url = ngrok.connect(5000, bind_tls=True)
# print(f"🔗 Public URL: {public_url}")

In [30]:
# # Stop the ngrok tunnel
# ngrok.disconnect(public_url)
# print("Ngrok tunnel has been closed.")

In [31]:
# #Check for running tunnels before starting a new one
# print("Current tunnels:")
# for t in ngrok.get_tunnels():
#     print("-", t.public_url)


In [32]:
# # Kill ALL running tunnels and the background agent process
# ngrok.kill()

# print("✅ ngrok agent and all tunnels have been forcefully closed.")

## ──Tests───────────────────────────────────

### Test Full‐File Upload `/process_audio`

In [33]:
# import requests

# NGROK_URL = "http://abcdef1234.ngrok.io"
# NGROK_URL = "https://107f-34-127-85-46.ngrok-free.app/"
# filepath = "/content/ خاص تكون طارزان عايش فالغابة .. ريم شباط تنتقد معايير الدعم المباشر ووزيرة المالية ترد بالأرقام.mp3"  # Upload or place a test file there

# with open(filepath, "rb") as f:
#     files = { "file": f }
#     r = requests.post(f"{NGROK_URL}/process_audio", files=files)

# print(r.status_code, r.json())

### Test Streaming Chunk `/process_audio_chunk`

In [34]:
# import base64, requests

# NGROK_URL = "http://abcdef1234.ngrok.io"
# NGROK_URL = "https://e4c2-34-143-160-76.ngrok-free.app"
# chunk_path = "/content/some_short_chunk.wav"
# chunk_path = "/content/ خاص تكون طارزان عايش فالغابة .. ريم شباط تنتقد معايير الدعم المباشر ووزيرة المالية ترد بالأرقام.mp3"


# with open(chunk_path, "rb") as f:
#     b64 = base64.b64encode(f.read()).decode("utf-8")

# payload = { "audio_data_base64": b64 }
# r = requests.post(f"{NGROK_URL}/process_audio_chunk", json=payload)
# print(r.status_code, r.json())


In [35]:
# print(requests.get(f"{NGROK_URL}/"))

### Test Full‐file processing without refinement `/process_audio_no_refine`

In [36]:
# import requests

# NGROK_URL = "https://1e79-34-127-85-46.ngrok-free.app/"
# filepath = "/content/ خاص تكون طارزان عايش فالغابة .. ريم شباط تنتقد معايير الدعم المباشر ووزيرة المالية ترد بالأرقام.mp3"  # Upload or place a test file there

# with open(filepath, "rb") as f:
#     files = { "file": f }
#     r = requests.post(f"{NGROK_URL}/process_audio_no_refine", files=files)

# print(r.status_code, r.json())

### Test Streaming Chunk Without Refinement `/process_audio_minimal_no_refine`

In [57]:
# import base64, requests

# NGROK_URL = "https://e4c2-34-143-160-76.ngrok-free.app"

# chunk_path = "/content/ خاص تكون طارزان عايش فالغابة .. ريم شباط تنتقد معايير الدعم المباشر ووزيرة المالية ترد بالأرقام.mp3"


# with open(chunk_path, "rb") as f:
#     b64 = base64.b64encode(f.read()).decode("utf-8")

# payload = { "audio_data_base64": b64 }
# r = requests.post(f"{NGROK_URL}/process_audio_minimal_no_refine", json=payload)
# print(r.status_code, r.json())

### testing llms

In [58]:
audio_file = "/teamspace/studios/this_studio/my_transcriber_project/data/ خاص تكون طارزان عايش فالغابة .. ريم شباط تنتقد معايير الدعم المباشر ووزيرة المالية ترد بالأرقام.mp3"
audio_file = preprocess_audio_and_save(audio_file)
diarization = diarization_func(audio_file,pipeline)
transcription = transcribe_audio(audio_file,whisper_model)
speaker_texts = combine_diarization_transcription(diarization, transcription)
refined = process_refinement_with_gemini(speaker_texts,gemini_model)

print(refined)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[{'SPEAKER_02': 'سيدة الوزيرة، العديد من الأسر الفقيرة والهشة تقصوا من الدعم حيث تساكنين بوحدهم، وسبق لينا نبهناكم بهذا الأمر ولا مجيب. بحال أرمل، رجل مسن، أو شخص في وضعية إعاقة. ما معنى الحكومة تعطي الدعم وتزيد تقطع لهم 140 درهم ديال الدعم والتضامن، وتزيد تقطع لهم البنكة بين 28 إلى 20 درهم؟ وشي معايير ما أنزل الله بها من سلطان. بحال إلى عندك تعبئة ديال خمسة دراهم، راك لاباس عليك. البوطا، الماء والضوء، لاباس عليك. يعني أسيدة الوزيرة تكون ما عايش في الغابة، عاد باش غيمكّن لك تستافد من هذا الدعم. والطامة الكبرى أن الحالة المستورة كتأخذ 500 درهم ديال الدعم، في حين مواطن بكرامته كتحرموه هذا الدعم الذي لا يسمن ولا يغني من جوع. وما فهمناش كيفاش الحكومة تحيد برامج كتستافد من رعاية ملكية سامية، بحال برنامج "مليون محفظة" لأبناء الأسر الفقيرة، بحجة أنهم كيشدوا الدعم. وكاين أسر حصلت على الدعم لمدة شهر أو لشهرين وطلع لهم المؤشر. زعما واش أسيدة الوزيرة غيكونوا داروا لاباس في شهرين؟ راكم خرجتوا عباد الله تسعى. اتقوا الله في هذه البلاد وفي هذا الشعب. ونقطة أخيرة أسيدة الوزيرة، بغينا توضيح بخصوص التلا

In [39]:
for item in refined:
    for key,val in item.items():
        print(key , " : ",val)

SPEAKER_02  :  سيدة الوزيرة العديد من الأسر الفقيرة والهشة تقصو من الدعم حيث تساكنين بوحدهم  وصبقين ببهتكم بهذا الأمر ولا موجيب بحل أرمال رجل موسين أو شخص فضعية إعاقة  ما معنى الحكومة تعطي الدعم وتقطعهم 140 درهم دي الأم وتضامن  وتزيد تقطعهم البنكة بين 28 إلى 20 درهم وشي معايير ما أتى الله بها من سلطان  حال عندك تعبئة الخمسة درهم رك لا باس عليك البوطاء الماء والضوء لا باس عليك  يعني سيدة الوزيرة تكون طارزة عايش في الغابة عد بش غيمك لك تسافد من هذا الدعم  والطمى الكبرى أن الحولي المستورة كياخد 500 درهم ديال الدعم  في حين مواطن بكارامتو كتحرموهم هذا الدعم الذي لا يسمينو ولا يغني من جوع  وما فهمناش كيفاش الحكومة تحيد برامج كتحضاب رعاية ملكية السامية  بحال برنامج مليون محفظة لأبناء الأسر الفقيرة بحجة أنهم يشدوا الدعم  وكين أسر حصلت على الدعم لمدة شهر أو لشهرين وطلع لهم المؤشر  زعمال سيدة الوزيرة غيكونوا داروا لبس في شهرين  راكم خرجتوا عباد الله تسعى تقعوا الله فاد البلاد وفاد الشاب  ونقطة آخيرة سيدة الوزيرة بغينا توضيح بخصوص تلاعب وزن البوطاء الكبيرة التي تزادت عشرات درهم  ونقصات بين كيلو و