In [1]:
import pandas as pd

# Load the CSV file
df = pd.read_csv("data/symptoms.csv")

# Show the first 5 rows
df.head()

Unnamed: 0,Symptom,Disease,Disease group (English name),Disease group (Sinhala name),Dosha types
0,Fever,Jvara,Endogenous,Nija Roga,Pitta
1,Chills with fever,Jvara,Endogenous,Nija Roga,Vata and Pitta
2,Sweating,Jvara,Endogenous,Nija Roga,Pitta
3,Thirst with fever,Jvara,Endogenous,Nija Roga,Pitta
4,Loss of appetite,Ajeerna,Endogenous,Nija Roga,Kapha


In [2]:
df["Disease group (English name)"].value_counts()

Disease group (English name)
Endogenous               149
Sweat                    113
Hereditary                98
Respiratory               90
thermoregulation          90
Reproductive – female     86
Exogenous                 85
Somatic                   85
Digestive                 84
Seasonal                  82
Musculoskeletal           74
Nervous                   70
Mental                    66
Incurable                 66
Cardio-blood              64
Over-nutrition            53
Natural                   52
External                  46
Urinary                   42
Internal                  41
Reproductive – male       40
Under-nutrition           37
Curable                   35
Congenital                34
Obesity                   34
Middle                    34
Psychosomatic             27
Metabolic                 18
Name: count, dtype: int64

In [3]:
df["Disease group (Sinhala name)"].value_counts()

Disease group (Sinhala name)
Sveda-vāha         203
Nija Roga          149
Sahaja              98
Prāṇavāha           90
Artava-vāha         86
Śārīrika            85
Āgantuka Roga       85
Annavāha            84
Asthi-vāha          74
Majjā-vāha          70
Mānāsika            66
Asādhya             66
Rasa/Rakta-vāha     64
Santarpanajanya     53
Medo-vāha           52
Svabhavaja          52
Ageing              51
Bāhya               46
Mutra-vāha          42
Abhyantara          41
Śukra-vāha          40
Apatarpanajanya     37
Sādhya              35
Madhyama            34
Garbhaja            34
Kālaja              31
Manodaihika         27
Name: count, dtype: int64

In [4]:
df["Disease"].value_counts()

Disease
Shvāsa             49
Atisveda           30
Shotha             24
Sthoulya           22
Prameha            22
                   ..
Mamsagata Vata      1
Rasa Agnimandya     1
Pittaja Kushta      1
Janu Shoola         1
Sphurana            1
Name: count, Length: 618, dtype: int64

In [5]:
df["Dosha types"].value_counts()

Dosha types
Vata               764
Kapha              403
Pitta              382
Vata and Pitta     102
Vata and Kapha      91
Pitta and Kapha     42
Tridosha            11
Name: count, dtype: int64

In [6]:
# Step 1: Load and Explore Data

import pandas as pd

# Load the CSV file (update the path if needed)
df = pd.read_csv("data/symptoms.csv")

# Show the first 5 rows (to get a quick look at the data)
print(df.head())

# Show the number of rows and columns
print("Shape of data:", df.shape)

# Show column names
print("Columns:", df.columns.tolist())

# Check for missing values
print("\nMissing values in each column:\n", df.isnull().sum())

             Symptom  Disease Disease group (English name)  \
0              Fever    Jvara                   Endogenous   
1  Chills with fever    Jvara                   Endogenous   
2           Sweating    Jvara                   Endogenous   
3  Thirst with fever    Jvara                   Endogenous   
4   Loss of appetite  Ajeerna                   Endogenous   

  Disease group (Sinhala name)     Dosha types  
0                    Nija Roga           Pitta  
1                    Nija Roga  Vata and Pitta  
2                    Nija Roga           Pitta  
3                    Nija Roga           Pitta  
4                    Nija Roga           Kapha  
Shape of data: (1795, 5)
Columns: ['Symptom', 'Disease', 'Disease group (English name)', 'Disease group (Sinhala name)', 'Dosha types']

Missing values in each column:
 Symptom                         0
Disease                         0
Disease group (English name)    0
Disease group (Sinhala name)    0
Dosha types                 

In [7]:
# Step 2: Convert symptoms (text) into numbers using TF-IDF

from sklearn.feature_extraction.text import TfidfVectorizer

# Select the text column (symptoms)
X = df["Symptom"]

# Create a TF-IDF Vectorizer
vectorizer = TfidfVectorizer()

# Fit the vectorizer and transform the symptoms
X_vectorized = vectorizer.fit_transform(X)

# Show the shape of the resulting matrix
print("TF-IDF matrix shape:", X_vectorized.shape)

TF-IDF matrix shape: (1795, 1157)


In [8]:
# Step 3: Training and evaluating different models

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

# Features (X) = Symptoms
X = df["Symptom"]

# Labels (y) = Disease group (English name) for now
y = df["Disease group (English name)"]

# Define the models we want to test
models = {
    "Logistic Regression": LogisticRegression(max_iter=1000),
    "Random Forest": RandomForestClassifier(n_estimators=200),
    "Decision Tree": DecisionTreeClassifier()
}

# Train and evaluate each model
for name, clf in models.items():
    pipe = Pipeline([
        ("tfidf", TfidfVectorizer()),  # Step 1: Convert text to numbers
        ("clf", clf)                   # Step 2: Train model
    ])
    pipe.fit(X, y)                      # Train
    y_pred = pipe.predict(X)            # Predict on same data
    acc = accuracy_score(y, y_pred)     # Check accuracy
    print(f"{name}: {acc:.2%}")

Logistic Regression: 68.86%
Random Forest: 99.44%
Decision Tree: 99.44%


In [9]:
# Step 3: Training and evaluating different models

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

# Features (X) = Symptoms
X = df["Symptom"]

# Labels (y) = Disease group (English name) for now
y = df["Disease group (Sinhala name)"]

# Define the models we want to test
models = {
    "Logistic Regression": LogisticRegression(max_iter=1000),
    "Random Forest": RandomForestClassifier(n_estimators=200),
    "Decision Tree": DecisionTreeClassifier()
}

# Train and evaluate each model
for name, clf in models.items():
    pipe = Pipeline([
        ("tfidf", TfidfVectorizer()),  # Step 1: Convert text to numbers
        ("clf", clf)                   # Step 2: Train model
    ])
    pipe.fit(X, y)                      # Train
    y_pred = pipe.predict(X)            # Predict on same data
    acc = accuracy_score(y, y_pred)     # Check accuracy
    print(f"{name}: {acc:.2%}")


Logistic Regression: 69.47%
Random Forest: 99.44%
Decision Tree: 99.44%


In [10]:
# Step 3: Training and evaluating different models

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

# Features (X) = Symptoms
X = df["Symptom"]

# Labels (y) = Disease group (English name) for now
y = df["Dosha types"]

# Define the models we want to test
models = {
    "Logistic Regression": LogisticRegression(max_iter=1000),
    "Random Forest": RandomForestClassifier(n_estimators=200),
    "Decision Tree": DecisionTreeClassifier()
}

# Train and evaluate each model
for name, clf in models.items():
    pipe = Pipeline([
        ("tfidf", TfidfVectorizer()),  # Step 1: Convert text to numbers
        ("clf", clf)                   # Step 2: Train model
    ])
    pipe.fit(X, y)                      # Train
    y_pred = pipe.predict(X)            # Predict on same data
    acc = accuracy_score(y, y_pred)     # Check accuracy
    print(f"{name}: {acc:.2%}")


Logistic Regression: 80.56%
Random Forest: 99.83%
Decision Tree: 99.83%


In [11]:
# ---------- English Group Model ----------
import numpy as np
import joblib
import pandas as pd
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression

# Load your dataset
df = pd.read_csv("data/symptoms.csv")   # adjust path to your CSV

X_en = df["Symptom"].astype(str)
y_en = df["Disease group (English name)"]

pipeline_en = Pipeline([
    ("tfidf", TfidfVectorizer(
        stop_words="english",
        lowercase=True,
        ngram_range=(1, 2),
        max_features=50_000,
        dtype=np.float32
    )),
    ("clf", LogisticRegression(
        solver="saga",
        max_iter=300,
        C=2.0,
        n_jobs=-1
    ))
])

pipeline_en.fit(X_en, y_en)
joblib.dump(pipeline_en, "disease_prediction_model_en_group.pkl", compress=3)
print("Saved: disease_prediction_model_en_group.pkl")

Saved: disease_prediction_model_en_group.pkl


In [12]:
# ---------- Sinhala Group Model ----------
import numpy as np
import joblib
import pandas as pd
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression

df = pd.read_csv("data/symptoms.csv")   # adjust if needed

X_si = df["Symptom"].astype(str)
y_si = df["Disease group (Sinhala name)"]   # change if your column is named differently

pipeline_si = Pipeline([
    ("tfidf", TfidfVectorizer(
        lowercase=True,
        ngram_range=(1, 2),
        max_features=50_000,
        dtype=np.float32
    )),
    ("clf", LogisticRegression(
        solver="saga",
        max_iter=300,
        C=2.0,
        n_jobs=-1
    ))
])

pipeline_si.fit(X_si, y_si)
joblib.dump(pipeline_si, "disease_prediction_model_si_group.pkl", compress=3)  # keep same filename you used
print("Saved: disease_prediction_model_si_group.pkl")


Saved: disease_prediction_model_si_group.pkl


In [None]:
# ---------- Disease Model ----------
import numpy as np
import joblib
import pandas as pd
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression

df = pd.read_csv("data/symptoms.csv")

X_disease = df["Symptom"].astype(str)
y_disease = df["Disease"]

pipeline_disease = Pipeline([
    ("tfidf", TfidfVectorizer(
        stop_words="english",
        lowercase=True,
        ngram_range=(1, 2),
        max_features=50_000,
        dtype=np.float32
    )),
    ("clf", LogisticRegression(
        solver="saga",
        max_iter=300,
        C=2.0,
        n_jobs=-1
    ))
])

pipeline_disease.fit(X_disease, y_disease)
joblib.dump(pipeline_disease, "disease_prediction_model_disease.pkl", compress=3)
print("Saved: ")


Saved: disease_prediction_model_disease.pkl


In [15]:
# ---------- Dosha Model ----------
import numpy as np
import joblib
import pandas as pd
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression

df = pd.read_csv("data/symptoms.csv")

X_dosha = df["Symptom"].astype(str)
y_dosha = df["Dosha types"]   # adjust if column has a different name

pipeline_dosha = Pipeline([
    ("tfidf", TfidfVectorizer(
        stop_words="english",
        lowercase=True,
        ngram_range=(1, 2),
        max_features=50_000,
        dtype=np.float32
    )),
    ("clf", LogisticRegression(
        solver="saga",
        max_iter=300,
        C=2.0,
        n_jobs=-1
    ))
])

pipeline_dosha.fit(X_dosha, y_dosha)
joblib.dump(pipeline_dosha, "dosha_classification_model.pkl", compress=3)
print("Saved: dosha_classification_model.pkl")


Saved: dosha_classification_model.pkl


In [16]:
%pip install ipywidgets

Note: you may need to restart the kernel to use updated packages.


In [17]:
import ipywidgets as widgets
widgets.IntSlider()

IntSlider(value=0)

In [18]:
# Suggestions (clickable list) + Predict Disease Group (English) from CSV (exact match only)
import pandas as pd, difflib
import ipywidgets as widgets
from IPython.display import display, Markdown

# Load data
df = pd.read_csv("data/symptoms.csv")
SYMPTOMS = sorted(set(df["Symptom"].dropna().astype(str).tolist()))
# map lowercase -> canonical symptom text
SYMPTOM_LOOKUP = {s.strip().lower(): s for s in SYMPTOMS}

# --- Suggestion helper (prefix + substring + fuzzy) ---
def get_suggestions(query, k=12):
    q = (query or "").strip().lower()
    if not q:
        return []
    prefix = [s for s in SYMPTOMS if s.lower().startswith(q)]
    substr = [s for s in SYMPTOMS if q in s.lower() and s not in prefix]
    fuzzy  = difflib.get_close_matches(q, SYMPTOMS, n=k*2, cutoff=0.6)
    fuzzy  = [s for s in fuzzy if s not in prefix and s not in substr]
    # merge + dedup
    out, seen = [], set()
    for s in prefix + substr + fuzzy:
        if s not in seen:
            out.append(s); seen.add(s)
        if len(out) >= k:
            break
    return out

def resolve_exact(text):
    return SYMPTOM_LOOKUP.get((text or "").strip().lower())

# --- Widgets ---
inp = widgets.Text(
    placeholder="Type or pick a symptom…",
    description="Symptom:",
    layout=widgets.Layout(width="80%")
)
# clickable suggestion list (no buttons)
sugg_list = widgets.Select(options=[], rows=6, layout=widgets.Layout(width="80%"))
pred_btn = widgets.Button(
    description="Predict (English Group)",
    button_style="success",   # 'primary', 'success', 'info', 'warning', 'danger' or ''
    layout=widgets.Layout(width="200px", height="50px")
)

out = widgets.Output()

# --- Wire up interactions ---
def on_text_change(change):
    suggs = get_suggestions(change["new"], k=12)
    sugg_list.options = suggs

def on_select_change(change):
    # when user clicks a suggestion, fill the input box
    if change["new"]:
        inp.value = change["new"]

def on_predict(_):
    out.clear_output()
    with out:
        canon = resolve_exact(inp.value)
        if not canon:
            display(Markdown("> ⚠️ Please select a valid symptom from the list or type an exact symptom from the CSV."))
            return
        # exact row from CSV: show ONLY the English disease group
        label = str(df.loc[df["Symptom"] == canon, "Disease group (English name)"].iloc[0])
        display(Markdown(f"### 🧾 Disease Group (English)\n**Symptom:** `{canon}`\n**Result:** `{label}`"))

# Events
inp.observe(on_text_change, names="value")
sugg_list.observe(on_select_change, names="value")
pred_btn.on_click(on_predict)

# Initial render
display(inp, sugg_list, pred_btn, out)


Text(value='', description='Symptom:', layout=Layout(width='80%'), placeholder='Type or pick a symptom…')

Select(layout=Layout(width='80%'), options=(), rows=6, value=None)

Button(button_style='success', description='Predict (English Group)', layout=Layout(height='50px', width='200p…

Output()

In [19]:
# Suggestions (clickable list) + Predict Disease Group (Sinhala) from CSV (exact match only)
import pandas as pd, difflib
import ipywidgets as widgets
from IPython.display import display, Markdown

# Load data
df = pd.read_csv("data/symptoms.csv")
SYMPTOMS = sorted(set(df["Symptom"].dropna().astype(str).tolist()))
SYMPTOM_LOOKUP = {s.strip().lower(): s for s in SYMPTOMS}

def get_suggestions(query, k=12):
    q = (query or "").strip().lower()
    if not q: return []
    prefix = [s for s in SYMPTOMS if s.lower().startswith(q)]
    substr = [s for s in SYMPTOMS if q in s.lower() and s not in prefix]
    fuzzy  = difflib.get_close_matches(q, SYMPTOMS, n=k*2, cutoff=0.6)
    fuzzy  = [s for s in fuzzy if s not in prefix and s not in substr]
    out, seen = [], set()
    for s in prefix + substr + fuzzy:
        if s not in seen:
            out.append(s); seen.add(s)
        if len(out) >= k: break
    return out

def resolve_exact(text):
    return SYMPTOM_LOOKUP.get((text or "").strip().lower())

# Widgets
inp_si = widgets.Text(placeholder="Type or pick a symptom…", description="Symptom:", layout=widgets.Layout(width="80%"))
sugg_list_si = widgets.Select(options=[], rows=6, layout=widgets.Layout(width="80%"))
pred_btn_si = widgets.Button(
    description="Predict (Sinhala Group)",
    button_style="success",   # 'primary', 'success', 'info', 'warning', 'danger' or ''
    layout=widgets.Layout(width="200px", height="50px")
)

out_si = widgets.Output()

def on_text_change_si(change):
    sugg_list_si.options = get_suggestions(change["new"], k=12)

def on_select_change_si(change):
    if change["new"]:
        inp_si.value = change["new"]

def on_predict_si(_):
    out_si.clear_output()
    with out_si:
        canon = resolve_exact(inp_si.value)
        if not canon:
            display(Markdown("> ⚠️ Please select a valid symptom from the list or type an exact symptom from the CSV."))
            return
        label = str(df.loc[df["Symptom"] == canon, "Disease group (Sinhala name)"].iloc[0])
        display(Markdown(f"### 🧾 Diease group(Sinahala)\n**Symptom:** `{canon}`\n**Result:** `{label}`"))

inp_si.observe(on_text_change_si, names="value")
sugg_list_si.observe(on_select_change_si, names="value")
pred_btn_si.on_click(on_predict_si)

display(inp_si, sugg_list_si, pred_btn_si, out_si)


Text(value='', description='Symptom:', layout=Layout(width='80%'), placeholder='Type or pick a symptom…')

Select(layout=Layout(width='80%'), options=(), rows=6, value=None)

Button(button_style='success', description='Predict (Sinhala Group)', layout=Layout(height='50px', width='200p…

Output()

In [20]:
# Suggestions (clickable list) + Predict Disease (Ayurveda) from CSV (exact match only)
import pandas as pd, difflib
import ipywidgets as widgets
from IPython.display import display, Markdown

# Load data
df = pd.read_csv("data/symptoms.csv")
SYMPTOMS = sorted(set(df["Symptom"].dropna().astype(str).tolist()))
SYMPTOM_LOOKUP = {s.strip().lower(): s for s in SYMPTOMS}  # lowercase -> canonical

def get_suggestions(query, k=12):
    q = (query or "").strip().lower()
    if not q:
        return []
    prefix = [s for s in SYMPTOMS if s.lower().startswith(q)]
    substr = [s for s in SYMPTOMS if q in s.lower() and s not in prefix]
    fuzzy  = difflib.get_close_matches(q, SYMPTOMS, n=k*2, cutoff=0.6)
    fuzzy  = [s for s in fuzzy if s not in prefix and s not in substr]
    # merge + dedup
    out, seen = [], set()
    for s in prefix + substr + fuzzy:
        if s not in seen:
            out.append(s); seen.add(s)
        if len(out) >= k:
            break
    return out

def resolve_exact(text):
    return SYMPTOM_LOOKUP.get((text or "").strip().lower())

# Widgets
inp_dis = widgets.Text(
    placeholder="Type or pick a symptom…",
    description="Symptom:",
    layout=widgets.Layout(width="80%")
)
sugg_list_dis = widgets.Select(options=[], rows=6, layout=widgets.Layout(width="80%"))
pred_btn_dis = widgets.Button(
    description="Predict (Disease)",
    button_style="success",   # 'primary', 'success', 'info', 'warning', 'danger' or ''
    layout=widgets.Layout(width="200px", height="50px")
)

out_dis = widgets.Output()

def on_text_change_dis(change):
    sugg_list_dis.options = get_suggestions(change["new"], k=12)

def on_select_change_dis(change):
    if change["new"]:
        inp_dis.value = change["new"]  # clicking suggestion fills input

def on_predict_dis(_):
    out_dis.clear_output()
    with out_dis:
        canon = resolve_exact(inp_dis.value)
        if not canon:
            display(Markdown("> ⚠️ Please select a valid symptom from the list or type an exact symptom from the CSV."))
            return
        # exact row from CSV: show ONLY the Disease (Ayurveda) column
        label = str(df.loc[df["Symptom"] == canon, "Disease"].iloc[0])
        display(Markdown(f"### 🧾 Disease (Ayurveda)\n**Symptom:** `{canon}`\n**Result:** `{label}`"))

# Wire up
inp_dis.observe(on_text_change_dis, names="value")
sugg_list_dis.observe(on_select_change_dis, names="value")
pred_btn_dis.on_click(on_predict_dis)

display(inp_dis, sugg_list_dis, pred_btn_dis, out_dis)


Text(value='', description='Symptom:', layout=Layout(width='80%'), placeholder='Type or pick a symptom…')

Select(layout=Layout(width='80%'), options=(), rows=6, value=None)

Button(button_style='success', description='Predict (Disease)', layout=Layout(height='50px', width='200px'), s…

Output()

In [21]:
# Suggestions (clickable list) + Predict Dosha Type from CSV (exact match only)
import pandas as pd, difflib
import ipywidgets as widgets
from IPython.display import display, Markdown

# Load data
df = pd.read_csv("data/symptoms.csv")
SYMPTOMS = sorted(set(df["Symptom"].dropna().astype(str).tolist()))
SYMPTOM_LOOKUP = {s.strip().lower(): s for s in SYMPTOMS}

def get_suggestions(query, k=12):
    q = (query or "").strip().lower()
    if not q: return []
    prefix = [s for s in SYMPTOMS if s.lower().startswith(q)]
    substr = [s for s in SYMPTOMS if q in s.lower() and s not in prefix]
    fuzzy  = difflib.get_close_matches(q, SYMPTOMS, n=k*2, cutoff=0.6)
    fuzzy  = [s for s in fuzzy if s not in prefix and s not in substr]
    out, seen = [], set()
    for s in prefix + substr + fuzzy:
        if s not in seen:
            out.append(s); seen.add(s)
        if len(out) >= k: break
    return out

def resolve_exact(text):
    return SYMPTOM_LOOKUP.get((text or "").strip().lower())

# Widgets
inp_do = widgets.Text(placeholder="Type or pick a symptom…", description="Symptom:", layout=widgets.Layout(width="80%"))
sugg_list_do = widgets.Select(options=[], rows=6, layout=widgets.Layout(width="80%"))
pred_btn_do = widgets.Button(
    description="Predict (Dosha Type)",
    button_style="success",   # 'primary', 'success', 'info', 'warning', 'danger' or ''
    layout=widgets.Layout(width="200px", height="50px")
)

out_do = widgets.Output()

def on_text_change_do(change):
    sugg_list_do.options = get_suggestions(change["new"], k=12)

def on_select_change_do(change):
    if change["new"]:
        inp_do.value = change["new"]

def on_predict_do(_):
    out_do.clear_output()
    with out_do:
        canon = resolve_exact(inp_do.value)
        if not canon:
            display(Markdown("> ⚠️ Please select a valid symptom from the list or type an exact symptom from the CSV."))
            return
        label = str(df.loc[df["Symptom"] == canon, "Dosha types"].iloc[0])
        display(Markdown(f"### 🧾 Dosha Type\n**Symptom:** `{canon}`\n**Result:** `{label}`"))

inp_do.observe(on_text_change_do, names="value")
sugg_list_do.observe(on_select_change_do, names="value")
pred_btn_do.on_click(on_predict_do)

display(inp_do, sugg_list_do, pred_btn_do, out_do)


Text(value='', description='Symptom:', layout=Layout(width='80%'), placeholder='Type or pick a symptom…')

Select(layout=Layout(width='80%'), options=(), rows=6, value=None)

Button(button_style='success', description='Predict (Dosha Type)', layout=Layout(height='50px', width='200px')…

Output()

In [22]:
# Suggestions (clickable list) + Predict full risk (EN group + Dosha + weights + score + level)
import pandas as pd, difflib
import ipywidgets as widgets
from IPython.display import display, Markdown

# ---------------------------
# Load data
# ---------------------------
df = pd.read_csv("data/symptoms.csv")
SYMPTOMS = sorted(set(df["Symptom"].dropna().astype(str).tolist()))
SYMPTOM_LOOKUP = {s.strip().lower(): s for s in SYMPTOMS}  # lowercase -> canonical

# ---------------------------
# Weights (0–10) aligned to your CSV Disease Groups + Doshas
# ---------------------------
disease_group_weight = {
    "Incurable": 10,
    "Cardio-blood": 9,
    "Respiratory": 9,
    "Mental": 8,
    "Nervous": 8,
    "Digestive": 7,
    "Musculoskeletal": 7,
    "Urinary": 7,
    "Reproductive – female": 7,
    "Reproductive – male": 7,
    "Congenital": 6,
    "Obesity": 6,
    "Over-nutrition": 6,
    "Under-nutrition": 6,
    "Internal": 5,
    "Endogenous": 5,
    "Exogenous": 5,
    "Somatic": 5,
    "Psychosomatic": 5,
    "Seasonal": 4,
    "Sweat": 4,
    "Natural": 4,
    "External": 4,
    "Middle": 4,
    "Hereditary": 3,
    "thermoregulation": 3,
    "Curable": 2,
    "Metabolic": 2,
}

dosha_weight = {
    "Tridosha": 9,
    "Pitta and Kapha": 8,
    "Vata and Pitta": 7,
    "Vata and Kapha": 7,
    "Pitta": 6,
    "Vata": 5,
    "Kapha": 4,
}

W_GROUP = 0.6   # weight in formula for disease group severity
W_DOSHA = 0.4   # weight in formula for dosha modulation

# ---------------------------
# Helpers
# ---------------------------
def resolve_exact(text: str):
    return SYMPTOM_LOOKUP.get((text or "").strip().lower())

def risk_level_from_score(score: float):
    if score < 4:
        return "🟢 Low Risk"
    elif score < 7:
        return "🟠 Medium Risk"
    else:
        return "🔴 High Risk"

def compute_risk_for_symptom(symptom_text: str, w_group=W_GROUP, w_dosha=W_DOSHA):
    """Exact match lookup from CSV. Returns dict with all required fields."""
    canon = resolve_exact(symptom_text)
    if not canon:
        return {"found": False, "message": "Symptom not found. Please select from suggestions or type exact."}

    row = df.loc[df["Symptom"] == canon].iloc[0]
    disease_name = str(row["Disease"])  # Ayurveda disease name column
    group_en = str(row["Disease group (English name)"])
    group_si = str(row["Disease group (Sinhala name)"])
    dosha = str(row["Dosha types"])

    g_w = float(disease_group_weight.get(group_en, 0.0))
    d_w = float(dosha_weight.get(dosha, 0.0))

    score = round(w_group * g_w + w_dosha * d_w, 2)
    level = risk_level_from_score(score)

    return {
        "found": True,
        "symptom": canon,
        "disease_name": disease_name,
        "disease_group_en": group_en,
        "disease_group_si": group_si,
        "dosha": dosha,
        "group_weight": g_w,
        "dosha_weight": d_w,
        "formula": f"Risk = {w_group}×{g_w} + {w_dosha}×{d_w}",
        "risk_score_0_10": score,
        "risk_level": level
    }

def get_suggestions(query, k=12):
    q = (query or "").strip().lower()
    if not q:
        return []
    prefix = [s for s in SYMPTOMS if s.lower().startswith(q)]
    substr = [s for s in SYMPTOMS if q in s.lower() and s not in prefix]
    fuzzy  = difflib.get_close_matches(q, SYMPTOMS, n=k*2, cutoff=0.6)
    fuzzy  = [s for s in fuzzy if s not in prefix and s not in substr]
    # merge + dedup
    out, seen = [], set()
    for s in prefix + substr + fuzzy:
        if s not in seen:
            out.append(s); seen.add(s)
        if len(out) >= k:
            break
    return out

# ---------------------------
# Widgets (styled Predict button)
# ---------------------------
inp = widgets.Text(
    placeholder="Type or pick a symptom…",
    description="Symptom:",
    layout=widgets.Layout(width="80%")
)
sugg_list = widgets.Select(options=[], rows=6, layout=widgets.Layout(width="80%"))

pred_btn = widgets.Button(
    description="Predict Risk",
    button_style="success",
    layout=widgets.Layout(width="220px", height="55px")
)
pred_btn.style.button_color = "#4CAF50"

out = widgets.Output()

# ---------------------------
# Wire up interactions
# ---------------------------
def on_text_change(change):
    sugg_list.options = get_suggestions(change["new"], k=12)

def on_select_change(change):
    if change["new"]:
        inp.value = change["new"]

def on_predict(_):
    out.clear_output()
    with out:
        res = compute_risk_for_symptom(inp.value)
        if not res.get("found"):
            display(Markdown("> ⚠️ Please select a valid symptom from the list or type an exact symptom from the CSV."))
            return

        display(Markdown(
f"""### 🧾 Risk Assessment
**Symptom:** `{res['symptom']}`  
**Disease Name (Ayurveda):** `{res['disease_name']}`  
**Disease Group (English):** `{res['disease_group_en']}`  
**Disease Group (Sinhala):** `{res['disease_group_si']}`  
**Dosha Type:** `{res['dosha']}`  

**Disease Group Weight:** `{res['group_weight']}`  
**Dosha Weight:** `{res['dosha_weight']}`  

**How Risk is Calculated:** `{res['formula']}`  
**Risk Score (0–10):** `{res['risk_score_0_10']}`  
**Risk Level:** **{res['risk_level']}**"""
        ))

# Events
inp.observe(on_text_change, names="value")
sugg_list.observe(on_select_change, names="value")
pred_btn.on_click(on_predict)

# Render
display(inp, sugg_list, pred_btn, out)

Text(value='', description='Symptom:', layout=Layout(width='80%'), placeholder='Type or pick a symptom…')

Select(layout=Layout(width='80%'), options=(), rows=6, value=None)

Button(button_style='success', description='Predict Risk', layout=Layout(height='55px', width='220px'), style=…

Output()