In [5]:
# Task 5 (Option A): evaluate k files sampled from the intersection of gold & predicted

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple, Set

LABELS = ("ADR", "Drug", "Disease", "Symptom")

@dataclass(frozen=True)
class Span:
    label: str
    start: int
    end: int

# ---------- Parsers (same as Task 3 strict) ----------
def parse_ann_file(path: Path) -> List[Span]:
    """Parse BRAT .ann file (gold or predicted). Keeps labels in LABELS and splits multi-ranges."""
    spans: List[Span] = []
    if not path.exists():
        return spans
    for line in path.read_text(encoding="utf-8").splitlines():
        line = line.strip()
        if not line or line.startswith("#") or not line.startswith("T"):
            continue
        parts = line.split("\t")
        if len(parts) < 2:
            continue
        head = parts[1]  # "LABEL s e[; s e]*"
        bits = head.split()
        if not bits:
            continue
        label = bits[0]
        if label not in LABELS:
            continue
        # collect all number pairs after the label (handles discontiguous spans)
        nums = [int(x) for x in bits[1:] if x.isdigit()]
        for i in range(0, len(nums), 2):
            if i + 1 >= len(nums):
                break
            s, e = nums[i], nums[i+1]
            if e > s:
                spans.append(Span(label, s, e))
    return sorted(set(spans), key=lambda x: (x.start, x.end, x.label))

def to_set(spans: List[Span], label: str | None = None) -> Set[Tuple[int, int, str]]:
    if label is None:
        return {(s.start, s.end, s.label) for s in spans}
    return {(s.start, s.end, s.label) for s in spans if s.label == label}

# ---------- Metrics ----------
def prf1(tp: int, fp: int, fn: int) -> Tuple[float, float, float]:
    p = tp / (tp + fp) if (tp + fp) else 0.0
    r = tp / (tp + fn) if (tp + fn) else 0.0
    f = (2 * p * r / (p + r)) if (p + r) else 0.0
    return p, r, f

# ---------- Single-file strict evaluation ----------
def eval_file_strict(base_name: str, gold_dir: Path, pred_dir: Path) -> Dict:
    """Strict character-offset match on one base name (no extension)."""
    gold = parse_ann_file(gold_dir / f"{base_name}.ann")
    pred = parse_ann_file(pred_dir / f"{base_name}.ann")

    g_all, p_all = to_set(gold), to_set(pred)
    tp_all = len(g_all & p_all)
    fp_all = len(p_all - g_all)
    fn_all = len(g_all - p_all)
    P_all, R_all, F_all = prf1(tp_all, fp_all, fn_all)

    per_label = {}
    for lab in LABELS:
        g = to_set(gold, lab)
        p = to_set(pred, lab)
        tp = len(g & p)
        fp = len(p - g)
        fn = len(g - p)
        P, R, F = prf1(tp, fp, fn)
        per_label[lab] = dict(tp=tp, fp=fp, fn=fn, P=P, R=R, F=F, gold=len(g), pred=len(p))

    return dict(
        base=base_name,
        overall=dict(tp=tp_all, fp=fp_all, fn=fn_all, P=P_all, R=R_all, F=F_all,
                     gold=len(g_all), pred=len(p_all)),
        per_label=per_label
    )

# ---------- Intersection sampler & aggregator ----------
def evaluate_from_intersection(
    gold_dir: Path,
    pred_dir: Path,
    k: int = 50,
    seed: int = 42,
):
    """
    Sample up to k basenames from the intersection of gold_dir and pred_dir .ann files.
    Runs strict evaluation per file and prints micro + per-label micro aggregates.
    """
    import random

    gold_bases = {p.stem for p in gold_dir.glob("*.ann")}
    pred_bases = {p.stem for p in pred_dir.glob("*.ann")}
    both = sorted(gold_bases & pred_bases)
    if not both:
        raise RuntimeError(f"No overlapping files between {gold_dir} and {pred_dir}.")

    rnd = random.Random(seed)
    picked = rnd.sample(both, k=min(k, len(both)))
    print(f"Evaluating {len(picked)} files from gold∩pred (available overlap={len(both)}). Seed={seed}")
    # aggregate
    totals_overall = dict(tp=0, fp=0, fn=0, gold=0, pred=0)
    totals_label = {lab: dict(tp=0, fp=0, fn=0, gold=0, pred=0) for lab in LABELS}

    for i, base in enumerate(picked, 1):
        res = eval_file_strict(base, gold_dir, pred_dir)
        o = res["overall"]
        print(f"  [{i:02d}] {base}: P={o['P']:.3f} R={o['R']:.3f} F1={o['F']:.3f}  (gold={o['gold']}, pred={o['pred']})")

        totals_overall["tp"]   += o["tp"]
        totals_overall["fp"]   += o["fp"]
        totals_overall["fn"]   += o["fn"]
        totals_overall["gold"] += o["gold"]
        totals_overall["pred"] += o["pred"]

        for lab in LABELS:
            pl = res["per_label"][lab]
            t = totals_label[lab]
            t["tp"] += pl["tp"]; t["fp"] += pl["fp"]; t["fn"] += pl["fn"]
            t["gold"] += pl["gold"]; t["pred"] += pl["pred"]

    # micro overall
    P, R, F = prf1(totals_overall["tp"], totals_overall["fp"], totals_overall["fn"])
    print("\n--- Micro (all labels) over evaluated files ---")
    print(f"Files evaluated: {len(picked)}")
    print(f"Gold spans: {totals_overall['gold']} | Pred spans: {totals_overall['pred']}")
    print(f"TP: {totals_overall['tp']} | FP: {totals_overall['fp']} | FN: {totals_overall['fn']}")
    print(f"Micro Precision: {P:.4f}")
    print(f"Micro Recall   : {R:.4f}")
    print(f"Micro F1       : {F:.4f}")

    # per-label micro
    per_label_scores = {}
    print("\n--- Per-label micro ---")
    for lab in LABELS:
        t = totals_label[lab]
        p, r, f = prf1(t["tp"], t["fp"], t["fn"])
        per_label_scores[lab] = dict(P=p, R=r, F=f, **t)
        print(f"{lab:<8} P={p:.4f}  R={r:.4f}  F1={f:.4f}  "
              f"(tp={t['tp']}, fp={t['fp']}, fn={t['fn']}, gold={t['gold']}, pred={t['pred']})")

    return {
        "files_evaluated": picked,
        "micro": {"P": P, "R": R, "F": F, **totals_overall},
        "per_label": per_label_scores,
    }


In [15]:
from pathlib import Path
import random

# ---- paths (edit as needed) ----
gold_dir = Path("/Users/anjalikulkarni/Desktop/Assignment1/CADEC-lPWNPfjE-/data/cadec/original")
pred_dir = Path("/Users/anjalikulkarni/Desktop/Assignment1/predicted")

k = 50
seed = 42

# 1) get all predicted .ann files
all_pred = sorted([p for p in pred_dir.iterdir() if p.suffix.lower() == ".ann"])
if not all_pred:
    raise FileNotFoundError(f"No .ann predictions found in {pred_dir}")

rnd = random.Random(seed)
sample = rnd.sample(all_pred, k=min(k, len(all_pred)))

print(f"Sampled {len(sample)} prediction files out of {len(all_pred)} available.")

# 2) evaluate each sampled file
totals_overall = dict(tp=0, fp=0, fn=0, gold=0, pred=0)
totals_label = {lab: dict(tp=0, fp=0, fn=0, gold=0, pred=0) for lab in LABELS}

print(f"\nEvaluating {len(sample)} files from predicted/ (seed={seed})")
for i, pred_path in enumerate(sample, 1):
    base = pred_path.stem
    gold_path = gold_dir / f"{base}.ann"
    if not gold_path.exists():
        print(f"  [{i:02d}] SKIP {base} — missing gold: {gold_path}")
        continue

    res = eval_file_strict(base, gold_dir, pred_dir)
    o = res["overall"]
    print(f"  [{i:02d}] {base}: P={o['P']:.3f} R={o['R']:.3f} F1={o['F']:.3f}  (gold={o['gold']}, pred={o['pred']})")

    totals_overall["tp"]   += o["tp"]
    totals_overall["fp"]   += o["fp"]
    totals_overall["fn"]   += o["fn"]
    totals_overall["gold"] += o["gold"]
    totals_overall["pred"] += o["pred"]

    for lab in LABELS:
        pl = res["per_label"][lab]
        t = totals_label[lab]
        t["tp"] += pl["tp"]; t["fp"] += pl["fp"]; t["fn"] += pl["fn"]
        t["gold"] += pl["gold"]; t["pred"] += pl["pred"]

# 3) micro overall
P = totals_overall["tp"] / (totals_overall["tp"] + totals_overall["fp"]) if (totals_overall["tp"] + totals_overall["fp"]) else 0.0
R = totals_overall["tp"] / (totals_overall["tp"] + totals_overall["fn"]) if (totals_overall["tp"] + totals_overall["fn"]) else 0.0
F = (2 * P * R / (P + R)) if (P + R) else 0.0

print("\n--- Micro (all labels) over evaluated files ---")
print(f"Files evaluated: {len(sample)}")
print(f"Gold spans: {totals_overall['gold']} | Pred spans: {totals_overall['pred']}")
print(f"TP: {totals_overall['tp']} | FP: {totals_overall['fp']} | FN: {totals_overall['fn']}")
print(f"Micro Precision: {P:.4f}")
print(f"Micro Recall   : {R:.4f}")
print(f"Micro F1       : {F:.4f}")

print("\n--- Per-label micro ---")
for lab in LABELS:
    t = totals_label[lab]
    p = t["tp"] / (t["tp"] + t["fp"]) if (t["tp"] + t["fp"]) else 0.0
    r = t["tp"] / (t["tp"] + t["fn"]) if (t["tp"] + t["fn"]) else 0.0
    f = (2 * p * r / (p + r)) if (p + r) else 0.0
    print(f"{lab:<8} P={p:.4f}  R={r:.4f}  F1={f:.4f}  "
          f"(tp={t['tp']}, fp={t['fp']}, fn={t['fn']}, gold={t['gold']}, pred={t['pred']})")


Sampled 50 prediction files out of 103 available.

Evaluating 50 files from predicted/ (seed=42)
  [01] PENNSAID.4: P=0.000 R=0.000 F1=0.000  (gold=1, pred=4)
  [02] ARTHROTEC.21: P=0.250 R=0.250 F1=0.250  (gold=4, pred=4)
  [03] ARTHROTEC.11: P=0.500 R=0.500 F1=0.500  (gold=2, pred=2)
  [04] VOLTAREN.46: P=0.182 R=0.167 F1=0.174  (gold=12, pred=11)
  [05] ARTHROTEC.5: P=0.333 R=1.000 F1=0.500  (gold=1, pred=3)
  [06] ARTHROTEC.46: P=0.000 R=0.000 F1=0.000  (gold=5, pred=5)
  [07] ARTHROTEC.43: P=0.333 R=0.400 F1=0.364  (gold=5, pred=6)
  [08] ARTHROTEC.24: P=0.400 R=0.286 F1=0.333  (gold=7, pred=5)
  [09] ZIPSOR.2: P=0.000 R=0.000 F1=0.000  (gold=2, pred=4)
  [10] ARTHROTEC.20: P=0.167 R=0.111 F1=0.133  (gold=9, pred=6)
  [11] VOLTAREN-XR.2: P=0.167 R=0.125 F1=0.143  (gold=8, pred=6)
  [12] LIPITOR.3: P=0.500 R=0.059 F1=0.105  (gold=17, pred=2)
  [13] ARTHROTEC.19: P=0.600 R=0.600 F1=0.600  (gold=5, pred=5)
  [14] LIPITOR.9: P=0.375 R=0.333 F1=0.353  (gold=9, pred=8)
  [15] ARTHROTEC.