#**Libraries / path definition / various functions**

In [None]:
!pip install torchmetrics

In [None]:
import cv2
import sys
import joblib
import numpy as np
from skimage import io, feature
from sklearn.ensemble import RandomForestClassifier
from sklearn.utils import shuffle
import torch
import torchmetrics
from PIL import Image
from tqdm import tqdm
import os
import re
from functools import partial
import torch
import torchmetrics
from torchmetrics import JaccardIndex
import random
import matplotlib.pyplot as plt

In [None]:
# Define the project directory path
project_dir = '/content/gdrive/MyDrive/'

# Add the DinoV2 code directory to the system path for module imports
sys.path.append(os.path.join(project_dir, "code/DinoV2/"))

# Define the runs directory path within the project directory
runs_directory = os.path.join(project_dir, 'runs')

# Define the data directory path within the project directory
data_directory = os.path.join(project_dir, 'data')

In [None]:
# The expected data directory structure is as follows:
# Datasets
# |_Sample1
# |  |_img
# |     |_image1.tiff
# |     |_image2.tiff
# |     |_...
# |  |_mask
# |     |_mask1.tiff
# |     |_mask2.tiff
# |     |_...
# |  |_...
# ...

# Set the input directories for the training sets (training on two rock sets)
image_dir1 = data_directory + '/Alhammadi/sample1/images'
mask_dir1 = data_directory + '/Alhammadi/sample1/masks'
image_dir2 = data_directory + '/Alhammadi/sample2/images'
mask_dir2 = data_directory + '/Alhammadi/sample2/masks'

# Se the directory for inference (predicting on one rock set)
inference_directory = data_directory + '/Alhammadi/sample3/'

# Define the number of samples to be used for training or processing
num_samples = 1000

# Define the size to which images will be cropped. Larger crop will cost more RAM
crop_size = (100, 100)

# Create the directory path where the model will be saved, incorporating the number of samples and crop size
save_directory = runs_directory + '/RF_' + f"{num_samples}_{crop_size[0]}/"

# Define a mapping from pixel values to class labels. Adapt accordingly
mapping = {85: 0, 170: 1, 255: 2}

# Define the number of classes in the segmentation task
num_classes = 3

In [None]:
def non_local_means_filter(image, h=10, templateWindowSize=7, searchWindowSize=21):
    """
    Apply Non-Local Means Denoising filter to an image.

    Parameters:
        image (ndarray): Input image.
        h (int): Parameter regulating filter strength.
        templateWindowSize (int): Size in pixels of the template patch.
        searchWindowSize (int): Size in pixels of the window used to compute weighted average.

    Returns:
        ndarray: Denoised image.
    """
    return cv2.fastNlMeansDenoising(image, None, h, templateWindowSize, searchWindowSize)

def load_images_and_masks(image_dirs, mask_dirs, limit=None, crop_size=None):
    """
    Load images and corresponding masks from directories.

    Parameters:
        image_dirs (list): List of directories containing images.
        mask_dirs (list): List of directories containing masks.
        limit (int, optional): Limit on the number of images and masks to load.
        crop_size (tuple, optional): Size to which images and masks should be cropped.

    Returns:
        tuple: Lists of images and masks.
    """
    assert len(image_dirs) == len(mask_dirs), "The number of image directories and mask directories must be the same"

    all_image_files = []
    all_mask_files = []

    for image_dir, mask_dir in zip(image_dirs, mask_dirs):
        image_files = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]
        mask_files = [os.path.join(mask_dir, f) for f in os.listdir(mask_dir) if os.path.isfile(os.path.join(mask_dir, f))]

        # Define a function to extract the number from the file name
        def extract_number(file_name):
            match = re.search(r'\d{4}', file_name)
            return int(match.group()) if match else 0

        # Sort the image and mask files based on the number in their names
        image_files = sorted(image_files, key=extract_number)
        mask_files = sorted(mask_files, key=extract_number)

        # Shuffle the combined list of image and mask pairs
        combined = list(zip(image_files, mask_files))
        random.shuffle(combined)
        image_files, mask_files = zip(*combined)

        # Limit files from each directory if a limit is specified
        if limit is not None:
            half_limit = limit // len(image_dirs)
            image_files = image_files[:half_limit]
            mask_files = mask_files[:half_limit]

        all_image_files.extend(image_files)
        all_mask_files.extend(mask_files)

    # Shuffle the combined list of image and mask pairs
    combined = list(zip(all_image_files, all_mask_files))
    random.shuffle(combined)
    all_image_files, all_mask_files = zip(*combined)

    images = []
    masks = []

    for img_path, mask_path in tqdm(zip(all_image_files, all_mask_files), desc="Loading images and masks", total=len(all_image_files)):
        img = io.imread(img_path, as_gray=True)
        mask = io.imread(mask_path, as_gray=True)

        if crop_size is not None:
            img = center_crop(img, crop_size)
            mask = center_crop(mask, crop_size)

        img = non_local_means_filter(img, h=15, templateWindowSize=7, searchWindowSize=21)

        images.append(img)
        masks.append(mask)

    return images, masks

def center_crop(image, crop_size):
    """
    Crop the center of an image.

    Parameters:
        image (ndarray): Input image.
        crop_size (tuple): Size to which the image should be cropped.

    Returns:
        ndarray: Cropped image.
    """
    height, width = image.shape[:2]
    crop_height, crop_width = crop_size

    if height < crop_height or width < crop_width:
        raise ValueError("Crop size must be smaller than the image size")

    top = (height - crop_height) // 2
    left = (width - crop_width) // 2

    cropped_image = image[top:top + crop_height, left:left + crop_width]

    return cropped_image

def train_segmenter(features, masks):
    """
    Train a Random Forest classifier for image segmentation.

    Parameters:
        features (list): List of feature arrays.
        masks (list): List of mask arrays.

    Returns:
        RandomForestClassifier: Trained classifier.
    """
    features = np.array(features)
    X = features.reshape(-1, features.shape[-1])
    masks = np.array(masks)
    y = masks.reshape(-1)
    X, y = shuffle(X, y, random_state=0)

    clf = RandomForestClassifier(
        n_estimators=150,
        n_jobs=-1,
        bootstrap=True,
        min_samples_leaf=1,
        min_samples_split=2,
        criterion='entropy',
        max_features=None,
        max_depth=15,
        max_samples=0.1
    )

    clf.fit(X, y)

    return clf

def extract_features(images, sigma_min=1, sigma_max=16):
    """
    Extract multiscale basic features from a list of images.

    Parameters:
        images (list): List of images.
        sigma_min (int): Minimum sigma value for feature extraction.
        sigma_max (int): Maximum sigma value for feature extraction.

    Returns:
        list: List of feature arrays.
    """
    features_func = partial(
        feature.multiscale_basic_features,
        intensity=True,
        edges=False,
        texture=True,
        sigma_min=sigma_min,
        sigma_max=sigma_max,
        channel_axis=None  # Explicitly set channel_axis to None for grayscale images
    )

    features = []
    for img in tqdm(images, desc='Extracting features'):
        img_features = features_func(img)
        # Reshape the features to flatten them
        flattened_features = img_features.reshape(-1, img_features.shape[-1])
        features.append(flattened_features)

    return features

def extract_features_single_image(img, sigma_min=1, sigma_max=16):
    """
    Extract multiscale basic features from a single image.

    Parameters:
        img (ndarray): Input image.
        sigma_min (int): Minimum sigma value for feature extraction.
        sigma_max (int): Maximum sigma value for feature extraction.

    Returns:
        ndarray: Flattened feature array.
    """
    features_func = partial(
        feature.multiscale_basic_features,
        intensity=True,
        edges=False,
        texture=True,
        sigma_min=sigma_min,
        sigma_max=sigma_max,
        channel_axis=None  # Explicitly set channel_axis to None for grayscale images
    )

    img_features = features_func(img)
    # Reshape the features to flatten them
    flattened_features = img_features.reshape(-1, img_features.shape[-1])
    return flattened_features

def segment_image(img, clf):
    """
    Segment an image using a trained classifier.

    Parameters:
        img (ndarray): Input image.
        clf (RandomForestClassifier): Trained classifier.

    Returns:
        ndarray: Segmented image.
    """
    features = extract_features_single_image(img)
    features = np.vstack(features)
    segmented_flat = clf.predict(features)
    segmented_image = segmented_flat.reshape(img.shape)
    return segmented_image

def process_images(input_directory, crop_size, clf, mapping, num_classes):
    """
    Process images and calculate metrics.

    Parameters:
        input_directory (str): Directory containing images and masks.
        crop_size (int): Size to which images and masks should be cropped.
        clf (RandomForestClassifier): Trained classifier.
        mapping (dict): Mapping of intensity values to class indices.
        num_classes (int): Number of classes.

    Returns:
        tuple: Average IoU and list of tuples (image, ground truth mask, predicted mask).
    """
    jaccard = torchmetrics.classification.MulticlassJaccardIndex(num_classes=num_classes)

    image_dir = os.path.join(input_directory, 'images')
    mask_dir = os.path.join(input_directory, 'masks')

    image_files = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]
    mask_files = [f for f in os.listdir(mask_dir) if os.path.isfile(os.path.join(mask_dir, f))]

    # Define a function to extract the number from the file name
    def extract_number(file_name):
        match = re.search(r'\d{4}', file_name)
        return int(match.group()) if match else 0

    # Sort the image and mask files based on the number in their names
    image_files = sorted(image_files, key=extract_number)
    mask_files = sorted(mask_files, key=extract_number)

    list_segmented_img = []
    image_mask_pred_list = []

    for image_file in tqdm(image_files, desc="Processing images"):
        img_path = os.path.join(image_dir, image_file)
        img = io.imread(img_path, as_gray=True)
        img = center_crop(img, crop_size)
        img = non_local_means_filter(img, h=15, templateWindowSize=7, searchWindowSize=21)

        segmented_img = segment_image(img, clf)
        list_segmented_img.append((img, segmented_img))

    # Initialize total IoU and count
    total_iou = 0
    count = 0

    # Process each mask file
    for i, mask_file in enumerate(tqdm(mask_files, desc="Calculating metrics")):
        # Load the ground truth mask and the segmented mask
        mask_path = os.path.join(mask_dir, mask_file)
        ground_truth_mask = np.array(Image.open(mask_path).convert('L'))
        img, pred_mask = list_segmented_img[i]

        # Crop the center of the images
        ground_truth_mask = center_crop(ground_truth_mask, crop_size)

        # Convert to PyTorch tensors
        ground_truth_mask_tensor = torch.tensor(ground_truth_mask, dtype=torch.int64)
        pred_mask_tensor = torch.tensor(pred_mask, dtype=torch.int64)

        # Apply the mapping to the tensors
        ground_truth_mask_tensor = ground_truth_mask_tensor.apply_(lambda x: mapping[x])
        pred_mask_tensor = pred_mask_tensor.apply_(lambda x: mapping[x])

        # Ensure the tensors have the same shape
        if ground_truth_mask_tensor.shape != pred_mask_tensor.shape:
            print(f"Skipping image {image_file} due to shape mismatch.")
            continue

        # Calculate IoU using torchmetrics
        mIoU = jaccard(pred_mask_tensor, ground_truth_mask_tensor).item()

        # Update total IoU and count
        total_iou += mIoU
        count += 1

        # Append the tuple (image, mask, prediction) to the list
        image_mask_pred_list.append((img, ground_truth_mask, pred_mask))

    # Calculate average IoU
    average_iou = total_iou / count

    return average_iou, image_mask_pred_list

def display_random_set(image_mask_pred_list):
    # Select a random tuple from the list
    random_set = random.choice(image_mask_pred_list)

    # Unpack the tuple
    img, ground_truth_mask, pred_mask = random_set

    # Display the images
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(img, cmap='gray')
    ax[0].set_title('Image')
    ax[1].imshow(ground_truth_mask, cmap='gray')
    ax[1].set_title('Ground Truth Mask')
    ax[2].imshow(pred_mask, cmap='gray')
    ax[2].set_title('Prediction Mask')
    plt.show()

# **Training**

In [None]:
# Load images and masks from the specified directories and process them
images, masks = load_images_and_masks([image_dir1, image_dir2], [mask_dir1, mask_dir2], num_samples, crop_size)

# Extract features from the loaded images
features = extract_features(images)

# Train the segmenter model using the extracted features and corresponding masks
clf = train_segmenter(features, masks)

# Check if the save directory exists, if not, create it
if not os.path.exists(save_directory):
    os.makedirs(save_directory)

# Save the trained classifier to the specified directory with a filename that includes the number of samples
joblib.dump(clf, save_directory + 'clf_' + str(num_samples) + '.pkl')

# **Inference**

In [None]:
# Load the trained classifier from the specified directory
clf = joblib.load(save_directory + 'clf_' + str(num_samples) + '.pkl')

# Process images using the loaded classifier to compute the Intersection over Union (IoU) and generate a list of processed images
iou, list_img = process_images(inference_directory, crop_size, clf, mapping, num_classes)

# Print the average IoU to evaluate the performance of the segmentation
print("average iou = " + str(iou))

# **Display**

In [None]:
display_random_set(list_img)