In [None]:
!pip install --upgrade --quiet speechrecognition
!pip install --upgrade --quiet accelerate
!pip install --upgrade --quiet bitsandbytes
!pip install --upgrade --quiet transformers

In [None]:
!sudo apt update
!sudo apt install portaudio19-dev python3-pyaudio


In [None]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Speech-to-Text Medical Assistant using Google MedGemma
This application integrates speech recognition with MedGemma for medical Q&A.
"""

import os
import sys
import speech_recognition as sr
import pyaudio
import torch
from transformers import BitsAndBytesConfig, pipeline, AutoModelForCausalLM, AutoTokenizer
from IPython.display import Markdown, display
from PIL import Image
import warnings

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

class SpeechToTextMedGemma:
    """
    A comprehensive speech-to-text medical assistant using Google MedGemma.
    Supports both text-only and multimodal (image + text) interactions.
    """

    def __init__(self, model_variant="4b-it", use_quantization=True, is_thinking=False):
        """
        Initialize the Speech-to-Text MedGemma system.

        Args:
            model_variant (str): "4b-it" or "27b-text-it"
            use_quantization (bool): Whether to use 4-bit quantization
            is_thinking (bool): Enable thinking mode (27B variant only)
        """
        self.model_variant = model_variant
        self.model_id = f"google/medgemma-{model_variant}"
        self.use_quantization = use_quantization
        self.is_thinking = is_thinking
        self.google_colab = "google.colab" in sys.modules and not os.environ.get("VERTEX_PRODUCT")

        # Initialize speech recognition
        self.recognizer = sr.Recognizer()
        self.microphone = sr.Microphone()

        # Setup authentication and model
        self._setup_authentication()
        self._install_dependencies()
        self._setup_model()

    def _setup_authentication(self):
        """Setup Hugging Face authentication following the official pattern"""
        if self.google_colab:
            # Use secret if running in Google Colab
            from google.colab import userdata
            os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
        else:
            # Store Hugging Face data under `/content` if running in Colab Enterprise
            if os.environ.get("VERTEX_PRODUCT") == "COLAB_ENTERPRISE":
                os.environ["HF_HOME"] = "/content/hf"
            # Authenticate with Hugging Face
            from huggingface_hub import get_token
            if get_token() is None:
                from huggingface_hub import notebook_login
                notebook_login()

    def _install_dependencies(self):
        """Install required dependencies"""
        dependencies = [
            "speechrecognition",
            "pyaudio",
            "accelerate",
            "bitsandbytes",
            "transformers"
        ]

        for dep in dependencies:
            os.system(f"pip install --upgrade --quiet {dep}")

    def _setup_model(self):
        """Setup the MedGemma model following the official configuration"""
        print("Setting up MedGemma model...")

        # Check memory requirements for 27B model
        if "27b" in self.model_variant and self.google_colab:
            if not ("A100" in torch.cuda.get_device_name(0) and self.use_quantization):
                raise ValueError(
                    "Runtime has insufficient memory to run the 27B variant. "
                    "Please select an A100 GPU and use 4-bit quantization."
                )

        # Model configuration following official pattern
        self.model_kwargs = dict(
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )

        if self.use_quantization:
            self.model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)

        # Check if this is a text-only variant
        self.is_text_only = "text" in self.model_variant

        # Setup pipeline for text generation
        self.text_pipe = pipeline(
            "text-generation",
            model=self.model_id,
            model_kwargs=self.model_kwargs,
        )
        self.text_pipe.model.generation_config.do_sample = False

        # Setup multimodal pipeline if not text-only
        if not self.is_text_only:
            try:
                self.image_text_pipe = pipeline(
                    "image-text-to-text",
                    model=self.model_id,
                    model_kwargs=self.model_kwargs,
                )
                self.image_text_pipe.model.generation_config.do_sample = False
            except Exception as e:
                print(f"Warning: Could not setup multimodal pipeline: {e}")
                self.image_text_pipe = None

        print("Model setup complete!")

    def calibrate_microphone(self):
        """Calibrate microphone for ambient noise"""
        print("Calibrating microphone for ambient noise...")
        try:
            with self.microphone as source:
                self.recognizer.adjust_for_ambient_noise(source, duration=1)
            print("Microphone calibrated!")
        except Exception as e:
            print(f"Warning: Could not calibrate microphone: {e}")

    def listen_for_speech(self, timeout=10, phrase_time_limit=None):
        """
        Listen for speech input from microphone

        Args:
            timeout (int): Maximum seconds to wait for speech
            phrase_time_limit (int): Maximum seconds for a phrase

        Returns:
            str: Transcribed text or None if failed
        """
        try:
            print("🎤 Listening for speech...")
            with self.microphone as source:
                # Listen for audio with timeout
                audio = self.recognizer.listen(
                    source,
                    timeout=timeout,
                    phrase_time_limit=phrase_time_limit
                )

            print("🔄 Processing speech...")
            # Convert speech to text using Google's speech recognition
            text = self.recognizer.recognize_google(audio)
            print(f"📝 Transcribed: '{text}'")
            return text

        except sr.WaitTimeoutError:
            print("⏰ No speech detected within timeout period")
            return None
        except sr.UnknownValueError:
            print("❌ Could not understand the speech")
            return None
        except sr.RequestError as e:
            print(f"🚫 Error with speech recognition service: {e}")
            return None

    def generate_text_response(self, user_input, system_instruction="You are a helpful medical assistant."):
        """
        Generate response using MedGemma for text-only input

        Args:
            user_input (str): User's input text
            system_instruction (str): System instruction for the model

        Returns:
            str: Generated response
        """
        # Prepare system instruction for thinking mode (following official pattern)
        role_instruction = system_instruction
        if "27b" in self.model_variant and self.is_thinking:
            system_instruction = f"SYSTEM INSTRUCTION: think silently if needed. {role_instruction}"
            max_new_tokens = 1500
        else:
            system_instruction = role_instruction
            max_new_tokens = 500

        # Prepare messages following official format
        messages = [
            {
                "role": "system",
                "content": system_instruction
            },
            {
                "role": "user",
                "content": user_input
            }
        ]

        print("🧠 Generating response...")

        # Generate response using pipeline (following official pattern)
        output = self.text_pipe(messages, max_new_tokens=max_new_tokens)
        response = output[0]["generated_text"][-1]["content"]

        # Handle thinking mode output (following official pattern)
        thought = None
        if "27b" in self.model_variant and self.is_thinking and "<unused95>" in response:
            thought, response = response.split("<unused95>")
            thought = thought.replace("<unused94>thought\n", "")

        return response, thought

    def generate_multimodal_response(self, user_input, image_path, system_instruction="You are an expert radiologist."):
        """
        Generate response using MedGemma for image + text input

        Args:
            user_input (str): User's text input
            image_path (str): Path to the image file
            system_instruction (str): System instruction for the model

        Returns:
            str: Generated response
        """
        if self.is_text_only or self.image_text_pipe is None:
            return "This model variant does not support image inputs.", None

        try:
            # Load image
            image = Image.open(image_path)

            # Format conversation following official pattern
            messages = [
                {
                    "role": "system",
                    "content": [{"type": "text", "text": system_instruction}]
                },
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": user_input},
                        {"type": "image", "image": image}
                    ]
                }
            ]

            print("🧠 Generating multimodal response...")

            # Generate response using pipeline
            output = self.image_text_pipe(text=messages, max_new_tokens=300)
            response = output[0]["generated_text"][-1]["content"]

            return response, None

        except Exception as e:
            return f"Error processing image: {e}", None

    def run_interactive_session(self, enable_multimodal=False):
        """
        Run an interactive speech-to-text session

        Args:
            enable_multimodal (bool): Whether to enable image input support
        """
        print("🏥 Starting MedGemma Speech-to-Text Medical Assistant")
        print("=" * 60)
        print("Commands:")
        print("- Say 'quit' or 'exit' to end the session")
        if enable_multimodal and not self.is_text_only:
            print("- Type 'image' to upload an image for analysis")
        print("- Press Enter and start speaking for medical questions")
        print("=" * 60)

        # Calibrate microphone
        self.calibrate_microphone()

        while True:
            print(f"\n{'='*20} New Interaction {'='*20}")

            # Check for special commands
            user_command = input("🔹 Press Enter to speak, or type a command: ").strip().lower()

            if user_command in ['quit', 'exit']:
                print("👋 Goodbye! Stay healthy!")
                break

            if user_command == 'image' and enable_multimodal and not self.is_text_only:
                # Handle image input
                image_path = input("📁 Enter image path: ").strip()
                if not os.path.exists(image_path):
                    print("❌ Image file not found!")
                    continue

                # Get speech input for image description
                print("🎤 Now speak your question about the image...")
                transcribed_text = self.listen_for_speech(timeout=15, phrase_time_limit=15)

                if transcribed_text is None:
                    print("❌ No valid speech input received.")
                    continue

                if transcribed_text.lower() in ['quit', 'exit', 'stop']:
                    break

                # Generate multimodal response
                try:
                    response, thought = self.generate_multimodal_response(
                        transcribed_text,
                        image_path,
                        "You are an expert radiologist."
                    )

                    # Display results
                    self._display_results(transcribed_text, response, thought, image_path=image_path)

                except Exception as e:
                    print(f"❌ Error generating response: {e}")

                continue

            # Regular speech input
            transcribed_text = self.listen_for_speech(timeout=15, phrase_time_limit=15)

            if transcribed_text is None:
                print("❌ No valid speech input received. Try again.")
                continue

            # Check if user wants to quit
            if transcribed_text.lower() in ['quit', 'exit', 'stop']:
                print("👋 Goodbye! Stay healthy!")
                break

            # Generate text response
            try:
                response, thought = self.generate_text_response(
                    transcribed_text,
                    "You are a helpful medical assistant."
                )

                # Display results
                self._display_results(transcribed_text, response, thought)

            except Exception as e:
                print(f"❌ Error generating response: {e}")

    def _display_results(self, user_input, response, thought=None, image_path=None):
        """Display the conversation results in a formatted way"""
        print(f"\n🗣️  **User Speech Input:**")
        print(f"   '{user_input}'")

        if image_path:
            print(f"📷 **Image:** {image_path}")

        if thought:
            print(f"\n🤔 **MedGemma Thinking:**")
            print(f"   {thought}")

        print(f"\n🏥 **MedGemma Response:**")
        print("-" * 60)
        print(response)
        print("-" * 60)

    def process_single_speech_input(self, system_instruction="You are a helpful medical assistant."):
        """
        Process a single speech input and return response

        Args:
            system_instruction (str): System instruction for the model

        Returns:
            tuple: (transcribed_text, response, thought)
        """
        print("🎤 Speak now...")

        # Calibrate microphone if not done
        self.calibrate_microphone()

        # Listen for speech
        transcribed_text = self.listen_for_speech(timeout=15, phrase_time_limit=15)

        if transcribed_text is None:
            return None, None, None

        # Generate response
        response, thought = self.generate_text_response(transcribed_text, system_instruction)

        return transcribed_text, response, thought

def main():
    """Main execution function"""
    print("🚀 Initializing MedGemma Speech-to-Text Medical Assistant...")
    print("Copyright 2025 Google LLC - Licensed under Apache License 2.0")
    print("-" * 60)

    try:
        # Initialize the system with configuration
        stt_medgemma = SpeechToTextMedGemma(
            model_variant="4b-it",  # Change to "27b-text-it" for text-only 27B model
            use_quantization=True,
            is_thinking=False  # Set to True for 27B variant with thinking mode
        )

        # Display model information
        print(f"✅ Model loaded: {stt_medgemma.model_id}")
        print(f"✅ Text-only variant: {stt_medgemma.is_text_only}")
        print(f"✅ Quantization enabled: {stt_medgemma.use_quantization}")
        print(f"✅ Thinking mode: {stt_medgemma.is_thinking}")

        # Choose execution mode
        print("\nSelect mode:")
        print("1. Interactive session (recommended)")
        print("2. Single question mode")

        choice = input("Enter choice (1 or 2): ").strip()

        if choice == "1":
            # Interactive session with multimodal support
            enable_multimodal = not stt_medgemma.is_text_only
            stt_medgemma.run_interactive_session(enable_multimodal=enable_multimodal)

        elif choice == "2":
            # Single input mode
            transcribed_text, response, thought = stt_medgemma.process_single_speech_input()
            if transcribed_text:
                stt_medgemma._display_results(transcribed_text, response, thought)
            else:
                print("❌ No speech input received.")

        else:
            print("❌ Invalid choice. Please run again.")

    except Exception as e:
        print(f"❌ Error initializing system: {e}")
        print("Please check your setup and try again.")

if __name__ == "__main__":
    main()