<a href="https://colab.research.google.com/github/Dimildizio/DS_course/blob/main/Neural_networks/Multimodal/stt_tts_converter.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# STT-TTS converter

In case you ever need to deal with someone's terrible accent

## Install App

In [1]:

%%capture


from IPython.display import clear_output, display, HTML
display(HTML('''<style>
div.cell.code_cell.rendered.selected div.input {
    display: none;
}
</style>'''))


# Install required packages
!pip install transformers soundfile librosa sounddevice kokoro gradio

import numpy as np
import torch
import gradio as gr
import tempfile
import os
import io
import scipy.io.wavfile as wav
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from kokoro import KModel, KPipeline


# STT Class
class WhisperSTT:
    def __init__(self):
        print("Loading Whisper model...")
        self.processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
        self.model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = self.model.to(self.device)
        self.sample_rate = 16000
        print(f"Model loaded! Using device: {self.device}")

    def transcribe_from_file(self, audio_file):
        import librosa
        try:
            audio, _ = librosa.load(audio_file, sr=self.sample_rate, mono=True)
            processor_output = self.processor(
                audio, sampling_rate=self.sample_rate,
                return_tensors="pt", return_attention_mask=True
            )
            input_features = processor_output.input_features.to(self.device)
            attention_mask = processor_output.attention_mask.to(self.device)
            forced_decoder_ids = self.processor.get_decoder_prompt_ids(language="en", task="transcribe")
            predicted_ids = self.model.generate(
                input_features, attention_mask=attention_mask,
                forced_decoder_ids=forced_decoder_ids, max_length=128
            )
            return self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
        except Exception as e:
            return f"Error processing audio file: {str(e)}"

# TTS Class
class KokoroTTS:
    def __init__(self):
        print("Loading Kokoro TTS model...")
        self.model = KModel().to('cpu').eval()
        self.sample_rate = 24000
        self.pipelines = {
            'a': KPipeline(lang_code='a', model=False),  # American English
            'b': KPipeline(lang_code='b', model=False)   # British English
        }
        self.pipelines['a'].g2p.lexicon.golds['kokoro'] = 'kˈOkəɹO'
        self.pipelines['b'].g2p.lexicon.golds['kokoro'] = 'kˈQkəɹQ'
        self.voice_map = {
            'a': 'af_bella',  # default female
            'b': 'bm_lewis',  # default male
            'c': 'am_michael',
            'd': 'bf_emma',
            'e': 'af_heart'
        }
        for voice in set(self.voice_map.values()):
            self.pipelines[voice[0]].load_voice(voice)
        print("Kokoro TTS model loaded!")

    def generate_audio(self, text, voice_type='a'):
        try:
            voice = self.voice_map.get(voice_type, 'af_bella')
            pipeline = self.pipelines[voice[0]]
            for _, ps, _ in pipeline(text, voice, speed=1):
                ref_s = pipeline.load_voice(voice)[len(ps) - 1]
                audio = self.model(ps, ref_s, speed=1)
                return audio.numpy().astype(np.float32), self.sample_rate
        except Exception as e:
            print(f"Error generating audio: {e}")
            return None, None

# Main Pipeline
class VoicePipeline:
    def __init__(self):
        self.stt = WhisperSTT()
        self.tts = KokoroTTS()

    def process_audio_file(self, audio_file, voice_selection):
        if audio_file is None:
            return "No audio file uploaded", None
        try:
            transcription = self.stt.transcribe_from_file(audio_file)
            voice_key = voice_selection.split(':')[0].strip()
            audio_output, output_sample_rate = self.tts.generate_audio(transcription, voice_key)
            return transcription, (output_sample_rate, audio_output)
        except Exception as e:
            return f"Error processing audio file: {str(e)}", None

    def process_text_input(self, text, voice_selection):
        if not text:
            return "No text provided", None
        try:
            voice_key = voice_selection.split(':')[0].strip()
            audio_output, output_sample_rate = self.tts.generate_audio(text, voice_key)
            return text, (output_sample_rate, audio_output)
        except Exception as e:
            return f"Error processing text: {str(e)}", None

# Create Gradio Interface
def create_app():
    pipeline = VoicePipeline()
    voices = pipeline.tts.voice_map
    voice_choices = [f"{k}: {v}" for k, v in voices.items()]
    clear_output()

    with gr.Blocks(title="Voice Conversion Pipeline") as app:
        gr.Markdown("# Voice Conversion Pipeline")
        gr.Markdown("### Convert your voice to text and narrate it using a different voice")

        with gr.Tab("Audio File Input"):
            gr.Markdown("Upload an audio file to be transcribed and narrated")
            with gr.Row():
                with gr.Column():
                    file_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio File")
                    voice_dropdown_file = gr.Dropdown(choices=voice_choices, value=voice_choices[0], label="Select TTS Voice")
                    submit_file_btn = gr.Button("Process Audio File")
                with gr.Column():
                    text_output_file = gr.Textbox(label="Transcription")
                    audio_output_file = gr.Audio(label="Generated Speech", type="numpy")
            submit_file_btn.click(
                fn=pipeline.process_audio_file,
                inputs=[file_input, voice_dropdown_file],
                outputs=[text_output_file, audio_output_file]
            )

        with gr.Tab("Microphone Input"):
            gr.Markdown("Use your microphone to record audio, which will be transcribed and narrated")
            with gr.Row():
                with gr.Column():
                    mic_input = gr.Audio(sources=["microphone"], type="filepath", label="Microphone Input")
                    voice_dropdown_mic = gr.Dropdown(choices=voice_choices, value=voice_choices[0], label="Select TTS Voice")
                    submit_mic_btn = gr.Button("Process Recording")
                with gr.Column():
                    text_output_mic = gr.Textbox(label="Transcription")
                    audio_output_mic = gr.Audio(label="Generated Speech", type="numpy")
            submit_mic_btn.click(
                fn=pipeline.process_audio_file,
                inputs=[mic_input, voice_dropdown_mic],
                outputs=[text_output_mic, audio_output_mic]
            )

        with gr.Tab("Text Input"):
            gr.Markdown("Enter text directly to be narrated")
            with gr.Row():
                with gr.Column():
                    text_input = gr.Textbox(lines=5, label="Text Input")
                    voice_dropdown_text = gr.Dropdown(choices=voice_choices, value=voice_choices[0], label="Select TTS Voice")
                    submit_text_btn = gr.Button("Generate Speech")
                with gr.Column():
                    text_output_direct = gr.Textbox(label="Text")
                    audio_output_text = gr.Audio(label="Generated Speech", type="numpy")
            submit_text_btn.click(
                fn=pipeline.process_text_input,
                inputs=[text_input, voice_dropdown_text],
                outputs=[text_output_direct, audio_output_text]
            )

    return app


## Run App

In [2]:
app = create_app()
app.launch(share=True, inline=False)

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://67338680de83d6e769.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
