In [None]:
from pathlib import Path
import re
import pickle
from datetime import datetime

import numpy as np
import shapely
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.ndimage import distance_transform_edt, gaussian_filter1d, binary_fill_holes
from skimage.measure import label, regionprops
from skimage.graph import route_through_array
from skimage.morphology import binary_dilation, binary_erosion, skeletonize

from topostats.plottingfuncs import Colormap

colormap = Colormap()
CMAP = colormap.get_cmap()

VMIN = 0
VMAX = 4

In [None]:
SAMPLE_TYPE = "OT2_SC"
DATE_TO_READ_FROM = "2024-03-22"
TODAY_DATE = datetime.now().strftime("%Y-%m-%d")
DATA_DIR = Path(f"/users/sylvi/topo_data/hariborings/extracted_grains/cas9_{SAMPLE_TYPE}/{DATE_TO_READ_FROM}")
SAVE_DIR = Path(f"/Users/sylvi/topo_data/hariborings/processed_grains/cas9_{SAMPLE_TYPE}/{TODAY_DATE}")
SAVE_DIR.mkdir(exist_ok=True, parents=True)
MAX_PX_TO_NM = 10.0
PLOT_RESULTS = False
MAX_GRAIN_NUMBER = 500

FILE_PATH = DATA_DIR / "grain_dict.pkl"
with open(FILE_PATH, "rb") as f:
    grain_dicts = pickle.load(f)

print(f"number of images for sample type [{SAMPLE_TYPE}] : {len(grain_dicts.keys())}")

# Cut off the number of grains
grain_dicts_sample = {}
for i, (key, grain_dictionary) in enumerate(grain_dicts.items()):
    if i > MAX_GRAIN_NUMBER:
        break
    grain_dicts_sample[key] = grain_dictionary

print(f"number of images in sample for {SAMPLE_TYPE} : {len(grain_dicts_sample.keys())}")

grain_dict = grain_dicts_sample

In [None]:
def plot_images(
    images: list, masks: list, grain_indexes: list, px_to_nms: list, width=5, VMIN=VMIN, VMAX=VMAX, cmap=CMAP
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 4, num_rows * 4))
    for i, (image, mask, grain_index, p_to_nm) in enumerate(zip(images, masks, grain_indexes, px_to_nms)):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        im_ax.set_title(f"Grain {grain_index} {p_to_nm} p/nm")
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int))
        mask_ax.axis("off")

    fig.tight_layout()
    plt.show()


if PLOT_RESULTS:
    images = [grain_dicts[i]["image"] for i in grain_dicts]
    masks = [grain_dicts[i]["predicted_mask"] for i in grain_dicts]
    grain_indexes = [i for i in grain_dicts]
    px_to_nms = [grain_dicts[i]["p_to_nm"] for i in grain_dicts]
    plot_images(images, masks, grain_indexes, px_to_nms)

# Pathfinding

In [None]:
def flip_if_anticlockwise(trace: np.ndarray):
    # Check if the trace is clockwise or anticlockwise by summing the cross products of the vectors
    # If the sum is positive, the trace is clockwise
    # If the sum is negative, the trace is anticlockwise
    # If the sum is 0, the trace is a straight line
    cross_sum = 0
    for i in range(1, len(trace) - 1):
        v1 = trace[i, :] - trace[i - 1, :]
        v2 = trace[i + 1, :] - trace[i, :]
        v1 /= np.linalg.norm(v1)
        v2 /= np.linalg.norm(v2)

        # Get cross product
        cross_prod = np.cross(v1, v2)
        cross_sum += cross_prod

    if cross_sum > 0:
        # print("clockwise")
        # Reverse the trace
        trace = np.flip(trace, axis=0)
    elif cross_sum < 0:
        # print("anticlockwise")
        pass

    return trace


def plot_images(
    images: list,
    masks: list,
    grain_indexes: list,
    px_to_nms: list,
    paths: list,
    width=5,
    VMIN=VMIN,
    VMAX=VMAX,
    cmap=CMAP,
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 4, num_rows * 4))
    for i, (image, mask, grain_index, p_to_nm, path) in enumerate(zip(images, masks, grain_indexes, px_to_nms, paths)):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        im_ax.set_title(f"Grain {grain_index} {p_to_nm} p/nm")
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int))
        mask_ax.plot(path[:, 1], path[:, 0], color="white")
        mask_ax.axis("off")

    fig.tight_layout()
    plt.show()


grain_dict_paths = {}

for index, grain_dict in grain_dicts.items():
    image = grain_dict["image"]
    mask = grain_dict["predicted_mask"]
    p_to_nm = grain_dict["p_to_nm"]
    intersection_labels = grain_dict["intersection_labels"]

    # Get distance transform
    distance_transform = distance_transform_edt(mask > 0)
    distance_transform[mask == 2] = 0

    # Starting at the point where intersection region 0 have maximum distance transform
    intersection_labels = label(intersection_labels)
    intersection_regions = regionprops(intersection_labels)
    region_0 = intersection_regions[0]
    region_1 = intersection_regions[1]
    region_0_distance_transform_values = []
    region_1_distance_transform_values = []
    for pixel in region_0.coords:
        region_0_distance_transform_values.append(distance_transform[pixel[0], pixel[1]])
    for pixel in region_1.coords:
        region_1_distance_transform_values.append(distance_transform[pixel[0], pixel[1]])
    region_0_distance_transform_values = np.array(region_0_distance_transform_values)
    region_1_distance_transform_values = np.array(region_1_distance_transform_values)
    region_0_max_distance_transform_value = np.max(region_0_distance_transform_values)
    region_1_max_distance_transform_value = np.max(region_1_distance_transform_values)
    region_0_max_distance_transform_value_index = np.argmax(region_0_distance_transform_values)
    region_1_max_distance_transform_value_index = np.argmax(region_1_distance_transform_values)
    region_0_max_distance_transform_value_pixel = region_0.coords[region_0_max_distance_transform_value_index]
    region_1_max_distance_transform_value_pixel = region_1.coords[region_1_max_distance_transform_value_index]

    start_point = (region_0_max_distance_transform_value_pixel[0], region_0_max_distance_transform_value_pixel[1])
    end_point = (region_1_max_distance_transform_value_pixel[0], region_1_max_distance_transform_value_pixel[1])

    inverted_distance_transform = np.max(distance_transform) - distance_transform
    inverted_distance_transform[inverted_distance_transform == np.max(inverted_distance_transform)] = 1000

    route, weight = route_through_array(inverted_distance_transform, start_point, end_point)
    route = np.array(route)

    route = flip_if_anticlockwise(route.astype(float))

    grain_dict_paths[index] = grain_dict
    grain_dict_paths[index]["path"] = route

plot_images(
    [grain_dict_paths[i]["image"] for i in grain_dict_paths],
    [grain_dict_paths[i]["predicted_mask"] for i in grain_dict_paths],
    [i for i in grain_dict_paths],
    [grain_dict_paths[i]["p_to_nm"] for i in grain_dict_paths],
    [grain_dict_paths[i]["path"] for i in grain_dict_paths],
)

# Pooling

In [None]:
# Pool the traces
def plot_images(
    images: list,
    masks: list,
    grain_indexes: list,
    px_to_nms: list,
    paths: list,
    width=5,
    VMIN=VMIN,
    VMAX=VMAX,
    cmap=CMAP,
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 4, num_rows * 4))
    for i, (image, mask, grain_index, p_to_nm, path) in enumerate(zip(images, masks, grain_indexes, px_to_nms, paths)):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        im_ax.set_title(f"Grain {grain_index} {p_to_nm} p/nm")
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int))
        mask_ax.plot(path[:, 1], path[:, 0], color="white")
        mask_ax.axis("off")

    fig.tight_layout()
    plt.show()


pooled_path_grain_dict = {}
# Number of points to pool
n = 6
for index, grain_dict in grain_dict_paths.items():
    path = grain_dict["path"]

    pooled_path = []

    for i in range(n // 2, len(path) - n // 2):
        binned_points = []
        for j in range(-n // 2, n // 2):
            if i + j < len(path):
                binned_points.append(path[i + j])
            else:
                break

        pooled_path.append(np.mean(binned_points, axis=0))
    pooled_path = np.array(pooled_path)

    # Add start and end points to the path
    pooled_path = np.concatenate([np.array([path[0]]), pooled_path, np.array([path[-1]])])

    pooled_path = flip_if_anticlockwise(pooled_path)

    pooled_path_grain_dict[index] = grain_dict
    pooled_path_grain_dict[index]["pooled_path"] = pooled_path

plot_images(
    [pooled_path_grain_dict[i]["image"] for i in pooled_path_grain_dict],
    [pooled_path_grain_dict[i]["predicted_mask"] for i in pooled_path_grain_dict],
    [i for i in pooled_path_grain_dict],
    [pooled_path_grain_dict[i]["p_to_nm"] for i in pooled_path_grain_dict],
    [pooled_path_grain_dict[i]["pooled_path"] for i in pooled_path_grain_dict],
)

# Ferets - based on path not mask

In [None]:
# Ferets and internal area
from topostats.measure.feret import get_feret_from_mask

# Calculate min, max feret diameters for the ring masks

image_dict_ferets = {}

for index, image_dict_item in pooled_path_grain_dict.items():
    image = image_dict_item["image"]
    mask = image_dict_item["predicted_mask"]
    p_to_nm = image_dict_item["p_to_nm"]
    pooled_path = image_dict_item["pooled_path"]
    path = image_dict_item["path"]

    # Get feret of path instead
    # Turn path into binary mask
    path_mask = np.zeros_like(mask)
    for point in path:
        path_mask[int(point[0]), int(point[1])] = 1
    results = get_feret_from_mask(mask_im=path_mask)

    # # Get feret of ring mask
    # ring_mask = mask == 1
    # results = get_feret_from_mask(mask_im=ring_mask)

    min_feret = results["min_feret"]
    max_feret = results["max_feret"]

    min_feret_coords = results["min_feret_coords"]
    max_feret_coords = results["max_feret_coords"]

    # Plot the ferets
    # plt.imshow(image)
    # plt.imshow(mask, alpha=0.5)
    # plt.plot(
    #     min_feret_coords[:, 1], min_feret_coords[:, 0], color="red", label=f"min feret: {min_feret * p_to_nm:.2f} nm"
    # )
    # plt.plot(
    #     max_feret_coords[:, 1], max_feret_coords[:, 0], color="blue", label=f"max feret: {max_feret * p_to_nm:.2f} nm"
    # )
    # plt.title(f"Min feret: {min_feret * p_to_nm:.2f} nm | Max feret: {max_feret * p_to_nm:.2f} nm")
    # plt.legend()
    # plt.show()

    image_dict_ferets[index] = image_dict_item
    image_dict_ferets[index]["min_feret"] = min_feret * p_to_nm
    image_dict_ferets[index]["max_feret"] = max_feret * p_to_nm


if True:
    sns.kdeplot([image_dict_ferets[i]["min_feret"] for i in image_dict_ferets], label="min feret")
    sns.kdeplot([image_dict_ferets[i]["max_feret"] for i in image_dict_ferets], label="max feret")
    plt.legend()
    plt.title(f"Feret diameters for {SAMPLE_TYPE} (n = {len(image_dict_ferets)})")
    plt.show()

# Save the grains with ferets and path

In [None]:
# save the ferets
today_date = datetime.now().strftime("%Y-%m-%d")
FERETS_SAVE_PATH = Path(f"/Users/sylvi/topo_data/hariborings/extracted_grains/cas9_{SAMPLE_TYPE}/date_{today_date}")
file_path = FERETS_SAVE_PATH / f"ferets_dict.pkl"
print(f"Saving ferets to {file_path}")
with open(file_path, "wb") as f:
    pickle.dump(image_dict_ferets, f)

In [None]:
# Calculate binding angle using the vector of the last n nanometres of the path


def angle_diff_signed(v1: np.ndarray, v2: np.ndarray):
    """Calculate the signed angle difference between two vectors.

    Parameters
    ----------
    v1: np.ndarray
        The first vector.
    v2: np.ndarray
        The second vector.

    Returns
    -------
    float
        The signed angle difference between the two vectors.
    """

    # Calculate if the new vector is clockwise or anticlockwise from the old vector

    # Calculate the angle between the vectors
    angle = np.arccos(np.clip(np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)), -1.0, 1.0))

    # Calculate the cross product
    cross = np.cross(v1, v2)

    # If the cross product is positive, the new vector is clockwise from the old vector
    if cross > 0:
        angle = -angle

    return angle


def angle_per_nm_non_closed(trace: np.ndarray, p_to_nm: float, plot: bool = False) -> np.ndarray:
    """Calculate the angle per nm of a non-closed trace.

    Parameters
    ----------
    trace: np.ndarray
        The trace to calculate the angle per nm of.
    p_to_nm: float
        The pixel to nm scaling factor.

    Returns
    -------
    np.ndarray
        The angle change per nm for each point in the trace
    """

    if plot:
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))

    angles_per_nm = np.zeros(len(trace))
    angle_diffs = np.zeros(len(trace))

    for index, point in enumerate(trace):
        # print(f"index: {index}")
        if plot:
            ax.scatter(point[1], point[0], c="purple", s=20)

        # Get the vectors to the previous and next points
        if index == 0:
            v_prev = point - trace[-1]
            v_next = trace[index + 1] - point
        if index == len(trace) - 1:
            v_prev = point - trace[index - 1]
            v_next = trace[0] - point
        else:
            v_prev = point - trace[index - 1]
            v_next = trace[index + 1] - point

        # print(f"vprev: {v_prev} vnext: {v_next}")

        # Normalise the vectors to unit length
        norm_v_prev = v_prev / np.linalg.norm(v_prev) * 0.1
        norm_v_next = v_next / np.linalg.norm(v_next) * 0.1

        if index != 0 and index != len(trace) - 1:
            angle = angle_diff_signed(v_prev, v_next)
        else:
            angle = 0

        if plot:
            # Plot the vectors
            ax.arrow(
                point[1], point[0], norm_v_prev[1], norm_v_prev[0], head_width=0.01, head_length=0.2, fc="r", ec="r"
            )
            ax.arrow(
                point[1], point[0], norm_v_next[1], norm_v_next[0], head_width=0.01, head_length=0.2, fc="b", ec="b"
            )
            # Write text for the angle
            ax.text(point[1], point[0], f"{np.degrees(angle):.2f}", fontsize=12, color="black")

        distance = np.linalg.norm(v_prev) * p_to_nm

        # print(f"distance: {distance:.4f} angle: {angle:.4f} angle per nm: {angle / distance:.4f}")

        angles_per_nm[index] = angle / distance
        angle_diffs[index] = angle

    if plot:
        plt.plot(trace[:, 1], trace[:, 0], "k")
        plt.show()

    return angles_per_nm, angle_diffs


def plot_images(
    images: list,
    masks: list,
    grain_indexes: list,
    px_to_nms: list,
    paths: list,
    binding_vector_positions: list,
    width=5,
    VMIN=VMIN,
    VMAX=VMAX,
    cmap=CMAP,
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 4, num_rows * 4))
    for i, (image, mask, grain_index, p_to_nm, path, binding_vector_position_tuple) in enumerate(
        zip(images, masks, grain_indexes, px_to_nms, paths, binding_vector_positions)
    ):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        im_ax.set_title(f"Grain {grain_index} {p_to_nm} p/nm")
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int))
        mask_ax.plot(path[:, 1], path[:, 0], color="white")
        mask_ax.plot([])
        mask_ax.axis("off")

    fig.tight_layout()
    plt.show()


binding_angle_grain_dict = {}

for index, grain_dict in image_dict_ferets.items():
    image = grain_dict["image"]
    pooled_path = grain_dict["pooled_path"]
    predicted_mask = grain_dict["predicted_mask"]
    p_to_nm = grain_dict["p_to_nm"]
    intersection_labels = grain_dict["intersection_labels"]

    # Sample points every n nm
    nm_sample_distance = 1
    nm_sampled_path = []
    pixel_sample_distance = nm_sample_distance / p_to_nm
    for i in range(len(pooled_path)):
        if i == 0:
            nm_sampled_path.append(pooled_path[i])
        else:
            last_index_nm_sampled_path = len(nm_sampled_path) - 1
            if np.linalg.norm(pooled_path[i] - nm_sampled_path[-1]) > pixel_sample_distance:
                nm_sampled_path.append(pooled_path[i])
    nm_sampled_path = np.array(nm_sampled_path)

    # Get angle diffs
    angles_per_nm, angle_diffs = angle_per_nm_non_closed(nm_sampled_path, p_to_nm, plot=False)

    # Smooth the angles per nm
    angles_per_nm = gaussian_filter1d(angles_per_nm, 3)

    total_angle_diff_over_path = np.sum(angle_diffs)

    angle_diff_between_start_end = total_angle_diff_over_path - np.pi

    # plt.imshow(image)
    # plt.plot(nm_sampled_path[:, 1], nm_sampled_path[:, 0], color="orange")
    # plt.title(f"{np.degrees(angle_diff_between_start_end):.2f}")
    # plt.show()

    binding_angle_grain_dict[index] = grain_dict
    binding_angle_grain_dict[index]["binding_angle"] = angle_diff_between_start_end
    binding_angle_grain_dict[index]["angles_per_nm"] = angles_per_nm

if PLOT_RESULTS:
    # Plot binding angle
    binding_angles = np.array([binding_angle_grain_dict[i]["binding_angle"] for i in binding_angle_grain_dict])
    sns.kdeplot(binding_angles)
    plt.title(f"Binding angles for {SAMPLE_TYPE} (n = {len(binding_angles)})")
    plt.show()

    mean_angles_per_nm = [np.mean(binding_angle_grain_dict[i]["angles_per_nm"]) for i in binding_angle_grain_dict]
    sns.kdeplot(mean_angles_per_nm)
    plt.title(f"Mean angles per nm for {SAMPLE_TYPE} (n = {len(mean_angles_per_nm)})")
    plt.show()

    std_angles_per_nm = [np.std(binding_angle_grain_dict[i]["angles_per_nm"]) for i in binding_angle_grain_dict]
    sns.kdeplot(std_angles_per_nm)
    plt.title(f"Standard deviation of angles per nm for {SAMPLE_TYPE} (n = {len(std_angles_per_nm)})")
    plt.show()

In [None]:
# Get connection regions as a proportion of perimeter

perimeter_intersection_grain_dict = {}

for index, grain_dict in binding_angle_grain_dict.items():
    # print(f"\ngrain index: {index}")

    try:
        image = grain_dict["image"]
        mask = grain_dict["predicted_mask"]
        p_to_nm = grain_dict["p_to_nm"]
        pooled_path = grain_dict["pooled_path"]
        intersection_labels = grain_dict["intersection_labels"]

        # Get perimeter of the gem region
        gem_region = mask == 2
        gem_region = binary_fill_holes(gem_region.astype(bool))

        gem_outline = binary_dilation(gem_region) ^ gem_region
        gem_outline[gem_region == 1] = 0

        # Skeletonize to get rid of any rogue corners
        gem_outline = skeletonize(gem_outline, method="lee").astype(bool)
        intersection_labels = label(skeletonize(intersection_labels.astype(bool), method="lee").astype(bool))

        gem_region_with_intersections = np.copy(gem_outline).astype(int)
        gem_region_with_intersections[intersection_labels > 0] = 2

        # Trace the outline
        # Start at a random point
        gem_tracing_hist = np.copy(gem_region_with_intersections)
        start_point = np.array(np.where(gem_tracing_hist == 1))[:, 0]
        # print(f"start point: {start_point}")

        point_coords = start_point
        gem_trace = []
        intersection_regions = []
        in_intersection = False
        intersection_start = None
        len_gem_outline = len(np.argwhere(gem_tracing_hist > 0))
        # print(f"number of > 0 pixels in gem tracing hist : {len_gem_outline}")

        # plt.imshow(image)
        # plt.show()
        # plt.imshow(mask)
        # plt.show()

        # fig, ax = plt.subplots(1, 1, figsize=(10, 10))
        # ax.imshow(gem_tracing_hist)
        # plt.show()

        for tracing_index in range(len_gem_outline - 1):
            # print(f"\ntracing index: {tracing_index}")
            # Add point to trace
            gem_trace.append(point_coords)

            # Get neighbours
            neighbours = gem_tracing_hist[
                point_coords[0] - 1 : point_coords[0] + 2, point_coords[1] - 1 : point_coords[1] + 2
            ]
            gem_tracing_hist[point_coords[0], point_coords[1]] = -1
            # print(neighbours)
            # Grab an index of a non zero point in neighbourhood
            nonzero_neighbours = np.argwhere(neighbours > 0)
            # print(f"nonzero neighbours: {nonzero_neighbours}")

            if tracing_index == 0:
                point_relative_coords = nonzero_neighbours[0]
            else:
                if len(nonzero_neighbours) > 1:
                    raise ValueError(
                        f"Point {point_coords} has more than one neighbour: {neighbours} {nonzero_neighbours}"
                    )

                point_relative_coords = nonzero_neighbours[0]

            point_relative_coords -= 1

            # print(f"point relative coords: {point_relative_coords}")

            point_coords = point_coords + point_relative_coords
            # print(f"new point coords: {point_coords}")
            point_value = gem_tracing_hist[point_coords[0], point_coords[1]]
            # print(f"point value: {point_value}")

            if point_value == 2 and not in_intersection:
                in_intersection = True
                intersection_start = tracing_index + 1
            elif point_value == 2:
                pass
            elif point_value == 1 and in_intersection:
                intersection_end = tracing_index - 1
                intersection_regions.append(
                    {
                        "start": intersection_start,
                        "end": intersection_end,
                        "midpoint": intersection_start + (intersection_end - intersection_start) // 2,
                    }
                )
                in_intersection = False

        if in_intersection:
            intersection_end = tracing_index - 1
            intersection_regions.append(
                {
                    "start": intersection_start,
                    "end": intersection_end,
                    "midpoint": intersection_start + (intersection_end - intersection_start) // 2,
                }
            )
            in_intersection = False

        if len(intersection_regions) < 2:
            raise ValueError(f"Only one intersection region {intersection_regions}")
        if len(intersection_regions) > 2:
            raise ValueError(f"More than two intersection regions: {intersection_regions}")

        # print(f"intersection regions: {intersection_regions}")

        # Distances between intersection midpoints
        # Inner distance (where the end of the trace is not encountered)
        inner_distance_pixels = intersection_regions[1]["midpoint"] - intersection_regions[0]["midpoint"]
        outer_distance_pixels = len(gem_trace) - inner_distance_pixels
        # print(f"inner distance: {inner_distance}")
        # print(f"outer distance: {outer_distance}")
        min_midpoint_distance_pixels = np.min([inner_distance_pixels, outer_distance_pixels])
        max_midpoint_distance_pixels = np.max([inner_distance_pixels, outer_distance_pixels])
        min_midpoint_distance_nm = min_midpoint_distance_pixels * p_to_nm
        max_midpoint_distance_nm = max_midpoint_distance_pixels * p_to_nm
        midpoint_distance_ratio = min_midpoint_distance_pixels / max_midpoint_distance_pixels

        # plt.imshow(gem_region_with_intersections)
        # Create colour plot
        coloured_gem_trace = np.zeros_like(image)
        for gem_trace_index, gem_trace_coord in enumerate(gem_trace):
            coloured_gem_trace[gem_trace_coord[0], gem_trace_coord[1]] = gem_trace_index

        # plt.imshow(coloured_gem_trace)
        # plt.scatter(start_point[1], start_point[0], c="white", s=10)
        # plt.show()

        # print(f"intersection midpoint perimeter position ratio: {midpoint_distance_ratio}")

        perimeter_intersection_grain_dict[index] = grain_dict
        perimeter_intersection_grain_dict[index]["intersection_min_midpoint_distance"] = min_midpoint_distance_nm
        perimeter_intersection_grain_dict[index]["intersection_max_midpoint_distance"] = max_midpoint_distance_nm
        perimeter_intersection_grain_dict[index]["intersection_midpoint_distance_ratio"] = midpoint_distance_ratio

    # Sometimes the dilation edge finder leaves a non-skeletonized outline and so tracing breaks
    except ValueError:
        print(f"Index {index} failed due to non-skeletonised gem outline")

    except IndexError:
        print(f"Index {index} failed due to gem outline being outside image bounds")

In [None]:
# Elongation metric


def signed_angle_between_vectors(v1, v2):
    # Check that neither vector is 0
    if np.all(v1 == 0) or np.all(v2 == 0):
        raise ValueError("One of the vectors is 0")

    # Calculate the dot product
    dot_product = np.dot(v1, v2)

    # Calculate the norms of each vector
    norm_v1 = np.linalg.norm(v1)
    norm_v2 = np.linalg.norm(v2)

    # Calculate the cosine of the angle
    cos_theta = dot_product / (norm_v1 * norm_v2)

    angle = np.arccos(np.clip(cos_theta, -1.0, 1.0))

    # Calculate the cross product
    cross_product = np.cross(v1, v2)

    # If the cross product is positive, then the angle is negative
    if cross_product > 0:
        angle = -angle

    # Convert to angle and return
    return angle


def rotate_points(points: np.ndarray, angle: float):
    # Rotate the points by the angle
    # print(f"rotating by {np.degrees(angle)} degrees")
    # rotation_matrix = np.array([[-np.sin(angle), np.cos(angle)], [np.cos(angle), np.sin(angle)]])
    rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])
    rotated_points = np.dot(points, rotation_matrix)
    return rotated_points


def plot_images_with_paths(
    images: list,
    masks: list,
    paths: list,
    centroids: list,
    midpoints: list,
    elongations: list,
    px_to_nms: list,
    bounding_boxes: list,
    aspect_ratios: list,
    width=5,
    save_dir=None,
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, ax = plt.subplots(
        num_rows, width * num_images_in_batch, figsize=(width * 2 * num_images_in_batch, num_rows * 4)
    )
    for i, (image, mask, path) in enumerate(zip(images, masks, paths)):
        im_ax = ax[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap="viridis")
        im_ax.axis("off")
        mask_ax = ax[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask, cmap="viridis")
        mask_ax.plot(path[:, 1], path[:, 0], "r-")
        mask_ax.axis("off")
        mask_ax.scatter(centroids[i][1], centroids[i][0], c="r", s=5)
        mask_ax.scatter(midpoints[i][1], midpoints[i][0], c="g", s=5)
        # Draw a line between midpoint and centroid
        ax[i // width, i % width * 2 + 1].plot(
            [centroids[i][1], midpoints[i][1]],
            [centroids[i][0], midpoints[i][0]],
            "g-",
        )
        # Plot the bounding box
        ax[i // width, i % width * 2 + 1].plot(bounding_boxes[i][:, 1], bounding_boxes[i][:, 0], color="orange")
        ax[i // width, i % width * 2 + 1].set_title(
            f"||| elongation: {elongations[i]:.2f} nm | \naspect ratio: {aspect_ratios[i]:.2f} | \npx_to_nm: {px_to_nms[i]} |||",
            fontsize=10,
        )
    fig.tight_layout()

    if save_dir is not None:
        plt.savefig(save_dir, dpi=500)


def align_points_to_vertical(points: np.ndarray, orientation_vector: np.ndarray):
    # Align the points to the vertical by rotating them by the angle between the orientation vector and the vertical
    vertical_vector = np.array([1, 0])
    angle = signed_angle_between_vectors(orientation_vector, vertical_vector)
    rotated_points = rotate_points(points, angle)
    rotated_orientation_vector = rotate_points(orientation_vector, angle)
    return rotated_points, rotated_orientation_vector, -angle


image_dict_elongation = {}

for index, grain_dict in perimeter_intersection_grain_dict.items():
    image = grain_dict["image"]
    mask = grain_dict["predicted_mask"]
    p_to_nm = grain_dict["p_to_nm"]
    pooled_path = grain_dict["pooled_path"]

    # Get the centroid of the path
    path_centroid = np.mean(pooled_path, axis=0)

    path_points_shifted = np.copy(pooled_path) - path_centroid

    # Get the start and end points of the path
    start_point = pooled_path[0]
    end_point = pooled_path[-1]
    start_point_end_point_midpoint = (start_point + end_point) / 2

    elongation_vector = start_point_end_point_midpoint - path_centroid
    elongation_vector_points = np.array(
        [[start_point_end_point_midpoint[0], start_point_end_point_midpoint[1]], [path_centroid[0], path_centroid[1]]]
    )

    shifted_elongation_vector_points = elongation_vector_points - path_centroid

    # Distance between midpoint and centroid
    distance = np.linalg.norm(elongation_vector)

    # Rotate the points
    shifted_rotated_path_points, shifted_rotated_elongation_vector, angle_rad = align_points_to_vertical(
        path_points_shifted, elongation_vector
    )

    # Plot the points after rotation
    # Move the points back to the original position
    shifted_rotated_points = shifted_rotated_path_points
    shifted_rotated_elongation_vector_points = np.array([[0, 0], shifted_rotated_elongation_vector])
    # Find the bounding box of the rotated points
    min_y = np.min(shifted_rotated_points[:, 0])
    max_y = np.max(shifted_rotated_points[:, 0])
    min_x = np.min(shifted_rotated_points[:, 1])
    max_x = np.max(shifted_rotated_points[:, 1])

    rotated_bounding_box = np.array([[min_y, min_x], [max_y, min_x], [max_y, max_x], [min_y, max_x], [min_y, min_x]])

    bounding_box_width = (max_x - min_x) * p_to_nm
    bounding_box_length = (max_y - min_y) * p_to_nm

    aspect_ratio = bounding_box_length / bounding_box_width

    # Un-rotate the bounding box and points
    unrotated_bounding_box = rotate_points(rotated_bounding_box, angle_rad) + path_centroid
    unrotated_points = rotate_points(shifted_rotated_points, angle_rad) + path_centroid
    unrotated_elongation_vector_points = (
        rotate_points(shifted_rotated_elongation_vector_points, angle_rad) + path_centroid
    )

    # Plot the unrotated points and the bounding box where the bounding box width is coloured green and the length is orange
    # plt.imshow(image)
    # plt.plot(unrotated_points[:, 1], unrotated_points[:, 0], color="white")
    # plt.plot(
    #     unrotated_elongation_vector_points[:, 1],
    #     unrotated_elongation_vector_points[:, 0],
    #     color="blue",
    #     label="joint bisector to path centroid",
    # )
    # # plot the vertical sides of the bounding box orange and the horizontal sides green
    # plt.plot(
    #     unrotated_bounding_box[[0, 1], 1],
    #     unrotated_bounding_box[[0, 1], 0],
    #     color="red",
    #     label=f"bounding box length = {(bounding_box_length):.2f}",
    # )
    # plt.plot(
    #     unrotated_bounding_box[[1, 2], 1],
    #     unrotated_bounding_box[[1, 2], 0],
    #     color="orange",
    #     label=f"bounding box width = {(bounding_box_width):.2f}",
    # )
    # plt.legend()
    # plt.title(f"Path bounding box aspect ratio: {aspect_ratio:.2f}")
    # plt.show()

    image_dict_elongation[index] = grain_dict
    image_dict_elongation[index]["path_centroid"] = path_centroid
    image_dict_elongation[index]["path_start_point"] = start_point
    image_dict_elongation[index]["path_end_point"] = end_point
    image_dict_elongation[index]["path_start_point_end_point_midpoint"] = start_point_end_point_midpoint
    image_dict_elongation[index]["path_elongation_distance_nm"] = distance * p_to_nm
    image_dict_elongation[index]["path_bounding_box_length"] = max_y - min_y
    image_dict_elongation[index]["path_bounding_box"] = unrotated_bounding_box
    image_dict_elongation[index]["path_aspect_ratio_length_over_width"] = aspect_ratio
    image_dict_elongation[index]["path_bounding_box_height"] = bounding_box_length
    image_dict_elongation[index]["path_bounding_box_width"] = bounding_box_width


if PLOT_RESULTS:
    # Plot the paths and centroids
    plot_images_with_paths(
        images=[image_dict_elongation[i]["image"] for i in image_dict_elongation],
        masks=[image_dict_elongation[i]["predicted_mask"] for i in image_dict_elongation],
        paths=[image_dict_elongation[i]["path"] for i in image_dict_elongation],
        centroids=[image_dict_elongation[i]["path_centroid"] for i in image_dict_elongation],
        midpoints=[image_dict_elongation[i]["path_start_point_end_point_midpoint"] for i in image_dict_elongation],
        elongations=[image_dict_elongation[i]["path_elongation_distance_nm"] for i in image_dict_elongation],
        px_to_nms=[image_dict_elongation[i]["p_to_nm"] for i in image_dict_elongation],
        bounding_boxes=[image_dict_elongation[i]["path_bounding_box"] for i in image_dict_elongation],
        aspect_ratios=[image_dict_elongation[i]["path_aspect_ratio"] for i in image_dict_elongation],
        save_dir=None,
    )

## Curvature - Turn in distance metric

In [None]:
def weighted_average_position(array: np.ndarray, clip_min: float) -> float:
    """Calculate the weighted average position in an array.

    Parameters
    ----------
    array: np.ndarray
        The array to calculate the weighted average position of.

    Returns
    -------
    float
        The weighted average position in the array.
    """
    positions = np.arange(len(array))
    weights = np.copy(array)
    weights[weights < clip_min] = clip_min
    weighted_average = np.average(positions, weights=array)
    return weighted_average


testarr = np.array([0, 1, 3, 5, 10, 10000000, 2, 12, 0])

print(weighted_average_position(testarr, clip_min=0))


def weighted_mean_defect_position(
    defect_angle_shifts: np.ndarray, defect_start_index: int, max_index: int, clip_min: float = 0.0
):
    weighted_mean_angle_shift_index = (
        weighted_average_position(defect_angle_shifts, clip_min=clip_min) + defect_start_index
    )
    weighted_mean_angle_shift_index_int = int(np.round(weighted_mean_angle_shift_index))
    if weighted_mean_angle_shift_index >= len(angles_per_nm):
        weighted_mean_angle_shift_index -= len(angles_per_nm)
    if weighted_mean_angle_shift_index_int >= len(angles_per_nm):
        weighted_mean_angle_shift_index_int -= len(angles_per_nm)

    return weighted_mean_angle_shift_index, weighted_mean_angle_shift_index_int


def find_distance_window_looped(distances: np.ndarray, window_distance: float, start_index: int):
    """Find a window that covers a certain distance in a set of points in a loop (but the start point is not repeated at the end)
    where the distances between the points are known including the distance between the end point and the start point.
    """

    distances = np.copy(distances)

    window_current_distance = 0
    window_start_index = start_index
    proposed_end_index = start_index
    found_end = False
    max_iterations = len(distances)
    iterations = 0
    while not found_end:
        # print(f"start index: {window_start_index}, proposed end index: {proposed_end_index}, distance so far: {window_current_distance}")
        # Check if distance is greater than the requested distance
        if window_current_distance > window_distance:
            # print(f"found end at index: {proposed_end_index}")
            found_end = True
            window_end_index = proposed_end_index
            return window_start_index, window_end_index, window_current_distance
        else:
            # Increment the window distance
            # print(f"adding distance: {distances[proposed_end_index]}")
            window_current_distance += distances[proposed_end_index]
            # Increment the proposed end index
            if proposed_end_index < len(distances) - 1:
                proposed_end_index += 1
            else:
                proposed_end_index = 0

        # Safety for infinite loop
        if iterations > max_iterations:
            raise ValueError("Max iterations reached")
        iterations += 1

In [None]:
def combine_two_defects(
    defect_0: dict,
    defect_1: dict,
    max_trace_index: int,
    p_to_nm: float,
):
    """Combine two defects into one defect."""

    # Combine the defects
    combined_indexes = np.unique(np.append(defect_0["indexes"], defect_1["indexes"]))

    # Find the start and end points of the combined indexes
    # Check if the combined indexes span the end of the array
    if max_trace_index in combined_indexes and 0 in combined_indexes:
        # If so, the starting index will be the index without a number preceding it and the end index will be the index without a number following it
        # To find end index, count forward from 0 until there is a number missing
        for candidate_end_index in range(max_trace_index + 1):
            if candidate_end_index + 1 not in combined_indexes:
                end_index = candidate_end_index
                break
        # To find start index, count backward from the end until there is a number missing
        for candidate_start_index in range(max_trace_index, 0, -1):
            if candidate_start_index - 1 not in combined_indexes:
                start_index = candidate_start_index
                break
    else:
        start_index = np.min(combined_indexes)
        end_index = np.max(combined_indexes)

    # Calculate the midpoint of the defect
    if end_index > start_index:
        midpoint = int(np.round((start_index + end_index) / 2))
    else:
        midpoint = int(np.round((start_index + end_index + max_trace_index + 1) / 2))
        if midpoint >= max_trace_index + 1:
            midpoint -= max_trace_index + 1

    combined_defect = {
        "start_index": start_index,
        "end_index": end_index,
        "indexes": combined_indexes,
        "midpoint_index": midpoint,
        "midpoint_nm": midpoint * p_to_nm,
    }

    return combined_defect


def combine_overlapping_defects(defects: list, max_trace_index: int, p_to_nm: float):
    """Combine any overlapping defects in a list of defects where each defect is a dictionary of statistics for the defect.

    Defects start at at starting point and end at an ending point. If two defects share any points then they are overlapping
    and should be combined. Defects can span the start and end of the array, which makes combination difficult.
    """

    # Flag to check if any defects have been combined
    defects_were_combined = True
    while defects_were_combined:
        # Reset the flag
        defects_were_combined = False
        # For each defect, check if it overlaps with any other defect
        for i, defect_0 in enumerate(defects):
            for j, defect_1 in enumerate(defects):
                if i != j:
                    # Check if the defects overlap
                    if len(np.intersect1d(defect_0["indexes"], defect_1["indexes"])) > 0:
                        # Combine the defects
                        combined_defect = combine_two_defects(defect_0, defect_1, max_trace_index, p_to_nm=p_to_nm)
                        # Remove the old defects
                        defects.pop(i)
                        defects.pop(j - 1)
                        # Add the new defect
                        defects.append(combined_defect)
                        # Set the flag
                        defects_were_combined = True
                        break
            if defects_were_combined:
                break

    return defects

# Height traces of pooled path

In [None]:
# For each pooled path, get the heights of each point in the path

image_dict_heights = {}

for index, grain_dict in image_dict_elongation.items():
    pooled_path = grain_dict["pooled_path"]
    image = grain_dict["image"]

    # plt.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
    # plt.plot(pooled_path[:, 1], pooled_path[:, 0], color="aqua", alpha=0.5)
    # plt.show()

    # Grab height of each point in the path
    heights = []
    for point in pooled_path:
        heights.append(image[int(np.round(point[0])), int(np.round(point[1]))])

    # plt.plot(heights)
    # plt.show()

    image_dict_heights[index] = grain_dict
    image_dict_heights[index]["heights"] = np.array(heights)

In [None]:
# Turn in distance metric

plotting = True
plot_results = True
plot_classifications = True
vmin = 0
vmax = 2.5


# Define what a defect is by any section that turns d degrees in n nm
defect_degrees_value = 75
defect_nm_value = 10.0
# pasty_distance_deviation_threshold_nm = 3.0
pasty_distance_deviation_threshold_percentage = 0.08

turn_in_distance_grain_dict = {}

for index, grain_data in image_dict_heights.items():
    # if index != 14:
    # continue

    try:
        print(f"grain {index}")
        # print(f"keys: {grain_data.keys()}")

        grain_image = grain_data["image"]
        grain_mask = grain_data["predicted_mask"]
        p_to_nm = grain_data["p_to_nm"]
        path = grain_data["path"]
        pooled_trace = grain_data["pooled_path"]
        height_trace = grain_data["heights"]

        # n = 1
        # pooled_trace_every_nth_point = pooled_trace[::n]
        # # Create a copy where the first point is appended to the end to calculate the distances between the last and first point
        # if np.array_equal(pooled_trace_every_nth_point[0], pooled_trace_every_nth_point[-1]):
        #     pooled_trace_every_nth_point_extra_start = np.copy(pooled_trace_every_nth_point)
        #     pooled_trace_every_nth_point = pooled_trace_every_nth_point[:-1]
        # else:
        #     pooled_trace_every_nth_point_extra_start = np.copy(pooled_trace_every_nth_point)
        #     pooled_trace_every_nth_point_extra_start = np.append(
        #         pooled_trace_every_nth_point_extra_start, [pooled_trace_every_nth_point[0]], axis=0
        #     )

        # distances_between_points = np.linalg.norm(np.diff(pooled_trace_every_nth_point_extra_start, axis=0), axis=1)
        # total_distance = np.sum(distances_between_points)

        # distances_between_points_nm = distances_between_points * p_to_nm
        # total_distance_nm = total_distance * p_to_nm

        # angles_per_nm, angle_diffs = angle_per_nm(pooled_trace_every_nth_point, p_to_nm)

        # Calculate distances between points
        pooled_trace_with_start_point = np.append(pooled_trace, [pooled_trace[0]], axis=0)
        distances_between_points = np.linalg.norm(np.diff(pooled_trace_with_start_point, axis=0), axis=1)
        distances_between_points_nm = distances_between_points * p_to_nm
        total_distance_nm = np.sum(distances_between_points_nm)

        # Detect defects
        in_defect = False
        defects = []
        maximum_total_angle_shift = 0
        maximum_angle_shift_index = 0
        # to detect when the start of the window is > 0
        started_tracing = False
        for point_index, (point) in enumerate(zip(pooled_trace)):
            # print(f"point index: {point_index}, point: {point}")

            if plotting:
                plt.imshow(grain_image, cmap=CMAP, vmin=vmin, vmax=vmax)
                plt.scatter(pooled_trace[:, 1], pooled_trace[:, 0], c="k", s=20)
                plt.scatter(
                    pooled_trace[point_index, 1],
                    pooled_trace[point_index, 0],
                    c="white",
                    s=40,
                )

            # Get window
            window_start_index, window_end_index, window_distance = find_distance_window_looped(
                distances_between_points_nm, defect_nm_value, point_index
            )

            if not started_tracing:
                # Continue until the window start is at the start of the trace
                if window_start_index > 0:
                    # print(f"Started trace: window start index: {window_start_index} window end index: {window_end_index}")
                    started_tracing = True
                else:
                    # print(
                    #     f"Not started trace yet: window start index: {window_start_index} window end index: {window_end_index}"
                    # )
                    continue

            # Break when the window reaches the end. Identified by the window end index being less than the start index
            if started_tracing:
                if window_end_index < window_start_index:
                    # print(f"End of trace - window end index: {window_end_index}, window start index: {window_start_index}")
                    break

            # print(
            #     f"  window start index: {window_start_index}, window end index: {window_end_index}, window distance: {window_distance}"
            # )

            if plotting:
                plt.scatter(
                    pooled_trace[window_start_index, 1],
                    pooled_trace[window_start_index, 0],
                    c="b",
                    s=20,
                )
                plt.scatter(
                    pooled_trace[window_end_index, 1],
                    pooled_trace[window_end_index, 0],
                    c="r",
                    s=20,
                )

            # THE END CONDITIONS WON'T HAPPEN FOR NON LOOPING TRACES
            # Calculate the starting vector as the mean vector between the previous point and the starting point of the window, and the starting point and next point
            # if window_start_index == 0:
            #     v0 = pooled_trace[window_start_index] - pooled_trace[-1]
            #     v1 = pooled_trace[window_start_index + 1] - pooled_trace[window_start_index]
            # elif window_start_index == len(pooled_trace) - 1:
            #     v0 = pooled_trace[window_start_index] - pooled_trace[window_start_index - 1]
            #     v1 = pooled_trace[0] - pooled_trace[window_start_index]
            # else:

            # v0 = pooled_trace[window_start_index] - pooled_trace[window_start_index - 1]
            # v1 = pooled_trace[window_start_index + 1] - pooled_trace[window_start_index]

            # mean_start_vector = (v0 + v1) / 2

            # Calculate the vector as the vector between the starting point of the window and the next point
            start_vector = pooled_trace[window_start_index + 1] - pooled_trace[window_start_index]

            # Calculate the ending vector as the mean vector between the ending point of the window and the next point, and the ending point and previous point
            # if window_end_index == 0:
            #     v0 = pooled_trace[window_end_index] - pooled_trace[-1]
            #     v1 = pooled_trace[window_end_index + 1] - pooled_trace[window_end_index]
            # elif window_end_index == len(pooled_trace) - 1:
            #     v0 = pooled_trace[window_end_index] - pooled_trace[window_end_index - 1]
            #     v1 = pooled_trace[0] - pooled_trace[window_end_index]
            # else:

            # v0 = pooled_trace[window_end_index] - pooled_trace[window_end_index - 1]
            # v1 = pooled_trace[window_end_index + 1] - pooled_trace[window_end_index]

            # mean_end_vector = (v0 + v1) / 2

            # Calculate the vector as the penultimate point and the last point vector
            end_vector = pooled_trace[window_end_index] - pooled_trace[window_end_index - 1]

            # angle_between_start_and_end = angle_diff_signed(mean_start_vector, mean_end_vector)
            angle_between_start_and_end = signed_angle_between_vectors(start_vector, end_vector)

            # print(
            #     f"angle between start and end: {angle_between_start_and_end} ({np.degrees(angle_between_start_and_end)} degrees) \n window start index: {window_start_index}, window end index: {window_end_index} length: {len(pooled_trace)}"
            # )

            # print(
            #     f"  angle between start and end: {angle_between_start_and_end} ({np.degrees(angle_between_start_and_end)} degrees)"
            # )

            # See if the angle shift is greater than the defect_degrees_value
            # Calculate total angle change by summing the angle shifts in the window
            # total_angle_shift = 0
            # # End index might be lower than start index if the window wraps around
            # if window_end_index > window_start_index:
            #     window_angle_shifts = angle_diffs[window_start_index : window_end_index + 1]
            # else:
            #     window_angle_shifts = np.append(angle_diffs[window_start_index:], angle_diffs[: window_end_index + 1])
            # print(f"window end index: {window_end_index}, window start index: {window_start_index}, angle_diffs: {angle_diffs}")

            # print(f"total angle shift: {total_angle_shift} ({np.degrees(total_angle_shift)} degrees)")

            if np.abs(angle_between_start_and_end) > np.radians(defect_degrees_value):
                # print("window angle > defect_degrees_value")
                # print(f"currently in defect: {in_defect}")
                if not in_defect:
                    print(f"@@@ DEFECT START at index: {point_index} window end index: {window_end_index}")
                    print(f"vectors: {start_vector}, {end_vector}")
                    if plotting:
                        # Plot between start and end index following the pooled trace
                        if window_end_index > window_start_index:
                            plt.plot(
                                pooled_trace[window_start_index : window_end_index + 1, 1],
                                pooled_trace[window_start_index : window_end_index + 1, 0],
                                "g",
                            )
                        else:
                            plt.plot(
                                np.append(
                                    pooled_trace[window_start_index:, 1],
                                    pooled_trace[: window_end_index + 1, 1],
                                ),
                                np.append(
                                    pooled_trace[window_start_index:, 0],
                                    pooled_trace[: window_end_index + 1, 0],
                                ),
                                "g",
                            )
                        plt.title(
                            f"defect start index: {point_index}, end index: {window_end_index},({np.degrees(angle_between_start_and_end)} degrees)"
                        )
                    in_defect = True
                    defect_start_index = point_index
                    if window_end_index > window_start_index:
                        defect_indexes = np.arange(window_start_index, window_end_index + 1)
                    else:
                        defect_indexes = np.append(
                            np.arange(window_start_index, len(pooled_trace)), np.arange(0, window_end_index + 1)
                        )
                    maximum_total_angle_shift = angle_between_start_and_end

                    # print(f"defect indexes: {defect_indexes}")
                    # print(f"defect shifts: {defect_shifts}")
                    # print(
                    #     f"initial total angle shift: {maximum_total_angle_shift} ({np.degrees(maximum_total_angle_shift)} degrees)"
                    # )

                    # print(f"starting defect, max shift: {maximum_angle_shift}, index: {maximum_angle_shift_index}")
                else:
                    # Add the new point(s) to the defect indexes. Note there may be more than one point added due to a high density of points in the window since we are sampling every n nm rather than n points
                    # Careful of the case where the window wraps around
                    if window_end_index > window_start_index:
                        defect_indexes = np.append(defect_indexes, np.arange(window_start_index, window_end_index + 1))
                    else:
                        defect_indexes = np.append(
                            defect_indexes,
                            np.append(
                                np.arange(window_start_index, len(pooled_trace)), np.arange(0, window_end_index + 1)
                            ),
                        )
                    # Ensure each index is unique
                    defect_indexes = np.unique(defect_indexes)

                    # Plot between start and end index following the pooled trace
                    if plotting:
                        if window_end_index > window_start_index:
                            plt.plot(
                                pooled_trace[window_start_index : window_end_index + 1, 1],
                                pooled_trace[window_start_index : window_end_index + 1, 0],
                                "lime",
                                linewidth=2,
                            )
                        else:
                            plt.plot(
                                np.append(
                                    pooled_trace[window_start_index:, 1],
                                    pooled_trace[: window_end_index + 1, 1],
                                ),
                                np.append(
                                    pooled_trace[window_start_index:, 0],
                                    pooled_trace[: window_end_index + 1, 0],
                                ),
                                "g",
                            )
                        plt.title(
                            f"defect start index: {defect_start_index}, end index: {window_end_index}, total angle shift: {angle_between_start_and_end:.1f} ({np.degrees(angle_between_start_and_end):.1f} degrees)"
                        )
                # print(f"defect indexes: {defect_indexes}")

            else:
                if plotting:
                    plt.title(
                        f"no defect at index: {point_index} window end index: {window_end_index}\ntotal angle shift: {angle_between_start_and_end:.1f} ({np.degrees(angle_between_start_and_end):.1f} degrees)\n threshold: {defect_degrees_value} degrees in {defect_nm_value} nm"
                    )
                if in_defect:
                    in_defect = False
                    defect_end_index = window_end_index
                    # print(
                    #     f"@@@ DEFECT DONE: start index: {defect_start_index}, end index: {defect_end_index}, window angle shifts: {window_angle_shifts} max shift: {maximum_angle_shift}, index: {maximum_angle_shift_index}"
                    # )
                    # print(f"defect indexes: {defect_indexes}")
                    # print(f"defect shifts: {defect_shifts}")

                    # Calculate the midpoint of the indexes
                    if defect_end_index > defect_start_index:
                        defect_midpoint = int(np.mean([defect_start_index, defect_end_index]))
                    else:
                        defect_midpoint = int(np.mean([defect_start_index, defect_end_index + len(pooled_trace)]))
                        if defect_midpoint >= len(pooled_trace):
                            defect_midpoint -= len(pooled_trace)

                    defects.append(
                        {
                            "start_index": defect_start_index,
                            "end_index": defect_end_index,
                            "maximum_total_angle_shift": maximum_total_angle_shift,
                            "indexes": defect_indexes,
                            "midpoint_index": defect_midpoint,
                            "midpoint_nm": defect_midpoint * p_to_nm,
                        }
                    )

            plt.show()

        if in_defect:
            in_defect = False
            defect_end_index = window_end_index
            # print(f"@@@ DEFECT DONE: start index: {defect_start_index}, end index: {defect_end_index} max shift: {maximum_angle_shift}, index: {maximum_angle_shift_index}")

            # Calculate the midpoint of the indexes
            if defect_end_index > defect_start_index:
                defect_midpoint = int(np.mean([defect_start_index, defect_end_index]))
            else:
                defect_midpoint = int(np.mean([defect_start_index, defect_end_index + len(pooled_trace)]))
                if defect_midpoint >= len(pooled_trace):
                    defect_midpoint -= len(pooled_trace)

            defects.append(
                {
                    "start_index": defect_start_index,
                    "end_index": defect_end_index,
                    "maximum_total_angle_shift": maximum_total_angle_shift,
                    "indexes": defect_indexes,
                    "midpoint_index": defect_midpoint,
                    "midpoint_nm": defect_midpoint * p_to_nm,
                }
            )

        # Combine any overlapping regions
        combined_defects = combine_overlapping_defects(defects, max_trace_index=len(pooled_trace) - 1, p_to_nm=p_to_nm)
        defects = combined_defects

        if len(defects) > 0:
            # Sort the list
            print(f"@@@ DEFECTS: {defects}")
            defects = sorted(defects, key=lambda x: x["midpoint_index"])
            # Calculate the distances between each defect along the trace
            to_find = [defect["midpoint_index"] for defect in defects]
            found_original = False
            current_index = to_find[0]
            current_position = pooled_trace[current_index]
            previous_position = current_position
            distances = []
            distance = 0
            while len(to_find) > 0:
                # Update old position
                previous_position = current_position
                # Increment the current position along the trace
                current_index += 1
                if current_index >= len(pooled_trace):
                    current_index -= len(pooled_trace)
                current_position = pooled_trace[current_index]

                # Increment the distance
                distance += np.linalg.norm(current_position - previous_position) * p_to_nm

                # Check if the current index is in the to_find list
                if current_index in to_find:
                    # Get rid of the current index from the to_find list
                    to_find.remove(current_index)
                    # Store the distance
                    distances.append(distance)
                    # Reset the distance
                    distance = 0
        else:
            distances = []

        # print(f"distances between defects: {distances}")

        if plot_results:
            if len(defects) > 0:
                print(f"Defects for grain {index}: {len(defects)}")
                fig, ax = plt.subplots(1, len(defects), figsize=(10 * len(defects), 10))
                for defect_index, defect in enumerate(defects):
                    if len(defects) == 1:
                        thisax = ax
                    else:
                        thisax = ax[defect_index]
                    # Plot the defect
                    thisax.imshow(grain_image, cmap=CMAP, vmin=vmin, vmax=vmax)
                    # Plot a horizontal line at the top left of the image starting at 10, 10 with a length equal to the nm distance threshold for defect
                    thisax.plot([2, 2 + defect_nm_value / p_to_nm], [2, 2], "r")
                    # Plot the points
                    thisax.scatter(pooled_trace[:, 1], pooled_trace[:, 0], c="k", s=20)
                    # Plot the start point of the entire trace
                    thisax.scatter(pooled_trace[0, 1], pooled_trace[0, 0], c="pink", s=60)
                    # Write the angle shift at each point
                    # for i, point in enumerate(pooled_trace):
                    #     thisax.text(
                    #         point[1], point[0], f"{int(np.round(np.degrees(angle_diffs[i])))}", fontsize=14, color="white"
                    #     )
                    # Plot the start and end index
                    thisax.scatter(
                        pooled_trace[defect["start_index"], 1],
                        pooled_trace[defect["start_index"], 0],
                        c="blue",
                        s=60,
                        alpha=1,
                    )
                    thisax.scatter(
                        pooled_trace[defect["end_index"], 1],
                        pooled_trace[defect["end_index"], 0],
                        c="red",
                        s=60,
                        alpha=1,
                    )
                    # # Plot the maximum angle shift index
                    # thisax.scatter(
                    #     pooled_trace[defect["maximum_angle_shift_index"], 1],
                    #     pooled_trace[defect["maximum_angle_shift_index"], 0],
                    #     c="green",
                    #     s=100,
                    #     alpha=1,
                    # )
                    # Plot the weighted mean angle shift index
                    thisax.scatter(
                        pooled_trace[defect["midpoint_index"], 1],
                        pooled_trace[defect["midpoint_index"], 0],
                        c="yellow",
                        s=100,
                        alpha=1,
                    )

                    # Plot a line between the start and end index following the pooled trace
                    if defect["end_index"] > defect["start_index"]:
                        thisax.plot(
                            pooled_trace[defect["start_index"] : defect["end_index"] + 1, 1],
                            pooled_trace[defect["start_index"] : defect["end_index"] + 1, 0],
                            "lime",
                            linewidth=2,
                        )
                    else:
                        thisax.plot(
                            np.append(
                                pooled_trace[defect["start_index"] :, 1],
                                pooled_trace[: defect["end_index"] + 1, 1],
                            ),
                            np.append(
                                pooled_trace[defect["start_index"] :, 0],
                                pooled_trace[: defect["end_index"] + 1, 0],
                            ),
                            "lime",
                            linewidth=2,
                        )

                plt.suptitle(
                    f"Defects {len(defects)} for grain {index}, total distance: {total_distance_nm:.2f} nm, defect degrees: {defect_degrees_value}, defect nm: {defect_nm_value}, p_to_nm: {p_to_nm:.2f}"
                )
                fig.tight_layout()
                plt.show()
            else:
                print(f"No defects for grain {index}")

        # Classification
        # If there are 3 defects then it is a dorito
        # If there are 2 defects then it is either a churro or pasty
        # If the two defects have significantly different distances then it's a pasty, else it's a churro
        # If it has fewer than 2 or more than 3, then it is unclassified

        if len(defects) == 3:
            tag = "dorito"
        elif len(defects) == 2:
            abs_diff_distance = np.abs(distances[0] - distances[1])
            # print(f"index : {index} abs diff distance: {abs_diff_distance} total distance: {total_distance_nm} pasty threshold: {pasty_distance_deviation_threshold_percentage * total_distance_nm}")
            if abs_diff_distance > pasty_distance_deviation_threshold_percentage * total_distance_nm:
                tag = "pasty"
            else:
                tag = "churro"
        elif len(defects) == 1:
            tag = "teardrop"
        else:
            tag = "open"

        turn_in_distance_grain_dict[index] = grain_data
        turn_in_distance_grain_dict[index]["pooled_trace"] = pooled_trace
        turn_in_distance_grain_dict[index]["complex_defects"] = defects
        turn_in_distance_grain_dict[index]["complex_distances_between_defects"] = distances
        turn_in_distance_grain_dict[index]["complex_tag"] = tag
        turn_in_distance_grain_dict[index]["total_distance"] = total_distance_nm
        turn_in_distance_grain_dict[index]["complex_num_defects"] = len(defects)
        turn_in_distance_grain_dict[index]["distances_between_points"] = distances_between_points
        turn_in_distance_grain_dict[index]["height_trace"] = height_trace

    except ValueError as e:
        if "Max iterations reached" in str(e):
            print(f"Max iterations reached for grain {index}, skipping")
            continue
        else:
            raise e


def plot_images(
    images: list,
    tags: list,
    traces: list,
    grain_indexes: list,
    defects: list,
    distances_between_defects: list,
    px_to_nms: list,
    width=5,
    cmap=CMAP,
    vmin=0,
    vmax=4,
):
    num_images = len(images)
    rows = np.ceil(num_images / width).astype(int)
    fig, ax = plt.subplots(rows, width, figsize=(30, rows * 5))
    for i, (
        image,
        tag,
        grain_index,
        trace,
        defect_dict,
        defect_distances,
    ) in enumerate(zip(images, tags, grain_indexes, traces, defects, distances_between_defects)):
        if rows == 1:
            thisax = ax[i]
        else:
            thisax = ax[i // width, i % width]
        thisax.imshow(image, cmap=CMAP, vmin=vmin, vmax=vmax)
        thisax.axis("off")
        distances_between_defects_string = ", ".join([f"{distance:.2f}" for distance in defect_distances])
        defect_positions_string = [np.round(d["midpoint_nm"], 1) for d in defect_dict]
        thisax.set_title(
            f"index: {grain_index} tag: {tag} p_to_nm: {px_to_nms[i]:.2f}\n defect positions: {defect_positions_string} nm total distance: {np.sum(defect_distances):.2f} nm"
        )
        thisax.plot(trace[:, 1], trace[:, 0], "green")
        for defect_index, defect in enumerate(defect_dict):
            thisax.scatter(
                trace[defect["midpoint_index"]][1],
                trace[defect["midpoint_index"]][0],
                c="cyan",
                s=400,
            )
        # Plot a horizontal line at the bottom right of the image with the length of 20nm
        line_length_nm = 10
        line_length_pixels = line_length_nm / px_to_nms[i]
        offset_from_right = image.shape[1] * 0.05
        offset_from_bottom = image.shape[1] * 0.05
        line_right_point_x = image.shape[1] - offset_from_right
        line_left_point_x = line_right_point_x - line_length_pixels
        line_bottom_point_y = image.shape[0] - offset_from_bottom
        line_top_point_y = line_bottom_point_y
        thisax.plot(
            [line_left_point_x, line_right_point_x],
            [line_bottom_point_y, line_top_point_y],
            "white",
            linewidth=5,
        )
        # Add text above the line
        thisax.text(
            line_right_point_x - 10,
            line_bottom_point_y - 2,
            f"{line_length_nm} nm",
            fontsize=24,
            fontweight="bold",
            color="white",
        )
    fig.tight_layout()
    plt.show()


# plot_images(
#     [turn_in_distance_grain_dict[i]["image"] for i in turn_in_distance_grain_dict],
#     [turn_in_distance_grain_dict[i]["tag"] for i in turn_in_distance_grain_dict],
#     [i for i in turn_in_distance_grain_dict],
#     [turn_in_distance_grain_dict[i]["p_to_nm"] for i in turn_in_distance_grain_dict],
# )

if plot_classifications:
    for tag_to_plot in ["churro", "pasty", "dorito", "teardrop", "open"]:
        indexes = [
            i for i in turn_in_distance_grain_dict if turn_in_distance_grain_dict[i]["complex_tag"] == tag_to_plot
        ]
        print(tag_to_plot, indexes)
        if len(indexes) > 0:
            plot_images(
                [turn_in_distance_grain_dict[i]["image"] for i in indexes],
                [turn_in_distance_grain_dict[i]["complex_tag"] for i in indexes],
                [turn_in_distance_grain_dict[i]["pooled_trace"] for i in indexes],
                [i for i in indexes],
                [turn_in_distance_grain_dict[i]["complex_defects"] for i in indexes],
                [turn_in_distance_grain_dict[i]["complex_distances_between_defects"] for i in indexes],
                [turn_in_distance_grain_dict[i]["p_to_nm"] for i in indexes],
            )

# Plot bar chart of tags
tags = [turn_in_distance_grain_dict[i]["complex_tag"] for i in turn_in_distance_grain_dict]
unique_tags = ["churro", "pasty", "dorito", "teardrop", "open"]
counts = [tags.count(tag) for tag in unique_tags]
plt.bar(unique_tags, counts)
plt.title(f"(Turn in distance based) distribution for {SAMPLE_TYPE}")
plt.show()

In [None]:
defect_positions_nm = []
defect_positions_percentage = []
defect_heights = []
mean_pooled_trace_height_values = []
defect_height_difference_from_means = []

num_grains_with_defects = 0

for index, grain_data in turn_in_distance_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["predicted_mask"]
    p_to_nm = grain_data["p_to_nm"]
    path = grain_data["path"]
    pooled_trace = grain_data["pooled_path"]
    height_trace = grain_data["heights"]
    complex_defects = grain_data["complex_defects"]
    total_distance_nm = grain_data["total_distance"]

    # if index > 10:
    #     break

    # print(f"grain index: {index} | p_to_nm: {p_to_nm:.2f} | tag: {grain_data['complex_tag']} | total distance: {total_distance_nm:.2f} nm | defects: {len(complex_defects)} | defect positions: {[np.round(d['midpoint_nm'], 1) for d in complex_defects]} nm")

    # plt.imshow(grain_image, cmap=CMAP, vmin=vmin, vmax=vmax)
    # for defect in complex_defects:
    #     plt.scatter(
    #         pooled_trace[defect["midpoint_index"]][1],
    #         pooled_trace[defect["midpoint_index"]][0],
    #         c="cyan",
    #         s=400,
    #     )
    # plt.plot(pooled_trace[:, 1], pooled_trace[:, 0], "green", alpha=0.3)
    # plt.show()

    if len(complex_defects) == 0:
        continue

    num_grains_with_defects += 1

    sub_defect_positions_nm = []
    sub_defect_positions_percentage = []
    sub_defect_heights = []
    sub_defect_height_difference_from_means = []

    for defect in complex_defects:
        defect_position_index = defect["midpoint_index"]
        defect_position_nm = np.round(defect["midpoint_nm"], 2)
        defect_position_percentage = np.round(defect["midpoint_nm"] / total_distance_nm, 2)
        defect_height = grain_image[
            int(np.round(pooled_trace[defect_position_index][0])), int(np.round(pooled_trace[defect_position_index][1]))
        ]

        sub_defect_positions_nm.append(defect_position_nm)
        sub_defect_positions_percentage.append(defect_position_percentage)
        sub_defect_heights.append(defect_height)

        defect_height_difference_from_mean = defect_height - np.mean(height_trace)

        sub_defect_height_difference_from_means.append(defect_height_difference_from_mean)

    defect_positions_nm.append(sub_defect_positions_nm)
    defect_positions_percentage.append(sub_defect_positions_percentage)
    defect_heights.append(sub_defect_heights)
    defect_height_difference_from_means.append(sub_defect_height_difference_from_means)

    mean_pooled_trace_height_values.append(np.mean(height_trace))

# Plot nm defect positions as a histogram
defect_positions_nm_flat = [item for sublist in defect_positions_nm for item in sublist]
plt.hist(defect_positions_nm_flat, bins=20)
plt.title(
    f"nm Curvature defect positions from protein for {SAMPLE_TYPE} n defects = {len(defect_positions_nm_flat)} n grains = {num_grains_with_defects}"
)
plt.xlabel("Defect position (nm)")
plt.ylabel("Count")
plt.xlim(0, 40)
plt.show()

# plot as kde too
sns.kdeplot(defect_positions_nm_flat, fill=True)
plt.title(
    f"nm Curvature defect positions from protein for {SAMPLE_TYPE} n defects = {len(defect_positions_nm_flat)} n grains = {num_grains_with_defects}"
)
plt.xlabel("Defect position (nm)")
plt.ylabel("Density")
plt.xlim(0, 40)
plt.show()

# Plot percentage defect positions as a histogram
defect_positions_percentage_flat = [item for sublist in defect_positions_percentage for item in sublist]
plt.hist(defect_positions_percentage_flat, bins=20)
plt.title(
    f"% Curvature defect positions from protein for {SAMPLE_TYPE} n defects = {len(defect_positions_percentage_flat)} n grains = {num_grains_with_defects}"
)
plt.xlabel("Defect position (%)")
plt.ylabel("Count")
plt.xlim(0, 1)
plt.show()

# plot as kde too
sns.kdeplot(defect_positions_percentage_flat, fill=True)
plt.title(
    f"% Curvature defect positions from protein for {SAMPLE_TYPE} n defects = {len(defect_positions_percentage_flat)} n grains = {num_grains_with_defects}"
)
plt.xlabel("Defect position (%)")
plt.ylabel("Density")
plt.xlim(0, 1)
plt.show()


# Plot defect heights as a histogram
defect_heights_flat = [item for sublist in defect_heights for item in sublist]
plt.hist(defect_heights_flat, bins=20)
# plot vertical line at average height value for all grains
mean_height_value = np.mean(mean_pooled_trace_height_values)
plt.axvline(mean_height_value, color="black", linestyle="--", label="mean trace heights")
plt.title(
    f"Defect heights for {SAMPLE_TYPE} n-defects = {len(defect_heights_flat)} n-grains = {num_grains_with_defects}"
)
plt.xlabel("Defect height")
plt.ylabel("Count")
plt.legend()
plt.show()

# plot as kde too
sns.kdeplot(defect_heights_flat, fill=True)
# plot vertical line at average height value for all grains
plt.axvline(mean_height_value, color="black", linestyle="--", label="mean trace heights")
plt.title(
    f"Defect heights for {SAMPLE_TYPE} n-defects = {len(defect_heights_flat)} n-grains = {num_grains_with_defects}"
)
plt.xlabel("Defect height")
plt.ylabel("Density")
plt.legend()
plt.show()

# Plot the difference from the mean height for each defect
defect_height_difference_from_means_flat = [item for sublist in defect_height_difference_from_means for item in sublist]
plt.hist(defect_height_difference_from_means_flat, bins=20)
# plot vertical line at zero
plt.axvline(0, color="black", linestyle="--")
plt.title(
    f"Defect height difference from mean for {SAMPLE_TYPE} n-defects = {len(defect_height_difference_from_means_flat)} n-grains = {num_grains_with_defects}"
)
plt.xlabel("Defect height difference from mean")
plt.ylabel("Count")
plt.show()

# plot as kde too
sns.kdeplot(defect_height_difference_from_means_flat, fill=True)
# plot vertical line at zero
plt.axvline(0, color="black", linestyle="--")
plt.title(
    f"Defect height difference from mean for {SAMPLE_TYPE} n-defects = {len(defect_height_difference_from_means_flat)} n-grains = {num_grains_with_defects}"
)
plt.xlabel("Defect height difference from mean")
plt.ylabel("Density")
plt.show()

In [None]:
# Save the dictionaries

# with open(SAVE_DIR / f"{SAMPLE_TYPE}_dict.pkl", "wb") as f:
#     pickle.dump(binding_angle_grain_dict, f)