In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from typing import Callable
import SimpleITK as sitk
from torch.utils.data import DataLoader, Dataset
from os.path import abspath
from pathlib import Path
from monai import transforms as mt
import matplotlib.pyplot as plt
from monai.data import CacheDataset

  from .autonotebook import tqdm as notebook_tqdm


## Data Loading

In [2]:
def preprocess(
    image: torch.Tensor,
    label: torch.Tensor | None = None,
    crop_size: tuple[int, ...] = (28, 28, 28)
) -> tuple[torch.Tensor, torch.Tensor]:
    
    # Normalize the image (Z-score normalization)
    image = (image - image.mean() ) / image.std()
    
    # Add a channel dimension to the image
    image = image.unsqueeze(0)

    # Random crop
    crop_origin = [0, 0, 0]
    for dim in range(3):  # Remember, image.shape = [Ch, X, Y, Z ]
        max_value = image.shape[dim+1] - crop_size[dim]
        crop_origin[dim] = torch.randint(0, max_value, (1,)).item()
        
    image = image[
        :,
        crop_origin[0]:crop_origin[0] + crop_size[0],
        crop_origin[1]:crop_origin[1] + crop_size[1],
        crop_origin[2]:crop_origin[2] + crop_size[2],
    ]
    
    # Add a channel dimension to the label
    if label is not None:
        label = label.unsqueeze(0)

        label = label[
            :,
            crop_origin[0]:crop_origin[0] + crop_size[0],
            crop_origin[1]:crop_origin[1] + crop_size[1],
            crop_origin[2]:crop_origin[2] + crop_size[2],
        ]
    
    return image, label

In [3]:
def collect_samples(root: Path, is_test: bool = False) -> list[dict[str, Path]]:
    """
    Collects the samples from the WMH dataset.

    Parameters
    ----------
    root : Path
        The root directory of the dataset.
    is_test : bool
        Whether to collect the test set or the training set.

    Returns
    -------
    list[dict[str, Path]]
        A list of dictionaries containing the image and label paths.
    """
  
    samples = []
    # Iterate through each subject directory
    for subject_dir in root.iterdir():
        if subject_dir.is_dir():  # Check if it's a directory
            flair_path = subject_dir / "FLAIR.nii.gz"
            t1_path = subject_dir / "T1.nii.gz"
            WMH_path = subject_dir / "wmh.nii.gz"
            
            # Check if all required files exist
            if flair_path.exists() and t1_path.exists() and WMH_path.exists():
                sample = {
                    "flair": flair_path,
                    "t1": t1_path,
                    "WMH": WMH_path
                }
                samples.append(sample)
            else:
                print(f"Missing files in {subject_dir}: "
                      f"{'FLAIR' if not flair_path.exists() else ''} "
                      f"{'T1' if not t1_path.exists() else ''} "
                      f"{'wmh' if not WMH_path.exists() else ''}")

    return samples

In [8]:
# TODO - What does this do? 
# is this data augmentation?
# why is the transform different for train and validation?

train_transforms = mt.Compose([
    mt.LoadImaged(keys=["image", "label"]),
    mt.EnsureChannelFirstd(keys=["image", "label"]),
    mt.NormalizeIntensityd(keys=["image"]),
    mt.RandSpatialCropd(keys=["image", "label"], roi_size=[28, 28, 28]),
    # You can add more transforms here! See
    # https://docs.monai.io/en/stable/transforms.html#dictionary-transforms
])

val_transforms = mt.Compose([
    mt.LoadImaged(keys=["image", "label"]),
    mt.EnsureChannelFirstd(keys=["image", "label"]),
    mt.NormalizeIntensityd(keys=["image"]),
])


# Actually running loading 
# TODO - Create a data folder and put the wmh data there in accordance with this path
root = Path(abspath(""))
wmh_path = root.joinpath("data/WMH")
sample_paths = [item for item in wmh_path.iterdir() if item.is_dir()]
samples = []

for sample_path in sample_paths:
    samples.append(collect_samples(sample_path))

# Train and test split
# TODO - Implement shuffeling of the dataset 
# TODO - also make a test dataset
# Maybe think about balancing the dataset aswell between singapore, amsterdam and utrecht. 
# if metadata about patient health status, gender ect is availible we should maybe also balance based on that
train_samples = samples[:int(len(samples) * 0.8)]
val_samples = samples[int(len(samples) * 0.8):]

# You might get an error here, make sure you install the required extra dependencies
# https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
train_ds = CacheDataset(data=train_samples, transform= train_transforms)
val_ds   = CacheDataset(data=val_samples,   transform= val_transforms)

Loading dataset:   0%|          | 0/2 [00:00<?, ?it/s]


RuntimeError: applying transform <monai.transforms.io.dictionary.LoadImaged object at 0x167517490>

In [6]:
def load_nifti_image(image_path: Path):
    """Load a NIfTI image using SimpleITK."""
    image = sitk.ReadImage(str(image_path))
    return sitk.GetArrayFromImage(image)  # Convert to NumPy array

def visualize_sample(sample):
    """Visualize FLAIR, T1, and WMH images from a sample."""
    flair_image = load_nifti_image(sample['flair'])
    t1_image = load_nifti_image(sample['t1'])
    label_image = load_nifti_image(sample['WMH'])

    # Select a slice to visualize (e.g., the middle slice)
    slice_index = flair_image.shape[0] // 2

    # Create subplots
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # FLAIR Image
    axes[0].imshow(flair_image[slice_index, :, :], cmap='gray')
    axes[0].set_title('FLAIR Image')
    axes[0].axis('off')

    # T1 Image
    axes[1].imshow(t1_image[slice_index, :, :], cmap='gray')
    axes[1].set_title('T1 Image')
    axes[1].axis('off')

    # WMH Label
    axes[2].imshow(label_image[slice_index, :, :], cmap='gray')
    axes[2].set_title('WMH')
    axes[2].axis('off')

    plt.show()

# Example usage
# root = Path(r"./WMH/WMH/Amsterdam").resolve()
# root2 = Path(r"./WMH/WMH/Singapore").resolve()
# root3 = Path(r"./WMH/WMH/Utrecht").resolve()
# samples = collect_samples(root) + collect_samples(root2) + collect_samples(root3)

# Visualize the first sample if available
if samples:
    visualize_sample(samples[0])  # Visualize the first sample
else:
    print("No samples found to visualize.")
len(samples)

TypeError: list indices must be integers or slices, not str

## Model setup

In [None]:
# Class holding the double convolution
class DoubleConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        use_norm: bool = True,
    ) -> None:
        super().__init__()
        
        # For now we are just considering 2d input
        # Thus expected input dimensions is [batch_size, channels, height, width] 
        # Or when not using batches [channels, height, width]
        # Or in the convention of pytorch: (N, C, H, W) or (C, H, W)

        # nn.Identity just return it's input so it's used as a replacement for normalization if normalization is not used
        # TO DO: find out what batchnorm does exactly
        # TO DO: Find out how exactly relu works
        conv = nn.Conv2d
        norm = nn.BatchNorm2d if use_norm else nn.Identity 
        activation_function = nn.ReLU

        layers = [
            conv(in_channels, out_channels, 3),
            norm(out_channels),
            activation_function(inplace=True),
            conv(out_channels, out_channels, 3),
            norm(out_channels),
            activation_function(inplace=True)    
        ]

        self.double_conv = nn.Sequential(*layers)

    def forward(self, x):
        return self.double_conv(x)

In [None]:
# Class holding the downsampling
class DownSample(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int
    ) -> None:

        super().__init__()

        # with kernel_size 2 and stride 2 the dimensions will be halved
        self.downsample = nn.MaxPool2d(
            kernel_size=2,
            stride=2
        )

        def forward(self, x):
            return self.downsample(x)

In [None]:
# Class holding the upsampling
class UpSample(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int
    ) -> None:

        super().__init__()

        # with kernel_size 2 and stride 2 the dimensions will be doubled
        self.upsample = nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size=2,
            stride=2
        )

        def forward(self, x):
            return self.upsample(x)

In [None]:
# Class holding the entire model:
class SegmentUnet(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        depth: int = 3,
        first_layer_channel_count: int = 64,
        use_norm: bool = True
    ) -> None:
        super().__init__()
        
        # Define the input layer
        layers = [
            DoubleConv(
                in_channels, 
                first_layer_channel_count, 
                use_norm
            ),
            DownSample(
                in_channels=first_layer_channel_count,
                out_channels=first_layer_channel_count*2
            )
        ]

        current_channels = first_layer_channel_count*2
        # Encoder path
        for _ in range(depth-1): # minus one to account for input layer
            layers.append(DoubleConv(in_channels=current_channels, out_channels=current_channels, use_norm=use_norm))
            layers.append(DownSample(in_channels=current_channels, out_channels=out_channels*2))
            current_channels = current_channels*2 # double channel count each encoder block
        
        # Middle layer
        layers.append(DoubleConv(in_channels=current_channels, out_channels=current_channels, use_norm=use_norm))
        
        # Decoder path
        for _ in range(depth-1): # minus one to account for output layer
            # Keep in mind here that the double conv layers here gets both the output  
            # of upsample concatanated with the skip conncention.
            # So number of channels is doubled
            # We control this by concatanation in the second dimension 
            # the convention of pytorch: (N, C, H, W) or (C, H, W)
            
            layers.append(UpSample(in_channels=current_channels*2, out_channels=current_channels))
            layers.append(DoubleConv(in_channels=current_channels, out_channels=current_channels, use_norm=use_norm))
            current_channels = current_channels // 2

        # Output layer
        layers.append(UpSample(in_channels=current_channels*2, out_channels=current_channels))
        layers.append(DoubleConv(in_channels=current_channels, out_channels=out_channels, use_norm=use_norm))

        # Concatanate layers together
        self.layers = nn.ModuleList(layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        pass
        #for layer in self.layers[0:depth]

## Model training

In [None]:
class ModelTraining():
    def __init__(
        self,
        n_epochs: int,
        batch_size: int,
        learning_rate: float,
        loss_fn: Callable,
        samples: list
    ) -> None:
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = SegmentUNet(in_channels=1, out_channels=1, depth=3).to(self.device)
        self.loss_fn = loss_fn
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
        self.train_ds, self.train_dl, self.val_ds, self.val_dl = self._split_samples(samples)
    
    def run_training_loop(self):
        for epoch in (prog_bar := tqdm(range(self.n_epochs), desc="Training", unit="epoch", total=self.n_epochs, position=0)):
            prog_bar.set_description(f"Training Loop")
            train_losses = self._loop_train()
            prog_bar.set_postfix({"Training loss": sum(train_losses) / len(train_losses)})
            prog_bar.set_description(f"Validation Loop")
            val_losses = self._loop_validate()
            prog_bar.set_postfix({"Training loss": sum(train_losses) / len(train_losses), "Validation loss": sum(val_losses) / len(val_losses)})
        
    def _loop_train(self):
        self.model.train()
        train_losses = []

        for i, (image, label) in tqdm(enumerate(self.train_dl), total=len(self.train_dl), desc="Training", unit="batch", position=1, leave=False):
            
            image, label = image.to(self.device), label.to(self.device)

            self.optimizer.zero_grad() # Clear gradients
            output = self.model(...) # Model forward pass
            loss = loss_fn(output, label)  # Compute loss
            loss.backward()  # Backpropagate loss
            self.optimizer.step()  # Update model weights

            train_losses.append(loss.item()) # Append training loss for this batch

        return train_losses

    def _loop_validate(self):
        self.model.eval() # We set the model in evaluation mode
        val_losses = []
        for i, (image, label) in tqdm(enumerate(self.val_dl), total=len(self.val_dl), desc="Validation", unit="batch", position=1, leave=False):
            image, label = image.to(self.device), label.to(self.device)

            with torch.no_grad():
                output = self.model(image)
                loss = loss_fn(output, label)
            
            val_losses.append(loss.item())

        return val_losses

    def _split_samples(self, samples): #TODO - adjust this function to our data
        train_samples = samples[0 : int(len(samples) * 0.8)]
        val_samples = samples[int(len(samples) * 0.8): ]

        train_ds = MedicalDecathlonDataset(train_samples) #TODO - change to our own dataset class
        train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True) #TODO - change to our DataLoader class

        val_ds = MedicalDecathlonDataset(val_samples) #TODO - change to our own dataset class
        val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False) #TODO - change to our DataLoader class
        return train_ds, train_dl, val_ds, val_dl


In [None]:
loss_functions = [nn.BCEWithLogitsLoss()] #TODO - add different loss functions to evaluate
for loss_fn in loss_functions:
    Model_training = ModelTraining(
        n_epochs = 10,
        batch_size = 4,
        learning_rate = 1e-3,
        loss_fn = loss_fn,
        samples = []
        )

    Model_training.run_training_loop()