## Imports

In [1]:
%load_ext autoreload
%autoreload 2

# Standard imports
import glob

# 3rd party imports
import cv2
import matplotlib.pyplot as plt
import numpy as np 
from pprint import pprint
import SimpleITK as sitk
sitk.ProcessObject_SetGlobalWarningDisplay(False)
from scipy import ndimage

## Functions

In [2]:
### Functions for IO operations
def load_channels(filepath: str, idx: int) -> np.ndarray:
    """
    Load the image channels from the given filepath.
    
    Args:
        filepath: str, path to the image
        idx: int, index of the image to load
        
    Returns:
        channels: np.ndarray, image channels
    """
    filepaths = sorted(glob.glob(filepath))
    print(f"Found {int(len(filepaths)/2)} slices")
    pprint(filepaths)
    
    # Load the image
    file_idx = idx * 2  # Multiply by 2 because we have 2 channels and they're stored in pairs
    curr_img = (read_tif(filepaths[file_idx]), read_tif(filepaths[file_idx + 1]))
    print("\nCh1:", filepaths[file_idx])
    print("Ch2:", filepaths[file_idx + 1])
    curr_ch1 = curr_img[0]
    curr_ch2 = curr_img[1]

    # Check image stats
    print(f"\nChannel shape: {curr_ch1.shape}")
    print(f"Channel dtype: {curr_ch1.dtype}")
    print(f"Channel 1 min: {curr_ch1.min()}")
    print(f"Channel 1 max: {curr_ch1.max()}")
    print(f"Channel 1 mean: {curr_ch1.mean()}")
    
    return curr_ch1, curr_ch2


def load_3_channels(filepath: str, idx: int) -> np.ndarray:
    """
    Load the image channels from the given filepath.
    
    Args:
        filepath: str, path to the image
        idx: int, index of the image to load
        
    Returns:
        channels: np.ndarray, image channels
    """
    filepaths = sorted(glob.glob(filepath))
    print(f"Found {int(len(filepaths)/3)} slices")
    pprint(filepaths)
    
    # Load the image
    file_idx = idx * 3  # Multiply by 2 because we have 2 channels and they're stored in pairs
    curr_img = (read_tif(filepaths[file_idx]), read_tif(filepaths[file_idx + 1]), read_tif(filepaths[file_idx + 2]))
    print("\nCh1:", filepaths[file_idx])
    print("Ch2:", filepaths[file_idx + 1])
    print("Ch3:", filepaths[file_idx + 2])
    curr_ch1 = curr_img[0]
    curr_ch2 = curr_img[1]
    curr_ch3 = curr_img[2]
    
    # Check image stats
    print(f"\nChannel shape: {curr_ch1.shape}")
    print(f"Channel dtype: {curr_ch1.dtype}")
    print(f"Channel 1 min: {curr_ch1.min()}")
    print(f"Channel 1 max: {curr_ch1.max()}")
    print(f"Channel 1 mean: {curr_ch1.mean()}")

    return curr_ch1, curr_ch2, curr_ch3



def read_tif(filepath):
    """
    Read tiff files using SimpleITK
    
    Args:
        filepath: str, path to tiff file
        
    Returns:
        image: np.ndarray, tiff image
    """
    image = sitk.ReadImage(filepath)
    image = sitk.GetArrayFromImage(image)
    return image


def auto_contrast(data: np.ndarray, alpha: float = None, beta: float = None) -> np.ndarray:
    """
    Preprocess tiff files to automatically adjust brightness and contrast.
    https://stackoverflow.com/questions/56905592/automatic-contrast-and-brightness-adjustment-of-a-color-photo-of-a-sheet-of-pape
    """
    if not alpha:
        alpha = np.iinfo(data.dtype).max / (np.max(data) - np.min(data))
    if not beta:
        beta = -np.min(data) * alpha
    img = cv2.convertScaleAbs(data.copy(), alpha=alpha, beta=beta)
    return img


def gamma_correction(image: np.ndarray, gamma: float=2.0, min_value=None, max_value=None) -> np.ndarray:
    """
    Apply gamma correction to the image.
    
    Args:
        image: np.ndarray, input image
        gamma: float, gamma value
        
    Returns:
        image_enhanced: np.ndarray, gamma corrected image
    """
    if min_value is not None:
        image = image.copy()
        image[image < min_value] = 0
    if max_value is None:
        max_value = image.max()
    else:
        image = image.copy()
        image[image > max_value] = max_value
    # Normalize the image to the range [0, 1]
    image_normalized = image / max_value
    # Apply the exponential transformation
    image_enhanced = np.power(image_normalized, gamma)
    # Rescale the image back to the original intensity range
    image_enhanced = image_enhanced * max_value
    return image_enhanced


def save_figure(image, filename, contours=None):
    """
    Save figure to disk.
    
    Args:
        image: np.ndarray, input image
        filename: str, path to save the image
        contours: np.ndarray, contours to overlay on the image
    """
    plt.figure(figsize=(20, 20))
    plt.imshow(image, cmap='gray')
    if contours is not None:
        plt.contour(contours, colors='red', linewidths=0.15, alpha=0.35)
    plt.axis('off')
    plt.savefig(filename, dpi=600, bbox_inches='tight')
    print(f"Saved figure to {filename}")
    
    
def show(image: np.ndarray, contour: np.ndarray = None,
         image2: np.ndarray = None, contour2: np.ndarray = None, contour_alpha: float = 0.75,
         title: str = "", title2: str = "", 
         xlim: tuple[int, int] = None, ylim: tuple[int, int] = None,
         xlim2: tuple[int, int] = None, ylim2: tuple[int, int] = None,
         axis: bool = True,
         figsize: tuple[int, int] = (10, 10)):
    """
    Display the image.
    
    Args:
        image: np.ndarray, input image
        title: str, title of the image
    """
    f = plt.figure(figsize=figsize)
    # If there are two images, display them side by side
    if image2 is not None:
        plt.subplot(1, 2, 1)
        plt.imshow(image, cmap='gray')
        plt.title(title)
        if contour is not None:
            plt.contour(contour, colors='red', linewidths=0.5, alpha=contour_alpha)
        if xlim is not None:
            plt.xlim(xlim)
        if ylim is not None:
            plt.ylim(ylim)
        plt.axis(axis)
        plt.subplot(1, 2, 2)
        plt.imshow(image2, cmap='gray')
        plt.title(title2)
        if contour2 is not None:
            plt.contour(contour2, colors='red', linewidths=0.5, alpha=contour_alpha)
        if xlim2 is not None:
            plt.xlim(xlim2)
        if ylim2 is not None:
            plt.ylim(ylim2)
        plt.axis(axis)
    # If there is only one image, display it
    else:
        plt.imshow(image, cmap='gray')
        plt.title(title)
        if contour is not None:
            plt.contour(contour, colors='red', linewidths=0.5, alpha=contour_alpha)
        if xlim is not None:
            plt.xlim(xlim)
        if ylim is not None:
            plt.ylim(ylim)
        plt.axis(axis)
    plt.show()
    f.clear()
    plt.close(f)
    
    
def show3(image: np.ndarray, contour: np.ndarray = None,
          image2: np.ndarray = None, contour2: np.ndarray = None, 
          image3: np.ndarray = None, contour3: np.ndarray = None, 
          contour_alpha: float = 0.75,
          title: str = "", title2: str = "", title3: str = "",
          xlim: tuple[int, int] = None, ylim: tuple[int, int] = None,
          xlim2: tuple[int, int] = None, ylim2: tuple[int, int] = None,
          xlim3: tuple[int, int] = None, ylim3: tuple[int, int] = None,
          axis: bool = True,
          figsize: tuple[int, int] = (20, 10)):
    """
    Display the image.
    
    Args:
        image: np.ndarray, input image
        title: str, title of the image
    """
    f = plt.figure(figsize=figsize)
    plt.subplot(1, 3, 1)
    plt.imshow(image, cmap='gray')
    plt.title(title)
    if contour is not None:
        plt.contour(contour, colors='red', linewidths=0.5, alpha=contour_alpha)
    if xlim is not None:
        plt.xlim(xlim)
    if ylim is not None:
        plt.ylim(ylim)
    plt.axis(axis)
    
    plt.subplot(1, 3, 2)
    plt.imshow(image2, cmap='gray')
    plt.title(title2)
    if contour2 is not None:
        plt.contour(contour2, colors='red', linewidths=0.5, alpha=contour_alpha)
    if xlim2 is not None:
        plt.xlim(xlim2)
    if ylim2 is not None:
        plt.ylim(ylim2)
    plt.axis(axis)
    
    plt.subplot(1, 3, 3)
    plt.imshow(image3, cmap='gray')
    plt.title(title3)
    if contour3 is not None:
        plt.contour(contour3, colors='red', linewidths=0.5, alpha=contour_alpha)
    if xlim3 is not None:
        plt.xlim(xlim3)
    if ylim3 is not None:
        plt.ylim(ylim3)
    plt.axis(axis)
    
    plt.show()
    f.clear()
    plt.close(f)

In [3]:
### Functions for vessel detection
import itk
import numpy as np
from skimage.morphology import remove_small_objects, binary_closing, disk, remove_small_holes

# Parameters for vessel detection
#ALPHA = 0.5  # Default 0.5
#BETA = 0.5  # Default 0.5
#GAMMA = 5.0  # Default 5

def detect_vessels(input_image: np.ndarray, min_sigma: float=1.0, max_sigma: float=10.0, num_steps: int=10,
                   alpha=0.5, beta=0.5, gamma=5.0):
    """
    Use the Hessian-based vesselness filter to detect vessels in the image.
    
    Args:
        input_image: np.ndarray, input image
        min_sigma: float, minimum sigma value
        max_sigma: float, maximum sigma value
        num_steps: int, number of steps
        
    Returns:
        segmented_vessels_array: np.ndarray, segmented vessels
    """
    # Run ITK
    input_image = itk.image_from_array(input_image)
    #input_image = itk.imread(input_image, itk.F)

    ImageType = type(input_image)
    Dimension = input_image.GetImageDimension()
    HessianPixelType = itk.SymmetricSecondRankTensor[itk.D, Dimension]
    HessianImageType = itk.Image[HessianPixelType, Dimension]

    objectness_filter = itk.HessianToObjectnessMeasureImageFilter[
        HessianImageType, ImageType
    ].New()
    objectness_filter.SetBrightObject(False)  # Set to True if the structures are bright on a dark background
    objectness_filter.SetScaleObjectnessMeasure(False)  # Set to True to scale the objectness measure by the scale
    objectness_filter.SetAlpha(alpha)  # Sensitivity to blob-like structures
                                     # Set/Get Alpha, the weight corresponding to R_A 
                                     # (the ratio of the smallest eigenvalue that has to be large to the larger ones). 
                                     # Smaller values lead to increased sensitivity to the object dimensionality.
    objectness_filter.SetBeta(beta)   # Sensitivity to plate-like structures - 1.0 default
                                     # Set/Get Beta, the weight corresponding to R_B 
                                     # (the ratio of the largest eigenvalue that has to be small to the larger ones). 
                                     # Smaller values lead to increased sensitivity to the object dimensionality.
    objectness_filter.SetGamma(gamma)  # Sensitivity to noise - 5.0 default
                                     # Set/Get Gamma, the weight corresponding to S 
                                     # (the Frobenius norm of the Hessian matrix, or second-order structureness)

    multi_scale_filter = itk.MultiScaleHessianBasedMeasureImageFilter[
        ImageType, HessianImageType, ImageType
    ].New()
    multi_scale_filter.SetInput(input_image)
    multi_scale_filter.SetHessianToMeasureFilter(objectness_filter)
    multi_scale_filter.SetSigmaStepMethodToLogarithmic()
    multi_scale_filter.SetSigmaMinimum(min_sigma)
    multi_scale_filter.SetSigmaMaximum(max_sigma)
    multi_scale_filter.SetNumberOfSigmaSteps(num_steps)

    OutputPixelType = itk.UC
    OutputImageType = itk.Image[OutputPixelType, Dimension]

    rescale_filter = itk.RescaleIntensityImageFilter[ImageType, OutputImageType].New()
    rescale_filter.SetInput(multi_scale_filter)
    rescale_filter.Update()

    # Get numpy array
    segmented_vessels = rescale_filter.GetOutput()
    segmented_vessels_array = itk.array_view_from_image(segmented_vessels)
    segmented_vessels_array = np.asarray(segmented_vessels_array, dtype=np.float32)
    return segmented_vessels_array


def process_vessels(vessel_image: np.ndarray, thresh: int, min_size: int=10, area_threshold: float=2000, smoothing: int=3):
    """
    Process the thresholded vessels.
    
    Args:
        vessel_image: np.ndarray, input image
        thresh: int, threshold value
        min_size: int, minimum size
        area_threshold: float, area threshold
        smoothing: int, smoothing factor
        
    Returns:
        thresholded_vessels: np.ndarray, thresholded vessels
    """
    # Process the thresholded vessels
    thresholded_vessels = vessel_image > thresh
    thresholded_vessels = np.invert(thresholded_vessels)

    # Get rid of small objects
    thresholded_vessels = remove_small_objects(thresholded_vessels, min_size=min_size)
    thresholded_vessels = remove_small_holes(thresholded_vessels, area_threshold=area_threshold)

    # Smoothen edges
    thresholded_vessels = binary_closing(thresholded_vessels, footprint=disk(smoothing))
    
    return thresholded_vessels


def get_brain_mask(brain_image, area_threshold=300000, min_size=10000):
    """
    Get the mask of the brain from the image (run before contrast enhancement).
    
    Args:
        brain_image: np.ndarray, input image
        thresh: int, threshold value
        area_threshold: int, area threshold
        
    Returns:
        mask: np.ndarray, mask of the brain
    """
    _, mask = cv2.threshold(brain_image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_TRIANGLE)
    mask = remove_small_holes(mask.astype(bool), area_threshold=area_threshold)
    mask = remove_small_objects(mask, min_size=min_size)
    return mask


In [4]:
### Functions for evaluation
import csv
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error
from scipy.spatial.distance import hamming

def dice_coefficient(binary_image1, binary_image2, epsilon=1e-10):
    """
    Compute the Dice coefficient between two binary images.
    
    Parameters:
    - binary_image1: First binary image (numpy array).
    - binary_image2: Second binary image (numpy array).
    
    Returns:
    - dice: Dice coefficient.
    """
    intersection = np.sum(binary_image1 * binary_image2)
    size1 = np.sum(binary_image1)
    size2 = np.sum(binary_image2)
    
    dice = (2. * intersection + epsilon) / (size1 + size2 + epsilon)
    return dice


def iou(binary_image1, binary_image2, epsilon=1e-10):
    """
    Compute the Intersection over Union (IoU) between two binary images.
    
    Parameters:
    - binary_image1: First binary image (numpy array).
    - binary_image2: Second binary image (numpy array).
    
    Returns:
    - iou: IoU.
    """
    intersection = np.sum(binary_image1 * binary_image2)
    union = np.sum(binary_image1 + binary_image2)
    
    if union == 0:
        iou = 1.0
    
    iou = (intersection + epsilon) / (union + epsilon)
    return iou


def precision(binary_image1, binary_image2, epsilon=1e-10):
    """
    Compute the precision between two binary images.
    
    Parameters:
    - binary_image1: First binary image (numpy array).
    - binary_image2: Second binary image (numpy array).
    
    Returns:
    - precision: Precision.
    """
    true_positives = np.sum(binary_image1 * binary_image2)
    false_positives = np.sum(binary_image1 * (1 - binary_image2))
    
    precision = (true_positives) / (true_positives + false_positives + epsilon)
    return precision


def recall(binary_image1, binary_image2, epsilon=1e-10):
    """
    Compute the recall between two binary images.
    
    Parameters:
    - binary_image1: First binary image (numpy array).
    - binary_image2: Second binary image (numpy array).
    
    Returns:
    - recall: Recall.
    """
    true_positives = np.sum(binary_image1 * binary_image2)
    false_negatives = np.sum((1 - binary_image1) * binary_image2)
    
    recall = true_positives / (true_positives + false_negatives + epsilon)
    return recall


def rand_index(binary_image1, binary_image2):
    """
    Compute the Rand index between two binary images.
    
    Parameters:
    - binary_image1: First binary image (numpy array).
    - binary_image2: Second binary image (numpy array).
    
    Returns:
    - rand_index: Rand index.
    """
    true_positives = np.sum(binary_image1 * binary_image2)
    false_positives = np.sum(binary_image1 * (1 - binary_image2))
    false_negatives = np.sum((1 - binary_image1) * binary_image2)
    true_negatives = np.sum((1 - binary_image1) * (1 - binary_image2))
    
    rand_index = (true_positives + true_negatives) / (true_positives + false_positives + false_negatives + true_negatives)
    return rand_index

## Load data

In [None]:
# IO parameters
filepath = "/media/data/u01/Fig3/M13/*/*.tif"
IDX = 4

# Load the image channels
curr_ch1, curr_ch2, curr_ch3 = load_3_channels(filepath, IDX)
curr_ch1 = curr_ch1.astype(np.float32)
curr_ch2 = curr_ch2.astype(np.float32)
curr_ch3 = curr_ch3.astype(np.float32)

In [None]:
# Ch1 settings
gamma_ch1 = 2  # You can adjust this value to control the contrast enhancement
contrast_alpha_ch1 = 0.00525  # Try 0.0225 You can adjust this value to control the brightness enhancement 0.5 default

# Ch2 settings
gamma_ch2 = 2  # You can adjust this value to control the contrast enhancement
contrast_alpha_ch2 = 0.0125  # Try 0.125 You can adjust this value to control the brightness enhancement 0.5 default

# Ch3 settings
gamma_ch3 = 2  # You can adjust this value to control the contrast enhancement
contrast_alpha_ch3 = 0.0425  # Try 0.125 You can adjust this value to control the brightness enhancement 0.5 default


# No change
contrast_ch1 = curr_ch1
contrast_ch2 = curr_ch2
contrast_ch3 = curr_ch3


cc_ch1 = auto_contrast(contrast_ch1, alpha=contrast_alpha_ch1)
cc_ch2 = auto_contrast(contrast_ch2, alpha=contrast_alpha_ch2)
cc_ch3 = auto_contrast(contrast_ch3, alpha=contrast_alpha_ch3)


bg_alpha = 0.25  # 0.5
bg_mask = auto_contrast(curr_ch1, alpha=bg_alpha)  # 7
bg_mask = get_brain_mask(bg_mask, area_threshold=25000)  # 255 default ch0, 150 for ch1

show(cc_ch1, bg_mask, title=f"Section {IDX} ch1 contrast", axis=True)
show(cc_ch2, bg_mask, title=f"Section {IDX} ch2 contrast", axis=True)
show(cc_ch3, bg_mask, title=f"Section {IDX} ch3 contrast", axis=True)

In [39]:
# Setup FOVs
fovs = [ 
    # xlim,             ylim
    ( (4000, 5000),     (4000, 5000) ),
    ( (2000, 3000),     (2000, 4000) ),
    ( (2000, 4000),     (4000, 6000) ),
    ( (5000, 6000),     (2000, 4000) ),
    ( (5000, 6000),     (1000, 2000) ),
    ( (5000, 6000),     (4000, 5000) ),
]

fov1_xlim, fov1_ylim = fovs[0][0], fovs[0][1][::-1]
fov2_xlim, fov2_ylim = fovs[1][0], fovs[1][1][::-1]
fov3_xlim, fov3_ylim = fovs[2][0], fovs[2][1][::-1]
fov4_xlim, fov4_ylim = fovs[3][0], fovs[3][1][::-1]
fov5_xlim, fov5_ylim = fovs[4][0], fovs[4][1][::-1]
fov6_xlim, fov6_ylim = fovs[5][0], fovs[5][1][::-1]

## Threshold for CH1 and CH2

In [40]:
# Create a threshold mask for the image
curr_ch1_median = ndimage.median_filter(curr_ch1.copy(), size=5)
curr_ch2_median = ndimage.median_filter(curr_ch2.copy(), size=5)  # Repeat for ch2
curr_ch3_median = ndimage.median_filter(curr_ch3.copy(), size=5)  # Repeat for ch2

# Create auto contrast brightened images
auto_ch1 = auto_contrast(curr_ch1, alpha=contrast_alpha_ch1)
auto_ch2 = auto_contrast(curr_ch2, alpha=contrast_alpha_ch2)
auto_ch3 = auto_contrast(curr_ch3, alpha=contrast_alpha_ch3)
auto_ch1_median = ndimage.median_filter(auto_ch1.copy(), size=5) # 5
auto_ch2_median = ndimage.median_filter(auto_ch2.copy(), size=5)
auto_ch3_median = ndimage.median_filter(auto_ch3.copy(), size=5)

## Run thresholding for CH1

In [50]:
# Ch1 settings
gamma_ch1 = 2  # You can adjust this value to control the contrast enhancement
contrast_alpha_ch1 = 0.00525  # Try 0.0225 You can adjust this value to control the brightness enhancement 0.5 default

# Ch2 settings
gamma_ch2 = 2  # You can adjust this value to control the contrast enhancement
contrast_alpha_ch2 = 0.0125  # Try 0.125 You can adjust this value to control the brightness enhancement 0.5 default

# Ch3 settings
gamma_ch3 = 2  # You can adjust this value to control the contrast enhancement
contrast_alpha_ch3 = 0.0425  # Try 0.125 You can adjust this value to control the brightness enhancement 0.5 default


THRESH = 4000   # 4000 for all of channel 1
THRESH2 = 3000  # 3000 for all of channel 2
THRESH3 = None
max_value = 20000
max_value3 = 2000

beta1 = 0.5  # Use beta=0.5 for all of ch1
beta2 = 1.0  # Use beta=1.0 for all of ch2
beta3 = None # Swap between 0.5 and 1.0

# Index by index basis
if IDX == 0:
    THRESH = 7500
    
    THRESH3 = 400
    beta3 = 1.0
elif IDX == 1:
    THRESH = 5000
    
    THRESH3 = 750
    beta3 = 0.5
elif IDX == 2:
    THRESH = 5000  # 5000
    
    THRESH3 = 750
    beta3 = 0.5
elif IDX == 3:
    THRESH = 5500  # 5000
    
    THRESH3 = 750
    beta3 = 0.5
elif IDX == 4:
    THRESH = 7000
    
    THRESH2 = 4000 # Specific to this index
    
    THRESH3 = 1000
    beta3 = 0.5

In [None]:
# Create contrast enhanced images
cc_ch1 = gamma_correction(curr_ch1, gamma=gamma_ch1)
cc_ch1 = auto_contrast(cc_ch1, alpha=contrast_alpha_ch1)
cc_ch1_alt = gamma_correction(curr_ch1_median, gamma=gamma_ch1, max_value=max_value)

cc_ch2 = gamma_correction(curr_ch2, gamma=gamma_ch2)
cc_ch2 = auto_contrast(cc_ch2, alpha=contrast_alpha_ch2)

cc_ch3 = gamma_correction(curr_ch3, gamma=gamma_ch3)
cc_ch3 = auto_contrast(cc_ch3, alpha=contrast_alpha_ch3)
cc_ch3_alt = gamma_correction(curr_ch3_median, gamma=gamma_ch3, max_value=max_value3)


print(f"Threshold for ch1: {THRESH}")
print(f"Threshold for ch2: {THRESH2}")
print(f"Threshold for ch3: {THRESH3}")


curr_ch1_thresh = cc_ch1_alt.copy() > THRESH
curr_ch1_thresh[curr_ch1_thresh != 0] = 1
curr_ch1_thresh = curr_ch1_thresh.astype(bool)

curr_ch2_thresh = curr_ch2_median.copy() > THRESH2
curr_ch2_thresh[curr_ch2_thresh != 0] = 1
curr_ch2_thresh = curr_ch2_thresh.astype(bool)

curr_ch3_thresh = cc_ch3_alt.copy() > THRESH3
curr_ch3_thresh[curr_ch3_thresh != 0] = 1
curr_ch3_thresh = curr_ch3_thresh.astype(bool)

##########################################
# FOV 1
##########################################

show3(image=cc_ch1_alt, title=f"Section {IDX} ch1 contrast alt FOV1", 
     xlim=fov1_xlim, ylim=fov1_ylim,
     contour=curr_ch1_thresh, 
     image2=auto_ch2, title2=f"Section {IDX} ch2 contrast FOV1", 
     contour2=curr_ch2_thresh, 
     xlim2=fov1_xlim, ylim2=fov1_ylim,
     image3=cc_ch3_alt, title3=f"Section {IDX} ch3 contrast alt FOV1", 
     contour3=curr_ch3_thresh, 
     xlim3=fov1_xlim, ylim3=fov1_ylim
     )

show3(image=auto_ch1, title=f"Section {IDX} ch1 brightened FOV1", 
     xlim=fov1_xlim, ylim=fov1_ylim,
     image2=auto_ch2, title2=f"Section {IDX} ch2 brightened FOV1", 
     xlim2=fov1_xlim, ylim2=fov1_ylim,
     image3=auto_ch3, title3=f"Section {IDX} ch3 brightened FOV1", 
     xlim3=fov1_xlim, ylim3=fov1_ylim
     )

show3(image=auto_ch1, title=f"Section {IDX} ch1 brightened FOV1", 
     contour=curr_ch1_thresh, 
     xlim=fov1_xlim, ylim=fov1_ylim,
     image2=auto_ch2, title2=f"Section {IDX} ch2 brightened FOV1", 
     contour2=curr_ch2_thresh, 
     xlim2=fov1_xlim, ylim2=fov1_ylim,
     image3=auto_ch3, title3=f"Section {IDX} ch3 brightened FOV1", 
     contour3=curr_ch3_thresh, 
     xlim3=fov1_xlim, ylim3=fov1_ylim
     )
print("##########################################")

##########################################
# FOV 2
##########################################

"""
show(image=cc_ch1, title=f"Section {IDX} ch1 contrast FOV2", 
     contour=curr_ch1_thresh, 
     xlim=fov2_xlim, ylim=fov2_ylim,
     image2=cc_ch2, title2=f"Section {IDX} ch2 contrast FOV2", 
     contour2=curr_ch2_thresh, 
     xlim2=fov2_xlim, ylim2=fov2_ylim)
"""

show3(image=cc_ch1_alt, title=f"Section {IDX} ch1 contrast alt FOV2", 
     contour=curr_ch1_thresh, 
     xlim=fov2_xlim, ylim=fov2_ylim,
     image2=auto_ch2, title2=f"Section {IDX} ch2 contrast FOV2", 
     contour2=curr_ch2_thresh, 
     xlim2=fov2_xlim, ylim2=fov2_ylim,
     image3=cc_ch3_alt, title3=f"Section {IDX} ch3 contrast alt FOV2", 
     contour3=curr_ch3_thresh, 
     xlim3=fov2_xlim, ylim3=fov2_ylim
     )

show3(image=auto_ch1, title=f"Section {IDX} ch1 brightened FOV2", 
     xlim=fov2_xlim, ylim=fov2_ylim,
     image2=auto_ch2, title2=f"Section {IDX} ch2 brightened FOV2", 
     xlim2=fov2_xlim, ylim2=fov2_ylim,
     image3=auto_ch3, title3=f"Section {IDX} ch3 brightened FOV2", 
     xlim3=fov2_xlim, ylim3=fov2_ylim
     )

show3(image=auto_ch1, title=f"Section {IDX} ch1 brightened FOV2", 
     contour=curr_ch1_thresh, 
     xlim=fov2_xlim, ylim=fov2_ylim,
     image2=auto_ch2, title2=f"Section {IDX} ch2 brightened FOV2", 
     contour2=curr_ch2_thresh, 
     xlim2=fov2_xlim, ylim2=fov2_ylim,
     image3=auto_ch3, title3=f"Section {IDX} ch3 brightened FOV2", 
     contour3=curr_ch3_thresh, 
     xlim3=fov2_xlim, ylim3=fov2_ylim
     )
print("##########################################")

##########################################
# FOV 3
##########################################

show3(image=cc_ch1_alt, title=f"Section {IDX} ch1 contrast alt FOV3", 
     contour=curr_ch1_thresh, 
     xlim=fov3_xlim, ylim=fov3_ylim,
     image2=auto_ch2, title2=f"Section {IDX} ch2 contrast FOV3", 
     contour2=curr_ch2_thresh, 
     xlim2=fov3_xlim, ylim2=fov3_ylim,
     image3=cc_ch3_alt, title3=f"Section {IDX} ch3 contrast alt FOV3", 
     contour3=curr_ch3_thresh, 
     xlim3=fov3_xlim, ylim3=fov3_ylim
     )

show3(image=auto_ch1, title=f"Section {IDX} ch1 brightened FOV3",
     xlim=fov3_xlim, ylim=fov3_ylim,
     image2=auto_ch2, title2=f"Section {IDX} ch2 brightened FOV3", 
     xlim2=fov3_xlim, ylim2=fov3_ylim,
     image3=auto_ch3, title3=f"Section {IDX} ch3 brightened FOV3", 
     xlim3=fov3_xlim, ylim3=fov3_ylim
     )

show3(image=auto_ch1, title=f"Section {IDX} ch1 brightened FOV3", 
     contour=curr_ch1_thresh,
     xlim=fov3_xlim, ylim=fov3_ylim,
     image2=auto_ch2, title2=f"Section {IDX} ch2 brightened FOV3", 
     contour2=curr_ch2_thresh, 
     xlim2=fov3_xlim, ylim2=fov3_ylim,
     image3=auto_ch3, title3=f"Section {IDX} ch3 brightened FOV3", 
     contour3=curr_ch3_thresh, 
     xlim3=fov3_xlim, ylim3=fov3_ylim
     )
print("##########################################")

##########################################
# FOV 4
##########################################

show3(image=cc_ch1_alt, title=f"Section {IDX} ch1 contrast alt FOV4", 
     contour=curr_ch1_thresh, 
     xlim=fov4_xlim, ylim=fov4_ylim,
     image2=auto_ch2, title2=f"Section {IDX} ch2 contrast FOV4", 
     contour2=curr_ch2_thresh, 
     xlim2=fov4_xlim, ylim2=fov4_ylim,
     image3=cc_ch3_alt, title3=f"Section {IDX} ch3 contrast alt FOV4", 
     contour3=curr_ch3_thresh, 
     xlim3=fov4_xlim, ylim3=fov4_ylim
     )

show3(image=auto_ch1, title=f"Section {IDX} ch1 brightened FOV4",
     xlim=fov4_xlim, ylim=fov4_ylim,
     image2=auto_ch2, title2=f"Section {IDX} ch2 brightened FOV4", 
     xlim2=fov4_xlim, ylim2=fov4_ylim,
     image3=auto_ch3, title3=f"Section {IDX} ch3 brightened FOV4", 
     xlim3=fov4_xlim, ylim3=fov4_ylim
     )

show3(image=auto_ch1, title=f"Section {IDX} ch1 brightened FOV4", 
     contour=curr_ch1_thresh,
     xlim=fov4_xlim, ylim=fov4_ylim,
     image2=auto_ch2, title2=f"Section {IDX} ch2 brightened FOV4", 
     contour2=curr_ch2_thresh, 
     xlim2=fov4_xlim, ylim2=fov4_ylim,
     image3=auto_ch3, title3=f"Section {IDX} ch3 brightened FOV4", 
     contour3=curr_ch3_thresh, 
     xlim3=fov4_xlim, ylim3=fov4_ylim
     )

print("##########################################")

##########################################
# FOV 5
##########################################

show3(image=cc_ch1_alt, title=f"Section {IDX} ch1 contrast alt FOV5", 
     contour=curr_ch1_thresh, 
     xlim=fov5_xlim, ylim=fov5_ylim,
     image2=auto_ch2, title2=f"Section {IDX} ch2 brightened FOV5", 
     contour2=curr_ch2_thresh, 
     xlim2=fov5_xlim, ylim2=fov5_ylim,
     image3=cc_ch3_alt, title3=f"Section {IDX} ch3 contrast alt FOV5", 
     contour3=curr_ch3_thresh, 
     xlim3=fov5_xlim, ylim3=fov5_ylim
     )

show3(image=auto_ch1, title=f"Section {IDX} ch1 brightened FOV5",
     xlim=fov5_xlim, ylim=fov5_ylim,
     image2=auto_ch2, title2=f"Section {IDX} ch2 brightened FOV5", 
     xlim2=fov5_xlim, ylim2=fov5_ylim,
     image3=auto_ch3, title3=f"Section {IDX} ch3 brightened FOV5", 
     xlim3=fov5_xlim, ylim3=fov5_ylim
     )

show3(image=auto_ch1, title=f"Section {IDX} ch1 brightened FOV5", 
     contour=curr_ch1_thresh,
     xlim=fov5_xlim, ylim=fov5_ylim,
     image2=auto_ch2, title2=f"Section {IDX} ch2 brightened FOV5", 
     contour2=curr_ch2_thresh, 
     xlim2=fov5_xlim, ylim2=fov5_ylim,
     image3=auto_ch3, title3=f"Section {IDX} ch3 brightened FOV5", 
     contour3=curr_ch3_thresh, 
     xlim3=fov5_xlim, ylim3=fov5_ylim
     )

print("##########################################")

##########################################
# FOV 6
##########################################

show3(image=cc_ch1_alt, title=f"Section {IDX} ch1 contrast alt FOV6", 
     contour=curr_ch1_thresh, 
     xlim=fov6_xlim, ylim=fov6_ylim,
     image2=auto_ch2, title2=f"Section {IDX} ch2 contrast FOV6", 
     contour2=curr_ch2_thresh, 
     xlim2=fov6_xlim, ylim2=fov6_ylim,
     image3=cc_ch3_alt, title3=f"Section {IDX} ch3 contrast alt FOV6", 
     contour3=curr_ch3_thresh, 
     xlim3=fov6_xlim, ylim3=fov6_ylim
     )

show3(image=auto_ch1, title=f"Section {IDX} ch1 brightened FOV6",
     xlim=fov6_xlim, ylim=fov6_ylim,
     image2=auto_ch2, title2=f"Section {IDX} ch2 brightened FOV6", 
     xlim2=fov6_xlim, ylim2=fov6_ylim,
     image3=auto_ch3, title3=f"Section {IDX} ch3 brightened FOV6", 
     xlim3=fov6_xlim, ylim3=fov6_ylim
     )

show3(image=auto_ch1, title=f"Section {IDX} ch1 brightened FOV6", 
     contour=curr_ch1_thresh,
     xlim=fov6_xlim, ylim=fov6_ylim,
     image2=auto_ch2, title2=f"Section {IDX} ch2 brightened FOV6", 
     contour2=curr_ch2_thresh, 
     xlim2=fov6_xlim, ylim2=fov6_ylim,
     image3=auto_ch3, title3=f"Section {IDX} ch3 brightened FOV6", 
     contour3=curr_ch3_thresh, 
     xlim3=fov6_xlim, ylim3=fov6_ylim
     )     

print("##########################################")

##########################################################
# FULL SECTION
##########################################################

show3(image=cc_ch1, title=f"Full section {IDX} ch1 contrast", 
      image2=cc_ch2, title2=f"Full section {IDX} ch2 contrast",
      image3=cc_ch3, title3=f"Full section {IDX} ch3 contrast",
      figsize=(20, 10), axis=True)

show3(image=auto_ch1, title=f"Full section {IDX} ch1 brightened", 
      image2=auto_ch2, title2=f"Full section {IDX} ch2 brightened",
      image3=auto_ch3, title3=f"Full section {IDX} ch3 brightened",
      figsize=(20, 10), axis=True)

show3(image=curr_ch1_thresh, title=f"Full section {IDX} ch1 thresholded", 
      image2=curr_ch2_thresh, title2=f"Full section {IDX} ch2 thresholded",
      image3=curr_ch3_thresh, title3=f"Full section {IDX} ch3 thresholded",
      figsize=(20, 10), axis=True)

#thresholded_vessels_ch1 = curr_ch1_thresh
#thresholded_vessels_ch2 = curr_ch2_thresh

## Hessian Filter
https://examples.itk.org/src/nonunit/review/segmentbloodvesselswithmultiscalehessianbasedmeasure/documentation

# CHANNEL 1

In [None]:
# Parameters for vessel detection
sigma_minimum = 1.0  # Range of scales in which MultiScaleHessianBasedMeasureImageFilter will search for vessels
sigma_maximum = 10.0  # 10
number_of_sigma_steps = 10  # 10 Number of scales to search for vessels

# Parameters for post-processing
thresh = 230  # Threshold for binarization, 230 and 25 and 15
min_size = 100  # Minimum size of objects to keep 100
area_threshold = 2000 # Minimum area of holes to keep
smoothing = 1  # Smoothing factor for closing, 3

#############################################################

# Alternative: load image in memory
input_image = auto_ch1_median * bg_mask
#input_image = cc_ch1_alt * bg_mask
input_image = input_image.astype(np.float32)
input_image *= 255.0

show(input_image, title="CH1: Input image", axis=False)

# Print statistics
print("Input image type:", input_image.dtype)
print("Input image min:", input_image.min())
print("Input image max:", input_image.max())

# Run the vessel detection
segmented_vessels_array = detect_vessels(input_image, sigma_minimum, sigma_maximum, number_of_sigma_steps,
                                         beta=beta1)

# Process the thresholded vessels
thresholded_vessels_ch1 = process_vessels(segmented_vessels_array, thresh=thresh, min_size=min_size, area_threshold=area_threshold, smoothing=smoothing)
thresholded_vessels_ch1 = thresholded_vessels_ch1 * bg_mask * curr_ch1_thresh

# Print statistics
print("Vesselness image statistics:")
print("Shape:", segmented_vessels_array.shape)
print("Min:", segmented_vessels_array.min())
print("Max:", segmented_vessels_array.max())
print("Mean:", segmented_vessels_array.mean())
print("Median:", np.median(segmented_vessels_array))
#print("Std:", segmented_vessels_array.std())

print("CHANNEL 1")

# Plot the raw vesselness image
show(image=segmented_vessels_array, title=f"CH1: Vesselness image",
     image2=thresholded_vessels_ch1, title2=f"CH1: Vessel mask",
     axis=False)

# Show the results
show(image=input_image, title="CH1: Input image",
     image2=input_image, title2="Vessel mask contours over input image",
     contour2=thresholded_vessels_ch1, contour_alpha=0.45,
     axis=False)

# Plot the raw vesselness image FOV1
show(image=segmented_vessels_array, title=f"CH1 FOV1: Vesselness image",
     image2=thresholded_vessels_ch1, title2=f"CH1 FOV1: Vessel mask",
     xlim=fov1_xlim, ylim=fov1_ylim, xlim2=fov1_xlim, ylim2=fov1_ylim,
     axis=False)


# FOV1
show(image=input_image, title="CH1 FOV1: Input image",
     image2=input_image, title2="CH1 FOV1: Vessel mask contours over input image",
     contour2=thresholded_vessels_ch1, contour_alpha=0.45,
     xlim=fov1_xlim, ylim=fov1_ylim, xlim2=fov1_xlim, ylim2=fov1_ylim,
     axis=False)

# FOV2
show(image=input_image, title="CH1 FOV2: Input image",
     image2=input_image, title2="CH1 FOV2: Vessel mask contours over input image",
     contour2=thresholded_vessels_ch1, contour_alpha=0.45,
     xlim=fov2_xlim, ylim=fov2_ylim, xlim2=fov2_xlim, ylim2=fov2_ylim,
     axis=False)

# FOV3
show(image=input_image, title="CH1 FOV3: Input image",
     image2=input_image, title2="CH1 FOV3: Vessel mask contours over input image",
     contour2=thresholded_vessels_ch1, contour_alpha=0.45,
     xlim=fov3_xlim, ylim=fov3_ylim, xlim2=fov3_xlim, ylim2=fov3_ylim,
     axis=False)

# FOV4
show(image=input_image, title="CH1 FOV4: Input image",
     image2=input_image, title2="CH1 FOV4: Vessel mask contours over input image",
     contour2=thresholded_vessels_ch1, contour_alpha=0.45,
     xlim=fov4_xlim, ylim=fov4_ylim, xlim2=fov4_xlim, ylim2=fov4_ylim,
     axis=False)

# FOV5
show(image=input_image, title="CH1 FOV5: Input image",
     image2=input_image, title2="CH1 FOV5: Vessel mask contours over input image",
     contour2=thresholded_vessels_ch1, contour_alpha=0.45,
     xlim=fov5_xlim, ylim=fov5_ylim, xlim2=fov5_xlim, ylim2=fov5_ylim,
     axis=False)

# FOV6
show(image=input_image, title="CH1 FOV6: Input image",
     image2=input_image, title2="CH1 FOV6: Vessel mask contours over input image",
     contour2=thresholded_vessels_ch1, contour_alpha=0.45,
     xlim=fov6_xlim, ylim=fov6_ylim, xlim2=fov6_xlim, ylim2=fov6_ylim,
     axis=False)

# Full section
show(image=thresholded_vessels_ch1, title="CH1: Vessel segmentation",
     image2=curr_ch2_thresh, title2="CH2: Thresholded vessels",
     figsize=(20, 10), axis=False)

# CHANNEL 2

In [None]:
# Parameters for vessel detection
sigma_minimum = 1.0  # Range of scales in which MultiScaleHessianBasedMeasureImageFilter will search for vessels
sigma_maximum = 10.0  # 10
number_of_sigma_steps = 10  # 10 Number of scales to search for vessels

# Parameters for post-processing
thresh = 230  # Threshold for binarization, 230
min_size = 100  # Minimum size of objects to keep
area_threshold = 2000 # Minimum area of holes to keep
smoothing = 1  # Smoothing factor for closing, 3

#############################################################

# Alternative: load image in memory
input_image = auto_ch2_median * bg_mask
input_image = input_image.astype(np.float32)
input_image *= 255.0

show(input_image, title="CH2: Input image", axis=False)

# Print statistics
print("Input image type:", input_image.dtype)
print("Input image min:", input_image.min())
print("Input image max:", input_image.max())

# Run the vessel detection
segmented_vessels_array = detect_vessels(input_image, sigma_minimum, sigma_maximum, number_of_sigma_steps,
                                         beta=beta2)

# Process the thresholded vessels
thresholded_vessels_ch2 = process_vessels(segmented_vessels_array, thresh=thresh, min_size=min_size, area_threshold=area_threshold, smoothing=smoothing)
thresholded_vessels_ch2 = thresholded_vessels_ch2 * bg_mask * curr_ch2_thresh

# Print statistics
print("Vesselness image statistics:")
print("Shape:", segmented_vessels_array.shape)
print("Min:", segmented_vessels_array.min())
print("Max:", segmented_vessels_array.max())
print("Mean:", segmented_vessels_array.mean())
print("Median:", np.median(segmented_vessels_array))
#print("Std:", segmented_vessels_array.std())

print("CHANNEL 2")

# Plot the raw vesselness image
show(image=segmented_vessels_array, title=f"CH2: Vesselness image",
     image2=thresholded_vessels_ch2, title2=f"CH2: Vessel mask",
     axis=False)

# Show the results
show(image=input_image, title="CH2: Input image",
     image2=input_image, title2="Vessel mask contours over input image",
     contour2=thresholded_vessels_ch2, contour_alpha=0.45,
     axis=False)

# Plot the raw vesselness image FOV1
show(image=segmented_vessels_array, title=f"CH2 FOV1: Vesselness image",
     image2=thresholded_vessels_ch2, title2=f"CH2 FOV1: Vessel mask",
     xlim=fov1_xlim, ylim=fov1_ylim, xlim2=fov1_xlim, ylim2=fov1_ylim,
     axis=False)


# FOV1
show(image=input_image, title="CH2 FOV1: Input image",
     image2=input_image, title2="CH2 FOV1: Vessel mask contours over input image",
     contour2=thresholded_vessels_ch2, contour_alpha=0.45,
     xlim=fov1_xlim, ylim=fov1_ylim, xlim2=fov1_xlim, ylim2=fov1_ylim,
     axis=False)

# FOV2
show(image=input_image, title="CH2 FOV2: Input image",
     image2=input_image, title2="CH2 FOV2: Vessel mask contours over input image",
     contour2=thresholded_vessels_ch2, contour_alpha=0.45,
     xlim=fov2_xlim, ylim=fov2_ylim, xlim2=fov2_xlim, ylim2=fov2_ylim,
     axis=False)

# FOV3
show(image=input_image, title="CH2 FOV3: Input image",
     image2=input_image, title2="CH2 FOV3: Vessel mask contours over input image",
     contour2=thresholded_vessels_ch2, contour_alpha=0.45,
     xlim=fov3_xlim, ylim=fov3_ylim, xlim2=fov3_xlim, ylim2=fov3_ylim,
     axis=False)

# FOV4
show(image=input_image, title="CH2 FOV4: Input image",
     image2=input_image, title2="CH2 FOV4: Vessel mask contours over input image",
     contour2=thresholded_vessels_ch2, contour_alpha=0.45,
     xlim=fov4_xlim, ylim=fov4_ylim, xlim2=fov4_xlim, ylim2=fov4_ylim,
     axis=False)

# FOV5
show(image=input_image, title="CH2 FOV5: Input image",
     image2=input_image, title2="CH2 FOV5: Vessel mask contours over input image",
     contour2=thresholded_vessels_ch2, contour_alpha=0.45,
     xlim=fov5_xlim, ylim=fov5_ylim, xlim2=fov5_xlim, ylim2=fov5_ylim,
     axis=False)

# FOV6
show(image=input_image, title="CH2 FOV6: Input image",
     image2=input_image, title2="CH2 FOV6: Vessel mask contours over input image",
     contour2=thresholded_vessels_ch2, contour_alpha=0.45,
     xlim=fov6_xlim, ylim=fov6_ylim, xlim2=fov6_xlim, ylim2=fov6_ylim,
     axis=False)

# Full section
show(image=thresholded_vessels_ch1, title="CH1: Vessel segmentation",
     image2=thresholded_vessels_ch2, title2="CH2: Vessel segmentation",
     figsize=(10, 10), axis=False)

# CHANNEL 3

In [None]:
# Parameters for vessel detection
sigma_minimum = 1.0  # Range of scales in which MultiScaleHessianBasedMeasureImageFilter will search for vessels
sigma_maximum = 10.0  # 10
number_of_sigma_steps = 10  # 10 Number of scales to search for vessels

# Parameters for post-processing
thresh = 230  # Threshold for binarization, 230 and 25 and 15
min_size = 100  # Minimum size of objects to keep 100
area_threshold = 2000 # Minimum area of holes to keep
smoothing = 1  # Smoothing factor for closing, 3

#############################################################

# Alternative: load image in memory
input_image = auto_ch3_median * bg_mask
#input_image = cc_ch1_alt * bg_mask
input_image = input_image.astype(np.float32)
input_image *= 255.0

show(input_image, title="CH3: Input image", axis=False)

# Print statistics
print("Input image type:", input_image.dtype)
print("Input image min:", input_image.min())
print("Input image max:", input_image.max())

# Run the vessel detection
segmented_vessels_array = detect_vessels(input_image, sigma_minimum, sigma_maximum, number_of_sigma_steps,
                                         beta=beta3)

# Process the thresholded vessels
thresholded_vessels_ch3 = process_vessels(segmented_vessels_array, thresh=thresh, min_size=min_size, area_threshold=area_threshold, smoothing=smoothing)
thresholded_vessels_ch3 = thresholded_vessels_ch3 * bg_mask * curr_ch3_thresh

# Print statistics
print("Vesselness image statistics:")
print("Shape:", segmented_vessels_array.shape)
print("Min:", segmented_vessels_array.min())
print("Max:", segmented_vessels_array.max())
print("Mean:", segmented_vessels_array.mean())
print("Median:", np.median(segmented_vessels_array))
#print("Std:", segmented_vessels_array.std())

print("CHANNEL 3")

# Plot the raw vesselness image
show(image=segmented_vessels_array, title=f"CH3: Vesselness image",
     image2=thresholded_vessels_ch3, title2=f"CH3: Vessel mask",
     axis=False)

# Show the results
show(image=input_image, title="CH3: Input image",
     image2=input_image, title2="Vessel mask contours over input image",
     contour2=thresholded_vessels_ch3, contour_alpha=0.45,
     axis=False)

# Plot the raw vesselness image FOV1
show(image=segmented_vessels_array, title=f"CH3 FOV1: Vesselness image",
     image2=thresholded_vessels_ch3, title2=f"CH3 FOV1: Vessel mask",
     xlim=fov1_xlim, ylim=fov1_ylim, xlim2=fov1_xlim, ylim2=fov1_ylim,
     axis=False)


# FOV1
show(image=input_image, title="CH3 FOV1: Input image",
     image2=input_image, title2="CH3 FOV1: Vessel mask contours over input image",
     contour2=thresholded_vessels_ch3, contour_alpha=0.45,
     xlim=fov1_xlim, ylim=fov1_ylim, xlim2=fov1_xlim, ylim2=fov1_ylim,
     axis=False)

# FOV2
show(image=input_image, title="CH3 FOV2: Input image",
     image2=input_image, title2="CH3 FOV2: Vessel mask contours over input image",
     contour2=thresholded_vessels_ch3, contour_alpha=0.45,
     xlim=fov2_xlim, ylim=fov2_ylim, xlim2=fov2_xlim, ylim2=fov2_ylim,
     axis=False)

# FOV3
show(image=input_image, title="CH3 FOV3: Input image",
     image2=input_image, title2="CH3 FOV3: Vessel mask contours over input image",
     contour2=thresholded_vessels_ch3, contour_alpha=0.45,
     xlim=fov3_xlim, ylim=fov3_ylim, xlim2=fov3_xlim, ylim2=fov3_ylim,
     axis=False)

# FOV4
show(image=input_image, title="CH3 FOV4: Input image",
     image2=input_image, title2="CH3 FOV4: Vessel mask contours over input image",
     contour2=thresholded_vessels_ch3, contour_alpha=0.45,
     xlim=fov4_xlim, ylim=fov4_ylim, xlim2=fov4_xlim, ylim2=fov4_ylim,
     axis=False)

# FOV5
show(image=input_image, title="CH3 FOV5: Input image",
     image2=input_image, title2="CH3 FOV5: Vessel mask contours over input image",
     contour2=thresholded_vessels_ch3, contour_alpha=0.45,
     xlim=fov5_xlim, ylim=fov5_ylim, xlim2=fov5_xlim, ylim2=fov5_ylim,
     axis=False)

# FOV6
show(image=input_image, title="CH3 FOV6: Input image",
     image2=input_image, title2="CH3 FOV6: Vessel mask contours over input image",
     contour2=thresholded_vessels_ch3, contour_alpha=0.45,
     xlim=fov6_xlim, ylim=fov6_ylim, xlim2=fov6_xlim, ylim2=fov6_ylim,
     axis=False)

# Full section
show(image=thresholded_vessels_ch3, title="CH3: Vessel segmentation",
     image2=thresholded_vessels_ch2, title2="CH2: Vessel segmentation",
     figsize=(20, 10), axis=False)

# Compute statistics 

In [12]:
remove_mask = np.zeros(curr_ch1.shape)
remove_mask[2000:5000, 1000:3000] = 1
remove_mask = remove_mask.astype(bool)
remove_mask = np.invert(remove_mask)

thresholded_vessels_ch1 = thresholded_vessels_ch1 * remove_mask
thresholded_vessels_ch2 = thresholded_vessels_ch2 * remove_mask
thresholded_vessels_ch3 = thresholded_vessels_ch3 * remove_mask

In [None]:
thresholded_ch0 = thresholded_vessels_ch1
thresholded_ch1 = thresholded_vessels_ch2
thresholded_ch2 = thresholded_vessels_ch3

thresh_ch0_flat = thresholded_ch0.flatten()
thresh_ch1_flat = thresholded_ch1.flatten()
thresh_ch2_flat = thresholded_ch2.flatten()

# Compute the metrics ch1 and ch2
dice_score = dice_coefficient(thresholded_ch0, thresholded_ch1)  
iou_score = iou(thresholded_ch0, thresholded_ch1)  # Strongly penalizes over-segmentation and under-segmentation
precision_score = precision(thresholded_ch0, thresholded_ch1) 
recall_score = recall(thresholded_ch0, thresholded_ch1)
ssim_score = ssim(thresholded_ch0, thresholded_ch1)
mse_score = mean_squared_error(thresholded_ch0, thresholded_ch1)
hamming_distance = hamming(thresh_ch0_flat, thresh_ch1_flat)
rand_score = rand_index(thresholded_ch0, thresholded_ch1)  # Measures how close points are clustered together


#print("Beta:", BETA)
print("\nMetrics for CH1 and CH2")
print("Beta1:", beta1, "Beta2:", beta2)
print("Thresh1:", THRESH, "Thresh2:", THRESH2)
print("Dice coefficient:", dice_score)
print("IoU score:", iou_score)
print("Precision score:", precision_score)
print("Recall score:", recall_score)
print("SSIM score:", ssim_score)
print("MSE score:", mse_score)
print("Hamming distance:", hamming_distance)
print("Rand index:", rand_score)

# Compute the metrics ch1 and ch3
dice_score = dice_coefficient(thresholded_ch0, thresholded_ch2)  
iou_score = iou(thresholded_ch0, thresholded_ch2)  # Strongly penalizes over-segmentation and under-segmentation
precision_score = precision(thresholded_ch0, thresholded_ch2) 
recall_score = recall(thresholded_ch0, thresholded_ch2)
ssim_score = ssim(thresholded_ch0, thresholded_ch2)
mse_score = mean_squared_error(thresholded_ch0, thresholded_ch2)
hamming_distance = hamming(thresh_ch0_flat, thresh_ch2_flat)
rand_score = rand_index(thresholded_ch0, thresholded_ch2)  # Measures how close points are clustered together

print("\nMetrics for CH1 and CH3")
print("Beta1:", beta1, "Beta3:", beta3)
print("Thresh1:", THRESH,  "Thresh3:", THRESH3)
print("Dice coefficient:", dice_score)
print("IoU score:", iou_score)
print("Precision score:", precision_score)
print("Recall score:", recall_score)
print("SSIM score:", ssim_score)
print("MSE score:", mse_score)
print("Hamming distance:", hamming_distance)
print("Rand index:", rand_score)

# Compute the metrics ch2 and ch3
dice_score = dice_coefficient(thresholded_ch1, thresholded_ch2)  
iou_score = iou(thresholded_ch1, thresholded_ch2)  # Strongly penalizes over-segmentation and under-segmentation
precision_score = precision(thresholded_ch1, thresholded_ch2) 
recall_score = recall(thresholded_ch1, thresholded_ch2)
ssim_score = ssim(thresholded_ch1, thresholded_ch2)
mse_score = mean_squared_error(thresholded_ch1, thresholded_ch2)
hamming_distance = hamming(thresh_ch1_flat, thresh_ch2_flat)
rand_score = rand_index(thresholded_ch1, thresholded_ch2)  # Measures how close points are clustered together

print("\nMetrics for CH2 and CH3")
print("Beta2:", beta2, "Beta3:", beta3)
print("Thresh2:", THRESH2, "Thresh3:", THRESH3)
print("Dice coefficient:", dice_score)
print("IoU score:", iou_score)
print("Precision score:", precision_score)
print("Recall score:", recall_score)
print("SSIM score:", ssim_score)
print("MSE score:", mse_score)
print("Hamming distance:", hamming_distance)
print("Rand index:", rand_score)

"""
# Write the solutions to a CSV file
csv_filename = 'stats_enhanced_M13.csv'

# Write to rows
rows = [["Index", "Dice coefficient", "IoU score", "Precision", "Recall", "SSIM", "MSE", "Hamming distance", "Rand index"]]
rows.append([IDX, dice_score, iou_score, precision_score, recall_score, ssim_score, mse_score, hamming_distance, rand_score])

#with open(csv_filename, mode='w', newline='') as file:
#    writer = csv.writer(file)
#    writer.writerows(rows)
"""

Run the whole thing

In [6]:

def run_test(idx, t1=None, t3=None):
    print()
    print()
    IDX = idx
    
    THRESH2 = 3000  # 3000
    THRESH3 = 600  # 1000

    if IDX == 0:
        THRESH = 7500  # 7500
        max_value = 20000  # 20000
        THRESH3 = 600
    elif IDX == 1:
        THRESH = 6000  # 6000
        max_value = 20000  # 20000
        THRESH3 = 850
    elif IDX == 2:
        THRESH = 6000  # 6000
        max_value = 20000  # 20000
    elif IDX == 3:
        THRESH = 6000  # 6000
        max_value = 20000  # 20000
    elif IDX == 4:
        THRESH = 7500  # 7500
        THRESH2 = 4000 # 4000
        max_value = 20000  # 20000
    max_value3 = 2000

    if t3 is not None:
        THRESH3 = t3
    if t1 is not None:
        THRESH = t1
        
    # Load the image channels
    curr_ch1, curr_ch2, curr_ch3 = load_3_channels(filepath, IDX)
    curr_ch1 = curr_ch1.astype(np.float32)
    curr_ch2 = curr_ch2.astype(np.float32)
    curr_ch3 = curr_ch3.astype(np.float32)


    # Ch1 settings
    gamma_ch1 = 2  # You can adjust this value to control the contrast enhancement
    contrast_alpha_ch1 = 0.00525  # Try 0.0225 You can adjust this value to control the brightness enhancement 0.5 default

    # Ch2 settings
    gamma_ch2 = 2  # You can adjust this value to control the contrast enhancement
    contrast_alpha_ch2 = 0.0125  # Try 0.125 You can adjust this value to control the brightness enhancement 0.5 default

    # Ch3 settings
    gamma_ch3 = 2  # You can adjust this value to control the contrast enhancement
    contrast_alpha_ch3 = 0.0425  # Try 0.125 You can adjust this value to control the brightness enhancement 0.5 default


    # No change
    contrast_ch1 = curr_ch1
    contrast_ch2 = curr_ch2
    contrast_ch3 = curr_ch3


    cc_ch1 = auto_contrast(contrast_ch1, alpha=contrast_alpha_ch1)
    cc_ch2 = auto_contrast(contrast_ch2, alpha=contrast_alpha_ch2)
    cc_ch3 = auto_contrast(contrast_ch3, alpha=contrast_alpha_ch3)


    bg_alpha = 0.25  # 0.5
    bg_mask = auto_contrast(curr_ch1, alpha=bg_alpha)  # 7
    bg_mask = get_brain_mask(bg_mask, area_threshold=25000)  # 255 default ch0, 150 for ch1

    #show(cc_ch1, bg_mask, title=f"Section {IDX} ch1 contrast", axis=True)
    #show(cc_ch2, bg_mask, title=f"Section {IDX} ch2 contrast", axis=True)
    #show(cc_ch3, bg_mask, title=f"Section {IDX} ch3 contrast", axis=True)

    # Setup FOVs
    fovs = [ 
    # xlim,             ylim
    ( (4000, 5000),     (4000, 5000) ),
    ( (2000, 3000),     (2000, 4000) ),
    ( (2000, 4000),     (4000, 6000) ),
    ( (5000, 6000),     (2000, 4000) ),
    ( (5000, 6000),     (1000, 2000) ),
    ( (5000, 6000),     (4000, 5000) ),
    ]

    fov1_xlim, fov1_ylim = fovs[0][0], fovs[0][1][::-1]
    fov2_xlim, fov2_ylim = fovs[1][0], fovs[1][1][::-1]
    fov3_xlim, fov3_ylim = fovs[2][0], fovs[2][1][::-1]
    fov4_xlim, fov4_ylim = fovs[3][0], fovs[3][1][::-1]
    fov5_xlim, fov5_ylim = fovs[4][0], fovs[4][1][::-1]
    fov6_xlim, fov6_ylim = fovs[5][0], fovs[5][1][::-1]

    # Create a threshold mask for the image
    curr_ch1_median = ndimage.median_filter(curr_ch1.copy(), size=5)
    curr_ch2_median = ndimage.median_filter(curr_ch2.copy(), size=5)  # Repeat for ch2
    curr_ch3_median = ndimage.median_filter(curr_ch3.copy(), size=5)  # Repeat for ch2

    # Create auto contrast brightened images
    auto_ch1 = auto_contrast(curr_ch1, alpha=contrast_alpha_ch1)
    auto_ch2 = auto_contrast(curr_ch2, alpha=contrast_alpha_ch2)
    auto_ch3 = auto_contrast(curr_ch3, alpha=contrast_alpha_ch3)
    auto_ch1_median = ndimage.median_filter(auto_ch1.copy(), size=5) # 5
    auto_ch2_median = ndimage.median_filter(auto_ch2.copy(), size=5)
    auto_ch3_median = ndimage.median_filter(auto_ch3.copy(), size=5)

    # Ch1 settings
    gamma_ch1 = 2  # You can adjust this value to control the contrast enhancement
    contrast_alpha_ch1 = 0.00525  # Try 0.0225 You can adjust this value to control the brightness enhancement 0.5 default

    # Ch2 settings
    gamma_ch2 = 2  # You can adjust this value to control the contrast enhancement
    contrast_alpha_ch2 = 0.0125  # Try 0.125 You can adjust this value to control the brightness enhancement 0.5 default

    # Ch3 settings
    gamma_ch3 = 2  # You can adjust this value to control the contrast enhancement
    contrast_alpha_ch3 = 0.0425  # Try 0.125 You can adjust this value to control the brightness enhancement 0.5 default




    # Create contrast enhanced images
    cc_ch1 = gamma_correction(curr_ch1, gamma=gamma_ch1)
    cc_ch1 = auto_contrast(cc_ch1, alpha=contrast_alpha_ch1)
    cc_ch1_alt = gamma_correction(curr_ch1_median, gamma=gamma_ch1, max_value=max_value)

    cc_ch2 = gamma_correction(curr_ch2, gamma=gamma_ch2)
    cc_ch2 = auto_contrast(cc_ch2, alpha=contrast_alpha_ch2)

    cc_ch3 = gamma_correction(curr_ch3, gamma=gamma_ch3)
    cc_ch3 = auto_contrast(cc_ch3, alpha=contrast_alpha_ch3)
    cc_ch3_alt = gamma_correction(curr_ch3_median, gamma=gamma_ch3, max_value=max_value3)


    print(f"\nThreshold for ch1: {THRESH}")
    print(f"Threshold for ch2: {THRESH2}")
    print(f"Threshold for ch3: {THRESH3}")


    curr_ch1_thresh = cc_ch1_alt.copy() > THRESH
    curr_ch1_thresh[curr_ch1_thresh != 0] = 1
    curr_ch1_thresh = curr_ch1_thresh.astype(bool)

    curr_ch2_thresh = curr_ch2_median.copy() > THRESH2
    curr_ch2_thresh[curr_ch2_thresh != 0] = 1
    curr_ch2_thresh = curr_ch2_thresh.astype(bool)

    curr_ch3_thresh = cc_ch3_alt.copy() > THRESH3
    curr_ch3_thresh[curr_ch3_thresh != 0] = 1
    curr_ch3_thresh = curr_ch3_thresh.astype(bool)



    # Parameters for vessel detection
    sigma_minimum = 1.0  # Range of scales in which MultiScaleHessianBasedMeasureImageFilter will search for vessels
    sigma_maximum = 10.0  # 10
    number_of_sigma_steps = 10  # 10 Number of scales to search for vessels

    # Parameters for post-processing
    thresh = 230  # Threshold for binarization, 230 and 25 and 15
    min_size = 100  # Minimum size of objects to keep 100
    area_threshold = 2000 # Minimum area of holes to keep
    smoothing = 1  # Smoothing factor for closing, 3

    #############################################################

    # Alternative: load image in memory
    input_image = auto_ch1_median * bg_mask
    #input_image = cc_ch1_alt * bg_mask
    input_image = input_image.astype(np.float32)
    input_image *= 255.0

    #show(input_image, title="CH1: Input image", axis=False)

    # Print statistics
    #print("Input image type:", input_image.dtype)
    #print("Input image min:", input_image.min())
    #print("Input image max:", input_image.max())

    # Run the vessel detection
    segmented_vessels_array = detect_vessels(input_image, sigma_minimum, sigma_maximum, number_of_sigma_steps)

    # Process the thresholded vessels
    thresholded_vessels_ch1 = process_vessels(segmented_vessels_array, thresh=thresh, min_size=min_size, area_threshold=area_threshold, smoothing=smoothing)
    thresholded_vessels_ch1 = thresholded_vessels_ch1 * bg_mask #* curr_ch1_thresh

    # Print statistics
    #print("Vesselness image statistics:")
    #print("Shape:", segmented_vessels_array.shape)
    #print("Min:", segmented_vessels_array.min())
    #print("Max:", segmented_vessels_array.max())
    #print("Mean:", segmented_vessels_array.mean())
    #print("Median:", np.median(segmented_vessels_array))
    #print("Std:", segmented_vessels_array.std())


    # Parameters for vessel detection
    sigma_minimum = 1.0  # Range of scales in which MultiScaleHessianBasedMeasureImageFilter will search for vessels
    sigma_maximum = 10.0  # 10
    number_of_sigma_steps = 10  # 10 Number of scales to search for vessels

    # Parameters for post-processing
    thresh = 230  # Threshold for binarization, 230
    min_size = 100  # Minimum size of objects to keep
    area_threshold = 2000 # Minimum area of holes to keep
    smoothing = 1  # Smoothing factor for closing, 3

    #############################################################

    # Alternative: load image in memory
    input_image = auto_ch2_median * bg_mask
    input_image = input_image.astype(np.float32)
    input_image *= 255.0

    #show(input_image, title="CH2: Input image", axis=False)

    # Print statistics
    #print("Input image type:", input_image.dtype)
    #print("Input image min:", input_image.min())
    #print("Input image max:", input_image.max())

    # Run the vessel detection
    segmented_vessels_array = detect_vessels(input_image, sigma_minimum, sigma_maximum, number_of_sigma_steps)

    # Process the thresholded vessels
    thresholded_vessels_ch2 = process_vessels(segmented_vessels_array, thresh=thresh, min_size=min_size, area_threshold=area_threshold, smoothing=smoothing)
    thresholded_vessels_ch2 = thresholded_vessels_ch2 * bg_mask * curr_ch2_thresh


    #############################################################

    # Parameters for vessel detection
    sigma_minimum = 1.0  # Range of scales in which MultiScaleHessianBasedMeasureImageFilter will search for vessels
    sigma_maximum = 10.0  # 10
    number_of_sigma_steps = 10  # 10 Number of scales to search for vessels

    # Parameters for post-processing
    thresh = 230  # Threshold for binarization, 230 and 25 and 15
    min_size = 100  # Minimum size of objects to keep 100
    area_threshold = 2000 # Minimum area of holes to keep
    smoothing = 1  # Smoothing factor for closing, 3

    #############################################################

    # Alternative: load image in memory
    input_image = auto_ch3_median * bg_mask
    #input_image = cc_ch1_alt * bg_mask
    input_image = input_image.astype(np.float32)
    input_image *= 255.0

    #show(input_image, title="CH3: Input image", axis=False)

    # Print statistics
    #print("Input image type:", input_image.dtype)
    #print("Input image min:", input_image.min())
    #print("Input image max:", input_image.max())

    # Run the vessel detection
    segmented_vessels_array = detect_vessels(input_image, sigma_minimum, sigma_maximum, number_of_sigma_steps)

    # Process the thresholded vessels
    thresholded_vessels_ch3 = process_vessels(segmented_vessels_array, thresh=thresh, min_size=min_size, area_threshold=area_threshold, smoothing=smoothing)
    thresholded_vessels_ch3 = thresholded_vessels_ch3 * bg_mask #* curr_ch3_thresh

    # Print statistics

    thresholded_ch0 = thresholded_vessels_ch1
    thresholded_ch1 = thresholded_vessels_ch2
    thresholded_ch2 = thresholded_vessels_ch3

    thresh_ch0_flat = thresholded_ch0.flatten()
    thresh_ch1_flat = thresholded_ch1.flatten()
    thresh_ch2_flat = thresholded_ch2.flatten()

    # Compute the metrics ch1 and ch2
    dice_score = dice_coefficient(thresholded_ch0, thresholded_ch1)  
    iou_score = iou(thresholded_ch0, thresholded_ch1)  # Strongly penalizes over-segmentation and under-segmentation
    precision_score = precision(thresholded_ch0, thresholded_ch1) 
    recall_score = recall(thresholded_ch0, thresholded_ch1)
    ssim_score = ssim(thresholded_ch0, thresholded_ch1)
    mse_score = mean_squared_error(thresholded_ch0, thresholded_ch1)
    hamming_distance = hamming(thresh_ch0_flat, thresh_ch1_flat)
    rand_score = rand_index(thresholded_ch0, thresholded_ch1)  # Measures how close points are clustered together
    
    print("\nSlice index:", IDX)
    if t3 is not None:
        print("Thresh3:", THRESH3)
    if t1 is not None:
        print("Thresh1:", THRESH)
    print()

    #print("Beta:", BETA)
    print("\nMetrics for CH1 and CH2")
    print("Thresh1:", THRESH, "Thresh2:", THRESH2)
    print("Dice coefficient:", dice_score)
    print("IoU score:", iou_score)
    print("Precision score:", precision_score)
    print("Recall score:", recall_score)
    print("SSIM score:", ssim_score)
    print("MSE score:", mse_score)
    print("Hamming distance:", hamming_distance)
    print("Rand index:", rand_score)

    # Compute the metrics ch1 and ch3
    dice_score = dice_coefficient(thresholded_ch0, thresholded_ch2)  
    iou_score = iou(thresholded_ch0, thresholded_ch2)  # Strongly penalizes over-segmentation and under-segmentation
    precision_score = precision(thresholded_ch0, thresholded_ch2) 
    recall_score = recall(thresholded_ch0, thresholded_ch2)
    ssim_score = ssim(thresholded_ch0, thresholded_ch2)
    mse_score = mean_squared_error(thresholded_ch0, thresholded_ch2)
    hamming_distance = hamming(thresh_ch0_flat, thresh_ch2_flat)
    rand_score = rand_index(thresholded_ch0, thresholded_ch2)  # Measures how close points are clustered together

    print("\nMetrics for CH1 and CH3")
    print("Thresh1:", THRESH,  "Thresh3:", THRESH3)
    print("Dice coefficient:", dice_score)
    print("IoU score:", iou_score)
    print("Precision score:", precision_score)
    print("Recall score:", recall_score)
    print("SSIM score:", ssim_score)
    print("MSE score:", mse_score)
    print("Hamming distance:", hamming_distance)
    print("Rand index:", rand_score)

    # Compute the metrics ch2 and ch3
    dice_score = dice_coefficient(thresholded_ch1, thresholded_ch2)  
    iou_score = iou(thresholded_ch1, thresholded_ch2)  # Strongly penalizes over-segmentation and under-segmentation
    precision_score = precision(thresholded_ch1, thresholded_ch2) 
    recall_score = recall(thresholded_ch1, thresholded_ch2)
    ssim_score = ssim(thresholded_ch1, thresholded_ch2)
    mse_score = mean_squared_error(thresholded_ch1, thresholded_ch2)
    hamming_distance = hamming(thresh_ch1_flat, thresh_ch2_flat)
    rand_score = rand_index(thresholded_ch1, thresholded_ch2)  # Measures how close points are clustered together

    print("\nMetrics for CH2 and CH3")
    print("Thresh2:", THRESH2, "Thresh3:", THRESH3)
    print("Dice coefficient:", dice_score)
    print("IoU score:", iou_score)
    print("Precision score:", precision_score)
    print("Recall score:", recall_score)
    print("SSIM score:", ssim_score)
    print("MSE score:", mse_score)
    print("Hamming distance:", hamming_distance)
    print("Rand index:", rand_score)
    print()
    print()
    print()


# RUN TESTS

In [None]:
# Run without mask for ch1 beta = 0.5
filepath = "/media/data/u01/Fig3/M13/*/*.tif"
#IDX = 1

print("Channel 1")

for i in range(0, 5):
    print()
    print("Doing slice:", i)
    print()
    run_test(i, t1=0)
    
print("Channel 3")
    
# Run for ch 3
for i in range(0, 5):
    print()
    print("Doing slice:", i)
    print()
    run_test(i, t3=0)

In [None]:
# Run without mask for ch1
filepath = "/media/data/u01/Fig3/M13/*/*.tif"
#IDX = 1

print("Channel 1")

for i in range(0, 5):
    print()
    print("Doing slice:", i)
    print()
    run_test(i, t1=0)
    
print("Channel 3")
    
# Run for ch 3
for i in range(0, 5):
    print()
    print("Doing slice:", i)
    print()
    run_test(i, t3=0)

In [None]:
filepath = "/media/data/u01/Fig3/M13/*/*.tif"
#IDX = 1
THRESH1LIST = [4000, 5000, 6000, 6500, 7000, 7500, 8000, 8500, 9000]

for i in range(0, 5):
    print()
    print("Doing slice:", i)
    print()
    for t in THRESH1LIST:
        run_test(i, t1=t)

In [None]:
filepath = "/media/data/u01/Fig3/M13/*/*.tif"
#IDX = 1
THRESH3LIST = [200, 300, 400, 500, 600, 750, 850, 1000]

for i in range(0, 5):
    print()
    print("Doing slice:", i)
    print()
    for t in THRESH3LIST:
        run_test(i, t3=t)

## Threshold method (batch)

In [None]:
from tqdm import tqdm

data_path = "/media/data/u01/Fig3/M13/*/*.tif"
output_ch1_path = "/media/data/u01/lightsheet/quant-fig3/M13 run 2/segmentation/ch1/"
output_ch2_path = "/media/data/u01/lightsheet/quant-fig3/M13 run 2/segmentation/ch2/"
output_ch3_path = "/media/data/u01/lightsheet/quant-fig3/M13 run 2/segmentation/ch3/"
output_csv_ch1_ch2_path = "/media/data/u01/lightsheet/quant-fig3/M13 run 2/stats_enhanced2_ch1_ch2.csv"
output_csv_ch1_ch3_path = "/media/data/u01/lightsheet/quant-fig3/M13 run 2/stats_enhanced2_ch1_ch3.csv"
output_csv_ch2_ch3_path = "/media/data/u01/lightsheet/quant-fig3/M13 run 2/stats_enhanced2_ch2_ch3.csv"


################################################################################

# Parameters for vessel detection
sigma_minimum = 1.0  # Range of scales in which MultiScaleHessianBasedMeasureImageFilter will search for vessels
sigma_maximum = 10.0  # 10
number_of_sigma_steps = 10  # 10 Number of scales to search for vessels

# Parameters for post-processing
thresh1 = 230  # Threshold for binarization, 230 (ch1)
thresh2 = 230  # Threshold for binarization, 230 (ch2)
min_size1 = 100  # Minimum size of objects to keep (ch1)
min_size2 = 100  # Minimum size of objects to keep (ch2)
area_threshold = 2000 # Minimum area of holes to keep
smoothing = 1  # Smoothing factor for closing, 3

# Read all tif files in the folder
data_files = sorted(glob.glob(data_path))
num_slices = len(data_files) // 3
rows_ch1_ch2 = [["Index", "Dice coefficient", "IoU score", "Precision", "Recall", "SSIM", "MSE", "Hamming distance", "Rand index"]]
rows_ch1_ch3 = [["Index", "Dice coefficient", "IoU score", "Precision", "Recall", "SSIM", "MSE", "Hamming distance", "Rand index"]]
rows_ch2_ch3 = [["Index", "Dice coefficient", "IoU score", "Precision", "Recall", "SSIM", "MSE", "Hamming distance", "Rand index"]]

# Load the image channels
for i in tqdm(range(num_slices)):
    curr_ch1, curr_ch2, curr_ch3 = load_3_channels(data_path, i)
    curr_ch1 = curr_ch1.astype(np.float32)
    curr_ch2 = curr_ch2.astype(np.float32)
    curr_ch3 = curr_ch3.astype(np.float32)
    
    gamma_ch1 = 2  # You can adjust this value to control the contrast enhancement
    contrast_alpha_ch1 = 0.00525  # Try 0.0225 You can adjust this value to control the brightness enhancement 0.5 default

    # Ch2 settings
    gamma_ch2 = 2  # You can adjust this value to control the contrast enhancement
    contrast_alpha_ch2 = 0.0125  # Try 0.125 You can adjust this value to control the brightness enhancement 0.5 default

    # Ch3 settings
    gamma_ch3 = 2  # You can adjust this value to control the contrast enhancement
    contrast_alpha_ch3 = 0.0425  # Try 0.125 You can adjust this value to control the brightness enhancement 0.5 default

    THRESH = 4000   # 4000 for all of channel 1
    THRESH2 = 3000  # 3000 for all of channel 2
    THRESH3 = None
    max_value = 20000
    max_value3 = 2000

    beta1 = 0.5  # Use beta=0.5 for all of ch1
    beta2 = 1.0  # Use beta=1.0 for all of ch2
    beta3 = None # Swap between 0.5 and 1.0

    # Index by index basis
    if i == 0:
        THRESH = 7500
        
        THRESH3 = 400
        beta3 = 1.0
    elif i == 1:
        THRESH = 5000
        
        THRESH3 = 750
        beta3 = 0.5
    elif i == 2:
        THRESH = 5000  # 5000
        
        THRESH3 = 750
        beta3 = 0.5
    elif i == 3:
        THRESH = 5500  # 5000
        
        THRESH3 = 750
        beta3 = 0.5
    elif i == 4:
        THRESH = 7000
        
        THRESH2 = 4000 # Specific to this index
        
        THRESH3 = 1000
        beta3 = 0.5
    
    # Create a threshold mask for the image
    curr_ch1_median = ndimage.median_filter(curr_ch1.copy(), size=5)
    curr_ch2_median = ndimage.median_filter(curr_ch2.copy(), size=5)  # Repeat for ch2
    curr_ch3_median = ndimage.median_filter(curr_ch3.copy(), size=5)  # Repeat for ch2
    
    cc_ch1_alt = gamma_correction(curr_ch1_median, gamma=gamma_ch1, max_value=max_value)
    #cc_ch3_alt = gamma_correction(curr_ch3_median, gamma=gamma_ch3, max_value=50000)
    cc_ch3_alt = gamma_correction(curr_ch3_median, gamma=gamma_ch3, max_value=max_value3)

    curr_ch1_thresh = cc_ch1_alt.copy() > THRESH
    curr_ch1_thresh[curr_ch1_thresh != 0] = 1
    curr_ch1_thresh = curr_ch1_thresh.astype(bool)
    
    curr_ch2_thresh = curr_ch2_median.copy() > THRESH2
    curr_ch2_thresh[curr_ch2_thresh != 0] = 1
    curr_ch2_thresh = curr_ch2_thresh.astype(bool)
    
    curr_ch3_thresh = cc_ch3_alt.copy() > THRESH3
    curr_ch3_thresh[curr_ch3_thresh != 0] = 1
    curr_ch3_thresh = curr_ch3_thresh.astype(bool)

    auto_ch1 = auto_contrast(curr_ch1, alpha=contrast_alpha_ch1)
    auto_ch2 = auto_contrast(curr_ch2, alpha=contrast_alpha_ch2)
    auto_ch3 = auto_contrast(curr_ch3, alpha=contrast_alpha_ch3)
    auto_ch1_median = ndimage.median_filter(auto_ch1.copy(), size=5)
    auto_ch2_median = ndimage.median_filter(auto_ch2.copy(), size=5)
    auto_ch3_median = ndimage.median_filter(auto_ch3.copy(), size=5)


    bg_alpha = 0.25  #
    bg_mask = auto_contrast(curr_ch1, alpha=bg_alpha)  # 
    bg_mask = get_brain_mask(bg_mask, area_threshold=25000)  # 255 default ch0, 150 for ch1
    
    input_ch1 = auto_ch1_median * bg_mask
    input_ch1 = input_ch1.astype(np.float32)
    input_ch1 *= 255.0
        
    input_ch2 = auto_ch2_median * bg_mask
    input_ch2 = input_ch2.astype(np.float32)
    input_ch2 *= 255.0
    
    input_ch3 = auto_ch3_median * bg_mask
    input_ch3 = input_ch3.astype(np.float32)
    input_ch3 *= 255.0


    # Run the vessel detection
    segmented_vessels_ch1 = detect_vessels(input_ch1, sigma_minimum, sigma_maximum, number_of_sigma_steps, beta=beta1)
    segmented_vessels_ch2 = detect_vessels(input_ch2, sigma_minimum, sigma_maximum, number_of_sigma_steps, beta=beta2)
    segmented_vessels_ch3 = detect_vessels(input_ch3, sigma_minimum, sigma_maximum, number_of_sigma_steps, beta=beta3)

    # Process the thresholded vessels
    thresholded_vessels_ch1 = process_vessels(segmented_vessels_ch1, thresh=thresh1, min_size=min_size1, area_threshold=area_threshold, smoothing=smoothing)
    thresholded_vessels_ch2 = process_vessels(segmented_vessels_ch2, thresh=thresh2, min_size=min_size2, area_threshold=area_threshold, smoothing=smoothing)
    thresholded_vessels_ch3 = process_vessels(segmented_vessels_ch3, thresh=thresh1, min_size=min_size1, area_threshold=area_threshold, smoothing=smoothing)


    thresholded_vessels_ch1 = thresholded_vessels_ch1 * bg_mask * curr_ch1_thresh
    thresholded_vessels_ch2 = thresholded_vessels_ch2 * bg_mask * curr_ch2_thresh
    thresholded_vessels_ch3 = thresholded_vessels_ch3 * bg_mask * curr_ch3_thresh

    
    # Save to file
    sitk_ch1 = sitk.GetImageFromArray(thresholded_vessels_ch1.astype(np.uint8))  # Ch1
    output_ch1_file = output_ch1_path + f"ch1_seg_{str(i).zfill(4)}.tif"
    sitk.WriteImage(sitk_ch1, output_ch1_file)
    sitk_ch2 = sitk.GetImageFromArray(thresholded_vessels_ch2.astype(np.uint8))  # Ch1
    output_ch2_file = output_ch2_path + f"ch2_seg_{str(i).zfill(4)}.tif"
    sitk.WriteImage(sitk_ch2, output_ch2_file)
    sitk_ch3 = sitk.GetImageFromArray(thresholded_vessels_ch3.astype(np.uint8))  # Ch1
    output_ch3_file = output_ch3_path + f"ch3_seg_{str(i).zfill(4)}.tif"
    sitk.WriteImage(sitk_ch3, output_ch3_file)
    
    # Compute statistics between ch1 and ch2
    thresh_ch1_flat = thresholded_vessels_ch1.flatten()
    thresh_ch2_flat = thresholded_vessels_ch2.flatten()
    thresh_ch3_flat = thresholded_vessels_ch3.flatten()
    
    # Ch1 to ch2
    dice_score = dice_coefficient(thresholded_vessels_ch1, thresholded_vessels_ch2)  
    iou_score = iou(thresholded_vessels_ch1, thresholded_vessels_ch2)  # Strongly penalizes over-segmentation and under-segmentation
    precision_score = precision(thresholded_vessels_ch1, thresholded_vessels_ch2) 
    recall_score = recall(thresholded_vessels_ch1, thresholded_vessels_ch2)
    ssim_score = ssim(thresholded_vessels_ch1, thresholded_vessels_ch2)
    mse_score = mean_squared_error(thresholded_vessels_ch1, thresholded_vessels_ch2)
    hamming_distance = hamming(thresh_ch1_flat, thresh_ch2_flat)
    rand_score = rand_index(thresholded_vessels_ch1, thresholded_vessels_ch2)  # Measures how close points are clustered together
    rows_ch1_ch2.append([i, dice_score, iou_score, precision_score, recall_score, ssim_score, mse_score, hamming_distance, rand_score])
    print("ch1-ch2:", rows_ch1_ch2[i + 1])
    
    # Ch1 to ch3
    dice_score = dice_coefficient(thresholded_vessels_ch1, thresholded_vessels_ch3)  
    iou_score = iou(thresholded_vessels_ch1, thresholded_vessels_ch3)  # Strongly penalizes over-segmentation and under-segmentation
    precision_score = precision(thresholded_vessels_ch1, thresholded_vessels_ch3) 
    recall_score = recall(thresholded_vessels_ch1, thresholded_vessels_ch3)
    ssim_score = ssim(thresholded_vessels_ch1, thresholded_vessels_ch3)
    mse_score = mean_squared_error(thresholded_vessels_ch1, thresholded_vessels_ch3)
    hamming_distance = hamming(thresh_ch1_flat, thresh_ch3_flat)
    rand_score = rand_index(thresholded_vessels_ch1, thresholded_vessels_ch3)  # Measures how close points are clustered together
    rows_ch1_ch3.append([i, dice_score, iou_score, precision_score, recall_score, ssim_score, mse_score, hamming_distance, rand_score])
    print("ch1-ch3:", rows_ch1_ch3[i + 1])
    
    # Ch2 to ch3
    dice_score = dice_coefficient(thresholded_vessels_ch2, thresholded_vessels_ch3)  
    iou_score = iou(thresholded_vessels_ch2, thresholded_vessels_ch3)  # Strongly penalizes over-segmentation and under-segmentation
    precision_score = precision(thresholded_vessels_ch2, thresholded_vessels_ch3) 
    recall_score = recall(thresholded_vessels_ch2, thresholded_vessels_ch3)
    ssim_score = ssim(thresholded_vessels_ch2, thresholded_vessels_ch3)
    mse_score = mean_squared_error(thresholded_vessels_ch2, thresholded_vessels_ch3)
    hamming_distance = hamming(thresh_ch2_flat, thresh_ch3_flat)
    rand_score = rand_index(thresholded_vessels_ch2, thresholded_vessels_ch3)  # Measures how close points are clustered together
    rows_ch2_ch3.append([i, dice_score, iou_score, precision_score, recall_score, ssim_score, mse_score, hamming_distance, rand_score])
    print("ch2-ch3:", rows_ch2_ch3[i + 1])

with open(output_csv_ch1_ch2_path, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(rows_ch1_ch2)
    
with open(output_csv_ch2_ch3_path, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(rows_ch2_ch3)
    
with open(output_csv_ch1_ch3_path, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(rows_ch1_ch3)