In [None]:
from fastapi import FastAPI, WebSocket
import whisper
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import os
import tempfile
import socket
import threading

# Initialize FastAPI app
app = FastAPI()

# Load Whisper model for transcription
whisper_model_path = os.getenv("WHISPER_MODEL_PATH", "base")
whisper_model = whisper.load_model(whisper_model_path)

# Load locally stored DeepSeek model for content correction
deepseek_model_path = os.getenv("DEEPSEEK_MODEL_PATH", "./deepseek-3b")
deepseek_tokenizer = AutoTokenizer.from_pretrained(deepseek_model_path)
deepseek_model = AutoModelForSeq2SeqLM.from_pretrained(deepseek_model_path, torch_dtype=torch.float16, device_map="auto")

def correct_text_local(text: str):
    """Runs DeepSeek locally for content correction."""
    inputs = deepseek_tokenizer(text, return_tensors="pt", truncation=True, padding=True).to("cuda")
    outputs = deepseek_model.generate(**inputs, max_length=512)
    corrected_text = deepseek_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return corrected_text

# Load summarization and translation models from Hugging Face
tokenizer_summarization = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
model_summarization = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")

tokenizer_translation = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
model_translation = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-fr")

@app.websocket("/stream/")
async def stream_audio(websocket: WebSocket):
    """Handles real-time audio streaming."""
    await websocket.accept()
    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
        while True:
            try:
                data = await websocket.receive_bytes()
                temp_audio.write(data)
            except:
                break
    
    # Transcribe the streamed audio
    result = whisper_model.transcribe(temp_audio.name)
    await websocket.send_json({"transcription": result["text"]})
    await websocket.close()

@app.post("/correct/")
async def correct_text(text: str):
    """Corrects content using locally stored DeepSeek model."""
    corrected_text = correct_text_local(text)
    return {"corrected_text": corrected_text}

@app.post("/summarize/")
async def summarize_text(text: str):
    """Summarizes the given text."""
    inputs = tokenizer_summarization(text, return_tensors="pt", truncation=True, padding=True).to("cuda")
    outputs = model_summarization.generate(**inputs, max_length=150, min_length=50, do_sample=False)
    summary = tokenizer_summarization.decode(outputs[0], skip_special_tokens=True)
    return {"summary": summary}

@app.post("/translate/")
async def translate_text(text: str):
    """Translates text from English to French."""
    inputs = tokenizer_translation(text, return_tensors="pt", truncation=True, padding=True).to("cuda")
    outputs = model_translation.generate(**inputs, max_length=512)
    translated_text = tokenizer_translation.decode(outputs[0], skip_special_tokens=True)
    return {"translated_text": translated_text}

# Audio Streaming Server
PORT = 5000  # Port for streaming

def audio_stream_server():
    """Receives real-time audio data over a socket and saves it."""
    server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server_socket.bind(("0.0.0.0", PORT))
    server_socket.listen(1)
    print(f"[STREAMING] Server listening on port {PORT}...")
    
    conn, addr = server_socket.accept()
    print(f"[CONNECTED] Receiving stream from {addr}")
    
    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
        try:
            while True:
                data = conn.recv(1024)
                if not data:
                    break
                temp_audio.write(data)
        except KeyboardInterrupt:
            print("[STOPPED] Streaming stopped.")
        finally:
            conn.close()
            server_socket.close()

# Start the streaming server in a separate thread
threading.Thread(target=audio_stream_server, daemon=True).start()