In [1]:
Core ideas:
- Knowledge Representation: diseases, symptoms, risk factors, tests, red flags
- Logic programming via declarative IF-THEN rules with certainty factors
- Inference: rule matching over patient facts (+ lightweight backward chaining)
- Explainability: rule audit trail, per-disease score breakdown, red flags
"""

from dataclasses import dataclass, field
from typing import Callable, Dict, List, Tuple, Any, Optional
import math
import sys

# -----------------------------
# Utilities
# -----------------------------
def yesno(s: str) -> bool:
    return str(s).strip().lower() in {"y", "yes", "1", "true", "t"}

def safe_float(v: Any, default: float = float("nan")) -> float:
    try:
        return float(v)
    except Exception:
        return default

def sigmoid(x: float) -> float:
    # Smooths numeric features into (0,1) for certainty adjustments
    return 1 / (1 + math.exp(-x))

# -----------------------------
# Knowledge Representation
# -----------------------------
@dataclass
class Rule:
    """IF conditions -> THEN (conclusion, weight)."""
    name: str
    conditions: List[Callable[[Dict[str, Any]], bool]]
    conclusion: Tuple[str, float]  # (disease_id, certainty_weight 0..1)
    why: str  # human explanation for the rule

@dataclass
class Disease:
    id: str
    label: str
    summary: str
    typical_symptoms: List[str]
    risk_factors: List[str] = field(default_factory=list)
    recommended_tests: List[str] = field(default_factory=list)
    red_flags: List[str] = field(default_factory=list)

@dataclass
class KnowledgeBase:
    diseases: Dict[str, Disease]
    rules: List[Rule]

# -----------------------------
# Domain Knowledge Base (editable, declarative)
# NOTE: This is a starter KB. Expand with your local guidelines & SME review.
# -----------------------------
def build_kb() -> KnowledgeBase:
    diseases = {
        "common_cold": Disease(
            id="common_cold",
            label="Common Cold (Viral URTI)",
            summary="Self-limiting viral infection with runny nose, sneezing, sore throat, mild fever.",
            typical_symptoms=["runny_nose", "sore_throat", "sneezing", "mild_fever", "cough"],
            risk_factors=["recent_cold_contact"],
            recommended_tests=["No tests typically required; symptomatic management."],
            red_flags=[]
        ),
        "influenza": Disease(
            id="influenza",
            label="Influenza (Flu)",
            summary="Acute onset fever, myalgia, headache, cough; seasonal spikes.",
            typical_symptoms=["fever", "myalgia", "headache", "dry_cough", "fatigue"],
            risk_factors=["seasonal_outbreak", "no_flu_vaccine", "elderly", "chronic_disease"],
            recommended_tests=["Rapid influenza test (if available)", "Consider CBC if severe"],
            red_flags=["dyspnea", "hypoxia", "confusion", "chest_pain"]
        ),
        "pneumonia": Disease(
            id="pneumonia",
            label="Community-Acquired Pneumonia",
            summary="Infectious consolidation; cough, fever, pleuritic chest pain, dyspnea.",
            typical_symptoms=["fever", "productive_cough", "chest_pain", "dyspnea"],
            risk_factors=["elderly", "smoker", "COPD", "recent_influenza"],
            recommended_tests=["Chest X-ray", "SpO2", "CBC", "CRP"],
            red_flags=["SpO2<92", "resp_rate>30", "confusion", "SBP<90"]
        ),
        "dengue": Disease(
            id="dengue",
            label="Dengue Fever",
            summary="Acute febrile illness with headache, retro-orbital pain, myalgia, rash; warning signs around day 3–7.",
            typical_symptoms=["fever", "headache", "rash", "myalgia", "retro_orbital_pain"],
            risk_factors=["mosquito_exposure", "outbreak_area"],
            recommended_tests=["NS1 (day 1–5)", "IgM/IgG (after day 5)", "CBC for platelet, hematocrit"],
            red_flags=["bleeding", "persistent_vomiting", "severe_abdominal_pain", "mucosal_bleed", "lethargy"]
        ),
        "t2dm": Disease(
            id="t2dm",
            label="Type 2 Diabetes Mellitus",
            summary="Metabolic disease with polyuria, polydipsia, weight changes; risk with obesity, family history.",
            typical_symptoms=["polyuria", "polydipsia", "polyphagia", "fatigue", "blurred_vision"],
            risk_factors=["obesity", "family_history_dm", "sedentary", "gestational_dm_history"],
            recommended_tests=["Fasting plasma glucose", "HbA1c", "OGTT (if needed)"],
            red_flags=["DKA_signs (abdominal_pain, vomiting, rapid_breathing)"]
        ),
    }

    # Helper condition builders
    def has(symptom: str):
        return lambda facts: bool(facts.get("symptoms", {}).get(symptom, False))

    def nothas(symptom: str):
        return lambda facts: not bool(facts.get("symptoms", {}).get(symptom, False))

    def risk(r: str):
        return lambda facts: bool(facts.get("risks", {}).get(r, False))

    def age_ge(threshold: int):
        return lambda facts: safe_float(facts.get("age", 0)) >= threshold

    def temp_ge(celsius: float):
        return lambda facts: safe_float(facts.get("vitals", {}).get("temp_c", float("nan"))) >= celsius

    def spo2_lt(pct: float):
        return lambda facts: safe_float(facts.get("vitals", {}).get("spo2", float("nan"))) < pct

    def resp_rate_gt(n: int):
        return lambda facts: safe_float(facts.get("vitals", {}).get("resp_rate", float("nan"))) > n

    def day_of_illness_between(a: int, b: int):
        return lambda facts: a <= safe_float(facts.get("timeline", {}).get("day_of_illness", float("nan"))) <= b

    def lab_lt(name: str, threshold: float):
        return lambda facts: safe_float(facts.get("labs", {}).get(name, float("nan"))) < threshold

    def lab_gt(name: str, threshold: float):
        return lambda facts: safe_float(facts.get("labs", {}).get(name, float("nan"))) > threshold

    rules: List[Rule] = [
        # Common cold
        Rule(
            name="cold_core",
            conditions=[has("runny_nose"), has("sore_throat"), has("sneezing")],
            conclusion=("common_cold", 0.65),
            why="Runny nose + sore throat + sneezing are classic viral URTI features."
        ),
        Rule(
            name="cold_mild_fever",
            conditions=[has("runny_nose"), has("cough"), has("mild_fever"), nothas("myalgia")],
            conclusion=("common_cold", 0.35),
            why="Mild fever with coryzal symptoms and cough, without systemic myalgia, suggests common cold."
        ),

        # Influenza
        Rule(
            name="flu_core",
            conditions=[has("fever"), has("myalgia"), has("headache"), has("dry_cough")],
            conclusion=("influenza", 0.7),
            why="Abrupt fever + myalgia + headache + dry cough are hallmark of influenza."
        ),
        Rule(
            name="flu_season_risk",
            conditions=[has("fever"), has("cough"), risk("seasonal_outbreak")],
            conclusion=("influenza", 0.3),
            why="Cough with fever during seasonal outbreak increases odds of influenza."
        ),

        # Pneumonia
        Rule(
            name="pna_core",
            conditions=[has("fever"), has("productive_cough"), has("chest_pain")],
            conclusion=("pneumonia", 0.55),
            why="Fever + productive cough + pleuritic chest pain are typical for pneumonia."
        ),
        Rule(
            name="pna_severity",
            conditions=[spo2_lt(92), resp_rate_gt(30)],
            conclusion=("pneumonia", 0.45),
            why="Hypoxia (SpO2<92) and tachypnea (>30/min) indicate pneumonia severity."
        ),

        # Dengue
        Rule(
            name="dengue_core",
            conditions=[has("fever"), has("headache"), has("myalgia"), risk("outbreak_area")],
            conclusion=("dengue", 0.6),
            why="Acute fever + headache + myalgia in outbreak area is suggestive of dengue."
        ),
        Rule(
            name="dengue_retroorbital",
            conditions=[has("retro_orbital_pain")],
            conclusion=("dengue", 0.25),
            why="Retro-orbital pain increases diagnostic suspicion of dengue."
        ),
        Rule(
            name="dengue_warning_platelet",
            conditions=[day_of_illness_between(3, 7), lab_lt("platelet", 100_000)],
            conclusion=("dengue", 0.35),
            why="Day 3–7 with thrombocytopenia is compatible with dengue warning phase."
        ),

        # Type 2 DM
        Rule(
            name="t2dm_core",
            conditions=[has("polyuria"), has("polydipsia")],
            conclusion=("t2dm", 0.55),
            why="Polyuria + polydipsia are classic hyperglycemia features."
        ),
        Rule(
            name="t2dm_risk",
            conditions=[risk("obesity"), risk("family_history_dm")],
            conclusion=("t2dm", 0.35),
            why="Obesity and family history substantially increase T2DM risk."
        ),
        Rule(
            name="t2dm_hba1c",
            conditions=[lab_gt("hba1c", 6.4)],
            conclusion=("t2dm", 0.6),
            why="HbA1c ≥ 6.5% is diagnostic for diabetes (adults)."
        ),
    ]

    return KnowledgeBase(diseases=diseases, rules=rules)

# -----------------------------
# Inference Engine
# -----------------------------
class InferenceEngine:
    def __init__(self, kb: KnowledgeBase):
        self.kb = kb

    def evaluate(self, facts: Dict[str, Any]) -> Dict[str, Any]:
        """
        Run rule matching over facts. Accumulate per-disease scores and why-trail.
        Returns a report dict containing:
          - scores: dict[disease_id] -> score (0..1)
          - explanations: dict[disease_id] -> list of (rule_name, why, weight)
          - red_flags_triggered: list[str]
          - recommendations: dict[disease_id] -> recommended tests
        """
        scores: Dict[str, float] = {}
        explanations: Dict[str, List[Tuple[str, str, float]]] = {}
        fired_rules: List[str] = []

        # Fire rules
        for rule in self.kb.rules:
            try:
                if all(cond(facts) for cond in rule.conditions):
                    d_id, w = rule.conclusion
                    scores[d_id] = scores.get(d_id, 0.0) + w
                    explanations.setdefault(d_id, []).append((rule.name, rule.why, w))
                    fired_rules.append(rule.name)
            except Exception as e:
                # Safety: ignore malformed condition exceptions
                continue

        # Clip and normalize to 0..1 using 1 - exp(-x) for diminishing returns
        for d_id, raw in list(scores.items()):
            scores[d_id] = 1 - math.exp(-raw)  # smooth cap < 1.0

        # Red flags detection
        triggered_red_flags: List[str] = []
        for d in self.kb.diseases.values():
            for rf in d.red_flags:
                # RF may be semantic patterns; we look in symptoms, vitals, labs
                if rf.startswith("SpO2<"):
                    thr = safe_float(rf.replace("SpO2<", ""))
                    val = safe_float(facts.get("vitals", {}).get("spo2", float("nan")))
                    if not math.isnan(val) and val < thr:
                        triggered_red_flags.append(f"{d.label}: SpO2<{thr} (actual {val})")
                elif rf.startswith("resp_rate>"):
                    thr = safe_float(rf.replace("resp_rate>", ""))
                    val = safe_float(facts.get("vitals", {}).get("resp_rate", float("nan")))
                    if not math.isnan(val) and val > thr:
                        triggered_red_flags.append(f"{d.label}: RR>{thr} (actual {val})")
                else:
                    # Symptom-style red flags
                    if facts.get("symptoms", {}).get(rf, False):
                        triggered_red_flags.append(f"{d.label}: {rf}")

        # Recommendations
        recommendations: Dict[str, List[str]] = {}
        for d_id, d in self.kb.diseases.items():
            if scores.get(d_id, 0) > 0:
                recommendations[d_id] = d.recommended_tests

        # Compile summary ranked list
        ranking = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)

        return {
            "scores": scores,
            "ranking": ranking,
            "explanations": explanations,
            "red_flags_triggered": triggered_red_flags,
            "recommendations": recommendations,
            "fired_rules": fired_rules,
        }

# -----------------------------
# Simple Backward-Chaining Query (optional)
# -----------------------------
def supports_disease(engine: InferenceEngine, facts: Dict[str, Any], disease_id: str) -> Dict[str, Any]:
    """
    Show which rules could support a given disease and whether their conditions hold.
    """
    details = []
    for r in engine.kb.rules:
        if r.conclusion[0] == disease_id:
            cond_results = []
            ok = True
            for c in r.conditions:
                try:
                    res = c(facts)
                except Exception:
                    res = False
                cond_results.append(res)
                ok = ok and res
            details.append({
                "rule": r.name, "all_met": ok, "conditions_met": cond_results, "why": r.why, "weight": r.conclusion[1]
            })
    return {"disease": disease_id, "rules": details}

# -----------------------------
# Interactive CLI
# -----------------------------
def prompt_patient_facts() -> Dict[str, Any]:
    print("\n=== Medical Diagnosis Expert System ===")
    print("Enter basic details (press Enter to skip):")
    age = input("Age (years): ").strip()
    sex = input("Sex (M/F): ").strip().lower()
    day = input("Day of illness (integer): ").strip()

    print("\nVitals (optional):")
    temp_c = input("Temperature (°C): ").strip()
    spo2 = input("SpO2 (%): ").strip()
    resp_rate = input("Respiratory rate (/min): ").strip()

    def ask_sym(name: str) -> bool:
        return yesno(input(f"Symptom: {name.replace('_',' ')}? (y/n): "))

    print("\nMark symptoms (y/n). Leave blank as 'n':")
    symptom_list = [
        "fever","mild_fever","myalgia","headache","dry_cough","cough","productive_cough",
        "runny_nose","sore_throat","sneezing","chest_pain","dyspnea",
        "rash","retro_orbital_pain","fatigue",
        "polyuria","polydipsia","polyphagia","blurred_vision",
        # Red-flag style:
        "bleeding","persistent_vomiting","severe_abdominal_pain","mucosal_bleed","lethargy","confusion"
    ]
    symptoms = {s: ask_sym(s) for s in symptom_list}

    print("\nRisk factors / context (y/n):")
    risk_list = [
        "seasonal_outbreak","no_flu_vaccine","elderly","chronic_disease","smoker","COPD",
        "recent_influenza","mosquito_exposure","outbreak_area","obesity","family_history_dm",
        "sedentary","gestational_dm_history","recent_cold_contact"
    ]
    risks = {r: yesno(input(f"Risk: {r.replace('_',' ')}? (y/n): ")) for r in risk_list}

    print("\nLabs (optional; Enter to skip):")
    platelet = input("Platelet count (/µL): ").strip()
    hba1c = input("HbA1c (%): ").strip()

    facts = {
        "age": safe_float(age) if age else None,
        "sex": sex if sex in {"m", "f"} else None,
        "timeline": {"day_of_illness": int(day) if day.isdigit() else None},
        "vitals": {
            "temp_c": safe_float(temp_c) if temp_c else None,
            "spo2": safe_float(spo2) if spo2 else None,
            "resp_rate": safe_float(resp_rate) if resp_rate else None,
        },
        "symptoms": symptoms,
        "risks": risks,
        "labs": {
            "platelet": safe_float(platelet) if platelet else None,
            "hba1c": safe_float(hba1c) if hba1c else None,
        }
    }
    # Convenience: set elderly flag if age>=65
    if (facts.get("age") or 0) >= 65:
        facts["risks"]["elderly"] = True
    return facts

# -----------------------------
# Reporting
# -----------------------------
def print_report(kb: KnowledgeBase, report: Dict[str, Any], top_k: int = 5) -> None:
    diseases = kb.diseases
    scores = report["scores"]
    ranking = report["ranking"]
    explanations = report["explanations"]
    red_flags = report["red_flags_triggered"]
    recommendations = report["recommendations"]

    print("\n=== Differential Diagnosis (Top Matches) ===")
    if not ranking:
        print("No rules fired. Consider entering more symptoms/history/labs.")
        return

    for i, (d_id, score) in enumerate(ranking[:top_k], start=1):
        d = diseases[d_id]
        print(f"\n{i}. {d.label}  | Score: {score:.2f}")
        print(f"   Summary: {d.summary}")
        if d.risk_factors:
            print(f"   Typical risks: {', '.join(d.risk_factors)}")
        if d.typical_symptoms:
            print(f"   Typical symptoms: {', '.join(d.typical_symptoms)}")

        # Why
        if explanations.get(d_id):
            print("   Why (fired rules):")
            for (rname, why, w) in explanations[d_id]:
                print(f"     - [{rname}] +{w:.2f}: {why}")

        # Recommendations
        if recommendations.get(d_id):
            print("   Suggested next steps / tests:")
            for t in recommendations[d_id]:
                print(f"     - {t}")

    if red_flags:
        print("\n=== RED FLAGS DETECTED (Prioritize urgent evaluation) ===")
        for rf in red_flags:
            print(f" - {rf}")

    print("\nNote: Scores are heuristic certainty estimates from rules (0–1, capped).")
    print("Clinical judgment & local guidelines are essential. This is an educational tool.")

# -----------------------------
# Demo / Program Entry
# -----------------------------
def demo_cases(engine: InferenceEngine):
    print("\n>>> Running quick demos")
    # Flu-like demo
    facts1 = {
        "age": 30, "sex": "f",
        "timeline": {"day_of_illness": 2},
        "vitals": {"temp_c": 39.0, "spo2": 97, "resp_rate": 18},
        "symptoms": {"fever": True, "myalgia": True, "headache": True, "dry_cough": True},
        "risks": {"seasonal_outbreak": True},
        "labs": {}
    }
    r1 = engine.evaluate(facts1)
    print_report(engine.kb, r1, top_k=3)

    # Dengue demo (warning phase)
    facts2 = {
        "age": 24, "sex": "m",
        "timeline": {"day_of_illness": 4},
        "vitals": {"temp_c": 38.3, "spo2": 98, "resp_rate": 16},
        "symptoms": {"fever": True, "headache": True, "myalgia": True, "retro_orbital_pain": True},
        "risks": {"outbreak_area": True, "mosquito_exposure": True},
        "labs": {"platelet": 85000}
    }
    r2 = engine.evaluate(facts2)
    print_report(engine.kb, r2, top_k=3)

def main():
    kb = build_kb()
    engine = InferenceEngine(kb)

    if len(sys.argv) > 1 and sys.argv[1] == "--demo":
        demo_cases(engine)
        return

    # Interactive flow
    facts = prompt_patient_facts()
    report = engine.evaluate(facts)
    print_report(kb, report, top_k=5)

    # Optional: inspect support for a specific disease
    q = input("\nInspect rules for a specific disease id (e.g., dengue, pneumonia, influenza, common_cold, t2dm), or Enter to skip: ").strip()
    if q and q in kb.diseases:
        detail = supports_disease(engine, facts, q)
        print(f"\nRules for {kb.diseases[q].label}:")
        for r in detail["rules"]:
            met = "✅" if r["all_met"] else "—"
            print(f" - {r['rule']} ({met}) weight {r['weight']:.2f}: {r['why']}")

if __name__ == "__main__":
    main()



=== Medical Diagnosis Expert System ===
Enter basic details (press Enter to skip):


Age (years):  25
Sex (M/F):  f
Day of illness (integer):  3



Vitals (optional):


Temperature (°C):  102
SpO2 (%):  96
Respiratory rate (/min):  90



Mark symptoms (y/n). Leave blank as 'n':


Symptom: fever? (y/n):  y
Symptom: mild fever? (y/n):  n
Symptom: myalgia? (y/n):  n
Symptom: headache? (y/n):  y
Symptom: dry cough? (y/n):  y
Symptom: cough? (y/n):  y
Symptom: productive cough? (y/n):  n
Symptom: runny nose? (y/n):  y
Symptom: sore throat? (y/n):  y
Symptom: sneezing? (y/n):  y
Symptom: chest pain? (y/n):  y
Symptom: dyspnea? (y/n):  n
Symptom: rash? (y/n):  y
Symptom: retro orbital pain? (y/n):  n
Symptom: fatigue? (y/n):  y
Symptom: polyuria? (y/n):  n
Symptom: polydipsia? (y/n):  n
Symptom: polyphagia? (y/n):  n
Symptom: blurred vision? (y/n):  n
Symptom: bleeding? (y/n):  n
Symptom: persistent vomiting? (y/n):  n
Symptom: severe abdominal pain? (y/n):  n
Symptom: mucosal bleed? (y/n):  n
Symptom: lethargy? (y/n):  n
Symptom: confusion? (y/n):  y



Risk factors / context (y/n):


Risk: seasonal outbreak? (y/n):  y
Risk: no flu vaccine? (y/n):  y
Risk: elderly? (y/n):  n
Risk: chronic disease? (y/n):  n
Risk: smoker? (y/n):  n
Risk: COPD? (y/n):  n
Risk: recent influenza? (y/n):  n
Risk: mosquito exposure? (y/n):  y
Risk: outbreak area? (y/n):  n
Risk: obesity? (y/n):  n
Risk: family history dm? (y/n):  n
Risk: sedentary? (y/n):  n
Risk: gestational dm history? (y/n):  n
Risk: recent cold contact? (y/n):  n



Labs (optional; Enter to skip):


Platelet count (/µL):  170k
HbA1c (%):  



=== Differential Diagnosis (Top Matches) ===

1. Common Cold (Viral URTI)  | Score: 0.48
   Summary: Self-limiting viral infection with runny nose, sneezing, sore throat, mild fever.
   Typical risks: recent_cold_contact
   Typical symptoms: runny_nose, sore_throat, sneezing, mild_fever, cough
   Why (fired rules):
     - [cold_core] +0.65: Runny nose + sore throat + sneezing are classic viral URTI features.
   Suggested next steps / tests:
     - No tests typically required; symptomatic management.

2. Influenza (Flu)  | Score: 0.26
   Summary: Acute onset fever, myalgia, headache, cough; seasonal spikes.
   Typical risks: seasonal_outbreak, no_flu_vaccine, elderly, chronic_disease
   Typical symptoms: fever, myalgia, headache, dry_cough, fatigue
   Why (fired rules):
     - [flu_season_risk] +0.30: Cough with fever during seasonal outbreak increases odds of influenza.
   Suggested next steps / tests:
     - Rapid influenza test (if available)
     - Consider CBC if severe

=== RED F


Inspect rules for a specific disease id (e.g., dengue, pneumonia, influenza, common_cold, t2dm), or Enter to skip:  dengue



Rules for Dengue Fever:
 - dengue_core (—) weight 0.60: Acute fever + headache + myalgia in outbreak area is suggestive of dengue.
 - dengue_retroorbital (—) weight 0.25: Retro-orbital pain increases diagnostic suspicion of dengue.
