In [None]:
# ------------------------------
# FastAPI Whisper Spanish-only API with 16kHz Resampling and WAV Audio Saving
# ------------------------------

import whisper
import nest_asyncio
import uvicorn
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
import tempfile
import os
import numpy as np
import librosa
from datetime import datetime
import soundfile as sf
import warnings
from pydub import AudioSegment
import io

# Suppress warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

# Allow running in Jupyter / notebooks
nest_asyncio.apply()

# ---------------- Settings ----------------
LANGUAGE = "es"          # Force Spanish
MODEL_SIZE = "medium"
DEVICE = "cpu"
TARGET_SAMPLE_RATE = 16000  # Whisper requires 16kHz
TARGET_CHANNELS = 1         # Mono audio

# Audio saving configuration
SAVE_AUDIO_FILES = True
AUDIO_SAVE_DIR = "received_audio"
os.makedirs(AUDIO_SAVE_DIR, exist_ok=True)

print("Loading Whisper model...")
model = whisper.load_model(MODEL_SIZE, device=DEVICE)
print("Model loaded!")
print(f"Audio files will be saved to: {os.path.abspath(AUDIO_SAVE_DIR)}")

# ---------------- FastAPI ----------------
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Serve the audio files statically
app.mount("/audio", StaticFiles(directory=AUDIO_SAVE_DIR), name="audio")

def save_audio_as_wav(audio_np, sample_rate, filename_prefix, transcribed_text):
    """
    Save audio as WAV file for easy playback
    """
    if not SAVE_AUDIO_FILES:
        return None
    
    try:
        # Create timestamp for filename
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
        
        # Sanitize transcribed text for filename
        safe_text = "".join(c for c in transcribed_text[:30] if c.isalnum() or c in (' ', '-', '_')).strip()
        safe_text = safe_text.replace(' ', '_') if safe_text else "no_text"
        
        # Create WAV filename
        filename = f"{timestamp}_{filename_prefix}_{safe_text}.wav"
        filepath = os.path.join(AUDIO_SAVE_DIR, filename)
        
        # Ensure audio is in the right format for WAV saving
        if audio_np.dtype != np.float32:
            audio_np = audio_np.astype(np.float32)
        
        # Normalize audio to prevent clipping
        if np.max(np.abs(audio_np)) > 1.0:
            audio_np = audio_np / np.max(np.abs(audio_np))
        
        # Save as WAV using soundfile
        sf.write(filepath, audio_np, sample_rate)
        
        print(f"Audio saved: {filename}")
        return filename
        
    except Exception as e:
        print(f"Error saving WAV audio: {e}")
        return None

def convert_to_wav_16k_mono(file_path: str):
    """
    Convert any audio file to 16kHz mono WAV using pydub (handles WebM better)
    Returns the audio array and sample rate
    """
    try:
        print("Converting audio with pydub...")
        
        # Load audio with pydub (handles WebM, MP3, etc.)
        audio = AudioSegment.from_file(file_path)
        
        print(f"Original: {audio.frame_rate}Hz, {audio.channels} channels, {len(audio)}ms")
        
        # Convert to mono if stereo
        if audio.channels > 1:
            audio = audio.set_channels(1)
            print("Converted to mono")
        
        # Resample to 16kHz if needed
        if audio.frame_rate != TARGET_SAMPLE_RATE:
            audio = audio.set_frame_rate(TARGET_SAMPLE_RATE)
            print(f"Resampled to {TARGET_SAMPLE_RATE}Hz")
        
        # Export to WAV bytes
        wav_bytes = io.BytesIO()
        audio.export(wav_bytes, format="wav")
        wav_bytes.seek(0)
        
        # Load with soundfile to get numpy array
        audio_np, sr = sf.read(wav_bytes)
        
        print(f"Converted: {sr}Hz, shape: {audio_np.shape}, dtype: {audio_np.dtype}")
        print(f"Duration: {len(audio_np)/sr:.2f}s")
        
        # Ensure we return float32 for Whisper compatibility
        if audio_np.dtype != np.float32:
            audio_np = audio_np.astype(np.float32)
            print(f"Converted dtype to float32")
        
        return audio_np, sr
        
    except Exception as e:
        print(f"Error converting with pydub: {e}")
        # Fallback to librosa
        try:
            print("Falling back to librosa...")
            audio_np, original_sample_rate = librosa.load(file_path, sr=None, mono=True)
            print(f"Librosa loaded: {original_sample_rate}Hz, dtype: {audio_np.dtype}")
            
            if original_sample_rate != TARGET_SAMPLE_RATE:
                audio_np = librosa.resample(audio_np, orig_sr=original_sample_rate, target_sr=TARGET_SAMPLE_RATE)
                print(f"Resampled to {TARGET_SAMPLE_RATE}Hz")
            
            # Ensure float32 for Whisper
            if audio_np.dtype != np.float32:
                audio_np = audio_np.astype(np.float32)
                print(f"Converted dtype to float32")
            
            return audio_np, TARGET_SAMPLE_RATE
        except Exception as e2:
            raise Exception(f"Failed to process audio file: {e2}")

def prepare_audio_for_whisper(audio_np: np.ndarray) -> np.ndarray:
    """
    Ensure audio is in the exact format Whisper expects
    """
    # Convert to float32 if needed
    if audio_np.dtype != np.float32:
        audio_np = audio_np.astype(np.float32)
        print(f"Converted audio dtype to float32")
    
    # Normalize to [-1, 1] range if needed
    if np.max(np.abs(audio_np)) > 1.0:
        audio_np = audio_np / np.max(np.abs(audio_np))
        print("Normalized audio to [-1, 1] range")
    
    print(f"Final audio shape: {audio_np.shape}, dtype: {audio_np.dtype}")
    return audio_np

def save_uploaded_file_as_wav(file_content: bytes, filename: str, transcribed_text: str):
    """
    Save the original uploaded file as WAV using pydub
    """
    if not SAVE_AUDIO_FILES:
        return None
    
    try:
        # Create timestamp for filename
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
        
        # Sanitize transcribed text for filename
        safe_text = "".join(c for c in transcribed_text[:30] if c.isalnum() or c in (' ', '-', '_')).strip()
        safe_text = safe_text.replace(' ', '_') if safe_text else "no_text"
        
        # Create WAV filename
        wav_filename = f"{timestamp}_upload_{safe_text}.wav"
        wav_filepath = os.path.join(AUDIO_SAVE_DIR, wav_filename)
        
        # Load audio from bytes with pydub
        audio = AudioSegment.from_file(io.BytesIO(file_content))
        
        # Convert to mono and 16kHz
        if audio.channels > 1:
            audio = audio.set_channels(1)
        if audio.frame_rate != TARGET_SAMPLE_RATE:
            audio = audio.set_frame_rate(TARGET_SAMPLE_RATE)
        
        # Export as WAV
        audio.export(wav_filepath, format="wav")
        
        print(f"Uploaded file saved as WAV: {wav_filename}")
        return wav_filename
        
    except Exception as e:
        print(f"Error saving uploaded file as WAV: {e}")
        return None

def process_audio_file(file_path: str, original_file_content: bytes, original_filename: str):
    """
    Process audio file and convert to Whisper-compatible format
    """
    try:
        # Convert to 16kHz mono using pydub (handles WebM better)
        audio_16k, sample_rate = convert_to_wav_16k_mono(file_path)
        
        # Prepare audio for Whisper (ensure correct dtype)
        audio_16k = prepare_audio_for_whisper(audio_16k)
        
        print(f"Transcribing with Whisper...")
        # Transcribe using Whisper's native array input
        result = model.transcribe(
            audio_16k, 
            language=LANGUAGE, 
            task="transcribe", 
            verbose=False
        )
        
        transcribed_text = result.get("text", "").strip()
        print(f"Transcription: '{transcribed_text}'")
        
        # Save the processed audio as WAV
        processed_wav_filename = save_audio_as_wav(
            audio_16k, sample_rate, "processed", transcribed_text
        )
        
        # Save the original uploaded file as WAV
        upload_wav_filename = save_uploaded_file_as_wav(
            original_file_content, original_filename, transcribed_text
        )
        
        return transcribed_text
        
    except Exception as e:
        print(f"Error processing audio file: {e}")
        raise

@app.post("/transcribe")
async def transcribe_audio(file: UploadFile = File(...)):
    tmp_path = None
    try:
        # Get file suffix
        suffix = "." + file.filename.split(".")[-1] if "." in file.filename else ".webm"

        # Read the file content first
        content = await file.read()
        
        # Save uploaded file temporarily
        with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
            tmp.write(content)
            tmp_path = tmp.name

        print(f"Processing file: {file.filename} ({len(content)} bytes)")
        
        # Process the audio file
        text = process_audio_file(tmp_path, content, file.filename)

        # Clean up temporary file
        try:
            os.remove(tmp_path)
        except:
            pass

        return JSONResponse({"text": text})

    except Exception as e:
        print(f"Transcription error: {e}")
        # Clean up temporary file on error too
        try:
            if tmp_path and os.path.exists(tmp_path):
                os.remove(tmp_path)
        except:
            pass
        return JSONResponse({"error": str(e)}, status_code=500)

@app.get("/audio_files")
async def list_audio_files():
    """
    Endpoint to list all saved audio files for verification
    """
    if not SAVE_AUDIO_FILES:
        return JSONResponse({"error": "Audio file saving is disabled"})
    
    try:
        files = []
        for filename in os.listdir(AUDIO_SAVE_DIR):
            filepath = os.path.join(AUDIO_SAVE_DIR, filename)
            if os.path.isfile(filepath) and filename.endswith('.wav'):
                stat = os.stat(filepath)
                files.append({
                    "filename": filename,
                    "url": f"/audio/{filename}",
                    "play_url": f"/play/{filename}",
                    "size": stat.st_size,
                    "modified": datetime.fromtimestamp(stat.st_mtime).isoformat()
                })
        
        return JSONResponse({"audio_files": sorted(files, key=lambda x: x["filename"], reverse=True)})
    
    except Exception as e:
        return JSONResponse({"error": str(e)}, status_code=500)

@app.get("/play/{filename}")
async def play_audio(filename: str):
    """
    Direct endpoint to play a specific audio file
    """
    if not SAVE_AUDIO_FILES:
        return JSONResponse({"error": "Audio file saving is disabled"})
    
    filepath = os.path.join(AUDIO_SAVE_DIR, filename)
    if os.path.exists(filepath) and filename.endswith('.wav'):
        return FileResponse(filepath, media_type='audio/wav', filename=filename)
    else:
        return JSONResponse({"error": "File not found"}, status_code=404)

@app.delete("/audio_files")
async def clear_audio_files():
    """
    Endpoint to clear all saved audio files
    """
    if not SAVE_AUDIO_FILES:
        return JSONResponse({"error": "Audio file saving is disabled"})
    
    try:
        for filename in os.listdir(AUDIO_SAVE_DIR):
            filepath = os.path.join(AUDIO_SAVE_DIR, filename)
            if os.path.isfile(filepath):
                os.remove(filepath)
        
        return JSONResponse({"message": f"All audio files cleared from {AUDIO_SAVE_DIR}"})
    
    except Exception as e:
        return JSONResponse({"error": str(e)}, status_code=500)

@app.get("/")
async def root():
    return {
        "message": "Whisper Spanish Transcription API", 
        "status": "running",
        "endpoints": {
            "transcribe": "POST /transcribe",
            "list_files": "GET /audio_files", 
            "play_audio": "GET /play/{filename}",
            "clear_files": "DELETE /audio_files"
        }
    }

# ---------------- Run Server ----------------
if __name__ == "__main__":
    print(f"Server starting...")
    print(f"Audio files will be saved as WAV to: {os.path.abspath(AUDIO_SAVE_DIR)}")
    print(f"Access /audio_files endpoint to list saved files")
    print(f"Access /play/filename.wav to directly play any audio file")
    uvicorn.run(app, host="0.0.0.0", port=8000)