#**Libraries / path definition**

In [None]:
!pip install torchmetrics

In [None]:
import os
import re
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import tifffile
from PIL import Image
import torch
import torch.nn.functional as F
import sys
import torchmetrics
from skimage import io
from torchmetrics.classification import MulticlassJaccardIndex
import random

In [None]:
# Define the project directory path
project_dir = '/content/gdrive/MyDrive/'

# Define the name of the folder containining the datasets. All data are in tiff format (or equivalent).
# The expected data directory structure is as follows:
# Datasets
# |_Sample1
# |  |_img
# |     |_image1.tiff
# |     |_image2.tiff
# |     |_...
# |  |_mask
# |     |_mask1.tiff
# |     |_mask2.tiff
# |     |_...
# |  |_...
# ...
dataset_name = "some_dataset"
sample_name = "sample3"
num_classes = 3
crop_size = (560, 560)
# Arguments for processing the images. Can be modified.
# User need to set the fixed_values according to data to be processed and the number of classes
args = {
    'fixed_values': [85, 170, 255],
    'num_bit': 8,
    'num_cluster': 3,
    'fuzziness': 2,
    'epsilon': 0.005,
    'max_iteration': 1000,
}

######

# Add the DinoV2 code directory to the system path for module imports
sys.path.append(os.path.join(project_dir, "code/DinoV2/"))

# Define the data directory path within the project directory
data_directory = os.path.join(project_dir, 'data')

# Define the input directory path to the dataset for segmentation
input_directory = os.path.join(data_directory, dataset_name, sample_name)

# **Various functions**




In [None]:
class FCM():
    def __init__(self, image, fixed_values, image_bit, n_clusters, m, epsilon, max_iter):
        """
        Initializes the FCM (Fuzzy C-Means) clustering instance.

        Parameters:
            image (np.ndarray): Input grayscale image.
            image_bit (int): Bit depth of the image.
            n_clusters (int): Number of clusters.
            m (float): Fuzziness parameter (>= 1).
            epsilon (float): Convergence threshold.
            max_iter (int): Maximum number of iterations.
        """
        if np.ndim(image) != 2:
            raise Exception("<image> needs to be 2D (gray scale image).")
        if n_clusters <= 0 or n_clusters != int(n_clusters):
            raise Exception("<n_clusters> needs to be positive integer.")
        if m < 1:
            raise Exception("<m> needs to be >= 1.")
        if epsilon <= 0:
            raise Exception("<epsilon> needs to be > 0")

        self.image = image
        self.image_bit = image_bit
        self.n_clusters = n_clusters
        self.m = m
        self.epsilon = epsilon
        self.max_iter = max_iter
        self.fixed_values= fixed_values

        self.shape = image.shape
        self.X = image.flatten().astype('float')
        self.numPixels = image.size

    def initial_U(self):
        """
        Initializes the membership matrix U with equal membership for all clusters.

        Returns:
            np.ndarray: The initialized membership matrix.
        """
        U = np.zeros((self.numPixels, self.n_clusters))
        idx = np.arange(self.numPixels)
        for ii in range(self.n_clusters):
            idxii = idx % self.n_clusters == ii
            U[idxii, ii] = 1
        return U

    def update_U(self):
        """
        Updates the membership matrix U based on the current cluster centers C.

        Returns:
            np.ndarray: The updated membership matrix.
        """
        c_mesh, idx_mesh = np.meshgrid(self.C, self.X)
        power = 2. / (self.m - 1)
        p1 = abs(idx_mesh - c_mesh) ** power
        p2 = np.sum((1. / abs(idx_mesh - c_mesh)) ** power, axis=1)
        return 1. / (p1 * p2[:, None])

    def update_C(self):
        """
        Updates the cluster centers C based on the current membership matrix U.

        Returns:
            np.ndarray: The updated cluster centers.
        """
        numerator = np.dot(self.X, self.U ** self.m)
        denominator = np.sum(self.U ** self.m, axis=0)
        return numerator / denominator

    def form_clusters(self):
        """
        Forms clusters by iteratively updating U and C until convergence or maximum iterations.

        Sets:
            self.U: Updated membership matrix.
            self.C: Updated cluster centers.
        """
        d = 100
        self.U = self.initial_U()
        if self.max_iter != -1:
            i = 0
            while True:
                self.C = self.update_C()
                old_u = np.copy(self.U)
                self.U = self.update_U()
                d = np.sum(abs(self.U - old_u))
                if d < self.epsilon or i > self.max_iter:
                    break
                i += 1
        else:
            i = 0
            while d > self.epsilon:
                self.C = self.update_C()
                old_u = np.copy(self.U)
                self.U = self.update_U()
                d = np.sum(abs(self.U - old_u))
                if d < self.epsilon or i > self.max_iter:
                    break
                i += 1
        self.segmentImage()

    def deFuzzify(self):
        """
        Converts the fuzzy membership matrix U to a crisp classification.

        Returns:
            np.ndarray: Array of cluster indices for each pixel.
        """
        return np.argmax(self.U, axis=1)

    def map_clusters_to_fixed_values(self, centroids, fixed_values):
        """
        Maps cluster indices to fixed intensity values.

        Parameters:
            centroids (np.ndarray): Array of cluster centers.
            fixed_values (list): List of fixed intensity values to map to.

        Returns:
            dict: Mapping from cluster index to fixed intensity value.
        """
        sorted_indices = np.argsort(centroids)
        sorted_fixed_values = sorted(fixed_values)
        mapping = {sorted_indices[i]: sorted_fixed_values[i] for i in range(len(sorted_fixed_values))}
        return mapping

    def segmentImage(self):
        """
        Segments the image by de-fuzzifying the membership matrix and mapping clusters to fixed values.

        Returns:
            np.ndarray: The segmented image.
        """
        result = self.deFuzzify()
        fixed_values = self.fixed_values
        mapping = self.map_clusters_to_fixed_values(self.C, fixed_values)
        mapped_result = np.vectorize(mapping.get)(result)
        self.result = mapped_result.reshape(self.shape).astype('int')
        return self.result

# Define the Non-Local Means filter function
def non_local_means_filter(image, h=10, templateWindowSize=7, searchWindowSize=21):
    return cv2.fastNlMeansDenoising(image, None, h, templateWindowSize, searchWindowSize)

# Define a function to center crop the image
def center_crop(image, crop_size):
    height, width = image.shape[:2]
    crop_height, crop_width = crop_size
    if height < crop_height or width < crop_width:
        raise ValueError("Crop size must be smaller than the image size")
    top = (height - crop_height) // 2
    left = (width - crop_width) // 2
    cropped_image = image[top:top + crop_height, left:left + crop_width]
    return cropped_image

def mapping(target, pred, num_classes):
  unique_values_target = np.unique(target)
  unique_values_pred = np.unique(pred)

  # Check that both sets have the same unique values
  if np.array_equal(unique_values_target, unique_values_pred):
      list_tmp = list(range(num_classes))
      # If both sets are mapped using int ranging from 0 to num_classes
      if np.array_equal(unique_values_target, list_tmp):
          pass  # continue
      else:
        # Create a mapping dictionary for cropped_mask
        mapping_dict_target = {old_val: new_val for old_val, new_val in zip(unique_values_target, list_tmp)}
        # Remap cropped_mask
        copy_target = np.copy(target)
        for old_val, new_val in mapping_dict_target.items():
            target[copy_target == old_val] = new_val

        # Create a mapping dictionary for otsu_segmented
        mapping_dict_pred = {old_val: new_val for old_val, new_val in zip(unique_values_pred, list_tmp)}

        # Remap otsu_segmented
        copy_pred = np.copy(pred)
        for old_val, new_val in mapping_dict_pred.items():
            pred[copy_pred == old_val] = new_val
  else:
      list_tmp = list(range(num_classes))

      # Create a mapping dictionary for cropped_mask
      mapping_dict_target = {old_val: new_val for old_val, new_val in zip(unique_values_target, list_tmp)}

      # Remap cropped_mask
      copy_target = np.copy(target)
      for old_val, new_val in mapping_dict_target.items():
          target[copy_target == old_val] = new_val

      # Create a mapping dictionary for otsu_segmented
      mapping_dict_pred = {old_val: new_val for old_val, new_val in zip(unique_values_pred, list_tmp)}

      # Remap otsu_segmented
      copy_pred = np.copy(pred)
      for old_val, new_val in mapping_dict_pred.items():
          pred[copy_pred == old_val] = new_val

  return target, pred


def process_images(input_directory, crop_size, num_classes, args):
    image_files = [f for f in os.listdir(os.path.join(input_directory, 'images')) if os.path.isfile(os.path.join(input_directory, 'images', f))]
    mask_files = [f for f in os.listdir(os.path.join(input_directory, 'masks')) if os.path.isfile(os.path.join(input_directory, 'masks', f))]
    jaccard = torchmetrics.classification.MulticlassJaccardIndex(num_classes=num_classes)

    def extract_number(file_name):
        match = re.search(r'\d{4}', file_name)
        return int(match.group()) if match else 0

    image_files = sorted(image_files, key=extract_number)
    mask_files = sorted(mask_files, key=extract_number)

    combined = list(zip(image_files, mask_files))

    results = []
    total_iou = 0
    count = 0

    for img_file, mask_file in tqdm(combined, desc="Loading images and masks", total=len(image_files)):
        img_path = os.path.join(input_directory, 'images', img_file)
        mask_path = os.path.join(input_directory, 'masks', mask_file)

        img = io.imread(img_path, as_gray=True)
        mask = io.imread(mask_path, as_gray=True)

        filtered_image = non_local_means_filter(img, h=15, templateWindowSize=7, searchWindowSize=21)

        # Apply FCM clustering to the filtered image
        cluster = FCM(filtered_image, fixed_values=args['fixed_values'], image_bit=args['num_bit'], n_clusters=args['num_cluster'], m=args['fuzziness'], epsilon=args['epsilon'], max_iter=args['max_iteration'])
        cluster.form_clusters()
        fcm_segmented = cluster.result

        cropped_scanner = center_crop(img, crop_size)
        cropped_pred = center_crop(fcm_segmented, crop_size)
        cropped_mask = center_crop(mask, crop_size)

        cropped_mask, cropped_pred = mapping(cropped_mask, cropped_pred, num_classes)

        # Convert to tensors with uint8 data type
        ground_truth_mask_tensor = torch.tensor(cropped_mask, dtype=torch.uint8)
        fcm_mask_tensor = torch.tensor(cropped_pred, dtype=torch.uint8)

        # Add batch dimension
        ground_truth_mask_tensor = ground_truth_mask_tensor.unsqueeze(0)
        fcm_mask_tensor = fcm_mask_tensor.unsqueeze(0)

        try:
            mIoU = jaccard(fcm_mask_tensor, ground_truth_mask_tensor).item()
        except Exception as e:
            print(f"Error calculating mIoU: {e}")
            continue

        total_iou += mIoU
        count += 1

        results.append((cropped_scanner, cropped_mask, cropped_pred))

    if count > 0:
        average_iou = total_iou / count
    else:
        average_iou = 0
        print("No valid image-mask pairs processed.")

    return average_iou, results

def display_random_set(image_mask_pred_list):
    random_set = random.choice(image_mask_pred_list)
    img, ground_truth_mask, pred_mask = random_set
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(img, cmap='gray')
    ax[0].set_title('Image')
    ax[1].imshow(ground_truth_mask, cmap='gray')
    ax[1].set_title('Ground Truth Mask')
    ax[2].imshow(pred_mask, cmap='gray')
    ax[2].set_title('FCM segmented')
    plt.show()

# **Segmentation**

In [None]:
average_iou, results = process_images(input_directory, crop_size, num_classes, args)
print(f"Average IoU: {average_iou:.4f}")

# **Display**

In [None]:
display_random_set(results)