#**Libraries / path definition**

In [None]:
!pip install torchmetrics

In [None]:
import os
import re
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import tifffile
from PIL import Image
import torch
import torch.nn.functional as F
import sys
import torchmetrics
from skimage import io
from torchmetrics.classification import MulticlassJaccardIndex
import random

In [None]:
# Define the project directory path
project_dir = '/content/gdrive/MyDrive/'

# Define the name of the folder containining the datasets. All data are in tiff format (or equivalent).
# The expected data directory structure is as follows:
# Datasets
# |_Sample1
# |  |_images
# |     |_image1.tiff
# |     |_image2.tiff
# |     |_...
# |  |_masks
# |     |_mask1.tiff
# |     |_mask2.tiff
# |     |_...
# |  |_...
# ...
dataset_name = "some_dataset"
sample_name = "sample3"
num_classes = 3
crop_size = (560, 560)

######

# Add the DinoV2 code directory to the system path for module imports
sys.path.append(os.path.join(project_dir, "code/DinoV2/"))

# Define the data directory path within the project directory
data_directory = os.path.join(project_dir, 'data')

# Define the input directory path to the dataset for segmentation
input_directory = os.path.join(data_directory, dataset_name, sample_name)

# **Various functions**




In [None]:
def k_means(image, k=3, attempts=10):
    # Reshaping the image into a 2D array of pixels
    pixel_vals = image.reshape((-1, 1))

    # Convert to float type
    pixel_vals = np.float32(pixel_vals)

    # Define criteria for the algorithm to stop running
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)

    # Perform k-means clustering
    retval, labels, centers = cv2.kmeans(pixel_vals, k, None, criteria, attempts, cv2.KMEANS_RANDOM_CENTERS)

    # Convert data into 8-bit values
    centers = np.uint8(centers)
    segmented_data = centers[labels.flatten()]

    # Reshape data into the original image dimensions
    segmented_image = segmented_data.reshape((image.shape))

    return segmented_image

# Define the Non-Local Means filter function
def non_local_means_filter(image, h=10, templateWindowSize=7, searchWindowSize=21):
    return cv2.fastNlMeansDenoising(image, None, h, templateWindowSize, searchWindowSize)

# Define a function to center crop the image
def center_crop(image, crop_size):
    height, width = image.shape[:2]
    crop_height, crop_width = crop_size
    if height < crop_height or width < crop_width:
        raise ValueError("Crop size must be smaller than the image size")
    top = (height - crop_height) // 2
    left = (width - crop_width) // 2
    cropped_image = image[top:top + crop_height, left:left + crop_width]
    return cropped_image

def mapping(target, pred, num_classes):
  unique_values_target = np.unique(target)
  unique_values_pred = np.unique(pred)

  # Check that both sets have the same unique values
  if np.array_equal(unique_values_target, unique_values_pred):
      list_tmp = list(range(num_classes))
      # If both sets are mapped using int ranging from 0 to num_classes
      if np.array_equal(unique_values_target, list_tmp):
          pass  # continue
      else:
        # Create a mapping dictionary for cropped_mask
        mapping_dict_target = {old_val: new_val for old_val, new_val in zip(unique_values_target, list_tmp)}
        # Remap cropped_mask
        copy_target = np.copy(target)
        for old_val, new_val in mapping_dict_target.items():
            target[copy_target == old_val] = new_val

        # Create a mapping dictionary for otsu_segmented
        mapping_dict_pred = {old_val: new_val for old_val, new_val in zip(unique_values_pred, list_tmp)}

        # Remap otsu_segmented
        copy_pred = np.copy(pred)
        for old_val, new_val in mapping_dict_pred.items():
            pred[copy_pred == old_val] = new_val
  else:
      list_tmp = list(range(num_classes))

      # Create a mapping dictionary for cropped_mask
      mapping_dict_target = {old_val: new_val for old_val, new_val in zip(unique_values_target, list_tmp)}

      # Remap cropped_mask
      copy_target = np.copy(target)
      for old_val, new_val in mapping_dict_target.items():
          target[copy_target == old_val] = new_val

      # Create a mapping dictionary for otsu_segmented
      mapping_dict_pred = {old_val: new_val for old_val, new_val in zip(unique_values_pred, list_tmp)}

      # Remap otsu_segmented
      copy_pred = np.copy(pred)
      for old_val, new_val in mapping_dict_pred.items():
          pred[copy_pred == old_val] = new_val

  return target, pred


def process_images(input_directory, crop_size, num_classes):
    image_files = [f for f in os.listdir(os.path.join(input_directory, 'images')) if os.path.isfile(os.path.join(input_directory, 'images', f))]
    mask_files = [f for f in os.listdir(os.path.join(input_directory, 'masks')) if os.path.isfile(os.path.join(input_directory, 'masks', f))]
    jaccard = torchmetrics.classification.MulticlassJaccardIndex(num_classes=num_classes)

    def extract_number(file_name):
        match = re.search(r'\d{4}', file_name)
        return int(match.group()) if match else 0

    image_files = sorted(image_files, key=extract_number)
    mask_files = sorted(mask_files, key=extract_number)

    combined = list(zip(image_files, mask_files))

    results = []
    list_iou = []

    total_iou = 0
    count = 0

    for img_file, mask_file in tqdm(combined, desc="Loading images and masks", total=len(image_files)):
        img_path = os.path.join(input_directory, 'images', img_file)
        mask_path = os.path.join(input_directory, 'masks', mask_file)

        img = io.imread(img_path, as_gray=True)
        mask = io.imread(mask_path, as_gray=True)

        filtered_image = non_local_means_filter(img, h=15, templateWindowSize=7, searchWindowSize=21)

        # Call the k_means function
        k_means_segmented = k_means(filtered_image, k=num_classes)

        cropped_scanner = center_crop(img, crop_size)
        cropped_pred = center_crop(k_means_segmented, crop_size)
        cropped_mask = center_crop(mask, crop_size)

        cropped_mask, cropped_pred = mapping(cropped_mask, cropped_pred, num_classes)

        # Convert to tensors with uint8 data type
        ground_truth_mask_tensor = torch.tensor(cropped_mask, dtype=torch.uint8)
        kmeans_mask_tensor = torch.tensor(cropped_pred, dtype=torch.uint8)

        # Add batch dimension
        ground_truth_mask_tensor = ground_truth_mask_tensor.unsqueeze(0)
        kmeans_mask_tensor = kmeans_mask_tensor.unsqueeze(0)

        try:
            mIoU = jaccard(kmeans_mask_tensor, ground_truth_mask_tensor).item()
        except Exception as e:
            print(f"Error calculating mIoU: {e}")
            continue

        total_iou += mIoU
        count += 1
        list_iou.append(mIoU)

        results.append((cropped_scanner, cropped_mask, cropped_pred))

    if count > 0:
        average_iou = total_iou / count
    else:
        average_iou = 0
        print("No valid image-mask pairs processed.")

    return average_iou, results, list_iou

def display_random_set(image_mask_pred_list):
    random_set = random.choice(image_mask_pred_list)
    img, ground_truth_mask, pred_mask = random_set
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(img, cmap='gray')
    ax[0].set_title('Image')
    ax[1].imshow(ground_truth_mask, cmap='gray')
    ax[1].set_title('Ground Truth Mask')
    ax[2].imshow(pred_mask, cmap='gray')
    ax[2].set_title('FCM segmented')
    plt.show()

# **Segmentation**

In [None]:
average_iou, results, iou = process_images(input_directory, crop_size, num_classes)
print(f"Average IoU: {average_iou:.4f}")

# **Display**

In [None]:
display_random_set(results)