In [None]:
from pathlib import Path
from typing import List, Tuple, Union, Optional, Dict
import re

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from skimage.morphology import binary_dilation, label
from skimage.measure import regionprops
from skimage.color import label2rgb
from skimage.graph import route_through_array
from sklearn.cluster import KMeans
from scipy.ndimage import distance_transform_edt
import seaborn as sns
from skimage.morphology import binary_erosion
from scipy.ndimage import binary_fill_holes
from skimage.feature import canny
from scipy.interpolate import splprep, splev

from topostats.grain_finding_haribo_unet import (
    predict_unet,
    load_model,
    predict_unet_multiclass_and_get_angle,
    mean_iou,
    predict_unet_multiclass,
)

from topostats.haribonet_process_grain_bound import process_grain

In [None]:
# Load model
model_path = "./haribonet_multiclass_improved_norm_big_95_bridging_v1_2024-01-17_10-58-46.h5"
print(f"Loading Unet model: {model_path}")
model = load_model(model_path=model_path, custom_objects={"mean_iou": mean_iou})
print(f"Loaded Unet model: {model_path}")

In [None]:
# Get ring images
DATA_DIR = Path("/Users/sylvi/topo_data/hariborings/cas9_crops_p2nm/OT2_SC_p2nm")
IMAGE_SAVE_DIR = DATA_DIR / "output_plots"
IMAGE_SAVE_DIR.mkdir(exist_ok=True)

image_files = sorted(DATA_DIR.glob("*.npy"))
print(f"num images: {len(image_files)}")

In [None]:
img_file = image_files[0]
print(f"Processing image: {img_file}")
# Get the p_to_nm value from the file name. it comes after the image number and is a float
p_to_nm = float(re.findall(r"\d+\.\d+", img_file.name)[0])
print(f"p_to_nm: {p_to_nm}")

In [None]:
# predicted_masks = []
# images = []
# all_p_2_nm = []

image_dict = {}

for index, image_file in enumerate(image_files):
    print(f"Processing {image_file}")
    image = np.load(image_file)

    p_to_nm = float(re.findall(r"\d+\.\d+", image_file.name)[0])

    if p_to_nm < 10.0:
        # Predict mask
        predicted_mask = predict_unet_multiclass(
            model=model,
            image=image,
            confidence=0.5,
            model_image_size=256,
            image_output_dir=IMAGE_SAVE_DIR,
            IMAGE_SAVE_DIR=IMAGE_SAVE_DIR,
            filename="test",
            image_index=index,
        )

        # images.append(image)
        # predicted_masks.append(predicted_mask)
        # all_p_2_nm.append(p_to_nm)

        image_dict[index] = {
            "image": image,
            "predicted_mask": predicted_mask,
            "p_to_nm": p_to_nm,
        }

print(f"len images: {len(image_dict)}")

In [None]:
# Plot gallery of images and predicted masks
def plot_images(images: list, masks: list, px_to_nms: list, width=5):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 2, figsize=(20, 30))
    for i, (image, mask) in enumerate(zip(images, masks)):
        ax[i // width, i % width * 2].imshow(image, cmap="viridis")
        ax[i // width, i % width * 2].axis("off")
        ax[i // width, i % width * 2 + 1].imshow(mask, cmap="viridis")
        ax[i // width, i % width * 2 + 1].axis("off")
        ax[i // width, i % width * 2].set_title(f"p_to_nm: {px_to_nms[i]}")
    fig.tight_layout()


plot_images(
    images=[image_dict[i]["image"] for i in image_dict],
    masks=[image_dict[i]["predicted_mask"] for i in image_dict],
    px_to_nms=[image_dict[i]["p_to_nm"] for i in image_dict],
)

In [None]:
# Check there is enough ring and gem pixels


def check_ring_and_mask_exists(image, combined_predicted_mask):
    # Check if there is a ring and mask larger than 100 pixels in the predicted mask

    min_ring_pixels = 40
    min_gem_pixels = 40

    ring_mask = combined_predicted_mask == 1
    if np.sum(ring_mask) < min_ring_pixels:
        raise ValueError(f"Ring pixels < {min_ring_pixels}")
    gem_mask = combined_predicted_mask == 2
    if np.sum(gem_mask) < min_gem_pixels:
        raise ValueError(f"Gem pixels < {min_gem_pixels}")

    return combined_predicted_mask


removed_images_large_enough_ring_gem = []
removed_masks_large_enough_ring_gem = []

image_dict_large_enough_ring_gem = {}

# for index, (image, mask, p_2_nm) in enumerate(zip(images, predicted_masks, all_p_2_nm)):
for index, image_dict_item in image_dict.items():
    image = image_dict_item["image"]
    mask = image_dict_item["predicted_mask"]
    p_2_nm = image_dict_item["p_to_nm"]

    # Check if there is sufficient ring and gem pixels
    try:
        mask = check_ring_and_mask_exists(image, mask)
        image_dict_large_enough_ring_gem[index] = image_dict_item
        image_dict_large_enough_ring_gem[index]["predicted_mask"] = mask

    except ValueError as e:
        print(f"Skipping image {index}: {e}")
        removed_images_large_enough_ring_gem.append(image)
        removed_masks_large_enough_ring_gem.append(mask)
        continue

for removed_image, removed_mask in zip(removed_images_large_enough_ring_gem, removed_masks_large_enough_ring_gem):
    plt.imshow(removed_image, cmap="viridis")
    plt.imshow(removed_mask, alpha=0.5, cmap="viridis")
    plt.show()

In [None]:
def turn_small_gem_regions_into_ring(image, combined_predicted_mask):
    # Make a copy of the combined predicted mask
    combined_predicted_mask_copy = combined_predicted_mask.copy()
    gem_mask = combined_predicted_mask == 2

    # Find the largest gem region
    gem_labels = label(gem_mask)
    gem_regions = regionprops(gem_labels)
    gem_areas = [region.area for region in gem_regions]
    print(f"gem_areas: {gem_areas}")
    largest_gem_region = gem_regions[np.argmax(gem_areas)]

    # For all the other regions, check if they touch a ring region
    for region in gem_regions:
        if region.label == largest_gem_region.label:
            continue
        # Get only the pixels in the region
        region_mask = gem_labels == region.label
        # Dilate the region
        small_gem_dilation_strength = 5
        dilated_region_mask = region_mask
        for i in range(small_gem_dilation_strength):
            dilated_region_mask = binary_dilation(dilated_region_mask)
        # Get the intersection with the ring mask
        predicted_ring_mask = combined_predicted_mask == 1
        intersection = dilated_region_mask & predicted_ring_mask
        # If there is an intersection, then it is a ring
        if np.any(intersection):
            combined_predicted_mask[dilated_region_mask] = 1

    # # Plot them side by side
    # fig, ax = plt.subplots(1, 3, figsize=(10, 5))
    # ax[0].imshow(image)
    # ax[0].set_title("Image")
    # ax[1].imshow(combined_predicted_mask_copy)
    # ax[1].set_title("Predicted Mask (Before)")
    # ax[2].imshow(combined_predicted_mask)
    # ax[2].set_title("Predicted Mask (After)")
    # plt.show()

    return combined_predicted_mask


image_dict_small_gem_regions_into_ring = {}

# predicted_masks_small_gem_regions_into_ring = []
# images_small_gem_regions_into_ring = []
# p_2_nm_small_gem_regions_into_ring = []

for index, image_dict_item in image_dict_large_enough_ring_gem.items():
    image = image_dict_item["image"]
    mask = image_dict_item["predicted_mask"]
    p_2_nm = image_dict_item["p_to_nm"]

    print(f"processing image {index}")
    combined_predicted_mask = mask.copy()
    combined_predicted_mask = turn_small_gem_regions_into_ring(
        image=image,
        combined_predicted_mask=combined_predicted_mask,
    )

    image_dict_small_gem_regions_into_ring[index] = image_dict_item
    image_dict_small_gem_regions_into_ring[index]["predicted_mask"] = combined_predicted_mask

plot_images(
    images=[image_dict_small_gem_regions_into_ring[i]["image"] for i in image_dict_small_gem_regions_into_ring],
    masks=[image_dict_small_gem_regions_into_ring[i]["predicted_mask"] for i in image_dict_small_gem_regions_into_ring],
    px_to_nms=[image_dict_small_gem_regions_into_ring[i]["p_to_nm"] for i in image_dict_small_gem_regions_into_ring],
)

In [None]:
# Remove all but largest ring region. Connectivity should not include diagonals

image_dict_largest_ring = {}


def remove_all_but_largest_ring_region(image, combined_predicted_mask):
    ring_mask = combined_predicted_mask == 1
    # Find the largest ring region
    ring_labels = label(ring_mask, connectivity=1)
    ring_regions = regionprops(ring_labels)
    ring_areas = [region.area for region in ring_regions]
    largest_ring_region = ring_regions[np.argmax(ring_areas)]
    # For all the other regions, set them to 0
    for region in ring_regions:
        if region.label == largest_ring_region.label:
            continue
        # Get only the pixels in the region
        region_mask = ring_labels == region.label
        combined_predicted_mask[region_mask] = 0

    return combined_predicted_mask


for index, image_dict_item in image_dict_small_gem_regions_into_ring.items():
    image = image_dict_item["image"]
    mask = image_dict_item["predicted_mask"]
    p_2_nm = image_dict_item["p_to_nm"]

    combined_predicted_mask = mask.copy()
    combined_predicted_mask = remove_all_but_largest_ring_region(
        image=image,
        combined_predicted_mask=combined_predicted_mask,
    )

    image_dict_largest_ring[index] = image_dict_item
    image_dict_largest_ring[index]["predicted_mask"] = combined_predicted_mask

plot_images(
    images=[image_dict_largest_ring[i]["image"] for i in image_dict_largest_ring],
    masks=[image_dict_largest_ring[i]["predicted_mask"] for i in image_dict_largest_ring],
    px_to_nms=[image_dict_largest_ring[i]["p_to_nm"] for i in image_dict_largest_ring],
)

In [None]:
# function to determine the number of connection points between ring and gem


def get_number_of_connection_points(image, combined_mask):
    # Get the number of connection points between the ring and gem
    ring_mask = combined_mask == 1
    gem_mask = combined_mask == 2
    # Dilate the ring mask
    gem_dilation_strength = 1
    dilated_gem_mask = gem_mask
    for i in range(gem_dilation_strength):
        dilated_gem_mask = binary_dilation(dilated_gem_mask)
    # Get the intersection with the gem mask
    intersection = dilated_gem_mask & ring_mask

    # Get number of separate connection point regions
    intersection_labels = label(intersection)
    intersection_regions = regionprops(intersection_labels)
    num_connection_points = len(intersection_regions)
    return num_connection_points, intersection_labels


# def find_middle_of_connection_points(image, combined_mask, intersection_labels):
#     # Find the middle of the connection points by finding the pixel with the shortest distance to the centroid
#     intersection_regions = regionprops(intersection_labels)
#     region_0 = intersection_regions[0]
#     region_1 = intersection_regions[1]
#     region_0_centroid = region_0.centroid
#     region_1_centroid = region_1.centroid
#     region_0_distances_to_centroid = []
#     region_1_distances_to_centroid = []
#     for pixel in region_0.coords:
#         region_0_distances_to_centroid.append(
#             np.linalg.norm(np.array(pixel) - np.array(region_0_centroid))
#         )
#     for pixel in region_1.coords:
#         region_1_distances_to_centroid.append(
#             np.linalg.norm(np.array(pixel) - np.array(region_1_centroid))
#         )
#     region_0_distances_to_centroid = np.array(region_0_distances_to_centroid)
#     region_1_distances_to_centroid = np.array(region_1_distances_to_centroid)
#     region_0_min_distance_to_centroid = np.min(region_0_distances_to_centroid)
#     region_1_min_distance_to_centroid = np.min(region_1_distances_to_centroid)
#     region_0_min_distance_to_centroid_index = np.argmin(
#         region_0_distances_to_centroid
#     )
#     region_1_min_distance_to_centroid_index = np.argmin(
#         region_1_distances_to_centroid
#     )
#     region_0_min_distance_to_centroid_pixel = region_0.coords[
#         region_0_min_distance_to_centroid_index
#     ]
#     region_1_min_distance_to_centroid_pixel = region_1.coords[
#         region_1_min_distance_to_centroid_index
#     ]
#     region_0_min_distance_to_centroid_pixel = np.array(
#         region_0_min_distance_to_centroid_pixel
#     )
#     region_1_min_distance_to_centroid_pixel = np.array(
#         region_1_min_distance_to_centroid_pixel
#     )

#     return (region_0_min_distance_to_centroid_pixel, region_1_min_distance_to_centroid_pixel)


# Find connecting regions for all images
image_dict_2_connection_points = {}
# images_2_connection_points = []
# masks_2_connection_points = []
# p_2_nm_2_connection_points = []
images_not_2_connection_points = []
masks_not_2_connection_points = []
# all_intersection_labels = []
for index, image_dict_item in image_dict_largest_ring.items():
    image = image_dict_item["image"]
    mask = image_dict_item["predicted_mask"]
    p_2_nm = image_dict_item["p_to_nm"]

    num_connection_points, intersection_labels = get_number_of_connection_points(
        image=image,
        combined_mask=mask,
    )
    if num_connection_points != 2:
        images_not_2_connection_points.append(image)
        masks_not_2_connection_points.append(mask)
    else:
        image_dict_2_connection_points[index] = image_dict_item
        image_dict_2_connection_points[index]["intersection_labels"] = intersection_labels

        # images_2_connection_points.append(image)
        # masks_2_connection_points.append(mask)
        # all_intersection_labels.append(intersection_labels)
        # p_2_nm_2_connection_points.append(p_2_nm)

print(f"num failed: {len(images_not_2_connection_points)}")
print(f"num passed: {len(image_dict_2_connection_points)}")

plot_images(
    images=images_not_2_connection_points,
    masks=masks_not_2_connection_points,
    px_to_nms=[image_dict_2_connection_points[i]["p_to_nm"] for i in image_dict_2_connection_points],
)

# plot_images(
#     images=images_2_connection_points,
#     masks=masks_2_connection_points,
# )


def plot_images_with_overlays(images: list, masks: list, overlays: list, px_to_nms: list, width=5):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 2, figsize=(20, 30))
    for i, (image, mask, overlay) in enumerate(zip(images, masks, overlays)):
        ax[i // width, i % width * 2].imshow(image, cmap="viridis")
        ax[i // width, i % width * 2].axis("off")
        ax[i // width, i % width * 2 + 1].imshow(mask, cmap="viridis")
        ax[i // width, i % width * 2 + 1].imshow(overlay > 0, alpha=0.5, cmap="viridis")
        ax[i // width, i % width * 2 + 1].axis("off")
        ax[i // width, i % width * 2].set_title(f"p_to_nm: {px_to_nms[i]}")
    fig.tight_layout()


plot_images_with_overlays(
    images=[image_dict_2_connection_points[i]["image"] for i in image_dict_2_connection_points],
    masks=[image_dict_2_connection_points[i]["predicted_mask"] for i in image_dict_2_connection_points],
    overlays=[image_dict_2_connection_points[i]["intersection_labels"] for i in image_dict_2_connection_points],
    px_to_nms=[image_dict_2_connection_points[i]["p_to_nm"] for i in image_dict_2_connection_points],
)

# image = images_largest_ring[0]
# mask = masks_largest_ring[0]

# num_connection_points, intersection_labels = get_number_of_connection_points(
#     image=image,
#     combined_mask=mask,
# )

# connecting_midpoints = find_middle_of_connection_points(
#     image=image,
#     combined_mask=mask,
#     intersection_labels=intersection_labels,
# )

# plt.imshow(image, cmap="viridis")
# plt.imshow(mask, alpha=0.5, cmap="viridis")
# plt.scatter(connecting_midpoints[0][1], connecting_midpoints[0][0], c="r", s=5)
# plt.scatter(connecting_midpoints[1][1], connecting_midpoints[1][0], c="b", s=5)
# plt.show()

In [None]:
# # Get distance transform for the masks
# index = 3
# mask = masks_2_connection_points[index]
# intersection_labels = all_intersection_labels[index]
# plt.imshow(mask)
# plt.imshow(intersection_labels > 0, alpha=0.5)
# plt.show()

image_dict_paths = {}

for index, image_dict_item in image_dict_2_connection_points.items():
    image = image_dict_item["image"]
    mask = image_dict_item["predicted_mask"]
    p_2_nm = image_dict_item["p_to_nm"]
    intersection_labels = image_dict_item["intersection_labels"]

    # Get distance transform for the mask

    distance_transform = distance_transform_edt(mask > 0)
    distance_transform[mask == 2] = 0

    # plt.imshow(distance_transform)
    # plt.show()

    # starting point is the point where intersection region 0 has the largest distance transform value
    intersection_labels = label(intersection_labels)
    intersection_regions = regionprops(intersection_labels)
    region_0 = intersection_regions[0]
    region_1 = intersection_regions[1]
    region_0_distance_transform_values = []
    region_1_distance_transform_values = []
    for pixel in region_0.coords:
        region_0_distance_transform_values.append(distance_transform[pixel[0], pixel[1]])
    for pixel in region_1.coords:
        region_1_distance_transform_values.append(distance_transform[pixel[0], pixel[1]])
    region_0_distance_transform_values = np.array(region_0_distance_transform_values)
    region_1_distance_transform_values = np.array(region_1_distance_transform_values)
    region_0_max_distance_transform_value = np.max(region_0_distance_transform_values)
    region_1_max_distance_transform_value = np.max(region_1_distance_transform_values)
    region_0_max_distance_transform_value_index = np.argmax(region_0_distance_transform_values)
    region_1_max_distance_transform_value_index = np.argmax(region_1_distance_transform_values)
    region_0_max_distance_transform_value_pixel = region_0.coords[region_0_max_distance_transform_value_index]
    region_1_max_distance_transform_value_pixel = region_1.coords[region_1_max_distance_transform_value_index]

    # # plot the points
    # plt.imshow(mask)
    # plt.imshow(intersection_labels > 0, alpha=0.5)
    # plt.scatter(region_0_max_distance_transform_value_pixel[1], region_0_max_distance_transform_value_pixel[0], c="r", s=5)
    # plt.scatter(region_1_max_distance_transform_value_pixel[1], region_1_max_distance_transform_value_pixel[0], c="b", s=5)
    # plt.show()

    start_point = (
        region_0_max_distance_transform_value_pixel[0],
        region_0_max_distance_transform_value_pixel[1],
    )
    end_point = (
        region_1_max_distance_transform_value_pixel[0],
        region_1_max_distance_transform_value_pixel[1],
    )

    inverted_distance_transform = np.max(distance_transform) - distance_transform
    # plt.imshow(inverted_distance_transform)
    inverted_distance_transform[inverted_distance_transform == np.max(inverted_distance_transform)] = 1000

    # Find the shortest path between the two points weighted by the inverted distance transform
    route, weight = route_through_array(inverted_distance_transform, start_point, end_point)
    route = np.array(route)

    # plt.imshow(inverted_distance_transform)
    # plt.plot(route[:, 1], route[:, 0], "r-")
    # plt.show()

    image_dict_paths[index] = image_dict_item
    image_dict_paths[index]["path"] = route


def plot_images_with_paths(images: list, masks: list, paths: list, px_to_nms: list, width=5):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 2, figsize=(20, 30))
    for i, (image, mask, path) in enumerate(zip(images, masks, paths)):
        ax[i // width, i % width * 2].imshow(image, cmap="viridis")
        ax[i // width, i % width * 2].axis("off")
        ax[i // width, i % width * 2 + 1].imshow(mask, cmap="viridis")
        ax[i // width, i % width * 2 + 1].plot(path[:, 1], path[:, 0], "r-")
        ax[i // width, i % width * 2 + 1].axis("off")
        ax[i // width, i % width * 2].set_title(f"p_to_nm: {px_to_nms[i]}")
    fig.tight_layout()


plot_images_with_paths(
    images=[image_dict_paths[i]["image"] for i in image_dict_paths],
    masks=[image_dict_paths[i]["predicted_mask"] for i in image_dict_paths],
    paths=[image_dict_paths[i]["path"] for i in image_dict_paths],
    px_to_nms=[image_dict_paths[i]["p_to_nm"] for i in image_dict_paths],
)

In [None]:
# Calculate start and end angles for the paths


def calculate_start_and_end_angles(path):
    # Calculate the start and end angles for the path by averaging the angle from the start and end points to the next points

    num_averaging_points = 5

    start_origin = path[0]
    end_origin = path[-1]

    start_points = path[1:num_averaging_points]
    end_points = path[-num_averaging_points:-1]

    # Calculate mean vector from start point to next points
    start_vectors = start_points - start_origin
    start_mean_vector = np.mean(start_vectors, axis=0)

    # Calculate mean vector from end point to next points
    end_vectors = end_points - end_origin
    end_mean_vector = np.mean(end_vectors, axis=0)

    # Calculate start and end angles
    start_angle = np.arctan2(start_mean_vector[0], start_mean_vector[1])
    end_angle = np.arctan2(end_mean_vector[0], end_mean_vector[1])

    angle_difference = np.abs(start_angle - end_angle)
    if angle_difference > np.pi:
        angle_difference = 2 * np.pi - angle_difference

    # convert to degrees
    angle_difference = angle_difference * 180 / np.pi

    return (
        angle_difference,
        start_angle,
        end_angle,
        start_mean_vector,
        end_mean_vector,
        start_origin,
        end_origin,
    )


# index = 22
# path = paths[index]
# image = images_2_connection_points[index]
# mask = masks_2_connection_points[index]

# (
#     angle_difference,
#     start_angle,
#     end_angle,
#     start_mean_vector,
#     end_mean_vector,
#     start_origin,
#     end_origin,
# ) = calculate_start_and_end_angles(path)

# # Plot the start and end points and vectors
# plt.imshow(image)
# plt.imshow(mask, alpha=0.5)
# plt.scatter(start_origin[1], start_origin[0], c="r", s=5)
# plt.scatter(end_origin[1], end_origin[0], c="b", s=5)
# # plot path
# plt.plot(path[:, 1], path[:, 0], "r-")
# plt.quiver(
#     start_origin[1],
#     start_origin[0],
#     start_mean_vector[1],
#     -start_mean_vector[0],
#     color="r",
#     scale=10,
# )
# plt.quiver(
#     end_origin[1],
#     end_origin[0],
#     end_mean_vector[1],
#     -end_mean_vector[0],
#     color="b",
#     scale=10,
# )
# plt.show()

# print(f"angle difference: {angle_difference}")


def plot_images_with_paths_and_angles(
    images: list,
    masks: list,
    paths: list,
    angles: list,
    start_angles: list,
    end_angles: list,
    start_origins: list,
    end_origins: list,
    px_to_nms: list,
    width=5,
):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 2, figsize=(20, 40))
    for i, (
        image,
        mask,
        path,
        angle,
        start_angle,
        end_angle,
        start_origin,
        end_origin,
    ) in enumerate(zip(images, masks, paths, angles, start_angles, end_angles, start_origins, end_origins)):
        ax[i // width, i % width * 2].imshow(image, cmap="viridis")
        ax[i // width, i % width * 2].axis("off")
        ax[i // width, i % width * 2 + 1].imshow(mask, cmap="viridis")
        ax[i // width, i % width * 2 + 1].plot(path[:, 1], path[:, 0], "r-")
        ax[i // width, i % width * 2 + 1].quiver(
            start_origin[1],
            start_origin[0],
            np.cos(start_angle),
            -np.sin(start_angle),
            color="r",
            scale=3,
        )
        ax[i // width, i % width * 2 + 1].quiver(
            end_origin[1],
            end_origin[0],
            np.cos(end_angle),
            -np.sin(end_angle),
            color="b",
            scale=3,
        )
        ax[i // width, i % width * 2 + 1].axis("off")
        ax[i // width, i % width * 2 + 1].set_title(f"angle: {angle:.2f} degrees")
        ax[i // width, i % width * 2].set_title(f"p_to_nm: {px_to_nms[i]}")
    fig.tight_layout()


image_dict_angles = {}

for index, image_dict_item in image_dict_paths.items():
    image = image_dict_item["image"]
    mask = image_dict_item["predicted_mask"]
    p_2_nm = image_dict_item["p_to_nm"]
    path = image_dict_item["path"]
    (
        angle_difference,
        start_angle,
        end_angle,
        start_mean_vector,
        end_mean_vector,
        start_origin,
        end_origin,
    ) = calculate_start_and_end_angles(path)

    image_dict_angles[index] = image_dict_item
    image_dict_angles[index]["angle_difference"] = angle_difference
    image_dict_angles[index]["start_angle"] = start_angle
    image_dict_angles[index]["end_angle"] = end_angle
    image_dict_angles[index]["start_mean_vector"] = start_mean_vector
    image_dict_angles[index]["end_mean_vector"] = end_mean_vector
    image_dict_angles[index]["start_origin"] = start_origin
    image_dict_angles[index]["end_origin"] = end_origin

plot_images_with_paths_and_angles(
    images=[image_dict_angles[i]["image"] for i in image_dict_angles],
    masks=[image_dict_angles[i]["predicted_mask"] for i in image_dict_angles],
    paths=[image_dict_angles[i]["path"] for i in image_dict_angles],
    angles=[image_dict_angles[i]["angle_difference"] for i in image_dict_angles],
    start_angles=[image_dict_angles[i]["start_angle"] for i in image_dict_angles],
    end_angles=[image_dict_angles[i]["end_angle"] for i in image_dict_angles],
    start_origins=[image_dict_angles[i]["start_origin"] for i in image_dict_angles],
    end_origins=[image_dict_angles[i]["end_origin"] for i in image_dict_angles],
    px_to_nms=[image_dict_angles[i]["p_to_nm"] for i in image_dict_angles],
)

## Elongation metric

In [None]:
# Calculate the elongation metric


def signed_angle_between_vectors(v1, v2):
    # Check that neither vector is 0
    if np.all(v1 == 0) or np.all(v2 == 0):
        raise ValueError("One of the vectors is 0")

    # Calculate the dot product
    dot_product = np.dot(v1, v2)

    # Calculate the norms of each vector
    norm_v1 = np.linalg.norm(v1)
    norm_v2 = np.linalg.norm(v2)

    # Calculate the cosine of the angle
    cos_theta = dot_product / (norm_v1 * norm_v2)

    angle = np.arccos(np.clip(cos_theta, -1.0, 1.0))

    # Calculate the cross product
    cross_product = np.cross(v1, v2)

    # If the cross product is positive, then the angle is negative
    if cross_product > 0:
        angle = -angle

    # Convert to angle and return
    return angle


def rotate_points(points: np.ndarray, angle: float):
    # Rotate the points by the angle
    # print(f"rotating by {np.degrees(angle)} degrees")
    # rotation_matrix = np.array([[-np.sin(angle), np.cos(angle)], [np.cos(angle), np.sin(angle)]])
    rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])
    rotated_points = np.dot(points, rotation_matrix)
    return rotated_points


def align_points_to_vertical(points: np.ndarray, orientation_vector: np.ndarray):
    # Align the points to the vertical by rotating them by the angle between the orientation vector and the vertical
    vertical_vector = np.array([1, 0])
    angle = signed_angle_between_vectors(orientation_vector, vertical_vector)
    rotated_points = rotate_points(points, angle)
    rotated_orientation_vector = rotate_points(orientation_vector, angle)
    return rotated_points, rotated_orientation_vector, -angle


def plot_images_with_paths(
    images: list,
    masks: list,
    paths: list,
    centroids: list,
    midpoints: list,
    elongations: list,
    px_to_nms: list,
    bounding_boxes: list,
    aspect_ratios: list,
    width=5,
    save_dir=None,
):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 2, figsize=(30, 30))
    for i, (image, mask, path) in enumerate(zip(images, masks, paths)):
        ax[i // width, i % width * 2].imshow(image, cmap="viridis")
        ax[i // width, i % width * 2].axis("off")
        ax[i // width, i % width * 2 + 1].imshow(mask, cmap="viridis")
        ax[i // width, i % width * 2 + 1].plot(path[:, 1], path[:, 0], "r-")
        ax[i // width, i % width * 2 + 1].axis("off")
        ax[i // width, i % width * 2 + 1].scatter(centroids[i][1], centroids[i][0], c="r", s=5)
        ax[i // width, i % width * 2 + 1].scatter(midpoints[i][1], midpoints[i][0], c="g", s=5)
        # Draw a line between midpoint and centroid
        ax[i // width, i % width * 2 + 1].plot(
            [centroids[i][1], midpoints[i][1]],
            [centroids[i][0], midpoints[i][0]],
            "g-",
        )
        # Plot the bounding box
        ax[i // width, i % width * 2 + 1].plot(bounding_boxes[i][:, 1], bounding_boxes[i][:, 0], color="orange")
        ax[i // width, i % width * 2 + 1].set_title(
            f"||| elongation: {elongations[i]:.2f} nm | aspect ratio: {aspect_ratios[i]:.2f} | px_to_nm: {px_to_nms[i]} |||",
            fontsize=10,
        )
    fig.tight_layout()

    if save_dir is not None:
        plt.savefig(save_dir, dpi=500)


image_dict_elongation = {}

for index, image_dict_path in image_dict_paths.items():
    image = image_dict_path["image"]
    mask = image_dict_path["predicted_mask"]
    p_2_nm = image_dict_path["p_to_nm"]
    path = image_dict_path["path"]

    # Get the centroid of the path
    path_centroid = np.mean(path, axis=0)

    path_points_shifted = np.copy(path) - path_centroid

    # Get the start and end points of the path
    start_point = path[0]
    end_point = path[-1]
    start_point_end_point_midpoint = (start_point + end_point) / 2

    elongation_vector = start_point_end_point_midpoint - path_centroid
    elongation_vector_points = np.array(
        [[start_point_end_point_midpoint[0], start_point_end_point_midpoint[1]], [path_centroid[0], path_centroid[1]]]
    )

    shifted_elongation_vector_points = elongation_vector_points - path_centroid

    # Distance between midpoint and centroid
    distance = np.linalg.norm(elongation_vector)

    # Rotate the points
    shifted_rotated_path_points, shifted_rotated_elongation_vector, angle_rad = align_points_to_vertical(
        path_points_shifted, elongation_vector
    )
    # print(f"rotated by {np.degrees(angle_rad)} degrees")

    # Plot image
    # plt.imshow(image, cmap="viridis")
    # plt.plot(path[:, 1], path[:, 0])
    # plt.scatter(path_centroid[1], path_centroid[0], c="r", s=5)
    # plt.plot(elongation_vector_points[:, 1], elongation_vector_points[:, 0])
    # plt.title(f"orientation: {np.degrees(angle_rad):.2f} degrees")
    # plt.show()

    # Plot points before rotation
    # plt.scatter(path_points_shifted[:, 1], path_points_shifted[:, 0])
    # plt.plot(shifted_elongation_vector_points[:, 1], shifted_elongation_vector_points[:, 0])
    # plt.gca().set_aspect("equal", adjustable="box")
    # plt.show()

    # Plot the points after rotation
    # Move the points back to the original position
    shifted_rotated_points = shifted_rotated_path_points
    shifted_rotated_elongation_vector_points = np.array([[0, 0], shifted_rotated_elongation_vector])
    # Find the bounding box of the rotated points
    min_x = np.min(shifted_rotated_points[:, 0])
    max_x = np.max(shifted_rotated_points[:, 0])
    min_y = np.min(shifted_rotated_points[:, 1])
    max_y = np.max(shifted_rotated_points[:, 1])

    rotated_bounding_box = np.array([[min_x, min_y], [max_x, min_y], [max_x, max_y], [min_x, max_y], [min_x, min_y]])

    aspect_ratio = (max_x - min_x) / (max_y - min_y)

    # plt.scatter(shifted_rotated_points[:, 1], shifted_rotated_points[:, 0])
    # plt.plot(shifted_rotated_elongation_vector_points[:, 1], shifted_rotated_elongation_vector_points[:, 0])
    # # Plot the bounding box
    # plt.plot(rotated_bounding_box[:, 1], rotated_bounding_box[:, 0], "r-")
    # plt.gca().set_aspect("equal", adjustable="box")
    # plt.show()

    # Un-rotate the bounding box and points
    unrotated_bounding_box = rotate_points(rotated_bounding_box, angle_rad) + path_centroid
    unrotated_points = rotate_points(shifted_rotated_points, angle_rad) + path_centroid
    unrotated_elongation_vector_points = (
        rotate_points(shifted_rotated_elongation_vector_points, angle_rad) + path_centroid
    )

    # Plot the un-rotated bounding box with un-rotated points
    # plt.scatter(unrotated_points[:, 1], unrotated_points[:, 0])
    # plt.plot(unrotated_bounding_box[:, 1], unrotated_bounding_box[:, 0], "r-")
    # plt.plot(unrotated_elongation_vector_points[:, 1], unrotated_elongation_vector_points[:, 0], "g-")
    # plt.gca().set_aspect("equal", adjustable="box")
    # plt.show()

    image_dict_elongation[index] = image_dict_path
    image_dict_elongation[index]["path_centroid"] = path_centroid
    image_dict_elongation[index]["start_point"] = start_point
    image_dict_elongation[index]["end_point"] = end_point
    image_dict_elongation[index]["start_point_end_point_midpoint"] = start_point_end_point_midpoint
    image_dict_elongation[index]["distance"] = distance * p_2_nm
    image_dict_elongation[index]["rotated_bounding_box"] = unrotated_bounding_box
    image_dict_elongation[index]["aspect_ratio"] = aspect_ratio

    # if index > 5:
    #     break


# Plot the paths and centroids
plot_images_with_paths(
    images=[image_dict_elongation[i]["image"] for i in image_dict_elongation],
    masks=[image_dict_elongation[i]["predicted_mask"] for i in image_dict_elongation],
    paths=[image_dict_elongation[i]["path"] for i in image_dict_elongation],
    centroids=[image_dict_elongation[i]["path_centroid"] for i in image_dict_elongation],
    midpoints=[image_dict_elongation[i]["start_point_end_point_midpoint"] for i in image_dict_elongation],
    elongations=[image_dict_elongation[i]["distance"] for i in image_dict_elongation],
    px_to_nms=[image_dict_elongation[i]["p_to_nm"] for i in image_dict_elongation],
    bounding_boxes=[image_dict_elongation[i]["rotated_bounding_box"] for i in image_dict_elongation],
    aspect_ratios=[image_dict_elongation[i]["aspect_ratio"] for i in image_dict_elongation],
    save_dir=IMAGE_SAVE_DIR / "elongation_images.png",
)

# Save the elongation values to numpy file
elongations = [image_dict_elongation[i]["distance"] for i in image_dict_elongation]
np.save(IMAGE_SAVE_DIR / "elongations.npy", elongations)

# Save the aspect ratios to numpy file
aspect_ratios = [image_dict_elongation[i]["aspect_ratio"] for i in image_dict_elongation]
np.save(IMAGE_SAVE_DIR / "aspect_ratios.npy", aspect_ratios)

In [None]:
sns.histplot([image_dict_angles[i]["angle_difference"] for i in image_dict_angles], bins="auto")
plt.show()
sns.kdeplot([image_dict_angles[i]["angle_difference"] for i in image_dict_angles])
plt.show()

np.save(
    IMAGE_SAVE_DIR / "angle_differences.npy",
    [image_dict_angles[i]["angle_difference"] for i in image_dict_angles],
)

In [None]:
def calculate_edges(grain_mask: np.ndarray):
    """Class method that takes a 2D boolean numpy array image of a grain and returns a python list of the
    coordinates of the edges of the grain.

    Parameters
    ----------
    grain_mask : np.ndarray
        A 2D numpy array image of a grain. Data in the array must be boolean.
    edge_detection_method : str
        Method used for detecting the edges of grain masks before calculating statistics on them.
        Do not change unless you know exactly what this is doing. Options: "binary_erosion", "canny".

    Returns
    -------
    edges : list
        List containing the coordinates of the edges of the grain.
    """
    # Fill any holes
    filled_grain_mask = binary_fill_holes(grain_mask)

    # Add padding (needed for erosion)
    padded = np.pad(filled_grain_mask, 1)
    # Erode by 1 pixel
    eroded = binary_erosion(padded)
    # Remove padding
    eroded = eroded[1:-1, 1:-1]

    # Edges is equal to the difference between the
    # original image and the eroded image.
    edges = filled_grain_mask.astype(int) - eroded.astype(int)

    nonzero_coordinates = edges.nonzero()
    # Get vector representation of the points
    # FIXME : Switched to list comprehension but should be unnecessary to create this as a list as we can use
    # np.stack() to combine the arrays and use that...
    # return np.stack(nonzero_coordinates, axis=1)
    # edges = []
    # for vector in np.transpose(nonzero_coordinates):
    #     edges.append(list(vector))
    # return edges
    return [list(vector) for vector in np.transpose(nonzero_coordinates)]


def is_clockwise(p_1: tuple, p_2: tuple, p_3: tuple) -> bool:
    """Function to determine if three points make a clockwise or counter-clockwise turn.

    Parameters
    ----------
    p_1: tuple
        First point to be used to calculate turn.
    p_2: tuple
        Second point to be used to calculate turn.
    p_3: tuple
        Third point to be used to calculate turn.

    Returns
    -------
    boolean
        Indicator of whether turn is clockwise.
    """
    # Determine if three points form a clockwise or counter-clockwise turn.
    # I use the method of calculating the determinant of the following rotation matrix here. If the determinant
    # is > 0 then the rotation is counter-clockwise.
    rotation_matrix = np.array(((p_1[0], p_1[1], 1), (p_2[0], p_2[1], 1), (p_3[0], p_3[1], 1)))
    return not np.linalg.det(rotation_matrix) > 0


def get_triangle_height(base_point_1: np.array, base_point_2: np.array, top_point: np.array) -> float:
    """Returns the height of a triangle defined by the input point vectors.
    Parameters
    ----------
    base_point_1: np.ndarray
        a base point of the triangle, eg: [5, 3].

    base_point_2: np.ndarray
        a base point of the triangle, eg: [8, 3].

    top_point: np.ndarray
        the top point of the triangle, defining the height from the line between the two base points, eg: [6,10].

    Returns
    -------
    Float
        The height of the triangle - ie the shortest distance between the top point and the line between the two
    base points.
    """

    # Height of triangle = A/b = ||AB X AC|| / ||AB||
    a_b = base_point_1 - base_point_2
    a_c = base_point_1 - top_point
    return np.linalg.norm(np.cross(a_b, a_c)) / np.linalg.norm(a_b)


def get_max_min_ferets(edge_points: list):
    """Returns the minimum and maximum feret diameters for a grain.
    These are defined as the smallest and greatest distances between
    a pair of callipers that are rotating around a 2d object, maintaining
    contact at all times.

    Parameters
    ----------
    edge_points: list
        a list of the vector positions of the pixels comprising the edge of the
        grain. Eg: [[0, 0], [1, 0], [2, 1]]
    Returns
    -------
    min_feret: float
        the minimum feret diameter of the grain
    max_feret: float
        the maximum feret diameter of the grain

    Notes
    -----
    The method starts out by calculating the upper and lower convex hulls using
    an algorithm based on the Graham Scan Algorithm [1]. Using these upper and
    lower hulls, the callipers are simulated as rotating clockwise around the grain.
    We determine the order in which vertices are encountered by comparing the
    gradients of the slopes between vertices. An array of pairs of points that
    are in contact with either calliper at a given time is created in order to
    be able to calculate the maximum feret diameter. The minimum diameter is a
    little tricky, since it won't simply be the shortest distance between two
    contact points, but it will occur somewhere during the rotation around a
    pair of contact points. It turns out that the point will always be such
    that two points are in contact with one calliper while the other calliper
    is in contact with another point. We can use this fact to be sure of finding
    the smallest feret diameter, simply by testing each triangle of 3 contact points
    as we iterate, finding the height of the triangle that is formed between the
    three aforementioned points, as this will be the perpendicular distance between
    the callipers.

    References
    ----------
    [1] Graham, R.L. (1972).
        "An Efficient Algorithm for Determining the Convex Hull of a Finite Planar Set".
        Information Processing Letters. 1 (4): 132-133.
        doi:10.1016/0020-0190(72)90045-2.
    """

    min_feret_triangle = None

    # Sort the vectors by their x coordinate and then by their y coordinate.
    # The conversion between list and numpy array can be removed, though it would be harder
    # to read.
    edge_points.sort()
    edge_points = np.array(edge_points)

    # Construct upper and lower hulls for the edge points. Sadly we can't just use the standard hull
    # that graham_scan() returns, since we need to separate the upper and lower hulls. I might streamline
    # these two into one method later.
    upper_hull = []
    lower_hull = []
    for point in edge_points:
        while len(lower_hull) > 1 and is_clockwise(lower_hull[-2], lower_hull[-1], point):
            lower_hull.pop()
        lower_hull.append(point)
        while len(upper_hull) > 1 and not is_clockwise(upper_hull[-2], upper_hull[-1], point):
            upper_hull.pop()
        upper_hull.append(point)

    upper_hull = np.array(upper_hull)
    lower_hull = np.array(lower_hull)

    # Create list of contact vertices for calipers on the antipodal hulls
    contact_points = []
    upper_index = 0
    lower_index = len(lower_hull) - 1
    min_feret = None
    while upper_index < len(upper_hull) - 1 or lower_index > 0:
        contact_points.append([lower_hull[lower_index, :], upper_hull[upper_index, :]])
        # If we have reached the end of the upper hull, continute iterating over the lower hull
        if upper_index == len(upper_hull) - 1:
            lower_index -= 1
            small_feret = get_triangle_height(
                np.array(lower_hull[lower_index + 1, :]),
                np.array(lower_hull[lower_index, :]),
                np.array(upper_hull[upper_index, :]),
            )
            if min_feret is None or small_feret < min_feret:
                min_feret = small_feret
                min_feret_triangle = [
                    lower_hull[lower_index + 1, :],
                    lower_hull[lower_index, :],
                    upper_hull[upper_index, :],
                ]
        # If we have reached the end of the lower hull, continue iterating over the upper hull
        elif lower_index == 0:
            upper_index += 1
            small_feret = get_triangle_height(
                np.array(upper_hull[upper_index - 1, :]),
                np.array(upper_hull[upper_index, :]),
                np.array(lower_hull[lower_index, :]),
            )
            if min_feret is None or small_feret < min_feret:
                min_feret = small_feret
                min_feret_triangle = [
                    lower_hull[lower_index + 1, :],
                    lower_hull[lower_index, :],
                    upper_hull[upper_index, :],
                ]
        # Check if the gradient of the last point and the proposed next point in the upper hull is greater than the gradient
        # of the two corresponding points in the lower hull, if so, this means that the next point in the upper hull
        # will be encountered before the next point in the lower hull and vice versa.
        # Note that the calculation here for gradients is the simple delta upper_y / delta upper_x > delta lower_y / delta lower_x
        # however I have multiplied through the denominators such that there are no instances of division by zero. The
        # inequality still holds and provides what is needed.
        elif (upper_hull[upper_index + 1, 1] - upper_hull[upper_index, 1]) * (
            lower_hull[lower_index, 0] - lower_hull[lower_index - 1, 0]
        ) > (lower_hull[lower_index, 1] - lower_hull[lower_index - 1, 1]) * (
            upper_hull[upper_index + 1, 0] - upper_hull[upper_index, 0]
        ):
            # If the upper hull is encoutnered first, increment the iteration index for the upper hull
            # Also consider the triangle that is made as the two upper hull vertices are colinear with the caliper
            upper_index += 1
            small_feret = get_triangle_height(
                np.array(upper_hull[upper_index - 1, :]),
                np.array(upper_hull[upper_index, :]),
                np.array(lower_hull[lower_index, :]),
            )
            if min_feret is None or small_feret < min_feret:
                min_feret = small_feret
        else:
            # The next point in the lower hull will be encountered first, so increment the lower hull iteration index.
            lower_index -= 1
            small_feret = get_triangle_height(
                np.array(lower_hull[lower_index + 1, :]),
                np.array(lower_hull[lower_index, :]),
                np.array(upper_hull[upper_index, :]),
            )

            if min_feret is None or small_feret < min_feret:
                min_feret = small_feret
                min_feret_triangle = [
                    lower_hull[lower_index + 1, :],
                    lower_hull[lower_index, :],
                    upper_hull[upper_index, :],
                ]

    contact_points = np.array(contact_points)

    # Find the minimum and maximum distance in the contact points
    max_feret = None
    for point_pair in contact_points:
        dist = np.sqrt((point_pair[0, 0] - point_pair[1, 0]) ** 2 + (point_pair[0, 1] - point_pair[1, 1]) ** 2)
        if max_feret is None or max_feret < dist:
            max_feret = dist

    return min_feret, max_feret, min_feret_triangle

In [None]:
# Calculate min, max feret diameters for the ring masks

image_dict_ferets = {}

iteration_index = 0
for index, image_dict_item in image_dict_angles.items():
    # print(f"iteration index: {iteration_index} / {len(image_dict_angles)}")
    iteration_index += 1

    image = image_dict_item["image"]
    mask = image_dict_item["predicted_mask"]
    p_2_nm = image_dict_item["p_to_nm"]
    path = image_dict_item["path"]

    ring_mask = mask == 1

    edge_points = calculate_edges(ring_mask)
    # plot edge points
    # plt.imshow(ring_mask)

    min_feret, max_feret, min_feret_triangle = get_max_min_ferets(edge_points)

    image_dict_ferets[index] = image_dict_item
    image_dict_ferets[index]["min_feret"] = min_feret * p_2_nm
    image_dict_ferets[index]["max_feret"] = max_feret * p_2_nm

    # a_b = min_feret_triangle[0] - min_feret_triangle[1]
    # a_c = min_feret_triangle[0] - min_feret_triangle[2]

    # return np.linalg.norm(np.cross(a_b, a_c)) / np.linalg.norm(a_b)

    # plt.imshow(image)
    # plt.scatter(np.array(edge_points)[:, 1], np.array(edge_points)[:, 0], c="r", s=5)
    # plt.scatter(
    #     np.array(min_feret_triangle)[:-1, 1], np.array(min_feret_triangle)[:-1, 0], c="y", s=20
    # )
    # plt.scatter(
    #     np.array(min_feret_triangle)[-1, 1], np.array(min_feret_triangle)[-1, 0], c="aqua", s=20
    # )

    # min_feret_triangle = np.append(min_feret_triangle, [min_feret_triangle[0]], axis=0)
    # plt.plot(
    #     np.array(min_feret_triangle)[:, 1], np.array(min_feret_triangle)[:, 0], "y", linewidth=2
    # )
    # plt.show()

    # plt.imshow(image, cmap="viridis")
    # plt.show()
    # plt.imshow(ring_mask, cmap="viridis")
    # plt.show()
    # print(f"min feret: {min_feret * p_2_nm}")
    # print(f"min feret pixels: {min_feret} px_2_nm: {p_2_nm}")


# sns.histplot(
#     [image_dict_ferets[i]["min_feret"] for i in image_dict_ferets], bins="auto", label="min feret"
# )
# sns.histplot(
#     [image_dict_ferets[i]["min_feret"] for i in image_dict_ferets], bins="auto", label="max feret"
# )
# plt.legend()


sns.kdeplot([image_dict_ferets[i]["min_feret"] for i in image_dict_ferets], label="min feret")
sns.kdeplot([image_dict_ferets[i]["max_feret"] for i in image_dict_ferets], label="max feret")
plt.legend()
plt.show()

print(f"min ferets: {[image_dict_ferets[i]['min_feret'] for i in image_dict_ferets]}")
print(f"mean min feret: {np.mean([image_dict_ferets[i]['min_feret'] for i in image_dict_ferets])}")


np.save(
    IMAGE_SAVE_DIR / "min_ferets_updated.npy",
    [image_dict_ferets[i]["min_feret"] for i in image_dict_ferets],
)
np.save(
    IMAGE_SAVE_DIR / "max_ferets_updated.npy",
    [image_dict_ferets[i]["min_feret"] for i in image_dict_ferets],
)

In [None]:
def _rim_curvature(xs: np.ndarray, ys: np.ndarray, periodic: bool = True):
    """Calculate the curvature of a set of points. Uses the standard curvature definition of the derivative of the
    tangent vector.

    Parameters:
    ----------
    xs: np.ndarray
        One dimensional numpy array of x-coordinates of the points
    ys: np.ndarray
        One dimensional numpy array of y-coordinates of the points
    Returns:
    -------
    np.ndarray
        One-dimensional numpy array of curvatures for the spline.
    """

    extension_length = xs.shape[0]
    if periodic:
        xs_extended = np.append(xs, xs)
        xs_extended = np.append(xs_extended, xs)

        ys_extended = np.append(ys, ys)
        ys_extended = np.append(ys_extended, ys)
        dx = np.gradient(xs_extended)
        dy = np.gradient(ys_extended)
    else:
        dx = np.gradient(xs)
        dy = np.gradient(ys)

    d2x = np.gradient(dx)

    d2y = np.gradient(dy)
    curv = abs((dx * d2y - d2x * dy) / (dx * dx + dy * dy) ** 1.5)

    if periodic:
        curv = curv[extension_length : (len(curv) - extension_length)]

    return curv


def _interpolate_points_spline(points: np.ndarray, num_points: int, smoothing: float = 0.0, periodic: int = 1):
    """Interpolate a set of points using a spline.

    Parameters
    ----------
    points: np.ndarray
        Nx2 Numpy array of coordinates for the points.
    num_points: int
        The number of points to return following the calculated spline.

    Returns
    -------
    interpolated_points: np.ndarray
        An Ix2 Numpy array of coordinates of the interpolated points, where I is the number of points
        specified.
    """

    x, y = splprep(points.T, u=None, s=smoothing, per=periodic)
    x_spline = np.linspace(y.min(), y.max(), num_points)
    x_new, y_new = splev(x_spline, x, der=0)
    interpolated_points = np.array((x_new, y_new)).T
    return interpolated_points


def interpolate_spline_and_get_curvature(
    points: np.ndarray, interpolation_number: int, smoothing: float = 0.0, periodic: int = 1
):
    """Calculate the curvature for a set of points in a closed loop. Interpolates the points using a spline
    to reduce anomalies.

    Parameters
    ----------
    points: np.ndarray
        2xN Numpy array of coordinates for the points.
    interpolation_number: int
        Number of interpolation points per point. Eg: for a set of 10 points and 2 interpolation points,
        there will be 20 points in the spline.

    Returns
    -------
    interpolated_curvatures: np.ndarray
        1xN Numpy array of curvatures corresponding to the interpolated points.
    interpolated_points: np.ndarray
        2xN Numpy array of interpolated points generated from the spline of the
        original points.
    """
    # Ensure we do not alter the original points array
    points = points.copy()

    # Interpolate the data using cubic splines
    num_points = interpolation_number * points.shape[0]
    interpolated_points = _interpolate_points_spline(
        points=points, num_points=num_points, smoothing=smoothing, periodic=periodic
    )
    x = interpolated_points[:, 0]
    y = interpolated_points[:, 1]

    # Calculate the curvature
    interpolated_curvatures = _rim_curvature(x, y, periodic=bool(periodic))

    return interpolated_curvatures, interpolated_points

In [None]:
# Calculate curvature and ring masks


def calculate_path_length(path: np.ndarray):
    """Calculate the length of a path.

    Parameters
    ----------
    path: np.ndarray
        Nx2 numpy array of points in the path.

    Returns
    -------
    float
        The length of the path.
    """
    path_length = 0
    for i in range(1, len(path)):
        path_length += np.linalg.norm(path[i] - path[i - 1])
    return path_length


image_dict_curvatures = {}

for index, image_dict_item in image_dict_ferets.items():
    path = image_dict_item["path"]

    px_2_nm = image_dict_item["p_to_nm"]

    if px_2_nm > 0.59:
        continue

    # Check if the path is overall clockwise or counter-clockwise and reverse if necessary by summing over the edge points
    # (x2 - x1)(y2 + y1)
    # If the sum is positive, the path is clockwise, if negative, it is counter-clockwise
    # clockwise_sum = 0
    # for i in range(1, len(path)):
    #     clockwise_sum += (path[i, 1] - path[i - 1, 1]) * (path[i, 0] + path[i - 1, 0])
    # if clockwise_sum > 0:
    #     path = np.flip(path, axis=0)

    # # Plot the image, the path
    plt.imshow(image_dict_item["image"])
    plt.plot(path[:, 1], path[:, 0], "r-")
    plt.scatter(path[0, 1], path[0, 0], c="g", s=50)
    plt.scatter(path[-1, 1], path[-1, 0], c="b", s=50)
    plt.show()

    path_length_pixels = calculate_path_length(path)
    path_length_nm = path_length_pixels * px_2_nm
    print(f"path length: {path_length_nm} nm | {path_length_pixels} pixels")

    # Try to find a way to smooth consistently across different image sizes
    smoothing = 4 * 1 / px_2_nm

    # Calculate the curvature of the path
    interpolated_curvatures, interpolated_points = interpolate_spline_and_get_curvature(
        points=path, interpolation_number=2, smoothing=smoothing, periodic=1
    )

    # Convert interpolated curavtures to 1 / nm units
    interpolated_curvatures_nm = interpolated_curvatures * 1 / px_2_nm

    image_dict_curvatures[index] = image_dict_item
    image_dict_curvatures[index]["curvature"] = interpolated_curvatures_nm
    image_dict_curvatures[index]["path_length_nm"] = path_length_nm
    image_dict_curvatures[index]["smoothing"] = smoothing
    image_dict_curvatures[index]["path_length_pixels"] = path_length_pixels

    # Keep cmap consistent across images
    cmap = "afmhot"
    cmap_max = 0.4
    cmap_min = 0.0

    # Plot the curvature
    plt.imshow(image_dict_item["image"])
    # plt.plot(path[:, 1], path[:, 0], "r-")
    plt.scatter(
        interpolated_points[:, 1],
        interpolated_points[:, 0],
        c=interpolated_curvatures_nm,
        cmap="afmhot",
        vmin=cmap_min,
        vmax=cmap_max,
        s=5,
    )
    plt.colorbar()
    plt.show()

    print(f"p_2_nm: {px_2_nm} smoothing: {smoothing}")

    # Plot the curvature on another plot as a standard plot
    fig, ax = plt.subplots(1, 1, figsize=(10, 5))
    ax.plot(interpolated_curvatures_nm)
    ax.set_ylim(0.0, cmap_max)
    # Draw a line at 0
    # ax.axhline(0, c="k")
    ax.set_ylabel("Curvature (absolute, 1/nm)")
    # Set x ticks to be from 0 to the length of the path in nm
    len_path_nm = len(path) * px_2_nm

    # Get the x ticks
    x_ticks = np.linspace(0, len(interpolated_curvatures_nm), 5)
    # Create labels in nm
    x_tick_labels = []
    for tick in x_ticks:
        tick_label = int((tick / len(interpolated_curvatures_nm)) * len_path_nm)
        x_tick_labels.append(tick_label)
    # Set the ticks and labels
    ax.set_xticks(x_ticks)
    ax.set_xticklabels(x_tick_labels)
    ax.set_xlabel("Distance along path (nm)")

    plt.show()

In [None]:
# Plot all the curvatures normalized to the same scale
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
for index, image_dict_item in image_dict_curvatures.items():
    length_nm = image_dict_item["path_length_nm"]
    # Plot the curvature so that the x axis is in nm
    ax.plot(
        np.linspace(0, length_nm, len(image_dict_item["curvature"])),
        image_dict_item["curvature"],
        label=f"{index}",
    )
ax.set_ylabel("Curvature (absolute, 1/nm)")
ax.set_xlabel("Distance along path (nm)")
ax.legend()
plt.show()

In [None]:
# Make another plot for the average curvature with error bars as the standard deviation at each point
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
max_nm_length = np.max([image_dict_curvatures[i]["path_length_nm"] for i in image_dict_curvatures])
print(f"max length: {max_nm_length}")

# For each nm in the range of the longest path, calculate the average curvature and standard deviation over all samples
average_curvatures = []
std_curvatures = []
points_in_average = []
interval = 0.5
for nm in np.arange(0, max_nm_length, interval):
    # print(f"nm: {nm}")
    # For each sample, calculate the average curvature for the span of that nm
    sample_average_curvatures_for_this_nm = []
    for index, image_dict_item in image_dict_curvatures.items():
        curvatures = image_dict_item["curvature"]
        path_length_nm = image_dict_item["path_length_nm"]
        # Create a list of the nm values for each pixel value in the curvature array
        nm_values = np.linspace(
            0, path_length_nm, len(curvatures)
        )  # 0 to the length of the path in nm, using the length of the curvature array as the number of points
        # Find the indices of the curvatures that are within the range of the current nm
        indices = np.where((nm_values >= nm) & (nm_values < nm + 1))
        # Calculate the average curvature for this sample over the range of the current nm
        mean_curvature = np.mean(curvatures[indices])
        # Check if the mean curvature is nan, if so, skip this sample
        if np.isnan(mean_curvature):
            continue
        sample_average_curvatures_for_this_nm.append(mean_curvature)
    # Calculate the average curvature for all samples for the current nm
    average_curvature = np.mean(sample_average_curvatures_for_this_nm)
    # print(f"len sample average curvatures: {len(sample_average_curvatures_for_this_nm)}")
    std_curvature = np.std(sample_average_curvatures_for_this_nm)
    average_curvatures.append(average_curvature)
    std_curvatures.append(std_curvature)
    points_in_average.append(len(sample_average_curvatures_for_this_nm))

# Plot the average curvature with error bars, the size of the marker is proportional to the number of points in the average and the error bars in gray and the line in blue
ax.errorbar(
    np.arange(len(average_curvatures)),
    average_curvatures,
    yerr=std_curvatures,
    fmt="-",
    markersize=5,
    ecolor="gray",
    elinewidth=2,
    capsize=5,
    label="Average curvature",
)

# Set just the plot background to dark gray
# ax.set_facecolor("lightgray")

# Plot scatter plot overlaid on the average curvature plot where each point is the average curvature for a given nm and the colour is the number of points in the average
# Plot the average curvature with error bars, and the standard deviation as the error. Error bars should be gray and the line should be coloured as a propotion of how many points are in the average
colours = [num_points for num_points in points_in_average]
# size= [num_points / np.max(points_in_average) * 100 for num_points in points_in_average]
# Colour schemes to try: "afmhot", "viridis", "cividis", "inferno", "plasma", "magma", "rainbow"
scatter_colourmap = "rainbow"
ax.scatter(
    np.arange(len(average_curvatures)),
    average_curvatures,
    c=colours,
    s=50,
    cmap=scatter_colourmap,
    label="Average curvature",
)
# Add a colorbar with maximum and minimum number of points in the average as the limits
cbar = plt.colorbar(
    ax.scatter(np.arange(len(average_curvatures)), average_curvatures, c=colours, cmap=scatter_colourmap)
)
cbar.set_label("Number of points in average")

ax.set_ylabel("Average curvature (absolute, 1/nm)")
ax.set_xlabel("Distance along path (nm)")
plt.suptitle(
    f"Average curvature along path for each nm position, irrespective of proportional distance along sample. Averaging interval: {interval} nm"
)
# Subtitle
plt.title("Error bars are the standard deviation of the average curvature for each nm position.")

# save the plot
plt.savefig(IMAGE_SAVE_DIR / "average_curvatures.png")

# Save the average curvatures and standard deviations to numpy files
np.save(IMAGE_SAVE_DIR / "average_curvatures.npy", average_curvatures)
np.save(IMAGE_SAVE_DIR / "std_curvatures.npy", std_curvatures)
np.save(IMAGE_SAVE_DIR / "points_in_average.npy", points_in_average)

In [None]:
# Also do curvature as a function of proportional distance along the sample
average_curvatures_proportional = []
std_curvatures_proportional = []
points_in_average_proportional = []
percentage_interval = 0.05

for percentage in np.arange(0, 1, percentage_interval):
    # For each sample, calculate the average curvature for the span of that percentage
    sample_average_curvatures_for_this_percentage = []
    for index, image_dict_item in image_dict_curvatures.items():
        curvatures = image_dict_item["curvature"]
        path_length_pixels = image_dict_item["path_length_pixels"]
        # Create a list of the percentage values for each pixel value in the curvature array
        percentage_values = np.linspace(
            0, 1, len(curvatures)
        )  # 0 to the length of the path in nm, using the length of the curvature array as the number of points
        # Find the indices of the curvatures that are within the range of the current percentage
        indices = np.where((percentage_values >= percentage) & (percentage_values < percentage + percentage_interval))
        # Calculate the average curvature for this sample over the range of the current percentage
        mean_curvature = np.mean(curvatures[indices])
        # Check if the mean curvature is nan, if so, skip this sample
        if np.isnan(mean_curvature):
            continue
        sample_average_curvatures_for_this_percentage.append(mean_curvature)
    # Calculate the average curvature for all samples for the current percentage
    average_curvature = np.mean(sample_average_curvatures_for_this_percentage)
    std_curvature = np.std(sample_average_curvatures_for_this_percentage)
    average_curvatures_proportional.append(average_curvature)
    std_curvatures_proportional.append(std_curvature)
    points_in_average_proportional.append(len(sample_average_curvatures_for_this_percentage))


# Plot the average curvature with error bars, the size of the marker is proportional to the number of points in the average and the error bars in gray and the line in blue
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.errorbar(
    np.arange(len(average_curvatures_proportional)),
    average_curvatures_proportional,
    yerr=std_curvatures_proportional,
    fmt="-",
    markersize=5,
    ecolor="gray",
    elinewidth=2,
    capsize=5,
    label="Average curvature",
)

# Set just the plot background to dark gray
# ax.set_facecolor("lightgray")

# Plot scatter plot overlaid on the average curvature plot where each point is the average curvature for a given nm and the colour is the number of points in the average
# Plot the average curvature with error bars, and the standard deviation as the error. Error bars should be gray and the line should be coloured as a propotion of how many points are in the average
colours = [num_points for num_points in points_in_average_proportional]
# size= [num_points / np.max(points_in_average) * 100 for num_points in points_in_average]
# Colour schemes to try: "afmhot", "viridis", "cividis", "inferno", "plasma", "magma", "rainbow"
scatter_colourmap = "rainbow"
ax.scatter(
    np.arange(len(average_curvatures_proportional)),
    average_curvatures_proportional,
    c=colours,
    s=50,
    cmap=scatter_colourmap,
    label="Average curvature",
)
# Add a colorbar with maximum and minimum number of points in the average as the limits
cbar = plt.colorbar(
    ax.scatter(
        np.arange(len(average_curvatures_proportional)),
        average_curvatures_proportional,
        c=colours,
        cmap=scatter_colourmap,
    )
)
cbar.set_label("Number of points in average")

ax.set_ylabel("Average curvature (absolute, 1/nm)")
ax.set_xlabel("Proportional distance along path")
plt.suptitle(
    f"Average curvature along path for each proportional distance along sample. Averaging interval: {percentage_interval*100} %"
)
# Subtitle
plt.title("Error bars are the standard deviation of the average curvature for each proportional distance.")

# save the plot
plt.savefig(IMAGE_SAVE_DIR / "average_curvatures_proportional.png")

# Save the average curvatures and standard deviations to numpy files
np.save(IMAGE_SAVE_DIR / "average_curvatures_proportional.npy", average_curvatures_proportional)
np.save(IMAGE_SAVE_DIR / "std_curvatures_proportional.npy", std_curvatures_proportional)
np.save(IMAGE_SAVE_DIR / "points_in_average_proportional.npy", points_in_average_proportional)