### run this code to generate confusion matricies for each dentist for each class for each dentist, against the generated ground truth

In [None]:
ground_truth_loc = "F:/Github Repos/study/study_code/processed_data/task_1_inferred_ground_truth.csv"

dentists_loc = "F:/Github Repos/study/study_code/processed_data/master_copy_task_1_users_v1_with_manual_corrections.csv"

import os
import ast
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix



OUT_DIR = os.path.join(os.path.dirname(dentists_loc), "confusion_matrices_by_annotator")
os.makedirs(OUT_DIR, exist_ok=True)


KEYS = ["mesial", "distal", "mesial_pls", "distal_pls", "furcation", "arr_left", "arr_right"]

MESIAL_DISTAL_KEYS = {"mesial", "distal"}
PLS_KEYS = {"mesial_pls", "distal_pls"}
ARR_KEYS = {"arr_left", "arr_right"}
FURCATION_KEYS = {"furcation"}

MESIAL_DISTAL_CLASSES = ["Healthy", "Mild", "Moderate", "Severe"]
BINARY_CLASSES = [False, True]


def to_binary_md(v):
    """
    Mesial/Distal binarisation:
    Healthy -> False
    Anything else -> True
    """
    if v == "Healthy":
        return False
    return True

def safe_literal_eval(x):
    if pd.isna(x):
        return None
    if isinstance(x, (dict, list)):
        return x
    s = str(x).strip()
    if s == "" or s.lower() in {"none", "nan"}:
        return None
    try:
        return ast.literal_eval(s)
    except Exception:
        return None

def normalise_none_token(v):
    if v is None:
        return None
    if isinstance(v, float) and np.isnan(v):
        return None
    if isinstance(v, str) and v.strip().lower() in {"none", "nan", ""}:
        return None
    return v

def normalise_boolish(v):
    v = normalise_none_token(v)
    if v is None:
        return None
    if isinstance(v, bool):
        return v
    if isinstance(v, (int, np.integer)) and v in (0, 1):
        return bool(v)
    if isinstance(v, str):
        s = v.strip().lower()
        if s in {"true", "t", "1", "yes", "y", "present", "pos", "positive"}:
            return True
        if s in {"false", "f", "0", "no", "n", "absent", "neg", "negative"}:
            return False
    return v

def drop_conf_keys(d):
    if not isinstance(d, dict):
        return d
    return {k: v for k, v in d.items() if "_conf" not in str(k)}

def extract_tooth_entries(obj):

    obj = safe_literal_eval(obj)
    if obj is None:
        return []

    if isinstance(obj, dict) and "tooth_annotations" in obj:
        obj = obj["tooth_annotations"]

    out = []

    if isinstance(obj, list):
        for item in obj:
            if not isinstance(item, dict):
                continue
            item = drop_conf_keys(item)
            tn = normalise_none_token(item.get("tooth_num", None))
            if tn is None:
                continue
            if isinstance(tn, str) and tn.strip().isdigit():
                tn = int(tn.strip())
            out.append((tn, item))
        return out

    if isinstance(obj, dict):
        if "tooth_num" in obj:
            payload = drop_conf_keys(obj)
            tn = normalise_none_token(payload.get("tooth_num", None))
            if tn is None:
                return []
            if isinstance(tn, str) and tn.strip().isdigit():
                tn = int(tn.strip())
            return [(tn, payload)]

        for k, v in obj.items():
            if not isinstance(v, dict):
                continue
            tn = k
            if isinstance(tn, str) and tn.strip().isdigit():
                tn = int(tn.strip())
            payload = drop_conf_keys(v)
            out.append((tn, payload))
        return out

    return []


def should_count_pair(gt_val, dent_val, key):
    gt_val = normalise_none_token(gt_val)
    dent_val = normalise_none_token(dent_val)

    if key in MESIAL_DISTAL_KEYS:
        if gt_val is None:
            gt_val = "Healthy"
        if dent_val is None:
            dent_val = "Healthy"
        return True, gt_val, dent_val


    if key in FURCATION_KEYS:
        if gt_val is None:
            return False, None, None

        gt_b = normalise_boolish(gt_val)
        dent_b = normalise_boolish(dent_val)

        if gt_b is None:
            return False, None, None

        if dent_b is None:
            dent_b = False

        return True, gt_b, dent_b

    if key in ARR_KEYS or key in PLS_KEYS:
        gt_b = normalise_boolish(gt_val)
        dent_b = normalise_boolish(dent_val)

        if gt_b is None:
            gt_b = False
        if dent_b is None:
            dent_b = False

        return True, gt_b, dent_b

    if gt_val is None:
        return False, None, None
    if dent_val is None:
        return False, None, None
    return True, gt_val, dent_val


def make_confusion_df(y_true, y_pred, labels):
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    return pd.DataFrame(
        cm,
        index=[f"GT_{l}" for l in labels],
        columns=[f"PRED_{l}" for l in labels],
    )


gt_raw = pd.read_csv(ground_truth_loc)
dent_raw = pd.read_csv(dentists_loc)

if "tooth_annotations" not in gt_raw.columns:
    raise ValueError("Ground truth CSV must contain 'tooth_annotations'.")
gt_image_col = "image_name" if "image_name" in gt_raw.columns else ("image" if "image" in gt_raw.columns else None)
if gt_image_col is None:
    raise ValueError("Ground truth CSV must contain 'image_name' or 'image'.")

if "user_id" not in dent_raw.columns:
    raise ValueError("Dentists CSV must contain 'user_id'.")
dent_image_col = "image" if "image" in dent_raw.columns else ("image_name" if "image_name" in dent_raw.columns else None)
if dent_image_col is None:
    raise ValueError("Dentists CSV must contain 'image' or 'image_name'.")
if "annotations" not in dent_raw.columns:
    raise ValueError("Dentists CSV must contain 'annotations'.")


gt_rows = []
for _, r in gt_raw.iterrows():
    img = r[gt_image_col]
    for tooth_num, payload in extract_tooth_entries(r["tooth_annotations"]):
        row = {"image_name": img, "tooth_num": tooth_num}
        for k in KEYS:
            row[k] = normalise_none_token(payload.get(k, None))
        gt_rows.append(row)

gt_long = pd.DataFrame(gt_rows)
if gt_long.empty:
    raise ValueError("Parsed GT is empty. Check 'tooth_annotations' format.")
gt_long = gt_long.drop_duplicates(subset=["image_name", "tooth_num"], keep="first")


dent_rows = []
for _, r in dent_raw.iterrows():
    user_id = r["user_id"]
    img = r[dent_image_col]
    for tooth_num, payload in extract_tooth_entries(r["annotations"]):
        out = {"user_id": user_id, "image_name": img, "tooth_num": tooth_num}
        for k in KEYS:
            out[k] = normalise_none_token(payload.get(k, None))
        dent_rows.append(out)

dent_long = pd.DataFrame(dent_rows)
if dent_long.empty:
    raise ValueError("Parsed dentists is empty. Check 'annotations' format.")


all_users = sorted([u for u in dent_long["user_id"].dropna().unique()])
summary_rows = []

for user in all_users:
    user_dent = dent_long[dent_long["user_id"] == user].copy()

    user_df = gt_long.merge(
        user_dent,
        on=["image_name", "tooth_num"],
        suffixes=("_gt", "_dent"),
        how="left",
    )

    user_out_dir = os.path.join(OUT_DIR, f"user_{user}")
    os.makedirs(user_out_dir, exist_ok=True)

    for key in KEYS:
        gt_col = f"{key}_gt"
        dent_col = f"{key}_dent"

        y_true, y_pred = [], []
        for gt_val, dent_val in zip(user_df[gt_col].tolist(), user_df[dent_col].tolist()):
            ok, gt_v, dent_v = should_count_pair(gt_val, dent_val, key)
            if not ok:
                continue
            y_true.append(gt_v)
            y_pred.append(dent_v)

        n_used = len(y_true)
        if n_used == 0:
            with open(os.path.join(user_out_dir, f"{key}_NO_COUNTS.txt"), "w", encoding="utf-8") as f:
                f.write(
                    "No valid pairs to count under rules.\n"
                    "Mesial/Distal: None treated as Healthy.\n"
                    "ARR+PLS: None treated as False (not skipped).\n"
                    "Furcation: None treated as False (not skipped).\n"
                    "Keys containing '_conf' are ignored.\n"
                )
            summary_rows.append({"user_id": user, "key": key, "n_used": 0})
            continue

        if key in MESIAL_DISTAL_KEYS:
            labels = MESIAL_DISTAL_CLASSES
            extra = sorted((set(y_true) | set(y_pred)) - set(labels))
            labels = labels + extra
        elif key in ARR_KEYS or key in PLS_KEYS or key in FURCATION_KEYS:
            labels = BINARY_CLASSES
            extra = sorted((set(y_true) | set(y_pred)) - set(labels), key=lambda x: str(x))
            labels = labels + extra
        else:
            labels = sorted(set(y_true) | set(y_pred), key=lambda x: str(x))

        cm_df = make_confusion_df(y_true, y_pred, labels=labels)
        cm_df.to_csv(os.path.join(user_out_dir, f"{key}_confusion_matrix.csv"), index=True)
        
        if key in MESIAL_DISTAL_KEYS:
            y_true_bin = [to_binary_md(v) for v in y_true]
            y_pred_bin = [to_binary_md(v) for v in y_pred]

            bin_labels = BINARY_CLASSES

            cm_bin_df = make_confusion_df(
                y_true_bin,
                y_pred_bin,
                labels=bin_labels,
            )

            cm_bin_df.to_csv(
                os.path.join(user_out_dir, f"{key}_binary_confusion_matrix.csv"),
                index=True,
            )

        summary_rows.append({
            "user_id": user,
            "key": key,
            "n_used": n_used,
            "labels": "|".join(map(str, labels)),
        })

summary_df = pd.DataFrame(summary_rows).sort_values(["user_id", "key"])
summary_df.to_csv(os.path.join(OUT_DIR, "SUMMARY_counts_and_labels.csv"), index=False)

print(f"Done. Wrote per-user confusion matrices to: {OUT_DIR}")


Done. Wrote per-user confusion matrices to: F:/Github Repos/study/study_code/processed_data\confusion_matrices_by_annotator


### run this code to evaluate dentists

In [None]:
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Tuple, List

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt



def _normalise_label(s: str) -> str:
    s = str(s)
    if s.startswith("GT_"):
        return s[3:]
    if s.startswith("PRED_"):
        return s[5:]
    return s


def load_confusion_csv(csv_path: Path) -> pd.DataFrame:
    """
    Load a confusion matrix CSV with a first column as GT labels (often 'Unnamed: 0'),
    and columns for predicted labels.
    Returns a DataFrame with index=GT labels, columns=PRED labels, values=int counts.
    """
    df = pd.read_csv(csv_path)

    gt_col = df.columns[0]
    df = df.rename(columns={gt_col: "GT_LABEL"})
    df["GT_LABEL"] = df["GT_LABEL"].map(_normalise_label)

    pred_cols = [c for c in df.columns if c != "GT_LABEL"]
    renamed = {c: _normalise_label(c) for c in pred_cols}
    df = df.rename(columns=renamed)

    df = df.set_index("GT_LABEL")
    df = df.apply(pd.to_numeric, errors="raise").astype(int)

    if (df.values < 0).any():
        raise ValueError(f"Negative counts found in {csv_path}")

    return df



@dataclass
class BinaryCounts:
    tn: int
    fp: int
    fn: int
    tp: int


def _safe_div(num: float, den: float) -> float:
    return float(num) / float(den) if den != 0 else np.nan


def binary_counts_from_cm(cm: pd.DataFrame,
                          negative_label: str = "False",
                          positive_label: str = "True") -> BinaryCounts:
    """
    Extract TN/FP/FN/TP from a 2x2 confusion matrix with GT rows and PRED cols.
    """
    idx = {str(i): i for i in cm.index}
    col = {str(c): c for c in cm.columns}

    def pick(mapping: Dict[str, str], target: str) -> str:
        if target in mapping:
            return mapping[target]
        for k, v in mapping.items():
            if k.lower() == target.lower():
                return v
        raise KeyError(f"Could not find label '{target}' in {list(mapping.keys())}")

    neg_i = pick(idx, negative_label)
    pos_i = pick(idx, positive_label)
    neg_c = pick(col, negative_label)
    pos_c = pick(col, positive_label)

    tn = int(cm.loc[neg_i, neg_c])
    fp = int(cm.loc[neg_i, pos_c])
    fn = int(cm.loc[pos_i, neg_c])
    tp = int(cm.loc[pos_i, pos_c])
    return BinaryCounts(tn=tn, fp=fp, fn=fn, tp=tp)


def metrics_from_binary_counts(b: BinaryCounts) -> Dict[str, float]:
    tn, fp, fn, tp = b.tn, b.fp, b.fn, b.tp
    n = tn + fp + fn + tp

    sens = _safe_div(tp, tp + fn)
    spec = _safe_div(tn, tn + fp)
    ppv  = _safe_div(tp, tp + fp)
    npv  = _safe_div(tn, tn + fn)
    f1   = _safe_div(2 * tp, 2 * tp + fp + fn)
    acc  = _safe_div(tp + tn, n)
    bal_acc = np.nanmean([sens, spec])

    mcc_num = (tp * tn) - (fp * fn)
    mcc_den = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
    mcc = _safe_div(mcc_num, mcc_den)

    fnr = 1.0 - sens
    fpr = 1.0 - spec
    misdx = 1.0 - acc

    return {
        "n": float(n),
        "tp": float(tp), "tn": float(tn), "fp": float(fp), "fn": float(fn),
        "sensitivity": sens,
        "specificity": spec,
        "precision_ppv": ppv,
        "npv": npv,
        "f1": f1,
        "accuracy": acc,
        "balanced_accuracy": bal_acc,
        "mcc": mcc,

        # requested rates
        "false_negative_rate": fnr,
        "false_positive_rate": fpr,

        # kept for reference (but no longer used for plotting)
        "misdiagnosis_rate": misdx,
    }


def metrics_multiclass(cm: pd.DataFrame) -> Dict[str, float]:
    """
    Multiclass metrics with one-vs-rest macro precision/recall/F1 AND macro-specificity.
    Also returns:
      false_negative_rate = 1 - macro_recall
      false_positive_rate = 1 - macro_specificity
    """
    cmv = cm.values.astype(float)
    n = cmv.sum()
    acc = _safe_div(np.trace(cmv), n)

    tp = np.diag(cmv)
    pred_sum = cmv.sum(axis=0)
    gt_sum = cmv.sum(axis=1)

    fp = pred_sum - tp
    fn = gt_sum - tp
    tn = n - tp - fp - fn

    prec = np.array([_safe_div(tp[i], pred_sum[i]) for i in range(len(tp))], dtype=float)
    rec  = np.array([_safe_div(tp[i], gt_sum[i]) for i in range(len(tp))], dtype=float)
    f1   = np.array([_safe_div(2 * prec[i] * rec[i], prec[i] + rec[i]) for i in range(len(tp))], dtype=float)
    spec = np.array([_safe_div(tn[i], tn[i] + fp[i]) for i in range(len(tp))], dtype=float)

    macro_prec = np.nanmean(prec)
    macro_rec  = np.nanmean(rec)
    macro_f1   = np.nanmean(f1)
    macro_spec = np.nanmean(spec)

    fnr = 1.0 - macro_rec
    fpr = 1.0 - macro_spec

    return {
        "n": float(n),
        "accuracy": acc,
        "macro_precision": macro_prec,
        "macro_recall": macro_rec,
        "macro_f1": macro_f1,
        "balanced_accuracy": macro_rec,  # mean recall across classes

        # include specificity as a metric (macro one-vs-rest)
        "macro_specificity": macro_spec,

        # requested rates for heatmaps
        "false_negative_rate": fnr,
        "false_positive_rate": fpr,
    }


def bootstrap_binary_cis(b: BinaryCounts, n_boot: int = 2000, seed: int = 0) -> Dict[str, Tuple[float, float]]:
    """
    Bootstrap CIs using multinomial resampling over the 4 cells.
    """
    rng = np.random.default_rng(seed)
    counts = np.array([b.tn, b.fp, b.fn, b.tp], dtype=int)
    n = counts.sum()
    if n == 0:
        return {}

    p = counts / n
    boots = []
    for _ in range(n_boot):
        sample = rng.multinomial(n, p)
        bb = BinaryCounts(tn=int(sample[0]), fp=int(sample[1]), fn=int(sample[2]), tp=int(sample[3]))
        m = metrics_from_binary_counts(bb)
        boots.append([
            m["false_negative_rate"],
            m["false_positive_rate"],
            m["sensitivity"],
            m["specificity"],
            m["balanced_accuracy"],
            m["f1"],
            m["mcc"],
        ])

    arr = np.asarray(boots, dtype=float)
    keys = ["false_negative_rate", "false_positive_rate", "sensitivity", "specificity", "balanced_accuracy", "f1", "mcc"]
    cis = {}
    for i, k in enumerate(keys):
        lo, hi = np.nanpercentile(arr[:, i], [2.5, 97.5])
        cis[k] = (float(lo), float(hi))
    return cis



def find_annotator_folders(root: Path) -> List[Path]:
    return [p for p in root.iterdir() if p.is_dir() and not p.name.startswith(".")]


def collect_confusion_files(folder: Path, pattern: str = "*confusion_matrix.csv") -> List[Path]:
    return sorted(folder.rglob(pattern))


def infer_task_name(csv_path: Path) -> str:
    name = csv_path.stem
    for suffix in ["_confusion_matrix", "confusion_matrix"]:
        if name.endswith(suffix):
            name = name[: -len(suffix)]
            break
    return name.strip("_")


def aggregate_confusions(confusions: List[pd.DataFrame]) -> pd.DataFrame:
    """
    Sum confusion matrices with potentially different label sets.
    """
    if not confusions:
        return pd.DataFrame()

    all_gt = sorted(set().union(*[set(c.index) for c in confusions]))
    all_pr = sorted(set().union(*[set(c.columns) for c in confusions]))

    agg = pd.DataFrame(0, index=all_gt, columns=all_pr, dtype=int)
    for c in confusions:
        agg.loc[c.index, c.columns] += c.astype(int)
    return agg



def plot_metric_heatmap(per_group_dentist: pd.DataFrame,
                        metric: str,
                        out_path: Path,
                        title: str,
                        group_order: List[str],
                        group_labels: List[str],
                        value_fmt: str = "{:.3f}") -> None:
    """
    Heatmap with per-cell numeric annotations.
    Rows: dentists (D1, D2, ...)
    Cols: groups (PBL, ARR, PLS, Furcation)
    """
    df = per_group_dentist.copy()
    df = df[df["group"].isin(group_order)].copy()
    df["group"] = pd.Categorical(df["group"], categories=group_order, ordered=True)

    pivot = df.pivot_table(index="dentist_id", columns="group", values=metric, aggfunc="first")
    pivot = pivot.reindex(columns=group_order).sort_index()

    data = pivot.values.astype(float)

    fig = plt.figure(figsize=(10, 6), constrained_layout=True)
    ax = fig.add_subplot(111)

    im = ax.imshow(data, aspect="auto")

    ax.set_yticks(np.arange(pivot.shape[0]))
    ax.set_yticklabels(pivot.index)

    ax.set_xticks(np.arange(pivot.shape[1]))
    ax.set_xticklabels(group_labels, rotation=0)

    ax.set_xlabel("Task", labelpad=10)
    ax.set_ylabel("Dentist", labelpad=10)
    ax.set_title(title)

    cbar = fig.colorbar(im, ax=ax)

    for i in range(pivot.shape[0]):
        for j in range(pivot.shape[1]):
            v = data[i, j]
            txt = "NA" if np.isnan(v) else value_fmt.format(v)
            ax.text(j, i, txt, ha="center", va="center")

    fig.savefig(out_path, dpi=200)
    plt.close(fig)



root = Path("F:/Github Repos/study/study_code/processed_data/confusion_matrices_by_annotator")
out_dir = Path("F:/Github Repos/study/study_code/processed_data/results_dentists")
n_boot = 2000
seed = 0

out_dir.mkdir(parents=True, exist_ok=True)
plots_dir = out_dir / "plots"
plots_dir.mkdir(parents=True, exist_ok=True)

GROUPS = {
    "mesial":    {"tasks": ["mesial_binary"],                    "type": "binary"},
    "distal":    {"tasks": ["distal_binary"],                    "type": "binary"},
    "PBL":       {"tasks": ["mesial_binary", "distal_binary"],   "type": "binary"},  # mesial+distal combined (binary)
    "ARR":       {"tasks": ["arr_left", "arr_right"],            "type": "binary"},
    "Furcation": {"tasks": ["furcation"],                        "type": "binary"},
    "PLS":       {"tasks": ["mesial_pls", "distal_pls"],         "type": "binary"},
}

# Dentist renaming to D1, D2, ...
annotator_folders = sorted(find_annotator_folders(root), key=lambda p: p.name)
if not annotator_folders:
    raise RuntimeError(f"No annotator subfolders found under: {root}")

dentist_map = {p.name: f"D{i+1}" for i, p in enumerate(annotator_folders)}

# Load all CSVs once
all_rows = []
for folder in annotator_folders:
    dentist_raw = folder.name
    dentist_id = dentist_map[dentist_raw]

    csv_files = collect_confusion_files(folder)
    for csv_path in csv_files:
        task = infer_task_name(csv_path)
        cm = load_confusion_csv(csv_path)

        all_rows.append({
            "dentist_raw": dentist_raw,
            "dentist_id": dentist_id,
            "task": task,
            "file": str(csv_path),
            "cm": cm,
        })

all_df = pd.DataFrame(all_rows)
if all_df.empty:
    raise RuntimeError(f"No confusion_matrix.csv files found under: {root}")

# Per-group outputs
per_group_file_rows = []
per_group_dentist_rows = []

for group_name, spec in GROUPS.items():
    tasks = set(spec["tasks"])
    group_type = spec["type"]

    group_out = out_dir / group_name
    group_out.mkdir(parents=True, exist_ok=True)

    gdf = all_df[all_df["task"].isin(tasks)].copy()

    for _, r in gdf.iterrows():
        cm = r["cm"]

        if group_type == "binary":
            if cm.shape != (2, 2):
                raise ValueError(f"[{group_name}] Expected binary 2x2 but got {cm.shape} in {r['file']}")
            b = binary_counts_from_cm(cm, "False", "True")
            m = metrics_from_binary_counts(b)
        else:
            m = metrics_multiclass(cm)

        per_group_file_rows.append({
            "group": group_name,
            "dentist_raw": r["dentist_raw"],
            "dentist_id": r["dentist_id"],
            "task": r["task"],
            "file": r["file"],
            **m
        })

    for dentist_raw, sub in gdf.groupby("dentist_raw"):
        dentist_id = dentist_map[dentist_raw]
        cms = list(sub["cm"].values)
        agg = aggregate_confusions(cms)

        if agg.empty:
            continue

        if group_type == "binary":
            if agg.shape != (2, 2):
                raise ValueError(f"[{group_name}] Expected aggregated binary 2x2 but got {agg.shape} for {dentist_raw}")
            b = binary_counts_from_cm(agg, "False", "True")
            m = metrics_from_binary_counts(b)
            cis = bootstrap_binary_cis(b, n_boot=n_boot, seed=seed)

            row = {"group": group_name, "dentist_raw": dentist_raw, "dentist_id": dentist_id, **m}
            for k, (lo, hi) in cis.items():
                row[f"{k}_ci_low"] = lo
                row[f"{k}_ci_high"] = hi
            per_group_dentist_rows.append(row)
        else:
            m = metrics_multiclass(agg)
            per_group_dentist_rows.append({"group": group_name, "dentist_raw": dentist_raw, "dentist_id": dentist_id, **m})

    # Save group CSVs
    per_file_g = pd.DataFrame([r for r in per_group_file_rows if r["group"] == group_name])
    per_dent_g = pd.DataFrame([r for r in per_group_dentist_rows if r["group"] == group_name])

    per_file_g.to_csv(group_out / f"per_file_metrics_{group_name}.csv", index=False)
    per_dent_g.to_csv(group_out / f"per_dentist_metrics_{group_name}.csv", index=False)

per_file = pd.DataFrame(per_group_file_rows)
per_dentist = pd.DataFrame(per_group_dentist_rows)
per_file.to_csv(out_dir / "per_file_metrics_ALLGROUPS.csv", index=False)
per_dentist.to_csv(out_dir / "per_dentist_metrics_ALLGROUPS.csv", index=False)


plot_groups = ["PBL", "ARR", "PLS", "Furcation"]
plot_labels = ["PBL (mesial+distal, binary)", "ARR", "PLS", "Furcation"]

plot_df = per_dentist[per_dentist["group"].isin(plot_groups)].copy()

plot_metric_heatmap(
    per_group_dentist=plot_df,
    metric="false_negative_rate",
    out_path=plots_dir / "heatmap_false_negative_rate.png",
    title="False Negative Rate (1 - recall)",
    group_order=plot_groups,
    group_labels=plot_labels,
)

plot_metric_heatmap(
    per_group_dentist=plot_df,
    metric="false_positive_rate",
    out_path=plots_dir / "heatmap_false_positive_rate.png",
    title="False Positive Rate (1 - specificity)",
    group_order=plot_groups,
    group_labels=plot_labels,
)

print("Done. Outputs written to:")
print(f"  {out_dir}")
print("Plots written to:")
print(f"  {plots_dir}")


  pivot = df.pivot_table(index="dentist_id", columns="group", values=metric, aggfunc="first")
  pivot = df.pivot_table(index="dentist_id", columns="group", values=metric, aggfunc="first")


Done. Outputs written to:
  F:\Github Repos\study\study_code\processed_data\results_dentists
Plots written to:
  F:\Github Repos\study\study_code\processed_data\results_dentists\plots


### generate latex tables

In [None]:


from pathlib import Path
import re
import numpy as np
import pandas as pd


out_dir = Path(r"F:/Github Repos/study/study_code/processed_data/results_dentists")
results_path = out_dir / "per_dentist_metrics_ALLGROUPS.csv"

GROUP_ORDER = ["mesial", "distal", "PBL", "ARR", "PLS", "Furcation"]
GROUP_LABELS = {
    "mesial": "Mesial (binary)",
    "distal": "Distal (binary)",
    "PBL": "PBL (mesial+distal, binary)",
    "ARR": "ARR",
    "PLS": "PLS",
    "Furcation": "Furcation",
}

METRICS = [
    ("macro_precision", "Precision"),
    ("macro_recall", "Recall"),
    ("macro_f1", "F1"),
    ("macro_specificity", "Specificity"),
    ("false_negative_rate", "FNR"),
    ("false_positive_rate", "FPR"),
]

CAPTION = (
    "Observer-wise results across evaluation types. Metrics are computed from aggregated confusion matrices per observer. "
    "For binary tasks, macro metrics are set to their binary equivalents. Row \\textbf{A} reports the mean across observers "
    "within each evaluation type."
)
LABEL = "tab:observer_eval_single"


def _fmt(x: float, ndp: int = 3) -> str:
    if x is None or (isinstance(x, float) and np.isnan(x)):
        return "NA"
    return f"{float(x):.{ndp}f}"

def _latex_escape(s: str) -> str:
    s = str(s)
    return (s.replace("\\", r"\textbackslash{}")
             .replace("&", r"\&")
             .replace("%", r"\%")
             .replace("$", r"\$")
             .replace("#", r"\#")
             .replace("_", r"\_")
             .replace("{", r"\{")
             .replace("}", r"\}")
             .replace("~", r"\textasciitilde{}")
             .replace("^", r"\textasciicircum{}"))

def _observer_sort_key(s: str) -> int:
    m = re.match(r"D(\d+)$", str(s))
    return int(m.group(1)) if m else 10**9

def _is_binary_group(g: str) -> bool:
    return g in {"mesial", "distal", "PBL", "ARR", "PLS", "Furcation"}

def _autofind_results_csv(root: Path) -> Path:
    cands = list(root.rglob("per_dentist_metrics_ALLGROUPS.csv"))
    if cands:
        return cands[0]
    cands = list(root.rglob("*per_dentist_metrics*ALLGROUPS*.csv"))
    if cands:
        return cands[0]
    raise FileNotFoundError(
        f"Could not find per_dentist_metrics_ALLGROUPS.csv under: {root}\n"
        "Set `results_path` to the correct CSV."
    )

def _ensure_macro_columns_for_binary(df: pd.DataFrame) -> pd.DataFrame:
    """
    Ensure required macro_* columns exist for BOTH multiclass and binary rows.
    For binary rows, map:
      macro_precision    <- precision_ppv
      macro_recall       <- sensitivity
      macro_f1           <- f1
      macro_specificity  <- specificity

    Also ensure FNR/FPR exist:
      FNR <- 1 - macro_recall
      FPR <- 1 - macro_specificity
    """
    df = df.copy()
    is_bin = df["group"].apply(_is_binary_group)

    for c in ["macro_precision", "macro_recall", "macro_f1", "macro_specificity"]:
        if c not in df.columns:
            df[c] = np.nan

    def fill(target: str, source: str):
        if source in df.columns:
            mask = is_bin & df[target].isna()
            df.loc[mask, target] = pd.to_numeric(df.loc[mask, source], errors="coerce")

    fill("macro_precision", "precision_ppv")
    fill("macro_recall", "sensitivity")
    fill("macro_f1", "f1")
    fill("macro_specificity", "specificity")

    if "false_negative_rate" not in df.columns:
        df["false_negative_rate"] = 1.0 - pd.to_numeric(df["macro_recall"], errors="coerce")
    if "false_positive_rate" not in df.columns:
        df["false_positive_rate"] = 1.0 - pd.to_numeric(df["macro_specificity"], errors="coerce")

    return df


if not results_path.exists():
    results_path = _autofind_results_csv(out_dir)

df = pd.read_csv(results_path)

required = {"dentist_id", "group"}
missing = required - set(df.columns)
if missing:
    raise ValueError(f"Missing required columns in results CSV: {missing}")

df = df[df["group"].isin(GROUP_ORDER)].copy()
df["group"] = pd.Categorical(df["group"], categories=GROUP_ORDER, ordered=True)

df = _ensure_macro_columns_for_binary(df)

keep_cols = ["dentist_id", "group"] + [c for c, _ in METRICS]
df = df[keep_cols].copy()
df["dentist_id"] = df["dentist_id"].astype(str)

observers = sorted(df["dentist_id"].unique().tolist(), key=_observer_sort_key)
row_index = observers + ["A"]

avg_rows = []
for g in GROUP_ORDER:
    sub = df[df["group"] == g].copy()
    if sub.empty:
        continue
    row = {"group": g, "dentist_id": "A"}
    for col, _ in METRICS:
        row[col] = pd.to_numeric(sub[col], errors="coerce").mean()
    avg_rows.append(row)

df = pd.concat([df, pd.DataFrame(avg_rows)], ignore_index=True)


metric_disp = [disp for _, disp in METRICS]
n_metrics = len(metric_disp)

colspec = "l l|" + "".join(["c"] * n_metrics)

lines = []
lines.append(r"\begin{table}[!htbp]")
lines.append(r"\centering")
lines.append(r"\scriptsize")
lines.append(r"\begin{adjustbox}{center, max width=\paperwidth}")
lines.append(r"\begin{tabular}{" + colspec + r"}")
lines.append(r"\toprule")
lines.append(r"\multicolumn{" + str(2 + n_metrics) + r"}{c}{Observer Evaluation}\\")
lines.append(r"\midrule")

# Header row
hdr = ["Evaluation", "Observer"] + metric_disp
lines.append(" & ".join(map(_latex_escape, hdr)) + r" \\")
lines.append(r"\midrule")

for gi, g in enumerate(GROUP_ORDER):
    sub = df[df["group"] == g].copy()
    if sub.empty:
        continue

    sub = sub.set_index("dentist_id")
    group_rows = []
    for obs in row_index:
        if obs not in sub.index:
            vals = ["NA"] * n_metrics
        else:
            vals = []
            for col, _disp in METRICS:
                vals.append(_fmt(pd.to_numeric(sub.loc[obs, col], errors="coerce")))
        group_rows.append((obs, vals))

    eval_name = GROUP_LABELS.get(g, g)

    # First row with multirow evaluation label
    lines.append(
        r"\multirow{" + str(len(group_rows)) + r"}{*}{" + _latex_escape(eval_name) + r"} & "
        + _latex_escape(group_rows[0][0]) + " & " + " & ".join(group_rows[0][1]) + r" \\"
    )
    # Remaining rows
    for obs, vals in group_rows[1:]:
        lines.append(r" & " + _latex_escape(obs) + " & " + " & ".join(vals) + r" \\")
    # Separator
    if gi == len(GROUP_ORDER) - 1:
        lines.append(r"\bottomrule")
    else:
        lines.append(r"\midrule")

lines.append(r"\end{tabular}")
lines.append(r"\end{adjustbox}")
lines.append(r"\caption{" + CAPTION + r"}")
lines.append(r"\label{" + LABEL + r"}")
lines.append(r"\end{table}")

print("\n".join(lines))

# Required LaTeX packages:
# \usepackage{booktabs}
# \usepackage{multirow}
# \usepackage{adjustbox}



\begin{table}[!htbp]
\centering
\scriptsize
\begin{adjustbox}{center, max width=\paperwidth}
\begin{tabular}{l l|cccccc}
\toprule
\multicolumn{8}{c}{Observer Evaluation}\\
\midrule
Evaluation & Observer & Precision & Recall & F1 & Specificity & FNR & FPR \\
\midrule
\multirow{8}{*}{Mesial (binary)} & D1 & 1.000 & 0.444 & 0.615 & 1.000 & 0.556 & 0.000 \\
 & D2 & 1.000 & 0.889 & 0.941 & 1.000 & 0.111 & 0.000 \\
 & D3 & 0.722 & 0.722 & 0.722 & 0.545 & 0.278 & 0.455 \\
 & D4 & 1.000 & 0.167 & 0.286 & 1.000 & 0.833 & 0.000 \\
 & D5 & 0.846 & 0.611 & 0.710 & 0.818 & 0.389 & 0.182 \\
 & D6 & 1.000 & 0.500 & 0.667 & 1.000 & 0.500 & 0.000 \\
 & D7 & 1.000 & 0.944 & 0.971 & 1.000 & 0.056 & 0.000 \\
 & A & 0.938 & 0.611 & 0.702 & 0.909 & 0.389 & 0.091 \\
\midrule
\multirow{8}{*}{Distal (binary)} & D1 & 1.000 & 0.500 & 0.667 & 1.000 & 0.500 & 0.000 \\
 & D2 & 0.929 & 0.650 & 0.765 & 0.889 & 0.350 & 0.111 \\
 & D3 & 0.778 & 0.700 & 0.737 & 0.556 & 0.300 & 0.444 \\
 & D4 & 1.000 & 0.200 & 0.333 & 1.

### run this code to evaluate models

############################################################

############################################################

############################################################

############################################################

############################################################

In [None]:
import os
import ast
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix

# Attached files in this environment
ground_truth_loc = "F:/Github Repos/study/study_code/processed_data/task_1_inferred_ground_truth.csv"

results_loc = "F:/Github Repos/study/study_code/processed_data/object_detection/object_detection_results.csv"

OUT_DIR = os.path.join(os.path.dirname(results_loc), "confusion_matrices")
os.makedirs(OUT_DIR, exist_ok=True)

KEYS = ["mesial", "distal", "mesial_pls", "distal_pls", "furcation", "arr_left", "arr_right"]

MESIAL_DISTAL_KEYS = {"mesial", "distal"}
PLS_KEYS = {"mesial_pls", "distal_pls"}
ARR_KEYS = {"arr_left", "arr_right"}
FURCATION_KEYS = {"furcation"}

MESIAL_DISTAL_CLASSES = ["Healthy", "Mild", "Moderate", "Severe"]
BINARY_CLASSES = [False, True]


def to_binary_md(v):
    """Mesial/Distal binarisation: Healthy -> False; Anything else -> True."""
    return False if v == "Healthy" else True


def safe_literal_eval(x):
    if pd.isna(x):
        return None
    if isinstance(x, (dict, list)):
        return x
    s = str(x).strip()
    if s == "" or s.lower() in {"none", "nan"}:
        return None
    try:
        return ast.literal_eval(s)
    except Exception:
        return None


def normalise_none_token(v):
    if v is None:
        return None
    if isinstance(v, float) and np.isnan(v):
        return None
    if isinstance(v, str) and v.strip().lower() in {"none", "nan", ""}:
        return None
    return v


def normalise_boolish(v):
    v = normalise_none_token(v)
    if v is None:
        return None
    if isinstance(v, bool):
        return v
    if isinstance(v, (int, np.integer)) and v in (0, 1):
        return bool(v)
    if isinstance(v, str):
        s = v.strip().lower()
        if s in {"true", "t", "1", "yes", "y", "present", "pos", "positive"}:
            return True
        if s in {"false", "f", "0", "no", "n", "absent", "neg", "negative"}:
            return False
    return v  # keep unexpected labels visible


def drop_conf_keys(d):
    if not isinstance(d, dict):
        return d
    return {k: v for k, v in d.items() if "_conf" not in str(k)}


def extract_tooth_entries(obj):
    """
    Accepts:
      - list[dict] with 'tooth_num'
      - dict with 'tooth_annotations'
      - dict keyed by tooth_num
      - single dict with 'tooth_num'
    Returns list[(tooth_num, payload_dict)].
    """
    obj = safe_literal_eval(obj)
    if obj is None:
        return []

    if isinstance(obj, dict) and "tooth_annotations" in obj:
        obj = obj["tooth_annotations"]

    out = []

    if isinstance(obj, list):
        for item in obj:
            if not isinstance(item, dict):
                continue
            item = drop_conf_keys(item)
            tn = normalise_none_token(item.get("tooth_num", None))
            if tn is None:
                continue
            if isinstance(tn, str) and tn.strip().isdigit():
                tn = int(tn.strip())
            out.append((tn, item))
        return out

    if isinstance(obj, dict):
        if "tooth_num" in obj:
            payload = drop_conf_keys(obj)
            tn = normalise_none_token(payload.get("tooth_num", None))
            if tn is None:
                return []
            if isinstance(tn, str) and tn.strip().isdigit():
                tn = int(tn.strip())
            return [(tn, payload)]

        for k, v in obj.items():
            if not isinstance(v, dict):
                continue
            tn = k
            if isinstance(tn, str) and tn.strip().isdigit():
                tn = int(tn.strip())
            payload = drop_conf_keys(v)
            out.append((tn, payload))
        return out

    return []


def should_count_pair(gt_val, pred_val, key):
    gt_val = normalise_none_token(gt_val)
    pred_val = normalise_none_token(pred_val)

    # Mesial/Distal
    if key in MESIAL_DISTAL_KEYS:
        if gt_val is None:
            gt_val = "Healthy"
        if pred_val is None:
            pred_val = "Healthy"
        return True, gt_val, pred_val

    # Furcation
    if key in FURCATION_KEYS:
        if gt_val is None:
            return False, None, None

        gt_b = normalise_boolish(gt_val)
        pred_b = normalise_boolish(pred_val)

        if gt_b is None:
            return False, None, None
        if pred_b is None:
            pred_b = False

        return True, gt_b, pred_b

    # ARR + PLS
    if key in ARR_KEYS or key in PLS_KEYS:
        gt_b = normalise_boolish(gt_val)
        pred_b = normalise_boolish(pred_val)

        if gt_b is None:
            gt_b = False
        if pred_b is None:
            pred_b = False

        return True, gt_b, pred_b

    if gt_val is None or pred_val is None:
        return False, None, None
    return True, gt_val, pred_val


def make_confusion_df(y_true, y_pred, labels):
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    return pd.DataFrame(
        cm,
        index=[f"GT_{l}" for l in labels],
        columns=[f"PRED_{l}" for l in labels],
    )



gt_raw = pd.read_csv(ground_truth_loc)
pred_raw = pd.read_csv(results_loc)

# Ground truth requirements
if "tooth_annotations" not in gt_raw.columns:
    raise ValueError("Ground truth CSV must contain 'tooth_annotations'.")
gt_image_col = "image_name" if "image_name" in gt_raw.columns else ("image" if "image" in gt_raw.columns else None)
if gt_image_col is None:
    raise ValueError("Ground truth CSV must contain 'image_name' or 'image'.")

pred_annot_col = None
for c in ["annotations", "tooth_annotations"]:
    if c in pred_raw.columns:
        pred_annot_col = c
        break
if pred_annot_col is None:
    raise ValueError("Results CSV must contain 'annotations' or 'tooth_annotations'.")

pred_image_col = "image_name" if "image_name" in pred_raw.columns else ("image" if "image" in pred_raw.columns else None)
if pred_image_col is None:
    raise ValueError("Results CSV must contain 'image_name' or 'image'.")



gt_rows = []
for _, r in gt_raw.iterrows():
    img = r[gt_image_col]
    for tooth_num, payload in extract_tooth_entries(r["tooth_annotations"]):
        row = {"image_name": img, "tooth_num": tooth_num}
        for k in KEYS:
            row[k] = normalise_none_token(payload.get(k, None))
        gt_rows.append(row)

gt_long = pd.DataFrame(gt_rows)
if gt_long.empty:
    raise ValueError("Parsed GT is empty. Check 'tooth_annotations' format.")
gt_long = gt_long.drop_duplicates(subset=["image_name", "tooth_num"], keep="first")



pred_rows = []
for _, r in pred_raw.iterrows():
    img = r[pred_image_col]
    for tooth_num, payload in extract_tooth_entries(r[pred_annot_col]):
        out = {"image_name": img, "tooth_num": tooth_num}
        for k in KEYS:
            out[k] = normalise_none_token(payload.get(k, None))
        pred_rows.append(out)

pred_long = pd.DataFrame(pred_rows)
if pred_long.empty:
    raise ValueError("Parsed results is empty. Check predictions annotation column format.")



df = gt_long.merge(
    pred_long,
    on=["image_name", "tooth_num"],
    suffixes=("_gt", "_pred"),
    how="left",
)

summary_rows = []

for key in KEYS:
    gt_col = f"{key}_gt"
    pred_col = f"{key}_pred"

    y_true, y_pred = [], []
    for gt_val, pred_val in zip(df[gt_col].tolist(), df[pred_col].tolist()):
        ok, gt_v, pred_v = should_count_pair(gt_val, pred_val, key)
        if not ok:
            continue
        y_true.append(gt_v)
        y_pred.append(pred_v)

    n_used = len(y_true)
    if n_used == 0:
        with open(os.path.join(OUT_DIR, f"{key}_NO_COUNTS.txt"), "w", encoding="utf-8") as f:
            f.write(
                "No valid pairs to count under rules.\n"
                "Mesial/Distal: None treated as Healthy.\n"
                "ARR+PLS: None treated as False (not skipped).\n"
                "Furcation: GT None skipped; missing pred treated as False.\n"
                "Keys containing '_conf' are ignored.\n"
            )
        summary_rows.append({"key": key, "n_used": 0})
        continue

    if key in MESIAL_DISTAL_KEYS:
        labels = MESIAL_DISTAL_CLASSES
        extra = sorted((set(y_true) | set(y_pred)) - set(labels))
        labels = labels + extra
    elif key in ARR_KEYS or key in PLS_KEYS or key in FURCATION_KEYS:
        labels = BINARY_CLASSES
        extra = sorted((set(y_true) | set(y_pred)) - set(labels), key=lambda x: str(x))
        labels = labels + extra
    else:
        labels = sorted(set(y_true) | set(y_pred), key=lambda x: str(x))

    cm_df = make_confusion_df(y_true, y_pred, labels=labels)
    cm_df.to_csv(os.path.join(OUT_DIR, f"{key}_confusion_matrix.csv"), index=True)

    if key in MESIAL_DISTAL_KEYS:
        y_true_bin = [to_binary_md(v) for v in y_true]
        y_pred_bin = [to_binary_md(v) for v in y_pred]
        cm_bin_df = make_confusion_df(y_true_bin, y_pred_bin, labels=BINARY_CLASSES)
        cm_bin_df.to_csv(os.path.join(OUT_DIR, f"{key}_binary_confusion_matrix.csv"), index=True)

    summary_rows.append({
        "key": key,
        "n_used": n_used,
        "labels": "|".join(map(str, labels)),
    })

summary_df = pd.DataFrame(summary_rows).sort_values(["key"])
summary_df.to_csv(os.path.join(OUT_DIR, "SUMMARY_counts_and_labels.csv"), index=False)

print(f"Done. Wrote confusion matrices to: {OUT_DIR}")


Done. Wrote confusion matrices to: F:/Github Repos/study/study_code/processed_data/object_detection\confusion_matrices


In [None]:
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Tuple, List

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt




root = Path("F:/Github Repos/study/study_code/processed_data/object_detection/confusion_matrices")
out_dir = Path("F:/Github Repos/study/study_code/processed_data/object_detection/results_dentists")
n_boot = 2000
seed = 0





def _normalise_label(s: str) -> str:
    s = str(s)
    if s.startswith("GT_"):
        return s[3:]
    if s.startswith("PRED_"):
        return s[5:]
    return s


def load_confusion_csv(csv_path: Path) -> pd.DataFrame:
    """
    Load a confusion matrix CSV with a first column as GT labels (often 'Unnamed: 0'),
    and columns for predicted labels.
    Returns a DataFrame with index=GT labels, columns=PRED labels, values=int counts.
    """
    df = pd.read_csv(csv_path)

    gt_col = df.columns[0]
    df = df.rename(columns={gt_col: "GT_LABEL"})
    df["GT_LABEL"] = df["GT_LABEL"].map(_normalise_label)

    pred_cols = [c for c in df.columns if c != "GT_LABEL"]
    renamed = {c: _normalise_label(c) for c in pred_cols}
    df = df.rename(columns=renamed)

    df = df.set_index("GT_LABEL")
    df = df.apply(pd.to_numeric, errors="raise").astype(int)

    if (df.values < 0).any():
        raise ValueError(f"Negative counts found in {csv_path}")

    return df


@dataclass
class BinaryCounts:
    tn: int
    fp: int
    fn: int
    tp: int


def _safe_div(num: float, den: float) -> float:
    return float(num) / float(den) if den != 0 else np.nan


def binary_counts_from_cm(
    cm: pd.DataFrame,
    negative_label: str = "False",
    positive_label: str = "True",
) -> BinaryCounts:
    """
    Extract TN/FP/FN/TP from a 2x2 confusion matrix with GT rows and PRED cols.
    """
    idx = {str(i): i for i in cm.index}
    col = {str(c): c for c in cm.columns}

    def pick(mapping: Dict[str, str], target: str) -> str:
        if target in mapping:
            return mapping[target]
        for k, v in mapping.items():
            if k.lower() == target.lower():
                return v
        raise KeyError(f"Could not find label '{target}' in {list(mapping.keys())}")

    neg_i = pick(idx, negative_label)
    pos_i = pick(idx, positive_label)
    neg_c = pick(col, negative_label)
    pos_c = pick(col, positive_label)

    tn = int(cm.loc[neg_i, neg_c])
    fp = int(cm.loc[neg_i, pos_c])
    fn = int(cm.loc[pos_i, neg_c])
    tp = int(cm.loc[pos_i, pos_c])
    return BinaryCounts(tn=tn, fp=fp, fn=fn, tp=tp)


def metrics_from_binary_counts(b: BinaryCounts) -> Dict[str, float]:
    tn, fp, fn, tp = b.tn, b.fp, b.fn, b.tp
    n = tn + fp + fn + tp

    sens = _safe_div(tp, tp + fn)
    spec = _safe_div(tn, tn + fp)
    ppv = _safe_div(tp, tp + fp)
    npv = _safe_div(tn, tn + fn)
    f1 = _safe_div(2 * tp, 2 * tp + fp + fn)
    acc = _safe_div(tp + tn, n)
    bal_acc = np.nanmean([sens, spec])

    mcc_num = (tp * tn) - (fp * fn)
    mcc_den = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
    mcc = _safe_div(mcc_num, mcc_den)

    fnr = 1.0 - sens
    fpr = 1.0 - spec
    misdx = 1.0 - acc

    return {
        "n": float(n),
        "tp": float(tp), "tn": float(tn), "fp": float(fp), "fn": float(fn),
        "sensitivity": sens,
        "specificity": spec,
        "precision_ppv": ppv,
        "npv": npv,
        "f1": f1,
        "accuracy": acc,
        "balanced_accuracy": bal_acc,
        "mcc": mcc,
        "false_negative_rate": fnr,
        "false_positive_rate": fpr,
        "misdiagnosis_rate": misdx,
    }


def metrics_multiclass(cm: pd.DataFrame) -> Dict[str, float]:
    """
    Multiclass metrics with one-vs-rest macro precision/recall/F1 AND macro-specificity.
    Also returns:
      false_negative_rate = 1 - macro_recall
      false_positive_rate = 1 - macro_specificity
    """
    cmv = cm.values.astype(float)
    n = cmv.sum()
    acc = _safe_div(np.trace(cmv), n)

    tp = np.diag(cmv)
    pred_sum = cmv.sum(axis=0)
    gt_sum = cmv.sum(axis=1)

    fp = pred_sum - tp
    fn = gt_sum - tp
    tn = n - tp - fp - fn

    prec = np.array([_safe_div(tp[i], pred_sum[i]) for i in range(len(tp))], dtype=float)
    rec  = np.array([_safe_div(tp[i], gt_sum[i]) for i in range(len(tp))], dtype=float)
    f1   = np.array([_safe_div(2 * prec[i] * rec[i], prec[i] + rec[i]) for i in range(len(tp))], dtype=float)
    spec = np.array([_safe_div(tn[i], tn[i] + fp[i]) for i in range(len(tp))], dtype=float)

    macro_prec = np.nanmean(prec)
    macro_rec  = np.nanmean(rec)
    macro_f1   = np.nanmean(f1)
    macro_spec = np.nanmean(spec)

    fnr = 1.0 - macro_rec
    fpr = 1.0 - macro_spec

    return {
        "n": float(n),
        "accuracy": acc,
        "macro_precision": macro_prec,
        "macro_recall": macro_rec,
        "macro_f1": macro_f1,
        "balanced_accuracy": macro_rec,
        "macro_specificity": macro_spec,
        "false_negative_rate": fnr,
        "false_positive_rate": fpr,
    }


def bootstrap_binary_cis(b: BinaryCounts, n_boot: int = 2000, seed: int = 0) -> Dict[str, Tuple[float, float]]:
    """
    Bootstrap CIs using multinomial resampling over the 4 cells.
    """
    rng = np.random.default_rng(seed)
    counts = np.array([b.tn, b.fp, b.fn, b.tp], dtype=int)
    n = counts.sum()
    if n == 0:
        return {}

    p = counts / n
    boots = []
    for _ in range(n_boot):
        sample = rng.multinomial(n, p)
        bb = BinaryCounts(tn=int(sample[0]), fp=int(sample[1]), fn=int(sample[2]), tp=int(sample[3]))
        m = metrics_from_binary_counts(bb)
        boots.append([
            m["false_negative_rate"],
            m["false_positive_rate"],
            m["sensitivity"],
            m["specificity"],
            m["balanced_accuracy"],
            m["f1"],
            m["mcc"],
        ])

    arr = np.asarray(boots, dtype=float)
    keys = ["false_negative_rate", "false_positive_rate", "sensitivity", "specificity", "balanced_accuracy", "f1", "mcc"]
    cis = {}
    for i, k in enumerate(keys):
        lo, hi = np.nanpercentile(arr[:, i], [2.5, 97.5])
        cis[k] = (float(lo), float(hi))
    return cis


def infer_task_name(csv_path: Path) -> str:
    name = csv_path.stem
    for suffix in ["_confusion_matrix", "confusion_matrix"]:
        if name.endswith(suffix):
            name = name[: -len(suffix)]
            break
    return name.strip("_")


def aggregate_confusions(confusions: List[pd.DataFrame]) -> pd.DataFrame:
    """
    Sum confusion matrices with potentially different label sets.
    """
    if not confusions:
        return pd.DataFrame()

    all_gt = sorted(set().union(*[set(c.index) for c in confusions]))
    all_pr = sorted(set().union(*[set(c.columns) for c in confusions]))

    agg = pd.DataFrame(0, index=all_gt, columns=all_pr, dtype=int)
    for c in confusions:
        agg.loc[c.index, c.columns] += c.astype(int)
    return agg


def plot_metric_heatmap(per_group_dentist: pd.DataFrame,
                        metric: str,
                        out_path: Path,
                        title: str,
                        group_order: List[str],
                        group_labels: List[str],
                        value_fmt: str = "{:.3f}") -> None:
    """
    Heatmap with per-cell numeric annotations.
    Rows: dentists (D1, D2, ...)
    Cols: groups (PBL, ARR, PLS, Furcation)
    """
    df = per_group_dentist.copy()
    df = df[df["group"].isin(group_order)].copy()
    df["group"] = pd.Categorical(df["group"], categories=group_order, ordered=True)

    pivot = df.pivot_table(index="dentist_id", columns="group", values=metric, aggfunc="first")
    pivot = pivot.reindex(columns=group_order).sort_index()

    data = pivot.values.astype(float)

    fig = plt.figure(figsize=(10, 6), constrained_layout=True)
    ax = fig.add_subplot(111)

    im = ax.imshow(data, aspect="auto")

    ax.set_yticks(np.arange(pivot.shape[0]))
    ax.set_yticklabels(pivot.index)

    ax.set_xticks(np.arange(pivot.shape[1]))
    ax.set_xticklabels(group_labels, rotation=0)

    ax.set_xlabel("Task", labelpad=10)
    ax.set_ylabel("Dentist", labelpad=10)
    ax.set_title(title)

    fig.colorbar(im, ax=ax)

    for i in range(pivot.shape[0]):
        for j in range(pivot.shape[1]):
            v = data[i, j]
            txt = "NA" if np.isnan(v) else value_fmt.format(v)
            ax.text(j, i, txt, ha="center", va="center")

    fig.savefig(out_path, dpi=200)
    plt.close(fig)



out_dir.mkdir(parents=True, exist_ok=True)
plots_dir = out_dir / "plots"
plots_dir.mkdir(parents=True, exist_ok=True)

GROUPS = {
    "mesial":    {"tasks": ["mesial_binary"],                  "type": "binary"},
    "distal":    {"tasks": ["distal_binary"],                  "type": "binary"},
    "PBL":       {"tasks": ["mesial_binary", "distal_binary"], "type": "binary"},
    "ARR":       {"tasks": ["arr_left", "arr_right"],          "type": "binary"},
    "Furcation": {"tasks": ["furcation"],                      "type": "binary"},
    "PLS":       {"tasks": ["mesial_pls", "distal_pls"],       "type": "binary"},
}

KNOWN_TASKS = sorted({t for spec in GROUPS.values() for t in spec["tasks"]}, key=len, reverse=True)


def collect_confusion_files_flat(folder: Path, pattern: str = "*confusion_matrix.csv") -> List[Path]:
    return sorted(folder.glob(pattern))


def infer_dentist_and_task(csv_path: Path, known_tasks: List[str]) -> Tuple[str, str]:
    """
    We infer task by matching the filename (minus confusion_matrix suffix) to a known task.
    Dentist is whatever precedes the task, or 'ALL' if none.
    Examples:
      D1_mesial_binary_confusion_matrix.csv -> dentist_raw=D1, task=mesial_binary
      mesial_binary_confusion_matrix.csv    -> dentist_raw=ALL, task=mesial_binary
    """
    base = infer_task_name(csv_path)

    for t in known_tasks:
        if base == t:
            return "ALL", t
        for sep in ["__", "_", "-"]:
            suf = f"{sep}{t}"
            if base.endswith(suf):
                dentist_raw = base[: -len(suf)].strip("_-")
                return (dentist_raw if dentist_raw else "ALL"), t

    raise ValueError(
        f"Could not infer task from filename '{csv_path.name}'. Base='{base}'. "
        f"Expected it to end with one of known tasks: {known_tasks}"
    )


all_rows = []
csv_files = collect_confusion_files_flat(root)

if not csv_files:
    raise RuntimeError(f"No confusion_matrix.csv files found directly under: {root}")

skipped = []
dentist_raws_seen = set()

for csv_path in csv_files:
    try:
        dentist_raw, task = infer_dentist_and_task(csv_path, KNOWN_TASKS)
    except ValueError:
        skipped.append(csv_path.name)
        continue

    cm = load_confusion_csv(csv_path)

    dentist_raws_seen.add(dentist_raw)
    all_rows.append({
        "dentist_raw": dentist_raw,
        "task": task,
        "file": str(csv_path),
        "cm": cm,
    })

all_df = pd.DataFrame(all_rows)
if all_df.empty:
    raise RuntimeError(
        f"No recognised confusion matrix files found under: {root}\n"
        f"Skipped: {skipped}"
    )


dentist_raws = sorted(dentist_raws_seen)
dentist_map = {raw: f"D{i+1}" for i, raw in enumerate(dentist_raws)}
all_df["dentist_id"] = all_df["dentist_raw"].map(dentist_map)

if skipped:
    print("Skipped files (not part of configured tasks):")
    for s in skipped:
        print(f"  - {s}")


per_group_file_rows = []
per_group_dentist_rows = []

for group_name, spec in GROUPS.items():
    tasks = set(spec["tasks"])
    group_type = spec["type"]

    group_out = out_dir / group_name
    group_out.mkdir(parents=True, exist_ok=True)

    gdf = all_df[all_df["task"].isin(tasks)].copy()

    for _, r in gdf.iterrows():
        cm = r["cm"]

        if group_type == "binary":
            if cm.shape != (2, 2):
                raise ValueError(f"[{group_name}] Expected binary 2x2 but got {cm.shape} in {r['file']}")
            b = binary_counts_from_cm(cm, "False", "True")
            m = metrics_from_binary_counts(b)
        else:
            m = metrics_multiclass(cm)

        per_group_file_rows.append({
            "group": group_name,
            "dentist_raw": r["dentist_raw"],
            "dentist_id": r["dentist_id"],
            "task": r["task"],
            "file": r["file"],
            **m
        })

    for dentist_raw, sub in gdf.groupby("dentist_raw"):
        dentist_id = dentist_map[dentist_raw]
        cms = list(sub["cm"].values)
        agg = aggregate_confusions(cms)

        if agg.empty:
            continue

        if group_type == "binary":
            if agg.shape != (2, 2):
                raise ValueError(f"[{group_name}] Expected aggregated binary 2x2 but got {agg.shape} for {dentist_raw}")
            b = binary_counts_from_cm(agg, "False", "True")
            m = metrics_from_binary_counts(b)
            cis = bootstrap_binary_cis(b, n_boot=n_boot, seed=seed)

            row = {"group": group_name, "dentist_raw": dentist_raw, "dentist_id": dentist_id, **m}
            for k, (lo, hi) in cis.items():
                row[f"{k}_ci_low"] = lo
                row[f"{k}_ci_high"] = hi
            per_group_dentist_rows.append(row)
        else:
            m = metrics_multiclass(agg)
            per_group_dentist_rows.append({"group": group_name, "dentist_raw": dentist_raw, "dentist_id": dentist_id, **m})

    per_file_g = pd.DataFrame([r for r in per_group_file_rows if r["group"] == group_name])
    per_dent_g = pd.DataFrame([r for r in per_group_dentist_rows if r["group"] == group_name])

    per_file_g.to_csv(group_out / f"per_file_metrics_{group_name}.csv", index=False)
    per_dent_g.to_csv(group_out / f"per_dentist_metrics_{group_name}.csv", index=False)

# Save global combined tables
per_file = pd.DataFrame(per_group_file_rows)
per_dentist = pd.DataFrame(per_group_dentist_rows)
per_file.to_csv(out_dir / "per_file_metrics_ALLGROUPS.csv", index=False)
per_dentist.to_csv(out_dir / "per_dentist_metrics_ALLGROUPS.csv", index=False)

# Plots
plot_groups = ["PBL", "ARR", "PLS", "Furcation"]
plot_labels = ["PBL (mesial+distal, binary)", "ARR", "PLS", "Furcation"]

plot_df = per_dentist[per_dentist["group"].isin(plot_groups)].copy()

plot_metric_heatmap(
    per_group_dentist=plot_df,
    metric="false_negative_rate",
    out_path=plots_dir / "heatmap_false_negative_rate.png",
    title="False Negative Rate (1 - recall)",
    group_order=plot_groups,
    group_labels=plot_labels,
)

plot_metric_heatmap(
    per_group_dentist=plot_df,
    metric="false_positive_rate",
    out_path=plots_dir / "heatmap_false_positive_rate.png",
    title="False Positive Rate (1 - specificity)",
    group_order=plot_groups,
    group_labels=plot_labels,
)

print("Done. Outputs written to:")
print(f"  {out_dir}")
print("Plots written to:")
print(f"  {plots_dir}")


Skipped files (not part of configured tasks):
  - distal_confusion_matrix.csv
  - mesial_confusion_matrix.csv


  return _nanquantile_unchecked(
  pivot = df.pivot_table(index="dentist_id", columns="group", values=metric, aggfunc="first")
  pivot = df.pivot_table(index="dentist_id", columns="group", values=metric, aggfunc="first")


Done. Outputs written to:
  F:\Github Repos\study\study_code\processed_data\object_detection\results_dentists
Plots written to:
  F:\Github Repos\study\study_code\processed_data\object_detection\results_dentists\plots


In [None]:

from pathlib import Path
import re
import numpy as np
import pandas as pd


out_dir = Path(r"F:/Github Repos/study/study_code/processed_data/keypoint/results_dentists_multiclass")
# results_path = out_dir / "per_task_metrics.csv"
results_path = out_dir / "per_dentist_metrics_ALLGROUPS.csv"



GROUP_ORDER = ["mesial", "distal", "PBL", "ARR", "PLS", "Furcation"]
GROUP_LABELS = {
    "mesial": "Mesial (binary)",
    "distal": "Distal (binary)",
    "PBL": "PBL (mesial+distal, binary)",
    "ARR": "ARR",
    "PLS": "PLS",
    "Furcation": "Furcation",
}


METRICS = [
    ("macro_precision", "Precision"),
    ("macro_recall", "Recall"),
    ("macro_f1", "F1"),
    ("macro_specificity", "Specificity"),
    ("false_negative_rate", "FNR"),
    ("false_positive_rate", "FPR"),
]

CAPTION = (
    "Single-system results across evaluation types. Metrics are computed from aggregated confusion matrices "
    "per evaluation type. For binary tasks, macro metrics are set to their binary equivalents."
)
LABEL = "tab:single_system_eval"


def _fmt(x: float, ndp: int = 3) -> str:
    if x is None or (isinstance(x, float) and np.isnan(x)):
        return "NA"
    return f"{float(x):.{ndp}f}"


def _latex_escape(s: str) -> str:
    s = str(s)
    return (s.replace("\\", r"\textbackslash{}")
             .replace("&", r"\&")
             .replace("%", r"\%")
             .replace("$", r"\$")
             .replace("#", r"\#")
             .replace("_", r"\_")
             .replace("{", r"\{")
             .replace("}", r"\}")
             .replace("~", r"\textasciitilde{}")
             .replace("^", r"\textasciicircum{}"))


def _is_binary_group(g: str) -> bool:
    return g in {"mesial", "distal", "PBL", "ARR", "PLS", "Furcation"}


def _autofind_results_csv(root: Path) -> Path:
    cands = list(root.rglob("per_group_metrics.csv"))
    if cands:
        return cands[0]
    cands = list(root.rglob("*per_group_metrics*.csv"))
    if cands:
        return cands[0]
    raise FileNotFoundError(
        f"Could not find per_group_metrics.csv under: {root}\n"
        "Set `results_path` to the correct CSV."
    )


def _ensure_macro_columns_for_binary_single(df: pd.DataFrame) -> pd.DataFrame:
    """
    For binary groups, map:
      macro_precision    <- precision_ppv
      macro_recall       <- sensitivity
      macro_f1           <- f1
      macro_specificity  <- specificity

    Ensure FNR/FPR exist:
      FNR <- 1 - macro_recall
      FPR <- 1 - macro_specificity
    """
    df = df.copy()
    is_bin = df["group"].apply(_is_binary_group)

    for c in ["macro_precision", "macro_recall", "macro_f1", "macro_specificity"]:
        if c not in df.columns:
            df[c] = np.nan

    def fill(target: str, source: str):
        if source in df.columns:
            mask = is_bin & df[target].isna()
            df.loc[mask, target] = pd.to_numeric(df.loc[mask, source], errors="coerce")

    fill("macro_precision", "precision_ppv")
    fill("macro_recall", "sensitivity")
    fill("macro_f1", "f1")
    fill("macro_specificity", "specificity")

    if "false_negative_rate" not in df.columns:
        df["false_negative_rate"] = 1.0 - pd.to_numeric(df["macro_recall"], errors="coerce")
    if "false_positive_rate" not in df.columns:
        df["false_positive_rate"] = 1.0 - pd.to_numeric(df["macro_specificity"], errors="coerce")

    return df



if not results_path.exists():
    results_path = _autofind_results_csv(out_dir)

df = pd.read_csv(results_path)

required = {"group"}
missing = required - set(df.columns)
if missing:
    raise ValueError(f"Missing required columns in results CSV: {missing}")

df = df[df["group"].isin(GROUP_ORDER)].copy()
df["group"] = pd.Categorical(df["group"], categories=GROUP_ORDER, ordered=True)

df = _ensure_macro_columns_for_binary_single(df)

keep_cols = ["group"] + [c for c, _ in METRICS]
for c in keep_cols:
    if c not in df.columns:
        df[c] = np.nan
df = df[keep_cols].copy()

df = (df.sort_values(["group"])
        .groupby("group", as_index=False)
        .first())

metric_disp = [disp for _, disp in METRICS]
n_metrics = len(metric_disp)

colspec = "l|" + "".join(["c"] * n_metrics)

lines = []
lines.append(r"\begin{table}[!htbp]")
lines.append(r"\centering")
lines.append(r"\scriptsize")
lines.append(r"\begin{adjustbox}{center, max width=\paperwidth}")
lines.append(r"\begin{tabular}{" + colspec + r"}")
lines.append(r"\toprule")
lines.append(r"\multicolumn{" + str(1 + n_metrics) + r"}{c}{Single-system Evaluation}\\")
lines.append(r"\midrule")

# Header row
hdr = ["Evaluation"] + metric_disp
lines.append(" & ".join(map(_latex_escape, hdr)) + r" \\")
lines.append(r"\midrule")

# Body, one row per evaluation type
for gi, g in enumerate(GROUP_ORDER):
    sub = df[df["group"] == g].copy()
    eval_name = GROUP_LABELS.get(g, g)

    if sub.empty:
        vals = ["NA"] * n_metrics
    else:
        row = sub.iloc[0]
        vals = []
        for col, _disp in METRICS:
            vals.append(_fmt(pd.to_numeric(row[col], errors="coerce")))

    lines.append(_latex_escape(eval_name) + " & " + " & ".join(vals) + r" \\")

    if gi == len(GROUP_ORDER) - 1:
        lines.append(r"\bottomrule")
    else:
        lines.append(r"\midrule")

lines.append(r"\end{tabular}")
lines.append(r"\end{adjustbox}")
lines.append(r"\caption{" + CAPTION + r"}")
lines.append(r"\label{" + LABEL + r"}")
lines.append(r"\end{table}")

print("\n".join(lines))

# Required LaTeX packages:
# \usepackage{booktabs}
# \usepackage{adjustbox}


\begin{table}[!htbp]
\centering
\scriptsize
\begin{adjustbox}{center, max width=\paperwidth}
\begin{tabular}{l|cccccc}
\toprule
\multicolumn{7}{c}{Single-system Evaluation}\\
\midrule
Evaluation & Precision & Recall & F1 & Specificity & FNR & FPR \\
\midrule
Mesial (binary) & 1.000 & 0.667 & 0.800 & 1.000 & 0.333 & 0.000 \\
\midrule
Distal (binary) & 0.889 & 0.400 & 0.552 & 0.889 & 0.600 & 0.111 \\
\midrule
PBL (mesial+distal, binary) & 0.280 & 0.299 & 0.299 & 0.782 & 0.701 & 0.218 \\
\midrule
ARR & 1.000 & 0.375 & 0.545 & 1.000 & 0.625 & 0.000 \\
\midrule
PLS & 0.600 & 0.600 & 0.600 & 0.962 & 0.400 & 0.038 \\
\midrule
Furcation & 0.286 & 1.000 & 0.444 & 0.444 & 0.000 & 0.556 \\
\bottomrule
\end{tabular}
\end{adjustbox}
\caption{Single-system results across evaluation types. Metrics are computed from aggregated confusion matrices per evaluation type. For binary tasks, macro metrics are set to their binary equivalents.}
\label{tab:single_system_eval}
\end{table}


  .groupby("group", as_index=False)


In [None]:
### MULTI CLASS PBL


from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Tuple, List

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


root = Path("F:/Github Repos/study/study_code/processed_data/keypoint/confusion_matrices")
out_dir = Path("F:/Github Repos/study/study_code/processed_data/keypoint/results_dentists_multiclass")
n_boot = 2000
seed = 0



def _normalise_label(s: str) -> str:
    s = str(s)
    if s.startswith("GT_"):
        return s[3:]
    if s.startswith("PRED_"):
        return s[5:]
    return s


def load_confusion_csv(csv_path: Path) -> pd.DataFrame:
    """
    Load a confusion matrix CSV with a first column as GT labels (often 'Unnamed: 0'),
    and columns for predicted labels.
    Returns a DataFrame with index=GT labels, columns=PRED labels, values=int counts.
    """
    df = pd.read_csv(csv_path)

    gt_col = df.columns[0]
    df = df.rename(columns={gt_col: "GT_LABEL"})
    df["GT_LABEL"] = df["GT_LABEL"].map(_normalise_label)

    pred_cols = [c for c in df.columns if c != "GT_LABEL"]
    renamed = {c: _normalise_label(c) for c in pred_cols}
    df = df.rename(columns=renamed)

    df = df.set_index("GT_LABEL")
    df = df.apply(pd.to_numeric, errors="raise").astype(int)

    if (df.values < 0).any():
        raise ValueError(f"Negative counts found in {csv_path}")

    return df


@dataclass
class BinaryCounts:
    tn: int
    fp: int
    fn: int
    tp: int


def _safe_div(num: float, den: float) -> float:
    return float(num) / float(den) if den != 0 else np.nan


def binary_counts_from_cm(
    cm: pd.DataFrame,
    negative_label: str = "False",
    positive_label: str = "True",
) -> BinaryCounts:
    """
    Extract TN/FP/FN/TP from a 2x2 confusion matrix with GT rows and PRED cols.
    """
    idx = {str(i): i for i in cm.index}
    col = {str(c): c for c in cm.columns}

    def pick(mapping: Dict[str, str], target: str) -> str:
        if target in mapping:
            return mapping[target]
        for k, v in mapping.items():
            if k.lower() == target.lower():
                return v
        raise KeyError(f"Could not find label '{target}' in {list(mapping.keys())}")

    neg_i = pick(idx, negative_label)
    pos_i = pick(idx, positive_label)
    neg_c = pick(col, negative_label)
    pos_c = pick(col, positive_label)

    tn = int(cm.loc[neg_i, neg_c])
    fp = int(cm.loc[neg_i, pos_c])
    fn = int(cm.loc[pos_i, neg_c])
    tp = int(cm.loc[pos_i, pos_c])
    return BinaryCounts(tn=tn, fp=fp, fn=fn, tp=tp)


def metrics_from_binary_counts(b: BinaryCounts) -> Dict[str, float]:
    tn, fp, fn, tp = b.tn, b.fp, b.fn, b.tp
    n = tn + fp + fn + tp

    sens = _safe_div(tp, tp + fn)
    spec = _safe_div(tn, tn + fp)
    ppv = _safe_div(tp, tp + fp)
    npv = _safe_div(tn, tn + fn)
    f1 = _safe_div(2 * tp, 2 * tp + fp + fn)
    acc = _safe_div(tp + tn, n)
    bal_acc = np.nanmean([sens, spec])

    mcc_num = (tp * tn) - (fp * fn)
    mcc_den = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
    mcc = _safe_div(mcc_num, mcc_den)

    fnr = 1.0 - sens
    fpr = 1.0 - spec
    misdx = 1.0 - acc

    return {
        "n": float(n),
        "tp": float(tp), "tn": float(tn), "fp": float(fp), "fn": float(fn),
        "sensitivity": sens,
        "specificity": spec,
        "precision_ppv": ppv,
        "npv": npv,
        "f1": f1,
        "accuracy": acc,
        "balanced_accuracy": bal_acc,
        "mcc": mcc,
        "false_negative_rate": fnr,
        "false_positive_rate": fpr,
        "misdiagnosis_rate": misdx,
    }


def metrics_multiclass(cm: pd.DataFrame) -> Dict[str, float]:
    """
    Multiclass metrics with one-vs-rest macro precision/recall/F1 AND macro-specificity.
    Also returns:
      false_negative_rate = 1 - macro_recall
      false_positive_rate = 1 - macro_specificity
    """
    cmv = cm.values.astype(float)
    n = cmv.sum()
    acc = _safe_div(np.trace(cmv), n)

    tp = np.diag(cmv)
    pred_sum = cmv.sum(axis=0)
    gt_sum = cmv.sum(axis=1)

    fp = pred_sum - tp
    fn = gt_sum - tp
    tn = n - tp - fp - fn

    prec = np.array([_safe_div(tp[i], pred_sum[i]) for i in range(len(tp))], dtype=float)
    rec  = np.array([_safe_div(tp[i], gt_sum[i]) for i in range(len(tp))], dtype=float)
    f1   = np.array([_safe_div(2 * prec[i] * rec[i], prec[i] + rec[i]) for i in range(len(tp))], dtype=float)
    spec = np.array([_safe_div(tn[i], tn[i] + fp[i]) for i in range(len(tp))], dtype=float)

    macro_prec = np.nanmean(prec)
    macro_rec  = np.nanmean(rec)
    macro_f1   = np.nanmean(f1)
    macro_spec = np.nanmean(spec)

    fnr = 1.0 - macro_rec
    fpr = 1.0 - macro_spec

    return {
        "n": float(n),
        "accuracy": acc,
        "macro_precision": macro_prec,
        "macro_recall": macro_rec,
        "macro_f1": macro_f1,
        "balanced_accuracy": macro_rec,
        "macro_specificity": macro_spec,
        "false_negative_rate": fnr,
        "false_positive_rate": fpr,
    }


def bootstrap_binary_cis(b: BinaryCounts, n_boot: int = 2000, seed: int = 0) -> Dict[str, Tuple[float, float]]:
    """
    Bootstrap CIs using multinomial resampling over the 4 cells.
    """
    rng = np.random.default_rng(seed)
    counts = np.array([b.tn, b.fp, b.fn, b.tp], dtype=int)
    n = counts.sum()
    if n == 0:
        return {}

    p = counts / n
    boots = []
    for _ in range(n_boot):
        sample = rng.multinomial(n, p)
        bb = BinaryCounts(tn=int(sample[0]), fp=int(sample[1]), fn=int(sample[2]), tp=int(sample[3]))
        m = metrics_from_binary_counts(bb)
        boots.append([
            m["false_negative_rate"],
            m["false_positive_rate"],
            m["sensitivity"],
            m["specificity"],
            m["balanced_accuracy"],
            m["f1"],
            m["mcc"],
        ])

    arr = np.asarray(boots, dtype=float)
    keys = ["false_negative_rate", "false_positive_rate", "sensitivity", "specificity", "balanced_accuracy", "f1", "mcc"]
    cis = {}
    for i, k in enumerate(keys):
        lo, hi = np.nanpercentile(arr[:, i], [2.5, 97.5])
        cis[k] = (float(lo), float(hi))
    return cis


def infer_task_name(csv_path: Path) -> str:
    name = csv_path.stem
    for suffix in ["_confusion_matrix", "confusion_matrix"]:
        if name.endswith(suffix):
            name = name[: -len(suffix)]
            break
    return name.strip("_")


def aggregate_confusions(confusions: List[pd.DataFrame]) -> pd.DataFrame:
    """
    Sum confusion matrices with potentially different label sets.
    """
    if not confusions:
        return pd.DataFrame()

    all_gt = sorted(set().union(*[set(c.index) for c in confusions]))
    all_pr = sorted(set().union(*[set(c.columns) for c in confusions]))

    agg = pd.DataFrame(0, index=all_gt, columns=all_pr, dtype=int)
    for c in confusions:
        agg.loc[c.index, c.columns] += c.astype(int)
    return agg


def plot_metric_heatmap(per_group_dentist: pd.DataFrame,
                        metric: str,
                        out_path: Path,
                        title: str,
                        group_order: List[str],
                        group_labels: List[str],
                        value_fmt: str = "{:.3f}") -> None:
    """
    Heatmap with per-cell numeric annotations.
    Rows: dentists (D1, D2, ...)
    Cols: groups (PBL, ARR, PLS, Furcation)
    """
    df = per_group_dentist.copy()
    df = df[df["group"].isin(group_order)].copy()
    df["group"] = pd.Categorical(df["group"], categories=group_order, ordered=True)

    pivot = df.pivot_table(index="dentist_id", columns="group", values=metric, aggfunc="first")
    pivot = pivot.reindex(columns=group_order).sort_index()

    data = pivot.values.astype(float)

    fig = plt.figure(figsize=(10, 6), constrained_layout=True)
    ax = fig.add_subplot(111)

    im = ax.imshow(data, aspect="auto")

    ax.set_yticks(np.arange(pivot.shape[0]))
    ax.set_yticklabels(pivot.index)

    ax.set_xticks(np.arange(pivot.shape[1]))
    ax.set_xticklabels(group_labels, rotation=0)

    ax.set_xlabel("Task", labelpad=10)
    ax.set_ylabel("Dentist", labelpad=10)
    ax.set_title(title)

    fig.colorbar(im, ax=ax)

    for i in range(pivot.shape[0]):
        for j in range(pivot.shape[1]):
            v = data[i, j]
            txt = "NA" if np.isnan(v) else value_fmt.format(v)
            ax.text(j, i, txt, ha="center", va="center")

    fig.savefig(out_path, dpi=200)
    plt.close(fig)


out_dir.mkdir(parents=True, exist_ok=True)
plots_dir = out_dir / "plots"
plots_dir.mkdir(parents=True, exist_ok=True)


GROUPS = {
    "mesial":    {"tasks": ["mesial_binary"],          "type": "binary"},
    "distal":    {"tasks": ["distal_binary"],          "type": "binary"},
    "PBL":       {"tasks": ["mesial", "distal"],       "type": "multiclass"},  # <-- changed
    "ARR":       {"tasks": ["arr_left", "arr_right"],  "type": "binary"},
    "Furcation": {"tasks": ["furcation"],              "type": "binary"},
    "PLS":       {"tasks": ["mesial_pls", "distal_pls"], "type": "binary"},
}

KNOWN_TASKS = sorted({t for spec in GROUPS.values() for t in spec["tasks"]}, key=len, reverse=True)


def collect_confusion_files_flat(folder: Path, pattern: str = "*confusion_matrix.csv") -> List[Path]:
    return sorted(folder.glob(pattern))


def infer_dentist_and_task(csv_path: Path, known_tasks: List[str]) -> Tuple[str, str]:
    """
    Infer task by matching the filename (minus confusion_matrix suffix) to a known task.
    Dentist is whatever precedes the task, or 'ALL' if none.
    """
    base = infer_task_name(csv_path)

    for t in known_tasks:
        if base == t:
            return "ALL", t
        for sep in ["__", "_", "-"]:
            suf = f"{sep}{t}"
            if base.endswith(suf):
                dentist_raw = base[: -len(suf)].strip("_-")
                return (dentist_raw if dentist_raw else "ALL"), t

    raise ValueError(
        f"Could not infer task from filename '{csv_path.name}'. Base='{base}'. "
        f"Expected it to end with one of known tasks: {known_tasks}"
    )



all_rows = []
csv_files = collect_confusion_files_flat(root)

if not csv_files:
    raise RuntimeError(f"No confusion_matrix.csv files found directly under: {root}")

skipped = []
dentist_raws_seen = set()

for csv_path in csv_files:
    try:
        dentist_raw, task = infer_dentist_and_task(csv_path, KNOWN_TASKS)
    except ValueError:
        skipped.append(csv_path.name)
        continue

    cm = load_confusion_csv(csv_path)

    dentist_raws_seen.add(dentist_raw)
    all_rows.append({
        "dentist_raw": dentist_raw,
        "task": task,
        "file": str(csv_path),
        "cm": cm,
    })

all_df = pd.DataFrame(all_rows)
if all_df.empty:
    raise RuntimeError(
        f"No recognised confusion matrix files found under: {root}\n"
        f"Skipped: {skipped}"
    )

dentist_raws = sorted(dentist_raws_seen)
dentist_map = {raw: f"D{i+1}" for i, raw in enumerate(dentist_raws)}
all_df["dentist_id"] = all_df["dentist_raw"].map(dentist_map)

if skipped:
    print("Skipped files (not part of configured tasks):")
    for s in skipped:
        print(f"  - {s}")



per_group_file_rows = []
per_group_dentist_rows = []

for group_name, spec in GROUPS.items():
    tasks = set(spec["tasks"])
    group_type = spec["type"]

    group_out = out_dir / group_name
    group_out.mkdir(parents=True, exist_ok=True)

    gdf = all_df[all_df["task"].isin(tasks)].copy()

    # Per-file metrics
    for _, r in gdf.iterrows():
        cm = r["cm"]

        if group_type == "binary":
            if cm.shape != (2, 2):
                raise ValueError(f"[{group_name}] Expected binary 2x2 but got {cm.shape} in {r['file']}")
            b = binary_counts_from_cm(cm, "False", "True")
            m = metrics_from_binary_counts(b)
        else:
            m = metrics_multiclass(cm)

        per_group_file_rows.append({
            "group": group_name,
            "dentist_raw": r["dentist_raw"],
            "dentist_id": r["dentist_id"],
            "task": r["task"],
            "file": r["file"],
            **m
        })

    # Per-dentist aggregated metrics
    for dentist_raw, sub in gdf.groupby("dentist_raw"):
        dentist_id = dentist_map[dentist_raw]
        cms = list(sub["cm"].values)
        agg = aggregate_confusions(cms)

        if agg.empty:
            continue

        if group_type == "binary":
            if agg.shape != (2, 2):
                raise ValueError(f"[{group_name}] Expected aggregated binary 2x2 but got {agg.shape} for {dentist_raw}")
            b = binary_counts_from_cm(agg, "False", "True")
            m = metrics_from_binary_counts(b)
            cis = bootstrap_binary_cis(b, n_boot=n_boot, seed=seed)

            row = {"group": group_name, "dentist_raw": dentist_raw, "dentist_id": dentist_id, **m}
            for k, (lo, hi) in cis.items():
                row[f"{k}_ci_low"] = lo
                row[f"{k}_ci_high"] = hi
            per_group_dentist_rows.append(row)
        else:
            m = metrics_multiclass(agg)
            per_group_dentist_rows.append({"group": group_name, "dentist_raw": dentist_raw, "dentist_id": dentist_id, **m})

    # Save group CSVs
    per_file_g = pd.DataFrame([r for r in per_group_file_rows if r["group"] == group_name])
    per_dent_g = pd.DataFrame([r for r in per_group_dentist_rows if r["group"] == group_name])

    per_file_g.to_csv(group_out / f"per_file_metrics_{group_name}.csv", index=False)
    per_dent_g.to_csv(group_out / f"per_dentist_metrics_{group_name}.csv", index=False)

# Save global combined tables
per_file = pd.DataFrame(per_group_file_rows)
per_dentist = pd.DataFrame(per_group_dentist_rows)
per_file.to_csv(out_dir / "per_file_metrics_ALLGROUPS.csv", index=False)
per_dentist.to_csv(out_dir / "per_dentist_metrics_ALLGROUPS.csv", index=False)

plot_groups = ["PBL", "ARR", "PLS", "Furcation"]
plot_labels = ["PBL (mesial+distal, multiclass)", "ARR", "PLS", "Furcation"]  # <-- updated label

plot_df = per_dentist[per_dentist["group"].isin(plot_groups)].copy()

plot_metric_heatmap(
    per_group_dentist=plot_df,
    metric="false_negative_rate",
    out_path=plots_dir / "heatmap_false_negative_rate.png",
    title="False Negative Rate (1 - recall)",
    group_order=plot_groups,
    group_labels=plot_labels,
)

plot_metric_heatmap(
    per_group_dentist=plot_df,
    metric="false_positive_rate",
    out_path=plots_dir / "heatmap_false_positive_rate.png",
    title="False Positive Rate (1 - specificity)",
    group_order=plot_groups,
    group_labels=plot_labels,
)

print("Done. Outputs written to:")
print(f"  {out_dir}")
print("Plots written to:")
print(f"  {plots_dir}")



  pivot = df.pivot_table(index="dentist_id", columns="group", values=metric, aggfunc="first")
  pivot = df.pivot_table(index="dentist_id", columns="group", values=metric, aggfunc="first")


Done. Outputs written to:
  F:\Github Repos\study\study_code\processed_data\keypoint\results_dentists_multiclass
Plots written to:
  F:\Github Repos\study\study_code\processed_data\keypoint\results_dentists_multiclass\plots
