## DICOM JPEG Decompression Support

The RSNA Pulmonary Embolism dataset contains JPEG-compressed DICOM images.
To enable correct decoding of pixel data, we install `pylibjpeg` and
`pylibjpeg-libjpeg`, which are required by `pydicom` for decompression.


In [1]:
# Install DICOM JPEG decoders (required for RSNA dataset)
!pip install -q pylibjpeg pylibjpeg-libjpeg

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m36.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import torch
import os, cv2
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from glob import glob
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    roc_auc_score, roc_curve,
    confusion_matrix, accuracy_score,
    precision_recall_curve, average_precision_score
)
import pydicom
import timm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)




Using device: cuda


In [3]:
DATA_ROOT = "/kaggle/input/rsna-str-pulmonary-embolism-detection"
IMG_ROOT = os.path.join(DATA_ROOT, "train")
LABELS_PATH = os.path.join(DATA_ROOT, "train.csv")

df = pd.read_csv(LABELS_PATH)

# Study-level label
df["label"] = (df["negative_exam_for_pe"] == 0).astype(int)

print("Total rows:", len(df))


Total rows: 1790594


In [4]:
df.head()

Unnamed: 0,StudyInstanceUID,SeriesInstanceUID,SOPInstanceUID,pe_present_on_image,negative_exam_for_pe,qa_motion,qa_contrast,flow_artifact,rv_lv_ratio_gte_1,rv_lv_ratio_lt_1,leftsided_pe,chronic_pe,true_filling_defect_not_pe,rightsided_pe,acute_and_chronic_pe,central_pe,indeterminate,label
0,6897fa9de148,2bfbb7fd2e8b,c0f3cb036d06,0,0,0,0,0,0,1,1,0,0,1,0,0,0,1
1,6897fa9de148,2bfbb7fd2e8b,f57ffd3883b6,0,0,0,0,0,0,1,1,0,0,1,0,0,0,1
2,6897fa9de148,2bfbb7fd2e8b,41220fda34a3,0,0,0,0,0,0,1,1,0,0,1,0,0,0,1
3,6897fa9de148,2bfbb7fd2e8b,13b685b4b14f,0,0,0,0,0,0,1,1,0,0,1,0,0,0,1
4,6897fa9de148,2bfbb7fd2e8b,be0b7524ffb4,0,0,0,0,0,0,1,1,0,0,1,0,0,0,1


In [5]:
def series_has_images(study, series):
    p = os.path.join(IMG_ROOT, study, series)
    return os.path.exists(p) and len(glob(p+"/*.dcm")) > 0

series_df = df.groupby("SeriesInstanceUID")["label"].first().reset_index()

train_ids, val_ids = train_test_split(
    series_df.SeriesInstanceUID,
    test_size=0.2,
    stratify=series_df.label,
    random_state=42
)

train_df = df[df.SeriesInstanceUID.isin(train_ids)]
val_df = df[df.SeriesInstanceUID.isin(val_ids)]

print("Train series:", train_df.SeriesInstanceUID.nunique())
print("Val series:", val_df.SeriesInstanceUID.nunique())


Train series: 5823
Val series: 1456


In [6]:
import pydicom

In [7]:
sample = train_df.iloc[0]
sample_path = glob(
    os.path.join(
        IMG_ROOT,
        sample.StudyInstanceUID,
        sample.SeriesInstanceUID,
        "*.dcm"
    )
)[0]

dcm = pydicom.dcmread(sample_path)
raw = dcm.pixel_array.astype(np.float32)
raw = raw * dcm.RescaleSlope + dcm.RescaleIntercept
win = window_ct(raw)

plt.figure(figsize=(12,4))
plt.subplot(1,3,1); plt.imshow(raw,cmap="gray"); plt.title("Raw CT (HU)")
plt.subplot(1,3,2); plt.imshow(win,cmap="gray"); plt.title("Windowed CT")
plt.subplot(1,3,3); plt.hist(raw.flatten(), bins=200); plt.title("HU Histogram")
plt.tight_layout(); plt.show()


NameError: name 'window_ct' is not defined

In [None]:
def window_ct(img, level=100, width=700):
    low = level - width // 2
    high = level + width // 2
    img = np.clip(img, low, high)
    img = (img - low) / (high - low)
    return img


In [None]:
class RSNADataset(Dataset):
    def __init__(self, df, stack=4, train=True):
        self.groups = df.groupby("SeriesInstanceUID")
        self.series_ids = list(self.groups.groups.keys())
        self.stack = stack
        self.train = train

    def __len__(self):
        return len(self.series_ids)

    def __getitem__(self, idx):
        sid = self.series_ids[idx]
        g = self.groups.get_group(sid)

        study_uid = g.StudyInstanceUID.iloc[0]
        files = glob(os.path.join(IMG_ROOT, study_uid, sid, "*.dcm"))

        slices = []
        for f in files:
            dcm = pydicom.dcmread(f)
            img = dcm.pixel_array.astype(np.float32)
            img = img * dcm.RescaleSlope + dcm.RescaleIntercept
            z = float(dcm.ImagePositionPatient[2])
            img = window_ct(img)
            img = cv2.resize(img, (224,224))
            slices.append((z, img))

        slices = [s[1] for s in sorted(slices, key=lambda x: x[0])]
        n = len(slices)

        if n < 2*self.stack + 1:
            center = n // 2
            idxs = [center] * (2*self.stack + 1)
        else:
            center = (
                np.random.randint(self.stack, n-self.stack)
                if self.train else n//2
            )
            idxs = range(center-self.stack, center+self.stack+1)

        x = torch.tensor(np.stack([slices[i] for i in idxs])).unsqueeze(1)
        y = torch.tensor(g.label.iloc[0], dtype=torch.float32)

        return x, y


In [None]:
train_loader = DataLoader(
    RSNADataset(train_df, stack=4, train=True),
    batch_size=4,
    shuffle=True,
    num_workers=2
)

val_loader = DataLoader(
    RSNADataset(val_df, stack=4, train=False),
    batch_size=4,
    shuffle=False,
    num_workers=2
)


In [None]:
class UnifiedPEModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = timm.create_model(
            "efficientnet_b2",
            pretrained=True,
            in_chans=1,
            features_only=True
        )

        C = self.encoder.feature_info[-1]["num_chs"]

        self.slice_attn = nn.Sequential(
            nn.Linear(C, C//2),
            nn.ReLU(),
            nn.Linear(C//2, 1)
        )

        self.cls_head = nn.Sequential(
            nn.Dropout(0.4),
            nn.Linear(C, 1)
        )

        self.seg_head = nn.Sequential(
            nn.Conv2d(C, C//2, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(C//2, 1, 1)
        )

        self.feature_maps = None
        self.feature_grads = {}

    def save_grad(self, idx):
        def hook(grad):
            self.feature_grads[idx] = grad
        return hook

    def forward(self, x):

        B,S,C,H,W = x.shape
        x = x.view(B*S,C,H,W)

        feats = self.encoder(x)

        self.feature_maps = feats
        self.feature_grads = {}

        if torch.is_grad_enabled():
            for i,f in enumerate(feats):
                f.register_hook(self.save_grad(i))

        feat = feats[-1].view(B,S,feats[-1].shape[1],
                              feats[-1].shape[2],
                              feats[-1].shape[3])

        pooled = feat.mean(dim=(3,4))
        attn = torch.softmax(self.slice_attn(pooled),dim=1)
        feat = (feat*attn.unsqueeze(-1).unsqueeze(-1)).sum(dim=1)

        cls = self.cls_head(
            F.adaptive_avg_pool2d(feat,1).flatten(1)
        ).squeeze(1)

        seg = self.seg_head(feat)
        seg_up = F.interpolate(seg,(224,224),mode="bilinear")

        return cls, seg_up


In [None]:
class MultiScaleGradCAM:
    def __init__(self, model, scales=(1,2,3)):
        self.model = model
        self.scales = scales

    def generate(self, x):
        self.model.zero_grad()

        cls, _ = self.model(x)
        cls.mean().backward(retain_graph=True)

        cams = []

        for i in self.scales:
            act = self.model.feature_maps[i]
            grad = self.model.feature_grads[i]

            w = grad.mean(dim=(2,3), keepdim=True)
            cam = F.relu((w * act).sum(1, keepdim=True))
            cam = F.interpolate(cam, (224,224))
            cams.append(cam)

        cam = torch.mean(torch.stack(cams), 0)
        return cam / (cam.max() + 1e-8)


In [None]:
num_pos = (train_df.label == 1).sum()
num_neg = (train_df.label == 0).sum()
pos_weight = torch.tensor([num_neg/num_pos]).to(DEVICE)

cls_loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
seg_loss_fn = nn.BCEWithLogitsLoss()

def dice_loss(pred, target, smooth=1e-6):
    pred = torch.sigmoid(pred)
    inter = (pred * target).sum(dim=(1,2,3))
    union = pred.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3))
    return 1 - ((2*inter+smooth)/(union+smooth)).mean()

model = UnifiedPEModel().to(DEVICE)
cam_gen = MultiScaleGradCAM(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
scaler = torch.amp.GradScaler("cuda")


In [None]:
EPOCHS = 50

for epoch in range(EPOCHS):

    model.train()
    total = 0

    if epoch < 5:
        seg_weight = 0.0
    elif epoch < 15:
        seg_weight = 0.05
    else:
        seg_weight = 0.1

    print(f"\nEpoch {epoch+1}/{EPOCHS} | Seg Weight: {seg_weight}")

    for x,y in tqdm(train_loader):
        x,y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()

        with torch.amp.autocast("cuda"):
            cls, seg = model(x)
            cls_loss = cls_loss_fn(cls,y)

            if seg_weight > 0:
                cam = cam_gen.generate(x)
                B,S = x.shape[0],x.shape[1]
                cam = cam.view(B,S,1,224,224).mean(1)

                cam_min = cam.view(B,-1).min(dim=1)[0].view(B,1,1,1)
                cam_max = cam.view(B,-1).max(dim=1)[0].view(B,1,1,1)
                cam_n = (cam-cam_min)/(cam_max-cam_min+1e-8)

                with torch.no_grad():
                    pseudo_mask = (cam_n>0.35).float()

                pe = y.view(-1,1,1,1)
                seg_loss = seg_loss_fn(seg*pe,pseudo_mask*pe) + \
                           dice_loss(seg*pe,pseudo_mask*pe)

                loss = cls_loss + seg_weight*seg_loss
            else:
                loss = cls_loss

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total += loss.item()

    scheduler.step()
    print("Epoch Loss:", total/len(train_loader))


In [None]:
model.eval()
y_true, y_prob = [], []

with torch.no_grad():
    for x, y in tqdm(val_loader):
        x = x.to(DEVICE)
        cls, _ = model(x)

        y_true.extend(y.numpy())
        y_prob.extend(torch.sigmoid(cls).cpu().numpy())

y_true = np.array(y_true)
y_prob = np.array(y_prob)

# AUROC
auc = roc_auc_score(y_true, y_prob)
print("Validation AUROC:", auc)

# Default threshold 0.5
y_pred = (y_prob > 0.5).astype(int)

cm = confusion_matrix(y_true, y_pred)
tn, fp, fn, tp = cm.ravel()

print("Accuracy:", accuracy_score(y_true, y_pred))
print("Sensitivity:", tp/(tp+fn+1e-8))
print("Specificity:", tn/(tn+fp+1e-8))


In [None]:
fpr, tpr, _ = roc_curve(y_true, y_prob)

plt.figure(figsize=(6,5))
plt.plot(fpr, tpr, label=f"AUC = {auc:.3f}")
plt.plot([0,1], [0,1], '--')
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
plt.figure(figsize=(5,4))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.show()


In [None]:
precision, recall, _ = precision_recall_curve(y_true, y_prob)
ap = average_precision_score(y_true, y_prob)

plt.figure(figsize=(6,5))
plt.plot(recall, precision, label=f"AP = {ap:.3f}")
plt.xlabel("Recall (Sensitivity)")
plt.ylabel("Precision")
plt.title("Precision–Recall Curve")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
val_losses = []

model.eval()
with torch.no_grad():
    total = 0
    for x, y in val_loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        cls, _ = model(x)
        total += cls_loss_fn(cls, y).item()

val_losses.append(total / len(val_loader))

plt.figure(figsize=(6,5))
plt.plot(train_losses, label="Training Loss")
plt.plot(val_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training vs Validation Loss")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
x, y = next(iter(val_loader))
x = x.to(DEVICE)

model.eval()

cls, seg = model(x)
cam = cam_gen.generate(x)

with torch.no_grad():
    seg_out = torch.sigmoid(seg)

B, S = x.shape[0], x.shape[1]
cam = cam.view(B, S, 1, 224, 224).mean(1)

plt.figure(figsize=(14,4))

plt.subplot(1,3,1)
plt.imshow(x[0,2,0].cpu(), cmap="gray")
plt.title("CT")

plt.subplot(1,3,2)
plt.imshow(cam[0,0].detach().cpu(), cmap="jet")
plt.title("Pseudo Mask (CAM)")

plt.subplot(1,3,3)
plt.imshow(seg_out[0,0].cpu(), cmap="jet")
plt.title("Weak Segmentation")

plt.show()


In [None]:
thresholds = np.linspace(0, 1, 50)

sens_list, spec_list, acc_list = [], [], []

for t in thresholds:
    preds = (y_prob > t).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_true, preds).ravel()

    sens_list.append(tp/(tp+fn+1e-8))
    spec_list.append(tn/(tn+fp+1e-8))
    acc_list.append((tp+tn)/(tp+tn+fp+fn))

plt.figure(figsize=(7,5))
plt.plot(thresholds, sens_list, label="Sensitivity")
plt.plot(thresholds, spec_list, label="Specificity")
plt.plot(thresholds, acc_list, label="Accuracy")

plt.xlabel("Decision Threshold")
plt.ylabel("Metric Value")
plt.title("Threshold Sensitivity Analysis")
plt.legend()
plt.grid(True)
plt.show()
