In [1]:
!pip install -q \
  torch>=2.1.0 \
  transformers>=4.39.0 \
  appdirs \
  jsonpickle \
  filelock \
  h5py \
  nltk \
  dotmap \
  pytest


In [2]:
!pip install radgraph

Collecting radgraph
  Downloading radgraph-0.1.18.tar.gz (587 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/588.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━[0m [32m307.2/588.0 kB[0m [31m9.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m588.0/588.0 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: radgraph
  Building wheel for radgraph (setup.py) ... [?25l[?25hdone
  Created wheel for radgraph: filename=radgraph-0.1.18-py3-none-any.whl size=812635 sha256=2e293d289b3d26e970c220839ec84b4615fbde84abcdf9fbaca061c6f174eb6a
  Stored in directory: /root/.cache/pip/wheels/fb/3c/fb/214f5d5cdab2a0f9f0904fd81d7fd1134404100b4444554df8
Successfully built radgraph
Installing collected packages: radgraph
Successfully installed radgraph-0.1.18


In [3]:
import json
from radgraph import get_radgraph_processed_annotations, RadGraph

  return datetime.utcnow().replace(tzinfo=utc)


In [4]:
import pandas as pd
import ast
import re


In [5]:
def clean_text(x):
    if not x:
        return ""
    if isinstance(x, list):
        return ", ".join(x)
    return str(x)


def annotation_to_sentence(annotation):
    obs = clean_text(annotation.get("observation"))
    loc = clean_text(annotation.get("located_at"))
    sug = clean_text(annotation.get("suggestive_of"))
    tag = annotation.get("tags", [""])[0]

    # Normalize tag
    tag = re.sub("_", " ", tag)

    parts = []

    # Handle ABSENT case first
    if tag == "definitely absent":
        if obs:
            sentence = f"No {obs}"
            if loc:
                sentence += f" in the {loc}"
            sentence += "."
            return sentence

    # PRESENT (or unknown) case
    if obs:
        parts.append(obs.capitalize())

    if loc:
        parts.append(f"in the {loc}")

    if sug:
        parts.append(f"suggestive of {sug}")

    sentence = " ".join(parts).strip()

    if sentence and not sentence.endswith("."):
        sentence += "."

    return sentence


In [7]:

df = pd.read_csv("/content/output.csv")

clean = (
    df.iloc[33333:33734, 0]
      .astype(str)
      .str.replace(r'^FINDINGS:\s*', '', regex=True)
      .str.replace(r"\s+", " ", regex=True)
      .str.strip()
      .reset_index(drop=True)
)

concepts = df.iloc[33333:33734, 1].reset_index(drop=True)

eval_df = pd.DataFrame({"report": clean, "concepts": concepts})
model_type = "modern-radgraph-xl"
radgraph = RadGraph(model_type=model_type)
reports = eval_df["report"].astype(str).tolist()

pred_concepts = []

for report in reports:
    annotations = radgraph([report])  # one report batch
    processed = get_radgraph_processed_annotations(annotations)  # dict with "processed_annotations"

    sents = []
    for ann in processed["processed_annotations"]:
        s = annotation_to_sentence(ann)
        if s:
            sents.append(s)

    pred_concepts.append(list(dict.fromkeys(sents)))

reports = eval_df["report"].astype(str).tolist()

gt_concepts = [
    ast.literal_eval(x) if pd.notna(x) else []
    for x in eval_df["concepts"].tolist()
]


Using device: cpu


  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)


In [8]:
def canonicalize_concept(s: str) -> str:
    s = str(s).strip().lower()
    s = re.sub(r"\s+", " ", s)
    s = s.rstrip(".")
    s = s.replace("top - normal", "normal")
    s = s.replace("within normal limits", "normal")
    s = s.replace("unremarkable", "normal")

    # reorder templates to match GT phrase style
    # "no effusion in the pleural" -> "no pleural effusion"
    m = re.match(r"^no (.+?) in the (.+)$", s)
    if m:
        obs, loc = m.group(1), m.group(2)
        s = f"no {loc} {obs}"

    # "calcified in the aorta" -> "calcified aorta"
    m = re.match(r"^(.+?) in the (.+)$", s)
    if m:
        obs, loc = m.group(1), m.group(2)
        s = f"{obs} {loc}"

    s = re.sub(r"\s+", " ", s).strip()

    # drop single-word leftovers (these usually come from missing location)
    if len(s.split()) == 1:
        return ""

    return s

gt_norm   = [[canonicalize_concept(x) for x in xs] for xs in gt_concepts]
pred_norm = [[canonicalize_concept(x) for x in xs] for xs in pred_concepts]

# remove blanks
gt_norm   = [[x for x in xs if x] for xs in gt_norm]
pred_norm = [[x for x in xs if x] for xs in pred_norm]



In [30]:
for i in (34,43,25,38,22,79,223,185,113,83,67):
  print(reports[i])
  print(i," :")
  print(gt_norm[i],'\n',pred_norm[i],'\n')


The lungs are hyperexpanded, but clear. There is no pleural abnormality. The cardiac and mediastinal silhouettes are unremarkable. Multiple rib deformities with callus formation is again seen.
34  :
['hyperexpanded lungs', 'clear lungs', 'no pleural abnormality', 'normal cardiac silhouette', 'normal mediastinal silhouette', 'rib deformities with callus formation'] 
 ['hyperexpanded lungs', 'clear lungs', 'no pleural abnormality', 'normal cardiac silhouettes, mediastinal silhouettes', 'multiple deformities rib', 'callus formation'] 

The patient is status post median sternotomy and CABG. The heart size is normal. The mediastinal and hilar contours are unremarkable. The pulmonary vasculature is normal in the lungs are clear. No focal consolidation, pleural effusion or pneumothorax is visualized. There are no acute osseous abnormalities. Partially imaged is cervical spinal fusion hardware.
43  :
['status post median sternotomy and cabg', 'normal heart size', 'normal mediastinal and hilar 

In [11]:
def similarity1(a: str, b: str) -> float:
    a_tokens = set(a.lower().split())
    b_tokens = set(b.lower().split())

    if not a_tokens and not b_tokens:
        return 1.0
    if not a_tokens or not b_tokens:
        return 0.0

    intersection = a_tokens & b_tokens
    union = a_tokens | b_tokens

    return len(intersection) / len(union)


In [12]:
def similarity2(a: str, b: str) -> float:
    a_tokens = set(a.lower().split())
    b_tokens = set(b.lower().split())

    if not a_tokens and not b_tokens:
        return 1.0
    if not a_tokens or not b_tokens:
        return 0.0

    intersection = a_tokens & b_tokens
    union = a_tokens | b_tokens

    return (2*len(intersection)) / (len(a_tokens)+len(b_tokens))


In [13]:
def fuzzy_counts_one(gt_list, pred_list, threshold):
    gt = [g.strip().lower() for g in gt_list if g.strip()]
    pr = [p.strip().lower() for p in pred_list if p.strip()]

    used_gt = set()
    TP = 0

    for p in pr:
        best_j = None
        best_score = 0.0

        for j, g in enumerate(gt):
            if j in used_gt:
                continue

            score = similarity2(p, g)
            if score > best_score:
                best_score = score
                best_j = j

        if best_j is not None and best_score >= threshold:
            TP += 1
            used_gt.add(best_j)

    FP = len(pr) - TP
    FN = len(gt) - TP

    return TP, FP, FN


In [14]:
def fuzzy_prf(gt_norm, pred_norm, threshold):
    TP = FP = FN = 0

    for gt, pr in zip(gt_norm, pred_norm):
        t, f, n = fuzzy_counts_one(gt, pr, threshold)
        TP += t
        FP += f
        FN += n

    precision = TP / (TP + FP) if TP + FP else 0
    recall    = TP / (TP + FN) if TP + FN else 0
    f1        = (2 * precision * recall / (precision + recall)) if precision + recall else 0

    return precision, recall, f1, (TP, FP, FN)


In [15]:
for th in [0.7,0.75, 0.8, 0.85,0.9]:
    P, R, F1, counts = fuzzy_prf(gt_norm, pred_norm, threshold=th)
    print(th, P, R, F1)

0.7 0.6802120141342756 0.6787729196050776 0.6794917049064595
0.75 0.6646643109540636 0.6632581100141044 0.6639604659371692
0.8 0.6353356890459364 0.6339915373765868 0.6346629015178257
0.85 0.5565371024734982 0.5553596614950634 0.5559477585598306
0.9 0.49469964664310956 0.4936530324400564 0.494175785386516


no fuzzy

In [16]:
TP = FP = FN = 0

for g_list, p_list in zip(gt_norm, pred_norm):
    g = set(g_list)
    p = set(p_list)

    TP += len(g & p)
    FP += len(p - g)
    FN += len(g - p)

precision = TP / (TP + FP) if (TP + FP) else 0
recall    = TP / (TP + FN) if (TP + FN) else 0
f1        = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0

print("TP FP FN:", TP, FP, FN)
print("Precision:", precision)
print("Recall:", recall)
print("F1:", f1)


TP FP FN: 1125 1705 1711
Precision: 0.39752650176678445
Recall: 0.3966854724964739
F1: 0.3971055418284504


In [17]:
subset_correct = 0
for g_list, p_list in zip(gt_norm, pred_norm):
    if set(g_list) == set(p_list):
        subset_correct += 1

subset_accuracy = subset_correct / len(gt_norm) if len(gt_norm) else 0
print("Subset accuracy:", subset_accuracy)



Subset accuracy: 0.00997506234413965


In [18]:
label_set = sorted(set(x for xs in gt_norm for x in xs) | set(x for xs in pred_norm for x in xs))
L = len(label_set)
N = len(gt_norm)

mismatches = 0
for g_list, p_list in zip(gt_norm, pred_norm):
    g = set(g_list)
    p = set(p_list)
    for lbl in label_set:
        mismatches += int((lbl in g) ^ (lbl in p))

hamming_loss = mismatches / (N * L) if (N * L) else 0
print("Hamming loss:", hamming_loss)
print("N (reports):", N, "L (labels):", L)


Hamming loss: 0.004052665671691371
N (reports): 401 L (labels): 2102
