In [None]:
from pathlib import Path
from typing import List, Tuple, Union, Optional, Dict
import re

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from skimage.morphology import binary_dilation, label
from skimage.measure import regionprops
from skimage.color import label2rgb
from skimage.graph import route_through_array
from sklearn.cluster import KMeans
from scipy.ndimage import distance_transform_edt
import seaborn as sns
from skimage.morphology import binary_erosion
from scipy.ndimage import binary_fill_holes
from skimage.feature import canny

from topostats.grain_finding_haribo_unet import (
    predict_unet,
    load_model,
    predict_unet_multiclass_and_get_angle,
    mean_iou,
    predict_unet_multiclass,
)

from topostats.haribonet_process_grain_bound import process_grain

In [None]:
# Load model
model_path = "./haribonet_multiclass_improved_norm_big_95_bridging_v1_2024-01-17_10-58-46.h5"
print(f"Loading Unet model: {model_path}")
model = load_model(model_path=model_path, custom_objects={"mean_iou": mean_iou})
print(f"Loaded Unet model: {model_path}")

In [None]:
# Get ring images
DATA_DIR = Path("/Users/sylvi/topo_data/hariborings/cas9_crops_p2nm/OT2_SC_p2nm/")
IMAGE_SAVE_DIR = DATA_DIR / "output_plots"
IMAGE_SAVE_DIR.mkdir(exist_ok=True)

image_files = sorted(DATA_DIR.glob("*.npy"))
print(f"num images: {len(image_files)}")

In [None]:
img_file = image_files[0]
print(f"Processing image: {img_file}")
# Get the p_to_nm value from the file name. it comes after the image number and is a float
p_to_nm = float(re.findall(r"\d+\.\d+", img_file.name)[0])
print(f"p_to_nm: {p_to_nm}")

In [None]:
predicted_masks = []
images = []
all_p_2_nm = []

for index, image_file in enumerate(image_files):
    print(f"Processing {image_file}")
    image = np.load(image_file)

    p_to_nm = float(re.findall(r"\d+\.\d+", image_file.name)[0])

    if p_to_nm < 0.59:
        # Predict mask
        predicted_mask = predict_unet_multiclass(
            model=model,
            image=image,
            confidence=0.5,
            model_image_size=256,
            image_output_dir=IMAGE_SAVE_DIR,
            IMAGE_SAVE_DIR=IMAGE_SAVE_DIR,
            filename="test",
            image_index=index,
        )

        images.append(image)
        predicted_masks.append(predicted_mask)
        all_p_2_nm.append(p_to_nm)

print(f"len images: {len(images)}")

In [None]:
# Plot gallery of images and predicted masks
def plot_images(images: list, masks: list, width=5):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 2, figsize=(20, 30))
    for i, (image, mask) in enumerate(zip(images, masks)):
        ax[i // width, i % width * 2].imshow(image, cmap="viridis")
        ax[i // width, i % width * 2].axis("off")
        ax[i // width, i % width * 2 + 1].imshow(mask, cmap="viridis")
        ax[i // width, i % width * 2 + 1].axis("off")
    fig.tight_layout()


plot_images(
    images=images,
    masks=predicted_masks,
)

In [None]:
# Check there is enough ring and gem pixels


def check_ring_and_mask_exists(image, combined_predicted_mask):
    # Check if there is a ring and mask larger than 100 pixels in the predicted mask

    min_ring_pixels = 40
    min_gem_pixels = 40

    ring_mask = combined_predicted_mask == 1
    if np.sum(ring_mask) < min_ring_pixels:
        raise ValueError(f"Ring pixels < {min_ring_pixels}")
    gem_mask = combined_predicted_mask == 2
    if np.sum(gem_mask) < min_gem_pixels:
        raise ValueError(f"Gem pixels < {min_gem_pixels}")

    return combined_predicted_mask


removed_images_large_enough_ring_gem = []
removed_masks_large_enough_ring_gem = []
predicted_masks_large_enough_ring_gem = []
images_large_enough_ring_gem = []
p_2_nm_large_enough_ring_gem = []

for index, (image, mask, p_2_nm) in enumerate(zip(images, predicted_masks, all_p_2_nm)):
    # Check if there is sufficient ring and gem pixels
    try:
        mask = check_ring_and_mask_exists(image, mask)
        predicted_masks_large_enough_ring_gem.append(mask)
        images_large_enough_ring_gem.append(image)
        p_2_nm_large_enough_ring_gem.append(p_2_nm)
    except ValueError as e:
        print(f"Skipping image {index}: {e}")
        removed_images_large_enough_ring_gem.append(image)
        removed_masks_large_enough_ring_gem.append(mask)
        continue

for removed_image, removed_mask in zip(removed_images_large_enough_ring_gem, removed_masks_large_enough_ring_gem):
    plt.imshow(removed_image, cmap="viridis")
    plt.imshow(removed_mask, alpha=0.5, cmap="viridis")
    plt.show()

In [None]:
def turn_small_gem_regions_into_ring(image, combined_predicted_mask):
    # Make a copy of the combined predicted mask
    combined_predicted_mask_copy = combined_predicted_mask.copy()
    gem_mask = combined_predicted_mask == 2

    # Find the largest gem region
    gem_labels = label(gem_mask)
    gem_regions = regionprops(gem_labels)
    gem_areas = [region.area for region in gem_regions]
    print(f"gem_areas: {gem_areas}")
    largest_gem_region = gem_regions[np.argmax(gem_areas)]

    # For all the other regions, check if they touch a ring region
    for region in gem_regions:
        if region.label == largest_gem_region.label:
            continue
        # Get only the pixels in the region
        region_mask = gem_labels == region.label
        # Dilate the region
        small_gem_dilation_strength = 5
        dilated_region_mask = region_mask
        for i in range(small_gem_dilation_strength):
            dilated_region_mask = binary_dilation(dilated_region_mask)
        # Get the intersection with the ring mask
        predicted_ring_mask = combined_predicted_mask == 1
        intersection = dilated_region_mask & predicted_ring_mask
        # If there is an intersection, then it is a ring
        if np.any(intersection):
            combined_predicted_mask[dilated_region_mask] = 1

    # # Plot them side by side
    # fig, ax = plt.subplots(1, 3, figsize=(10, 5))
    # ax[0].imshow(image)
    # ax[0].set_title("Image")
    # ax[1].imshow(combined_predicted_mask_copy)
    # ax[1].set_title("Predicted Mask (Before)")
    # ax[2].imshow(combined_predicted_mask)
    # ax[2].set_title("Predicted Mask (After)")
    # plt.show()

    return combined_predicted_mask


predicted_masks_small_gem_regions_into_ring = []
images_small_gem_regions_into_ring = []
p_2_nm_small_gem_regions_into_ring = []
for index, (image, mask) in enumerate(zip(images_large_enough_ring_gem, predicted_masks_large_enough_ring_gem)):
    print(f"processing image {index}")
    combined_predicted_mask = mask.copy()
    combined_predicted_mask = turn_small_gem_regions_into_ring(
        image=image,
        combined_predicted_mask=combined_predicted_mask,
    )
    predicted_masks_small_gem_regions_into_ring.append(combined_predicted_mask)
    images_small_gem_regions_into_ring.append(image)
    p_2_nm_small_gem_regions_into_ring.append(all_p_2_nm[index])

plot_images(
    images=images_small_gem_regions_into_ring,
    masks=predicted_masks_small_gem_regions_into_ring,
)

In [None]:
# Remove all but largest ring region. Connectivity should not include diagonals

images_largest_ring = []
masks_largest_ring = []
p_2_nm_largest_ring = []


def remove_all_but_largest_ring_region(image, combined_predicted_mask):
    ring_mask = combined_predicted_mask == 1
    # Find the largest ring region
    ring_labels = label(ring_mask, connectivity=1)
    ring_regions = regionprops(ring_labels)
    ring_areas = [region.area for region in ring_regions]
    largest_ring_region = ring_regions[np.argmax(ring_areas)]
    # For all the other regions, set them to 0
    for region in ring_regions:
        if region.label == largest_ring_region.label:
            continue
        # Get only the pixels in the region
        region_mask = ring_labels == region.label
        combined_predicted_mask[region_mask] = 0

    return combined_predicted_mask


for index, (image, mask) in enumerate(
    zip(images_small_gem_regions_into_ring, predicted_masks_small_gem_regions_into_ring)
):
    combined_predicted_mask = mask.copy()
    combined_predicted_mask = remove_all_but_largest_ring_region(
        image=image,
        combined_predicted_mask=combined_predicted_mask,
    )
    masks_largest_ring.append(combined_predicted_mask)
    images_largest_ring.append(image)
    p_2_nm_largest_ring.append(p_2_nm_small_gem_regions_into_ring[index])

plot_images(
    images=images_largest_ring,
    masks=masks_largest_ring,
)

In [None]:
# function to determine the number of connection points between ring and gem


def get_number_of_connection_points(image, combined_mask):
    # Get the number of connection points between the ring and gem
    ring_mask = combined_mask == 1
    gem_mask = combined_mask == 2
    # Dilate the ring mask
    gem_dilation_strength = 1
    dilated_gem_mask = gem_mask
    for i in range(gem_dilation_strength):
        dilated_gem_mask = binary_dilation(dilated_gem_mask)
    # Get the intersection with the gem mask
    intersection = dilated_gem_mask & ring_mask

    # Get number of separate connection point regions
    intersection_labels = label(intersection)
    intersection_regions = regionprops(intersection_labels)
    num_connection_points = len(intersection_regions)
    return num_connection_points, intersection_labels


# def find_middle_of_connection_points(image, combined_mask, intersection_labels):
#     # Find the middle of the connection points by finding the pixel with the shortest distance to the centroid
#     intersection_regions = regionprops(intersection_labels)
#     region_0 = intersection_regions[0]
#     region_1 = intersection_regions[1]
#     region_0_centroid = region_0.centroid
#     region_1_centroid = region_1.centroid
#     region_0_distances_to_centroid = []
#     region_1_distances_to_centroid = []
#     for pixel in region_0.coords:
#         region_0_distances_to_centroid.append(
#             np.linalg.norm(np.array(pixel) - np.array(region_0_centroid))
#         )
#     for pixel in region_1.coords:
#         region_1_distances_to_centroid.append(
#             np.linalg.norm(np.array(pixel) - np.array(region_1_centroid))
#         )
#     region_0_distances_to_centroid = np.array(region_0_distances_to_centroid)
#     region_1_distances_to_centroid = np.array(region_1_distances_to_centroid)
#     region_0_min_distance_to_centroid = np.min(region_0_distances_to_centroid)
#     region_1_min_distance_to_centroid = np.min(region_1_distances_to_centroid)
#     region_0_min_distance_to_centroid_index = np.argmin(
#         region_0_distances_to_centroid
#     )
#     region_1_min_distance_to_centroid_index = np.argmin(
#         region_1_distances_to_centroid
#     )
#     region_0_min_distance_to_centroid_pixel = region_0.coords[
#         region_0_min_distance_to_centroid_index
#     ]
#     region_1_min_distance_to_centroid_pixel = region_1.coords[
#         region_1_min_distance_to_centroid_index
#     ]
#     region_0_min_distance_to_centroid_pixel = np.array(
#         region_0_min_distance_to_centroid_pixel
#     )
#     region_1_min_distance_to_centroid_pixel = np.array(
#         region_1_min_distance_to_centroid_pixel
#     )

#     return (region_0_min_distance_to_centroid_pixel, region_1_min_distance_to_centroid_pixel)


# Find connecting regions for all images
images_2_connection_points = []
masks_2_connection_points = []
p_2_nm_2_connection_points = []
images_not_2_connection_points = []
masks_not_2_connection_points = []
all_intersection_labels = []
for index, (image, mask, p_2_nm) in enumerate(zip(images_largest_ring, masks_largest_ring, p_2_nm_largest_ring)):
    num_connection_points, intersection_labels = get_number_of_connection_points(
        image=image,
        combined_mask=mask,
    )
    if num_connection_points != 2:
        images_not_2_connection_points.append(image)
        masks_not_2_connection_points.append(mask)
    else:
        images_2_connection_points.append(image)
        masks_2_connection_points.append(mask)
        all_intersection_labels.append(intersection_labels)
        p_2_nm_2_connection_points.append(p_2_nm)

print(f"num failed: {len(images_not_2_connection_points)}")
print(f"num passed: {len(images_2_connection_points)}")

plot_images(
    images=images_not_2_connection_points,
    masks=masks_not_2_connection_points,
)

# plot_images(
#     images=images_2_connection_points,
#     masks=masks_2_connection_points,
# )


def plot_images_with_overlays(images: list, masks: list, overlays: list, width=5):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 2, figsize=(20, 30))
    for i, (image, mask, overlay) in enumerate(zip(images, masks, overlays)):
        ax[i // width, i % width * 2].imshow(image, cmap="viridis")
        ax[i // width, i % width * 2].axis("off")
        ax[i // width, i % width * 2 + 1].imshow(mask, cmap="viridis")
        ax[i // width, i % width * 2 + 1].imshow(overlay > 0, alpha=0.5, cmap="viridis")
        ax[i // width, i % width * 2 + 1].axis("off")
    fig.tight_layout()


plot_images_with_overlays(
    images=images_2_connection_points,
    masks=masks_2_connection_points,
    overlays=all_intersection_labels,
)

# image = images_largest_ring[0]
# mask = masks_largest_ring[0]

# num_connection_points, intersection_labels = get_number_of_connection_points(
#     image=image,
#     combined_mask=mask,
# )

# connecting_midpoints = find_middle_of_connection_points(
#     image=image,
#     combined_mask=mask,
#     intersection_labels=intersection_labels,
# )

# plt.imshow(image, cmap="viridis")
# plt.imshow(mask, alpha=0.5, cmap="viridis")
# plt.scatter(connecting_midpoints[0][1], connecting_midpoints[0][0], c="r", s=5)
# plt.scatter(connecting_midpoints[1][1], connecting_midpoints[1][0], c="b", s=5)
# plt.show()

In [None]:
# # Get distance transform for the masks
# index = 3
# mask = masks_2_connection_points[index]
# intersection_labels = all_intersection_labels[index]
# plt.imshow(mask)
# plt.imshow(intersection_labels > 0, alpha=0.5)
# plt.show()

paths = []

for index, (image, mask, intersection_labels) in enumerate(
    zip(images_2_connection_points, masks_2_connection_points, all_intersection_labels)
):
    # Get distance transform for the mask

    distance_transform = distance_transform_edt(mask > 0)
    distance_transform[mask == 2] = 0

    # plt.imshow(distance_transform)
    # plt.show()

    # starting point is the point where intersection region 0 has the largest distance transform value
    intersection_labels = label(intersection_labels)
    intersection_regions = regionprops(intersection_labels)
    region_0 = intersection_regions[0]
    region_1 = intersection_regions[1]
    region_0_distance_transform_values = []
    region_1_distance_transform_values = []
    for pixel in region_0.coords:
        region_0_distance_transform_values.append(distance_transform[pixel[0], pixel[1]])
    for pixel in region_1.coords:
        region_1_distance_transform_values.append(distance_transform[pixel[0], pixel[1]])
    region_0_distance_transform_values = np.array(region_0_distance_transform_values)
    region_1_distance_transform_values = np.array(region_1_distance_transform_values)
    region_0_max_distance_transform_value = np.max(region_0_distance_transform_values)
    region_1_max_distance_transform_value = np.max(region_1_distance_transform_values)
    region_0_max_distance_transform_value_index = np.argmax(region_0_distance_transform_values)
    region_1_max_distance_transform_value_index = np.argmax(region_1_distance_transform_values)
    region_0_max_distance_transform_value_pixel = region_0.coords[region_0_max_distance_transform_value_index]
    region_1_max_distance_transform_value_pixel = region_1.coords[region_1_max_distance_transform_value_index]

    # # plot the points
    # plt.imshow(mask)
    # plt.imshow(intersection_labels > 0, alpha=0.5)
    # plt.scatter(region_0_max_distance_transform_value_pixel[1], region_0_max_distance_transform_value_pixel[0], c="r", s=5)
    # plt.scatter(region_1_max_distance_transform_value_pixel[1], region_1_max_distance_transform_value_pixel[0], c="b", s=5)
    # plt.show()

    start_point = (
        region_0_max_distance_transform_value_pixel[0],
        region_0_max_distance_transform_value_pixel[1],
    )
    end_point = (
        region_1_max_distance_transform_value_pixel[0],
        region_1_max_distance_transform_value_pixel[1],
    )

    inverted_distance_transform = np.max(distance_transform) - distance_transform
    # plt.imshow(inverted_distance_transform)
    inverted_distance_transform[inverted_distance_transform == np.max(inverted_distance_transform)] = 1000

    # Find the shortest path between the two points weighted by the inverted distance transform
    route, weight = route_through_array(inverted_distance_transform, start_point, end_point)
    route = np.array(route)

    # plt.imshow(inverted_distance_transform)
    # plt.plot(route[:, 1], route[:, 0], "r-")
    # plt.show()

    paths.append(route)


def plot_images_with_paths(images: list, masks: list, paths: list, width=5):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 2, figsize=(20, 30))
    for i, (image, mask, path) in enumerate(zip(images, masks, paths)):
        ax[i // width, i % width * 2].imshow(image, cmap="viridis")
        ax[i // width, i % width * 2].axis("off")
        ax[i // width, i % width * 2 + 1].imshow(mask, cmap="viridis")
        ax[i // width, i % width * 2 + 1].plot(path[:, 1], path[:, 0], "r-")
        ax[i // width, i % width * 2 + 1].axis("off")
    fig.tight_layout()


plot_images_with_paths(
    images=images_2_connection_points,
    masks=masks_2_connection_points,
    paths=paths,
)

In [None]:
# Calculate start and end angles for the paths


def calculate_start_and_end_angles(path):
    # Calculate the start and end angles for the path by averaging the angle from the start and end points to the next points

    num_averaging_points = 5

    start_origin = path[0]
    end_origin = path[-1]

    start_points = path[1:num_averaging_points]
    end_points = path[-num_averaging_points:-1]

    # Calculate mean vector from start point to next points
    start_vectors = start_points - start_origin
    start_mean_vector = np.mean(start_vectors, axis=0)

    # Calculate mean vector from end point to next points
    end_vectors = end_points - end_origin
    end_mean_vector = np.mean(end_vectors, axis=0)

    # Calculate start and end angles
    start_angle = np.arctan2(start_mean_vector[0], start_mean_vector[1])
    end_angle = np.arctan2(end_mean_vector[0], end_mean_vector[1])

    angle_difference = np.abs(start_angle - end_angle)
    if angle_difference > np.pi:
        angle_difference = 2 * np.pi - angle_difference

    # convert to degrees
    angle_difference = angle_difference * 180 / np.pi

    return (
        angle_difference,
        start_angle,
        end_angle,
        start_mean_vector,
        end_mean_vector,
        start_origin,
        end_origin,
    )


# index = 22
# path = paths[index]
# image = images_2_connection_points[index]
# mask = masks_2_connection_points[index]

# (
#     angle_difference,
#     start_angle,
#     end_angle,
#     start_mean_vector,
#     end_mean_vector,
#     start_origin,
#     end_origin,
# ) = calculate_start_and_end_angles(path)

# # Plot the start and end points and vectors
# plt.imshow(image)
# plt.imshow(mask, alpha=0.5)
# plt.scatter(start_origin[1], start_origin[0], c="r", s=5)
# plt.scatter(end_origin[1], end_origin[0], c="b", s=5)
# # plot path
# plt.plot(path[:, 1], path[:, 0], "r-")
# plt.quiver(
#     start_origin[1],
#     start_origin[0],
#     start_mean_vector[1],
#     -start_mean_vector[0],
#     color="r",
#     scale=10,
# )
# plt.quiver(
#     end_origin[1],
#     end_origin[0],
#     end_mean_vector[1],
#     -end_mean_vector[0],
#     color="b",
#     scale=10,
# )
# plt.show()

# print(f"angle difference: {angle_difference}")


def plot_images_with_paths_and_angles(
    images: list,
    masks: list,
    paths: list,
    angles: list,
    start_angles: list,
    end_angles: list,
    start_origins: list,
    end_origins: list,
    width=5,
):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 2, figsize=(20, 30))
    for i, (
        image,
        mask,
        path,
        angle,
        start_angle,
        end_angle,
        start_origin,
        end_origin,
    ) in enumerate(zip(images, masks, paths, angles, start_angles, end_angles, start_origins, end_origins)):
        ax[i // width, i % width * 2].imshow(image, cmap="viridis")
        ax[i // width, i % width * 2].axis("off")
        ax[i // width, i % width * 2 + 1].imshow(mask, cmap="viridis")
        ax[i // width, i % width * 2 + 1].plot(path[:, 1], path[:, 0], "r-")
        ax[i // width, i % width * 2 + 1].quiver(
            start_origin[1],
            start_origin[0],
            np.cos(start_angle),
            -np.sin(start_angle),
            color="r",
            scale=3,
        )
        ax[i // width, i % width * 2 + 1].quiver(
            end_origin[1],
            end_origin[0],
            np.cos(end_angle),
            -np.sin(end_angle),
            color="b",
            scale=3,
        )
        ax[i // width, i % width * 2 + 1].axis("off")
        ax[i // width, i % width * 2 + 1].set_title(f"angle: {angle:.2f} degrees")
    fig.tight_layout()


angle_differences = []
start_angles = []
end_angles = []
start_origins = []
end_origins = []

for index, (image, mask, path) in enumerate(zip(images_2_connection_points, masks_2_connection_points, paths)):
    (
        angle_difference,
        start_angle,
        end_angle,
        start_mean_vector,
        end_mean_vector,
        start_origin,
        end_origin,
    ) = calculate_start_and_end_angles(path)

    angle_differences.append(angle_difference)
    start_angles.append(start_angle)
    end_angles.append(end_angle)
    start_origins.append(start_origin)
    end_origins.append(end_origin)

plot_images_with_paths_and_angles(
    images=images_2_connection_points,
    masks=masks_2_connection_points,
    paths=paths,
    angles=angle_differences,
    start_angles=start_angles,
    end_angles=end_angles,
    start_origins=start_origins,
    end_origins=end_origins,
)

In [None]:
sns.histplot(angle_differences, bins="auto")
plt.show()
sns.kdeplot(angle_differences)
plt.show()

np.save(IMAGE_SAVE_DIR / "angle_differences_high_res_only.npy", angle_differences)

In [None]:
def calculate_edges(grain_mask: np.ndarray):
    """Class method that takes a 2D boolean numpy array image of a grain and returns a python list of the
    coordinates of the edges of the grain.

    Parameters
    ----------
    grain_mask : np.ndarray
        A 2D numpy array image of a grain. Data in the array must be boolean.
    edge_detection_method : str
        Method used for detecting the edges of grain masks before calculating statistics on them.
        Do not change unless you know exactly what this is doing. Options: "binary_erosion", "canny".

    Returns
    -------
    edges : list
        List containing the coordinates of the edges of the grain.
    """
    # Fill any holes
    filled_grain_mask = binary_fill_holes(grain_mask)

    # Add padding (needed for erosion)
    padded = np.pad(filled_grain_mask, 1)
    # Erode by 1 pixel
    eroded = binary_erosion(padded)
    # Remove padding
    eroded = eroded[1:-1, 1:-1]

    # Edges is equal to the difference between the
    # original image and the eroded image.
    edges = filled_grain_mask.astype(int) - eroded.astype(int)

    nonzero_coordinates = edges.nonzero()
    # Get vector representation of the points
    # FIXME : Switched to list comprehension but should be unnecessary to create this as a list as we can use
    # np.stack() to combine the arrays and use that...
    # return np.stack(nonzero_coordinates, axis=1)
    # edges = []
    # for vector in np.transpose(nonzero_coordinates):
    #     edges.append(list(vector))
    # return edges
    return [list(vector) for vector in np.transpose(nonzero_coordinates)]


def is_clockwise(p_1: tuple, p_2: tuple, p_3: tuple) -> bool:
    """Function to determine if three points make a clockwise or counter-clockwise turn.

    Parameters
    ----------
    p_1: tuple
        First point to be used to calculate turn.
    p_2: tuple
        Second point to be used to calculate turn.
    p_3: tuple
        Third point to be used to calculate turn.

    Returns
    -------
    boolean
        Indicator of whether turn is clockwise.
    """
    # Determine if three points form a clockwise or counter-clockwise turn.
    # I use the method of calculating the determinant of the following rotation matrix here. If the determinant
    # is > 0 then the rotation is counter-clockwise.
    rotation_matrix = np.array(((p_1[0], p_1[1], 1), (p_2[0], p_2[1], 1), (p_3[0], p_3[1], 1)))
    return not np.linalg.det(rotation_matrix) > 0


def get_triangle_height(base_point_1: np.array, base_point_2: np.array, top_point: np.array) -> float:
    """Returns the height of a triangle defined by the input point vectors.
    Parameters
    ----------
    base_point_1: np.ndarray
        a base point of the triangle, eg: [5, 3].

    base_point_2: np.ndarray
        a base point of the triangle, eg: [8, 3].

    top_point: np.ndarray
        the top point of the triangle, defining the height from the line between the two base points, eg: [6,10].

    Returns
    -------
    Float
        The height of the triangle - ie the shortest distance between the top point and the line between the two
    base points.
    """

    # Height of triangle = A/b = ||AB X AC|| / ||AB||
    a_b = base_point_1 - base_point_2
    a_c = base_point_1 - top_point
    return np.linalg.norm(np.cross(a_b, a_c)) / np.linalg.norm(a_b)


def get_max_min_ferets(edge_points: list):
    """Returns the minimum and maximum feret diameters for a grain.
    These are defined as the smallest and greatest distances between
    a pair of callipers that are rotating around a 2d object, maintaining
    contact at all times.

    Parameters
    ----------
    edge_points: list
        a list of the vector positions of the pixels comprising the edge of the
        grain. Eg: [[0, 0], [1, 0], [2, 1]]
    Returns
    -------
    min_feret: float
        the minimum feret diameter of the grain
    max_feret: float
        the maximum feret diameter of the grain

    Notes
    -----
    The method starts out by calculating the upper and lower convex hulls using
    an algorithm based on the Graham Scan Algorithm [1]. Using these upper and
    lower hulls, the callipers are simulated as rotating clockwise around the grain.
    We determine the order in which vertices are encountered by comparing the
    gradients of the slopes between vertices. An array of pairs of points that
    are in contact with either calliper at a given time is created in order to
    be able to calculate the maximum feret diameter. The minimum diameter is a
    little tricky, since it won't simply be the shortest distance between two
    contact points, but it will occur somewhere during the rotation around a
    pair of contact points. It turns out that the point will always be such
    that two points are in contact with one calliper while the other calliper
    is in contact with another point. We can use this fact to be sure of finding
    the smallest feret diameter, simply by testing each triangle of 3 contact points
    as we iterate, finding the height of the triangle that is formed between the
    three aforementioned points, as this will be the perpendicular distance between
    the callipers.

    References
    ----------
    [1] Graham, R.L. (1972).
        "An Efficient Algorithm for Determining the Convex Hull of a Finite Planar Set".
        Information Processing Letters. 1 (4): 132-133.
        doi:10.1016/0020-0190(72)90045-2.
    """

    # Sort the vectors by their x coordinate and then by their y coordinate.
    # The conversion between list and numpy array can be removed, though it would be harder
    # to read.
    edge_points.sort()
    edge_points = np.array(edge_points)

    # Construct upper and lower hulls for the edge points. Sadly we can't just use the standard hull
    # that graham_scan() returns, since we need to separate the upper and lower hulls. I might streamline
    # these two into one method later.
    upper_hull = []
    lower_hull = []
    for point in edge_points:
        while len(lower_hull) > 1 and is_clockwise(lower_hull[-2], lower_hull[-1], point):
            lower_hull.pop()
        lower_hull.append(point)
        while len(upper_hull) > 1 and not is_clockwise(upper_hull[-2], upper_hull[-1], point):
            upper_hull.pop()
        upper_hull.append(point)

    upper_hull = np.array(upper_hull)
    lower_hull = np.array(lower_hull)

    # Create list of contact vertices for calipers on the antipodal hulls
    contact_points = []
    upper_index = 0
    lower_index = len(lower_hull) - 1
    min_feret = None
    while upper_index < len(upper_hull) - 1 or lower_index > 0:
        contact_points.append([lower_hull[lower_index, :], upper_hull[upper_index, :]])
        # If we have reached the end of the upper hull, continute iterating over the lower hull
        if upper_index == len(upper_hull) - 1:
            lower_index -= 1
            small_feret = get_triangle_height(
                np.array(lower_hull[lower_index + 1, :]),
                np.array(lower_hull[lower_index, :]),
                np.array(upper_hull[upper_index, :]),
            )
            if min_feret is None or small_feret < min_feret:
                min_feret = small_feret
        # If we have reached the end of the lower hull, continue iterating over the upper hull
        elif lower_index == 0:
            upper_index += 1
            small_feret = get_triangle_height(
                np.array(upper_hull[upper_index - 1, :]),
                np.array(upper_hull[upper_index, :]),
                np.array(lower_hull[lower_index, :]),
            )
            if min_feret is None or small_feret < min_feret:
                min_feret = small_feret
        # Check if the gradient of the last point and the proposed next point in the upper hull is greater than the gradient
        # of the two corresponding points in the lower hull, if so, this means that the next point in the upper hull
        # will be encountered before the next point in the lower hull and vice versa.
        # Note that the calculation here for gradients is the simple delta upper_y / delta upper_x > delta lower_y / delta lower_x
        # however I have multiplied through the denominators such that there are no instances of division by zero. The
        # inequality still holds and provides what is needed.
        elif (upper_hull[upper_index + 1, 1] - upper_hull[upper_index, 1]) * (
            lower_hull[lower_index, 0] - lower_hull[lower_index - 1, 0]
        ) > (lower_hull[lower_index, 1] - lower_hull[lower_index - 1, 1]) * (
            upper_hull[upper_index + 1, 0] - upper_hull[upper_index, 0]
        ):
            # If the upper hull is encoutnered first, increment the iteration index for the upper hull
            # Also consider the triangle that is made as the two upper hull vertices are colinear with the caliper
            upper_index += 1
            small_feret = get_triangle_height(
                np.array(upper_hull[upper_index - 1, :]),
                np.array(upper_hull[upper_index, :]),
                np.array(lower_hull[lower_index, :]),
            )
            if min_feret is None or small_feret < min_feret:
                min_feret = small_feret
        else:
            # The next point in the lower hull will be encountered first, so increment the lower hull iteration index.
            lower_index -= 1
            small_feret = get_triangle_height(
                np.array(lower_hull[lower_index + 1, :]),
                np.array(lower_hull[lower_index, :]),
                np.array(upper_hull[upper_index, :]),
            )

            if min_feret is None or small_feret < min_feret:
                min_feret = small_feret

    contact_points = np.array(contact_points)

    # Find the minimum and maximum distance in the contact points
    max_feret = None
    for point_pair in contact_points:
        dist = np.sqrt((point_pair[0, 0] - point_pair[1, 0]) ** 2 + (point_pair[0, 1] - point_pair[1, 1]) ** 2)
        if max_feret is None or max_feret < dist:
            max_feret = dist

    return min_feret, max_feret

In [None]:
# Calculate min, max feret diameters for the ring masks

min_ferets = []
max_ferets = []

for index, (image, mask, p_2_nm) in enumerate(
    zip(images_2_connection_points, masks_2_connection_points, p_2_nm_2_connection_points)
):
    ring_mask = mask == 1
    edge_points = calculate_edges(ring_mask)
    min_feret, max_feret = get_max_min_ferets(edge_points)
    min_ferets.append(min_feret * p_2_nm)
    max_ferets.append(max_feret * p_2_nm)

sns.histplot(min_ferets, bins="auto", label="min feret")
sns.histplot(max_ferets, bins="auto", label="max feret")
plt.legend()


np.save(IMAGE_SAVE_DIR / "min_ferets_high_res_only.npy", min_ferets)
np.save(IMAGE_SAVE_DIR / "max_ferets_high_res_only.npy", max_ferets)