<a href="https://colab.research.google.com/github/Vishwam2609/Voice-based-Medical-Support-using-LLM/blob/Patient-Symptom-Checker/Final_13_03.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# Install required libraries quietly
!pip install transformers -q
!pip install pydub -q
!pip install ipywidgets -q
!pip install nltk -q
!pip install torchaudio -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m40.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m65.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m34.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m39.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [1]:
# =============================
# CONFIGURATION & IMPORTS
# =============================

import os
import re
import logging
from base64 import b64decode
from io import BytesIO

import torch
import torchaudio
import ipywidgets as widgets
from IPython.display import Javascript, display, clear_output, Audio
from google.colab import output
from pydub import AudioSegment
from pydub.effects import normalize

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
    WhisperFeatureExtractor,
    WhisperForConditionalGeneration,
    WhisperTokenizer,
)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.info("Starting Symptom Guide Application")

# Suppress transformers logging
logging.getLogger("transformers").setLevel(logging.ERROR)

# Load configuration parameters
HUGGING_FACE_TOKEN = os.getenv("HUGGING_FACE_TOKEN", "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")
LLM_MODEL_NAME = os.getenv("LLM_MODEL_NAME", "ContactDoctor/Bio-Medical-Llama-3-2-1B-CoT-012025")
WHISPER_MODEL_NAME = os.getenv("WHISPER_MODEL_NAME", "openai/whisper-small")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# =============================
# AUDIO HANDLING MODULE
# =============================

class AudioHandler:
    """Handles audio recording and transcription using Whisper."""
    def __init__(self):
        self.feature_extractor = None
        self.tokenizer = None
        self.asr_model = None

    def load_model(self):
        if self.asr_model is None:
            try:
                logger.info("Loading Whisper model for transcription")
                self.feature_extractor = WhisperFeatureExtractor.from_pretrained(WHISPER_MODEL_NAME)
                self.tokenizer = WhisperTokenizer.from_pretrained(WHISPER_MODEL_NAME)
                self.asr_model = WhisperForConditionalGeneration.from_pretrained(
                    WHISPER_MODEL_NAME,
                    torch_dtype=torch.float16 if DEVICE=="cuda" else torch.float32
                ).to(DEVICE)
                self.tokenizer.pad_token = self.tokenizer.eos_token
                logger.info("Whisper model loaded")
            except Exception as e:
                logger.error("Error loading Whisper model", exc_info=True)
                raise e

    def transcribe_audio(self, audio_file):
        if self.asr_model is None:
            self.load_model()
        try:
            logger.info(f"Transcribing audio file: {audio_file}")
            waveform, sample_rate = torchaudio.load(audio_file)
            if sample_rate != 16000:
                waveform = torchaudio.functional.resample(waveform, sample_rate, 16000)
            input_features = self.feature_extractor(
                waveform.squeeze().numpy(),
                sampling_rate=16000,
                return_tensors="pt"
            ).input_features.to(DEVICE)
            if DEVICE=="cuda":
                input_features = input_features.half()
            predicted_ids = self.asr_model.generate(input_features, language='en')
            transcription = self.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)[0]
            logger.info("Transcription successful")
            return transcription
        except Exception as e:
            logger.error("Transcription error", exc_info=True)
            return f"Error: Could not transcribe audio. {e}"

# =============================
# LLM HANDLING MODULE
# =============================

class LLMHandler:
    """Handles interactions with the LLM for generating guidance."""
    def __init__(self):
        self.tokenizer = None
        self.model = None
        self.pipeline = None

    def load_model(self):
        if self.pipeline is None:
            try:
                logger.info("Loading LLM model")
                self.tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME, token=HUGGING_FACE_TOKEN)
                self.model = AutoModelForCausalLM.from_pretrained(
                    LLM_MODEL_NAME,
                    token=HUGGING_FACE_TOKEN,
                    torch_dtype=torch.float16 if DEVICE=="cuda" else torch.float32
                )
                self.pipeline = pipeline(
                    "text-generation",
                    model=self.model,
                    tokenizer=self.tokenizer,
                    device=0 if DEVICE=="cuda" else -1
                )
                logger.info("LLM model loaded")
            except Exception as e:
                logger.error("Error loading LLM model", exc_info=True)
                raise e

    def generate_text(self, prompt, max_new_tokens, num_beams, temperature, repetition_penalty, early_stopping=True):
        if self.pipeline is None:
            self.load_model()
        try:
            logger.info("Generating text from LLM")
            result = self.pipeline(
                prompt,
                max_new_tokens=max_new_tokens,
                num_beams=num_beams,
                early_stopping=early_stopping,
                temperature=temperature,
                repetition_penalty=repetition_penalty
            )[0]['generated_text']
            logger.info("Text generation complete")
            return result
        except Exception as e:
            logger.error("LLM generation error", exc_info=True)
            return ""

    def generate_causes(self, detailed_context):
        prompt = (
            "You are a compassionate doctor. Analyze the following patient information and provide three distinct, original possible causes for the condition. "
            "Your causes should be based on medical reasoning and not simply repeat the patient's symptoms.\n"
            "Format your answer as follows (include only your answer):\n"
            "Cause 1: <your answer>\n"
            "Cause 2: <your answer>\n"
            "Cause 3: <your answer>\n\n"
            f"Patient Information:\n{detailed_context}\n\nAnswer:\n"
        )
        for _ in range(5):
            result = self.generate_text(prompt, 125, 5, 0.7, 1.2)
            if "Answer:" in result:
                result = result.split("Answer:", 1)[1].strip()
            causes = [line.split(":",1)[1].strip() for line in result.splitlines()
                      if line.lower().startswith("cause") and len(line.split(":",1)) > 1 and line.split(":",1)[1].strip()]
            if len(causes) == 3:
                return causes
        return causes

    def generate_help(self, detailed_context):
        prompt = (
            "You are a compassionate doctor. Analyze the following patient information and provide three distinct, original warnings for when the patient should seek immediate medical help. "
            "Your warnings should be based on medical reasoning and not simply repeat the patient's symptoms.\n"
            "Format your answer as follows (include only your answer):\n"
            "Warning 1: <your answer>\n"
            "Warning 2: <your answer>\n"
            "Warning 3: <your answer>\n\n"
            f"Patient Information:\n{detailed_context}\n\nAnswer:\n"
        )
        for _ in range(5):
            result = self.generate_text(prompt, 125, 5, 0.7, 1.2)
            if "Answer:" in result:
                result = result.split("Answer:", 1)[1].strip()
            warnings = [line.split(":",1)[1].strip() for line in result.splitlines()
                        if line.lower().startswith("warning") and len(line.split(":",1)) > 1 and line.split(":",1)[1].strip()]
            if len(warnings) == 3:
                return warnings
        return warnings

    def generate_tips(self, detailed_context):
        prompt = (
            "You are a compassionate doctor. Analyze the following patient information and provide three distinct, original tips to help the patient feel better. "
            "Your tips should be based on medical reasoning and not simply repeat the patient's symptoms.\n"
            "Format your answer as follows (include only your answer):\n"
            "Tip 1: <your answer>\n"
            "Tip 2: <your answer>\n"
            "Tip 3: <your answer>\n\n"
            f"Patient Information:\n{detailed_context}\n\nAnswer:\n"
        )
        for _ in range(5):
            result = self.generate_text(prompt, 125, 5, 0.7, 1.2)
            # Debug print can help you see the raw output if needed:
            # print("Raw result:", result)
            if "Answer:" in result:
                result = result.split("Answer:", 1)[1].strip()
            # Use regex to capture any line that starts with "Tip" followed by a digit and a colon
            tips = re.findall(r"Tip\s*\d:\s*(.+)", result)
            # If tips were output as bullet points instead, try an alternative extraction
            if not tips:
                tips = re.findall(r"-\s*(.+)", result)
            # Remove stray asterisks from each tip
            tips = [tip.replace("**", "").strip() for tip in tips]
            if len(tips) == 3:
                return tips
        return tips

    def generate_seriousness(self, detailed_context):
        prompt = (
            "You are a compassionate doctor. Analyze the following patient information and provide three distinct, original statements assessing the seriousness of the patient's condition. "
            "Your statements should be based on medical reasoning and not simply repeat the patient's symptoms.\n"
            "Format your answer as follows (include only your answer):\n"
            "Seriousness 1: <your answer>\n"
            "Seriousness 2: <your answer>\n"
            "Seriousness 3: <your answer>\n\n"
            f"Patient Information:\n{detailed_context}\n\nAnswer:\n"
        )
        for _ in range(5):
            result = self.generate_text(prompt, 125, 5, 0.7, 1.2)
            if "Answer:" in result:
                result = result.split("Answer:", 1)[1].strip()
            seriousness = [line.split(":",1)[1].strip() for line in result.splitlines()
                           if line.lower().startswith("seriousness") and len(line.split(":",1)) > 1 and line.split(":",1)[1].strip()]
            if len(seriousness) == 3:
                return seriousness
        return seriousness

    def generate_concise_guideline(self, detailed_context):
        prompt = (
            "You are a compassionate doctor. Analyze the following patient information and provide three distinct, original home care guidelines for the patient. "
            "Your guidelines must be based on medical reasoning and should not merely repeat the patient's symptoms.\n\n"
            "IMPORTANT: Do not include any extra headings, bullet points, markdown formatting, or commentary. "
            "Only output exactly three lines in the following format:\n"
            "Guideline 1: <your answer>\n"
            "Guideline 2: <your answer>\n"
            "Guideline 3: <your answer>\n\n"
            f"Patient Information:\n{detailed_context}\n\nAnswer:\n"
        )
        for _ in range(5):
            result = self.generate_text(prompt, 125, 5, 0.7, 1.0)
            # Try splitting based on our expected output header if it appears
            if "Answer:" in result:
                result = result.split("Answer:", 1)[1].strip()
            # First, try to extract using the expected format:
            guidelines = re.findall(r"Guideline\s*1:\s*(.+)\nGuideline\s*2:\s*(.+)\nGuideline\s*3:\s*(.+)", result)
            if guidelines:
                return list(guidelines[0])
            # Fallback: check for lines starting with "Guideline" separately:
            guidelines = re.findall(r"Guideline\s*\d:\s*(.+)", result)
            guidelines = [g.replace("**", "").strip() for g in guidelines]
            if len(guidelines) == 3:
                return guidelines
        return guidelines

    def generate_full_response(self, detailed_context):
        causes = self.generate_causes(detailed_context)
        tips = self.generate_tips(detailed_context)
        seriousness = self.generate_seriousness(detailed_context)
        help_signs = self.generate_help(detailed_context)
        return causes, tips, seriousness, help_signs

# =============================
# UI HANDLING MODULE
# =============================

class UIHandler:
    """Handles UI interactions and audio recording."""
    def __init__(self, audio_handler, llm_handler, symptom_guide):
        self.audio_handler = audio_handler
        self.llm_handler = llm_handler
        self.symptom_guide = symptom_guide
        self.inject_js()

    def inject_js(self):
        RECORD_JS = """
        window.audioRecorder = {
          recorder: null,
          audioChunks: [],
          start: async function() {
            const stream = await navigator.mediaDevices.getUserMedia({audio:true});
            this.recorder = new MediaRecorder(stream);
            this.audioChunks = [];
            this.recorder.ondataavailable = event => {
              this.audioChunks.push(event.data);
            };
            this.recorder.start();
            return "Recording started";
          },
          stop: async function() {
            return new Promise(resolve => {
              this.recorder.onstop = () => {
                let blob = new Blob(this.audioChunks, { type: 'audio/wav' });
                this.audioChunks = [];
                const reader = new FileReader();
                reader.readAsDataURL(blob);
                reader.onloadend = function() {
                  let base64data = reader.result;
                  resolve(base64data);
                };
              };
              this.recorder.stop();
            });
          }
        };
        """
        display(Javascript(RECORD_JS))

    def start_recording(self):
        clear_output(wait=True)
        self.inject_js()
        output.eval_js('window.audioRecorder.start()')
        logger.info("Recording started")
        stop_button = widgets.Button(description="Stop Recording")
        stop_button.on_click(self.stop_recording)
        display(stop_button)

    def stop_recording(self, b):
        try:
            logger.info("Stopping recording")
            audio_data = output.eval_js('window.audioRecorder.stop()')
            binary = b64decode(audio_data.split(',')[1])
            audio = AudioSegment.from_file(BytesIO(binary))
            audio = normalize(audio)
            file_path = "patient_input.wav"
            audio.export(file_path, format="wav")
            logger.info(f"Audio exported to {file_path}")
            clear_output(wait=True)
            display(widgets.HTML("<h3>Review your recording:</h3>"))
            display(Audio(file_path))
            transcription = self.audio_handler.transcribe_audio(file_path)
            self.symptom_guide.patient_symptom_text = transcription
            correction_input = widgets.Textarea(value=transcription, description='Corrected Symptoms:')
            confirm_button = widgets.Button(description="Confirm Correction")
            retry_button = widgets.Button(description="Retry Recording")
            button_box = widgets.HBox([confirm_button, retry_button],
                                      layout=widgets.Layout(justify_content='space-between', width='50%'))
            confirm_button.on_click(lambda x: self.symptom_guide.process_corrected_symptoms(correction_input.value))
            retry_button.on_click(lambda x: self.symptom_guide.re_record())
            display(correction_input, button_box)
        except Exception as e:
            logger.error("Recording error", exc_info=True)
            clear_output(wait=True)
            display(widgets.HTML(f"<h3>Error: {e}</h3>"))
            retry_button = widgets.Button(description="Retry Recording")
            retry_button.on_click(lambda x: self.symptom_guide.re_record())
            display(retry_button)

# =============================
# MAIN APPLICATION MODULE
# =============================

class SymptomGuide:
    """
    Implements the complete workflow:
      1. Voice Input & Transcription.
      2. Initial LLM Analysis to extract the key symptom.
      3. Predefined Follow-Up Questions based on the key symptom.
      4. Context Enrichment (using only the key symptom and follow-up Q&A).
      5. Section-wise Output Generation in a strict yet friendly doctor tone.
    """
    def __init__(self):
        self.patient_symptom_text = ""
        self.key_symptom = ""
        self.followup_answers = []
        self.followup_questions = []
        self.current_question_index = 0
        self.audio_handler = AudioHandler()
        self.llm_handler = LLMHandler()
        self.ui_handler = UIHandler(self.audio_handler, self.llm_handler, self)

    def process_corrected_symptoms(self, corrected_text):
        clear_output(wait=True)
        logger.info("Processing corrected symptoms")
        self.patient_symptom_text = corrected_text
        # Extract key symptom from corrected text.
        self.key_symptom = self.extract_key_symptom(self.patient_symptom_text)
        # Ask follow-up questions based on the key symptom.
        self.ask_followup_questions()

    def ask_followup_questions(self):
        clear_output(wait=True)
        # Choose follow-up questions based on the extracted key symptom.
        text_lower = self.key_symptom.lower()
        if "fever" in text_lower:
            category = "fever"
        elif "cough" in text_lower:
            category = "cough"
        elif "headache" in text_lower:
            category = "headache"
        elif "back pain" in text_lower:
            category = "back pain"
        elif "toothache" in text_lower:
            category = "toothache"
        else:
            category = "general"
        questions_dict = {
            "fever": [
                "At what time of day did you first notice your fever?",
                "Is your fever constant or intermittent?",
                "Are there any additional symptoms accompanying your fever?"
            ],
            "cough": [
                "Is your cough dry or productive?",
                "How long have you been experiencing your cough?",
                "Do you experience difficulty breathing with your cough?"
            ],
            "headache": [
                "How severe is your headache?",
                "Is your headache localized or widespread?",
                "Do you experience any other symptoms such as nausea?"
            ],
            "back pain": [
                "Is your back pain in the upper or lower region?",
                "Did your back pain start suddenly or gradually?",
                "Does movement worsen your back pain?"
            ],
            "toothache": [
                "Is your toothache sharp or dull?",
                "Is the pain constant or intermittent?",
                "Does the pain extend to your jaw?"
            ],
            "general": [
                "Please provide any additional details about your symptoms."
            ]
        }
        self.followup_questions = questions_dict.get(category, questions_dict["general"])
        self.followup_answers = []
        self.current_question_index = 0
        self.ask_next_followup_question()

    def ask_next_followup_question(self):
        clear_output(wait=True)
        if self.current_question_index >= len(self.followup_questions):
            self.process_followup_answers()
            return
        question = self.followup_questions[self.current_question_index]
        question_label = widgets.HTML(f"<h3>{question}</h3>")
        answer_input = widgets.Text(value="", placeholder="Type your answer here...")
        next_button = widgets.Button(description="Next")
        def on_next(b):
            self.followup_answers.append(answer_input.value.strip())
            self.current_question_index += 1
            self.ask_next_followup_question()
        next_button.on_click(on_next)
        display(question_label, answer_input, next_button)

    def process_followup_answers(self):
        # Build detailed context using only the key symptom and follow-up Q&A.
        followup_qas = "\n".join([f"Q: {q}\nA: {a}" for q, a in zip(self.followup_questions, self.followup_answers)])
        detailed_context = (
            f"Key Symptom: {self.key_symptom}\n"
            f"Follow-up Q&A:\n{followup_qas}"
        )
        logger.info(f"Detailed context:\n{detailed_context}")
        # Generate section-wise output.
        causes, tips, seriousness, help_signs = self.llm_handler.generate_full_response(detailed_context)
        final_answer = self.parse_llm_response(self.key_symptom, causes, tips, seriousness, help_signs)
        full_response = self.format_text_response(self.patient_symptom_text, final_answer)
        clear_output(wait=True)
        print(full_response)
        guideline = self.llm_handler.generate_concise_guideline(detailed_context)
        print("**Concise Guideline for You:**")
        if guideline:
            for item in guideline:
                print(f"  - {item}")
        else:
            print("No guidelines were generated. Please try again.")
        print("\n----------------------------------------")
        print("**Important:** *This advice is informational only. Please see a doctor if your symptoms worsen.*")
        self.cleanup_temp_files()
        re_record_button = widgets.Button(description="Re-Record")
        re_record_button.on_click(lambda x: self.re_record())
        display(re_record_button)

    def extract_key_symptom(self, text):
        prompt = (
            "Based on the following patient information, extract the primary symptom that is the main reason for seeking medical advice. "
            "Provide your answer as one concise sentence without extra commentary.\n"
            f"{text}\n\nAnswer:\n"
        )
        extraction = self.llm_handler.generate_text(prompt, 50, 3, 0.7, 1.2)
        if "Answer:" in extraction:
            extraction = extraction.split("Answer:", 1)[1].strip()
        return extraction

    def parse_llm_response(self, key_symptom, causes, tips, seriousness, help_signs):
        output_text = f"**Extracted Key Symptom:** {key_symptom}\n\n"
        output_text += "**What Might Be Causing This:**\n"
        for cause in causes:
            output_text += f"  - {cause}\n"
        output_text += "\n**Tips to Feel Better:**\n"
        for tip in tips:
            output_text += f"  - {tip}\n"
        output_text += "\n**How Serious It Is:**\n"
        for statement in seriousness:
            output_text += f"  - {statement}\n"
        output_text += "\n**When to Get Help Right Away:**\n"
        for sign in help_signs:
            output_text += f"  - {sign}\n"
        return output_text

    def format_text_response(self, original_input, final_answer):
        return f"**Symptom Guide for You**\n----------------------------------------\n**You Reported:** {original_input}\n\n{final_answer}"

    def cleanup_temp_files(self):
        for file in ["patient_input.wav"]:
            if os.path.exists(file):
                try:
                    os.remove(file)
                    logger.info(f"Removed temporary file: {file}")
                except Exception as e:
                    logger.error(f"Error removing temporary file {file}", exc_info=True)

    def re_record(self):
        logger.info("Re-record requested. Clearing data.")
        self.clear_all_data()
        self.start()

    def clear_all_data(self):
        self.patient_symptom_text = ""
        self.key_symptom = ""
        self.followup_answers = []
        self.followup_questions = []
        self.current_question_index = 0

    def start(self):
        clear_output(wait=True)
        start_button = widgets.Button(description="Start Recording")
        start_button.on_click(lambda x: self.ui_handler.start_recording())
        display(start_button)
        print("Click 'Start Recording' to begin.")

# =============================
# RUN THE APPLICATION
# =============================

symptom_guide_app = SymptomGuide()
symptom_guide_app.start()

**Symptom Guide for You**
----------------------------------------
**You Reported:** I have fever.

**Extracted Key Symptom:** The primary symptom that is the main reason for seeking medical advice is **fever**.

**What Might Be Causing This:**
  - It's possible that the cause of the fever could be related to an infection. Given that the fever is constant and accompanied by a cough, it might suggest something like a respiratory infection, such as pneumonia, which is a common cause of fever in adults.
  - Considering the constant nature of the fever and the presence of a cough, another possibility could be an upper respiratory tract infection. This type of infection often starts with a fever and can be accompanied by a cough, making it a plausible explanation.
  - Although the fever is constant, the fact that there are no additional symptoms like

**Tips to Feel Better:**
  - When it comes to managing a fever, especially in children, it's crucial to consider their age and overall health

Button(description='Re-Record', style=ButtonStyle())