<a href="https://colab.research.google.com/github/Vishwam2609/Voice-based-Medical-Support-using-LLM/blob/Patient-Symptom-Checker/12_03.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install required libraries quietly
!pip install transformers -q
!pip install pydub -q
!pip install ipywidgets -q
!pip install nltk -q
!pip install torchaudio -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m26.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m65.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m44.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m39.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [7]:
# =============================
# CONFIGURATION & IMPORTS
# =============================

import os
import re
import logging
from base64 import b64decode
from io import BytesIO

import torch
import torchaudio
import ipywidgets as widgets
from IPython.display import Javascript, display, clear_output, Audio
from google.colab import output
from pydub import AudioSegment
from pydub.effects import normalize

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
    WhisperFeatureExtractor,
    WhisperForConditionalGeneration,
    WhisperTokenizer,
)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.info("Starting Symptom Guide Application")

# Suppress transformers logging to avoid clutter
logging.getLogger("transformers").setLevel(logging.ERROR)

# Configuration: load parameters from environment variables or use defaults
HUGGING_FACE_TOKEN = os.getenv("HUGGING_FACE_TOKEN", "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")
LLM_MODEL_NAME = os.getenv("LLM_MODEL_NAME", "meta-llama/Llama-3.2-1B")
WHISPER_MODEL_NAME = os.getenv("WHISPER_MODEL_NAME", "openai/whisper-small")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# =============================
# AUDIO HANDLING MODULE
# =============================

class AudioHandler:
    """
    Handles audio processing tasks such as lazy loading of the Whisper ASR model and transcription.
    """
    def __init__(self):
        self.feature_extractor = None
        self.tokenizer = None
        self.asr_model = None

    def load_model(self):
        if self.asr_model is None:
            try:
                logger.info("Lazy loading Whisper model for transcription")
                self.feature_extractor = WhisperFeatureExtractor.from_pretrained(WHISPER_MODEL_NAME)
                self.tokenizer = WhisperTokenizer.from_pretrained(WHISPER_MODEL_NAME)
                self.asr_model = WhisperForConditionalGeneration.from_pretrained(
                    WHISPER_MODEL_NAME,
                    torch_dtype=torch.float16 if DEVICE=="cuda" else torch.float32
                ).to(DEVICE)
                self.tokenizer.pad_token = self.tokenizer.eos_token
                logger.info("Whisper model loaded successfully")
            except Exception as e:
                logger.error("Error lazy loading Whisper model", exc_info=True)
                raise e

    def transcribe_audio(self, audio_file):
        if self.asr_model is None:
            self.load_model()
        try:
            logger.info(f"Starting transcription for file: {audio_file}")
            waveform, sample_rate = torchaudio.load(audio_file)
            if sample_rate != 16000:
                logger.info(f"Resampling audio from {sample_rate} to 16000")
                waveform = torchaudio.functional.resample(waveform, sample_rate, 16000)
            input_features = self.feature_extractor(
                waveform.squeeze().numpy(),
                sampling_rate=16000,
                return_tensors="pt"
            ).input_features.to(DEVICE)
            if DEVICE=="cuda":
                input_features = input_features.half()
            predicted_ids = self.asr_model.generate(input_features, language='en')
            transcription = self.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True)[0]
            logger.info("Transcription completed successfully")
            return transcription
        except Exception as e:
            logger.error("Failed to transcribe audio", exc_info=True)
            return f"Error: Could not transcribe audio. {e}"

# =============================
# LLM HANDLING MODULE
# =============================

class LLMHandler:
    """
    Handles interactions with the large language model (LLM) for generating medical guidance.
    All functions now use only the combined symptom context.
    """
    def __init__(self):
        self.tokenizer = None
        self.model = None
        self.pipeline = None

    def load_model(self):
        if self.pipeline is None:
            try:
                logger.info("Lazy loading LLM model")
                self.tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME, token=HUGGING_FACE_TOKEN)
                self.model = AutoModelForCausalLM.from_pretrained(
                    LLM_MODEL_NAME,
                    token=HUGGING_FACE_TOKEN,
                    torch_dtype=torch.float16 if DEVICE=="cuda" else torch.float32
                )
                self.pipeline = pipeline(
                    "text-generation",
                    model=self.model,
                    tokenizer=self.tokenizer,
                    device=0 if DEVICE=="cuda" else -1
                )
                logger.info("LLM model loaded successfully")
            except Exception as e:
                logger.error("Error lazy loading LLM model", exc_info=True)
                raise e

    def generate_text(self, prompt, max_new_tokens, num_beams, temperature, repetition_penalty, early_stopping=True):
        if self.pipeline is None:
            self.load_model()
        try:
            logger.info("Generating text from LLM")
            result = self.pipeline(
                prompt,
                max_new_tokens=max_new_tokens,
                num_beams=num_beams,
                early_stopping=early_stopping,
                temperature=temperature,
                repetition_penalty=repetition_penalty
            )[0]['generated_text']
            logger.info("Text generation completed")
            return result
        except Exception as e:
            logger.error("Error during text generation", exc_info=True)
            return "Error: Could not generate text."

    def generate_causes(self, combined_symptom):
        prompt = (
            f"As a doctor, list three distinct possible causes for a patient experiencing the following condition:\n"
            f"{combined_symptom}\n\n"
            "Your answer must be exactly three non-empty lines in the format:\n"
            "Cause 1: <detailed explanation>\n"
            "Cause 2: <detailed explanation>\n"
            "Cause 3: <detailed explanation>\n"
            "Do not use any placeholder text such as '<explanation>'.\nAnswer:\n"
        )
        result = self.generate_text(prompt, max_new_tokens=250, num_beams=5, temperature=0.8, repetition_penalty=1.2)
        if "Answer:" in result:
            result = result.split("Answer:", 1)[1].strip()
        causes = []
        for line in result.splitlines():
            match = re.match(r"(?i)^cause\s*\d+\s*:\s*(.+)", line)
            if match:
                text = match.group(1).strip()
                if "<" not in text and text.lower() != "no detailed cause provided.":
                    causes.append(text)
        while len(causes) < 3:
            causes.append("No detailed cause provided.")
        return causes[:3]

    def generate_help(self, combined_symptom):
        prompt = (
            f"As a doctor, review the following patient details:\n{combined_symptom}\n\n"
            "Now, provide three concise and effective warning signs that indicate the need for immediate medical attention. "
            "Each warning should be a brief sentence focusing only on the critical symptoms and a clear instruction to seek care. "
            "Your answer must consist of exactly three non-empty lines, each starting with 'Warning 1:', 'Warning 2:', and 'Warning 3:' respectively, with no extra commentary.\n"
            "Answer:\n"
        )
        result = self.generate_text(prompt, max_new_tokens=180, num_beams=5, temperature=0.4, repetition_penalty=1.2)
        if "Answer:" in result:
            result = result.split("Answer:", 1)[1].strip()
        warnings = []
        for line in result.splitlines():
            match = re.match(r"(?i)^warning\s*\d+\s*:\s*(.+)", line)
            if match:
                warnings.append(match.group(1).strip())
        while len(warnings) < 3:
            warnings.append("Seek immediate care if symptoms worsen.")
        return warnings[:3]

    def generate_tips(self, combined_symptom):
        prompt = (
            f"As a doctor, list three distinct and practical tips for a patient to feel better when experiencing the following condition:\n"
            f"{combined_symptom}\n\n"
            "Your answer must be exactly three non-empty lines, each beginning with 'Tip 1: ', 'Tip 2: ', and 'Tip 3: '.\n"
            "Do not use placeholder text.\nAnswer:\n"
        )
        result = self.generate_text(prompt, max_new_tokens=100, num_beams=5, temperature=0.5, repetition_penalty=1.2)
        if "Answer:" in result:
            result = result.split("Answer:", 1)[1].strip()
        tips = []
        for line in result.splitlines():
            match = re.match(r"(?i)^tip\s*\d+\s*:\s*(.+)", line)
            if match:
                tips.append(match.group(1).strip())
        seen = set()
        unique_tips = []
        for tip in tips:
            if tip.lower() not in seen:
                unique_tips.append(tip)
                seen.add(tip.lower())
        while len(unique_tips) < 3:
            unique_tips.append("Additional guideline not provided.")
        return unique_tips[:3]

    def generate_seriousness(self, combined_symptom):
        prompt = (
            f"As a doctor, review the following patient details:\n{combined_symptom}\n\n"
            "Now, provide three distinct and concise bullet points that evaluate the seriousness of the condition. "
            "Each bullet point must be a complete sentence that explains a specific potential risk and why that risk is concerning. "
            "Each bullet point must be prefixed with 'Seriousness 1:', 'Seriousness 2:', and 'Seriousness 3:' respectively, with no extra commentary.\n"
            "Answer:\n"
        )
        result = self.generate_text(prompt, max_new_tokens=250, num_beams=5, temperature=0.6, repetition_penalty=1.2)
        if "Answer:" in result:
            result = result.split("Answer:", 1)[1].strip()
        seriousness_points = []
        for line in result.splitlines():
            match = re.match(r"(?i)^seriousness\s*(?:\d+)?\s*:\s*(.+)", line)
            if match:
                seriousness_points.append(match.group(1).strip())
        # Format as bullet points
        bullet_points = ["• " + point for point in seriousness_points[:3]]
        return "\n".join(bullet_points)

    def generate_concise_guideline(self, combined_symptom):
        prompt = (
            f"As a doctor, list exactly five concise and actionable home care guidelines for a patient experiencing the following:\n"
            f"{combined_symptom}\n\n"
            "Answer in exactly five lines, each starting with 'Guideline 1: ', 'Guideline 2: ', 'Guideline 3: ', 'Guideline 4: ', and 'Guideline 5: '.\n"
            "Do not include any extra text.\nAnswer:\n"
        )
        result = self.generate_text(prompt, max_new_tokens=200, num_beams=5, temperature=0.5, repetition_penalty=1.2)
        if "Answer:" in result:
            result = result.split("Answer:", 1)[1].strip()
        guidelines = []
        for line in result.splitlines():
            if re.search(r"(?i)^guideline\s*\d+\s*:", line):
                parts = line.split(":", 1)
                guidelines.append(parts[1].strip() if len(parts) > 1 else "No guideline available.")
        while len(guidelines) < 5:
            guidelines.append("No guideline available.")
        return guidelines[:5]

    def generate_full_response(self, combined_symptom):
        causes = self.generate_causes(combined_symptom)
        tips = self.generate_tips(combined_symptom)
        seriousness = self.generate_seriousness(combined_symptom)
        help_signs = self.generate_help(combined_symptom)
        return causes, tips, seriousness, help_signs

# =============================
# UI HANDLING MODULE
# =============================

class UIHandler:
    """
    Handles the user interface (UI) using ipywidgets and JavaScript for audio recording.
    """
    def __init__(self, audio_handler, llm_handler, symptom_guide):
        self.audio_handler = audio_handler
        self.llm_handler = llm_handler
        self.symptom_guide = symptom_guide
        self.inject_js()

    def inject_js(self):
        RECORD_JS = """
        window.audioRecorder = {
          recorder: null,
          audioChunks: [],
          start: async function() {
            const stream = await navigator.mediaDevices.getUserMedia({audio:true});
            this.recorder = new MediaRecorder(stream);
            this.audioChunks = [];
            this.recorder.ondataavailable = event => {
              this.audioChunks.push(event.data);
            };
            this.recorder.start();
            return "Recording started";
          },
          stop: async function() {
            return new Promise(resolve => {
              this.recorder.onstop = () => {
                let blob = new Blob(this.audioChunks, { type: 'audio/wav' });
                this.audioChunks = [];
                const reader = new FileReader();
                reader.readAsDataURL(blob);
                reader.onloadend = function() {
                  let base64data = reader.result;
                  resolve(base64data);
                };
              };
              this.recorder.stop();
            });
          }
        };
        """
        display(Javascript(RECORD_JS))

    def start_recording(self):
        clear_output(wait=True)
        self.inject_js()
        output.eval_js('window.audioRecorder.start()')
        logger.info("Recording started by user")
        stop_button = widgets.Button(description="Stop Recording")
        stop_button.on_click(self.stop_recording)
        display(stop_button)

    def stop_recording(self, b):
        try:
            logger.info("Stopping recording")
            audio_data = output.eval_js('window.audioRecorder.stop()')
            binary = b64decode(audio_data.split(',')[1])
            audio = AudioSegment.from_file(BytesIO(binary))
            audio = normalize(audio)
            file_path = "patient_input.wav"
            audio.export(file_path, format="wav")
            logger.info(f"Audio exported to {file_path}")
            clear_output(wait=True)
            display(widgets.HTML("<h3>Review your recording:</h3>"))
            display(Audio(file_path))
            transcription = self.audio_handler.transcribe_audio(file_path)
            self.symptom_guide.patient_symptom_text = transcription
            correction_input = widgets.Textarea(value=transcription, description='Corrected Symptoms:')
            confirm_button = widgets.Button(description="Confirm Correction")
            retry_button = widgets.Button(description="Retry Recording")
            button_box = widgets.HBox([confirm_button, retry_button],
                                      layout=widgets.Layout(justify_content='space-between', width='50%'))
            confirm_button.on_click(lambda x: self.symptom_guide.process_corrected_symptoms(correction_input.value))
            retry_button.on_click(lambda x: self.symptom_guide.re_record())
            display(correction_input, button_box)
        except Exception as e:
            logger.error("Error during stop_recording", exc_info=True)
            clear_output(wait=True)
            display(widgets.HTML(f"<h3>Error: {e}</h3>"))
            retry_button = widgets.Button(description="Retry Recording")
            retry_button.on_click(lambda x: self.symptom_guide.re_record())
            display(retry_button)

# =============================
# MAIN APPLICATION MODULE
# =============================

class SymptomGuide:
    """
    Main application class that manages patient data, UI flow, and integration of audio and LLM modules.
    """
    def __init__(self):
        self.patient_symptom_text = ""
        self.key_symptom = ""
        self.followup_answers = []   # List to store each answer.
        self.followup_questions = [] # List of questions based on symptom category.
        self.current_question_index = 0
        self.symptom_category = ""
        self.audio_handler = AudioHandler()
        self.llm_handler = LLMHandler()
        self.ui_handler = UIHandler(self.audio_handler, self.llm_handler, self)

    def process_corrected_symptoms(self, corrected_text):
        clear_output(wait=True)
        logger.info("Processing corrected symptoms")
        self.patient_symptom_text = corrected_text
        self.key_symptom = self.extract_key_symptom(corrected_text)
        self.ask_followup_questions()

    def extract_key_symptom(self, text):
        patterns = [
            r"^i have\s+",
            r"^i am having\s+",
            r"^i'm having\s+",
            r"^i feel\s+",
            r"^i am experiencing\s+",
            r"^my\s+"
        ]
        cleaned_text = text.strip()
        for pat in patterns:
            cleaned_text = re.sub(pat, "", cleaned_text, flags=re.IGNORECASE)
        cleaned_text = re.sub(r"[.?!]$", "", cleaned_text)
        return cleaned_text.strip().lower() if cleaned_text.strip() else text.strip()

    def ask_followup_questions(self):
        clear_output(wait=True)
        key = self.key_symptom.lower()
        if "fever" in key:
            category = "fever"
        elif "cough" in key:
            category = "cough"
        elif "headache" in key:
            category = "headache"
        elif "back pain" in key:
            category = "back pain"
        elif "toothache" in key:
            category = "toothache"
        else:
            category = "general"
        self.symptom_category = category

        questions_dict = {
            "fever": [
                "At what time of day did you first notice your fever?",
                "Would you describe your fever as constant or intermittent?",
                "Are you experiencing any additional symptoms along with your fever?"
            ],
            "cough": [
                "Can you classify your cough as either dry or productive?",
                "Is your cough acute or chronic?",
                "Do you experience breathlessness when you cough?"
            ],
            "headache": [
                "How would you rate the intensity of your headache as mild, moderate, or severe?",
                "Is your headache localized to one side or generalized across your head?",
                "Are you also experiencing nausea along with your headache?"
            ],
            "back pain": [
                "Can you specify whether your back pain is in your upper or lower back?",
                "Did your back pain begin suddenly or gradually?",
                "Is your back pain aggravated by movement?"
            ],
            "toothache": [
                "How would you describe your toothache: is the pain sharp or dull?",
                "Is the pain from your toothache constant or intermittent?",
                "Does your toothache extend to your jaw?"
            ],
            "general": [
                "Please provide additional details about your symptom."
            ]
        }
        self.followup_questions = questions_dict.get(category, questions_dict["general"])
        self.followup_answers = []
        self.current_question_index = 0
        self.ask_next_followup_question()

    def ask_next_followup_question(self):
        clear_output(wait=True)
        if self.current_question_index >= len(self.followup_questions):
            self.process_followup_answers()
            return

        question = self.followup_questions[self.current_question_index]
        question_label = widgets.HTML(f"<h3>{question}</h3>")
        answer_input = widgets.Text(value="", placeholder="Type your answer here...")
        next_button = widgets.Button(description="Next")

        def on_next(b):
            answer = answer_input.value.strip()
            self.followup_answers.append(answer)
            self.current_question_index += 1
            self.ask_next_followup_question()

        next_button.on_click(on_next)
        display(question_label, answer_input, next_button)

    def process_followup_answers(self):
        followup_text = " ".join(self.followup_answers)
        self.process_all_patient_details(followup_text)

    def process_all_patient_details(self, followup_text):
        clear_output(wait=True)
        combined_symptom = self.key_symptom
        if followup_text:
            combined_symptom += ". " + followup_text
        logger.info(f"Combined symptom: {combined_symptom}")
        causes, tips, seriousness, help_signs = self.llm_handler.generate_full_response(combined_symptom)
        final_answer = self.parse_llm_response(combined_symptom, causes, tips, seriousness, help_signs)
        full_response = self.format_text_response(self.patient_symptom_text, final_answer)
        clear_output(wait=True)
        print(full_response)
        guideline = self.llm_handler.generate_concise_guideline(combined_symptom)
        print("**Concise Guideline for You:**")
        if isinstance(guideline, str):
            print(guideline)
        else:
            for item in guideline:
                print(f"  • {item}")
        print("\n----------------------------------------")
        print("**Important:** *This advice is informational only. Please see a doctor if your symptoms worsen.*")
        self.cleanup_temp_files()
        re_record_button = widgets.Button(description="Re-Record")
        re_record_button.on_click(lambda x: self.re_record())
        display(re_record_button)

    def parse_llm_response(self, extracted_symptom, causes, tips, seriousness, help_signs):
        output_text = f"**Extracted Symptom:** {extracted_symptom}\n\n"
        output_text += "**What Might Be Causing This:**\n"
        if causes:
            for cause in causes:
                output_text += f"  • {cause}\n"
        else:
            output_text += "  • No detailed cause provided.\n"
        output_text += "\n**Tips to Feel Better:**\n"
        if tips:
            for tip in tips:
                output_text += f"  • {tip}\n"
        else:
            output_text += "  • No tip provided.\n"
        output_text += f"\n**How Serious It Is:**\n  {seriousness}\n"
        output_text += "\n**When to Get Help Right Away:**\n"
        if help_signs:
            for sign in help_signs:
                output_text += f"  • {sign}\n"
        else:
            output_text += "  • No detailed warning provided.\n"
        return output_text

    def format_text_response(self, original_input, final_answer):
        output_text = f"**Symptom Guide for You**\n----------------------------------------\n"
        output_text += f"**You Reported:** {original_input}\n"
        output_text += "\n" + final_answer
        return output_text

    def cleanup_temp_files(self):
        temp_files = ["patient_input.wav"]
        for file in temp_files:
            if os.path.exists(file):
                try:
                    os.remove(file)
                    logger.info(f"Removed temporary file: {file}")
                except Exception as e:
                    logger.error(f"Error removing temporary file {file}", exc_info=True)

    def re_record(self):
        logger.info("User requested to re-record. Clearing data and restarting.")
        self.clear_all_data()
        self.start()

    def clear_all_data(self):
        self.patient_symptom_text = ""
        self.key_symptom = ""
        self.followup_answers = []
        self.followup_questions = []
        self.current_question_index = 0
        self.symptom_category = ""

    def start(self):
        clear_output(wait=True)
        start_button = widgets.Button(description="Start Recording")
        start_button.on_click(lambda x: self.ui_handler.start_recording())
        display(start_button)
        print("Click 'Start Recording' to begin.")

# =============================
# RUN THE APPLICATION
# =============================

symptom_guide_app = SymptomGuide()
symptom_guide_app.start()

**Symptom Guide for You**
----------------------------------------
**You Reported:** I have fever.

**Extracted Symptom:** fever. Morning Constant Coughing

**What Might Be Causing This:**
  • The patient may have a viral infection such as the common cold or the flu.
  • The patient may have a bacterial infection such as pneumonia or bronchitis.
  • The patient may have an allergic reaction to something in the environment such as pollen or pet dander.

**Tips to Feel Better:**
  • Drink plenty of fluids.
  • Take over-the-counter medications such as acetaminophen (Tylenol) or ibuprofen (Advil, Motrin) to reduce fever and pain.
  • Use a humidifier to add moisture to the air, which can help loosen mucus and make it easier to cough up.

**How Serious It Is:**
  • The patient has a fever. This is concerning because fevers can be a sign of an infection, which can be life-threatening if left untreated.
• The patient has morning coughing. This is concerning because coughing can be a sign of 

Button(description='Re-Record', style=ButtonStyle())