In [None]:
import os
import sys
import argparse
import queue
import threading
import time
import logging
from logging.handlers import RotatingFileHandler
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
import csv
import numpy as np
import sounddevice as sd
import soundfile as sf
import webrtcvad
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from TTS.api import TTS
from dotenv import load_dotenv


In [None]:
# ----------------------
# Defaults & constants
# ----------------------
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
DEFAULT_MODEL = "openai/whisper-medium"
DEFAULT_SAMPLE_RATE = 16000
DEFAULT_FRAME_MS = 20
DEFAULT_VAD_AGGR = 2
DEFAULT_MAX_SILENCE_FRAMES = 12  # ~240ms
TRANSCRIPT_CSV = "transcripts.csv"
LOG_FILE = "realtime_translator.log"



# Filler/short outputs to ignore
IGNORE_SET = {"thank you", "thanks", "ok", "okay", "hmm", "mm", "mhm", "yeah", "no", "nah"}


In [None]:
def setup_logger(logfile: str = LOG_FILE, level=logging.INFO):
    logger = logging.getLogger("realtime_translator")
    logger.setLevel(level)
    fmt = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")

    ch = logging.StreamHandler()
    ch.setFormatter(fmt)
    logger.addHandler(ch)

    fh = RotatingFileHandler(logfile, maxBytes=5 * 1024 * 1024, backupCount=3)
    fh.setFormatter(fmt)
    logger.addHandler(fh)

    return logger

logger = setup_logger()

def parse_args():
    class Args:
        model = DEFAULT_MODEL
        sample_rate = DEFAULT_SAMPLE_RATE
        frame_ms = DEFAULT_FRAME_MS
        vad_aggr = DEFAULT_VAD_AGGR
        lang = None
        task = "translate"
        max_silence_frames = DEFAULT_MAX_SILENCE_FRAMES
        num_beams = 5
        output_csv = TRANSCRIPT_CSV
        no_tts = False
    return Args()


In [None]:
def save_transcript(csv_path, timestamp, input_lang, task, text):
    write_header = not os.path.exists(csv_path)
    with open(csv_path, "a", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        if write_header:
            w.writerow(["timestamp", "input_lang", "task", "text"])
        w.writerow([timestamp, input_lang or "", task, text])


In [None]:
class RealtimeTranslator:
    def __init__(self, args):
        self.args = args
        self.sample_rate = args.sample_rate
        self.frame_ms = args.frame_ms
        self.chunk_samples = int(self.sample_rate * self.frame_ms / 1000)
        self.vad = webrtcvad.Vad(args.vad_aggr)
        self.q = queue.Queue()
        self.running = threading.Event()
        self.running.set()
        self.silence_threshold = args.max_silence_frames

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.dtype = torch.float16 if self.device == "cuda" else torch.float32
        logger.info(f"Using device={self.device} dtype={self.dtype}")

        logger.info("Loading Whisper model and processor...")
        self.processor = WhisperProcessor.from_pretrained(args.model, use_auth_token=HF_TOKEN)
        self.model = WhisperForConditionalGeneration.from_pretrained(
            args.model, torch_dtype=self.dtype, use_auth_token=HF_TOKEN
        )
        self.model.to(self.device)
        self.model.eval()

        gen_kw = dict(
            task=args.task,
            language=args.lang,
            num_beams=args.num_beams,
            temperature=0.0,
            no_repeat_ngram_size=3,
            min_length=4,
            length_penalty=1.0,
            suppress_tokens=[],
        )
        self.gen_kw = gen_kw

        self.no_tts = args.no_tts
        if not self.no_tts:
            logger.info("Loading TTS model...")
            try:
                self.tts = TTS(
                    model_name="tts_models/multilingual/multi-dataset/your_tts",
                    progress_bar=False, gpu=(self.device == "cuda")
                )
            except Exception as e:
                logger.exception("Failed to load TTS model - continuing in no-tts mode.")
                self.no_tts = True
                self.tts = None
        else:
            self.tts = None

        self.default_speaker = self.tts.speakers[0] if (self.tts and len(self.tts.speakers) > 0) else None
        logger.info(f"Default speaker: {self.default_speaker}")

        self.executor = ThreadPoolExecutor(max_workers=2)
        self._warmup()

    def _warmup(self):
        try:
            dummy = np.zeros((1600,), dtype=np.float32)
            inputs = self.processor(dummy, sampling_rate=self.sample_rate, return_tensors="pt")
            input_features = inputs.input_features.to(self.device, dtype=self.dtype)
            with torch.no_grad():
                _ = self.model.generate(input_features, max_length=1, **self.gen_kw)
            logger.info("Warmup complete.")
        except Exception as e:
            logger.warning("Warmup failed: %s", e)

    def audio_callback(self, indata, frames, time_info, status):
        if status:
            logger.debug("Input status: %s", status)
        self.q.put(bytes(indata))

    def bytes_to_tensor(self, fr):
        arr = np.frombuffer(fr, dtype=np.int16).astype(np.float32) / 32768.0
        return torch.from_numpy(arr).to(self.device, dtype=self.dtype)

    def preemphasis_torch(self, x: torch.Tensor, coeff: float = 0.97):
        if x.numel() == 0:
            return x
        return torch.cat([x[:1], x[1:] - coeff * x[:-1]])

    def tts_playback_worker(self, text, out_path="output.wav"):
        try:
            tts_lang = "en" if self.args.task == "translate" else None
            self.tts.tts_to_file(text=text, file_path=out_path,
                                 speaker=self.default_speaker, language=tts_lang)
            data, sr = sf.read(out_path)
            sd.play(data, sr)
            sd.wait()
            os.remove(out_path)
        except Exception:
            logger.exception("TTS/playback failed for text: %s", text)

    def run(self):
        logger.info("Starting real-time loop. Press Stop to interrupt.")
        ring = []
        silence_count = 0
        try:
            with sd.RawInputStream(
                samplerate=self.sample_rate,
                blocksize=self.chunk_samples,
                dtype="int16",
                channels=1,
                callback=self.audio_callback,
            ):
                while self.running.is_set():
                    try:
                        frame = self.q.get(timeout=0.2)
                    except queue.Empty:
                        continue

                    is_speech = self.vad.is_speech(frame, self.sample_rate)
                    tensor_frame = self.bytes_to_tensor(frame)

                    if is_speech:
                        ring.append(tensor_frame)
                        silence_count = 0
                    else:
                        if len(ring) > 0:
                            silence_count += 1
                            if silence_count < self.silence_threshold:
                                ring.append(tensor_frame)
                                continue

                            speech = torch.cat(ring)
                            ring.clear()
                            silence_count = 0

                            speech = self.preemphasis_torch(speech)
                            inputs = self.processor(
                                speech.cpu().numpy(),
                                sampling_rate=self.sample_rate,
                                return_tensors="pt"
                            )
                            input_features = inputs.input_features.to(self.device, dtype=self.dtype)

                            with torch.no_grad():
                                generated_ids = self.model.generate(input_features, **self.gen_kw)

                            text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
                            if not text:
                                continue

                            tlow = text.lower().strip(" .!,?")
                            if tlow in IGNORE_SET or (len(tlow.split()) <= 1 and len(tlow) < 3):
                                continue

                            ts = datetime.utcnow().isoformat()
                            logger.info("Result [%s]: %s", ts, text)

                            save_transcript(self.args.output_csv, ts, self.args.lang, self.args.task, text)

                            if not self.no_tts and self.tts:
                                outpath = f"output_{int(time.time()*1000)}.wav"
                                self.executor.submit(self.tts_playback_worker, text, outpath)

        except Exception:
            logger.exception("Unhandled exception in main loop.")
        finally:
            self.shutdown()

    def shutdown(self):
        logger.info("Shutting down tasks...")
        self.running.clear()
        self.executor.shutdown(wait=True)
        logger.info("Shutdown complete..")


In [None]:
args = parse_args()
rt = RealtimeTranslator(args)
rt.run()