In [None]:
# Run this the first time installing PyTorch and other dependencies
# conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
# pip install ftfy regex tqdm
# pip install git+https://github.com/openai/CLIP.git

In [7]:
import os
import clip
import torch
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

In [9]:
from torchvision.datasets import CIFAR100

# Load the model
if (torch.cuda.is_available()):
    device = torch.device("cuda")
elif (torch.backends.mps.is_available()):
    device = torch.device("mps")
else:
    device = torch.device("cpu")

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")


Top predictions:

           snake: 65.31%
          turtle: 12.29%
    sweet_pepper: 3.83%
          lizard: 1.88%
       crocodile: 1.75%


## Data preprocessing

In [11]:
# Build a robust preprocessing pipeline for chest X-ray reports into CLIP-friendly captions
import re
import json
import pandas as pd

df = pd.read_csv("dataset/df_train_partial_10000.csv")

# -------------------- helper regexes --------------------

SENTENCE_SPLIT = re.compile(r"(?<=[\.\?\!])\s+")
NON_CHEST_TERMS = re.compile(r"\b(bowel gas|abdomen(al)?|pelvis|shoulder|humerus|knee|hip)\b", re.I)

# Phrases to drop entirely (cross-modal, temporal, recommendations, uncertainty boilerplate)
DROP_PATTERNS = [
    r"\bcompared to\b.*",                     # comparisons
    r"\bunchanged since\b.*",
    r"\bsince\s+\d{2,4}\b.*",
    r"\bprior\b.*",                           # generic "prior" sentence
    r"\bprevious\b.*",
    r"\bbetter (seen|demonstrated) on\b.*",
    r"\b(read in conjunction|correlate clinically|clinical correlation)\b.*",
    r"\bif (clinically|otherwise) indicated\b.*",
    r"\bfollow[- ]?up\b.*",
    r"\bmanagement recommendations?\b.*",
    r"\bCT\b.*",                              # cross-modal references
    r"\bMRI\b.*",
    r"\bultrasound\b.*",
    r"\bechocardiogram?\b.*",
    r"\bPET[- ]?CT\b.*",
    r"\bfluoroscopy\b.*",
    r"___",                                   # placeholder blanks
    r"\bportable AP\b.*",                     # technical headers
    r"\bAP chest\b.*",
    r"\bPA and lateral\b.*",
]

# Light rewrites / normalizations
REWRITES = [
    (r"\bcardiac and mediastinal contours\b", "cardiomediastinal contours"),
    (r"\bPort[\-\s]?A[\-\s]?Cath\b", "Port-A-Cath"),
    (r"\bETT\b", "endotracheal tube"),
    (r"\bNG[-\s]?tube\b|\bnasogastric tube\b", "nasogastric tube"),
    (r"\bPICC\b", "PICC line"),
    (r"\bCVC\b", "central venous catheter"),
    (r"\bchest[-\s]?port\b|\bmediport\b", "Port-A-Cath"),
    (r"\bcardiomediastinal silhouette\b", "cardiomediastinal contours"),
    (r"\bhemi[-\s]?thorax\b", "hemithorax"),
    (r"\bno acute cardiopulmonary abnormality\b", "no acute cardiopulmonary abnormality"),
]

# Simple clinical finding extraction with negation detection
FINDINGS = {
    "pleural_effusion": r"\bpleural (?:fluid|effusions?)\b|\beffusion(s)?\b",
    "pneumothorax": r"\bpneumothorax\b",
    "consolidation": r"\bconsolidation(s)?\b",
    "atelectasis": r"\batelectasis\b|\bplate (?:like|atelectatic)\b",
    "pulmonary_edema": r"\b(edema|oedema)\b",
    "cardiomegaly": r"\bcardiomegaly\b|\benlarged (cardiac|heart)\b",
    "pleural_thickening": r"\bpleural thickening\b",
    "nodular_opacities": r"\bnodular opacit(?:y|ies)\b|\bnodules?\b",
    "pacemaker": r"\bpacemaker\b|\bICD\b",
    "endotracheal_tube": r"\bendotracheal tube\b",
    "nasogastric_tube": r"\bnasogastric tube\b",
    "picc_cvc": r"\bPICC line\b|\bcentral venous catheter\b|\bcentral line\b",
    "port_a_cath": r"\bPort-A-Cath\b",
    "cardiomediastinal_normal": r"\bcardiomediastinal (?:contours|silhouette) (?:are )?normal\b|\bheart size (?:is )?normal\b",
}

NEGATION = r"\b(no|without|absent|free of|negative for)\b"

def has_finding(text, pattern):
    # positive if pattern exists and not negated within ~5 words before it
    for m in re.finditer(pattern, text, flags=re.I):
        start = max(0, m.start()-50)
        window = text[start:m.start()]
        if not re.search(NEGATION, window, flags=re.I):
            return True
    return False

def is_negated(text, pattern):
    return bool(re.search(NEGATION + r".{0,20}" + pattern, text, flags=re.I))

def clean_sentences(text):
    # Normalize whitespace
    t = re.sub(r"\s+", " ", text.strip())
    # Split to sentences
    sents = SENTENCE_SPLIT.split(t) if t else []
    kept = []
    for s in sents:
        s0 = s.strip()
        if not s0:
            continue
        # Drop non-chest sentences
        if NON_CHEST_TERMS.search(s0):
            continue
        # Drop noisy patterns
        if any(re.search(p, s0, flags=re.I) for p in DROP_PATTERNS):
            continue
        # Apply rewrites
        for pat, rep in REWRITES:
            s0 = re.sub(pat, rep, s0, flags=re.I)
        kept.append(s0)
    return kept

def compress_to_caption(sents, max_tokens=70):
    # Join and then trim by tokens
    caption = " ".join(sents)
    tokens = caption.split()
    if len(tokens) > max_tokens:
        caption = " ".join(tokens[:max_tokens]) + "."
    return caption

def concise_caption(sents, max_tokens=35):
    # Prefer key finding sentences: devices, effusion, pneumothorax, consolidation, atelectasis, edema, cardiomediastinal
    priority = []
    others = []
    key_patterns = [
        r"Port-A-Cath", r"pleural", r"pneumothorax", r"consolidation", r"atelectasis",
        r"edema|oedema", r"cardiomediastinal", r"pacemaker", r"tube", r"line", r"nodule|nodular"
    ]
    for s in sents:
        if any(re.search(k, s, flags=re.I) for k in key_patterns):
            priority.append(s)
        else:
            others.append(s)
    ordered = priority + others
    return compress_to_caption(ordered, max_tokens=max_tokens)

def extract_tags(text):
    tags = {}
    for name, pat in FINDINGS.items():
        present = has_finding(text, pat)
        neg = is_negated(text, pat)
        if present:
            tags[name] = "present"
        elif neg:
            tags[name] = "absent"
    return tags

def preprocess_report(text):
    sents = clean_sentences(text or "")
    cap_long = compress_to_caption(sents, max_tokens=70)
    cap_short = concise_caption(sents, max_tokens=35)
    tags = extract_tags(" ".join(sents))
    return cap_long, cap_short, json.dumps(tags, ensure_ascii=False)

# Apply to the dataset
out = df.copy()
cap_long_list = []
cap_short_list = []
tags_list = []

for t in out["text"].fillna("").tolist():
    long_c, short_c, tags = preprocess_report(t)
    cap_long_list.append(long_c)
    cap_short_list.append(short_c)
    tags_list.append(tags)

out["caption"] = cap_long_list
out["caption_concise"] = cap_short_list
out["tags_json"] = tags_list

# Save
save_path = "dataset/df_train_preprocessed_clip_labels.csv"
out.to_csv(save_path, index=False)

# Show a small preview to the user
preview = out[["id", "path", "caption", "caption_concise", "tags_json"]].head(15)

save_path


'dataset/df_train_preprocessed_clip_labels.csv'