In [None]:
# MONAI Inference on DICOM and NRRD Images Using SimpleITK
"""
This notebook demonstrates how to perform inference on 3D medical images using the MONAI framework.
It includes:
- Loading NRRD and DICOM data using SimpleITK
- Applying MONAI transforms for preprocessing
- Running a basic 3D UNet model (untrained)
- Visualizing input and output mid-slices

Note: The UNet model here is not trained, so results are not meaningful.
Replace with a trained model for actual segmentation use.
"""

import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
from monai.transforms import Compose, AddChannel, ScaleIntensity, Resize, ToTensor
from monai.data import Dataset, DataLoader
from monai.networks.nets import UNet

In [None]:
# Paths to your input files
nrrd_file_path = "your_data/image3d.nrrd"  # Replace with your actual NRRD file path
dicom_folder_path = "your_data/dicom_series/"  # Replace with your DICOM folder path

In [None]:
# Load NRRD and DICOM using SimpleITK
def load_nrrd(path):
    """Load and normalize a NRRD file as a 3D NumPy array."""
    image = sitk.ReadImage(path)
    array = sitk.GetArrayFromImage(image)
    array = array.astype(np.float32)
    return (array - np.min(array)) / (np.max(array) - np.min(array))

def load_dicom(folder):
    """Load and normalize a DICOM series folder as a 3D NumPy array."""
    reader = sitk.ImageSeriesReader()
    series_files = reader.GetGDCMSeriesFileNames(folder)
    reader.SetFileNames(series_files)
    image = reader.Execute()
    array = sitk.GetArrayFromImage(image)
    array = array.astype(np.float32)
    return (array - np.min(array)) / (np.max(array) - np.min(array))

In [None]:
# Load the data
image_nrrd = load_nrrd(nrrd_file_path)
image_dicom = load_dicom(dicom_folder_path)

In [None]:
# Create MONAI transforms
trans_3d = Compose([
    AddChannel(),                # (D, H, W) -> (1, D, H, W)
    ScaleIntensity(),           # Normalize intensities to [0, 1]
    Resize((128, 128, 128)),    # Resize to a shape compatible with UNet
    ToTensor()                  # Convert to PyTorch tensor
])

# Wrap in dicts and use MONAI Dataset
nrrd_dict = [{"image": image_nrrd}]
dicom_dict = [{"image": image_dicom}]

dataset_nrrd = Dataset(data=nrrd_dict, transform=trans_3d)
dataset_dicom = Dataset(data=dicom_dict, transform=trans_3d)

tensor_nrrd = next(iter(DataLoader(dataset_nrrd)))['image']  # Shape: (1, 1, 128, 128, 128)
tensor_dicom = next(iter(DataLoader(dataset_dicom)))['image']

In [None]:
# Download and load pretrained model from Hugging Face
from huggingface_hub import hf_hub_download

# Download the model checkpoint
model_ckpt_path = hf_hub_download(repo_id="wdika/SEG_UNet3D_BraTS2023AdultGlioma", filename="model.pth")

# Define the architecture matching the pretrained model
model = UNet(
    dimensions=3,
    in_channels=4,  # Trained with 4 MRI modalities (e.g., T1, T1c, T2, FLAIR)
    out_channels=3,  # Typically 3 tumor regions in BraTS
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
)

# Load pretrained weights
model.load_state_dict(torch.load(model_ckpt_path, map_location=torch.device('cpu')))
model.eval()

# Inference
with torch.no_grad():
    output_nrrd = model(tensor_nrrd)
    output_dicom = model(tensor_dicom)

In [None]:
# Convert to numpy for visualization
output_nrrd_np = output_nrrd.detach().cpu().numpy()[0, 0]
output_dicom_np = output_dicom.detach().cpu().numpy()[0, 0]
input_nrrd_np = tensor_nrrd.detach().cpu().numpy()[0, 0]
input_dicom_np = tensor_dicom.detach().cpu().numpy()[0, 0]

# Visualization of mid slices
slice_index_nrrd = input_nrrd_np.shape[0] // 2
slice_index_dicom = input_dicom_np.shape[0] // 2

plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.title("NRRD Input (Mid Slice)")
plt.imshow(input_nrrd_np[slice_index_nrrd], cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.title("DICOM Input (Mid Slice)")
plt.imshow(input_dicom_np[slice_index_dicom], cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title("DICOM Output (Mid Slice)")
plt.imshow(output_dicom_np[slice_index_dicom], cmap='gray')
plt.axis('off')

plt.tight_layout()
plt.show()