In [1]:
# TASK 6 — ADR code assignment (string-match vs embedding-match)
# Jupyter-friendly library cell (no argparse; import-free outside standard libs)

from __future__ import annotations

import math, os, re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple, Optional

# -------- Data classes --------
@dataclass(frozen=True)
class Range:
    start: int
    end: int

@dataclass
class AnnSpan:
    label: str              # ADR, Drug, Disease, Symptom (from original or predicted)
    ranges: List[Range]
    text: str               # pulled from raw text

@dataclass
class SctSpan:
    code: str               # SNOMED CT code (or MedDRA if your sct dir has those)
    ranges: List[Range]
    text: str               # mapped term from sct file (we treat this as "standard_text")

@dataclass
class Joined:
    code: str
    standard_text: str      # sct span text (mapped term)
    label_type: str         # from original ann (ADR/Drug/Disease/Symptom)
    gt_text: str            # from original ann
    gt_ranges: List[Range]

# -------- Parsers --------
RANGE_RE = re.compile(r"(\d+)\s+(\d+)")

def _read_text(p: Path) -> str:
    return p.read_text(encoding="utf-8")

def _surface(text: str, ranges: List[Range]) -> str:
    # joins discontiguous spans with a space — same behavior as your other tasks
    return " ".join(text[r.start:r.end].replace("\n", " ") for r in ranges)

def _parse_ann_spans(path: Path, raw: str, accept_labels: Optional[set[str]] = None) -> List[AnnSpan]:
    """
    Parse brat T-lines:  T1\tADR 10 20;30 40\ttext...
    Keeps multi-range spans; pulls surface from raw text.
    """
    spans: List[AnnSpan] = []
    if not path.exists():
        return spans
    for line in path.read_text(encoding="utf-8").splitlines():
        line = line.strip()
        if not line or not line.startswith("T") or line.startswith("TT"):
            continue
        parts = line.split("\t")
        if len(parts) < 2:
            continue
        head = parts[1]                      # "LABEL s e; s e ..."
        label = head.split()[0]
        if accept_labels and label not in accept_labels:
            continue
        rr: List[Range] = []
        for m in RANGE_RE.finditer(head[len(label):]):
            s, e = int(m.group(1)), int(m.group(2))
            if e > s:
                rr.append(Range(s, e))
        if not rr:
            continue
        txt = _surface(raw, rr)
        spans.append(AnnSpan(label=label, ranges=rr, text=txt))
    return spans

def _parse_sct_spans(path: Path, raw: str) -> List[SctSpan]:
    """
    Parse mapped-code TT-lines:  TT1\t<CODE> <start> <end>[; ...]\t<mapped term>
    If mapped term missing, uses surface from raw.
    """
    out: List[SctSpan] = []
    if not path.exists():
        return out
    for line in path.read_text(encoding="utf-8").splitlines():
        line = line.strip()
        if not line or not line.startswith("TT"):
            continue
        parts = line.split("\t")
        if len(parts) < 2:
            continue
        head = parts[1].strip()              # "<CODE> s e [; s e ...]"
        bits = head.split()
        code = bits[0]
        rr: List[Range] = []
        for m in RANGE_RE.finditer(head[len(code):]):
            s, e = int(m.group(1)), int(m.group(2))
            if e > s:
                rr.append(Range(s, e))
        if not rr:
            continue
        mapped = parts[2].strip() if len(parts) >= 3 else ""
        if not mapped:
            mapped = _surface(raw, rr)
        out.append(SctSpan(code=code, ranges=rr, text=mapped))
    return out

# -------- Overlap helpers --------
def _spans_overlap_len(ar: List[Range], br: List[Range]) -> int:
    tot = 0
    for a in ar:
        for b in br:
            s = max(a.start, b.start)
            e = min(a.end, b.end)
            if e > s:
                tot += (e - s)
    return tot

# -------- Join original ↔ sct by max overlap --------
def build_joined(original_spans: List[AnnSpan], sct_spans: List[SctSpan]) -> List[Joined]:
    out: List[Joined] = []
    for g in original_spans:
        best: Optional[SctSpan] = None
        best_ol = 0
        for t in sct_spans:
            ol = _spans_overlap_len(g.ranges, t.ranges)
            if ol > best_ol:
                best_ol = ol
                best = t
        if best and best_ol > 0:
            out.append(Joined(
                code=best.code,
                standard_text=best.text,
                label_type=g.label,
                gt_text=g.text,
                gt_ranges=g.ranges
            ))
        else:
            out.append(Joined(
                code="",
                standard_text="",
                label_type=g.label,
                gt_text=g.text,
                gt_ranges=g.ranges
            ))
    return out

# -------- Matching (fuzzy & embedding) --------
def _norm(s: str) -> str:
    return re.sub(r"\s+", " ", s).strip().casefold()

def fuzzy_score(a: str, b: str) -> float:
    """Return 0..100. Uses rapidfuzz if available, else difflib (token-insensitive)."""
    try:
        from rapidfuzz import fuzz
        return float(fuzz.token_set_ratio(a, b))
    except Exception:
        import difflib
        return 100.0 * difflib.SequenceMatcher(None, _norm(a), _norm(b)).ratio()

def embed_model_loader(model_name: str):
    try:
        from sentence_transformers import SentenceTransformer
        return SentenceTransformer(model_name)
    except Exception as ex:
        raise RuntimeError(
            f"Embedding model '{model_name}' not available.\n"
            f"Try: pip install sentence-transformers\nDetails: {ex}"
        )

def embed_vectors(model, texts: List[str]):
    # model.encode returns np.ndarray; normalize for cosine
    return model.encode(texts, normalize_embeddings=True, convert_to_numpy=True)

def cosine_sim_matrix(cat_vecs, q_vec):
    # both L2-normalized; cosine == dot
    import numpy as np
    return cat_vecs @ q_vec

# -------- Pretty print helpers --------
def fmt_ranges(rr: List[Range]) -> str:
    return ";".join(f"{r.start}-{r.end}" for r in rr)

def print_table(rows: List[List[str]]):
    if not rows:
        return
    widths = [max(len(str(row[i])) for row in rows) for i in range(len(rows[0]))]
    for i, row in enumerate(rows):
        line = " | ".join(str(row[j]).ljust(widths[j]) for j in range(len(row)))
        print(line)
        if i == 0:
            print("-+-".join("-" * w for w in widths))

# -------- Runner (call from next cell) --------
def run_task6_for_file(
    text_dir: str | Path,
    original_dir: str | Path,
    sct_dir: str | Path,
    predicted_dir: str | Path,
    file_basename: str,                           # e.g., "ARTHROTEC.24" (with or w/o .txt/.ann)
    embed_model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
    topn: int = 3,                                # kept for parity; we currently show top-1 per method
    min_fuzzy: float = 60.0,
    min_cos: float = 0.35,
):
    base = file_basename
    base = base[:-4] if base.lower().endswith(".txt") or base.lower().endswith(".ann") else base

    text_path = Path(text_dir) / f"{base}.txt"
    orig_path = Path(original_dir) / f"{base}.ann"
    sct_path  = Path(sct_dir) / f"{base}.ann"
    pred_path = Path(predicted_dir) / f"{base}.ann"

    if not text_path.exists():
        raise FileNotFoundError(text_path)
    raw = _read_text(text_path)

    # 1) Parse
    original_spans   = _parse_ann_spans(orig_path, raw, accept_labels=None)   # all labels
    sct_spans        = _parse_sct_spans(sct_path, raw)
    predicted_spans  = _parse_ann_spans(pred_path, raw, accept_labels={"ADR"})  # only ADR

    # 2) Join original ↔ sct by overlap
    joined = build_joined(original_spans, sct_spans)

    # 3) ADR catalog from joined (code + standardized text)
    catalog = [(j.code, j.standard_text, j.gt_text, j.label_type) for j in joined if j.label_type == "ADR" and j.code]
    if not catalog:
        print("No ADR-coded entries found in sct↔original join for this file.")
    else:
        print(f"[Info] ADR catalog size: {len(catalog)}")

    # 4) Prepare embedding model & vectors for catalog standard_text
    emb_model = None
    emb_cat = None
    if catalog:
        try:
            emb_model = embed_model_loader(embed_model_name)
            emb_cat = embed_vectors(emb_model, [c[1] for c in catalog])  # embeddings of standard_text
        except Exception as ex:
            print(f"[WARN] Embedding model unavailable: {ex}")
            emb_model = None

    # 5) Compare for each predicted ADR (top-1 fuzzy & top-1 embedding)
    header = [
        "Pred ADR text",
        "Fuzzy code", "Fuzzy std text", "Fuzzy score",
        "Embed code", "Embed std text", "Cosine"
    ]
    rows = [header]
    agree = 0
    total = 0

    for p in predicted_spans:
        total += 1
        ptxt = p.text

        # (a) fuzzy best
        best_f = (-1.0, "", "")  # (score, code, std_text)
        for code, std_text, gt_text, lab in catalog:
            sc = fuzzy_score(ptxt, std_text)
            if sc > best_f[0]:
                best_f = (sc, code, std_text)
        fuzzy_code, fuzzy_txt, fuzzy_sc = "", "", 0.0
        if best_f[0] >= min_fuzzy:
            fuzzy_sc, fuzzy_code, fuzzy_txt = best_f[0], best_f[1], best_f[2]

        # (b) embedding best
        embed_code, embed_txt, embed_cos = "", "", 0.0
        if emb_model is not None and emb_cat is not None and len(catalog) > 0:
            import numpy as np
            v = embed_vectors(emb_model, [ptxt])[0]      # shape (d,)
            sims = cosine_sim_matrix(emb_cat, v)         # shape (N,)
            idx = int(sims.argmax())
            cs = float(sims[idx])
            if cs >= min_cos:
                embed_code, embed_txt, embed_cos = catalog[idx][0], catalog[idx][1], cs

        rows.append([
            ptxt,
            fuzzy_code, fuzzy_txt, f"{fuzzy_sc:.1f}",
            embed_code, embed_txt, f"{embed_cos:.3f}"
        ])

        if fuzzy_code and embed_code and (fuzzy_code == embed_code):
            agree += 1

    print("\n=== ADR Code Assignment (Predicted ADRs) ===")
    print_table(rows)

    if total > 0:
        print(f"\nAgreement (fuzzy vs embedding) on assigned code: {agree}/{total} = {agree/total:.2%}")
    else:
        print("\nNo predicted ADR spans found for this file.")

    # 6) Show joined catalog for transparency
    print("\n=== Joined catalog (original ↔ sct) for this file ===")
    cat_rows = [["Code", "Standard Text (from sct)", "Label", "GT Text", "GT Ranges"]]
    for j in joined:
        if j.label_type != "ADR":
            continue
        if not j.code:
            # skip ADRs with no mapped code
            continue
        cat_rows.append([j.code, j.standard_text, j.label_type, j.gt_text, fmt_ranges(j.gt_ranges)])
    print_table(cat_rows)


In [4]:
# Batch Task 6 runner for a few files

FILES = ["ARTHROTEC.6.", "ARTHROTEC.7", "LIPITOR.344", "VOLTAREN.10", "ARTHROTEC.76"]

# Paths (edit if yours are different)
TEXT_DIR = "/Users/anjalikulkarni/Desktop/Assignment1/CADEC-lPWNPfjE-/data/cadec/text"
ORIG_DIR = "/Users/anjalikulkarni/Desktop/Assignment1/CADEC-lPWNPfjE-/data/cadec/original"
SCT_DIR  = "/Users/anjalikulkarni/Desktop/Assignment1/CADEC-lPWNPfjE-/data/cadec/sct"
PRED_DIR = "/Users/anjalikulkarni/Desktop/Assignment1/predicted"

# Options
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"  # or set to None to skip embeddings
MIN_FUZZY   = 60.0
MIN_COS     = 0.35

def _clean_base(b: str) -> str:
    # Normalize: remove trailing '.' and optional .txt/.ann
    b = b.strip()
    if b.endswith(".txt") or b.endswith(".ann"):
        b = b[:-4]
    while b.endswith("."):
        b = b[:-1]
    return b

for raw_name in FILES:
    base = _clean_base(raw_name)
    print("\n" + "="*90)
    print(f"### Processing: {base}")
    print("="*90)
    try:
        run_task6_for_file(
            text_dir=TEXT_DIR,
            original_dir=ORIG_DIR,
            sct_dir=SCT_DIR,
            predicted_dir=PRED_DIR,
            file_basename=base,
            embed_model_name=EMBED_MODEL,
            topn=3,
            min_fuzzy=MIN_FUZZY,
            min_cos=MIN_COS,
        )
    except FileNotFoundError as e:
        print(f"[SKIP] Missing required file for {base}: {e}")
    except Exception as e:
        print(f"[ERROR] {base}: {e}")



### Processing: ARTHROTEC.6
[Info] ADR catalog size: 9

=== ADR Code Assignment (Predicted ADRs) ===
Pred ADR text    | Fuzzy code | Fuzzy std text            | Fuzzy score | Embed code | Embed std text   | Cosine
-----------------+------------+---------------------------+-------------+------------+------------------+-------
stomach pain     | 271681002  | stomach pain              | 100.0       | 271681002  | stomach pain     | 1.000 
slight nausea    | 422587007  | slight nausea             | 100.0       | 422587007  | slight nausea    | 1.000 
cramps           | 9991008    | abdominal cramps and pain | 100.0       | 9991008    | abdominal cramps | 0.826 
abdominal cramps | 9991008    | abdominal cramps and pain | 100.0       | 9991008    | abdominal cramps | 1.000 
pain relief      |            |                           | 0.0         | 271681002  | stomach pain     | 0.433 

Agreement (fuzzy vs embedding) on assigned code: 4/5 = 80.00%

=== Joined catalog (original ↔ sct) for thi