In [None]:
import os
import glob

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from natsort import natsorted
import re
import seaborn as sns


os.environ["OPENCV_IO_MAX_result_PIXELS"] = pow(2,40).__str__()

import cv2


In [None]:
def load_image(path):
    return cv2.imread(path)

def convert_to_grayscale(image):
    return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

def threshold_image(grayscale_image, threshold):
    _, binary_image = cv2.threshold(grayscale_image, threshold[0], threshold[1], cv2.THRESH_BINARY)
    return binary_image

def post_process(binary_image, dilation):
    # Apply morphological operations
    kernel = np.ones((dilation,dilation), np.uint8)
    cleaned_image = cv2.morphologyEx(binary_image, cv2.MORPH_CLOSE, kernel)
    return cleaned_image

def extract_contours(binary_image):
    contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    return contours

def create_binary_mask(image_shape, contours):
    mask = np.zeros(image_shape, dtype=np.uint8)
    cv2.drawContours(mask, contours, -1, (255), thickness=cv2.FILLED)
    return mask

def compute_jaccard_index(mask1, mask2):
    # Flatten the masks
    mask1_flat = mask1.flatten() / 255  # Convert from 0-255 to 0-1 range
    mask2_flat = mask2.flatten() / 255
    intersection = np.sum(mask1_flat*mask2_flat)
    return (intersection + 1.0)/(np.sum(mask1_flat) + np.sum(mask2_flat) - intersection + 1.0)

def process_image(path, threshold=(50,255), dilation=25):
    image = load_image(path)
    gray_image = convert_to_grayscale(image)
    binary_image = threshold_image(gray_image, threshold)
    cleaned_image = post_process(binary_image, dilation)
    contours = extract_contours(cleaned_image)
    mask = create_binary_mask(cleaned_image.shape, contours)
    return mask


In [None]:
he_aligned_images = glob.glob("./data/66-4/processed_32/clean_dust_bubbles/registered/elastic registration/*.tif")

he_subgroups = []
results_set1 = []
results_set2 = []
results_set3 = []
results_set4 = []

for group in natsorted(os.listdir("./data/grouped_multiplex")):
    multiplex_subgroup = glob.glob(os.path.join("./data/grouped_multiplex/", group, "registered","*.tif"))
    min_slice = int(os.path.basename(multiplex_subgroup[0]).split(".")[0])
    max_slice = int(os.path.basename(multiplex_subgroup[-1]).split(".")[0])
    # Trigger to avoid running the code again once the HE image is copied
    if len(multiplex_subgroup) > 4:
        break
    for file in he_aligned_images:
        match = re.search(r'_(\d+)_', file)
        number = match.group(1)
        if int(number) > min_slice and int(number)< max_slice:
            he_subgroups.append(file)
            #shutil.copy(file, os.path.join("./data/grouped_multiplex", group, str(number)+".tif"))
            break

for i, group in enumerate(natsorted(os.listdir("./data/grouped_multiplex"))):
    he_path = he_subgroups[i]
    multiplex_subgroup = glob.glob(os.path.join("./data/grouped_multiplex/", group, "registered", "*.tif"))

    for i, _path in enumerate(multiplex_subgroup):
        mask_multiplex = process_image(_path, threshold=(50,255), dilation=30)
        mask_he = process_image(he_path, threshold=(240,255), dilation=25)
        jaccard_index = compute_jaccard_index(mask_multiplex, mask_he)

        if i+1 == 1:
            results_set1.append(jaccard_index)
        elif i+1 == 2: 
            results_set2.append(jaccard_index)
        elif i+1 == 3: 
            results_set3.append(jaccard_index)
        elif i+1 == 4: 
            results_set4.append(jaccard_index)

        
        #print(f"Jaccard Index: {jaccard_index}")

data = [results_set1, results_set2, results_set3, results_set4]

plt.boxplot(data)

In [None]:
# Combine data into a DataFrame
data = {'List': ['Set 1']*len(results_set1) + ['Set 2']*len(results_set2) + ['Set 3']*len(results_set3) + ['Set 4']*len(results_set4),
        'Values': results_set1 + results_set2 + results_set3 + results_set4}

df = pd.DataFrame(data)

# Set the style of the visualization
sns.set(style="whitegrid")

# Create a box plot
plt.figure(figsize=(6, 4), dpi=600)
sns.boxplot(x='List', y='Values', width=0.4, data=df)

plt.ylabel('Jaccard Index')

# Show plot
plt.show()