In [4]:
# Notebook code: conta [UH]/[UM] SOLO per i transcript di WhisperX_nyrahealth
# e stampa statistiche split per diagnosi (AD vs CN) separando TRAIN e TEST.
# Niente CSV: solo print.

from pathlib import Path
import re
import pandas as pd
import numpy as np

# --------------- CONFIG ---------------
TRANSCRIPT_DIRS = [
    Path("transcripts/WhisperX_nyrahealth"),
    Path("transcripts_test/WhisperX_nyrahealth"),
]
TRAIN_LABELS = Path("Adresso21/ADReSSo21-diagnosis-train/ADReSSo21/diagnosis/train/adresso-train-mmse-scores.csv")
TEST_LABELS  = Path("Adresso21/label_test_task1.csv")  # etichette AD/CN del test
# --------------------------------------

# ---- carica e normalizza label -> (ID, Dx, subset) ----
def _normalize_labels_df(df: pd.DataFrame, subset: str) -> pd.DataFrame:
    cols = {c.lower(): c for c in df.columns}
    # ID
    id_col = None
    for cand in ["id", "adressfname", "adressfilename", "filename", "name"]:
        if cand in cols: id_col = cols[cand]; break
    if id_col is None:
        raise ValueError(f"Colonna ID non trovata nelle label: {df.columns.tolist()}")
    # Dx
    dx_col = None
    for cand in ["dx", "diagnosis", "label", "class"]:
        if cand in cols: dx_col = cols[cand]; break
    if dx_col is None:
        raise ValueError(f"Colonna diagnosi non trovata nelle label: {df.columns.tolist()}")

    out = pd.DataFrame({
        "ID": df[id_col].astype(str).str.replace(r"\.wav$", "", regex=True).str.lower(),
        "Dx_raw": df[dx_col].astype(str),
    })
    map_dict = {
        "ad":"AD","cn":"CN","probablead":"AD","control":"CN",
        "AD":"AD","CN":"CN","ProbableAD":"AD","Control":"CN",
    }
    out["Dx"] = out["Dx_raw"].str.strip().map(lambda x: map_dict.get(x, map_dict.get(x.lower(), x)))
    out["subset"] = subset
    return out[["ID","Dx","subset"]]

label_frames = []
if TRAIN_LABELS.exists():
    label_frames.append(_normalize_labels_df(pd.read_csv(TRAIN_LABELS), "train"))
if TEST_LABELS.exists():
    label_frames.append(_normalize_labels_df(pd.read_csv(TEST_LABELS), "test"))
if not label_frames:
    raise FileNotFoundError("File di label non trovati. Controlla TRAIN_LABELS e TEST_LABELS.")

labels = pd.concat(label_frames, ignore_index=True).drop_duplicates(subset=["ID","subset"])

# ---- regex per [UH]/[UM] ----
pat_UH = re.compile(r"\[UH\]", re.IGNORECASE)
pat_UM = re.compile(r"\[UM\]", re.IGNORECASE)

# ---- scansione SOLO dei transcript WhisperX_nyrahealth ----
rows = []
for base in TRANSCRIPT_DIRS:
    if not base.exists():
        continue
    subset = "test" if "transcripts_test" in base.parts else "train"
    for p in base.rglob("*.txt"):
        file_id = p.stem.lower()
        try:
            txt = p.read_text(encoding="utf-8", errors="ignore")
        except Exception:
            with open(p, "r", encoding="utf-8", errors="ignore") as f:
                txt = f.read()
        uh = len(pat_UH.findall(txt))
        um = len(pat_UM.findall(txt))
        rows.append({"subset": subset, "file_id": file_id, "UH": uh, "UM": um})

if not rows:
    raise RuntimeError("Nessun .txt trovato sotto transcripts*/WhisperX_nyrahealth.")

df = pd.DataFrame(rows)

# ---- join con le label e filtro AD/CN ----
df_lab = df.merge(labels, left_on=["file_id","subset"], right_on=["ID","subset"], how="left")
df_lab = df_lab[df_lab["Dx"].isin(["AD","CN"])].copy()

# ---- stampe: mean/std per testo e totali, split per subset & Dx ----
def _print_stats(sub, subset_name):
    print(f"== {subset_name.upper()} | Statistiche per testo (WhisperX_nyrahealth) ==")
    for dx in ["CN","AD"]:
        part = sub[sub["Dx"] == dx]
        if part.empty:
            print(f"{dx}: nessun file.")
            continue
        n = len(part)
        mean_uh, std_uh = part["UH"].mean(), part["UH"].std(ddof=1) if n>1 else 0.0
        mean_um, std_um = part["UM"].mean(), part["UM"].std(ddof=1) if n>1 else 0.0
        print(f"{dx}: n={n} | [UH] mean={mean_uh:.3f} std={std_uh:.3f} | [UM] mean={mean_um:.3f} std={std_um:.3f}")
    print()
    print(f"== {subset_name.upper()} | Totali per diagnosi ==")
    agg = sub.groupby("Dx")[["UH","UM"]].sum()
    for dx in ["CN","AD"]:
        if dx in agg.index:
            uh, um = int(agg.loc[dx,"UH"]), int(agg.loc[dx,"UM"])
            print(f"{dx}: UH={uh}  UM={um}  (UH+UM={uh+um})")
    print()

_print_stats(df_lab[df_lab["subset"]=="train"], "train")
_print_stats(df_lab[df_lab["subset"]=="test"],  "test")

# (opzionale) totali combinati train+test
comb = df_lab.groupby("Dx")[["UH","UM"]].sum()
print("== COMBINATI (train+test) | Totali per diagnosi ==")
for dx in ["CN","AD"]:
    if dx in comb.index:
        uh, um = int(comb.loc[dx,"UH"]), int(comb.loc[dx,"UM"])
        print(f"{dx}: UH={uh}  UM={um}  (UH+UM={uh+um})")


== TRAIN | Statistiche per testo (WhisperX_nyrahealth) ==
CN: n=79 | [UH] mean=2.241 std=2.371 | [UM] mean=1.354 std=1.783
AD: n=87 | [UH] mean=2.690 std=3.258 | [UM] mean=0.483 std=1.088

== TRAIN | Totali per diagnosi ==
CN: UH=177  UM=107  (UH+UM=284)
AD: UH=234  UM=42  (UH+UM=276)

== TEST | Statistiche per testo (WhisperX_nyrahealth) ==
CN: n=36 | [UH] mean=2.528 std=3.308 | [UM] mean=0.750 std=1.251
AD: n=35 | [UH] mean=2.000 std=2.288 | [UM] mean=0.200 std=0.473

== TEST | Totali per diagnosi ==
CN: UH=91  UM=27  (UH+UM=118)
AD: UH=70  UM=7  (UH+UM=77)

== COMBINATI (train+test) | Totali per diagnosi ==
CN: UH=268  UM=134  (UH+UM=402)
AD: UH=304  UM=49  (UH+UM=353)
