<a href="https://colab.research.google.com/github/Janak-Sh/2021-naamii-rl-practical/blob/main/end_to_end_speech.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
"""
Fixed Speech Translation with Proper Error Handling
English ↔ Persian (Farsi)

This version includes fixes for common Whisper errors:
- Proper audio preprocessing
- Device handling
- Memory management
- Input validation
"""

import torch
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    VitsModel,
    VitsTokenizer
)
import soundfile as sf
import numpy as np
from scipy.io import wavfile
from scipy import signal
import warnings
import time
warnings.filterwarnings('ignore')


class SpeechTranslatorFixed:
    def __init__(self,
                 whisper_model_name="openai/whisper-large-v3",
                 device="cuda" if torch.cuda.is_available() else "cpu",
                 use_fp16=True):
        """
        Initialize with proper error handling
        """
        self.device = device
        self.use_fp16 = use_fp16 and device == "cuda"

        print(f"🔧 Using device: {self.device}")
        if self.device == "cuda":
            print(f"   GPU: {torch.cuda.get_device_name(0)}")
            print(f"   VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
            if self.use_fp16:
                print(f"   Using FP16 for memory efficiency")
        print()

        # Initialize Whisper
        print(f"📥 Loading Whisper: {whisper_model_name}...")
        start = time.time()
        self.whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)

        # Use FP16 on GPU to save memory
        if self.use_fp16:
            self.whisper_model = WhisperForConditionalGeneration.from_pretrained(
                whisper_model_name,
                torch_dtype=torch.float16
            ).to(self.device)
        else:
            self.whisper_model = WhisperForConditionalGeneration.from_pretrained(
                whisper_model_name
            ).to(self.device)

        print(f"   ✅ Loaded in {time.time() - start:.2f}s")

        # Initialize NLLB
        print("📥 Loading NLLB...")
        start = time.time()
        self.nllb_tokenizer = AutoTokenizer.from_pretrained(
            "facebook/nllb-200-distilled-600M"
        )
        self.nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
            "facebook/nllb-200-distilled-600M"
        ).to(self.device)
        print(f"   ✅ Loaded in {time.time() - start:.2f}s")

        # Initialize TTS
        print("📥 Loading TTS models...")
        start = time.time()

        self.tts_eng_model = VitsModel.from_pretrained(
            "facebook/mms-tts-eng"
        ).to(self.device)
        self.tts_eng_tokenizer = VitsTokenizer.from_pretrained(
            "facebook/mms-tts-eng"
        )

        self.tts_fas_model = VitsModel.from_pretrained(
            "facebook/mms-tts-fas"
        ).to(self.device)
        self.tts_fas_tokenizer = VitsTokenizer.from_pretrained(
            "facebook/mms-tts-fas"
        )
        print(f"   ✅ Loaded in {time.time() - start:.2f}s\n")

    def validate_and_preprocess_audio(self, audio_path):
        """
        Validate and preprocess audio with proper error handling
        """
        # Load audio
        try:
            audio, sr = sf.read(audio_path)
        except Exception as e:
            raise ValueError(f"Failed to load audio file: {e}")

        # Validate not empty
        if len(audio) == 0:
            raise ValueError("Audio file is empty!")

        # Calculate duration
        duration = len(audio) / sr
        if duration < 0.1:
            raise ValueError(f"Audio too short: {duration:.2f}s (minimum 0.1s)")

        # Convert stereo to mono (IMPORTANT!)
        if len(audio.shape) > 1:
            audio = audio.mean(axis=1)

        # Ensure float32
        audio = audio.astype(np.float32)

        # Check for invalid values
        if np.isnan(audio).any():
            raise ValueError("Audio contains NaN values!")
        if np.isinf(audio).any():
            raise ValueError("Audio contains Inf values!")

        # Normalize if needed
        max_val = np.abs(audio).max()
        if max_val > 1.0:
            audio = audio / max_val

        # Resample to 16kHz if needed
        if sr != 16000:
            audio = signal.resample(
                audio,
                int(len(audio) * 16000 / sr)
            )
            sr = 16000

        return audio, sr, duration

    def transcribe_audio(self, audio_path, source_language="en"):
        """
        Transcribe with proper error handling
        """
        start_time = time.time()
        print(f"🎧 Transcribing: {audio_path} ({source_language})")

        # Validate and preprocess
        audio, sr, duration = self.validate_and_preprocess_audio(audio_path)
        print(f"   Duration: {duration:.2f}s")

        # Process audio
        input_features = self.whisper_processor(
            audio,
            sampling_rate=16000,
            return_tensors="pt"
        ).input_features

        # Move to device with proper dtype
        if self.use_fp16:
            input_features = input_features.to(self.device).half()
        else:
            input_features = input_features.to(self.device)

        # Set language
        language_code = "en" if source_language == "en" else "fa"
        forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids(
            language=language_code,
            task="transcribe"
        )

        # Generate
        with torch.no_grad():
            predicted_ids = self.whisper_model.generate(
                input_features,
                forced_decoder_ids=forced_decoder_ids,
                max_length=448
            )

        transcription = self.whisper_processor.batch_decode(
            predicted_ids,
            skip_special_tokens=True
        )[0]

        time_taken = time.time() - start_time
        print(f"   ✅ Transcription: {transcription}")
        print(f"   ⏱️  Time: {time_taken:.2f}s (RTF: {time_taken/duration:.2f}x)\n")

        return transcription, time_taken

    def translate_text(self, text, source_lang="eng_Latn", target_lang="pes_Arab"):
        """
        Translate text
        """
        start_time = time.time()
        print(f"🔄 Translating: {text}")

        self.nllb_tokenizer.src_lang = source_lang
        inputs = self.nllb_tokenizer(
            text,
            return_tensors="pt",
            padding=True
        ).to(self.device)

        forced_bos_token_id = self.nllb_tokenizer.convert_tokens_to_ids(target_lang)

        with torch.no_grad():
            generated_tokens = self.nllb_model.generate(
                **inputs,
                forced_bos_token_id=forced_bos_token_id,
                max_length=512
            )

        translation = self.nllb_tokenizer.batch_decode(
            generated_tokens,
            skip_special_tokens=True
        )[0]

        time_taken = time.time() - start_time
        print(f"   ✅ Translation: {translation}")
        print(f"   ⏱️  Time: {time_taken:.2f}s\n")

        return translation, time_taken

    def synthesize_speech(self, text, output_path, target_language="en"):
        """
        Synthesize speech
        """
        start_time = time.time()
        print(f"🔊 Synthesizing: {text}")

        if target_language == "en":
            tts_model = self.tts_eng_model
            tts_tokenizer = self.tts_eng_tokenizer
        else:
            tts_model = self.tts_fas_model
            tts_tokenizer = self.tts_fas_tokenizer

        inputs = tts_tokenizer(text, return_tensors="pt").to(self.device)

        with torch.no_grad():
            output = tts_model(**inputs)

        waveform = output.waveform[0].cpu().numpy()

        wavfile.write(
            output_path,
            rate=tts_model.config.sampling_rate,
            data=waveform
        )

        duration = len(waveform) / tts_model.config.sampling_rate
        time_taken = time.time() - start_time
        print(f"   ✅ Saved: {output_path}")
        print(f"   Duration: {duration:.2f}s")
        print(f"   ⏱️  Time: {time_taken:.2f}s\n")

        return output_path, time_taken

    def create_12s_sample(self, text, output_path, language="en"):
        """
        Create 12-second sample audio
        """
        print(f"🎵 Creating 12s sample: {language}")
        path, _ = self.synthesize_speech(text, output_path, language)

        # Verify duration
        audio, sr = sf.read(path)
        duration = len(audio) / sr
        print(f"   ✅ Sample duration: {duration:.2f}s\n")

        return path, duration

    def translate_speech_en_to_fa(self, input_path, output_path):
        """
        English → Persian translation
        """
        total_start = time.time()
        print("="*70)
        print("🇺🇸 ➡️ 🇮🇷  ENGLISH → PERSIAN")
        print("="*70 + "\n")

        transcription, asr_time = self.transcribe_audio(input_path, "en")
        translation, mt_time = self.translate_text(
            transcription, "eng_Latn", "pes_Arab"
        )
        output, tts_time = self.synthesize_speech(translation, output_path, "fa")

        total_time = time.time() - total_start

        print("="*70)
        print("✅ COMPLETE!")
        print("="*70)
        print(f"⏱️  ASR: {asr_time:.2f}s | MT: {mt_time:.2f}s | TTS: {tts_time:.2f}s")
        print(f"   TOTAL: {total_time:.2f}s")
        print("="*70 + "\n")

        return {
            "transcription": transcription,
            "translation": translation,
            "output_audio": output,
            "timing": {
                "asr": asr_time,
                "mt": mt_time,
                "tts": tts_time,
                "total": total_time
            }
        }

    def translate_speech_fa_to_en(self, input_path, output_path):
        """
        Persian → English translation
        """
        total_start = time.time()
        print("="*70)
        print("🇮🇷 ➡️ 🇺🇸  PERSIAN → ENGLISH")
        print("="*70 + "\n")

        transcription, asr_time = self.transcribe_audio(input_path, "fa")
        translation, mt_time = self.translate_text(
            transcription, "pes_Arab", "eng_Latn"
        )
        output, tts_time = self.synthesize_speech(translation, output_path, "en")

        total_time = time.time() - total_start

        print("="*70)
        print("✅ COMPLETE!")
        print("="*70)
        print(f"⏱️  ASR: {asr_time:.2f}s | MT: {mt_time:.2f}s | TTS: {tts_time:.2f}s")
        print(f"   TOTAL: {total_time:.2f}s")
        print("="*70 + "\n")

        return {
            "transcription": transcription,
            "translation": translation,
            "output_audio": output,
            "timing": {
                "asr": asr_time,
                "mt": mt_time,
                "tts": tts_time,
                "total": total_time
            }
        }




In [2]:

    # Initialize
    print("🚀 Initializing (with FP16 for memory efficiency)...\n")
    translator = SpeechTranslatorFixed(
        whisper_model_name="openai/whisper-large-v3",
        use_fp16=True  # Use FP16 to save memory
    )


🚀 Initializing (with FP16 for memory efficiency)...

🔧 Using device: cuda
   GPU: Tesla T4
   VRAM: 15.64 GB
   Using FP16 for memory efficiency

📥 Loading Whisper: openai/whisper-large-v3...


preprocessor_config.json:   0%|          | 0.00/340 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

normalizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]



special_tokens_map.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

Loading weights:   0%|          | 0/1259 [00:00<?, ?it/s]

generation_config.json: 0.00B [00:00, ?B/s]

   ✅ Loaded in 30.64s
📥 Loading NLLB...


config.json:   0%|          | 0.00/846 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/564 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.3M [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.46G [00:00<?, ?B/s]

Loading weights:   0%|          | 0/512 [00:00<?, ?it/s]



generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.46G [00:00<?, ?B/s]

   ✅ Loaded in 37.91s
📥 Loading TTS models...


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/145M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/762 [00:00<?, ?it/s]

tokenizer_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/413 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/47.0 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/145M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/762 [00:00<?, ?it/s]

tokenizer_config.json:   0%|          | 0.00/288 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/517 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

   ✅ Loaded in 46.73s



In [3]:

# English sample (~12s)
text_en = (
    "Good morning everyone. Welcome to our artificial intelligence presentation. "
    "Today we will discuss the latest developments in machine learning and natural "
    "language processing. These technologies are transforming how we interact with "
    "computers and are opening up new possibilities for the future."
)

# Persian sample (~12s)
text_fa = (
    "سلام و درود بر همه شما. امروز می‌خواهیم در مورد پیشرفت‌های جدید در زمینه هوش مصنوعی "
    "صحبت کنیم. این فناوری‌ها در حال تغییر دادن نحوه تعامل ما با رایانه‌ها هستند و "
    "امکانات جدیدی را برای آینده فراهم می‌کنند. یادگیری ماشین و پردازش زبان طبیعی "
    "دو حوزه مهم در این زمینه محسوب می‌شوند."
)

# Create samples
print("="*70)
print("CREATING SAMPLES")
print("="*70 + "\n")

translator.create_12s_sample(text_en, "sample_en_12s.wav", "en")
translator.create_12s_sample(text_fa, "sample_fa_12s.wav", "fa")

# Test translations
result1 = translator.translate_speech_en_to_fa(
    "sample_en_12s.wav",
    "output_fa_12s.wav"
)

result2 = translator.translate_speech_fa_to_en(
    "sample_fa_12s.wav",
    "output_en_12s.wav"
)

# Summary
print("="*70)
print("📊 SUMMARY")
print("="*70)
print(f"EN→FA: {result1['timing']['total']:.2f}s")
print(f"FA→EN: {result2['timing']['total']:.2f}s")
print("="*70)




CREATING SAMPLES

🎵 Creating 12s sample: en
🔊 Synthesizing: Good morning everyone. Welcome to our artificial intelligence presentation. Today we will discuss the latest developments in machine learning and natural language processing. These technologies are transforming how we interact with computers and are opening up new possibilities for the future.
   ✅ Saved: sample_en_12s.wav
   Duration: 18.77s
   ⏱️  Time: 2.59s

   ✅ Sample duration: 18.77s

🎵 Creating 12s sample: fa
🔊 Synthesizing: سلام و درود بر همه شما. امروز می‌خواهیم در مورد پیشرفت‌های جدید در زمینه هوش مصنوعی صحبت کنیم. این فناوری‌ها در حال تغییر دادن نحوه تعامل ما با رایانه‌ها هستند و امکانات جدیدی را برای آینده فراهم می‌کنند. یادگیری ماشین و پردازش زبان طبیعی دو حوزه مهم در این زمینه محسوب می‌شوند.


Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of the `task` and `language` flags/config options.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


   ✅ Saved: sample_fa_12s.wav
   Duration: 21.15s
   ⏱️  Time: 0.46s

   ✅ Sample duration: 21.15s

🇺🇸 ➡️ 🇮🇷  ENGLISH → PERSIAN

🎧 Transcribing: sample_en_12s.wav (en)
   Duration: 18.77s


A custom logits processor of type <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> will take precedence. Please check the docstring of <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> to see related `.generate()` flags.
A custom logits processor of type <class 'transformers.generation.logits_process.SuppressTokensAtBeginLogitsProcessor'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'transformers.generation.logits_process.SuppressTokensAtBeginLogitsProcessor'> will take precedence. Please check the docstring of <class 'transformers.generation.logits_process.SuppressTokensAtBeginLogitsProcessor'> to see related `.generate()` flags.


   ✅ Transcription:  Good morning everyone. Welcome to our artificial intelligence presentation. Today we will discuss the latest developments in matching learning and natural language. Processing these technologies are transforming how we interact with computers and are opening up new possibilities for the future.
   ⏱️  Time: 2.04s (RTF: 0.11x)

🔄 Translating:  Good morning everyone. Welcome to our artificial intelligence presentation. Today we will discuss the latest developments in matching learning and natural language. Processing these technologies are transforming how we interact with computers and are opening up new possibilities for the future.
   ✅ Translation: سلام به همه. به نمایشگاه هوش مصنوعی خوش آمدید. امروز ما در مورد آخرین پیشرفت های یادگیری تطابق و زبان طبیعی صحبت خواهیم کرد. پردازش این فناوری ها نحوه تعامل ما با کامپیوتر را تغییر می دهد و فرصت های جدیدی را برای آینده باز می کند.
   ⏱️  Time: 0.92s

🔊 Synthesizing: سلام به همه. به نمایشگاه هوش مصنوعی خوش آمدید. امروز 

In [6]:
from IPython.display import Audio

Audio("output_fa_12s.wav")
