In [None]:
import numpy as np
import matplotlib.pyplot as plt
from  matplotlib.image import imread
from pathlib import Path
from skimage.io import imread
from PIL import Image
from typing import Union, Tuple

#------ALL FUNCTIONS NEEDED------


# Wiener filter
def local_wiener_filter(image, window_size=5, noise_variance=None):
    """
    Apply a local adaptive Wiener filter to a grayscale image.

    Parameters:
    - image: 2D numpy array of the grayscale image.
    - window_size: size of the square window (odd integer).
    - noise_variance: estimated variance of the noise; if None, estimate globally.

    Returns:
    - filtered: the Wiener-filtered image as a 2D numpy array.
    """
    # Pad the image to handle borders
    pad = window_size // 2
    img_padded = np.pad(image, pad, mode='reflect')

    # Estimate global noise variance if not provided
    if noise_variance is None:
        noise_variance = np.var(image - np.mean(image))

    filtered = np.zeros_like(image)
    # Slide window over image
    for i in range(filtered.shape[0]):
        for j in range(filtered.shape[1]):
            window = img_padded[i:i+window_size, j:j+window_size]
            local_mean = window.mean()
            local_var = window.var()
            # Compute Wiener filter response
            if local_var > noise_variance:
                filtered[i, j] = local_mean + (local_var - noise_variance) / local_var * (image[i, j] - local_mean)
            else:
                filtered[i, j] = local_mean

    f = filtered
    f_norm = (f - f.min()) / (f.max() - f.min())
    filtered = f_norm
    return filtered

# creating histogram for Otsu threshholding  
def custom_histogram(image: np.ndarray, nbins: int = 256) -> Tuple[np.ndarray, np.ndarray]:
    """
    Computes the histogram and corresponding bin centers of a grayscale image,
    replicating the behavior of skimage.exposure.histogram, including normalization
    to the [0, 255] range. This ensures consistent behavior with Otsu implementations
    that assume 8-bit images.

    Args:
        image (np.ndarray): Input image as a 2D array of grayscale values.
        nbins (int): Number of bins for the histogram (default: 256).

    Returns:
        hist (np.ndarray): Array of histogram frequencies for each bin.
        bin_centers (np.ndarray): Array of bin center values.
    """
    # Determine the minimum and maximum pixel intensity in the image
    img_min, img_max = image.min(), image.max()

    # Normalize the image intensities to the range [0, 255], as in skimage
    image_scaled = (image - img_min) / (img_max - img_min) * 255

    # Compute the histogram of the scaled image within [0, 255]
    hist, bin_edges = np.histogram(
        image_scaled.ravel(),
        bins=nbins,
        range=(0, 255)
    )

    # Compute bin centers as the average of adjacent bin edges
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    return hist, bin_centers

# Otsu thresholding
def apply_global_otsu(image: np.ndarray) -> float:
    """
    Computes the global Otsu threshold of an input grayscale image in a way that matches
    the behavior of skimage.filters.threshold_otsu, including histogram scaling and
    threshold rescaling back to the original intensity range.

    This function enables nearly identical thresholding results to skimage's implementation,
    even on images with floating-point or non-8-bit integer data.

    Args:
        image (np.ndarray): Input image as a 2D array of grayscale values.

    Returns:
        threshold_original (float): Computed Otsu threshold mapped back to the original image range.
    """
    # Compute histogram and bin centers consistent with skimage
    hist, bin_centers = custom_histogram(image, nbins=256)
    hist = hist.astype(np.float64)

    # Normalize histogram to obtain probability distribution p(k)
    p = hist / hist.sum()

    # Compute cumulative sums of class probabilities ω0 and ω1
    omega0 = np.cumsum(p)                           # Class probabilities for background
    omega1 = np.cumsum(p[::-1])[::-1]               # Class probabilities for foreground

    # Compute cumulative sums of class means μ0 and μ1
    mu0 = np.cumsum(p * bin_centers)                # Class means for background
    mu1 = np.cumsum((p * bin_centers)[::-1])[::-1]  # Class means for foreground

    # Compute between-class variance σ_b^2 for each possible threshold
    sigma_b_squared = (omega0[:-1] * omega1[1:] * (mu0[:-1] / omega0[:-1] - mu1[1:] / omega1[1:])**2)

    # Find the threshold index t maximizing σ_b^2
    t_idx = np.argmax(sigma_b_squared)
    t_scaled = bin_centers[t_idx]

    # Rescale threshold t back to original image intensity range
    img_min, img_max = image.min(), image.max()
    t_original = t_scaled / 255 * (img_max - img_min) + img_min

    return (image > t_original).astype(np.uint8)

# Dice score
def dice_score(otsu_img, otsu_gt):

    # control if the Pictures have the same Size
    if len(otsu_img) != len(otsu_gt):
       if len(otsu_img) > len(otsu_gt):
         otsu_img = otsu_img[1:len(otsu_gt)]
       else:
        otsu_gt = otsu_gt[1:len(otsu_img)]


    # defining the variables for the Dice Score equation
    positive_overlap = 0
    sum_img = 0
    sum_gt = 0

    for t, p in zip(otsu_img, otsu_gt):
        if t == 1:
            sum_img += 1
        if p == 1:
            sum_gt += 1
        if t == 1 and p == 1:
            positive_overlap += 1

    if sum_img + sum_gt == 0:
        return 1.0

    return 2 * positive_overlap / (sum_img + sum_gt)

# Process single image and its groundtruth
def process_single(img_path: Path, gt_path: Path) -> float:
    """ 
    Reads one image and corresponding groundtruth, proccesses, segments and computes dice score for one image.
    """
    # Reads image and reads, binarizes groundtruth
    img = imread(img_path, as_gray=True)
    img_scaled  = ((img / img.max()) * 255).astype(np.uint8)
    gt = imread(gt_path, as_gray=True)
    gt  = 1 - (((gt / gt.max()) * 255).astype(np.uint8) == 0)         
    
    # Wiener fitler, background estimation and removal
    img_filtered = img_scaled - local_wiener_filter(img_scaled, window_size=201)

    # Otsu
    binary1 = apply_global_otsu(img_filtered)
    
    # Invert if necessary
    if np.mean(binary1) > 0.5:
        binary1 = ~binary1


    # 5. Compute Dice score
    return dice_score(binary1.flatten(), gt.flatten())

# -------------------------------------------------------------------
# Mainroutine: Going through all image-gt pairs and collect all dice scores
# -------------------------------------------------------------------


# data set: NIH3T3
img_dir = Path("data/NIH3T3/img")
gt_dir  = Path("data/NIH3T3/gt")

dice_scores_NIH3T3 = []

for img_file in sorted(img_dir.glob("dna-*.png")):
    # Extract index to read the corresponding gt
    idx = img_file.stem.split('-')[-1]
    gt_file = gt_dir / f"{idx}.png"

    if not gt_file.exists():
        print(f" ! Kein Ground-Truth für {img_file.name} ! überspringe …")
        continue

    score = process_single(img_file, gt_file)
    dice_scores_NIH3T3.append(score)
    print(f"{img_file.name:<20} Dice = {score:.4f}")

# Vector containing all dice scores
dice_scores_NIH3T3 = np.array(dice_scores_NIH3T3)           
print("\nDONE ⇢ Mean Dice:", dice_scores_NIH3T3.mean())

np.save('NIH3T3_wiener_dice_scores', dice_scores_NIH3T3)



# data set: N2DL-HeLa
img_dir = Path("data/N2DL-HeLa/img")
gt_dir  = Path("data/N2DL-HeLa/gt")

dice_scores_N2DL_HeLa = []

for img_file in sorted(img_dir.glob("t-*.tif")):
    # Extract index to read the corresponding gt
    idx = img_file.stem.split('-')[-1]
    gt_file = gt_dir / f"man_seg{idx}.tif"

    if not gt_file.exists():
        print(f" ! Kein Ground-Truth für {img_file.name} ! überspringe …")
        continue

    score = process_single(img_file, gt_file)
    dice_scores_N2DL_HeLa.append(score)
    print(f"{img_file.name:<20} Dice = {score:.4f}")

# Vector containing all dice scores
dice_scores_N2DL_HeLa = np.array(dice_scores_N2DL_HeLa)           
print("\nDONE ⇢ Mean Dice:", dice_scores_N2DL_HeLa.mean())

np.save('N2DL-HeLa_wiener_dice_scores', dice_scores_N2DL_HeLa)


# dataset: N2DH-GOWT1
img_dir = Path("data/N2DH-GOWT1/img")
gt_dir  = Path("data/N2DH-GOWT1/gt")

dice_scores_N2DH_GOWT1 = []

for img_file in sorted(img_dir.glob("t-*.tif")):
    # Extract index to read the corresponding gt
    idx = img_file.stem.split('-')[-1]
    gt_file = gt_dir / f"man_seg{idx}.tif"

    if not gt_file.exists():
        print(f" ! Kein Ground-Truth für {img_file.name} ! überspringe …")
        continue

    score = process_single(img_file, gt_file)
    dice_scores_N2DH_GOWT1.append(score)
    print(f"{img_file.name:<20} Dice = {score:.4f}")

# Vector containing all dice scores
dice_scores_N2DH_GOWT1 = np.array(dice_scores_N2DH_GOWT1)           
print("\nDONE ⇢ Mean Dice:", dice_scores_N2DH_GOWT1.mean())

np.save('N2DH-GOWT1_wiener_dice_scores', dice_scores_N2DH_GOWT1)

dna-0.png            Dice = 0.8944
dna-1.png            Dice = 0.8845
dna-26.png           Dice = 0.8134


KeyboardInterrupt: 