<div class="alert alert-block alert-success">
  <h2>RSNA Pneumonia Dataset - Model Selection</h2>
</div>

<div class="alert alert-block alert-info">
    <h2>Import Libraries and Load Data</h2>
</div>

In [2]:
import os
import re
import time
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader

from transformers import AutoImageProcessor, AutoModelForImageClassification
from sklearn.metrics import (
    roc_auc_score, accuracy_score, precision_recall_fscore_support,
    confusion_matrix
)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
MODEL_ID = "lxyuan/vit-xray-pneumonia-classification"
LABELS_CSV = "/Users/tanmayswami/Downloads/stage_2_train_labels.csv"
IMAGES_DIR = "/Users/tanmayswami/Downloads/stage_2_train_images"

In [None]:
N_POS = 500  
N_NEG = 500  

BATCH_SIZE = 32
NUM_WORKERS = 2
SEED = 42

PATIENT_ID_REGEX = re.compile(
    r"([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})"
)


In [5]:
!pip install pydicom



In [None]:
import pydicom
import numpy as np
from PIL import Image

def dicom_to_pil(path):
    ds = pydicom.dcmread(path)
    img = ds.pixel_array.astype(np.float32)

    if getattr(ds, "PhotometricInterpretation", "") == "MONOCHROME1":
        img = img.max() - img

    img -= img.min()
    if img.max() > 0:
        img /= img.max()
    img = (img * 255.0).clip(0, 255).astype(np.uint8)

    return Image.fromarray(img).convert("RGB")

In [None]:
def index_images(images_dir):
    rows = []
    p = Path(images_dir)

    for fp in p.glob("*.dcm"): 
        patient_id = fp.stem 
        rows.append({"patientId": patient_id, "path": str(fp)})

    out = pd.DataFrame(rows).drop_duplicates(subset=["patientId"])

    if out.empty:
        raise ValueError(f"No .dcm files found in: {images_dir}")

    return out

img_df = index_images(IMAGES_DIR)
print(img_df.head())
print("Indexed:", len(img_df))

                              patientId  \
0  7be6b4de-afe9-43c0-a581-0f49608c8976   
1  2dcdd159-2889-48d3-a0ce-5c7b1086c49d   
2  d8e66874-305e-4c80-9b75-5e764eb718ff   
3  22f2d3ec-f7ea-4778-850d-bb111590202f   
4  cdaa07d4-4234-4cd2-b9bf-abbf5aed1bb4   

                                                path  
0  /Users/tanmayswami/Downloads/stage_2_train_ima...  
1  /Users/tanmayswami/Downloads/stage_2_train_ima...  
2  /Users/tanmayswami/Downloads/stage_2_train_ima...  
3  /Users/tanmayswami/Downloads/stage_2_train_ima...  
4  /Users/tanmayswami/Downloads/stage_2_train_ima...  
Indexed: 26684


In [8]:
def build_patient_labels(labels_csv):
    df = pd.read_csv(labels_csv)
    if "Target" not in df.columns or "patientId" not in df.columns:
        raise ValueError("Expected columns: patientId, Target in RSNA labels CSV.")
    y = df.groupby("patientId")["Target"].max().reset_index()
    y.rename(columns={"Target": "y"}, inplace=True)
    return y

build_patient_labels(LABELS_CSV)['y'].value_counts(dropna=False)

y
0    20672
1     6012
Name: count, dtype: int64

In [None]:
img_df = index_images(IMAGES_DIR)
y_df = build_patient_labels(LABELS_CSV)
df = img_df.merge(y_df, on="patientId", how="inner")
print(len(img_df), len(y_df), len(df)) 

print(df["y"].value_counts())
print("Pos %:", df["y"].mean())

26684 26684 26684
y
0    20672
1     6012
Name: count, dtype: int64
Pos %: 0.225303552690751


<div class="alert alert-block alert-info">
    <h2>Model 1: vit-xray-pneumonia-classification</h2>
</div>

In [None]:
def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

def build_patient_labels(labels_csv):
    # Defining patient-level y = max(Target) across rows.
    df = pd.read_csv(labels_csv)
    if "Target" not in df.columns or "patientId" not in df.columns:
        raise ValueError("Expected columns: patientId, Target in RSNA labels CSV.")
    y = df.groupby("patientId")["Target"].max().reset_index()
    y.rename(columns={"Target": "y"}, inplace=True)
    return y

def index_images(images_dir):
    rows = []
    p = Path(images_dir)

    for fp in p.glob("*.dcm"):
        patient_id = fp.stem 
        rows.append({"patientId": patient_id, "path": str(fp)})

    out = pd.DataFrame(rows).drop_duplicates(subset=["patientId"])

    if out.empty:
        raise ValueError(f"No .dcm files found in: {images_dir}")

    return out

class RSNADataset(Dataset):
    def __init__(self, df, processor):
        self.df = df.reset_index(drop=True)
        self.processor = processor

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = row["path"]
        y = int(row["y"])

        if path.lower().endswith(".dcm"):
            img = dicom_to_pil(path)
        else:
            img = Image.open(path).convert("RGB")

        inputs = self.processor(images=img, return_tensors="pt")
        pixel_values = inputs["pixel_values"].squeeze(0)
        return pixel_values, y

def infer_logits(model, loader, device):
    model.eval()
    all_probs = []
    all_y = []
    t0 = time.time()

    with torch.no_grad():
        for pixel_values, y in tqdm(loader, desc="Inference"):
            pixel_values = pixel_values.to(device)
            outputs = model(pixel_values=pixel_values)
            logits = outputs.logits

            probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()
            all_probs.append(probs)
            all_y.append(y.numpy())

    t1 = time.time()
    probs = np.vstack(all_probs)
    y_true = np.concatenate(all_y)

    elapsed = t1 - t0
    ips = len(y_true) / elapsed if elapsed > 0 else float("inf")
    return probs, y_true, elapsed, ips

def main():
    np.random.seed(SEED)

    device = get_device()
    print(f"Device: {device}")

    processor = AutoImageProcessor.from_pretrained(MODEL_ID)
    model = AutoModelForImageClassification.from_pretrained(MODEL_ID).to(device)

    id2label = getattr(model.config, "id2label", None)
    print("Model id2label:", id2label)

    y_df = build_patient_labels(LABELS_CSV)
    img_df = index_images(IMAGES_DIR)

    df = img_df.merge(y_df, on="patientId", how="inner")
    if df.empty:
        raise ValueError("No overlap between images and labels. Check your folder and CSV patientIds.")
    print(f"Matched images with labels: {len(df):,}")

    pos = df[df["y"] == 1]
    neg = df[df["y"] == 0]

    if len(pos) == 0 or len(neg) == 0:
        raise ValueError(f"Need both classes. Found pos={len(pos)}, neg={len(neg)}")

    pos_s = pos.sample(n=min(N_POS, len(pos)), random_state=SEED)
    neg_s = neg.sample(n=min(N_NEG, len(neg)), random_state=SEED)

    eval_df = pd.concat([pos_s, neg_s]).sample(frac=1, random_state=SEED).reset_index(drop=True)
    print(f"Eval set size: {len(eval_df)} (pos={eval_df['y'].sum()}, neg={(eval_df['y']==0).sum()})")

    ds = RSNADataset(eval_df, processor)
    NUM_WORKERS = 0
    loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

    probs, y_true, elapsed, ips = infer_logits(model, loader, device)

    if id2label and len(id2label) == probs.shape[1]:
        labels = {int(k): v.lower() for k, v in id2label.items()}
        pos_idx = None
        for k, v in labels.items():
            if "pneum" in v or "opacity" in v or "lung" in v:
                pos_idx = k
                break
        if pos_idx is None:
            pos_idx = 1 if probs.shape[1] > 1 else 0
    else:
        pos_idx = 1 if probs.shape[1] > 1 else 0

    y_score = probs[:, pos_idx]
    y_pred = (y_score >= 0.8).astype(int)

    try:
        auc = roc_auc_score(y_true, y_score)
    except ValueError:
        auc = float("nan")

    acc = accuracy_score(y_true, y_pred)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
    cm = confusion_matrix(y_true, y_pred)

    print("\n--- Results ---")
    print(f"AUROC: {auc:.4f}")
    print(f"Accuracy: {acc:.4f}")
    print(f"Precision: {prec:.4f} | Recall: {rec:.4f} | F1: {f1:.4f}")
    print("Confusion Matrix [[TN FP],[FN TP]]:")
    print(cm)
    print(f"Elapsed: {elapsed:.2f}s | Throughput: {ips:.2f} images/sec | Batch size: {BATCH_SIZE}")
    print(f"Positive class index used: {pos_idx}")

if __name__ == "__main__":
    main()

Device: mps
Model id2label: {0: 'NORMAL', 1: 'PNEUMONIA'}
Matched images with labels: 26,684
Eval set size: 1000 (pos=500, neg=500)


Inference: 100%|██████████| 32/32 [00:24<00:00,  1.33it/s]


--- Results ---
AUROC: 0.8144
Accuracy: 0.6770
Precision: 0.6151 | Recall: 0.9460 | F1: 0.7455
Confusion Matrix [[TN FP],[FN TP]]:
[[204 296]
 [ 27 473]]
Elapsed: 24.12s | Throughput: 41.47 images/sec | Batch size: 32
Positive class index used: 1





In [13]:
%pip install --no-cache-dir --force-reinstall torchvision -f https://download.pytorch.org/whl/torch_stable.html

Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torchvision
  Downloading torchvision-0.25.0-cp310-cp310-macosx_11_0_arm64.whl.metadata (5.4 kB)
Collecting numpy (from torchvision)
  Downloading numpy-2.2.6-cp310-cp310-macosx_14_0_arm64.whl.metadata (62 kB)
Collecting torch==2.10.0 (from torchvision)
  Downloading torch-2.10.0-1-cp310-none-macosx_11_0_arm64.whl.metadata (31 kB)
Collecting pillow!=8.3.*,>=5.3.0 (from torchvision)
  Downloading pillow-12.1.0-cp310-cp310-macosx_11_0_arm64.whl.metadata (8.8 kB)
Collecting filelock (from torch==2.10.0->torchvision)
  Downloading filelock-3.20.3-py3-none-any.whl.metadata (2.1 kB)
Collecting typing-extensions>=4.10.0 (from torch==2.10.0->torchvision)
  Downloading typing_extensions-4.15.0-py3-none-any.whl.metadata (3.3 kB)
Collecting sympy>=1.13.3 (from torch==2.10.0->torchvision)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx>=2.5.1 (from torch==2.10.0->torchvision)
  Download

In [1]:
import torch, torchvision
print(torch.__version__, torchvision.__version__)

2.10.0 0.25.0


In [10]:
import os
import time
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms, models

from sklearn.metrics import (
    roc_auc_score, accuracy_score, precision_recall_fscore_support,
    confusion_matrix, classification_report
)

<div class="alert alert-block alert-info">
    <h2>Model 2: Densenet with transfer learning</h2>
</div>

In [None]:
LABELS_CSV = "/Users/tanmayswami/Downloads/stage_2_train_labels.csv"
IMAGES_DIR = "/Users/tanmayswami/Downloads/stage_2_train_images"

SEED = 42
BATCH_SIZE = 32
NUM_WORKERS = 0       
LR = 1e-4
EPOCHS = 8            
WEIGHT_DECAY = 1e-4

N_POS_EVAL = 500
N_NEG_EVAL = 500

N_POS_TRAIN = 2000
N_NEG_TRAIN = 2000

THRESH = 0.8

DENSENET_SAVE_PATH = "rsna_densenet121_epoch_last.pt"
VIT_SAVE_PATH = "rsna_vit_hf_state_dict.pt" 

def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

device = get_device()
print("Device:", device)

torch.manual_seed(SEED)
np.random.seed(SEED)

class RSNADatasetTorchvision(Dataset):
    def __init__(self, df: pd.DataFrame, tfm: transforms.Compose):
        self.df = df.reset_index(drop=True)
        self.tfm = tfm

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        img = dicom_to_pil(row["path"]) 
        y = int(row["y"])
        x = self.tfm(img)            
        return x, y

def build_rsna_df(labels_csv, images_dir):
    y_df = build_patient_labels(labels_csv)   
    img_df = index_images(images_dir)        
    df = img_df.merge(y_df, on="patientId", how="inner")
    if df.empty:
        raise ValueError("No overlap between images and labels. Check paths.")
    return df

df_all = build_rsna_df(LABELS_CSV, IMAGES_DIR)
print("Total matched:", len(df_all))
print(df_all["y"].value_counts(), "| Pos %:", df_all["y"].mean())

def make_balanced_sample(df, n_pos, n_neg, seed):
    pos = df[df["y"] == 1]
    neg = df[df["y"] == 0]
    if len(pos) == 0 or len(neg) == 0:
        raise ValueError("Need both classes.")
    pos_s = pos.sample(n=min(n_pos, len(pos)), random_state=seed)
    neg_s = neg.sample(n=min(n_neg, len(neg)), random_state=seed)
    out = pd.concat([pos_s, neg_s]).sample(frac=1, random_state=seed).reset_index(drop=True)
    return out

train_df = make_balanced_sample(df_all, N_POS_TRAIN, N_NEG_TRAIN, SEED)
eval_df  = make_balanced_sample(df_all, N_POS_EVAL,  N_NEG_EVAL,  SEED + 1)

print("Train:", len(train_df), "Eval:", len(eval_df))
print("Train pos:", train_df["y"].sum(), "Eval pos:", eval_df["y"].sum())

train_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
])

eval_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
])

train_ds = RSNADatasetTorchvision(train_df, train_tfm)
eval_ds  = RSNADatasetTorchvision(eval_df,  eval_tfm)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS)
eval_loader  = DataLoader(eval_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)


def build_densenet121(num_classes=2):
    m = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
    in_features = m.classifier.in_features
    m.classifier = nn.Linear(in_features, num_classes)
    return m

model = build_densenet121(num_classes=2).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    n = 0
    for x, y in tqdm(loader, desc="Train", leave=False):
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        bs = y.size(0)
        total_loss += loss.item() * bs
        n += bs
    return total_loss / max(n, 1)

@torch.no_grad()
def infer_probs(model, loader, device):
    model.eval()
    all_probs = []
    all_y = []
    t0 = time.time()
    for x, y in tqdm(loader, desc="Infer", leave=False):
        x = x.to(device)
        logits = model(x)
        probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()
        all_probs.append(probs)
        all_y.append(y.numpy())
    t1 = time.time()
    probs = np.vstack(all_probs)
    y_true = np.concatenate(all_y)
    elapsed = t1 - t0
    ips = len(y_true) / elapsed if elapsed > 0 else float("inf")
    return probs, y_true, elapsed, ips

def evaluate_binary_from_probs(probs, y_true, pos_idx, thresh):
    y_score = probs[:, pos_idx]
    y_pred = (y_score >= thresh).astype(int)

    auc = roc_auc_score(y_true, y_score) if len(np.unique(y_true)) > 1 else float("nan")
    acc = accuracy_score(y_true, y_pred)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
    cm = confusion_matrix(y_true, y_pred)

    return {
        "auc": auc, "acc": acc, "prec": prec, "rec": rec, "f1": f1,
        "cm": cm, "y_score": y_score, "y_pred": y_pred
    }

for epoch in range(1, EPOCHS + 1):
    loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
    probs, y_true, elapsed, ips = infer_probs(model, eval_loader, device)
    metrics = evaluate_binary_from_probs(probs, y_true, pos_idx=1, thresh=THRESH)

    print(f"\nEpoch {epoch}/{EPOCHS}")
    print(f"Train loss: {loss:.4f}")
    print(f"Eval AUROC: {metrics['auc']:.4f} | Acc: {metrics['acc']:.4f}")
    print(f"Precision: {metrics['prec']:.4f} | Recall: {metrics['rec']:.4f} | F1: {metrics['f1']:.4f}")
    print("Confusion Matrix [[TN FP],[FN TP]]:")
    print(metrics["cm"])
    print(f"Elapsed: {elapsed:.2f}s | Throughput: {ips:.2f} imgs/s | Batch: {BATCH_SIZE} | Thresh: {THRESH}")

torch.save(
    {
        "model_name": "densenet121",
        "state_dict": model.state_dict(),
        "threshold": THRESH,
        "config": {
            "epochs": EPOCHS,
            "lr": LR,
            "weight_decay": WEIGHT_DECAY,
            "batch_size": BATCH_SIZE,
            "train_pos": N_POS_TRAIN,
            "train_neg": N_NEG_TRAIN,
            "eval_pos": N_POS_EVAL,
            "eval_neg": N_NEG_EVAL,
            "seed": SEED,
        },
    },
    DENSENET_SAVE_PATH
)
print(f"\nSaved DenseNet checkpoint to: {DENSENET_SAVE_PATH}")

def load_densenet_checkpoint(ckpt_path, device):
    ckpt = torch.load(ckpt_path, map_location=device)
    m = build_densenet121(num_classes=2).to(device)
    m.load_state_dict(ckpt["state_dict"])
    m.eval()
    return m

loaded_model = load_densenet_checkpoint(DENSENET_SAVE_PATH, device)
print("\nClassification report (DenseNet @ thresh={}):".format(THRESH))
print(classification_report(y_true, metrics["y_pred"], target_names=["NORMAL", "PNEUMONIA"], zero_division=0))

Device: mps
Total matched: 26684
y
0    20672
1     6012
Name: count, dtype: int64 | Pos %: 0.225303552690751
Train: 4000 Eval: 1000
Train pos: 2000 Eval pos: 500


                                                        


Epoch 1/8
Train loss: 0.5089
Eval AUROC: 0.8658 | Acc: 0.6500
Precision: 0.9076 | Recall: 0.3340 | F1: 0.4883
Confusion Matrix [[TN FP],[FN TP]]:
[[483  17]
 [333 167]]
Elapsed: 12.36s | Throughput: 80.88 imgs/s | Batch: 32 | Thresh: 0.8


                                                        


Epoch 2/8
Train loss: 0.4152
Eval AUROC: 0.8799 | Acc: 0.7270
Precision: 0.9097 | Recall: 0.5040 | F1: 0.6486
Confusion Matrix [[TN FP],[FN TP]]:
[[475  25]
 [248 252]]
Elapsed: 12.22s | Throughput: 81.84 imgs/s | Batch: 32 | Thresh: 0.8


                                                        


Epoch 3/8
Train loss: 0.3461
Eval AUROC: 0.8871 | Acc: 0.7560
Precision: 0.9129 | Recall: 0.5660 | F1: 0.6988
Confusion Matrix [[TN FP],[FN TP]]:
[[473  27]
 [217 283]]
Elapsed: 12.22s | Throughput: 81.84 imgs/s | Batch: 32 | Thresh: 0.8


                                                        


Epoch 4/8
Train loss: 0.2779
Eval AUROC: 0.8840 | Acc: 0.7560
Precision: 0.9076 | Recall: 0.5700 | F1: 0.7002
Confusion Matrix [[TN FP],[FN TP]]:
[[471  29]
 [215 285]]
Elapsed: 12.10s | Throughput: 82.63 imgs/s | Batch: 32 | Thresh: 0.8


                                                        


Epoch 5/8
Train loss: 0.1982
Eval AUROC: 0.8894 | Acc: 0.7940
Precision: 0.8326 | Recall: 0.7360 | F1: 0.7813
Confusion Matrix [[TN FP],[FN TP]]:
[[426  74]
 [132 368]]
Elapsed: 12.10s | Throughput: 82.63 imgs/s | Batch: 32 | Thresh: 0.8


                                                        


Epoch 6/8
Train loss: 0.1336
Eval AUROC: 0.8904 | Acc: 0.7820
Precision: 0.8917 | Recall: 0.6420 | F1: 0.7465
Confusion Matrix [[TN FP],[FN TP]]:
[[461  39]
 [179 321]]
Elapsed: 12.47s | Throughput: 80.18 imgs/s | Batch: 32 | Thresh: 0.8


                                                        


Epoch 7/8
Train loss: 0.1326
Eval AUROC: 0.8733 | Acc: 0.7940
Precision: 0.8063 | Recall: 0.7740 | F1: 0.7898
Confusion Matrix [[TN FP],[FN TP]]:
[[407  93]
 [113 387]]
Elapsed: 12.27s | Throughput: 81.50 imgs/s | Batch: 32 | Thresh: 0.8


                                                        


Epoch 8/8
Train loss: 0.1072
Eval AUROC: 0.8806 | Acc: 0.8130
Precision: 0.8111 | Recall: 0.8160 | F1: 0.8136
Confusion Matrix [[TN FP],[FN TP]]:
[[405  95]
 [ 92 408]]
Elapsed: 12.23s | Throughput: 81.73 imgs/s | Batch: 32 | Thresh: 0.8

Saved DenseNet checkpoint to: rsna_densenet121_epoch_last.pt

Classification report (DenseNet @ thresh=0.8):
              precision    recall  f1-score   support

      NORMAL       0.81      0.81      0.81       500
   PNEUMONIA       0.81      0.82      0.81       500

    accuracy                           0.81      1000
   macro avg       0.81      0.81      0.81      1000
weighted avg       0.81      0.81      0.81      1000

