In [23]:
!pip install datasets



In [24]:
# ==========================================
# LLM-based Disease Predictor using Symptoms
# Dataset: QuyenAnhDE/Diseases_Symptoms
# ==========================================

import random
import re
import warnings
import openai
import pandas as pd
from collections import defaultdict
from IPython.display import display
import ipywidgets as widgets
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics import accuracy_score, classification_report
from transformers import pipeline
from datasets import load_dataset

warnings.filterwarnings("ignore")

openai.api_key = "<KEY>"  # Replace with your actual key


In [25]:
# Mount google drive
from google.colab import drive
drive.mount('/content/drive')

path = '/content/drive/MyDrive/Colab Notebooks/AIH/HRP/'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [26]:
# =============================
# 1. LOAD AND PREPARE DATASETS
# =============================

# Load dataset from Hugging Face
dataset = load_dataset("QuyenAnhDE/Diseases_Symptoms")
df = dataset['train'].to_pandas()

# Check for missing values
print("Missing values:\n", df.isnull().sum())


Repo card metadata block was not found. Setting CardData to empty.


Missing values:
 Code          0
Name          0
Symptoms      0
Treatments    1
dtype: int64


In [27]:
df.shape

(400, 4)

In [28]:
df.head()

Unnamed: 0,Code,Name,Symptoms,Treatments
0,1,Panic disorder,"Palpitations, Sweating, Trembling, Shortness o...","Antidepressant medications, Cognitive Behavior..."
1,2,Vocal cord polyp,"Hoarseness, Vocal Changes, Vocal Fatigue","Voice Rest, Speech Therapy, Surgical Removal"
2,3,Turner syndrome,"Short stature, Gonadal dysgenesis, Webbed neck...","Growth hormone therapy, Estrogen replacement t..."
3,4,Cryptorchidism,"Absence or undescended testicle(s), empty scro...",Observation and monitoring (in cases of mild o...
4,5,Ethylene glycol poisoning-1,"Nausea, vomiting, abdominal pain, General mala...","Supportive Measures, Gastric Decontamination, ..."


In [29]:
# =============================
# 2. CLEANING & NORMALIZATION
# =============================

def clean_symptom(symptom):
    """Normalize symptom text."""
    return symptom.strip().lower()

In [30]:
# =============================
# 3. BUILD DISEASE INFORMATION
# =============================

disease_info = {}
for _, row in df.iterrows():
    disease = row['Name'].strip().lower()
    disease_info[disease] = {
        "Description": f"A condition known as {row['Name']}.",
        "Treatments": row.get("Treatments", "")
    }

In [31]:
# =============================
# 4. BUILD DISEASE-SYMPTOM MAP
# =============================

disease_symptoms = defaultdict(set)
all_symptom_set = set()

for _, row in df.iterrows():
    disease = row['Name'].strip().lower()
    if isinstance(row['Symptoms'], str):
        symptoms = [clean_symptom(s) for s in row['Symptoms'].split(',')]
        disease_symptoms[disease].update(symptoms)
        all_symptom_set.update(symptoms)

all_symptoms = sorted(list(all_symptom_set))

In [32]:
# =============================
# 5. SYMPTOM EMBEDDING MATCHING
# =============================

model = SentenceTransformer('all-MiniLM-L6-v2')
symptom_embeddings = model.encode(all_symptoms, convert_to_tensor=True)

def get_top_symptoms(user_input, top_k=5, similarity_threshold=0.45):
    """Return top matching symptoms based on semantic similarity."""
    if isinstance(user_input, list):
        user_input = " ".join(user_input)

    input_embedding = model.encode(user_input, convert_to_tensor=True)
    hits = util.semantic_search(input_embedding, symptom_embeddings, top_k=top_k)[0]
    return [all_symptoms[hit['corpus_id']] for hit in hits if hit['score'] >= similarity_threshold]

def match_disease(user_symptoms, weighted=True, top_k=5):
    """Match diseases based on symptom overlap."""
    disease_scores = []
    for disease, symptoms in disease_symptoms.items():
        matches = set(user_symptoms).intersection(symptoms)
        if not matches:
            continue
        score = len(matches) if weighted else len(matches) / len(symptoms)
        disease_scores.append((disease, score))
    return sorted(disease_scores, key=lambda x: x[1], reverse=True)[:top_k]

# Debug function to test matching
def debug_disease_match(user_input):
    symptoms = get_top_symptoms(user_input)
    print("Matched Symptoms:", symptoms)
    print("Top Diseases (weighted):", match_disease(symptoms, weighted=True))
    print("Top Diseases (raw):", match_disease(symptoms, weighted=False))

debug_disease_match("I'm dealing with fatigue, weight loss, excessive hunger, and polyuria.")

Matched Symptoms: ['extreme fatigue', 'increased hunger', 'proteinuria', 'muscle weakness or fatigue', 'increased need to urinate']
Top Diseases (weighted): [('chronic fatigue syndrome', 1), ('preeclampsia', 1), ('diabetic kidney disease', 1), ('hyperkalemia', 1), ('type 2 diabetes', 1)]
Top Diseases (raw): [('chronic fatigue syndrome', 0.3333333333333333), ('type 2 diabetes', 0.3333333333333333), ('hyperkalemia', 0.25), ('preeclampsia', 0.2), ('diabetic kidney disease', 0.16666666666666666)]


In [33]:
# =============================
# 6. LLM RESPONSE GENERATION
# =============================

def generate_response(user_input, llm_choice="OpenAI"):
    matched_symptoms = get_top_symptoms(user_input)
    top_diseases = match_disease(matched_symptoms)
    if not matched_symptoms:
        return "I couldn't detect any specific symptoms. Please describe clearly."

    if not top_diseases:
        return "I couldn't identify a possible condition. Please consult a doctor."

    context = ""
    for disease, score in top_diseases[:3]:
        desc = disease_info.get(disease, {}).get("Description", "No description available.")
        treatments = disease_info.get(disease, {}).get("Treatments", [])
        context += (
            f"\nDisease: {disease.title()}"
            f"\nMatch Score: {score:.2f}"
            f"\nDescription: {desc}"
            f"\Treatments: {treatments}\n"
        )

    prompt = f"""The user reports the following symptoms: \"{user_input}\"

    Here are some disease candidates based on symptom matching:
    {context}

    Using this information, provide a medical diagnosis response in the following format:

    Most Probable Diagnosis: <Disease Name>
    Description: <Description of the disease>
    Treatments:
    - <Treatments 1>
    - <Treatments 2>

    Other Possible Diagnoses:
      Diagnosis: <Alt Disease 1>
      Description: <Description of Alt Disease 1>
      Treatments:
        - <Treatments A1>
        - <Treatments A2>

      Diagnosis: <Alt Disease 2>
      Description: <Description of Alt Disease 2>
      Treatments:
        - <Treatments B1>
        - <Treatments B2>

    Only respond using this format. Do not add explanations or vary the structure."""

    if llm_choice.lower() == "openai":
        response = openai.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[{"role": "user", "content": prompt}]
        )
        return response.choices[0].message.content
    else:
        local_llm = pipeline("text-generation", model="tiiuae/falcon-7b-instruct", max_new_tokens=256)
        return local_llm(prompt)[0]['generated_text']

In [34]:
print(generate_response("I'm dealing with seizures, skin abnormalities, and developmental delays."))

Most Probable Diagnosis: Tuberous Sclerosis
Description: A condition known as Tuberous Sclerosis.
Treatments:
- Medications (anti-seizure, behavioral, etc.)
- Surgery (for tumor removal)
- Therapy (physical, occupational, speech)

Other Possible Diagnoses:
  Diagnosis: Ethylene Glycol Poisoning-2
  Description: A condition known as Ethylene glycol poisoning-2.
  Treatments:
    - Blood tests
    - Supportive Measures

  Diagnosis: Subdural Hemorrhage
  Description: A condition known as Subdural hemorrhage.
  Treatments:
    - Immediate medical attention
    - Monitoring of vital signs and neurological status


In [35]:
# ================================
# TESTING EXAMPLES
# ================================

print(generate_response("I have hoarseness, voice changes, vocal fatigue")) # Vocal Cord Polyp
print("\n-----------------------------------\n")
print(generate_response("I’ve been itchy skin particularly on the hands and feet")) # Atelectasis

Most Probable Diagnosis: Vocal Cord Polyp
Description: A condition known as Vocal cord polyp.
Treatments:
- Voice Rest
- Speech Therapy

Other Possible Diagnoses:
Diagnosis: Thoracic Aortic Aneurysm
Description: A condition known as Thoracic Aortic Aneurysm.
Treatments:
- Regular monitoring and imaging tests to assess aneurysm size and growth
- Lifestyle modifications (e.g., blood pressure control, avoiding smoking, managing cholesterol levels)

Diagnosis: Esophageal Cancer
Description: A condition known as Esophageal Cancer.
Treatments:
- Surgery
- Chemotherapy

-----------------------------------

Most Probable Diagnosis: Seborrheic Dermatitis
Description: A condition known as Seborrheic Dermatitis
Treatments:
- Topical antifungal creams or ointments
- Medicated shampoos
- Corticosteroid creams

Other Possible Diagnoses:
  Diagnosis: Hemorrhoids
  Description: A condition known as Hemorrhoids
  Treatments:
  - Increasing fiber intake
  - Warm sitz baths

  Diagnosis: Gestational Chol

In [36]:
# ===========================================
# 7. PREDICTION EVALUATION USING TEST DATA
# ===========================================

def extract_predicted_disease(response_text):
    """Extract most probable disease from response."""
    match = re.search(r"Most Probable Diagnosis: ([\w\s'()\-]+)(?=\n|$)", response_text, re.UNICODE)
    if match:
        return match.group(1).strip()
    return "Unknown"

def predict_disease(symptoms):
    """Predict disease using LLM and extract diagnosis."""
    response = generate_response(symptoms)
    return extract_predicted_disease(response)

# Example evaluation (assumes test CSV exists with 'Prompt' and 'Disease')
df_test = pd.read_csv(path + "hf_test_prompts.csv")
df_test.fillna("", inplace=True)
df_test['Predicted'] = df_test['Prompt'].apply(predict_disease)
df_test['Disease'] = df_test['Disease'].str.lower()
df_test['Predicted'] = df_test['Predicted'].str.lower()

# Evaluate performance by comparing predicted and actual diseases
accuracy = accuracy_score(df_test['Disease'], df_test['Predicted'])
print(f"Accuracy: {accuracy:.2f}")

print("\nPredictions vs Actuals:")
display(df_test[['Disease', 'Predicted']])

Accuracy: 0.52

Predictions vs Actuals:


Unnamed: 0,Disease,Predicted
0,bladder disorder,urinary tract infection (uti)
1,acute sinusitis,acute sinusitis
2,parkinson disease,extrapyramidal effect of drugs
3,cornea infection,corneal disorder
4,fibromyalgia,fibromyalgia
5,extrapyramidal effect of drugs,extrapyramidal effect of drugs
6,postpartum depression,postpartum depression
7,hashimoto thyroiditis,sjögren's syndrome
8,scleroderma,scleroderma
9,chronic migraine,ethylene glycol poisoning-2


In [None]:
# ================================
# 8. BASIC INTERACTIVE UI
# ================================

# Widgets for simple UI
from IPython.display import Markdown

symptom_input = widgets.Textarea(
    description='Enter Symptoms:',
    placeholder='e.g. I have fatigue, irregular sugar level, excessive hunger, increased appetite. What could be possible disease?',
    layout=widgets.Layout(width='600px', height='100px')
)

submit_button = widgets.Button(description="Check Disease")
output_area = widgets.Output()

def on_button_click(b):
    with output_area:
        output_area.clear_output()
        user_input = symptom_input.value.strip()
        if user_input:
            response = generate_response(user_input)
            print(response)
            #display(Markdown(f"**Diagnosis Result:**\n\n{response}"))
        else:
            print("Please enter symptoms.")

submit_button.on_click(on_button_click)

# Display the UI
display(symptom_input, submit_button, output_area)

Textarea(value='', description='Enter Symptoms:', layout=Layout(height='100px', width='600px'), placeholder='e…

Button(description='Check Disease', style=ButtonStyle())

Output()