In [11]:
# merge_data.py
# merge data_splits and event/time_to_event columns from clinical data
import pandas as pd 
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    Orientationd,
    Spacingd,
    ResampleToMatchd,
    CropForegroundd,
    ResizeWithPadOrCropd,
    NormalizeIntensityd,
    ConcatItemsd,
    RandAffined,
    RandGaussianNoised,
    RandBiasFieldd,
    RandAdjustContrastd,
    NormalizeIntensityd,
    MapTransform
    
)
from monai.data import Dataset, DataLoader, ITKReader
import numpy as np
import matplotlib.pyplot as plt
import torch
from pathlib import Path
import os

# Get data splits
data_splits = pd.read_csv("data_split_5fold.csv")

# Get clinical data
clinical_data = pd.read_csv('clinical_data.csv')

# Select only relevant columns "patient_id", "BCR", "time_to_follow-up/BCR"
df = clinical_data[["patient_id", "BCR","time_to_follow-up/BCR"]]

# Rename for convenience
df = df.rename(columns = {
    "BCR": "event",
    "time_to_follow-up/BCR": "time_to_event"
})

# Merge data_splits and df
df = df.merge(data_splits, on="patient_id")

# Split train and test ids
train_ids = df[df["fold"] != 0]["patient_id"].tolist()
test_ids = df[df["fold"] == 0]["patient_id"].tolist()

# Initialize empty dicts
train_files = []
test_files = []
# create train_dict
for patient_id in train_ids:
    train_files.append({
        "id": patient_id,
        "adc":  f"radiology/mpMRI/{patient_id}/{patient_id}_0001_adc.mha",
        "hbv":  f"radiology/mpMRI/{patient_id}/{patient_id}_0001_hbv.mha",
        "t2w":  f"radiology/mpMRI/{patient_id}/{patient_id}_0001_t2w.mha",
        "mask": f"radiology/prostate_mask_t2w/{patient_id}_0001_mask.mha",
        "time_to_event": df.loc[df["patient_id"] == patient_id, "time_to_event"].iloc[0],
        "event": df.loc[df["patient_id"] == patient_id, "event"].iloc[0]
                            
    })
    
# create train_dict
for patient_id in test_ids:
    test_files.append({
        "id": patient_id,
        "adc":  f"radiology/mpMRI/{patient_id}/{patient_id}_0001_adc.mha",
        "hbv":  f"radiology/mpMRI/{patient_id}/{patient_id}_0001_hbv.mha",
        "t2w":  f"radiology/mpMRI/{patient_id}/{patient_id}_0001_t2w.mha",
        "mask": f"radiology/prostate_mask_t2w/{patient_id}_0001_mask.mha",
        "time_to_event": df.loc[df["patient_id"] == patient_id, "time_to_event"].iloc[0],
        "event": df.loc[df["patient_id"] == patient_id, "event"].iloc[0]
                            
    })

target_spacing = (1.0, 1.0, 3.0)
target_shape   = (160, 160, 48)   # pick something reasonable

keys_img = ["adc", "hbv", "t2w"]
keys_all = ["adc", "hbv", "t2w", "mask"]
    
train_transforms = Compose([
    LoadImaged(keys=keys_all, 
               reader=ITKReader(), 
               image_only=False),
    EnsureChannelFirstd(keys=keys_all),
    Orientationd(keys=keys_all, axcodes="RAS"),
    Spacingd(
        keys=["t2w"],
        pixdim=target_spacing,
        mode=("trilinear",),
    ),
    Spacingd(
        keys=["mask"],
        pixdim=target_spacing,
        mode=("nearest",),
    ),
    ResampleToMatchd(
        keys=["adc", "hbv"],
        key_dst="t2w",                
        mode=("trilinear", "trilinear"),
    ),
    CropForegroundd(
        keys=keys_all,
        source_key="mask",
        margin=(16, 16, 2),
    ),
    ResizeWithPadOrCropd(
        keys=keys_all,
        spatial_size=target_shape,
    ),
    RandAffined(
        keys=keys_img,
        spatial_size=target_shape,     
        rotate_range=(0.087, 0.087, 0.087),
        translate_range=(5, 5, 1),
        scale_range=(0.05, 0.05, 0.0),
        mode=("trilinear",) * len(keys_img),
        prob=0.5,
    ),
    RandGaussianNoised(keys=keys_img, prob=0.2, mean=0, std=0.01),
    RandBiasFieldd(keys=keys_img, prob=0.3, coeff_range=(0.0, 0.1)),
    RandAdjustContrastd(keys=keys_img, prob=0.3, gamma=(0.9, 1.1)),
    NormalizeIntensityd(
        keys=keys_img,
        nonzero=True,
        channel_wise=True,
    ),
    ConcatItemsd(
        keys=["adc", "hbv", "t2w"],
        name="image",
        dim=0,
    ),
])

test_transforms = Compose([
    LoadImaged(keys=keys_all, 
               reader=ITKReader(), 
               image_only=False),
    EnsureChannelFirstd(keys=keys_all),
    Orientationd(keys=keys_all, axcodes="RAS"),
    Spacingd(
        keys=["t2w"],
        pixdim=target_spacing,
        mode=("trilinear",),
    ),
    Spacingd(
        keys=["mask"],
        pixdim=target_spacing,
        mode=("nearest",),
    ),
    ResampleToMatchd(
        keys=["adc", "hbv"],
        key_dst="t2w",                
        mode=("trilinear", "trilinear"),
    ),
    CropForegroundd(
        keys=keys_all,
        source_key="mask",
        margin=(16, 16, 2),
    ),
    ResizeWithPadOrCropd(
        keys=keys_all,
        spatial_size=target_shape,
    ),
    NormalizeIntensityd(
        keys=keys_img,
        nonzero=True,
        channel_wise=True,
    ),
    ConcatItemsd(
        keys=["adc", "hbv", "t2w"],
        name="image",
        dim=0,
    ),
])

train_ds = Dataset(train_files, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=1)


for batch in train_loader:
    images = batch["image"]
    times  = batch["time_to_event"]
    events = batch["event"]

monai.transforms.spatial.dictionary Orientationd.__init__:labels: Current default value of argument `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` was changed in version None from `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` to `labels=None`. Default value changed to None meaning that the transform now uses the 'space' of a meta-tensor, if applicable, to determine appropriate axis labels.


In [12]:
import torch
import torch.nn as nn
from monai.networks.nets import DenseNet121

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ðŸ‘‰ CHANGE THIS to how many channels you actually stacked:
# adc + hbv + t2w  â†’ in_channels = 3
# just t2w         â†’ in_channels = 1
in_channels = 3  

model = DenseNet121(
    spatial_dims=3,      # 3D volumes
    in_channels=in_channels,
    out_channels=1,      # single risk score per patient
).to(device)

In [None]:
# Suppose your preprocessed images are (B, C, D, H, W)
x = torch.randn(2, in_channels, 64, 128, 128).to(device)  # fake batch
with torch.no_grad():
    out = model(x)
print(out.shape)   # should be torch.Size([2, 1])

torch.Size([2, 1])
