In [2]:
import torch
import torchaudio
import numpy as np
import tempfile
import soundfile as sf
import noisereduce as nr
import gradio as gr
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from pyannote.audio import Pipeline

from dotenv import load_dotenv
load_dotenv()
import os

# Load Whisper
processor = WhisperProcessor.from_pretrained("./whisper-small-en", language="English", task="transcribe")
model = WhisperForConditionalGeneration.from_pretrained("./whisper-small-en").to("cuda:6")
model.generation_config.forced_decoder_ids = None

# Load diarization pipeline (no token needed for this model)
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization@2.1",
        use_auth_token=os.getenv("hugging_face_token"))

# Resample to 16kHz
def resample_audio(waveform, orig_sr):
    if orig_sr != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=orig_sr, new_freq=16000)
        waveform = resampler(torch.tensor(waveform))
    return waveform.numpy(), 16000

# Denoise audio
def denoise(waveform, sample_rate):
    noise_clip = waveform[:int(0.5 * sample_rate)]
    return nr.reduce_noise(y=waveform, sr=sample_rate, y_noise=noise_clip)

# Save waveform to temp WAV file for diarization
def save_audio_temp(waveform, sample_rate):
    tmpfile = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
    sf.write(tmpfile.name, waveform, sample_rate)
    return tmpfile.name

# Transcribe a waveform chunk using Whisper
def transcribe_chunk(chunk_tensor, sample_rate=16000):
    inputs = processor(chunk_tensor.numpy(), sampling_rate=sample_rate, return_tensors="pt", padding=True)
    input_features = inputs["input_features"].to("cuda:6")
    predicted_ids = model.generate(input_features=input_features)
    return processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]

# Main diarized transcription function
def diarized_transcribe(audio):
    try:
        sample_rate, waveform = audio

        # Normalize if int16
        if waveform.dtype == "int16":
            waveform = waveform.astype("float32") / 32768.0

        # Denoise and resample
        waveform = denoise(waveform, sample_rate)
        waveform, sample_rate = resample_audio(waveform, sample_rate)
        audio_path = save_audio_temp(waveform, sample_rate)

        # Run diarization
        diarization = pipeline(audio_path)
        full_audio_tensor = torch.tensor(waveform)

        # Build transcript
        transcript = ""
        for turn, _, speaker in diarization.itertracks(yield_label=True):
            start = int(turn.start * sample_rate)
            end = int(turn.end * sample_rate)
            segment = full_audio_tensor[start:end]

            if segment.numel() == 0:
                continue

            text = transcribe_chunk(segment)
            transcript += f"[{speaker} | {turn.start:.1f}s - {turn.end:.1f}s]: {text}\n"

        return transcript.strip()

    except Exception as e:
        return f"Error during processing: {str(e)}"

# Gradio UI
gr.Interface(
    fn=diarized_transcribe,
    inputs=gr.Audio(type="numpy", label="Upload or Record Audio"),
    outputs=gr.Textbox(label="Speaker-attributed Transcription"),
    title="Whisper + Diarization",
    description="Denoises audio, diarizes speakers, and transcribes speech with Whisper.",
    allow_flagging="never",
    live=True
).launch(share=True)


INFO:speechbrain.utils.quirks:Applied quirks (see `speechbrain.utils.quirks`): [disable_jit_profiling, allow_tf32]
INFO:speechbrain.utils.quirks:Excluded quirks specified by the `SB_DISABLE_QUIRKS` environment (comma-separated list): []
INFO:datasets:PyTorch version 2.6.0+cu118 available.
Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.5.1. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../.cache/torch/pyannote/models--pyannote--segmentation/snapshots/c4c8ceafcbb3a7a280c2d357aee9fbc9b0be7f9b/pytorch_model.bin`
INFO:speechbrain.utils.fetching:Fetch hyperparams.yaml: Fetching from HuggingFace Hub 'speechbrain/spkrec-ecapa-voxceleb' if not cached


Model was trained with pyannote.audio 0.0.1, yours is 3.3.2. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.6.0+cu118. Bad things might happen unless you revert torch to 1.x.


INFO:speechbrain.utils.fetching:Fetch custom.py: Fetching from HuggingFace Hub 'speechbrain/spkrec-ecapa-voxceleb' if not cached
INFO:speechbrain.utils.fetching:Fetch embedding_model.ckpt: Fetching from HuggingFace Hub 'speechbrain/spkrec-ecapa-voxceleb' if not cached
INFO:speechbrain.utils.fetching:Fetch mean_var_norm_emb.ckpt: Fetching from HuggingFace Hub 'speechbrain/spkrec-ecapa-voxceleb' if not cached
INFO:speechbrain.utils.fetching:Fetch classifier.ckpt: Fetching from HuggingFace Hub 'speechbrain/spkrec-ecapa-voxceleb' if not cached
INFO:speechbrain.utils.fetching:Fetch label_encoder.txt: Fetching from HuggingFace Hub 'speechbrain/spkrec-ecapa-voxceleb' if not cached
INFO:speechbrain.utils.parameter_transfer:Loading pretrained files for: embedding_model, mean_var_norm_emb, classifier, label_encoder
INFO:httpx:HTTP Request: GET http://127.0.0.1:7860/gradio_api/startup-events "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: HEAD http://127.0.0.1:7860/ "HTTP/1.1 200 OK"


* Running on local URL:  http://127.0.0.1:7860


INFO:httpx:HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: GET https://api.gradio.app/v3/tunnel-request "HTTP/1.1 200 OK"


* Running on public URL: https://67827e642aeb6b7db7.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


INFO:httpx:HTTP Request: HEAD https://67827e642aeb6b7db7.gradio.live "HTTP/1.1 200 OK"




The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/mnt/nvme_disk2/User_data/nb57077k/miniconda3/envs/ax/lib/python3.10/site-packages/uvicorn/protocols/http/h11_impl.py", line 403, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
  File "/mnt/nvme_disk2/User_data/nb57077k/miniconda3/envs/ax/lib/python3.10/site-packages/uvicorn/middleware/proxy_headers.py", line 60, in __call__
    return await self.app(scope, receive, send)
  File "/mnt/nvme_disk2/User_data/nb57077k/miniconda3/envs/ax/lib/python3.10/site-packages/fastapi/applications.py", line 1054, in __call__
    await super().__call__(scope, receive, send)
  File "/mnt/nvme_disk2/User_data/nb57077k/miniconda3/envs/ax/lib/python3.10/site-packages/starlett