<a href="https://colab.research.google.com/github/Jaderfonseca/Medical-Diagnostics-with-Bayesian-Reasoning/blob/main/diagnostics_bayes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# --- Setup (cell 1) ---
import numpy as np
import math
from pathlib import Path
import matplotlib.pyplot as plt

# Reproducibility
rng = np.random.default_rng(42)

# Folders
FIG_DIR = Path("figures")
FIG_DIR.mkdir(parents=True, exist_ok=True)


In [None]:
# --- Core Bayes helpers (cell 2) ---
def odds(p: float) -> float:
    """Probability (0..1) -> odds."""
    if p <= 0 or p >= 1:
        return math.inf if p == 1 else 0.0
    return p / (1 - p)

def prob_from_odds(o: float) -> float:
    """Odds -> probability (0..1)."""
    return o / (1 + o)

def likelihood_ratios(sens: float, spec: float):
    """Return LR+ and LR- from sensitivity and specificity."""
    lr_pos = sens / (1 - spec)
    lr_neg = (1 - sens) / spec
    return lr_pos, lr_neg

def posterior_probs_from_test(sens: float, spec: float, prevalence: float):
    """
    Return posterior probability of disease after:
      - positive test
      - negative test
    using LR and prior odds (Bayes' theorem in odds form).
    """
    prior_odds = odds(prevalence)
    lr_pos, lr_neg = likelihood_ratios(sens, spec)
    post_odds_pos = prior_odds * lr_pos
    post_odds_neg = prior_odds * lr_neg
    return prob_from_odds(post_odds_pos), prob_from_odds(post_odds_neg)


In [None]:
# --- Simulation utilities (cell 3) ---
def simulate_cohort(n: int, sens: float, spec: float, prevalence: float, rng=None):
    """Simulate n patients given sensitivity, specificity, and prevalence; return counts (TP/FP/FN/TN) and basic metrics."""

    if rng is None:
        rng = np.random.default_rng()

    # Disease status
    has_disease = rng.random(n) < prevalence

    # Test outcomes:
    # If disease: test positive with prob sens, else negative
    test_pos_if_d = rng.random(n) < sens
    # If no disease: test positive with prob (1 - spec), else negative
    test_pos_if_nd = rng.random(n) < (1 - spec)

    test_positive = np.where(has_disease, test_pos_if_d, test_pos_if_nd)

    # Confusion counts
    tp = int(np.sum(test_positive & has_disease))
    fp = int(np.sum(test_positive & ~has_disease))
    fn = int(np.sum(~test_positive & has_disease))
    tn = int(np.sum(~test_positive & ~has_disease))

    # Metrics
    ppv = tp / (tp + fp) if (tp + fp) > 0 else float('nan')
    npv = tn / (tn + fn) if (tn + fn) > 0 else float('nan')
    acc = (tp + tn) / n

    return {
        "n": n, "TP": tp, "FP": fp, "FN": fn, "TN": tn,
        "PPV": ppv, "NPV": npv, "ACC": acc
    }


In [None]:
# --- Plot helpers (cell 4) ---
def plot_prior_posterior(prior_p: float, post_p_pos: float, post_p_neg: float, title: str, outfile: Path):
    """Plot and save bar chart of prior vs posterior probabilities (test+ / test-)."""

    labels = ["Prior", "Posterior (Test +)", "Posterior (Test -)"]
    values = [prior_p, post_p_pos, post_p_neg]

    plt.figure(figsize=(6, 4))
    plt.bar(labels, values)
    plt.ylim(0, 1)
    plt.ylabel("Probability")
    plt.title(title)
    plt.tight_layout()
    plt.savefig(outfile, dpi=200)
    plt.close()

def plot_fp_fn_counts(tp: int, fp: int, fn: int, tn: int, title: str, outfile: Path):
    """Plot and save bar chart of TP, FP, FN, TN counts in the simulated cohort."""

    labels = ["TP", "FP", "FN", "TN"]
    values = [tp, fp, fn, tn]

    plt.figure(figsize=(6, 4))
    plt.bar(labels, values)
    plt.ylabel("Count (per cohort)")
    plt.title(title)
    plt.tight_layout()
    plt.savefig(outfile, dpi=200)
    plt.close()


In [None]:
# --- Case 1: Mammogram (cell 5) ---
N = 1000
prev1 = 0.01   # 1%
sens1 = 0.85   # 85%
spec1 = 0.90   # 90%

# Theoretical posterior probabilities (Bayes)
post_pos_1, post_neg_1 = posterior_probs_from_test(sens1, spec1, prev1)

# Simulation
res1 = simulate_cohort(N, sens1, spec1, prev1, rng=rng)

# Save figures
plot_prior_posterior(
    prior_p=prev1,
    post_p_pos=post_pos_1,
    post_p_neg=post_neg_1,
    title="Mammogram (40–49): Prior vs Posterior",
    outfile=FIG_DIR / "mammogram_prior_posterior.png"
)

plot_fp_fn_counts(
    tp=res1["TP"], fp=res1["FP"], fn=res1["FN"], tn=res1["TN"],
    title="Mammogram (40–49): TP/FP/FN/TN (n=1000)",
    outfile=FIG_DIR / "mammogram_confusion_counts.png"
)

res1, post_pos_1, post_neg_1


In [None]:
# --- Case 2: COVID-19 Rapid Antigen (cell 6) ---
N = 1000
prev2 = 0.20   # 20%
sens2 = 0.80   # 80%
spec2 = 0.98   # 98%

# Theoretical posterior probabilities (Bayes)
post_pos_2, post_neg_2 = posterior_probs_from_test(sens2, spec2, prev2)

# Simulation
res2 = simulate_cohort(N, sens2, spec2, prev2, rng=rng)

# Save figures
plot_prior_posterior(
    prior_p=prev2,
    post_p_pos=post_pos_2,
    post_p_neg=post_neg_2,
    title="COVID-19 Rapid Test: Prior vs Posterior",
    outfile=FIG_DIR / "covid_prior_posterior.png"
)

plot_fp_fn_counts(
    tp=res2["TP"], fp=res2["FP"], fn=res2["FN"], tn=res2["TN"],
    title="COVID-19 Rapid Test: TP/FP/FN/TN (n=1000)",
    outfile=FIG_DIR / "covid_confusion_counts.png"
)

res2, post_pos_2, post_neg_2


In [None]:
# --- nice printouts (cell 7) ---
def pretty_print_case(name, prevalence, sensitivity, specificity, res, post_pos, post_neg):
    print(f"=== {name} ===")
    print(f"Prevalence: {prevalence*100:.2f}% | Sensitivity: {sensitivity*100:.1f}% | Specificity: {specificity*100:.1f}%")
    print(f"Simulated counts (n={res['n']}): TP={res['TP']}  FP={res['FP']}  FN={res['FN']}  TN={res['TN']}")
    print(f"PPV ~ {res['PPV']*100:.2f}% | NPV ~ {res['NPV']*100:.2f}% | Accuracy ~ {res['ACC']*100:.2f}%")
    print(f"Bayes posterior P(disease | Test +): {post_pos*100:.2f}%")
    print(f"Bayes posterior P(disease | Test -): {post_neg*100:.2f}%")
    print()

pretty_print_case(
    "Case 1 — Mammogram (40–49)",
    prev1, sens1, spec1, res1, post_pos_1, post_neg_1
)

pretty_print_case(
    "Case 2 — COVID-19 Rapid Antigen",
    prev2, sens2, spec2, res2, post_pos_2, post_neg_2
)

print("Figures saved to:", FIG_DIR.resolve())
