In [None]:
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import cv2

from topostats.io import LoadScans
from topostats.filters import Filters

from skimage.morphology import label
from skimage.measure import regionprops
from skimage.filters.rank import (
    entropy,
    equalize,
    gradient,
    gradient_percentile,
    maximum,
    mean,
    mean_bilateral,
    mean_percentile,
    median,
    minimum,
    modal,
    noise_filter,
    percentile,
    otsu,
    sum_bilateral,
    threshold,
    threshold_percentile,
    windowed_histogram,
)
from skimage.filters import (
    gaussian,
    threshold_local,
    frangi,
    gabor,
    sobel,
    hessian,
    laplace,
    threshold_mean,
    threshold_isodata,
    threshold_li,
    threshold_minimum,
    threshold_multiotsu,
    threshold_niblack,
    threshold_sauvola,
    threshold_triangle,
    threshold_yen,
    try_all_threshold,
    wiener,
)
from skimage.color import label2rgb
from skimage.feature import hessian_matrix, hessian_matrix_eigvals, canny
from skimage.segmentation import active_contour
from scipy.ndimage import (
    distance_transform_edt,
    distance_transform_cdt,
    distance_transform_bf,
    distance_transform_edt,
    distance_transform_bf,
    convolve,
    generic_filter,
)

In [None]:
def plot(image: np.ndarray, title="", figsize=(8, 8), display_range=False):
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    ax.imshow(image)
    if display_range:
        title = title + f" range: {np.min(image), np.max(image)}"
    ax.set_title(title)
    plt.show()


def plot_gallery(images: list, figsize=(20, 40)):
    cols = 5
    rows = int(np.floor(len(images) / cols) + 1)
    print(f"images: {len(images)} rows: {rows}, cols: {cols}")
    fig, ax = plt.subplots(int(rows), int(cols), figsize=figsize)
    for index, image in enumerate(images):
        ax[int(np.floor(index / cols)), int(index % cols)].imshow(image)
    plt.show()


def detect_ridges(gray, sigma=1.0):
    H_elems = hessian_matrix(gray, sigma=sigma, order="rc")
    maxima_ridges, minima_ridges = hessian_matrix_eigvals(H_elems)
    return maxima_ridges, minima_ridges

In [None]:
def get_neighbours(point, shape):
    y, x = point
    neighbours = [
        (ny, nx)
        for ny, nx in [(y - 1, x), (y + 1, x), (y, x - 1), (y, x + 1)]
        if 0 <= ny < shape[0] and 0 <= nx < shape[1]
    ]
    return neighbours


def calc_zeros_in_range(array: np.ndarray, range: int):
    kernel = np.ones((range, range))

    array_int = array.astype(int)

    array_inverted = 1 - array_int

    result = convolve(array_inverted, kernel, mode="constant", cval=0.0)

    return result


def filter_by_percentile(arr, percentile: float, kernel_size=3):
    # Define the function to apply to each window
    index_of_middle = np.square(kernel_size) // 2

    def func(window):
        # Find the median value of the window
        threshold = np.percentile(window, percentile)
        # Set the center pixel to 0 if it is less than the median
        if window[index_of_middle] < threshold:
            return 0
        else:
            return window[index_of_middle]

    # Create a new array by applying the function to each 3x3 window
    result = generic_filter(arr, func, size=kernel_size, mode="constant", cval=0.0)

    return result

### Load images

In [None]:
FILE_DIR = Path("/Users/sylvi/topo_data/cats/flattened_images_numpy/")
# OUTPUT_DIR = Path("/Users/sylvi/topo_data/cats/flattened_images_numpy/")
CONFIG_DIR = Path("/Users/sylvi/topo_data/cats/catsconf.yaml")
files = FILE_DIR.glob("*.npy")

# images = []
# for file in files:
# print(file)
# image = cv2.imread(str(file), 0)
# images.append(image)
# plot(image)

images = []
for file in files:
    image = np.load(file)
    images.append(image)


threshold = 1.8
masks = []
for image in images:
    mask = image > threshold
    masks.append(mask)

In [None]:
# Start at generous threshold

# Threshold

# Label regions

# Count holes

# Lower threshold

# Label regions

# Count holes

# If holes number has increased check hole size

# If hole size is small, ban that hole

image = images[6]
thresholded_images = []
for threshold in reversed(np.arange(1, 3, 0.2)):
    thresholded_images.append(image > threshold)

plot(image)
for thresholded_image in thresholded_images:
    plot(thresholded_image)

### Experiment with filters

In [None]:
from skimage.morphology import disk
from scipy.ndimage import laplace as scipy_laplace
import cv2
import numpy as np
from scipy import ndimage as ndi
from skimage.feature import peak_local_max
from skimage import data, img_as_float


class Bot:

    def __init__(self, py, px, view_range: int) -> None:
        self.view_range = view_range
        self.px = px
        self.py = py

    def move(self, dx: int, dy: int):
        self.px += dx
        self.py += dy
    
    def get_local_area(self):
        py = self.py
        px = self.px
        view_range = self.view_range
        image[py-view_range:py+view_range+1, px-view_range:px+view_range+1]

    def navigate(self):
        

    def display(self):
        px = self.px
        py = self.py
        view_range = self.view_range
        fig, ax = plt.subplots(1, 2, figsize=(20, 10))
        ax[0].imshow(image)
        rect = patches.Rectangle((px-view_range-1, py-view_range-1), view_range*2+1, view_range*2+1, facecolor='none', edgecolor='r', linewidth=1)
        ax[0].add_patch(rect)
        ax[0].scatter([px], [py], edgecolors='none', c='b', s=10)
        ax[1].imshow(self.get_local_area())
        plt.show()


image = gaussian(images[7].copy(), sigma=1)

# Get maximum height
highest_point_pos = np.unravel_index(np.argmax(image), image.shape)
ipx = highest_point_pos[1]
ipy = highest_point_pos[0]


bot = Bot(ipy, ipx, view_range=8)
bot.display()


In [None]:
image = images[3]
initial_mask, exaggerated_contours = exaggerate_contours(image, threshold=2.0)
distance_transform = calc_zeros_in_range(exaggerated_contours, range=5)
plot(distance_transform, figsize=(20, 20))

### Plot gallery

In [None]:
fig, ax = plt.subplots(len(images), 3, figsize=(20, 300))
for index, (image, mask) in enumerate(zip(images, masks)):
    ax[index, 0].imshow(image)
    initial_mask, exaggerated_contours = exaggerate_contours(image, threshold=2.0)
    ax[index, 1].imshow(initial_mask)
    ax[index, 2].imshow(exaggerated_contours)

fig, ax = plt.subplots(1, 2, figsize=(15, 30))
ax[0].imshow(images[2])
ax[1].imshow(masks[2])
plt.show()

image = images[2]
mask = masks[2]

### Zoom in on one image

In [None]:
image = images[2]
mask = masks[2]

# Zoom in on image
image = image[0:512, 0:512]
mask = mask[0:512, 0:512]

# Get a mask
labelled = label(mask)
plot(labelled, title="labelled")

# Get an individual mask
molecule_mask = labelled == 5
plot(molecule_mask, title="molecule mask")

# Get molecule heights
molecule = np.full(image.shape, 0.0)
molecule[molecule_mask] = image[molecule_mask]

plot(molecule)

In [None]:
def dfs(array: np.ndarray, start: tuple, threshold: float):
    stack = [start]
    reachable_points = []
    visited = np.zeros_like(array, dtype=bool)
    while stack:
        point = stack.pop()
        if not visited[point]:
            visited[point] = True
            if array[point] > threshold:
                reachable_points.append(point)
                for neighbour in get_neighbours(point, array.shape):
                    if not visited[neighbour]:
                        stack.append(neighbour)
    return np.array(reachable_points)


threshold = 2.4

visited_points = np.zeros_like(molecule)

# Pick the highest point
highest_point_pos = np.unravel_index(np.argmax(molecule), molecule.shape)
highest_point_value = molecule[highest_point_pos]

print(f"molecule height range: {np.min(molecule), np.max(molecule)}")


current_point = highest_point_pos
for i in range(2920):
    neighbours = get_neighbours(current_point, molecule.shape)

    max_neighbour_height = 0
    max_neighbour_index = 0
    for neighbour_index, neighbour in enumerate(neighbours):
        neighbour_height = molecule[neighbour[0], neighbour[1]]
        if neighbour_height > max_neighbour_height and visited_points[neighbour[0], neighbour[1]] == 0:
            max_neighbour_height = neighbour_height
            max_neighbour_index = neighbour_index

    current_point = neighbours[max_neighbour_index]
    visited_points[current_point[0], current_point[1]] = 1


plot(visited_points)
plot(molecule)

# reachable_points = dfs(molecule, highest_point_pos, threshold=threshold)
# display_image = molecule.copy()
# display_image[highest_point_pos] = 10.0
# display_image[reachable_points[:, 0], reachable_points[:, 1]] = 10.0
# fig, ax = plt.subplots(figsize=(10, 10))
# ax.imshow(display_image)
# plt.show()

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(20, 10))
vmin, vmax = (np.min(image), np.max(image))
ax[0].imshow(image, vmin=vmin, vmax=vmax)
ax[1].imshow(mask)
ax[2].imshow(molecule, vmin=vmin, vmax=vmax)
plt.show()

# Pick the highest point
highest_point_pos = np.unravel_index(np.argmax(molecule), molecule.shape)
highest_point_value = molecule[highest_point_pos]
print(f"highest point value: {highest_point_value} position: {highest_point_pos}")
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(image)
ax.annotate("highest point", xy=(highest_point_pos[1], highest_point_pos[0]))
ax.scatter(highest_point_pos[1], highest_point_pos[0], c="red", marker=".")
plt.show()

# adaptive_threshold = highest_point_value
# adaptive_threshold_limit = np.median(molecule)
# adaptive_threshold_step = 0.5

# found_pixels = [highest_point_pos]

# found_pixels_image = np.zeros(image.shape).astype(bool)
# found_pixels_image[highest_point_pos] = True

# while adaptive_threshold > adaptive_threshold_limit:
#     print(f"threshold: {adaptive_threshold}")
#     # Get the points adjacent to the points in the found_pixels list that are not already in the found_pixels list and if they are above the threshold, add the point to the found_pixels list

#     # Get the found points from the binary image
#     found_points = np.argwhere(found_pixels_image == True)

#     # Find points next to the found points that exceed the threshold
#     for point in found_points:
#         y, x = point
#         for j in range(-1, 2):
#             for i in range(-1, 2):
#                 neighbour_y, neighbour_x = (y + j, x + i)
#                 if neighbour_y >= 0 and neighbour_y < image.shape[0] and neighbour_x >= 0 and neighbour_x < image.shape[1]:
#                     if image[neighbour_y, neighbour_x] > adaptive_threshold:
#                         found_pixels_image[neighbour_y, neighbour_x] = True

#     plt.imshow(found_pixels_image)
#     plt.show()

#     adaptive_threshold -= adaptive_threshold_step


threshold = 2.3
reachable_points = dfs(array=image, start=highest_point_pos, threshold=threshold)

display_image = image.copy()
display_image[highest_point_pos] = 10.0
display_image[reachable_points[:, 0], reachable_points[:, 1]] = 10.0
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(display_image)
plt.show()

reachable_points_image = np.zeros_like(image, dtype=bool)
reachable_points_image[reachable_points[:, 0], reachable_points[:, 1]] = True
plot(reachable_points_image, title="reachable points", figsize=(10, 10))
dists = distance_transform_edt(reachable_points_image)
plot(dists, title="dists", figsize=(20, 20))
plot(dists > 1, title="thresholded dists", figsize=(20, 20))

zeros_in_range = calc_zeros_in_range(array=reachable_points_image, range=3)
zeros_in_range[np.invert(reachable_points_image)] = 0
zeros_in_range[zeros_in_range <= 3] = 0
print(f"unique(zeros_in_range): {np.unique(zeros_in_range)}")

filtered = filter_by_percentile(image, percentile=40, kernel_size=21)
plot(filtered, title="local maxima (>40th percentile)", figsize=(10, 10))

fig, ax = plt.subplots(2, 2, figsize=(20, 20))
ax[0, 0].imshow(image)
ax[0, 0].set_title("molecule image")
ax[0, 1].imshow(reachable_points_image)
ax[0, 1].set_title("depth first search result")
ax[1, 0].imshow(dists > 1)
ax[1, 0].set_title("dfs - points within 1 pixels of a zero")
ax[1, 1].imshow(reachable_points_image - zeros_in_range.astype(bool).astype(int))
ax[1, 1].set_title("dfs - points within 3 pixels of 2 or more 0s")

In [None]:
test = dists.copy()
plot(test)
# plot(test, figsize=(20, 20))


def calc_zeros_in_range(array: np.ndarray, range: int):
    kernel = np.ones((range, range))

    array_int = array.astype(int)

    array_inverted = 1 - array_int

    result = convolve(array_inverted, kernel, mode="constant", cval=0.0)

    return result


# # Find all the valss
# vals = test > 5
# inverted_vals = 1 - vals.astype(int)
# distance_to_five = distance_transform_edt(inverted_vals)
# plot(distance_to_five, title="distance_to_five", figsize=(20, 20))

vals = np.logical_or(test == 3, test == 4).astype(int)
vals += (test == 5).astype(int) * 2

inverted_vals = 1 - vals.astype(int)
distance_to_three = distance_transform_edt(inverted_vals)
plot(distance_to_three, title="distance_to_three", figsize=(20, 20))

plot(vals, figsize=(20, 20))


# far_from_5 = distance_to_five > 3
# plot(far_from_5)
# mol_far_from_5 = np.zeros_like(test)
# for j in range(test.shape[0]):
#     for i in range(test.shape[1]):
#         if far_from_5[j, i]:
#             mol_far_from_5[j, i] = test[j, i]
# plot(mol_far_from_5, figsize=(20, 20))


# # Find all the non-zero values that are not within x of a 5
# def calc_vals_in_range(array: np.ndarray, range: int, val: int):
#     all_vals_bool = array == val
#     plot(all_vals_bool)

#     kernel = np.ones((range, range))
#     array_int = all_vals_bool.astype(int)
#     result = convolve(array_int, kernel, mode="constant", cval=0.0)
#     return result

# vals_in_range = calc_vals_in_range(test, range=5, val=5)
# plot(vals_in_range, figsize=(20, 20))

# Select 3s not within 3 of a six and scale up
# plot(vals, figsize=(20, 20))

### Exaggerate contours

TRY OTSU TOO! AND FLATTEN POINTS > 2.5 to 2.5!

Perhaps if IQR > threshold then only take top 10% of pixel heights

In [None]:
# KDE Smoothing

import numpy as np
from scipy.ndimage.filters import gaussian_filter1d
from scipy import stats
from scipy.signal import find_peaks
import seaborn as sns

# Gaussian smoothing
dist = np.array([0, 0, 0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5])
dist = sorted(dist)
smoothed_dist = gaussian_filter1d(dist, 2)
sns.kdeplot(dist, label="no smoothing")
kde = stats.gaussian_kde(dist, bw_method=1)
x = np.linspace(min(dist), max(dist), len(dist))
y = kde(x)

sns.kdeplot(x, label="gaussian_kde")

sns.kdeplot(smoothed_dist, label="manually smoothed kde")
plt.legend()

In [None]:
import numpy as np
from scipy import stats
from scipy.signal import find_peaks, peak_widths
import seaborn as sns
from scipy.ndimage.filters import gaussian_filter1d


def detect_bimodal(data, smoothing=1):
    data = sorted(data)
    data = gaussian_filter1d(data, smoothing)

    # x = np.linspace(min(data), max(data), len(data))
    # y = kde(x)

    num_points = 100
    kde = stats.gaussian_kde(data)
    x = np.linspace(np.min(data), np.max(data), num_points)
    y = kde(x)

    # print(y)
    peak_indices, peaks_data = find_peaks(y, height=0.2, width=0)
    # print(f"peak data: {peaks_data}")

    peaks = x[peak_indices]
    # print(f"peaks: #:{len(peak_indices)} {peaks}")

    left_points = peaks_data["left_ips"]
    right_points = peaks_data["right_ips"]
    left_real_points = [x[int(round(i))] for i in left_points]
    right_real_points = [x[int(round(i))] for i in right_points]
    peak_minimums = left_real_points
    peak_maximums = right_real_points

    # plt.plot(x, y, label=f"smoothed kde σ={smoothing}")
    # plt.vlines(peaks, ymin=0, ymax=np.max(y), colors="r", label="peaks")
    # plt.vlines(peak_minimums, ymin=0, ymax=np.max(y), colors="orange", label="peak minimums")
    # plt.vlines(peak_maximums, ymin=0, ymax=np.max(y), colors="y", label="peak maximums")
    # plt.legend()
    # plt.show()

    if len(peak_indices) == 2:
        return True, peak_minimums[1]
    else:
        return False, None


image = images[6]
mask = np.zeros_like(image)
kernel_size = 51
padded_image = np.pad(image, kernel_size // 2, mode="constant", constant_values=np.median(image))
for j in range(image.shape[0]):
    for i in range(image.shape[1]):
        local_image = padded_image[
            j + kernel_size // 2 - kernel_size // 2 : j + kernel_size // 2 + kernel_size // 2,
            i + kernel_size // 2 - kernel_size // 2 : i + kernel_size // 2 + kernel_size // 2,
        ]
        # plot(local_image)
        is_bimodal, threshold = detect_bimodal(local_image.flatten(), smoothing=0.0001)
        if is_bimodal:
            mask[j, i] = image[j, i] > threshold

        # plot(mask)

In [None]:
plot(image)
plot(mask)

In [None]:
image = images[6].copy()
plot(image)

kernel_size = 25

padded_image = np.pad(image, kernel_size // 2, mode="constant", constant_values=np.median(image))

mask = np.zeros_like(image)

plot(padded_image)

for j in range(image.shape[0]):
    for i in range(image.shape[1]):
        # Get local pixels
        local_pixels = padded_image[j : j + kernel_size, i : i + kernel_size]
        # Check if this pixel is an outlier
        q1, q3 = np.quantile(local_pixels, q=[0.25, 0.75])
        iqr = q3 - q1
        mask[j, i] = image[j, i] > q3 + 1.5 * iqr

plot(mask)

In [None]:
def filter_by_percentile(arr, percentile: float, kernel_size=3):
    # Define the function to apply to each window
    index_of_middle = np.square(kernel_size) // 2

    def func(window):
        # Find the median value of the window
        threshold = np.percentile(window, percentile)
        # Set the center pixel to 0 if it is less than the median
        if window[index_of_middle] < threshold:
            return 0
        else:
            return window[index_of_middle]

    # Create a new array by applying the function to each 3x3 window
    result = generic_filter(arr, func, size=kernel_size, mode="constant", cval=0.0)

    return result


def exaggerate_contours(image: np.ndarray, threshold: float):
    # plot(image)

    initial_thesholded_image = image > threshold
    # plot(initial_thesholed_image, title="initial threshold", figsize=(20, 20))
    initial_holes = np.invert(initial_thesholded_image)
    # initial_holes_labelled = label(initial_holes)
    # initial_holes_labelled_color = label2rgb(initial_holes_labelled)
    # plot(initial_holes_labelled_color, title="initial holes labelled", figsize=(20, 20))

    edge_exaggerated = filter_by_percentile(image, percentile=40, kernel_size=15)
    # plot(edge_exaggerated, figsize=(20, 20))
    exaggerated_thresholded = edge_exaggerated > threshold
    exaggerated_holes = np.invert(exaggerated_thresholded)
    exaggerated_holes_labelled = label(exaggerated_holes)
    # exaggerated_holes_labelled_color = label2rgb(exaggerated_holes_labelled)
    # plot(exaggerated_holes_labelled_color, title="exaggerated holes labelled", figsize=(20, 20))

    # Find holes in exaggerated that are touching holes in initial and keep them

    # Get regions of exaggerated holes that share pixels with initial holes
    checked_exaggerated_holes = np.zeros_like(image)
    removed_exaggerated_holes = np.zeros_like(image)
    for index, region in enumerate(regionprops(exaggerated_holes_labelled)):
        # Check if any pixels in the region are True in initial_holes
        coordinates = region.coords
        if not np.any(initial_holes[coordinates[:, 0], coordinates[:, 1]]):
            # print(f"region {index} has no matching points")
            removed_exaggerated_holes[coordinates[:, 0], coordinates[:, 1]] = 1
        else:
            checked_exaggerated_holes[coordinates[:, 0], coordinates[:, 1]] = 1

    # plot(label2rgb(label(checked_exaggerated_holes)), title="checked holes", figsize=(20, 20))
    # plot(label2rgb(label(removed_exaggerated_holes)), title="removed holes", figsize=(20, 20))

    # fig, ax = plt.subplots(1, 3, figsize=(30, 15))
    # ax[0].imshow(image)
    # ax[0].set_title("image")
    # ax[1].imshow(initial_thesholded_image)
    # ax[1].set_title("initial thresholded image")
    # ax[2].imshow(checked_exaggerated_holes)
    # ax[2].set_title("checked exaggerated contours")

    return initial_thesholded_image, checked_exaggerated_holes

In [None]:
sub_image = image[200:380, 280:400]
# sub_image = image[250:300, 300:400]
plot(sub_image)
filtered = filter_by_percentile(sub_image, percentile=40, kernel_size=23)
plot(filtered)

In [None]:
threshold = 2

for image in images:
    mask = image > threshold

    fig, ax = plt.subplots(1, 3, figsize=(30, 10))
    ax[0].imshow(image)
    ax[1].imshow(mask)