In [2]:
import os
from torch.utils.data import Dataset
from PIL import Image
import clip

class CLIPChestXrayDataset(Dataset):
    def __init__(self, image_dir, report_dir, preprocess, tokenizer, max_tokens=77):
        self.image_dir = image_dir
        self.report_dir = report_dir
        self.preprocess = preprocess
        self.tokenizer = tokenizer
        self.max_tokens = max_tokens

        # Match filenames
        self.image_files = sorted(os.listdir(image_dir))
        self.report_files = sorted(os.listdir(report_dir))

        # Extract common IDs (e.g., "CXR123_IM-xxxx-xxxx" from "123.xml")
        self.pairs = []
        for img_file in self.image_files:
            base_id = img_file.split("_")[0].replace("CXR", "")
            report_file = base_id + ".xml"
            if os.path.exists(os.path.join(report_dir, report_file)):
                self.pairs.append((img_file, report_file))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img_file, report_file = self.pairs[idx]

        # Load image
        image_path = os.path.join(self.image_dir, img_file)
        image = self.preprocess(Image.open(image_path).convert("RGB"))

        # Load report text
        report_path = os.path.join(self.report_dir, report_file)
        text = extract_report_text(report_path)

        # Truncate or clean if needed
        tokenized = self.tokenizer([text], truncate=True)[0]

        return image, tokenized

# You can use this external function from before
def extract_report_text(xml_file):
    import xml.etree.ElementTree as ET
    tree = ET.parse(xml_file)
    root = tree.getroot()

    findings, impression = "", ""
    for child in root.findall(".//AbstractText"):
        label = child.get("Label", "").lower()
        if "impression" in label and child.text:
            impression = child.text.strip()
        elif "findings" in label and child.text:
            findings = child.text.strip()

    # Clean 'XXXX', prioritize impression
    full_text = f"{impression}. {findings}"
    cleaned = " ".join(word for word in full_text.split() if "XXXX" not in word)
    return cleaned[:512]  # Truncate long reports for safety


In [3]:
import torch
import clip
from torch.utils.data import DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

train_dataset = CLIPChestXrayDataset(
    image_dir="C:\\Users\\2003j\\Downloads\\into_to_ml\\chest_reports\\split_data\\train\\images",
    report_dir="C:\\Users\\2003j\\Downloads\\into_to_ml\\chest_reports\\split_data\\train\\reports",
    preprocess=preprocess,
    tokenizer=clip.tokenize
)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

In [5]:
import torch.nn.functional as F
from tqdm import tqdm

def train_clip_model(train_loader, model, device, epochs=5, lr=1e-5):
    model.train()
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    for epoch in range(epochs):
        total_loss = 0
        progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        
        for batch in progress:
            images, texts = batch
            images = images.to(device)
            texts = texts.to(device)

            optimizer.zero_grad()

            # Get CLIP features
            image_features = model.encode_image(images)
            text_features = model.encode_text(texts)

            # Normalize features
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            # Compute cosine similarity
            logits_per_image = image_features @ text_features.t()
            logits_per_text = text_features @ image_features.t()

            # Labels are just indices (like [0, 1, 2, ..., batch_size - 1])
            labels = torch.arange(len(images)).to(device)

            # Cross entropy loss (both directions)
            loss_i = F.cross_entropy(logits_per_image, labels)
            loss_t = F.cross_entropy(logits_per_text, labels)
            loss = (loss_i + loss_t) / 2

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            progress.set_postfix(loss=loss.item())

        print(f"✅ Epoch {epoch+1} completed - Avg Loss: {total_loss / len(train_loader):.4f}")

    return model


In [6]:
fine_tuned_model = train_clip_model(train_loader, model, device=device)


Epoch 1/5: 100%|████████████████████████████████████████████████████████| 371/371 [1:42:55<00:00, 16.65s/it, loss=2.63]


✅ Epoch 1 completed - Avg Loss: 2.7438


Epoch 2/5: 100%|█████████████████████████████████████████████████████████| 371/371 [8:01:16<00:00, 77.84s/it, loss=2.5]


✅ Epoch 2 completed - Avg Loss: 2.6657


Epoch 3/5: 100%|████████████████████████████████████████████████████████| 371/371 [3:19:01<00:00, 32.19s/it, loss=2.32]


✅ Epoch 3 completed - Avg Loss: 2.5855


Epoch 4/5: 100%|████████████████████████████████████████████████████████| 371/371 [9:38:25<00:00, 93.55s/it, loss=2.56]


✅ Epoch 4 completed - Avg Loss: 2.5013


Epoch 5/5: 100%|██████████████████████████████████████████████████████████| 371/371 [43:00<00:00,  6.96s/it, loss=2.18]

✅ Epoch 5 completed - Avg Loss: 2.4144





In [11]:
# Test Dataset Setup
test_dataset = CLIPChestXrayDataset(
    image_dir="C:\\Users\\2003j\\Downloads\\into_to_ml\\chest_reports\\split_data\\test\\images",
    report_dir="C:\\Users\\2003j\\Downloads\\into_to_ml\\chest_reports\\split_data\\test\\reports",
    preprocess=preprocess,
    tokenizer=clip.tokenize
)

# Load all data into memory (small test set, so OK)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)


NameError: name 'CLIPChestXrayDataset' is not defined

In [10]:
import torch.nn.functional as F

def evaluate_clip_model(model, test_loader, device):
    model.eval()
    top1_correct = 0
    top3_correct = 0
    total = 0

    all_report_features = []

    print("🔍 Encoding all reports...")
    # Step 1: Encode all reports in test set (once)
    with torch.no_grad():
        for _, tokenized_text in test_loader:
            tokenized_text = tokenized_text.to(device)
            text_features = model.encode_text(tokenized_text)
            text_features = F.normalize(text_features, dim=-1)
            all_report_features.append(text_features)

    all_report_features = torch.cat(all_report_features, dim=0)  # Shape: (N, D)

    print("🖼 Matching images with reports...")
    with torch.no_grad():
        for idx, (image, tokenized_text) in enumerate(test_loader):
            image = image.to(device)
            image_features = model.encode_image(image)
            image_features = F.normalize(image_features, dim=-1)

            # Compute similarity with all report features
            similarity = (image_features @ all_report_features.T).squeeze(0)

            # Top-k matching indices
            top_k = similarity.topk(3).indices

            total += 1
            if idx == top_k[0].item():
                top1_correct += 1
            if idx in top_k:
                top3_correct += 1

    print(f"\n📊 Total Samples: {total}")
    print(f"🎯 Top-1 Accuracy: {top1_correct / total:.2f}")
    print(f"🔁 Top-3 Accuracy: {top3_correct / total:.2f}")


In [11]:
evaluate_clip_model(fine_tuned_model, test_loader, device)


🔍 Encoding all reports...
🖼 Matching images with reports...

📊 Total Samples: 743
🎯 Top-1 Accuracy: 0.00
🔁 Top-3 Accuracy: 0.01


In [4]:
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt

def inspect_clip_predictions(model, test_loader, device, num_samples=5):
    model.eval()

    image_list = []
    text_list = []

    print("🔍 Encoding reports...")
    with torch.no_grad():
        for image, tokenized in test_loader:
            image_list.append(image.to(device))
            text_list.append(tokenized.to(device))
            if len(image_list) >= num_samples:
                break

        # Encode all sampled reports
        all_text_features = torch.cat([
            F.normalize(model.encode_text(text), dim=-1)
            for text in text_list
        ], dim=0)

        # Encode all sampled images
        all_image_features = torch.cat([
            F.normalize(model.encode_image(img), dim=-1)
            for img in image_list
        ], dim=0)

    print("🔎 Matching images with reports...\n")
    for idx, image_feature in enumerate(all_image_features):
        sim_scores = (image_feature @ all_text_features.T).squeeze(0)
        top_indices = sim_scores.topk(3).indices.tolist()

        # Load the original image
        img_tensor = image_list[idx].cpu().squeeze().permute(1, 2, 0).numpy()
        plt.imshow(img_tensor, cmap="gray")
        plt.title(f"Test Image #{idx + 1}")
        plt.axis("off")
        plt.show()

        print(f"📷 Image #{idx+1}")
        print(f"✅ Ground Truth Report:\n{test_loader.dataset[idx][1][:200]}...\n")
        print("🔮 Top Predictions:")
        for rank, i in enumerate(top_indices):
            pred_text = test_loader.dataset[i][1]
            print(f"  🔹 Top-{rank+1}: {pred_text[:200]}...\n")
        print("-" * 80)


In [10]:
inspect_clip_predictions(fine_tuned_model, test_loader, device, num_samples=5)


NameError: name 'fine_tuned_model' is not defined