#| default_exp ai

In [None]:
import os
import logging
import json

import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import StreamingResponse, HTMLResponse
from fastrtc import (
    AdditionalOutputs,
    ReplyOnPause,
    Stream,
    AlgoOptions,
    SileroVadOptions,
    audio_to_bytes,
)
import torch
import subprocess
from faster_whisper import WhisperModel

def get_device(force_cpu=False):
    if force_cpu:
        return "cpu"
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        torch.mps.empty_cache()
        return "mps"
    else:
        return "cpu"
    
def get_torch_and_np_dtypes(device, use_bfloat16=False):
    if device == "cuda":
        torch_dtype = torch.bfloat16 if use_bfloat16 else torch.float16
        np_dtype = np.float16
    elif device == "mps":
        torch_dtype = torch.bfloat16 if use_bfloat16 else torch.float16
        np_dtype = np.float16
    else:
        torch_dtype = torch.float32
        np_dtype = np.float32
    return torch_dtype, np_dtype

def cuda_version_check():
    if torch.cuda.is_available():
        try:
            cuda_runtime = subprocess.check_output(["nvcc", "--version"]).decode()
            cuda_version = cuda_runtime.split()[-2]
        except Exception:
            # Fallback to PyTorch's built-in version if nvcc isn't available
            cuda_version = torch.version.cuda
        
        device_name = torch.cuda.get_device_name(0)
        return cuda_version, device_name
    else:
        return None, None


In [None]:


load_dotenv()
 
logger = logging.getLogger(__name__)


UI_MODE = os.getenv("UI_MODE", "fastapi").lower() # gradio | fastapi
UI_TYPE = os.getenv("UI_TYPE", "base").lower() # base | screen
APP_MODE = os.getenv("APP_MODE", "local").lower() # local | deployed
# TURN_PROVIDER = os.getenv("TURN_PROVIDER", "hf-cloudflare") # Not needed for local development

MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-base")  # Use base model for faster processing
LANGUAGE = os.getenv("LANGUAGE", "english")

logger.info(f"""
    --------------------------------------
    Configuration (environment variables):
    - UI_MODE: {UI_MODE}
    - UI_TYPE: {UI_TYPE}
    - APP_MODE: {APP_MODE}

    - MODEL_ID: {MODEL_ID}
    - LANGUAGE: {LANGUAGE}
    --------------------------------------
""")

# Initialize faster_whisper model for better performance
device = get_device(force_cpu=False)
transcribe_model = WhisperModel(
    MODEL_ID,
    device=device,
    compute_type="int8" if device == "cpu" else "float16",  # Use quantized model for CPU, float16 for GPU
    download_root=None,
    local_files_only=False
)

async def transcribe(audio: tuple[int, np.ndarray]):
    sample_rate, audio_array = audio
    logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}")

    # Convert audio bytes to the format expected by faster_whisper
    audio_bytes = audio_to_bytes(audio)

    # Use faster_whisper for transcription
    segments, info = transcribe_model.transcribe(
        audio_bytes,
        language=LANGUAGE,
        beam_size=1,  # Faster with beam_size=1
        condition_on_previous_text=False,  # Faster for real-time
    )

    # Collect all segments into text
    transcript_text = " ".join([segment.text.strip() for segment in segments])
    yield AdditionalOutputs(transcript_text.strip())


logger.info("Initializing FastRTC stream")
stream = Stream(
    handler=ReplyOnPause(
        transcribe,
        algo_options=AlgoOptions(
            # Duration in seconds of audio chunks passed to the VAD model (default 0.6)
            audio_chunk_duration=0.3,  # Reduced for faster response
            # If the chunk has more than started_talking_threshold seconds of speech, the user started talking (default 0.2)
            started_talking_threshold=0.05,  # Lower threshold for faster detection
            # If, after the user started speaking, there is a chunk with less than speech_threshold seconds of speech, the user stopped speaking. (default 0.1)
            speech_threshold=0.05,  # Lower threshold for faster detection
            # Max duration of speech chunks before the handler is triggered, even if a pause is not detected by the VAD model. (default -inf)
            max_continuous_speech_s=10  # Reduced for faster processing
        ),
        model_options=SileroVadOptions(
            # Threshold for what is considered speech (default 0.5)
            threshold=0.4,  # Lower threshold for better detection
            # Final speech chunks shorter min_speech_duration_ms are thrown out (default 250)
            min_speech_duration_ms=150,  # Reduced for faster processing
            # Max duration of speech chunks, longer will be split at the timestamp of the last silence that lasts more than 100ms (if any) or just before max_speech_duration_s (default float('inf')) (used internally in the VAD algorithm to split the audio that's passed to the algorithm)
            max_speech_duration_s=8,  # Reduced for faster processing
            # Wait for ms at the end of each speech chunk before separating it (default 2000)
            min_silence_duration_ms=200,  # Reduced for faster response
            # Chunk size for VAD model. Can be 512, 1024, 1536 for 16k s.r. (default 1024)
            window_size_samples=512,  # Smaller window for faster processing
            # Final speech chunks are padded by speech_pad_ms each side (default 400)
            speech_pad_ms=100,  # Reduced padding
        ),
    ),
    # send-receive: bidirectional streaming (default)
    # send: client to server only
    # receive: server to client only
    modality="audio",
    mode="send",
    additional_outputs=[
        gr.Textbox(label="Transcript"),
    ],
    additional_outputs_handler=lambda current, new: current + " " + new,
    rtc_configuration=None,  # Not needed for local development
)

app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
stream.mount(app)

@app.get("/")
async def index():
    if UI_TYPE == "base":
        html_content = open("static/index.html").read()
    elif UI_TYPE == "screen":
        html_content = open("static/index-screen.html").read()

    rtc_configuration = None  # Not needed for local development
    logger.info(f"RTC configuration: {rtc_configuration}")
    html_content = html_content.replace("__INJECTED_RTC_CONFIG__", json.dumps(rtc_configuration))
    return HTMLResponse(content=html_content)

@app.get("/transcript")
def _(webrtc_id: str):
    logger.debug(f"New transcript stream request for webrtc_id: {webrtc_id}")
    async def output_stream():
        try:
            async for output in stream.output_stream(webrtc_id):
                transcript = output.args[0]
                logger.debug(f"Sending transcript for {webrtc_id}: {transcript[:50]}...")
                yield f"event: output\ndata: {transcript}\n\n"
        except Exception as e:
            logger.error(f"Error in transcript stream for {webrtc_id}: {str(e)}")
            raise

    return StreamingResponse(output_stream(), media_type="text/event-stream")


if __name__ == "__main__":

    server_name = os.getenv("SERVER_NAME", "localhost")
    port = os.getenv("PORT", 7860)
    
    if UI_MODE == "gradio":
        logger.info("Launching Gradio UI")
        stream.ui.launch(
            server_port=port, 
            server_name=server_name,
            ssl_verify=False,
            debug=True
        )
    else:
        import uvicorn
        logger.info("Launching FastAPI server")
        uvicorn.run(app, host=server_name, port=port)