<a href="https://colab.research.google.com/github/abrarelidrisi/MRI-Segmentation/blob/main/experiment_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**TASK 1:** Defining our segmentation processes. This includes Normalizing, Gaussian blurring, Otsu Thresholding, Canny Edge Detection, KMeans Segmentation using dictionary-based temporal masks approach.

In [None]:
#Importing important libraries and loading our dataset (containing of MRI scans and the ground truth (Segmented)images
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from scipy.io import loadmat
import sys
import matplotlib.pyplot as plt
from skimage.morphology import remove_small_objects, remove_small_holes, binary_dilation
from scipy.ndimage import binary_dilation, binary_fill_holes
from skimage.feature import canny
import scipy.ndimage as ndi
from sklearn.cluster import KMeans
!pip install pandas
import pandas as pd
from skimage.filters import threshold_multiotsu
from scipy.ndimage import gaussian_filter

In [2]:
#Since I'll be using edge detection algorithms like Canny later on, It's better to normalize our data so the algorithms can perform better and not put muc count to outliers
def normalizing(T1):
  T1_normalized = (T1 - np.min(T1)) / (np.max(T1) - np.min(T1))
  return T1_normalized


In [3]:
#We now start the segmentation process, beginning with smoothing using Gaussian filter. scipy's gaussian_filter works better in 3D images
def smoothing(slice_normalized, sigma = 1):
   if slice_normalized.ndim == 3:
      slice_smoothed = gaussian_filter(slice_normalized, sigma=sigma)

   else:

       slice_smoothed = cv2.GaussianBlur(slice_normalized, (5, 5), 0)

   return slice_smoothed

In [4]:
#For 3D, applying the advanced algorithm of scipy's Multi level Otsu Thresholding to better handle 3D MRI images.
def otsu_thresholding(slice_smoothed):
    if slice_smoothed.ndim == 3:
      # Compute multi-Otsu thresholds
      thresholds = threshold_multiotsu(slice_smoothed)

      # Digitize the image into discrete regions
      thresholded = np.digitize(slice_smoothed, bins=thresholds)

    else:
        _, thresholded = cv2.threshold((slice_smoothed * 255).astype(np.uint8), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

    return thresholded

In [5]:
# Showcasing the outcome of applying otsu on our smoothed slice
def showcase_thresholded(thresholded, slice_smoothed):
  thresholded_img = thresholded * slice_smoothed

  return thresholded_img

In [6]:
#We apply canny edge detection. Given the nature of our ground truth, morphological processes such as hole filling and remove small objects has managed to achieve an improved accuracy.
def applying_canny(thresholded_img, sigma=2, min_size = 100):
    if thresholded_img.ndim == 3:

        # Create an empty array to store edges
        edges = np.zeros_like(thresholded_img, dtype=bool)

        # Apply Canny slice by slice along the depth (3rd dimension)
        for i in range(thresholded_img.shape[2]):
            edges_ = canny(thresholded_img[:, :, i].astype(np.float64), sigma=sigma)  # Apply Canny on a 2D slice
            edges_ = binary_dilation(edges_)  # Dilate the edges
            edges_ = binary_fill_holes(edges_)  # Fill holes in the edges
            edges_ = remove_small_objects(edges_, min_size=min_size)  # Remove small objects
            edges[:, :, i] = edges_  # Store the processed edges

    else:
        edges = canny(thresholded_img, sigma=2)
        # Dilate the edges
        edges = ndi.binary_dilation(edges)

        # Fill holes in the edges
        edges = ndi.binary_fill_holes(edges)

        # remove small objects
        edges = remove_small_objects(edges, min_size=100)


    return edges


In [7]:
#Showcasing the effect of applying Canny edge detection to our image
def showcasing_canny(slice_normalized, edges):

    # Multiply the original image by the edges to mask it
    img_thresholded = slice_normalized * edges  # Element-wise multiplication

    return img_thresholded

In [8]:
#Define k-means segmentation so we can apply it to the image
def kmeans_segmentation(image, n_clusters=6):
    # Reshape the image to a 2D array
    X = image.reshape(-1, 1)

    # Fit KMeans to the data
    kmeans = KMeans(n_clusters=n_clusters, random_state=0)

    # Predict the labels for the data
    labels = kmeans.fit_predict(X)

    # Reshape the labels to the original image shape
    labels = labels.reshape(image.shape)

    # Get the centroids of the clusters
    centroids = kmeans.cluster_centers_

    # Order the centroids and return the indices
    order = np.argsort(centroids, axis=0)

    return labels, order

In [9]:
#Creating a temporal_mask dictionary so we can ensure efficient mapping. This step increaed the segmentation accuracy dramatically given that it allows us to use masks, where each mask is 1 and the rest is zero.  By combining them together, the noise is reduced.
def temporal_mask_creation(cluster_labels, order):
  temporal_masks = {}
  temporal_masks["0"] = (cluster_labels == order[0]).astype(int)
  temporal_masks["1"] = (cluster_labels == order[1]).astype(int)
  temporal_masks["2"] = (cluster_labels == order[2]).astype(int)
  temporal_masks["3"] = (cluster_labels == order[3]).astype(int)
  temporal_masks["4"] = (cluster_labels == order[4]).astype(int)
  temporal_masks["5"] = (cluster_labels == order[5]).astype(int)
  return temporal_masks

In [10]:
#Defining functions to help in plotting our masks, as well as cobining the masks to give our final segmented output
def plot_masks(temporal_masks, rows=1, cols=2, slice_id=None):

    fig, axes = plt.subplots(rows, cols, figsize=(12, 6))

    fig.suptitle('Temporal Masks')

    flatten_axes = axes.flat
    for ax, (key, mask) in zip(flatten_axes, temporal_masks.items()):
        # Check if the mask has more than 2 dimensions and slice it to 2D if needed
        if mask.ndim > 2:
            mask = mask[:, :, 0]  # Select the first slice of the 3D array

        if slice_id is not None:
            mask = mask[:, :, slice_id]
        ax.imshow(mask, cmap='gray')
        ax.set_title(f'Mask: {key}')
        ax.axis('off')

    plt.tight_layout()


def temporal_masks2final_segmented_mask(temporal_masks, labels = range(6), slice_id = None):
    if slice_id is not None:
        segmented_labels = np.zeros_like(temporal_masks["0"][:,:,0])
    else:
        segmented_labels = np.zeros_like(temporal_masks["0"])

    # Accumulate all the temporal masks in the segmented_labels
    for label in labels:
        mask = temporal_masks[str(label)]
        if slice_id is not None:
            mask = mask[:,:,slice_id]
        segmented_labels[mask == 1] = label

    return segmented_labels

In [11]:
#function to showcase our final segmented image
def showcase_segmented_image(temporal_masks):
  segmented_img = temporal_masks2final_segmented_mask(temporal_masks)

  return segmented_img

In [12]:
#Defining our Dice metrics
def dice_coefficient(pred, truth):
    pred = pred > 0  # Ensure binary format
    truth = truth > 0  # Ensure binary format
    intersection = np.sum(pred * truth)
    return 2 * intersection / (np.sum(pred) + np.sum(truth))

In [13]:
#Defining our Jaccard metrics
def jaccard_index(pred, truth):
    pred = pred > 0
    truth = truth > 0
    intersection = np.sum(pred * truth)
    union = np.sum(pred + truth) - intersection
    return intersection / union

In [14]:
#Defining our accuracy metrics
def pixel_accuracy(pred, truth):

    pred = pred > 0  # Ensure binary format
    truth = truth > 0  # Ensure binary format

    correct = np.sum(pred == truth)  # Count matching pixels
    total = truth.size  # Total number of pixels
    return correct / total

In [15]:
#Showcasing the performance of our segmentation algoirthm vs. the ground truth
def compare_label_distributions(ground_truth, prediction, num_classes):

    # Count label occurrences in ground truth
    gt_labels, gt_counts = np.unique(ground_truth, return_counts=True)
    gt_distribution = np.zeros(num_classes)
    gt_distribution[gt_labels] = gt_counts

    # Count label occurrences in prediction
    pred_labels, pred_counts = np.unique(prediction, return_counts=True)
    pred_distribution = np.zeros(num_classes)
    pred_distribution[pred_labels] = pred_counts

    # Normalize to percentage if needed
    gt_percentage = (gt_distribution / np.sum(gt_distribution)) * 100
    pred_percentage = (pred_distribution / np.sum(pred_distribution)) * 100

    # Plot the distributions
    labels = [f"Label {i}" for i in range(num_classes)]
    x = np.arange(num_classes)

    fig, ax = plt.subplots(figsize=(10, 6))
    ax.bar(x - 0.2, gt_percentage, width=0.4, label='Ground Truth (%)', color='blue')
    ax.bar(x + 0.2, pred_percentage, width=0.4, label='Prediction (%)', color='orange')

    # Add labels, legend, and title
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.set_xlabel("Class Labels")
    ax.set_ylabel("Percentage of Pixels (%)")
    ax.set_title("Label Distribution: Ground Truth vs Prediction")
    ax.legend()

    plt.tight_layout()
    plt.show()

**TASK 1: 2D Segmentation: Now we start the 2D segmentation process. Calling all the functions above and showcaisng results.**

In [16]:
# Load data
data = loadmat('Brain.mat')
T1 = data['T1']
label = data['label']

In [42]:
#Defining a segmentation function that applies the whol steps
def full_segmentation_2d(slice_idx):

  slice_data = T1[..., slice_idx]
  ground_truth = label[..., slice_idx]

  # Collect images, titles, and colormaps for visualization
  images = []
  titles = []
  colormaps = []

  # Step 1: Normalizing
  T1_normalized = normalizing(T1)
  slice_normalized = T1_normalized[..., slice_idx]
  images.append(slice_normalized)
  titles.append("Normalized Slice")
  colormaps.append("gray")

  # Step 2: Gaussian Smoothing
  slice_smoothed = smoothing(slice_normalized)
  images.append(slice_smoothed)
  titles.append("Smoothed Slice")
  colormaps.append("gray")

  # Step 3: Otsu Thresholding
  thresholded = otsu_thresholding(slice_smoothed)
  images.append(thresholded)
  titles.append("Otsu Thresholding")
  colormaps.append("gray")

  # Step 3: Effect of Otsu Thresholding on Image
  thresholded_img = showcase_thresholded(thresholded, slice_smoothed)
  images.append(thresholded_img)
  titles.append("Effect of thresholding on Image")
  colormaps.append("gray")

  # Step 4: Canny Edge Detection
  edges = applying_canny(thresholded)
  images.append(edges)
  titles.append("Canny Edges")
  colormaps.append("gray")

  #For Option3 for 2D algorithm (not using morphological operations), comment the above canny edge detection part and uncomment the following:
  # edges = canny(thresholded_img, sigma=5)
  # images.append(edges)
  # titles.append("Canny Edges")
  # colormaps.append("gray")


  # Step 5: Showcasing Canny Effect
  img_thresholded = showcasing_canny(slice_normalized, edges)
  images.append(img_thresholded)
  titles.append("Effect of Canny on Image")
  colormaps.append("gray")

  # Step 6: K-means Segmentation
  cluster_labels, order = kmeans_segmentation(img_thresholded, n_clusters=6)
  temporal_masks = temporal_mask_creation(cluster_labels, order)
  segmented_img = showcase_segmented_image(temporal_masks)
  images.append(segmented_img)
  titles.append("Segmented Image")
  colormaps.append("viridis")

  #For Option2 in 2D Segmentation (not using dictionary-bases temporal masks after Kmeans), comment the above Kmeans and uncomment the below
  # slice_flat = img_thresholded.flatten().reshape(-1, 1)
  # kmeans = KMeans(n_clusters=6, random_state=0)

  # # Predict the labels for the data
  # kmeans_labels = kmeans.fit_predict(slice_flat)

  # segmented_img = kmeans_labels.reshape(img_thresholded.shape)


  # Step 8: Metrics
  dice_score = dice_coefficient(segmented_img, ground_truth)
  jaccard_score = jaccard_index(segmented_img, ground_truth)
  accuracy_score = pixel_accuracy(segmented_img, ground_truth)


  return temporal_masks, segmented_img, dice_score, jaccard_score, accuracy_score, ground_truth, images, titles, colormaps

In [43]:
#Defining a function to showcase all segmentation process in a grid
def showcase_in_grid(images, titles, rows=2, cols=4, colormaps=None):

    fig, axes = plt.subplots(rows, cols, figsize=(20, 10))
    axes = axes.flatten()

    for i, (img, title) in enumerate(zip(images, titles)):
        cmap = colormaps[i] if colormaps else 'gray'
        axes[i].imshow(img, cmap=cmap)
        axes[i].set_title(title)
        axes[i].axis('off')

    # Turn off unused axes
    for i in range(len(images), len(axes)):
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
#Loop throigh ths slices, showcase each final segmented slice and its corrresponding score
metrics_2D = []
for slice_idx in range(T1.shape[2]):
  temporal_masks, segmented_img, dice_score, jaccard_score, accuracy_score, ground_truth, images, titles, colormaps= full_segmentation_2d(slice_idx)
  print(f"Slice {slice_idx}")
  print(f"Dice Coefficient: {dice_score:.3f}")
  print(f"Jaccard Index: {jaccard_score:.3f}")
  print(f"Pixel Accuracy: {accuracy_score:.3f}")

  metrics_2D.append({
              "Slice Index": slice_idx,
              "Dice Coefficient": dice_score,
              "Jaccard Index": jaccard_score,
              "Pixel Accuracy": accuracy_score
          })

  plot_masks(temporal_masks, rows=2, cols=4)
  plt.imshow(segmented_img)
  plt.title('Segmented Output')
  plt.colorbar(ticks=range(6), label="Labels (0-5)")
  plt.show()
  print("---------------------------------")



In [None]:
#Loop throigh ths slices, showcaisng all process and segmented result
for slice_idx in range(T1.shape[2]):
  print(f"Slice {slice_idx}")
  temporal_masks, segmented_img, dice_score, jaccard_score, accuracy_score, ground_truth, images, titles, colormaps = full_segmentation_2d(slice_idx)

  showcase_in_grid(images, titles, rows=2, cols=4, colormaps=colormaps)

**Task 2: Showcasing comparisons via metrics such as: Dice, Jaccard, and pixel accuracy.**

In [None]:
def display_metrics_table(metrics, num_slices):
  metrics_df = pd.DataFrame(metrics)

      # Display the table
  print("\nSegmentation Metrics for All Slices:\n")
  print(metrics_df)
  metrics_df.to_csv("segmentation_metrics.csv", index=False)

num_slices = T1.shape[2]
display_metrics_table(metrics_2D, num_slices)

In [None]:
# Loop through all slice indices
for slice_idx in range(num_slices):
      # Call the full_segmentation function for the current slice ti compare the labels (classes)
      temporal_masks, segmented_img, dice_score, jaccard_score, accuracy_score, ground_truth, images, titles, colormaps = full_segmentation_2d(slice_idx)
      compare_label_distributions(ground_truth, segmented_img, num_classes = 6)



**TASK 3: 3D Segmentation Process**

In [23]:
#Defining a segmentation function that applies the whole steps
def full_segmentation_3d(slice):
    slice_data = T1
    ground_truth = label

    # List to store intermediate results and titles
    images = []
    titles = []
    colormaps = [] #To showcase the images

    # Step 1: Gaussian Normalizaing
    slice_normalized = normalizing(T1)
    images.append(slice_normalized[:, :, slice_normalized.shape[2] // 2])
    titles.append("Normalized Slice")
    colormaps.append("gray")

    # Step 2: Smoothing
    slice_smoothed = smoothing(slice_normalized)
    images.append(slice_smoothed[:, :, slice_smoothed.shape[2] // 2])
    titles.append("Gaussian Smoothed Slice")
    colormaps.append("gray")

    # Step 3: Otsu Thresholding
    thresholded = otsu_thresholding(slice_smoothed)
    thresholded_img = showcase_thresholded(thresholded, slice_smoothed)
    images.append(thresholded_img[:, :, thresholded_img.shape[2] // 2])
    titles.append("Otsu Thresholded Image")
    colormaps.append("grey")

    # Step 4: Apply Canny Edge Detection
    edges = applying_canny(thresholded_img)
    images.append(edges[:, :, edges.shape[2] // 2])
    titles.append("Canny Edges")
    colormaps.append("gray")

    # Step 5: Showcase Canny Effect
    img_thresholded = showcasing_canny(slice_normalized, edges)
    images.append(img_thresholded[:, :, img_thresholded.shape[2] // 2])
    titles.append("Masked Edges")
    colormaps.append(None)

    # Step 6: K-means Segmentation
    cluster_labels, order = kmeans_segmentation(thresholded_img, n_clusters=6)
    temporal_masks = temporal_mask_creation(cluster_labels, order)
    segmented_img = showcase_segmented_image(temporal_masks)
    images.append(segmented_img[:, :, segmented_img.shape[2] // 2])
    titles.append("Segmented Image")
    colormaps.append("viridis")

    # Showcase all results in a grid
    showcase_in_grid(images, titles, rows=2, cols=4, colormaps=colormaps)

    # Step 7: Compare Label Distributions (shown after the grid)
    compare_label_distributions(ground_truth, segmented_img, num_classes=6)

    # Step 7: Metrics
    dice_score = dice_coefficient(segmented_img, ground_truth)
    jaccard_score = jaccard_index(segmented_img, ground_truth)
    accuracy_score = pixel_accuracy(segmented_img, ground_truth)

    return temporal_masks, segmented_img, dice_score, jaccard_score, accuracy_score

In [24]:
#Plotting grid to showcase all results
def showcase_in_grid_3d(images, titles, rows=2, cols=4, colormaps=None):
    fig, axes = plt.subplots(rows, cols, figsize=(20, 10))
    axes = axes.flatten()  # Flatten to index easily

    for i, (img, title) in enumerate(zip(images, titles)):
        if isinstance(img, np.ndarray):  # If it's an image
            cmap = colormaps[i] if colormaps else 'gray'  # Default to grayscale
            axes[i].imshow(img, cmap=cmap)
            axes[i].set_title(title)
            axes[i].axis('off')
        elif isinstance(img, plt.Figure):  # If it's a Matplotlib figure
            # Render the plot into the grid
            img.canvas.draw()
            axes[i].imshow(np.array(img.canvas.renderer.buffer_rgba()))
            axes[i].axis('off')
            axes[i].set_title(title)

    # Turn off unused axes
    for i in range(len(images), len(axes)):
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
#Showcasing results
slice = T1

#Showcaisng all segmentation steps taken
temporal_masks, segmented_img, dice_score, jaccard_score, accuracy_score = full_segmentation_3d(slice)

metrics_3D = []
metrics_3D.append({
              "Slice Index": slice_idx,
              "Dice Coefficient": dice_score,
              "Jaccard Index": jaccard_score,
              "Pixel Accuracy": accuracy_score
          })

#Metrics
print("-------------- Metrics -----------------")
print(f"Dice Coefficient: {dice_score:.3f}")
print(f"Jaccard Index: {jaccard_score:.3f}")
print(f"Pixel Accuracy: {accuracy_score:.3f}")
print("----------------------------------------")

#Showcasing segmented image
plt.imshow(segmented_img[:, :, segmented_img.shape[2] // 2])
plt.title('Segmented Output')
plt.colorbar(ticks=range(6), label="Labels (0-5)")
plt.show()

#Showcasing the temporal masks used to hold the model labels
plot_masks(temporal_masks, rows=2, cols=4)

**Comparing between the performance of our 2D and 3D Segmentation Process**

In [None]:
# Total slices
num_slices = T1.shape[2]
# Display side-by-side tables
display_metrics_table(metrics_3D, num_slices)