In [None]:
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import cv2

from topostats.io import LoadScans
from topostats.filters import Filters

from skimage.morphology import label
from skimage.measure import regionprops

from scipy.ndimage import (
    distance_transform_edt,
    distance_transform_cdt,
    distance_transform_bf,
    distance_transform_edt,
    distance_transform_bf,
)

In [None]:
def plot(image: np.ndarray, title="", figsize=(8, 8)):
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    ax.imshow(image)
    ax.set_title(title)
    plt.show()


def plot_gallery(images: list, figsize=(20, 40)):
    cols = 5
    rows = int(np.floor(len(images) / cols) + 1)
    print(f"images: {len(images)} rows: {rows}, cols: {cols}")
    fig, ax = plt.subplots(int(rows), int(cols), figsize=figsize)
    for index, image in enumerate(images):
        ax[int(np.floor(index / cols)), int(index % cols)].imshow(image)
    plt.show()

In [None]:
FILE_DIR = Path("/Users/sylvi/topo_data/cats/flattened_images_numpy/")
# OUTPUT_DIR = Path("/Users/sylvi/topo_data/cats/flattened_images_numpy/")
CONFIG_DIR = Path("/Users/sylvi/topo_data/cats/catsconf.yaml")
files = FILE_DIR.glob("*.npy")

# images = []
# for file in files:
# print(file)
# image = cv2.imread(str(file), 0)
# images.append(image)
# plot(image)

images = []
for file in files:
    image = np.load(file)
    images.append(image)


threshold = 1.8
masks = []
for image in images:
    mask = image > threshold
    masks.append(mask)


# fig, ax = plt.subplots(len(images), 2, figsize=(15, 250))
# for index, (image, mask) in enumerate(zip(images, masks)):
#     ax[index, 0].imshow(image)
#     ax[index, 1].imshow(mask)

fig, ax = plt.subplots(1, 2, figsize=(15, 30))
ax[0].imshow(images[2])
ax[1].imshow(masks[2])
plt.show()

image = images[2]
mask = masks[2]

# Zoom in on image
image = image[0:512, 0:512]
mask = mask[0:512, 0:512]

# Get a mask
labelled = label(mask)
plot(labelled, title="labelled")

# Get an individual mask
molecule_mask = labelled == 5
plot(molecule_mask, title="molecule mask")

In [None]:
# Get molecule heights
molecule = np.full(image.shape, 0)
molecule[molecule_mask] = image[molecule_mask]

fig, ax = plt.subplots(1, 3, figsize=(20, 10))
vmin, vmax = (np.min(image), np.max(image))
ax[0].imshow(image, vmin=vmin, vmax=vmax)
ax[1].imshow(mask)
ax[2].imshow(molecule, vmin=vmin, vmax=vmax)
plt.show()

# Pick the highest point
highest_point_pos = np.unravel_index(np.argmax(molecule), molecule.shape)
highest_point_value = molecule[highest_point_pos]
print(f"highest point value: {highest_point_value} position: {highest_point_pos}")
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(image)
ax.annotate("highest point", xy=(highest_point_pos[1], highest_point_pos[0]))
ax.scatter(highest_point_pos[1], highest_point_pos[0], c="red", marker=".")
plt.show()

# adaptive_threshold = highest_point_value
# adaptive_threshold_limit = np.median(molecule)
# adaptive_threshold_step = 0.5

# found_pixels = [highest_point_pos]

# found_pixels_image = np.zeros(image.shape).astype(bool)
# found_pixels_image[highest_point_pos] = True

# while adaptive_threshold > adaptive_threshold_limit:
#     print(f"threshold: {adaptive_threshold}")
#     # Get the points adjacent to the points in the found_pixels list that are not already in the found_pixels list and if they are above the threshold, add the point to the found_pixels list

#     # Get the found points from the binary image
#     found_points = np.argwhere(found_pixels_image == True)

#     # Find points next to the found points that exceed the threshold
#     for point in found_points:
#         y, x = point
#         for j in range(-1, 2):
#             for i in range(-1, 2):
#                 neighbour_y, neighbour_x = (y + j, x + i)
#                 if neighbour_y >= 0 and neighbour_y < image.shape[0] and neighbour_x >= 0 and neighbour_x < image.shape[1]:
#                     if image[neighbour_y, neighbour_x] > adaptive_threshold:
#                         found_pixels_image[neighbour_y, neighbour_x] = True

#     plt.imshow(found_pixels_image)
#     plt.show()

#     adaptive_threshold -= adaptive_threshold_step


def dfs(array: np.ndarray, start: tuple, threshold: float):
    stack = [start]
    reachable_points = []
    visited = np.zeros_like(array, dtype=bool)
    while stack:
        point = stack.pop()
        if not visited[point]:
            visited[point] = True
            if array[point] > threshold:
                reachable_points.append(point)
                for neighbour in get_neighbours(point, array.shape):
                    if not visited[neighbour]:
                        stack.append(neighbour)
    return np.array(reachable_points)


def get_neighbours(point, shape):
    y, x = point
    neighbours = [
        (ny, nx)
        for ny, nx in [(y - 1, x), (y + 1, x), (y, x - 1), (y, x + 1)]
        if 0 <= ny < shape[0] and 0 <= nx < shape[1]
    ]
    return neighbours


threshold = 2.3
reachable_points = dfs(array=image, start=highest_point_pos, threshold=threshold)

display_image = image.copy()
display_image[highest_point_pos] = 10.0
display_image[reachable_points[:, 0], reachable_points[:, 1]] = 10.0
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(display_image)
plt.show()

reachable_points_image = np.zeros_like(image, dtype=bool)
reachable_points_image[reachable_points[:, 0], reachable_points[:, 1]] = True
plot(reachable_points_image, title="reachable points", figsize=(10, 10))
dists = distance_transform_edt(reachable_points_image)
plot(dists, title="dists", figsize=(20, 20))
plot(dists > 1, title="thresholded dists", figsize=(20, 20))
fig, ax = plt.subplots(1, 2, figsize=(10, 20))
ax[0].imshow(image)
ax[1].imshow(dists > 1)

In [None]:
for j in range(-1, 2):
    for i in range(-1, 2):
        print(i, j)