In [1]:
import argparse
import numpy as np
import sounddevice as sd
import whisper
import json
import faiss
from sentence_transformers import SentenceTransformer
import boto3
from botocore.config import Config
import re
from openai import OpenAI
import requests
import time
from collections import defaultdict, Counter
from transformers import pipeline

  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (


In [2]:
# Set your keys directly in code (make sure to keep them safe!)
AWS_ACCESS_KEY_ID = "YOUR-ACCESS-KEY"
AWS_SECRET_ACCESS_KEY = "YOUR-SECRET-KEY"
AWS_REGION = "us-west-2"  # Change to your preferred region

# Configure the Comprehend Medical client
comprehend_client = boto3.client(
    service_name='comprehendmedical',
    region_name=AWS_REGION,
    aws_access_key_id=AWS_ACCESS_KEY_ID,
    aws_secret_access_key=AWS_SECRET_ACCESS_KEY
)

In [3]:
def record_audio(duration: float = 5.0, fs: int = 16000) -> np.ndarray:
    """
    Record audio from the default microphone for the given duration.
    Returns a 1-D numpy array of float32 samples.
    """
    print(f"Recording {duration} seconds of audio (fs={fs})...")
    recording = sd.rec(int(duration * fs), samplerate=fs, channels=1, dtype='float32')
    sd.wait()  # wait until recording is finished
    audio = np.squeeze(recording)
    print("Recording complete.")
    return audio

class SpeechTranscriber:
    """
    Transcribes audio (numpy array or file path) using Whisper ASR.
    """
    def __init__(self, model_name: str = "base"):
        self.model = whisper.load_model(model_name)

    def transcribe(self, audio: np.ndarray) -> str:
        # Whisper's transcribe accepts a numpy array directly
        result = self.model.transcribe(audio, fp16=False)
        return result.get("text", "").strip()

In [4]:
def extract_medical_entities(text, client):
    """
    Extract symptoms and other medical entities using Amazon Comprehend Medical.
    """
    response = client.detect_entities_v2(Text=text)

    symptoms = []
    other_keywords = []

    for entity in response["Entities"]:
        category = entity.get("Category")
        text_value = entity.get("Text")
        if category == "MEDICAL_CONDITION":
            symptoms.append(text_value)
        else:
            other_keywords.append((category, text_value))

    return {
        "symptoms": symptoms,
        "other_keywords": other_keywords,
        "raw_response": response
    }

In [5]:
def run_rag_query(transcript, model, index_path, meta_path, jsonl_path, top_k=10):
    # Load FAISS index and metadata
    index = faiss.read_index(index_path)
    with open(meta_path, "r", encoding="utf-8") as f:
        metadata = json.load(f)
    with open(jsonl_path, "r", encoding="utf-8") as f:
        records = [json.loads(line) for line in f]

    # Embed query and normalize
    query_emb = model.encode([transcript])
    query_emb = np.array(query_emb).astype("float32")
    faiss.normalize_L2(query_emb)

    # Search
    D, I = index.search(query_emb, top_k)

    
    retrieved_texts = []
#     print("\nTop Retrieved Sections:")
    for rank, idx in enumerate(I[0]):
        rec = records[idx]
        retrieved_texts.append(rec["text"])
#         print(f"\n--- Result {rank+1} ---")
#         print(f"Title: {rec['doc_title']}")
#         print(f"Section: {rec['section_title']}")
#         print(f"Score: {D[0][rank]:.4f}")
#         print(f"Text:\n{rec['text'][:500]}...")
    
    return retrieved_texts

In [6]:
def generate_rag_response(openai_client, text, retrieved_context, model="YOUR-FINETUNEDMODEL"):
    """
    Generates a final response from ChatGPT using both the patient transcript and the retrieved medical sections.
    """
    # Build context chunks
    context_chunks = "\n\n".join([f"Section {i+1}: {ctx}" for i, ctx in enumerate(retrieved_context)])

    # Updated system prompt
    system_prompt = (
        "You are a medical diagnostic assistant tasked with analyzing patient-reported symptoms. "
        "Your objective is to determine the three most likely medical diagnoses ranked in order of likelihood. Only use disease names that are consistent with standard clinical terminology. If available, prefer the same naming used in the Infermedica API."
        "Only return disease names that match standard diagnostic terminology (as used in the Infermedica API)."
        "Avoid umbrella categories, vague conditions, or generalized infections. Use precise medical terms suitable for clinical documentation."
        "Each diagnosis must be supported by the symptoms provided.\n\n"
        "You will be provided two types of information:\n"
        "- Patient-Reported Symptoms (Primary)\n"
        "- Retrieved Medical Context (Secondary)\n\n"
        "Focus mainly on the Patient-Reported Symptoms. Use the retrieved context only if it clearly supports or refutes a diagnosis.\n\n"
        "Respond strictly in the following format:\n\n"
        "Disease: [Disease Name]\n"
        "Disease: [Disease Name]\n"
        "Disease: [Disease Name]\n"
        "Do not add any explanations, commentary, or extra text."
    )

    # Updated user prompt
    user_prompt = f"""You are provided two types of information:

Patient-Reported Symptoms:
{text}

Retrieved Medical Context:
{context_chunks}

---
Based on this information, provide the diagnoses and missing symptoms as per the specified format."""
    
    # API Call
    response = openai_client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=0.5
    )
    
    return response.choices[0].message.content.strip()

In [7]:
def infermedica_diagnose(symptom_names, sex, age, top_k):
    """
    Uses Infermedica API to return top diagnoses based on given symptoms.
    """
    INFERMEDICA_APP_ID = "YOUR-API-ID"
    INFERMEDICA_APP_KEY = "YOUR-API-KEY"
    INFERMEDICA_API_URL = "https://api.infermedica.com/v3"

    HEADERS = {
        "App-Id": INFERMEDICA_APP_ID,
        "App-Key": INFERMEDICA_APP_KEY,
        "Content-Type": "application/json",
        "Accept": "application/json"
    }
    # Step 1: Use /parse to extract symptoms
    parse_payload = {
        "text": symptom_names,
        "age": {
            "value": age,
            "unit": "year"
        },
        "include_tokens": False
    }

    parse_resp = requests.post(f"{INFERMEDICA_API_URL}/parse", headers=HEADERS, json=parse_payload)
    parse_resp.raise_for_status()
    mentions = parse_resp.json().get("mentions", [])
    
    if not mentions:
        print("No symptoms detected by Infermedica.")
        return []

    evidence = [{"id": m["id"], "choice_id": m["choice_id"]} for m in mentions]

    # Step 2: Run diagnosis
    diagnosis_payload = {
        "sex": sex,
        "age": {
            "value": age,
            "unit": "year"
        },
        "evidence": evidence
    }

    diag_resp = requests.post(f"{INFERMEDICA_API_URL}/diagnosis", headers=HEADERS, json=diagnosis_payload)
    diag_resp.raise_for_status()
    conditions = diag_resp.json().get("conditions", [])

    # Step 3: Return top results
    return [{
        "name": c["name"],
        "probability": round(c["probability"], 3)
    } for c in conditions[:top_k]]

In [8]:
def extract_disease_names_from_chatgpt(response_text):
    """
    Extracts disease names from ChatGPT's formatted diagnosis response.
    """
    pattern = r"Disease:\s*(.+)"
    matches = re.findall(pattern, response_text)
    return [match.strip() for match in matches if match.strip()]

def normalize_disease_names_chatgpt(disease_names, openai_client, model="your-Model"):
    """
    Uses ChatGPT to normalize a list of disease names by grouping and standardizing them.
    """
    prompt = (
        "You are a medical assistant. I will give you a list of disease names which may include synonyms, "
        "overlapping names, or general/specific variants (e.g., 'Hepatitis', 'Hepatitis A', 'Hepatitis B').\n\n"
        "Please return a list of the canonical or grouped disease names. Map all similar diseases to the same canonical form.\n"
        "Only return a JSON list of unique, normalized disease names in the best standard medical terminology you can use.\n\n"
        f"Here is the list:\n{disease_names}"
    )

    response = openai_client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": "You are a helpful medical assistant."},
            {"role": "user", "content": prompt}
        ],
        temperature=0
    )

    try:
        # Try to parse JSON-like list
        raw_text = response.choices[0].message.content.strip()
        # Example: ['Hepatitis', 'Jaundice', 'Viral infection']
        normalized = eval(raw_text) if raw_text.startswith("[") else [d.strip() for d in raw_text.split("\n") if d.strip()]
        return normalized
    except Exception as e:
        print("ChatGPT normalization failed:", e)
        return disease_names

def merge_and_rank_diagnoses(chatgpt_diseases, kg_diseases, openai_client, chatgpt_model="your-model"):
    """
    Merges diagnoses from ChatGPT and Infermedica, normalizing using ChatGPT itself.
    Returns top 3 merged diagnoses with scores.
    """
    
    # Combine all disease names
    chatgpt_names = [d.strip() for d in chatgpt_diseases]
    kg_names = [d['name'].strip() for d in kg_diseases]
    all_diseases = list(set(chatgpt_names + kg_names))

    # Ask ChatGPT to normalize/group the names
    normalized_list = normalize_disease_names_chatgpt(all_diseases, openai_client, model=chatgpt_model)

    # Create a reverse map: original -> normalized
    norm_map = {}
    for name in all_diseases:
        for norm in normalized_list:
            if name.lower() in norm.lower() or norm.lower() in name.lower():
                norm_map[name] = norm
                break
        else:
            norm_map[name] = name  # fallback

    # Scoring
    score_map = defaultdict(float)
    for name in chatgpt_names:
        norm = norm_map.get(name, name)
        score_map[norm] += 0.2

    for d in kg_diseases:
        norm = norm_map.get(d['name'], d['name'])
        score_map[norm] += d.get("probability", 0)

    # Sort and return top 3
    ranked = sorted(score_map.items(), key=lambda x: x[1], reverse=True)[:3]
    return [{"name": name, "score": round(score, 3)} for name, score in ranked]

def get_missing_symptoms_chatgpt(openai_client, reported_symptoms, diagnoses, model="your-model"):
    """
    Given patient-reported symptoms and a list of disease names,
    asks ChatGPT to return missing symptoms for each disease.
    """
    system_prompt = (
        "You are a clinical assistant. Based on a list of patient-reported symptoms and given diagnoses, "
        "identify important symptoms that are commonly associated with each disease but are missing from the patient's report.\n\n"
        "Do NOT include symptoms that are already present. Only list symptoms that are expected but not mentioned.\n"
        "Format your output like this:\n\n"
        "Disease: [Name]\nMissing Symptoms: [comma-separated list]\n\n"
        "Avoid repeating symptoms. Be clinically accurate but concise. No extra text or explanation."
    )

    user_prompt = (
        f"Patient symptoms:\n{', '.join(reported_symptoms)}\n\n"
        f"Diseases to analyze:\n{', '.join(diagnoses)}"
    )

    response = openai_client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=0
    )

    return response.choices[0].message.content.strip()



# 1. Load a sentiment‐analysis pipeline (SST-2–finetuned model by default)
_sentiment_analyzer = pipeline("sentiment-analysis")

def filter_positive_qa(questions, answers, threshold=0.7):
    """
    Given parallel lists `questions` and `answers`, run each answer through
    a local sentiment-analysis model and return a QA block that only
    includes pairs where the answer is classified as POSITIVE with
    confidence >= threshold.
    
    Returns:
        qa_block (str): lines of "Q: …\nA: …" joined by newlines
    """
    positive_pairs = []
    for q, a in zip(questions, answers):
        result = _sentiment_analyzer(a)[0]
        if result["label"] == "POSITIVE" and result["score"] >= threshold:
            positive_pairs.append((q, a))
    # build the final QA block
    qa_block = "\n".join(f"Q: {q}\nA: {a}" for q, a in positive_pairs)
    return qa_block


No model was supplied, defaulted to distilbert/distilbert-base-uncased-finetuned-sst-2-english and revision 714eb0f (https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english).
Using a pipeline without specifying a model name and revision in production is not recommended.
Device set to use cuda:0


In [12]:
def main():

    sex = input("Enter patient sex (male/female): ").strip().lower()
    age = int(input("Enter patient age in years: ").strip())
    
    JSONL_PATH    = "rag-dataset.jsonl"
    FAISS_INDEX   = "rag-index.faiss"
    META_PATH     = "rag-metadata.json"
    retrieval_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
    openai_client = OpenAI(api_key="YOUR-API-KEY")
    
    parser = argparse.ArgumentParser(
        description="Record from microphone and transcribe via Whisper"
    )
    parser.add_argument(
        "--model", type=str, default="base",
        choices=["tiny", "base", "small", "medium", "large"],
        help="Whisper model size"
    )
    parser.add_argument(
        "--duration", type=float, default=10.0,
        help="Recording duration in seconds"
    )
    parser.add_argument(
        "--fs", type=int, default=16000,
        help="Sampling rate (Hz)"
    )
    # ignore any unexpected args (e.g. Jupyter flags)
    args, _ = parser.parse_known_args()

    # Record from mic
    #audio = record_audio(duration=args.duration, fs=args.fs)
    
    # Transcribe
    #transcriber = SpeechTranscriber(model_name=args.model)
    #transcribed_text = transcriber.transcribe(audio)
    transcribed_text = "I have been having headaches for a while now. They are usually on the left side of my head and are very painful. I also get nausea, vomiting, and sensitivity to light and sound. I have tried taking over-the-counter pain relievers, but they don't seem to help much."
    
    print("\n--- Transcript ---")
    print(transcribed_text)
    
    # Extract symptoms from transcript
    extracted = extract_medical_entities(transcribed_text, comprehend_client)
    symptom_keywords = extracted["symptoms"]
    
    if symptom_keywords:
        query_text = " ".join(symptom_keywords)
        print("\n Using extracted symptoms for retrieval:")
        for s in symptom_keywords:
            print("•", s)
    else:
        query_text = transcribed_text
        print("\n No symptoms detected. Using full transcript for retrieval.")
    
    # Infermedica Knowledge graph
    symptom_sentence = "I am experiencing " + ", ".join(symptom_keywords)
    kg_disease_predictions = infermedica_diagnose(
        symptom_sentence,
        sex=sex,
        age=age,
        top_k=3
    )
    
    print("\n Infermedica Diagnosis:\n")
    for disease_info in kg_disease_predictions:
        print(f"Disease: {disease_info['name']} (Probability: {disease_info['probability']})")
    
    # Run RAG using only symptom keywords
    retrieved = run_rag_query(
        transcript=query_text,
        model=retrieval_model,
        index_path=FAISS_INDEX,
        meta_path=META_PATH,
        jsonl_path=JSONL_PATH,
        top_k=5
    )
    
    print("\n Generating final response from ChatGPT...")
    gpt_final_response = generate_rag_response(openai_client, text=query_text, retrieved_context=retrieved)
    print("\n ChatGPT Diagnosis:\n")
    print(gpt_final_response)
    
    chatgpt_diseases = extract_disease_names_from_chatgpt(gpt_final_response)
    
    # Merge and rank
    merged_diagnoses = merge_and_rank_diagnoses(
        chatgpt_diseases=chatgpt_diseases,
        kg_diseases=kg_disease_predictions,
        openai_client=openai_client
    )
    
    missing = get_missing_symptoms_chatgpt(
        diagnoses=[d["name"] for d in merged_diagnoses],
        reported_symptoms=symptom_keywords,
        openai_client=openai_client
    )
    print("\n Final Diagnosis and Missing Symptoms:\n")
    print(missing)

    # --- Follow-up loop cell: up to 3 iterations ---

    # Bootstrap the running transcript
    original_transcript = transcribed_text
    
    for iteration in range(1, 4):
        print(f"\n--- Iteration {iteration} ---")
    
        # 1. Get missing symptoms from ChatGPT
        missing_text = get_missing_symptoms_chatgpt(
            openai_client,
            reported_symptoms=symptom_keywords,
            diagnoses=[d["name"] for d in merged_diagnoses]
        )
        print("\nMissing Symptoms and Diseases:\n", missing_text)
    
        # If there are no missing symptoms, exit loop
        if not missing_text.strip():
            print("No missing symptoms; ending loop.")
            break
    
        # 2. Generate follow-up questions to confirm missing symptoms
        followup_prompt = f"""
    You are a medical diagnostic assistant.
    Transcript so far:
    {original_transcript}
    
    Based on this output:
    {missing_text}
    
    Generate 3 concise yes/no questions to confirm these missing symptoms with the patient. The goal of these questions is to narrow down which diagnosis is most accurate.
    """
        resp = openai_client.chat.completions.create(
            model="your-model",
            messages=[
                {"role": "system", "content": "You are a precise medical assistant."},
                {"role": "user",   "content": followup_prompt}
            ]
        )
        questions = [q.strip() for q in resp.choices[0].message.content.split("\n") if q.strip()]
    
        # 3. Display and collect patient answers
        answers = []
        print("\nFollow-up Questions:")
        for q in questions:
            # print(" -", q)
            ans = input(f"{q}\n> ")
            answers.append(ans.strip())
    
        # 4. Append the Q&A to the transcript
        # ... after collecting `questions` and `answers` lists ...
        
        qa_block = filter_positive_qa(questions, answers, threshold=0.7)
        original_transcript += "\n\n" + qa_block
        
            # 5. Re-extract symptoms and update variables
        if qa_block.strip():
            extracted = extract_medical_entities(original_transcript, comprehend_client)
            symptom_keywords = extracted["symptoms"]
            query_text = " ".join(symptom_keywords) if symptom_keywords else original_transcript
        
            # 6. Re-run KG diagnosis
            kg_disease_predictions = infermedica_diagnose(
                "I am experiencing " + ", ".join(symptom_keywords),
                sex=sex,
                age=age,
                top_k=3
                
            )
        
            # 7. Re-run RAG retrieval
            retrieved = run_rag_query(
                transcript=query_text,
                model=retrieval_model,
                index_path=FAISS_INDEX,
                meta_path=META_PATH,
                jsonl_path=JSONL_PATH,
                top_k=5
            )
        
            # 8. Re-generate ChatGPT response
            gpt_final_response = generate_rag_response(
                openai_client,
                text=query_text,
                retrieved_context=retrieved
            )
            chatgpt_diseases = extract_disease_names_from_chatgpt(gpt_final_response)
        
            # 9. Re-merge and rank diagnoses
            merged_diagnoses = merge_and_rank_diagnoses(
                chatgpt_diseases=chatgpt_diseases,
                kg_diseases=kg_disease_predictions,
                openai_client=openai_client
            )
        else:
            print("No positive answers — skipping re-diagnosis steps")
            break
    
    # 10. Print final diagnoses
    print("\n=== Final Diagnoses ===")
    for d in merged_diagnoses:
        score = d.get("score", d.get("probability", "N/A"))
        print(f"{d['name']} (score: {score})")


if __name__ == "__main__":
    main()
    # return transcribed_text, symptom_keywords, merged_diagnoses



# transcribed_text, symptom_keywords, merged_diagnoses = main()


Enter patient sex (male/female): male
Enter patient age in years: 30

--- Transcript ---
I have been having headaches for a while now. They are usually on the left side of my head and are very painful. I also get nausea, vomiting, and sensitivity to light and sound. I have tried taking over-the-counter pain relievers, but they don't seem to help much.

 Using extracted symptoms for retrieval:
• headaches
• nausea
• vomiting
• sensitivity to light and sound

 Infermedica Diagnosis:

Disease: Migraine (Probability: 0.429)
Disease: Viral gastroenteritis (Probability: 0.121)
Disease: Tension-type headaches (Probability: 0.083)

 Generating final response from ChatGPT...

 ChatGPT Diagnosis:

Disease: migraine without aura
Disease: vestibular migraine
Disease: tension-type headache

 Final Diagnosis and Missing Symptoms:

Disease: Migraine
Missing Symptoms: throbbing or pulsating pain, unilateral (one-sided) pain, aggravated by physical activity, visual aura (for some types)

Disease: Migra