In [None]:
from pathlib import Path
import re
import pickle

import numpy as np
import shapely
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.ndimage import distance_transform_edt, gaussian_filter1d, binary_fill_holes
from skimage.measure import label, regionprops
from skimage.graph import route_through_array
from skimage.morphology import binary_dilation, binary_erosion, skeletonize

from topostats.plottingfuncs import Colormap

colormap = Colormap()
CMAP = colormap.get_cmap()

VMIN = 0
VMAX = 4

In [None]:
SAMPLE_TYPE = "ON_SC"
DATE = "2024-03-22"
DATA_DIR = Path(f"/users/sylvi/topo_data/hariborings/extracted_grains/cas9_{SAMPLE_TYPE}/{DATE}")
SAVE_DIR = Path(f"/Users/sylvi/topo_data/hariborings/processed_grains/cas9_{SAMPLE_TYPE}/{DATE}")
MAX_PX_TO_NM = 10.0
PLOT_RESULTS = True
MAX_GRAIN_NUMBER = 100

FILE_PATH = DATA_DIR / "grain_dict.pkl"
with open(FILE_PATH, "rb") as f:
    grain_dicts = pickle.load(f)

print(f"number of images for sample type [{SAMPLE_TYPE}] : {len(grain_dicts.keys())}")

# Cut off the number of grains
grain_dicts_sample = {}
for i, (key, grain_dictionary) in enumerate(grain_dicts.items()):
    if i > MAX_GRAIN_NUMBER:
        break
    grain_dicts_sample[key] = grain_dictionary

print(f"number of images in sample for {SAMPLE_TYPE} : {len(grain_dicts_sample.keys())}")

grain_dict = grain_dicts_sample

In [None]:
def plot_images(
    images: list, masks: list, grain_indexes: list, px_to_nms: list, width=5, VMIN=VMIN, VMAX=VMAX, cmap=CMAP
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 4, num_rows * 4))
    for i, (image, mask, grain_index, p_to_nm) in enumerate(zip(images, masks, grain_indexes, px_to_nms)):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        im_ax.set_title(f"Grain {grain_index} {p_to_nm} p/nm")
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int))
        mask_ax.axis("off")

    fig.tight_layout()
    plt.show()


if PLOT_RESULTS:
    images = [grain_dicts[i]["image"] for i in grain_dicts]
    masks = [grain_dicts[i]["predicted_mask"] for i in grain_dicts]
    grain_indexes = [i for i in grain_dicts]
    px_to_nms = [grain_dicts[i]["p_to_nm"] for i in grain_dicts]
    plot_images(images, masks, grain_indexes, px_to_nms)

# Pathfinding

In [None]:
def flip_if_anticlockwise(trace: np.ndarray):
    # Check if the trace is clockwise or anticlockwise by summing the cross products of the vectors
    # If the sum is positive, the trace is clockwise
    # If the sum is negative, the trace is anticlockwise
    # If the sum is 0, the trace is a straight line
    cross_sum = 0
    for i in range(1, len(trace) - 1):
        v1 = trace[i, :] - trace[i - 1, :]
        v2 = trace[i + 1, :] - trace[i, :]
        v1 /= np.linalg.norm(v1)
        v2 /= np.linalg.norm(v2)

        # Get cross product
        cross_prod = np.cross(v1, v2)
        cross_sum += cross_prod

    if cross_sum > 0:
        # print("clockwise")
        # Reverse the trace
        trace = np.flip(trace, axis=0)
    elif cross_sum < 0:
        # print("anticlockwise")
        pass

    return trace


def plot_images(
    images: list,
    masks: list,
    grain_indexes: list,
    px_to_nms: list,
    paths: list,
    width=5,
    VMIN=VMIN,
    VMAX=VMAX,
    cmap=CMAP,
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 4, num_rows * 4))
    for i, (image, mask, grain_index, p_to_nm, path) in enumerate(zip(images, masks, grain_indexes, px_to_nms, paths)):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        im_ax.set_title(f"Grain {grain_index} {p_to_nm} p/nm")
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int))
        mask_ax.plot(path[:, 1], path[:, 0], color="white")
        mask_ax.axis("off")

    fig.tight_layout()
    plt.show()


grain_dict_paths = {}

for index, grain_dict in grain_dicts.items():
    image = grain_dict["image"]
    mask = grain_dict["predicted_mask"]
    p_to_nm = grain_dict["p_to_nm"]
    intersection_labels = grain_dict["intersection_labels"]

    # Get distance transform
    distance_transform = distance_transform_edt(mask > 0)
    distance_transform[mask == 2] = 0

    # Starting at the point where intersection region 0 have maximum distance transform
    intersection_labels = label(intersection_labels)
    intersection_regions = regionprops(intersection_labels)
    region_0 = intersection_regions[0]
    region_1 = intersection_regions[1]
    region_0_distance_transform_values = []
    region_1_distance_transform_values = []
    for pixel in region_0.coords:
        region_0_distance_transform_values.append(distance_transform[pixel[0], pixel[1]])
    for pixel in region_1.coords:
        region_1_distance_transform_values.append(distance_transform[pixel[0], pixel[1]])
    region_0_distance_transform_values = np.array(region_0_distance_transform_values)
    region_1_distance_transform_values = np.array(region_1_distance_transform_values)
    region_0_max_distance_transform_value = np.max(region_0_distance_transform_values)
    region_1_max_distance_transform_value = np.max(region_1_distance_transform_values)
    region_0_max_distance_transform_value_index = np.argmax(region_0_distance_transform_values)
    region_1_max_distance_transform_value_index = np.argmax(region_1_distance_transform_values)
    region_0_max_distance_transform_value_pixel = region_0.coords[region_0_max_distance_transform_value_index]
    region_1_max_distance_transform_value_pixel = region_1.coords[region_1_max_distance_transform_value_index]

    start_point = (region_0_max_distance_transform_value_pixel[0], region_0_max_distance_transform_value_pixel[1])
    end_point = (region_1_max_distance_transform_value_pixel[0], region_1_max_distance_transform_value_pixel[1])

    inverted_distance_transform = np.max(distance_transform) - distance_transform
    inverted_distance_transform[inverted_distance_transform == np.max(inverted_distance_transform)] = 1000

    route, weight = route_through_array(inverted_distance_transform, start_point, end_point)
    route = np.array(route)

    route = flip_if_anticlockwise(route.astype(float))

    grain_dict_paths[index] = grain_dict
    grain_dict_paths[index]["path"] = route

plot_images(
    [grain_dict_paths[i]["image"] for i in grain_dict_paths],
    [grain_dict_paths[i]["predicted_mask"] for i in grain_dict_paths],
    [i for i in grain_dict_paths],
    [grain_dict_paths[i]["p_to_nm"] for i in grain_dict_paths],
    [grain_dict_paths[i]["path"] for i in grain_dict_paths],
)

# Pooling

In [None]:
# Pool the traces


def plot_images(
    images: list,
    masks: list,
    grain_indexes: list,
    px_to_nms: list,
    paths: list,
    width=5,
    VMIN=VMIN,
    VMAX=VMAX,
    cmap=CMAP,
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 4, num_rows * 4))
    for i, (image, mask, grain_index, p_to_nm, path) in enumerate(zip(images, masks, grain_indexes, px_to_nms, paths)):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        im_ax.set_title(f"Grain {grain_index} {p_to_nm} p/nm")
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int))
        mask_ax.plot(path[:, 1], path[:, 0], color="white")
        mask_ax.axis("off")

    fig.tight_layout()
    plt.show()


pooled_path_grain_dict = {}
# Number of points to pool
n = 6
for index, grain_dict in grain_dict_paths.items():
    path = grain_dict["path"]

    pooled_path = []

    for i in range(n // 2, len(path) - n // 2):
        binned_points = []
        for j in range(-n // 2, n // 2):
            if i + j < len(path):
                binned_points.append(path[i + j])
            else:
                break

        pooled_path.append(np.mean(binned_points, axis=0))
    pooled_path = np.array(pooled_path)

    # Add start and end points to the path
    pooled_path = np.concatenate([np.array([path[0]]), pooled_path, np.array([path[-1]])])

    pooled_path = flip_if_anticlockwise(pooled_path)

    pooled_path_grain_dict[index] = grain_dict
    pooled_path_grain_dict[index]["pooled_path"] = pooled_path

plot_images(
    [pooled_path_grain_dict[i]["image"] for i in pooled_path_grain_dict],
    [pooled_path_grain_dict[i]["predicted_mask"] for i in pooled_path_grain_dict],
    [i for i in pooled_path_grain_dict],
    [pooled_path_grain_dict[i]["p_to_nm"] for i in pooled_path_grain_dict],
    [pooled_path_grain_dict[i]["pooled_path"] for i in pooled_path_grain_dict],
)

In [None]:
# Calculate binding angle using the vector of the last n nanometres of the path


def angle_diff_signed(v1: np.ndarray, v2: np.ndarray):
    """Calculate the signed angle difference between two vectors.

    Parameters
    ----------
    v1: np.ndarray
        The first vector.
    v2: np.ndarray
        The second vector.

    Returns
    -------
    float
        The signed angle difference between the two vectors.
    """

    # Calculate if the new vector is clockwise or anticlockwise from the old vector

    # Calculate the angle between the vectors
    angle = np.arccos(np.clip(np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)), -1.0, 1.0))

    # Calculate the cross product
    cross = np.cross(v1, v2)

    # If the cross product is positive, the new vector is clockwise from the old vector
    if cross > 0:
        angle = -angle

    return angle


def angle_per_nm_non_closed(trace: np.ndarray, p_to_nm: float, plot: bool = False) -> np.ndarray:
    """Calculate the angle per nm of a non-closed trace.

    Parameters
    ----------
    trace: np.ndarray
        The trace to calculate the angle per nm of.
    p_to_nm: float
        The pixel to nm scaling factor.

    Returns
    -------
    np.ndarray
        The angle change per nm for each point in the trace
    """

    if plot:
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))

    angles_per_nm = np.zeros(len(trace))
    angle_diffs = np.zeros(len(trace))

    for index, point in enumerate(trace):
        # print(f"index: {index}")
        if plot:
            ax.scatter(point[1], point[0], c="purple", s=20)

        # Get the vectors to the previous and next points
        if index == 0:
            v_prev = point - trace[-1]
            v_next = trace[index + 1] - point
        if index == len(trace) - 1:
            v_prev = point - trace[index - 1]
            v_next = trace[0] - point
        else:
            v_prev = point - trace[index - 1]
            v_next = trace[index + 1] - point

        # print(f"vprev: {v_prev} vnext: {v_next}")

        # Normalise the vectors to unit length
        norm_v_prev = v_prev / np.linalg.norm(v_prev) * 0.1
        norm_v_next = v_next / np.linalg.norm(v_next) * 0.1

        if index != 0 and index != len(trace) - 1:
            angle = angle_diff_signed(v_prev, v_next)
        else:
            angle = 0

        if plot:
            # Plot the vectors
            ax.arrow(
                point[1], point[0], norm_v_prev[1], norm_v_prev[0], head_width=0.01, head_length=0.2, fc="r", ec="r"
            )
            ax.arrow(
                point[1], point[0], norm_v_next[1], norm_v_next[0], head_width=0.01, head_length=0.2, fc="b", ec="b"
            )
            # Write text for the angle
            ax.text(point[1], point[0], f"{np.degrees(angle):.2f}", fontsize=12, color="black")

        distance = np.linalg.norm(v_prev) * p_to_nm

        # print(f"distance: {distance:.4f} angle: {angle:.4f} angle per nm: {angle / distance:.4f}")

        angles_per_nm[index] = angle / distance
        angle_diffs[index] = angle

    if plot:
        plt.plot(trace[:, 1], trace[:, 0], "k")
        plt.show()

    return angles_per_nm, angle_diffs


def plot_images(
    images: list,
    masks: list,
    grain_indexes: list,
    px_to_nms: list,
    paths: list,
    binding_vector_positions: list,
    width=5,
    VMIN=VMIN,
    VMAX=VMAX,
    cmap=CMAP,
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 4, num_rows * 4))
    for i, (image, mask, grain_index, p_to_nm, path, binding_vector_position_tuple) in enumerate(
        zip(images, masks, grain_indexes, px_to_nms, paths, binding_vector_positions)
    ):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        im_ax.set_title(f"Grain {grain_index} {p_to_nm} p/nm")
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int))
        mask_ax.plot(path[:, 1], path[:, 0], color="white")
        mask_ax.plot([])
        mask_ax.axis("off")

    fig.tight_layout()
    plt.show()


binding_angle_grain_dict = {}

for index, grain_dict in pooled_path_grain_dict.items():
    image = grain_dict["image"]
    pooled_path = grain_dict["pooled_path"]
    predicted_mask = grain_dict["predicted_mask"]
    p_to_nm = grain_dict["p_to_nm"]
    intersection_labels = grain_dict["intersection_labels"]

    # Sample points every n nm
    nm_sample_distance = 1
    nm_sampled_path = []
    pixel_sample_distance = nm_sample_distance / p_to_nm
    for i in range(len(pooled_path)):
        if i == 0:
            nm_sampled_path.append(pooled_path[i])
        else:
            last_index_nm_sampled_path = len(nm_sampled_path) - 1
            if np.linalg.norm(pooled_path[i] - nm_sampled_path[-1]) > pixel_sample_distance:
                nm_sampled_path.append(pooled_path[i])
    nm_sampled_path = np.array(nm_sampled_path)

    # Get angle diffs
    angles_per_nm, angle_diffs = angle_per_nm_non_closed(nm_sampled_path, p_to_nm, plot=False)

    # Smooth the angles per nm
    angles_per_nm = gaussian_filter1d(angles_per_nm, 3)

    total_angle_diff_over_path = np.sum(angle_diffs)

    angle_diff_between_start_end = total_angle_diff_over_path - np.pi

    # plt.imshow(image)
    # plt.plot(nm_sampled_path[:, 1], nm_sampled_path[:, 0], color="orange")
    # plt.title(f"{np.degrees(angle_diff_between_start_end):.2f}")
    # plt.show()

    binding_angle_grain_dict[index] = grain_dict
    binding_angle_grain_dict[index]["binding_angle"] = angle_diff_between_start_end
    binding_angle_grain_dict[index]["angles_per_nm"] = angles_per_nm

if PLOT_RESULTS:
    # Plot binding angle
    binding_angles = np.array([binding_angle_grain_dict[i]["binding_angle"] for i in binding_angle_grain_dict])
    sns.kdeplot(binding_angles)
    plt.title(f"Binding angles for {SAMPLE_TYPE} (n = {len(binding_angles)})")
    plt.show()

    mean_angles_per_nm = [np.mean(binding_angle_grain_dict[i]["angles_per_nm"]) for i in binding_angle_grain_dict]
    sns.kdeplot(mean_angles_per_nm)
    plt.title(f"Mean angles per nm for {SAMPLE_TYPE} (n = {len(mean_angles_per_nm)})")
    plt.show()

    std_angles_per_nm = [np.std(binding_angle_grain_dict[i]["angles_per_nm"]) for i in binding_angle_grain_dict]
    sns.kdeplot(std_angles_per_nm)
    plt.title(f"Standard deviation of angles per nm for {SAMPLE_TYPE} (n = {len(std_angles_per_nm)})")
    plt.show()

In [None]:
# Get connection regions as a proportion of perimeter

perimeter_intersection_grain_dict = {}

for index, grain_dict in binding_angle_grain_dict.items():
    # print(f"\ngrain index: {index}")

    try:
        image = grain_dict["image"]
        mask = grain_dict["predicted_mask"]
        p_to_nm = grain_dict["p_to_nm"]
        pooled_path = grain_dict["pooled_path"]
        intersection_labels = grain_dict["intersection_labels"]

        # Get perimeter of the gem region
        gem_region = mask == 2
        gem_region = binary_fill_holes(gem_region.astype(bool))

        gem_outline = binary_dilation(gem_region) ^ gem_region
        gem_outline[gem_region == 1] = 0

        # Skeletonize to get rid of any rogue corners
        gem_outline = skeletonize(gem_outline, method="lee").astype(bool)
        intersection_labels = label(skeletonize(intersection_labels.astype(bool), method="lee").astype(bool))

        gem_region_with_intersections = np.copy(gem_outline).astype(int)
        gem_region_with_intersections[intersection_labels > 0] = 2

        # Trace the outline
        # Start at a random point
        gem_tracing_hist = np.copy(gem_region_with_intersections)
        start_point = np.array(np.where(gem_tracing_hist == 1))[:, 0]
        # print(f"start point: {start_point}")

        point_coords = start_point
        gem_trace = []
        intersection_regions = []
        in_intersection = False
        intersection_start = None
        len_gem_outline = len(np.argwhere(gem_tracing_hist > 0))
        # print(f"number of > 0 pixels in gem tracing hist : {len_gem_outline}")

        # plt.imshow(image)
        # plt.show()
        # plt.imshow(mask)
        # plt.show()

        # fig, ax = plt.subplots(1, 1, figsize=(10, 10))
        # ax.imshow(gem_tracing_hist)
        # plt.show()

        for tracing_index in range(len_gem_outline - 1):
            # print(f"\ntracing index: {tracing_index}")
            # Add point to trace
            gem_trace.append(point_coords)

            # Get neighbours
            neighbours = gem_tracing_hist[
                point_coords[0] - 1 : point_coords[0] + 2, point_coords[1] - 1 : point_coords[1] + 2
            ]
            gem_tracing_hist[point_coords[0], point_coords[1]] = -1
            # print(neighbours)
            # Grab an index of a non zero point in neighbourhood
            nonzero_neighbours = np.argwhere(neighbours > 0)
            # print(f"nonzero neighbours: {nonzero_neighbours}")

            if tracing_index == 0:
                point_relative_coords = nonzero_neighbours[0]
            else:
                if len(nonzero_neighbours) > 1:
                    raise ValueError(
                        f"Point {point_coords} has more than one neighbour: {neighbours} {nonzero_neighbours}"
                    )

                point_relative_coords = nonzero_neighbours[0]

            point_relative_coords -= 1

            # print(f"point relative coords: {point_relative_coords}")

            point_coords = point_coords + point_relative_coords
            # print(f"new point coords: {point_coords}")
            point_value = gem_tracing_hist[point_coords[0], point_coords[1]]
            # print(f"point value: {point_value}")

            if point_value == 2 and not in_intersection:
                in_intersection = True
                intersection_start = tracing_index + 1
            elif point_value == 2:
                pass
            elif point_value == 1 and in_intersection:
                intersection_end = tracing_index - 1
                intersection_regions.append(
                    {
                        "start": intersection_start,
                        "end": intersection_end,
                        "midpoint": intersection_start + (intersection_end - intersection_start) // 2,
                    }
                )
                in_intersection = False

        if in_intersection:
            intersection_end = tracing_index - 1
            intersection_regions.append(
                {
                    "start": intersection_start,
                    "end": intersection_end,
                    "midpoint": intersection_start + (intersection_end - intersection_start) // 2,
                }
            )
            in_intersection = False

        if len(intersection_regions) < 2:
            raise ValueError(f"Only one intersection region {intersection_regions}")
        if len(intersection_regions) > 2:
            raise ValueError(f"More than two intersection regions: {intersection_regions}")

        # print(f"intersection regions: {intersection_regions}")

        # Distances between intersection midpoints
        # Inner distance (where the end of the trace is not encountered)
        inner_distance_pixels = intersection_regions[1]["midpoint"] - intersection_regions[0]["midpoint"]
        outer_distance_pixels = len(gem_trace) - inner_distance_pixels
        # print(f"inner distance: {inner_distance}")
        # print(f"outer distance: {outer_distance}")
        min_midpoint_distance_pixels = np.min([inner_distance_pixels, outer_distance_pixels])
        max_midpoint_distance_pixels = np.max([inner_distance_pixels, outer_distance_pixels])
        min_midpoint_distance_nm = min_midpoint_distance_pixels * p_to_nm
        max_midpoint_distance_nm = max_midpoint_distance_pixels * p_to_nm
        midpoint_distance_ratio = min_midpoint_distance_pixels / max_midpoint_distance_pixels

        # plt.imshow(gem_region_with_intersections)
        # Create colour plot
        coloured_gem_trace = np.zeros_like(image)
        for gem_trace_index, gem_trace_coord in enumerate(gem_trace):
            coloured_gem_trace[gem_trace_coord[0], gem_trace_coord[1]] = gem_trace_index

        # plt.imshow(coloured_gem_trace)
        # plt.scatter(start_point[1], start_point[0], c="white", s=10)
        # plt.show()

        # print(f"intersection midpoint perimeter position ratio: {midpoint_distance_ratio}")

        perimeter_intersection_grain_dict[index] = grain_dict
        perimeter_intersection_grain_dict[index]["intersection_min_midpoint_distance"] = min_midpoint_distance_nm
        perimeter_intersection_grain_dict[index]["intersection_max_midpoint_distance"] = max_midpoint_distance_nm
        perimeter_intersection_grain_dict[index]["intersection_midpoint_distance_ratio"] = midpoint_distance_ratio

    # Sometimes the dilation edge finder leaves a non-skeletonized outline and so tracing breaks
    except ValueError:
        print(f"Index {index} failed due to non-skeletonised gem outline")

    except IndexError:
        print(f"Index {index} failed due to gem outline being outside image bounds")

In [None]:
# Elongation metric


def signed_angle_between_vectors(v1, v2):
    # Check that neither vector is 0
    if np.all(v1 == 0) or np.all(v2 == 0):
        raise ValueError("One of the vectors is 0")

    # Calculate the dot product
    dot_product = np.dot(v1, v2)

    # Calculate the norms of each vector
    norm_v1 = np.linalg.norm(v1)
    norm_v2 = np.linalg.norm(v2)

    # Calculate the cosine of the angle
    cos_theta = dot_product / (norm_v1 * norm_v2)

    angle = np.arccos(np.clip(cos_theta, -1.0, 1.0))

    # Calculate the cross product
    cross_product = np.cross(v1, v2)

    # If the cross product is positive, then the angle is negative
    if cross_product > 0:
        angle = -angle

    # Convert to angle and return
    return angle


def rotate_points(points: np.ndarray, angle: float):
    # Rotate the points by the angle
    # print(f"rotating by {np.degrees(angle)} degrees")
    # rotation_matrix = np.array([[-np.sin(angle), np.cos(angle)], [np.cos(angle), np.sin(angle)]])
    rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])
    rotated_points = np.dot(points, rotation_matrix)
    return rotated_points


def plot_images_with_paths(
    images: list,
    masks: list,
    paths: list,
    centroids: list,
    midpoints: list,
    elongations: list,
    px_to_nms: list,
    bounding_boxes: list,
    aspect_ratios: list,
    width=5,
    save_dir=None,
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, ax = plt.subplots(
        num_rows, width * num_images_in_batch, figsize=(width * 2 * num_images_in_batch, num_rows * 4)
    )
    for i, (image, mask, path) in enumerate(zip(images, masks, paths)):
        im_ax = ax[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap="viridis")
        im_ax.axis("off")
        mask_ax = ax[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask, cmap="viridis")
        mask_ax.plot(path[:, 1], path[:, 0], "r-")
        mask_ax.axis("off")
        mask_ax.scatter(centroids[i][1], centroids[i][0], c="r", s=5)
        mask_ax.scatter(midpoints[i][1], midpoints[i][0], c="g", s=5)
        # Draw a line between midpoint and centroid
        ax[i // width, i % width * 2 + 1].plot(
            [centroids[i][1], midpoints[i][1]],
            [centroids[i][0], midpoints[i][0]],
            "g-",
        )
        # Plot the bounding box
        ax[i // width, i % width * 2 + 1].plot(bounding_boxes[i][:, 1], bounding_boxes[i][:, 0], color="orange")
        ax[i // width, i % width * 2 + 1].set_title(
            f"||| elongation: {elongations[i]:.2f} nm | \naspect ratio: {aspect_ratios[i]:.2f} | \npx_to_nm: {px_to_nms[i]} |||",
            fontsize=10,
        )
    fig.tight_layout()

    if save_dir is not None:
        plt.savefig(save_dir, dpi=500)


def align_points_to_vertical(points: np.ndarray, orientation_vector: np.ndarray):
    # Align the points to the vertical by rotating them by the angle between the orientation vector and the vertical
    vertical_vector = np.array([1, 0])
    angle = signed_angle_between_vectors(orientation_vector, vertical_vector)
    rotated_points = rotate_points(points, angle)
    rotated_orientation_vector = rotate_points(orientation_vector, angle)
    return rotated_points, rotated_orientation_vector, -angle


image_dict_elongation = {}

for index, grain_dict in perimeter_intersection_grain_dict.items():
    image = grain_dict["image"]
    mask = grain_dict["predicted_mask"]
    p_to_nm = grain_dict["p_to_nm"]
    pooled_path = grain_dict["pooled_path"]

    # Get the centroid of the path
    path_centroid = np.mean(pooled_path, axis=0)

    path_points_shifted = np.copy(pooled_path) - path_centroid

    # Get the start and end points of the path
    start_point = pooled_path[0]
    end_point = pooled_path[-1]
    start_point_end_point_midpoint = (start_point + end_point) / 2

    elongation_vector = start_point_end_point_midpoint - path_centroid
    elongation_vector_points = np.array(
        [[start_point_end_point_midpoint[0], start_point_end_point_midpoint[1]], [path_centroid[0], path_centroid[1]]]
    )

    shifted_elongation_vector_points = elongation_vector_points - path_centroid

    # Distance between midpoint and centroid
    distance = np.linalg.norm(elongation_vector)

    # Rotate the points
    shifted_rotated_path_points, shifted_rotated_elongation_vector, angle_rad = align_points_to_vertical(
        path_points_shifted, elongation_vector
    )

    # Plot the points after rotation
    # Move the points back to the original position
    shifted_rotated_points = shifted_rotated_path_points
    shifted_rotated_elongation_vector_points = np.array([[0, 0], shifted_rotated_elongation_vector])
    # Find the bounding box of the rotated points
    min_x = np.min(shifted_rotated_points[:, 0])
    max_x = np.max(shifted_rotated_points[:, 0])
    min_y = np.min(shifted_rotated_points[:, 1])
    max_y = np.max(shifted_rotated_points[:, 1])

    rotated_bounding_box = np.array([[min_x, min_y], [max_x, min_y], [max_x, max_y], [min_x, max_y], [min_x, min_y]])

    aspect_ratio = (max_x - min_x) / (max_y - min_y)

    # Un-rotate the bounding box and points
    unrotated_bounding_box = rotate_points(rotated_bounding_box, angle_rad) + path_centroid
    unrotated_points = rotate_points(shifted_rotated_points, angle_rad) + path_centroid
    unrotated_elongation_vector_points = (
        rotate_points(shifted_rotated_elongation_vector_points, angle_rad) + path_centroid
    )

    image_dict_elongation[index] = grain_dict
    image_dict_elongation[index]["path_centroid"] = path_centroid
    image_dict_elongation[index]["path_start_point"] = start_point
    image_dict_elongation[index]["path_end_point"] = end_point
    image_dict_elongation[index]["path_start_point_end_point_midpoint"] = start_point_end_point_midpoint
    image_dict_elongation[index]["path_elongation_distance_nm"] = distance * p_to_nm
    image_dict_elongation[index]["path_bounding_box_length"] = max_y - min_y
    image_dict_elongation[index]["path_bounding_box"] = unrotated_bounding_box
    image_dict_elongation[index]["path_aspect_ratio"] = aspect_ratio


if PLOT_RESULTS:
    # Plot the paths and centroids
    plot_images_with_paths(
        images=[image_dict_elongation[i]["image"] for i in image_dict_elongation],
        masks=[image_dict_elongation[i]["predicted_mask"] for i in image_dict_elongation],
        paths=[image_dict_elongation[i]["path"] for i in image_dict_elongation],
        centroids=[image_dict_elongation[i]["path_centroid"] for i in image_dict_elongation],
        midpoints=[image_dict_elongation[i]["path_start_point_end_point_midpoint"] for i in image_dict_elongation],
        elongations=[image_dict_elongation[i]["path_elongation_distance_nm"] for i in image_dict_elongation],
        px_to_nms=[image_dict_elongation[i]["p_to_nm"] for i in image_dict_elongation],
        bounding_boxes=[image_dict_elongation[i]["path_bounding_box"] for i in image_dict_elongation],
        aspect_ratios=[image_dict_elongation[i]["path_aspect_ratio"] for i in image_dict_elongation],
        save_dir=None,
    )

In [None]:
def is_clockwise(p_1: tuple, p_2: tuple, p_3: tuple) -> bool:
    """Function to determine if three points make a clockwise or counter-clockwise turn.

    Parameters
    ----------
    p_1: tuple
        First point to be used to calculate turn.
    p_2: tuple
        Second point to be used to calculate turn.
    p_3: tuple
        Third point to be used to calculate turn.

    Returns
    -------
    boolean
        Indicator of whether turn is clockwise.
    """
    # Determine if three points form a clockwise or counter-clockwise turn.
    # I use the method of calculating the determinant of the following rotation matrix here. If the determinant
    # is > 0 then the rotation is counter-clockwise.
    rotation_matrix = np.array(((p_1[0], p_1[1], 1), (p_2[0], p_2[1], 1), (p_3[0], p_3[1], 1)))
    return not np.linalg.det(rotation_matrix) > 0


def convex_hull(edge_points: np.ndarray):
    # Sort the vectors by their x coordinate and then by their y coordinate.
    # The conversion between list and numpy array can be removed, though it would be harder
    # to read.

    # THIS SORT IS BROKEN
    # edge_points.sort()
    sorted_order = np.lexsort((edge_points[:, 1], edge_points[:, 0]))
    sorted_edge_points = edge_points[sorted_order]
    print(sorted_edge_points)

    plt.plot(sorted_edge_points[:, 0], sorted_edge_points[:, 1])
    plt.show()

    # Construct upper and lower hulls for the edge points. Sadly we can't just use the standard hull
    # that graham_scan() returns, since we need to separate the upper and lower hulls. I might streamline
    # these two into one method later.
    upper_hull = []
    lower_hull = []
    for point in sorted_edge_points:
        while len(lower_hull) > 1 and is_clockwise(lower_hull[-2], lower_hull[-1], point):
            lower_hull.pop()
        lower_hull.append(point)
        while len(upper_hull) > 1 and not is_clockwise(upper_hull[-2], upper_hull[-1], point):
            upper_hull.pop()
        upper_hull.append(point)
        plt.plot(np.array(upper_hull)[:, 0], np.array(upper_hull)[:, 1], color="lightblue")

    upper_hull = np.array(upper_hull)
    lower_hull = np.array(lower_hull)

    plt.show()

    return lower_hull, upper_hull


# left_turn = np.array([
#     [5, 5],
#     [8, 5],
#     [9, 4],
# ])

# left_turn = np.flip(left_turn, axis=1)

# right_turn = np.array([
#     [5, 5],
#     [8, 5],
#     [9, 6],
# ])

# plt.plot(left_turn[:, 1], left_turn[:, 0])
# plt.plot(right_turn[:, 1], right_turn[:, 0])
# plt.xlim(0, 10)
# plt.ylim(0, 10)
# plt.show()

# print(is_clockwise(left_turn[0], left_turn[1], left_turn[2]))
# print(is_clockwise(right_turn[0], right_turn[1], right_turn[2]))

shape = np.array(
    [
        [
            5,
            5,
        ],
        [8, 10],
        [12, 20],
        [15, 7],
        [10, 5],
        [5, 5],
    ]
)

plt.plot(shape[:, 0], shape[:, 1])
plt.show()

lower_hull, upper_hull = convex_hull(edge_points=shape)

plt.plot(lower_hull[:, 0], lower_hull[:, 1], color="lightblue")
plt.plot(upper_hull[:, 0], upper_hull[:, 1], color="orange")
plt.show()

In [None]:
# Ferets and internal area


def calculate_edges(grain_mask: np.ndarray):
    """Class method that takes a 2D boolean numpy array image of a grain and returns a python list of the
    coordinates of the edges of the grain.

    Parameters
    ----------
    grain_mask : np.ndarray
        A 2D numpy array image of a grain. Data in the array must be boolean.
    edge_detection_method : str
        Method used for detecting the edges of grain masks before calculating statistics on them.
        Do not change unless you know exactly what this is doing. Options: "binary_erosion", "canny".

    Returns
    -------
    edges : list
        List containing the coordinates of the edges of the grain.
    """
    # Fill any holes
    filled_grain_mask = binary_fill_holes(grain_mask)

    # Add padding (needed for erosion)
    padded = np.pad(filled_grain_mask, 1)
    # Erode by 1 pixel
    eroded = binary_erosion(padded)
    # Remove padding
    eroded = eroded[1:-1, 1:-1]

    # Edges is equal to the difference between the
    # original image and the eroded image.
    edges = filled_grain_mask.astype(int) - eroded.astype(int)

    nonzero_coordinates = edges.nonzero()
    # Get vector representation of the points
    # FIXME : Switched to list comprehension but should be unnecessary to create this as a list as we can use
    # np.stack() to combine the arrays and use that...
    # return np.stack(nonzero_coordinates, axis=1)
    # edges = []
    # for vector in np.transpose(nonzero_coordinates):
    #     edges.append(list(vector))
    # return edges
    return [list(vector) for vector in np.transpose(nonzero_coordinates)]


def get_triangle_height(base_point_1: np.array, base_point_2: np.array, top_point: np.array) -> float:
    """Returns the height of a triangle defined by the input point vectors.
    Parameters
    ----------
    base_point_1: np.ndarray
        a base point of the triangle, eg: [5, 3].

    base_point_2: np.ndarray
        a base point of the triangle, eg: [8, 3].

    top_point: np.ndarray
        the top point of the triangle, defining the height from the line between the two base points, eg: [6,10].

    Returns
    -------
    Float
        The height of the triangle - ie the shortest distance between the top point and the line between the two
    base points.
    """

    # Height of triangle = A/b = ||AB X AC|| / ||AB||
    a_b = base_point_1 - base_point_2
    a_c = base_point_1 - top_point
    return np.linalg.norm(np.cross(a_b, a_c)) / np.linalg.norm(a_b)


def get_max_min_ferets(edge_points: list):
    """Returns the minimum and maximum feret diameters for a grain.
    These are defined as the smallest and greatest distances between
    a pair of callipers that are rotating around a 2d object, maintaining
    contact at all times.

    Parameters
    ----------
    edge_points: list
        a list of the vector positions of the pixels comprising the edge of the
        grain. Eg: [[0, 0], [1, 0], [2, 1]]
    Returns
    -------
    min_feret: float
        the minimum feret diameter of the grain
    max_feret: float
        the maximum feret diameter of the grain

    Notes
    -----
    The method starts out by calculating the upper and lower convex hulls using
    an algorithm based on the Graham Scan Algorithm [1]. Using these upper and
    lower hulls, the callipers are simulated as rotating clockwise around the grain.
    We determine the order in which vertices are encountered by comparing the
    gradients of the slopes between vertices. An array of pairs of points that
    are in contact with either calliper at a given time is created in order to
    be able to calculate the maximum feret diameter. The minimum diameter is a
    little tricky, since it won't simply be the shortest distance between two
    contact points, but it will occur somewhere during the rotation around a
    pair of contact points. It turns out that the point will always be such
    that two points are in contact with one calliper while the other calliper
    is in contact with another point. We can use this fact to be sure of finding
    the smallest feret diameter, simply by testing each triangle of 3 contact points
    as we iterate, finding the height of the triangle that is formed between the
    three aforementioned points, as this will be the perpendicular distance between
    the callipers.

    References
    ----------
    [1] Graham, R.L. (1972).
        "An Efficient Algorithm for Determining the Convex Hull of a Finite Planar Set".
        Information Processing Letters. 1 (4): 132-133.
        doi:10.1016/0020-0190(72)90045-2.
    """

    min_feret_triangle = None

    lower_hull, upper_hull = convex_hull(edge_points)

    plt.plot(upper_hull[:, 1], upper_hull[:, 0], color="green")
    plt.plot(lower_hull[:, 1], lower_hull[:, 1], color="yellow")

    # Create list of contact vertices for calipers on the antipodal hulls
    contact_points = []
    upper_index = 0
    lower_index = len(lower_hull) - 1
    min_feret = None
    while upper_index < len(upper_hull) - 1 or lower_index > 0:
        contact_points.append([lower_hull[lower_index, :], upper_hull[upper_index, :]])
        # If we have reached the end of the upper hull, continute iterating over the lower hull
        if upper_index == len(upper_hull) - 1:
            lower_index -= 1
            small_feret = get_triangle_height(
                np.array(lower_hull[lower_index + 1, :]),
                np.array(lower_hull[lower_index, :]),
                np.array(upper_hull[upper_index, :]),
            )
            if min_feret is None or small_feret < min_feret:
                min_feret = small_feret
                min_feret_triangle = [
                    lower_hull[lower_index + 1, :],
                    lower_hull[lower_index, :],
                    upper_hull[upper_index, :],
                ]
        # If we have reached the end of the lower hull, continue iterating over the upper hull
        elif lower_index == 0:
            upper_index += 1
            small_feret = get_triangle_height(
                np.array(upper_hull[upper_index - 1, :]),
                np.array(upper_hull[upper_index, :]),
                np.array(lower_hull[lower_index, :]),
            )
            if min_feret is None or small_feret < min_feret:
                min_feret = small_feret
                min_feret_triangle = [
                    lower_hull[lower_index + 1, :],
                    lower_hull[lower_index, :],
                    upper_hull[upper_index, :],
                ]
        # Check if the gradient of the last point and the proposed next point in the upper hull is greater than the gradient
        # of the two corresponding points in the lower hull, if so, this means that the next point in the upper hull
        # will be encountered before the next point in the lower hull and vice versa.
        # Note that the calculation here for gradients is the simple delta upper_y / delta upper_x > delta lower_y / delta lower_x
        # however I have multiplied through the denominators such that there are no instances of division by zero. The
        # inequality still holds and provides what is needed.
        elif (upper_hull[upper_index + 1, 1] - upper_hull[upper_index, 1]) * (
            lower_hull[lower_index, 0] - lower_hull[lower_index - 1, 0]
        ) > (lower_hull[lower_index, 1] - lower_hull[lower_index - 1, 1]) * (
            upper_hull[upper_index + 1, 0] - upper_hull[upper_index, 0]
        ):
            # If the upper hull is encoutnered first, increment the iteration index for the upper hull
            # Also consider the triangle that is made as the two upper hull vertices are colinear with the caliper
            upper_index += 1
            small_feret = get_triangle_height(
                np.array(upper_hull[upper_index - 1, :]),
                np.array(upper_hull[upper_index, :]),
                np.array(lower_hull[lower_index, :]),
            )
            if min_feret is None or small_feret < min_feret:
                min_feret = small_feret
        else:
            # The next point in the lower hull will be encountered first, so increment the lower hull iteration index.
            lower_index -= 1
            small_feret = get_triangle_height(
                np.array(lower_hull[lower_index + 1, :]),
                np.array(lower_hull[lower_index, :]),
                np.array(upper_hull[upper_index, :]),
            )

            if min_feret is None or small_feret < min_feret:
                min_feret = small_feret
                min_feret_triangle = [
                    lower_hull[lower_index + 1, :],
                    lower_hull[lower_index, :],
                    upper_hull[upper_index, :],
                ]

    contact_points = np.array(contact_points)

    # Find the minimum and maximum distance in the contact points
    max_feret = None
    for point_pair in contact_points:
        dist = np.sqrt((point_pair[0, 0] - point_pair[1, 0]) ** 2 + (point_pair[0, 1] - point_pair[1, 1]) ** 2)
        if max_feret is None or max_feret < dist:
            max_feret = dist

    return min_feret, max_feret, min_feret_triangle


# Calculate min, max feret diameters for the ring masks

image_dict_ferets = {}

for index, image_dict_item in image_dict_elongation.items():
    image = image_dict_item["image"]
    mask = image_dict_item["predicted_mask"]
    p_to_nm = image_dict_item["p_to_nm"]
    pooled_path = image_dict_item["pooled_path"]

    plt.imshow(image)
    plt.plot(pooled_path[:, 1], pooled_path[:, 0], color="white")
    plt.show()

    edge_points = np.array(calculate_edges(mask == 1))

    edge_visualisation = np.zeros_like(image)
    edge_visualisation[edge_points[:, 0], edge_points[:, 1]] = 1
    plt.imshow(edge_visualisation)
    plt.show()

    edge_points = np.flip(edge_points, axis=1)

    lower_hull, upper_hull = convex_hull(edge_points=edge_points)

    print(lower_hull)

    plt.imshow(image)
    plt.plot(lower_hull[:, 1], lower_hull[:, 0])
    plt.show()

    plt.imshow(image)

    min_feret, max_feret, min_feret_triangle = get_max_min_ferets(edge_points=edge_points)
    plt.show()

    image_dict_ferets[index] = image_dict_item
    image_dict_ferets[index]["min_feret"] = min_feret * p_to_nm
    image_dict_ferets[index]["max_feret"] = max_feret * p_to_nm

    a_b = min_feret_triangle[0] - min_feret_triangle[1]
    a_c = min_feret_triangle[0] - min_feret_triangle[2]

    # return np.linalg.norm(np.cross(a_b, a_c)) / np.linalg.norm(a_b)

    plt.imshow(image)
    plt.scatter(np.array(pooled_path)[:, 1], np.array(pooled_path)[:, 0], c="r", s=5)
    plt.scatter(np.array(min_feret_triangle)[:-1, 1], np.array(min_feret_triangle)[:-1, 0], c="y", s=20)
    plt.scatter(np.array(min_feret_triangle)[-1, 1], np.array(min_feret_triangle)[-1, 0], c="aqua", s=20)

    min_feret_triangle = np.append(min_feret_triangle, [min_feret_triangle[0]], axis=0)
    plt.plot(np.array(min_feret_triangle)[:, 1], np.array(min_feret_triangle)[:, 0], "y", linewidth=2)
    plt.show()

    plt.imshow(image, cmap="viridis")
    # plt.imshow(ring_mask, cmap="viridis")
    # plt.show()
    plt.plot(pooled_path[:, 1], pooled_path[:, 0], color="white")
    plt.show()
    print(f"min feret: {min_feret * p_to_nm}")
    print(f"min feret pixels: {min_feret} px_2_nm: {p_to_nm}")

    if index > 5:
        break

# sns.histplot(
#     [image_dict_ferets[i]["min_feret"] for i in image_dict_ferets], bins="auto", label="min feret"
# )
# sns.histplot(
#     [image_dict_ferets[i]["min_feret"] for i in image_dict_ferets], bins="auto", label="max feret"
# )
# plt.legend()


sns.kdeplot([image_dict_ferets[i]["min_feret"] for i in image_dict_ferets], label="min feret")
sns.kdeplot([image_dict_ferets[i]["max_feret"] for i in image_dict_ferets], label="max feret")
plt.legend()
plt.show()

print(f"min ferets: {[image_dict_ferets[i]['min_feret'] for i in image_dict_ferets]}")
print(f"mean min feret: {np.mean([image_dict_ferets[i]['min_feret'] for i in image_dict_ferets])}")