In [None]:
import torch
from torch import nn
from tqdm.auto import tqdm
from typing import Callable
import SimpleITK as sitk
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
from monai import transforms as mt
import matplotlib.pyplot as plt
from monai.data import CacheDataset

## Data Loading

In [None]:
def preprocess(
    image: torch.Tensor,
    label: torch.Tensor | None = None,
    crop_size: tuple[int, ...] = (28, 28, 28)
) -> tuple[torch.Tensor, torch.Tensor]:
    
    # Normalize the image (Z-score normalization)
    image = (image - image.mean() ) / image.std()
    
    # Add a channel dimension to the image
    image = image.unsqueeze(0)

    # Random crop
    crop_origin = [0, 0, 0]
    for dim in range(3):  # Remember, image.shape = [Ch, X, Y, Z ]
        max_value = image.shape[dim+1] - crop_size[dim]
        crop_origin[dim] = torch.randint(0, max_value, (1,)).item()
        
    image = image[
        :,
        crop_origin[0]:crop_origin[0] + crop_size[0],
        crop_origin[1]:crop_origin[1] + crop_size[1],
        crop_origin[2]:crop_origin[2] + crop_size[2],
    ]
    
    # Add a channel dimension to the label
    if label is not None:
        label = label.unsqueeze(0)

        label = label[
            :,
            crop_origin[0]:crop_origin[0] + crop_size[0],
            crop_origin[1]:crop_origin[1] + crop_size[1],
            crop_origin[2]:crop_origin[2] + crop_size[2],
        ]
    
    return image, label

In [None]:
def collect_samples(root: Path, is_test: bool = False) -> list[dict[str, Path]]:
    """
    Collects the samples from the Medical Decathlon dataset.

    Parameters
    ----------
    root : Path
        The root directory of the dataset.
    is_test : bool
        Whether to collect the test set or the training set.

    Returns
    -------
    list[dict[str, Path]]
        A list of dictionaries containing the image and label paths.
    """
  
   # if test_set:
    #    image_dir = root / "imagesTs"
   #     label_dir = None
   # else:
    #    image_dir = root / "imagesTr"
     #   label_dir = root / "labelsTr"
    #samples = []
    samples = []
    # Iterate through each subject directory
    for subject_dir in root.iterdir():
        if subject_dir.is_dir():  # Check if it's a directory
            flair_path = subject_dir / "FLAIR.nii.gz"
            t1_path = subject_dir / "T1.nii.gz"
            WMH_path = subject_dir / "wmh.nii.gz"
            
            # Check if all required files exist
            if flair_path.exists() and t1_path.exists() and WMH_path.exists():
                sample = {
                    "flair": flair_path,
                    "t1": t1_path,
                    "WMH": WMH_path
                }
                samples.append(sample)
            else:
                print(f"Missing files in {subject_dir}: "
                      f"{'FLAIR' if not flair_path.exists() else ''} "
                      f"{'T1' if not t1_path.exists() else ''} "
                      f"{'wmh' if not WMH_path.exists() else ''}")

    return samples

class MedicalDecathlonDataset(Dataset):
    def __init__(self, samples: list[tuple[Path, ...]], test: bool = False) -> None:
        self.samples = samples 
        
    def __len__(self) -> int:
        return len(self.samples)
    
    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
        image_path, label_path = self.samples[idx]
        
        image = sitk.ReadImage(image_path)
        image_array = sitk.GetArrayFromImage(image)
        image = torch.tensor(image_array,dtype=torch.float32)  # Convert image to PyTorch tensor and cast it to float

        if label_path is None:
           image, _ = preprocess(image)
           return image

        label = sitk.ReadImage(label_path)
        label_array = sitk.GetArrayFromImage(label)
        label = torch.tensor(label_array,dtype=torch.float32)

        return preprocess(image, label)

In [None]:
train_transforms = mt.Compose([
    mt.LoadImaged(keys=["image", "label"]),
    mt.EnsureChannelFirstd(keys=["image", "label"]),
    mt.NormalizeIntensityd(keys=["image"]),
    mt.RandSpatialCropd(keys=["image", "label"], roi_size=[28, 28, 28]),
    # You can add more transforms here! See
    # https://docs.monai.io/en/stable/transforms.html#dictionary-transforms
])

val_transforms = mt.Compose([
    mt.LoadImaged(keys=["image", "label"]),
    mt.EnsureChannelFirstd(keys=["image", "label"]),
    mt.NormalizeIntensityd(keys=["image"]),
])

root = Path(r"C:/Users/20213084/OneDrive - TU Eindhoven/Desktop/8UU22/Group project/WMH/WMH/Utrecht/0").resolve()
samples = collect_samples(root)

train_samples = samples[:int(len(samples) * 0.8)]
val_samples = samples[int(len(samples) * 0.8):]

# You might get an error here, make sure you install the required extra dependencies
# https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
train_ds = CacheDataset(data=train_samples, transform= train_transforms)
val_ds   = CacheDataset(data=val_samples,   transform= val_transforms)

In [None]:
def load_nifti_image(image_path: Path):
    """Load a NIfTI image using SimpleITK."""
    image = sitk.ReadImage(str(image_path))
    return sitk.GetArrayFromImage(image)  # Convert to NumPy array

def visualize_sample(sample):
    """Visualize FLAIR, T1, and WMH images from a sample."""
    flair_image = load_nifti_image(sample['flair'])
    t1_image = load_nifti_image(sample['t1'])
    label_image = load_nifti_image(sample['WMH'])

    # Select a slice to visualize (e.g., the middle slice)
    slice_index = flair_image.shape[0] // 2

    # Create subplots
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # FLAIR Image
    axes[0].imshow(flair_image[slice_index, :, :], cmap='gray')
    axes[0].set_title('FLAIR Image')
    axes[0].axis('off')

    # T1 Image
    axes[1].imshow(t1_image[slice_index, :, :], cmap='gray')
    axes[1].set_title('T1 Image')
    axes[1].axis('off')

    # WMH Label
    axes[2].imshow(label_image[slice_index, :, :], cmap='gray')
    axes[2].set_title('WMH')
    axes[2].axis('off')

    plt.show()

# Example usage
root = Path(r"./WMH/WMH/Amsterdam").resolve()
root2 = Path(r"./WMH/WMH/Singapore").resolve()
root3 = Path(r"./WMH/WMH/Utrecht").resolve()
samples = collect_samples(root) + collect_samples(root2) + collect_samples(root3)

# Visualize the first sample if available
if samples:
    visualize_sample(samples[0])  # Visualize the first sample
else:
    print("No samples found to visualize.")
len(samples)