<a href="https://colab.research.google.com/github/Codeblockz/gemma-fastRTC-Example/blob/main/GemmaFastRTC_Complete.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Gemma Voice Assistant with FastRTC

This notebook implements a voice-based conversational AI assistant using:
- Gemma 3 model from Hugging Face for text generation
- FastRTC for real-time audio communication
- Speech-to-text and text-to-speech capabilities
- Conversation history tracking

## 1. Installation

First, let's install the required packages:

In [1]:
!pip install torch transformers fastrtc fastrtc[vad] fastrtc[stt] fastrtc[tts] python-dotenv gradio twilio
!pip install git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3
!pip install accelerate

Collecting fastrtc
  Downloading fastrtc-0.0.16-py3-none-any.whl.metadata (13 kB)
Collecting python-dotenv
  Downloading python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB)
Collecting gradio
  Downloading gradio-5.22.0-py3-none-any.whl.metadata (16 kB)
Collecting twilio
  Downloading twilio-9.5.0-py2.py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from to

## 2. Imports and Configuration

In [1]:
import os
import base64
import json
import logging
from pathlib import Path
import numpy as np
from typing import List, Dict, Any, Tuple, Optional, Union, Generator
import asyncio
import time

import gradio as gr
from fastapi import FastAPI
from fastapi.responses import HTMLResponse, StreamingResponse
from pydantic import BaseModel
from fastrtc import (
    AdditionalOutputs,
    ReplyOnStopWords,
    Stream,
    get_stt_model,
    get_tts_model,
    get_twilio_turn_credentials,
)
from gradio.utils import get_space
from dotenv import load_dotenv
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

## 3. Environment Variables

In [2]:
# Set environment variables
from google.colab import userdata
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')  # Replace with your actual token
os.environ["USE_CUDA"] = "true"  # Set to "false" if you don't have a GPU
os.environ["USE_TWILIO"] = "true"
os.environ["TWILIO_ACCOUNT_SID"] = userdata.get('TWILIO_ACCOUNT_SID')
os.environ["TWILIO_AUTH_TOKEN"] = userdata.get('TWILIO_AUTH_TOKEN')
os.environ["SPEAKER_ID"] = "0"
os.environ["MODE"] = "UI"  # Options: "UI" for Gradio, "WEB" for FastAPI web interface

# Load environment variables
load_dotenv()

# Constants
SPEAKER_ID = int(os.environ.get("SPEAKER_ID", "0"))
MAX_CONTEXT_TURNS = 10
MODEL_ID = "google/gemma-3-1b-it"
DEVICE = "cuda" if os.environ.get("USE_CUDA", "true").lower() == "true" else "cpu"
MAX_LENGTH = 256

# Get current directory
curr_dir = Path.cwd()

## 4. Gemma Model Implementation

In [3]:
class GemmaModel:
    """Class to handle Gemma 3 model interactions."""

    def __init__(self, model_id: str = MODEL_ID, device: str = DEVICE):
        self.model_id = model_id
        self.device = device
        self.max_length = MAX_LENGTH

        # Check if HF_TOKEN is set
        if not os.environ.get("HF_TOKEN"):
            logging.warning("HF_TOKEN not found in environment variables. Model loading may fail.")

        try:
            logging.info(f"Loading tokenizer for {model_id}...")
            self.tokenizer = AutoTokenizer.from_pretrained(model_id)

            logging.info(f"Loading model {model_id} to {device}...")
            self.model = AutoModelForCausalLM.from_pretrained(
                model_id,
                device_map=device
            )
            logging.info("Model and tokenizer loaded successfully")
        except Exception as e:
            logging.error(f"Error loading model: {e}")
            raise

    def generate_response(self, user_input: str, conversation_history: List[Dict[str, str]] = None) -> str:
        try:
            # Initialize or update conversation
            if conversation_history is None:
                messages = [{"role": "user", "content": user_input}]
            else:
                # Copy conversation history and add new user message
                messages = conversation_history.copy()
                messages.append({"role": "user", "content": user_input})

            # Format messages for the model
            input_ids = self.tokenizer.apply_chat_template(
                messages,
                add_generation_prompt=True,
                return_tensors="pt"
            ).to(self.device)

            # Set generation config
            generation_config = GenerationConfig(
                max_new_tokens=self.max_length,
                temperature=0.7,
                top_p=0.9,
                do_sample=True
            )

            # Generate response
            outputs = self.model.generate(
                input_ids,
                generation_config=generation_config
            )

            # Decode the output
            decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=False)

            # Extract assistant's response
            start_token = "<start_of_turn>model"
            end_token = "<end_of_turn>"

            start_index = decoded_output.find(start_token)
            if start_index != -1:
                start_index += len(start_token)
                end_index = decoded_output.find(end_token, start_index)
                if end_index != -1:
                    assistant_response = decoded_output[start_index:end_index].strip()
                else:
                    assistant_response = decoded_output[start_index:].strip()
                return assistant_response

            # Fallback if couldn't find the tokens
            return decoded_output.split("<start_of_turn>model")[-1].split("<end_of_turn>")[0].strip()

        except Exception as e:
            logging.error(f"Error generating response: {e}")
            return "I'm sorry, I encountered an error generating a response."

## 5. Audio Processing Implementation

In [None]:
class AudioProcessor:
    """Class to handle speech-to-text and text-to-speech operations."""

    def __init__(self):
        try:
            logging.info("Loading STT model...")
            self.stt_model = get_stt_model()
            logging.info("STT model loaded successfully")

            logging.info("Loading TTS model...")
            self.tts_model = get_tts_model()
            logging.info("TTS model loaded successfully")
        except Exception as e:
            logging.error(f"Error loading audio models: {e}")
            raise

    def speech_to_text(self, audio: Tuple[int, np.ndarray]) -> str:
        try:
            logging.info("Converting speech to text...")
            sample_rate, audio_array = audio

            # Ensure audio is in the correct format
            if len(audio_array.shape) > 1:
                audio_array = audio_array.squeeze()

            # Convert to text
            text = self.stt_model.stt((sample_rate, audio_array))
            logging.info(f"Speech transcribed: '{text}'")
            return text
        except Exception as e:
            logging.error(f"Error in speech to text conversion: {e}")
            return ""

    def text_to_speech(self, text: str, speaker_id: int = SPEAKER_ID) -> Tuple[int, np.ndarray]:
        try:
            logging.info(f"Converting text to speech: '{text}'")

            # Generate speech
            sample_rate, audio_array = self.tts_model.tts(text, speaker_id=speaker_id)
            logging.info(f"Text converted to speech. Audio length: {len(audio_array) / sample_rate:.2f}s")

            return sample_rate, audio_array
        except Exception as e:
            logging.error(f"Error in text to speech conversion: {e}")
            # Return a silent audio segment as fallback
            return 16000, np.zeros(16000, dtype=np.float32)

## 6. FastAPI Implementation

In [4]:
# Create a simple HTML template (simplified version)
html_template = """
<!DOCTYPE html>
<html>
<head>
    <title>Gemma Voice Assistant</title>
    <style>
        body { font-family: sans-serif; margin: 0; padding: 20px; }
        .chat-container { height: 400px; overflow-y: auto; border: 1px solid #ddd; margin-bottom: 20px; padding: 10px; }
        .message { margin-bottom: 10px; padding: 8px; border-radius: 5px; }
        .user { background-color: #e3e9f2; margin-left: auto; max-width: 80%; }
        .assistant { background-color: #4a6fa5; color: white; max-width: 80%; }
        .controls { display: flex; flex-direction: column; gap: 10px; }
        .btn { padding: 10px; border: none; border-radius: 5px; background-color: #6c5ce7; color: white; cursor: pointer; }
    </style>
</head>
<body>
    <h1>Gemma Voice Assistant</h1>
    <div class="chat-container" id="chat"></div>
    <div class="controls">
        <button id="mic-btn" class="btn">Start Recording</button>
        <button id="clear-btn" class="btn">Clear Conversation</button>
    </div>
    <p id="status">Ready</p>
    <script>
        // WebRTC configuration
        const rtcConfiguration = __RTC_CONFIGURATION__;

        // Basic functionality - full implementation in the actual HTML file
        const chatContainer = document.getElementById('chat');
        const micButton = document.getElementById('mic-btn');
        const clearButton = document.getElementById('clear-btn');
        const statusElement = document.getElementById('status');

        let webrtcId = null;
        let isRecording = false;
        let chatHistory = [];
        let conversationState = [];

        // Event listeners and WebRTC setup would be implemented here
    </script>
</body>
</html>
"""

# Define Pydantic models for API
class Message(BaseModel):
    role: str
    content: str

class InputData(BaseModel):
    webrtc_id: str
    chatbot: List[Message]
    state: List[Message]
    stop: bool = False

## 7. Main Application Setup

In [5]:
# Define response handler
def response_handler(
    audio: Tuple[int, np.ndarray],
    gradio_chatbot: List[Dict] = None,
    conversation_state: List[Dict] = None,
):
    # Initialize if None
    gradio_chatbot = gradio_chatbot or []
    conversation_state = conversation_state or []

    # Get STT model
    stt_model = get_stt_model()

    # Process audio input (speech to text)
    text = stt_model.stt(audio)
    logging.info(f"STT in handler: '{text}'")

    # Add user message to UI
    sample_rate, array = audio
    gradio_chatbot.append(
        {"role": "user", "content": gr.Audio((sample_rate, array.squeeze()))}
    )

    # First yield to update UI with user's message
    yield AdditionalOutputs(gradio_chatbot, conversation_state)

    # Add user message to conversation state
    conversation_state.append({"role": "user", "content": text})
    # Limit conversation history
    if len(conversation_state) > MAX_CONTEXT_TURNS * 2:
        conversation_state = conversation_state[-MAX_CONTEXT_TURNS * 2:]

    # Initialize Gemma model
    gemma_model = GemmaModel(device=DEVICE)

    # Generate response from Gemma
    response_text = gemma_model.generate_response(
        text,
        conversation_history=conversation_state
    )

    # Add assistant response to conversation state
    assistant_response = {"role": "assistant", "content": response_text}
    conversation_state.append(assistant_response)
    gradio_chatbot.append(assistant_response)

    # Final yield with complete response
    yield AdditionalOutputs(gradio_chatbot, conversation_state)

# Helper function for API
def audio_to_base64(file_path):
    audio_format = "wav"
    with open(file_path, "rb") as audio_file:
        encoded_audio = base64.b64encode(audio_file.read()).decode("utf-8")
    return f"data:audio/{audio_format};base64,{encoded_audio}"

#test

In [10]:
def response_handler(
    audio: Tuple[int, np.ndarray],
    gradio_chatbot: List[Dict] | None = None,
    conversation_state: List[Dict] | None = None,
):
  # Initialize if None
  gradio_chatbot = gradio_chatbot or []
  conversation_state = conversation_state or []
  # Convert speech to text using STT model
  text = stt_model.stt(audio)
  if not text.strip():
      return

  # Generate response from Gemma
  response_text = gemma_model.generate_response(
      text,
      conversation_history=conversation_state
  )

  # Add assistant response to conversation state
  assistant_response = {"role": "assistant", "content": response_text}
  conversation_state.append(assistant_response)
  gradio_chatbot.append(assistant_response)
  yield AdditionalOutputs(gradio_chatbot, conversation_state)

  for audio_chunk in tts_model.stream_tts_sync(response_text or ""):
        # Yield the audio chunk
        yield audio_chunk



## 8. FastAPI App Setup

In [7]:
def setup_fastapi_app():
    # Setup FastAPI app
    app = FastAPI()

    # Setup FastRTC stream
    # Instead of None, provide empty gradio components for additional_inputs
    stream = Stream(
        ReplyOnStopWords(
            response_handler,
            stop_words=["stop", "goodbye"],
            input_sample_rate=16000,
        ),
        mode="send-receive",
        modality="audio",
        additional_inputs=[gr.Textbox(visible=False), gr.Textbox(visible=False)],  # Changed to empty Textboxes
        additional_outputs=[gr.Textbox(visible=False), gr.Textbox(visible=False)],
        additional_outputs_handler=lambda *a: (a[2], a[3]),
        concurrency_limit=5 if get_space() else None,
        time_limit=90 if get_space() else None,
        rtc_configuration=get_twilio_turn_credentials(),
    )


    # Mount the stream
    stream.mount(app)

    # Define API endpoints
    @app.get("/")
    async def root():
        rtc_config = get_twilio_turn_credentials() if get_space() else None
        html_content = html_template.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
        return HTMLResponse(content=html_content)

    @app.post("/input_hook")
    async def input_hook(data: InputData):
        body = data.model_dump()
        stream.set_input(data.webrtc_id, body["chatbot"], body["state"])

    @app.get("/outputs")
    async def outputs(webrtc_id: str):
        async def output_stream():
            async for output in stream.output_stream(webrtc_id):
                chatbot = output.args[0]
                state = output.args[1]

                if not chatbot or not state or not state[-1]:
                    continue

                data = {
                    "message": state[-1],
                    "audio": audio_to_base64(chatbot[-1]["content"].value["path"])
                    if (chatbot and chatbot[-1]["role"] == "user" and
                        isinstance(chatbot[-1]["content"], gr.Audio) and
                        hasattr(chatbot[-1]["content"], "value"))
                    else None,
                }
                yield f"event: output\ndata: {json.dumps(data)}\n\n"

        return StreamingResponse(output_stream(), media_type="text/event-stream")

    return app, stream

## 9. Launch Application

In [8]:
def launch_app():
    import uvicorn

    app, stream = setup_fastapi_app()
    mode = os.environ.get("MODE", "UI")
    stt_model = get_stt_model()
    tts_model = get_tts_model()
    gemma_model = GemmaModel(device=DEVICE)

    if mode == "UI":
        logging.info("Starting FastRTC with Gradio UI")
        stream.ui.launch(server_port=7860)
    else:
        logging.info("Starting FastRTC with FastAPI")
        import nest_asyncio
        nest_asyncio.apply()  # This line is added
        uvicorn.run(app, host="0.0.0.0", port=7860)

# Uncomment the line below to launch the application
#launch_app()

In [11]:
launch_app()

OSError: Cannot find empty port in range: 7860-7860. You can specify a different port by setting the GRADIO_SERVER_PORT environment variable or passing the `server_port` parameter to `launch()`.

## 10. Usage Example

In [None]:
# Example: Create and test a Gemma model instance
def test_gemma_model():
    # Replace with your actual Hugging Face token
    os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')

    # Use CPU for testing if no GPU available
    test_device = "cuda"

    print("Initializing Gemma model...")
    model = GemmaModel(device=test_device)

    test_input = "Hello, how are you today?"
    print(f"\nTest input: '{test_input}'")

    print("Generating response...")
    response = model.generate_response(test_input)

    print(f"\nGemma response: '{response}'")

    return response

# Uncomment the line below to test the Gemma model
test_gemma_model()

Initializing Gemma model...

Test input: 'Hello, how are you today?'
Generating response...

Gemma response: 'Hello there! I’m doing well, thank you for asking! As an AI, I don’t really experience feelings like humans do, but I’m functioning perfectly and ready to help you with whatever you need. 

How are *you* doing today? 😊 

Do you want to chat about something specific, or perhaps need some help with a task?'


'Hello there! I’m doing well, thank you for asking! As an AI, I don’t really experience feelings like humans do, but I’m functioning perfectly and ready to help you with whatever you need. \n\nHow are *you* doing today? 😊 \n\nDo you want to chat about something specific, or perhaps need some help with a task?'

## How to Use

1. Set your Hugging Face API token in the environment variables cell
2. Run all cells up to the launch_app() function
3. Uncomment and run the launch_app() line to start the application
4. Access the interface in your browser at http://localhost:7860

You can choose between two interfaces by setting the MODE environment variable:
- "UI" for Gradio interface
- "WEB" for FastAPI web interface