# SwinUNETR Unet

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Install MONAI and other dependencies
!pip install monai==0.9.1
!pip install nibabel
!pip install SimpleITK

# Import necessary libraries
import os
import glob
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import nibabel as nib
import SimpleITK as sitk

import monai
from monai.data import CacheDataset, DataLoader
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, ScaleIntensityd,
    RandSpatialCropd, RandFlipd, RandRotate90d, SpatialCropd,
    EnsureTyped, MapTransform, Transform
)
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.networks.nets import SwinUNETR  # Changed from AttentionUnet to SwinUNETR
from monai.utils import set_determinism

from google.colab import drive
drive.mount("/content/drive")

# Set determinism for reproducibility
set_determinism(seed=0)

# Define the data directory
data_dir = '/content/drive/MyDrive/Dataset_Final/'  # Update this path if different

# Function to load data dictionaries
def get_data_dicts(data_dir, set_name):
    data_pattern = os.path.join(data_dir, set_name, 'IBSR_*')
    subject_dirs = glob.glob(data_pattern)
    data_dicts = []

    for subject_dir in subject_dirs:
        subject_name = os.path.basename(subject_dir)
        img_file = os.path.join(subject_dir, '{}.nii.gz'.format(subject_name))
        seg_file = os.path.join(subject_dir, '{}_seg.nii.gz'.format(subject_name))
        data_dicts.append({'image': img_file, 'label': seg_file})

    # Debug: Print shapes of first few images
    if len(data_dicts) > 0:
        sample_image = nib.load(data_dicts[0]['image']).get_fdata()
        sample_label = nib.load(data_dicts[0]['label']).get_fdata()
        print(f"Sample image shape: {sample_image.shape}")
        print(f"Sample label shape: {sample_label.shape}")

    return data_dicts

# Get training and validation data dictionaries
train_files = get_data_dicts(data_dir, 'Training_Set')
val_files = get_data_dicts(data_dir, 'Validation_Set')

# Custom Transform: CenterSpatialCropd
class CenterSpatialCropd(Transform):
    def __init__(self, keys, roi_size):
        self.keys = keys
        self.roi_size = roi_size

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            img = d[key]
            spatial_shape = img.shape[1:]  # Assuming (C, D, H, W)
            center = [dim // 2 for dim in spatial_shape]
            # Apply SpatialCropd with roi_center
            crop = SpatialCropd(keys=[key], roi_size=self.roi_size, roi_center=center)
            d = crop(d)
        return d

# Define training transformations
train_transforms = Compose([
    LoadImaged(keys=['image', 'label']),
    EnsureChannelFirstd(keys=['image', 'label']),
    ScaleIntensityd(keys=['image']),
    RandSpatialCropd(keys=['image', 'label'], roi_size=[96, 96, 96], random_size=False),
    RandFlipd(keys=['image', 'label'], prob=0.5, spatial_axis=0),
    RandRotate90d(keys=['image', 'label'], prob=0.5, max_k=3),
    EnsureTyped(keys=['image', 'label']),
])

# Define validation transformations with deterministic spatial cropping
val_transforms = Compose([
    LoadImaged(keys=['image', 'label']),
    EnsureChannelFirstd(keys=['image', 'label']),
    ScaleIntensityd(keys=['image']),
    CenterSpatialCropd(keys=['image', 'label'], roi_size=[96, 96, 96]),  # Deterministic center crop
    EnsureTyped(keys=['image', 'label']),
])

# Use CacheDataset for faster data loading
num_workers = 4  # Adjust based on your system's capabilities

# Create training and validation datasets and dataloaders
train_ds = CacheDataset(
    data=train_files,
    transform=train_transforms,
    cache_rate=1.0,  # Cache all data
    num_workers=num_workers
)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=num_workers)

val_ds = CacheDataset(
    data=val_files,
    transform=val_transforms,
    cache_rate=1.0,  # Cache all data
    num_workers=num_workers
)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=num_workers)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the model (using MONAI's SwinUNETR)
try:
    model = SwinUNETR(
        in_channels=1,
        out_channels=4,               # Number of classes
        img_size=[96, 96, 96],        # Should match the ROI size used in transforms
        feature_size=48,              # Base number of features
        use_checkpoint=False,         # Disable checkpointing if unsupported
    ).to(device)
    print("SwinUNETR model instantiated successfully!")
except TypeError as e:
    print(f"Error during model instantiation: {e}")
    print("Please refer to MONAI v0.9.1 documentation for supported parameters.")
    raise e  # Re-raise the exception after informing the user

# Optionally, load pre-trained weights if available
# model.load_state_dict(torch.load('/path/to/pretrained/model.pth'))

# Define loss function and optimizer
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Define metric for evaluation without to_onehot_y and softmax
dice_metric = DiceMetric(
    include_background=True,
    reduction="mean",
    get_not_nans=False,
)

# Training loop
max_epochs = 2000
val_interval = 2  # Do validation every 2 epochs
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []

for epoch in range(max_epochs):
    print('-' * 10)
    print('Epoch {}/{}'.format(epoch + 1, max_epochs))
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data['image'].to(device), batch_data['label'].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print('{}/{}, train_loss: {:.4f}'.format(step, len(train_loader), loss.item()))
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print('Epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss))

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            metric_sum = 0.0
            metric_count = 0
            for val_data in val_loader:
                val_images = val_data['image'].to(device)
                val_labels = val_data['label'].to(device)
                val_outputs = model(val_images)

                # Apply softmax to model outputs
                val_outputs_softmax = F.softmax(val_outputs, dim=1)

                # One-hot encode labels
                val_labels_onehot = F.one_hot(val_labels.squeeze(1).long(), num_classes=4)  # Shape: (1, H, W, D, C)
                val_labels_onehot = val_labels_onehot.permute(0, 4, 1, 2, 3).float()       # Shape: (1, C, H, W, D)

                # Update DiceMetric
                dice_metric(y_pred=val_outputs_softmax, y=val_labels_onehot)

                metric_count += 1

            # Compute the mean dice score
            metric = dice_metric.aggregate().item()
            dice_metric.reset()
            metric_values.append(metric)

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), 'best_metric_model_SWIN.pth')
                print('Saved new best metric model')
            print('Current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}'.format(
                epoch + 1, metric, best_metric, best_metric_epoch))

print('Training completed, best_metric: {:.4f} at epoch: {}'.format(best_metric, best_metric_epoch))

# Load the best model for evaluation
model.load_state_dict(torch.load('best_metric_model_SWINUNETR.pth'))

# Evaluation functions
def compute_dice_coefficient(in1, in2, label=1):
    in1 = in1 == label
    in2 = in2 == label
    intersection = np.logical_and(in1, in2).sum()
    volumes = in1.sum() + in2.sum()
    if volumes == 0:
        return np.NaN
    else:
        return 2. * intersection / volumes

def compute_hausdorff_distance(in1, in2, label=1):
    in1 = (in1 == label).astype(np.uint8)
    in2 = (in2 == label).astype(np.uint8)
    in1_sitk = sitk.GetImageFromArray(in1)
    in2_sitk = sitk.GetImageFromArray(in2)
    hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
    hausdorff_distance_filter.Execute(in1_sitk, in2_sitk)
    return hausdorff_distance_filter.GetHausdorffDistance()

def compute_volumetric_difference(in1, in2, label=1):
    in1 = in1 == label
    in2 = in2 == label
    vol1 = in1.sum()
    vol2 = in2.sum()
    if vol1 + vol2 == 0:
        return np.NaN
    else:
        return abs(vol1 - vol2) / (vol1 + vol2)

# Evaluate on validation data
model.eval()
dice_scores = {i: [] for i in range(4)}  # Assuming 4 classes
hausdorff_distances = {i: [] for i in range(4)}
volumetric_differences = {i: [] for i in range(4)}

with torch.no_grad():
    for val_data in val_loader:
        val_images = val_data['image'].to(device)
        val_labels = val_data['label'].cpu().numpy()
        outputs = model(val_images)
        outputs = torch.argmax(F.softmax(outputs, dim=1), dim=1).cpu().numpy()  # Shape: (1, H, W, D)

        for label in range(4):  # For each class
            dice = compute_dice_coefficient(outputs[0], val_labels[0][0], label=label)
            hd = compute_hausdorff_distance(outputs[0], val_labels[0][0], label=label)
            vd = compute_volumetric_difference(outputs[0], val_labels[0][0], label=label)
            if not np.isnan(dice):
                dice_scores[label].append(dice)
            if not np.isnan(hd):
                hausdorff_distances[label].append(hd)
            if not np.isnan(vd):
                volumetric_differences[label].append(vd)

# Compute average metrics per class
for label in range(4):
    avg_dice = np.mean(dice_scores[label]) if dice_scores[label] else np.NaN
    avg_hd = np.mean(hausdorff_distances[label]) if hausdorff_distances[label] else np.NaN
    avg_vd = np.mean(volumetric_differences[label]) if volumetric_differences[label] else np.NaN
    print(f"Class {label}:")
    print(f"\tDice Coefficient = {avg_dice:.4f}")
    print(f"\tHausdorff Distance = {avg_hd:.4f}")
    print(f"\tVolumetric Difference = {avg_vd:.4f}")


Collecting monai==0.9.1
  Downloading monai-0.9.1-202207251608-py3-none-any.whl.metadata (7.5 kB)
Downloading monai-0.9.1-202207251608-py3-none-any.whl (990 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m990.7/990.7 kB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: monai
Successfully installed monai-0.9.1
Collecting SimpleITK
  Downloading SimpleITK-2.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.9 kB)
Downloading SimpleITK-2.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.4/52.4 MB[0m [31m39.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: SimpleITK
Successfully installed SimpleITK-2.4.0
Mounted at /content/drive
Sample image shape: (256, 128, 256, 1)
Sample label shape: (256, 128, 256, 1)
Sample image shape: (256, 128, 256, 1)
Sample label shape: (256, 128, 256, 1)


Loading dataset: 100%|██████████| 10/10 [00:02<00:00,  3.51it/s]
Loading dataset: 100%|██████████| 5/5 [00:01<00:00,  4.09it/s]


SwinUNETR model instantiated successfully!
----------
Epoch 1/2000


  ret = func(*args, **kwargs)
  ret = func(*args, **kwargs)
  ret = func(*args, **kwargs)
  ret = func(*args, **kwargs)
  t = cls([], dtype=storage.dtype, device=storage.device)


1/10, train_loss: 0.8055
2/10, train_loss: 0.7625
3/10, train_loss: 0.7495
4/10, train_loss: 0.6764
5/10, train_loss: 0.6809
6/10, train_loss: 0.6615
7/10, train_loss: 0.7765
8/10, train_loss: 0.6154
9/10, train_loss: 0.8063
10/10, train_loss: 0.5544
Epoch 1 average loss: 0.7089
----------
Epoch 2/2000
1/10, train_loss: 0.7974
2/10, train_loss: 0.5136
3/10, train_loss: 0.8641
4/10, train_loss: 0.6691
5/10, train_loss: 0.4902
6/10, train_loss: 0.4826
7/10, train_loss: 0.6213
8/10, train_loss: 0.5456
9/10, train_loss: 0.5369
10/10, train_loss: 0.6074
Epoch 2 average loss: 0.6128




[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Current epoch: 1630 current mean dice: 0.9175 best mean dice: 0.9269 at epoch 1436
----------
Epoch 1631/2000
1/10, train_loss: 0.0552
2/10, train_loss: 0.0495
3/10, train_loss: 0.0612
4/10, train_loss: 0.0952
5/10, train_loss: 0.2849
6/10, train_loss: 0.0568
7/10, train_loss: 0.1108
8/10, train_loss: 0.0691
9/10, train_loss: 0.3426
10/10, train_loss: 0.0856
Epoch 1631 average loss: 0.1211
----------
Epoch 1632/2000
1/10, train_loss: 0.0512
2/10, train_loss: 0.2730
3/10, train_loss: 0.0656
4/10, train_loss: 0.0720
5/10, train_loss: 0.0442
6/10, train_loss: 0.0612
7/10, train_loss: 0.0918
8/10, train_loss: 0.0662
9/10, train_loss: 0.0727
10/10, train_loss: 0.0860
Epoch 1632 average loss: 0.0884
Current epoch: 1632 current mean dice: 0.9214 best mean dice: 0.9269 at epoch 1436
----------
Epoch 1633/2000
1/10, train_loss: 0.0581
2/10, train_loss: 0.0659
3/10, train_loss: 0.3072
4/10, train_loss: 0.0626
5/10, train_loss: 0.05

  model.load_state_dict(torch.load('best_metric_model_SWINUNETR.pth'))


FileNotFoundError: [Errno 2] No such file or directory: 'best_metric_model_SWINUNETR.pth'

In [None]:
# Install MONAI and other dependencies
!pip install monai==0.9.1
!pip install nibabel
!pip install SimpleITK

# Import necessary libraries
import os
import glob
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import nibabel as nib
import SimpleITK as sitk

import monai
from monai.data import CacheDataset, DataLoader
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, ScaleIntensityd,
    SpatialCropd, EnsureTyped, Transform
)
from monai.networks.nets import SwinUNETR
from monai.utils import set_determinism
from monai.visualize import blend_images

from google.colab import drive
drive.mount("/content/drive")

# Set determinism for reproducibility
set_determinism(seed=0)

# Define the data directory
data_dir = '/content/drive/MyDrive/Dataset_Final/'  # Update this path if different

# Function to load data dictionaries
def get_data_dicts(data_dir, set_name):
    data_pattern = os.path.join(data_dir, set_name, 'IBSR_*')
    subject_dirs = glob.glob(data_pattern)
    data_dicts = []

    for subject_dir in subject_dirs:
        subject_name = os.path.basename(subject_dir)
        img_file = os.path.join(subject_dir, '{}.nii.gz'.format(subject_name))
        seg_file = os.path.join(subject_dir, '{}_seg.nii.gz'.format(subject_name))
        data_dicts.append({'image': img_file, 'label': seg_file})

    return data_dicts

# Get training and validation data dictionaries
train_files = get_data_dicts(data_dir, 'Training_Set')
val_files = get_data_dicts(data_dir, 'Validation_Set')

# Custom Transform: CenterSpatialCropd
class CenterSpatialCropd(Transform):
    def __init__(self, keys, roi_size):
        self.keys = keys
        self.roi_size = roi_size

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            img = d[key]
            spatial_shape = img.shape[1:]  # Assuming (C, D, H, W)
            center = [dim // 2 for dim in spatial_shape]
            # Apply SpatialCropd with roi_center
            crop = SpatialCropd(keys=[key], roi_size=self.roi_size, roi_center=center)
            d = crop(d)
        return d

# Define validation transformations with deterministic spatial cropping
val_transforms = Compose([
    LoadImaged(keys=['image', 'label']),
    EnsureChannelFirstd(keys=['image', 'label']),
    ScaleIntensityd(keys=['image']),
    CenterSpatialCropd(keys=['image', 'label'], roi_size=[96, 96, 96]),
    EnsureTyped(keys=['image', 'label']),
])

# Use CacheDataset for faster data loading
num_workers = 4  # Adjust based on your system's capabilities

# Create validation dataset and dataloader
val_ds = CacheDataset(
    data=val_files,
    transform=val_transforms,
    cache_rate=1.0,
    num_workers=num_workers
)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=num_workers)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the model (using MONAI's SwinUNETR)
model = SwinUNETR(
    in_channels=1,
    out_channels=4,
    img_size=[96, 96, 96],
    feature_size=48,
    use_checkpoint=False,
).to(device)
print("SwinUNETR model instantiated successfully!")

# Load the trained model weights
model_weights_path = '/content/drive/MyDrive/Dataset_Final/best_metric_model_SWIN.pth'  # Update this path
if os.path.exists(model_weights_path):
    model.load_state_dict(torch.load(model_weights_path, map_location=device))
    print(f"Loaded model weights from {model_weights_path}")
else:
    raise FileNotFoundError(f"Model weights not found at {model_weights_path}. Please check the path.")

# Define evaluation functions
def compute_dice_coefficient(in1, in2, label=1):
    in1 = in1 == label
    in2 = in2 == label
    intersection = np.logical_and(in1, in2).sum()
    volumes = in1.sum() + in2.sum()
    if volumes == 0:
        return np.NaN
    else:
        return 2. * intersection / volumes

def compute_hausdorff_distance(in1, in2, label=1):
    in1 = (in1 == label).astype(np.uint8)
    in2 = (in2 == label).astype(np.uint8)
    in1_sitk = sitk.GetImageFromArray(in1)
    in2_sitk = sitk.GetImageFromArray(in2)
    hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
    hausdorff_distance_filter.Execute(in1_sitk, in2_sitk)
    return hausdorff_distance_filter.GetHausdorffDistance()

def compute_volumetric_difference(in1, in2, label=1):
    in1 = in1 == label
    in2 = in2 == label
    vol1 = in1.sum()
    vol2 = in2.sum()
    if vol1 + vol2 == 0:
        return np.NaN
    else:
        return abs(vol1 - vol2) / (vol1 + vol2)

# Function to resample image to match reference image
def resample_to_match(image, reference_image):
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(reference_image)
    resampler.SetInterpolator(sitk.sitkNearestNeighbor)
    resampled_image = resampler.Execute(image)
    return resampled_image

# Directory to save the segmented outputs
output_dir = '/content/drive/MyDrive/Segmented_Outputs/'  # Update this path as needed
os.makedirs(output_dir, exist_ok=True)

# Evaluate on validation data
model.eval()
dice_scores = {i: [] for i in range(4)}
hausdorff_distances = {i: [] for i in range(4)}
volumetric_differences = {i: [] for i in range(4)}

with torch.no_grad():
    for idx, val_data in enumerate(val_loader):
        val_images = val_data['image'].to(device)
        val_labels = val_data['label'].cpu().numpy()
        outputs = model(val_images)
        outputs = torch.argmax(F.softmax(outputs, dim=1), dim=1).cpu().numpy()

        # Save the segmented output before resampling for visualization
        segmented_output_unresampled = outputs[0].astype(np.uint8)
        unresampled_output_file = os.path.join(output_dir, f"segmented_sample_{idx + 1}_cropped.nii.gz")
        # Create a SimpleITK image
        segmented_image_unresampled = sitk.GetImageFromArray(segmented_output_unresampled)
        # Copy the image information from the cropped image
        cropped_image_sitk = sitk.GetImageFromArray(val_data['image'][0].cpu().numpy()[0])
        segmented_image_unresampled.CopyInformation(cropped_image_sitk)
        # Save the image
        sitk.WriteImage(segmented_image_unresampled, unresampled_output_file)

        # Now proceed with resampling and saving the full-size segmentation
        segmented_image = sitk.GetImageFromArray(outputs[0].astype(np.uint8))
        original_image = sitk.ReadImage(val_data['image_meta_dict']['filename_or_obj'][0])

        # Resample segmented image to match the original image's size
        if segmented_image.GetSize() != original_image.GetSize():
            segmented_image = resample_to_match(segmented_image, original_image)

        # Copy metadata
        segmented_image.CopyInformation(original_image)
        output_file = os.path.join(output_dir, f"segmented_sample_{idx + 1}.nii.gz")
        sitk.WriteImage(segmented_image, output_file)
        print(f"Saved segmented output for sample {idx + 1} to {output_file}")

        # Calculate metrics
        for label in range(4):
            dice = compute_dice_coefficient(outputs[0], val_labels[0][0], label=label)
            hd = compute_hausdorff_distance(outputs[0], val_labels[0][0], label=label)
            vd = compute_volumetric_difference(outputs[0], val_labels[0][0], label=label)
            if not np.isnan(dice):
                dice_scores[label].append(dice)
            if not np.isnan(hd):
                hausdorff_distances[label].append(hd)
            if not np.isnan(vd):
                volumetric_differences[label].append(vd)
        print(f"Processed {idx + 1}/{len(val_loader)} validation samples.")

# Compute average metrics per class
print("\n--- Evaluation Metrics on Validation Set ---")
for label in range(4):
    avg_dice = np.mean(dice_scores[label]) if dice_scores[label] else np.NaN
    avg_hd = np.mean(hausdorff_distances[label]) if hausdorff_distances[label] else np.NaN
    avg_vd = np.mean(volumetric_differences[label]) if volumetric_differences[label] else np.NaN
    print(f"Class {label}:")
    print(f"\tDice Coefficient = {avg_dice:.4f}")
    print(f"\tHausdorff Distance = {avg_hd:.4f}")
    print(f"\tVolumetric Difference = {avg_vd:.4f}")

# Visualization using MONAI
def visualize_monai(sample_index, slice_index, val_ds, output_dir):
    data = val_ds[sample_index]
    image = data['image']  # Shape: (C, D, H, W)
    label = data['label']  # Shape: (C, D, H, W)
    segmented_output_path = os.path.join(output_dir, f"segmented_sample_{sample_index + 1}_cropped.nii.gz")
    segmented_output = nib.load(segmented_output_path).get_fdata()

    # Convert numpy arrays to torch tensors
    image_tensor = torch.from_numpy(image).float()
    label_tensor = torch.from_numpy(label).float()
    segmented_output_tensor = torch.from_numpy(segmented_output[np.newaxis, ...]).float()  # Add channel dimension

    # Normalize image for visualization
    image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min())

    # Select the slice
    # Assuming slice_index corresponds to the D dimension
    image_slice = image_tensor[0, slice_index, :, :]  # Shape: (H, W)
    label_slice = label_tensor[0, slice_index, :, :]  # Shape: (H, W)
    segmented_slice = segmented_output_tensor[0, slice_index, :, :]  # Shape: (H, W)

    # Blend images
    blended_gt = blend_images(image_slice.unsqueeze(0), label_slice.unsqueeze(0), alpha=0.5)
    blended_pred = blend_images(image_slice.unsqueeze(0), segmented_slice.unsqueeze(0), alpha=0.5)

    # Convert blended images to (H, W, C)
    blended_gt_np = blended_gt.squeeze().permute(1, 2, 0).numpy()
    blended_pred_np = blended_pred.squeeze().permute(1, 2, 0).numpy()

    # Plot the images
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))

    axs[0].imshow(blended_gt_np)
    axs[0].set_title('Original with Ground Truth')
    axs[0].axis('off')

    axs[1].imshow(blended_pred_np)
    axs[1].set_title('Original with Predicted Segmentation')
    axs[1].axis('off')

    plt.show()

# Example usage:
# sample_index = 0  # Index of the sample in the validation dataset
# slice_index = 48  # Index of the slice to visualize (should be within the D dimension range)

# visualize_monai(sample_index, slice_index, val_ds, output_dir)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Loading dataset: 100%|██████████| 5/5 [00:00<00:00,  6.50it/s]


SwinUNETR model instantiated successfully!


  model.load_state_dict(torch.load(model_weights_path, map_location=device))


Loaded model weights from /content/drive/MyDrive/Dataset_Final/best_metric_model_SWIN.pth
Saved segmented output for sample 1 to /content/drive/MyDrive/Segmented_Outputs/segmented_sample_1.nii.gz
Processed 1/5 validation samples.
Saved segmented output for sample 2 to /content/drive/MyDrive/Segmented_Outputs/segmented_sample_2.nii.gz
Processed 2/5 validation samples.
Saved segmented output for sample 3 to /content/drive/MyDrive/Segmented_Outputs/segmented_sample_3.nii.gz
Processed 3/5 validation samples.
Saved segmented output for sample 4 to /content/drive/MyDrive/Segmented_Outputs/segmented_sample_4.nii.gz
Processed 4/5 validation samples.
Saved segmented output for sample 5 to /content/drive/MyDrive/Segmented_Outputs/segmented_sample_5.nii.gz
Processed 5/5 validation samples.

--- Evaluation Metrics on Validation Set ---
Class 0:
	Dice Coefficient = 0.9490
	Hausdorff Distance = 19.4543
	Volumetric Difference = 0.0205
Class 1:
	Dice Coefficient = 0.9019
	Hausdorff Distance = 19.6174


In [None]:
# Install MONAI and other dependencies
!pip install monai==0.9.1
!pip install nibabel
!pip install SimpleITK

# Import necessary libraries
import os
import glob
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import nibabel as nib
import SimpleITK as sitk

import monai
from monai.data import CacheDataset, DataLoader
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, ScaleIntensityd,
    RandSpatialCropd, RandFlipd, RandRotate90d,
    EnsureTyped
)
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.utils import set_determinism

from google.colab import drive
drive.mount("/content/drive")

# Set determinism for reproducibility
set_determinism(seed=0)

# Define the data directory
data_dir = '/content/drive/MyDrive/Dataset_Final/'  # Update this path if different

# Function to load data dictionaries
def get_data_dicts(data_dir, set_name):
    data_pattern = os.path.join(data_dir, set_name, 'IBSR_*')
    subject_dirs = glob.glob(data_pattern)
    data_dicts = []

    for subject_dir in subject_dirs:
        subject_name = os.path.basename(subject_dir)
        img_file = os.path.join(subject_dir, '{}.nii.gz'.format(subject_name))
        seg_file = os.path.join(subject_dir, '{}_seg.nii.gz'.format(subject_name))
        data_dicts.append({'image': img_file, 'label': seg_file})

    return data_dicts

# Get training and validation data dictionaries
train_files = get_data_dicts(data_dir, 'Training_Set')
val_files = get_data_dicts(data_dir, 'Validation_Set')

# Define training transformations
train_transforms = Compose([
    LoadImaged(keys=['image', 'label']),
    EnsureChannelFirstd(keys=['image', 'label']),
    ScaleIntensityd(keys=['image']),
    RandSpatialCropd(keys=['image', 'label'], roi_size=(96, 96, 96), random_size=False),
    RandFlipd(keys=['image', 'label'], prob=0.5, spatial_axis=0),
    RandRotate90d(keys=['image', 'label'], prob=0.5, max_k=3),
    EnsureTyped(keys=['image', 'label']),
])

# Define validation transformations
val_transforms = Compose([
    LoadImaged(keys=['image', 'label']),
    EnsureChannelFirstd(keys=['image', 'label']),
    ScaleIntensityd(keys=['image']),
    EnsureTyped(keys=['image', 'label']),
])

# Use CacheDataset for faster data loading
num_workers = 4  # Adjust based on your system's capabilities

# Create training and validation datasets and dataloaders
train_ds = CacheDataset(
    data=train_files,
    transform=train_transforms,
    cache_rate=1.0,  # Cache all data
    num_workers=num_workers
)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=num_workers)

val_ds = CacheDataset(
    data=val_files,
    transform=val_transforms,
    cache_rate=1.0,  # Cache all data
    num_workers=num_workers
)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=num_workers)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the model (using MONAI's UNet)
model = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=4,  # Number of classes
    channels=(32, 64, 128, 256, 512),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)

# Optionally, load pre-trained weights if available
# model.load_state_dict(torch.load('/path/to/pretrained/model.pth'))

# Define loss function and optimizer
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Define metric for evaluation without to_onehot_y and softmax
dice_metric = DiceMetric(
    include_background=True,
    reduction="mean",
    get_not_nans=False,
)

# Training loop
max_epochs = 1000
val_interval = 2  # Do validation every 2 epochs
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []

for epoch in range(max_epochs):
    print('-' * 10)
    print('Epoch {}/{}'.format(epoch + 1, max_epochs))
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data['image'].to(device), batch_data['label'].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print('{}/{}, train_loss: {:.4f}'.format(step, len(train_loader), loss.item()))
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print('Epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss))

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            metric_sum = 0.0
            metric_count = 0
            for val_data in val_loader:
                val_images = val_data['image'].to(device)
                val_labels = val_data['label'].to(device)
                val_outputs = model(val_images)

                # Apply softmax to model outputs
                val_outputs_softmax = F.softmax(val_outputs, dim=1)

                # Convert labels to one-hot encoding
                # Ensure labels are of type LongTensor
                val_labels_onehot = F.one_hot(val_labels.squeeze(1).long(), num_classes=4)  # Shape: (1, H, W, D, C)
                val_labels_onehot = val_labels_onehot.permute(0, 4, 1, 2, 3).float()  # Shape: (1, C, H, W, D)

                # Update DiceMetric
                dice_metric(y_pred=val_outputs_softmax, y=val_labels_onehot)

                metric_count += 1

            # Compute the mean dice score
            metric = dice_metric.aggregate().item()
            dice_metric.reset()
            metric_values.append(metric)

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), 'best_metric_model.pth')
                print('Saved new best metric model')
            print('Current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}'.format(
                epoch + 1, metric, best_metric, best_metric_epoch))

print('Training completed, best_metric: {:.4f} at epoch: {}'.format(best_metric, best_metric_epoch))

# Load the best model for evaluation
model.load_state_dict(torch.load('best_metric_model.pth'))

# Evaluation functions
def compute_dice_coefficient(in1, in2, label=1):
    in1 = in1 == label
    in2 = in2 == label
    intersection = np.logical_and(in1, in2).sum()
    volumes = in1.sum() + in2.sum()
    if volumes == 0:
        return np.NaN
    else:
        return 2. * intersection / volumes

def compute_hausdorff_distance(in1, in2, label=1):
    in1 = (in1 == label).astype(np.uint8)
    in2 = (in2 == label).astype(np.uint8)
    in1_sitk = sitk.GetImageFromArray(in1)
    in2_sitk = sitk.GetImageFromArray(in2)
    hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
    hausdorff_distance_filter.Execute(in1_sitk, in2_sitk)
    return hausdorff_distance_filter.GetHausdorffDistance()

def compute_volumetric_difference(in1, in2, label=1):
    in1 = in1 == label
    in2 = in2 == label
    vol1 = in1.sum()
    vol2 = in2.sum()
    if vol1 + vol2 == 0:
        return np.NaN
    else:
        return abs(vol1 - vol2) / (vol1 + vol2)

# Evaluate on validation data
model.eval()
dice_scores = {i: [] for i in range(4)}  # Assuming 4 classes
hausdorff_distances = {i: [] for i in range(4)}
volumetric_differences = {i: [] for i in range(4)}

with torch.no_grad():
    for val_data in val_loader:
        val_images = val_data['image'].to(device)
        val_labels = val_data['label'].cpu().numpy()
        outputs = model(val_images)
        outputs = torch.argmax(F.softmax(outputs, dim=1), dim=1).cpu().numpy()  # Shape: (1, H, W, D)

        for label in range(4):  # For each class
            dice = compute_dice_coefficient(outputs[0], val_labels[0][0], label=label)
            hd = compute_hausdorff_distance(outputs[0], val_labels[0][0], label=label)
            vd = compute_volumetric_difference(outputs[0], val_labels[0][0], label=label)
            if not np.isnan(dice):
                dice_scores[label].append(dice)
            if not np.isnan(hd):
                hausdorff_distances[label].append(hd)
            if not np.isnan(vd):
                volumetric_differences[label].append(vd)

# Compute average metrics per class
for label in range(4):
    avg_dice = np.mean(dice_scores[label]) if dice_scores[label] else np.NaN
    avg_hd = np.mean(hausdorff_distances[label]) if hausdorff_distances[label] else np.NaN
    avg_vd = np.mean(volumetric_differences[label]) if volumetric_differences[label] else np.NaN
    print(f"Class {label}:")
    print(f"\tDice Coefficient = {avg_dice:.4f}")
    print(f"\tHausdorff Distance = {avg_hd:.4f}")
    print(f"\tVolumetric Difference = {avg_vd:.4f}")


Collecting monai==0.9.1
  Using cached monai-0.9.1-202207251608-py3-none-any.whl.metadata (7.5 kB)
Using cached monai-0.9.1-202207251608-py3-none-any.whl (990 kB)
Installing collected packages: monai
  Attempting uninstall: monai
    Found existing installation: monai 1.4.0
    Uninstalling monai-1.4.0:
      Successfully uninstalled monai-1.4.0
Successfully installed monai-0.9.1
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Loading dataset: 100%|██████████| 10/10 [00:01<00:00,  7.01it/s]
Loading dataset: 100%|██████████| 5/5 [00:00<00:00,  5.78it/s]


----------
Epoch 1/1000


  ret = func(*args, **kwargs)
  ret = func(*args, **kwargs)
  ret = func(*args, **kwargs)
  ret = func(*args, **kwargs)
  t = cls([], dtype=storage.dtype, device=storage.device)


1/10, train_loss: 0.7879
2/10, train_loss: 0.8550
3/10, train_loss: 0.7899
4/10, train_loss: 0.8124
5/10, train_loss: 0.8008
6/10, train_loss: 0.7958
7/10, train_loss: 0.8105
8/10, train_loss: 0.7937
9/10, train_loss: 0.8449
10/10, train_loss: 0.7273
Epoch 1 average loss: 0.8018
----------
Epoch 2/1000
1/10, train_loss: 0.7309
2/10, train_loss: 0.7582
3/10, train_loss: 0.8304
4/10, train_loss: 0.7275
5/10, train_loss: 0.7030
6/10, train_loss: 0.7007
7/10, train_loss: 0.7619
8/10, train_loss: 0.6929
9/10, train_loss: 0.6794
10/10, train_loss: 0.7472
Epoch 2 average loss: 0.7332




[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Current epoch: 630 current mean dice: 0.8353 best mean dice: 0.8711 at epoch 610
----------
Epoch 631/1000
1/10, train_loss: 0.0803
2/10, train_loss: 0.0664
3/10, train_loss: 0.0705
4/10, train_loss: 0.1633
5/10, train_loss: 0.0960
6/10, train_loss: 0.0605
7/10, train_loss: 0.1129
8/10, train_loss: 0.0833
9/10, train_loss: 0.0761
10/10, train_loss: 0.0779
Epoch 631 average loss: 0.0887
----------
Epoch 632/1000
1/10, train_loss: 0.1041
2/10, train_loss: 0.0747
3/10, train_loss: 0.1139
4/10, train_loss: 0.1252
5/10, train_loss: 0.0900
6/10, train_loss: 0.0865
7/10, train_loss: 0.0767
8/10, train_loss: 0.0746
9/10, train_loss: 0.1801
10/10, train_loss: 0.0628
Epoch 632 average loss: 0.0989
Current epoch: 632 current mean dice: 0.8677 best mean dice: 0.8711 at epoch 610
----------
Epoch 633/1000
1/10, train_loss: 0.1693
2/10, train_loss: 0.1838
3/10, train_loss: 0.2310
4/10, train_loss: 0.1108
5/10, train_loss: 0.0603
6/10, 

  model.load_state_dict(torch.load('best_metric_model.pth'))


Class 0:
	Dice Coefficient = 0.9963
	Hausdorff Distance = 23.6909
	Volumetric Difference = 0.0018
Class 1:
	Dice Coefficient = 0.8372
	Hausdorff Distance = 21.4847
	Volumetric Difference = 0.0781
Class 2:
	Dice Coefficient = 0.8845
	Hausdorff Distance = 10.7371
	Volumetric Difference = 0.0325
Class 3:
	Dice Coefficient = 0.8522
	Hausdorff Distance = 11.6964
	Volumetric Difference = 0.1046
