<a href="https://colab.research.google.com/github/aqibjaved28/TH22--R_IW_CI_PW-/blob/main/R_IW_CI_PW.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Libraries

In [12]:
import sys
sys.path.insert(0,'/content/RICI/')

Install requirements

In [None]:
!pip install monai
!pip3 install torch torchvision torchaudio

Import libraries

In [None]:
import nibabel as nib
import torch
import matplotlib.pyplot as plt
%matplotlib inline

from main.ICI_loss import ICILoss
from main.RegOutputsLabels import RegOutputsLabels
from main import tools

from monai.losses import DiceLoss
from monai.losses import FocalLoss
from losses.benchmark import MAX_SEG_PIXEL, MAX_SEG_INSTANCE, MAX_SEG_CENTER, MAX_SEG_FDR, norm


import seaborn as sns
import numpy as np
import random

from monai.losses import DiceLoss

print(torch.__version__)
print(torch.cuda.device_count())
print(torch.cuda.current_device())
print(torch.cuda.device(0))
print(torch.cuda.get_device_name(0))

## Load Data

# EDA, Visualization, and Pre-processing

In [None]:
#Data
label_img = nib.load("/content/RICI/image_data/label.nii.gz")
output_img = nib.load("/content/RICI/image_data/output.nii.gz")

# Convert the images to numpy arrays for analysis
label_data = label_img.get_fdata()
output_data = output_img.get_fdata()

# Get basic information
label_shape = label_data.shape
output_shape = output_data.shape
label_min, label_max = label_data.min(), label_data.max()
output_min, output_max = output_data.min(), output_data.max()

print("Label Image Shape:", label_shape)
print("Output Image Shape:", output_shape)
print("Label Intensity Range: Min =", label_min, "Max =", label_max)
print("Output Intensity Range: Min =", output_min, "Max =", output_max)

In [None]:
print("Label Data Shape:", label_data.shape)
print("Output Data Shape:", output_data.shape)
print("Label Affine Matrix:\n", label_img.affine)
print("Output Affine Matrix:\n", output_img.affine)
print("Label Data Type:", label_data.dtype)
print("Output Data Type:", output_data.dtype)


In [None]:
# Data Summary Statistics
label_stats = {
    'min': np.min(label_data),
    'max': np.max(label_data),
    'mean': np.mean(label_data),
    'std': np.std(label_data),
    'sum': np.sum(label_data > 0)  # count of lesion voxels
}

output_stats = {
    'min': np.min(output_data),
    'max': np.max(output_data),
    'mean': np.mean(output_data),
    'std': np.std(output_data),
    'sum': np.sum(output_data > 0)  # count of lesion voxels
}

print("\nLabel Data Stats:", label_stats)
print("Output Data Stats:", output_stats)


In [None]:
# Visualize MRI and Segmentation Slices
def plot_slices(data, title, slices=[30, 50, 70]):
    fig, axs = plt.subplots(1, len(slices), figsize=(15, 5))
    for i, slice_idx in enumerate(slices):
        axs[i].imshow(data[:, :, slice_idx], cmap='gray')
        axs[i].set_title(f"{title} - Slice {slice_idx}")
        axs[i].axis('off')
    plt.show()

plot_slices(label_data, "Label Data")
plot_slices(output_data, "Output Data")

In [None]:
# Overlay Segmentation on MRI Output
def plot_overlay_slices(label_data, output_data, slices=[30, 50, 70]):
    fig, axs = plt.subplots(1, len(slices), figsize=(15, 5))
    for i, slice_idx in enumerate(slices):
        axs[i].imshow(output_data[:, :, slice_idx], cmap='gray')
        axs[i].imshow(label_data[:, :, slice_idx], cmap='Reds', alpha=0.5)  # overlay label in red
        axs[i].set_title(f"Overlay - Slice {slice_idx}")
        axs[i].axis('off')
    plt.show()

plot_overlay_slices(label_data, output_data)

In [None]:
# Lesion Volume and Distribution Analysis
# Lesion volume in each slice
lesion_volumes = [np.sum(label_data[:, :, i] > 0) for i in range(label_data.shape[2])]
output_volumes = [np.sum(output_data[:, :, i] > 0) for i in range(output_data.shape[2])]

# Plot lesion volume across slices
plt.figure(figsize=(10, 5))
plt.plot(lesion_volumes, label="Label Lesion Volume")
plt.plot(output_volumes, label="Output Lesion Volume", linestyle='--')
plt.xlabel("Slice Index")
plt.ylabel("Lesion Volume (voxel count)")
plt.title("Lesion Volume Across Slices")
plt.legend()
plt.show()

In [None]:
# Voxel Intensity Distribution
# Compare intensity distributions
plt.figure(figsize=(12, 6))
sns.histplot(label_data.flatten(), color='red', kde=True, label='Label Data')
sns.histplot(output_data.flatten(), color='blue', kde=True, label='Output Data')
plt.xlabel("Voxel Intensity")
plt.ylabel("Frequency")
plt.title("Voxel Intensity Distribution in Label and Output Data")
plt.legend()
plt.show()

In [None]:
# 3D Lesion Density Map (Mean Lesion Across Slices)
mean_lesion_map = np.mean(label_data > 0, axis=2)  # calculate mean presence of lesion across all slices
plt.figure(figsize=(8, 8))
plt.imshow(mean_lesion_map, cmap='hot')
plt.title("Mean Lesion Map Across Slices")
plt.colorbar(label="Lesion Presence Frequency")
plt.show()

In [None]:
# Visualize sample slices from the middle of each volume
mid_slice_label = label_data.shape[2] // 2
mid_slice_output = output_data.shape[2] // 2

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(label_data[:, :, mid_slice_label], cmap='gray')
plt.title('Label - Middle Slice')

plt.subplot(1, 2, 2)
plt.imshow(output_data[:, :, mid_slice_output], cmap='gray')
plt.title('Output - Middle Slice')

plt.show()

In [None]:
print("Label Intensity Range: Min =", label_min, "Max =", label_max)
print("Output Intensity Range: Min =", output_min, "Max =", output_max)

In [None]:
# Lesion Intensity Analysis
# Calculate mean intensity within lesions vs. non-lesion areas
lesion_intensity = output_data[label_data > 0]
non_lesion_intensity = output_data[label_data == 0]

plt.figure(figsize=(10, 5))
sns.histplot(lesion_intensity, color='red', kde=True, label='Lesion Intensity')
sns.histplot(non_lesion_intensity, color='blue', kde=True, label='Non-Lesion Intensity')
plt.legend()
plt.xlabel("Voxel Intensity")
plt.ylabel("Frequency")
plt.title("Lesion vs Non-Lesion Intensity Distribution")
plt.show()


## Regulerized IW, CI, and PW Loss

The regulerized IW, CI, and PW loss functions are one such sophisticated tool designed to enhance the segmentation outcomes. This guide provides an in-depth explanation on how to instantiate and utilize the ICILoss class effectively. It also explores its integration with various pixel-wise segmentation losses such as Dice Loss and Focal Loss, leveraging the capabilities of the MONAI core library.

In [None]:
loss_dice = DiceLoss(
    to_onehot_y=False,
    sigmoid=False,
    softmax=False
    )

loss_dice_center = DiceLoss(
    to_onehot_y=False,
    sigmoid=False,
    softmax=False
    )

loss_focal = FocalLoss(
    to_onehot_y=False,
    use_softmax=False
    )

loss_focal_center = FocalLoss(
    to_onehot_y=False,
    use_softmax=False
    )

activation="none"
num_out_chn = 1
object_chn = 1
mul_too_many = 50
centroid_offset = 4
num_iterations = 350
max_false_detections = 50
rate_instead_number = False

ici_loss_function = ICILoss(
    loss_function_pixel=loss_dice,
    loss_function_instance=loss_dice_center,
    loss_function_center=loss_dice_center,
    activation=activation,
    num_out_chn=num_out_chn,
    object_chn=object_chn,
    mul_too_many=mul_too_many,
    max_cc_out=max_false_detections,
    num_iterations=num_iterations,
    centroid_offset=centroid_offset,
    rate_instead_number=rate_instead_number,
    instance_wise_loss_no_tp=True,
)

ici_loss_function.print_parameters()

In [None]:
output_file_path = "/content/ICI-loss/example_blobs/output-0.nii.gz"
label_file_path = "/content/ICI-loss/example_blobs/label-0.nii.gz"

# Initialize the processor
processor = RegOutputsLabels(output_file_path, label_file_path)

# Compute scaled tensors
r_outputs, r_labels = processor.compute_scaled_tensors(n=10)


# Print or use the scaled tensors
print("Regularized Outputs Tensor:", r_outputs)
print("Regularized Labels Tensor:", r_labels)

In [None]:
seg_pixel, seg_instance, seg_center, seg_fdr, cc_falsed, cc_missed = ici_loss_function(
    r_outputs,
    r_labels,
)

if seg_instance < 0 or seg_instance > MAX_SEG_INSTANCE:
    seg_instance = norm(MAX_SEG_INSTANCE)
    print(f"RIW loss: {seg_instance:.4f}")

if seg_center < 0 or seg_center > MAX_SEG_CENTER:
    seg_center = norm(MAX_SEG_CENTER)
    print(f"RCI loss: {seg_center:.4f}")

if seg_pixel < 0 or seg_pixel > MAX_SEG_PIXEL:
    seg_pixel = norm(MAX_SEG_PIXEL)
    print(f"RPW loss: {seg_pixel:.4f}")

if seg_fdr < 0 or seg_fdr > MAX_SEG_FDR:
    seg_fdr = norm(MAX_SEG_FDR)
    print(f"FDR: {seg_fdr:.4f}")

# Print the final Regularized metrics
print(f"Final Regularized Metrics - RIW loss: {seg_instance:.4f}, RCI loss: {seg_center:.4f}, RPW loss: {seg_pixel:.4f}, FDR: {seg_fdr:.4f}")

In [None]:
seg_pixel, seg_instance, seg_center, seg_fdr, cc_falsed, cc_missed = ici_loss_function(
    r_outputs,
    r_labels,
)


print("\nRIW, RCI , and RPW")
print("riw loss:", seg_instance, " - with gradients? ", seg_instance.requires_grad)
print("rci loss:", seg_center, " - with gradients? ", seg_center.requires_grad)
print("rpw loss:", seg_pixel, " - with gradients? ", seg_pixel.requires_grad)
print("DONE!")

In [None]:
print("\nFDR")
print("num false:", cc_falsed, " - with gradients? ", cc_falsed.requires_grad)
print("DONE!")

In [None]:
print("\nNumber of missed instances")
print("num missed:", cc_missed, " - with gradients? ", cc_missed.requires_grad)
print("DONE!")

# Cost Analysis

In [None]:
import tracemalloc
import time

tracemalloc.start()
# Start tracking memory

start_time = time.time()

seg_pixel, seg_instance, seg_center, seg_fdr, cc_falsed, cc_missed = ici_loss_function(
    r_outputs,
    r_labels,
)
end_time = time.time()

print(f"Time taken for ici_loss_function: {end_time - start_time} seconds")

# Measure memory usage
current, peak = tracemalloc.get_traced_memory()

percentage_usage = (current / peak) * 100 if peak != 0 else 0

print(f"Current memory usage: {current / 1024:.2f} KB")
print(f"Peak memory usage: {peak / 1024:.2f} KB")
print(f"Percentage of peak usage: {percentage_usage:.2f}%")


tracemalloc.stop()