In [None]:
from pathlib import Path
from typing import Union
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import h5py


from topostats.plottingfuncs import Colormap

colormap = Colormap()
cmap = colormap.get_cmap()

In [None]:
def plot(
    image: np.ndarray, title: str = None, vmin: float = -8, vmax: float = 8, cmap=cmap, figsize=(10, 10), cbar=False
):
    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
    if title is not None:
        ax.set_title(title)
    if cbar:
        fig.colorbar(im, ax=ax)
    plt.show()

In [None]:
MAX_PX_TO_NM = 0.59
BBOX_PAD = 5

In [None]:
def plot_images(images: list, masks: list, px_to_nms: list, grain_indexes: list, width=5, cmap=cmap, vmin=-8, vmax=8):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 3, figsize=(30, 30))
    for i, (image, mask, grain_index) in enumerate(zip(images, masks, grain_indexes)):
        ax[i // width, i % width * 3].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3].axis("off")
        ax[i // width, i % width * 3 + 1].imshow(mask, cmap="binary")
        ax[i // width, i % width * 3].set_title(f"grain: {grain_index} p_to_nm: {px_to_nms[i]}")
        ax[i // width, i % width * 3 + 2].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3 + 2].imshow(mask, cmap="viridis", alpha=0.2)
    fig.tight_layout()
    plt.show()


today = datetime.now().strftime("%Y-%m-%d")

SAMPLE_TYPE = "OT1_REL"
on_rel = Path(f"/Users/sylvi/topo_data/hariborings/testing_all_unbound_data/output_{SAMPLE_TYPE}/processed/")
SAVE_DIR = Path(f"/Users/sylvi/topo_data/hariborings/processed_grains/unbound_{SAMPLE_TYPE}/date_{today}")
SAVE_DIR.mkdir(exist_ok=True, parents=True)
assert SAVE_DIR.exists()
assert on_rel.exists()
# Grab all .topostats files
on_files = list(on_rel.glob("*.topostats"))

# file = on_files[1]

grains_processed = 0
stop_at_grain = 200
plotting = False

grain_dict = {}

for file in on_files:
    print(file)
    # Load file
    with h5py.File(file, "r") as f:
        print(f.keys())
        image = f["image"][:]
        grain_masks = f["grain_masks"]["above"][:]
        p_to_nm = f["pixel_to_nm_scaling"][()]

    if p_to_nm > MAX_PX_TO_NM:
        continue

    # Plot image and mask side by side
    if plotting:
        fig, ax = plt.subplots(1, 2, figsize=(20, 10))
        ax[0].imshow(image, cmap=cmap, vmin=-8, vmax=8)
        ax[0].set_title("image")
        ax[1].imshow(grain_masks, cmap="gray")
        ax[1].set_title("grain_masks")
        plt.suptitle(f"pixel to nm scaling: {p_to_nm}")
        fig.tight_layout()
        plt.show()

    # Process the grains
    for grain in range(1, grain_masks.max() + 1):
        if grains_processed == stop_at_grain:
            break
        # Get the bounding box of the grain
        grain_mask_fullsize = grain_masks == grain
        grain_bbox = np.argwhere(grain_mask_fullsize)
        minr, minc = grain_bbox.min(axis=0)
        maxr, maxc = grain_bbox.max(axis=0)
        # Add padding to the bounding box
        minr = max(0, minr - BBOX_PAD)
        minc = max(0, minc - BBOX_PAD)
        maxr = min(grain_mask_fullsize.shape[0], maxr + BBOX_PAD)
        maxc = min(grain_mask_fullsize.shape[1], maxc + BBOX_PAD)

        # Get the crop of grain image
        grain_image = image[minr:maxr, minc:maxc]
        grain_mask = grain_mask_fullsize[minr:maxr, minc:maxc]

        if plotting:
            fig, ax = plt.subplots(1, 3, figsize=(20, 10))
            ax[0].imshow(grain_mask, cmap="gray")
            ax[0].set_title("grain mask")
            ax[1].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
            ax[1].set_title("grain image")
            ax[2].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
            ax[2].imshow(grain_mask, cmap="gray", alpha=0.2)
            plt.show()

        grain_dict[grains_processed] = {
            "image": grain_image,
            "mask": grain_mask,
            "p_to_nm": p_to_nm,
        }

        grains_processed += 1

    if grains_processed == stop_at_grain:
        break

# Plot the grains
images = [grain_dict[i]["image"] for i in range(grains_processed)]
masks = [grain_dict[i]["mask"] for i in range(grains_processed)]
px_to_nms = [grain_dict[i]["p_to_nm"] for i in range(grains_processed)]
grain_indexes = list(range(grains_processed))
plot_images(
    images,
    masks,
    px_to_nms,
    grain_indexes,
)

In [None]:
# Clean up the masks
from skimage.morphology import binary_dilation, binary_erosion

DILATION_PASS = 2
ERODE_PASS = 2

dilated_grain_dict = {}

for index, grain_data in grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]

    # Dilation
    for _ in range(DILATION_PASS):
        grain_mask = binary_dilation(grain_mask)
    # Erosion
    for _ in range(ERODE_PASS):
        grain_mask = binary_erosion(grain_mask)

    dilated_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "p_to_nm": p_to_nm,
    }

plot_images(
    [dilated_grain_dict[i]["image"] for i in range(grains_processed)],
    [dilated_grain_dict[i]["mask"] for i in range(grains_processed)],
    [dilated_grain_dict[i]["p_to_nm"] for i in range(grains_processed)],
    [i for i in range(grains_processed)],
)


from skimage.measure import label, regionprops

LOWER_AREA_BOUND = 100
UPPER_AREA_BOUND = 10000

removed_anomaly_grain_dict = {}
for index, grain_data in dilated_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]

    # Label the grains
    labelled_background = label(grain_mask == 0)
    background_props = regionprops(labelled_background)

    if len(background_props) < 2:
        print(f"Grain {index} has too few background regions")
        # plt.imshow(labelled_background)
        # print(len(background_props))
        # plt.show()
    elif len(background_props) >= 3:
        print(f"Grain {index} has too many background regions")
        # plt.imshow(labelled_background)
        # print(len(background_props))
        # plt.show()
    else:
        # Check the size of the foreground
        foreground_area = grain_mask.sum()
        if foreground_area < LOWER_AREA_BOUND:
            print(f"Grain {index} has too small foreground area")
        elif foreground_area > UPPER_AREA_BOUND:
            print(f"Grain {index} has too large foreground area")
        else:
            removed_anomaly_grain_dict[index] = grain_data

plot_images(
    [removed_anomaly_grain_dict[i]["image"] for i in removed_anomaly_grain_dict],
    [removed_anomaly_grain_dict[i]["mask"] for i in removed_anomaly_grain_dict],
    [removed_anomaly_grain_dict[i]["p_to_nm"] for i in removed_anomaly_grain_dict],
    [i for i in removed_anomaly_grain_dict],
)

In [None]:
# Skeletonise using standard skeletonise
from skimage.morphology import skeletonize
from scipy.ndimage import convolve


def plot_images(images: list, masks: list, px_to_nms: list, skeletons: list, width=5, cmap=cmap, vmin=-8, vmax=8):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 3, figsize=(30, 30))
    for i, (image, mask, skeleton) in enumerate(zip(images, masks, skeletons)):
        ax[i // width, i % width * 3].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3].axis("off")
        ax[i // width, i % width * 3 + 1].imshow(skeleton, cmap="binary")
        ax[i // width, i % width * 3].set_title(f"p_to_nm: {px_to_nms[i]}")
        ax[i // width, i % width * 3 + 2].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3 + 2].imshow(skeleton, cmap="viridis", alpha=0.2)
    fig.tight_layout()
    plt.show()


def convolve_skelly(skeleton) -> np.ndarray:
    """Convolves the skeleton with a 3x3 ones kernel to produce an array
    of the skeleton as 1, endpoints as 2, and nodes as 3.

    Parameters
    ----------
    skeleton: np.ndarray
        Single pixel thick binary trace(s) within an array.

    Returns
    -------
    np.ndarray
        The skeleton (=1) with endpoints (=2), and crossings (=3) highlighted.
    """
    conv = convolve(skeleton.astype(np.int32), np.ones((3, 3)))
    conv[skeleton == 0] = 0  # remove non-skeleton points
    conv[conv == 3] = 1  # skelly = 1
    conv[conv > 3] = 3  # nodes = 3
    return conv


plotting = False
paths_grain_dict = {}
# Method for finding the optimal path in the molecule mask
# Options: skeletonize, distance transform, height
path_method = "distance transform"

for index, grain_data in removed_anomaly_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]

    # Skeletonise
    skeleton = skeletonize(grain_mask)

    if plotting:
        fig, ax = plt.subplots(1, 4, figsize=(20, 10))
        ax[0].imshow(grain_mask, cmap="gray")
        ax[0].set_title("grain mask")
        ax[1].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        ax[1].set_title("grain image")
        ax[2].imshow(skeleton, cmap="gray")
        ax[2].set_title("skeleton")
        ax[3].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        ax[3].imshow(skeleton, cmap="viridis", alpha=0.2)
        plt.show()

    # Ignore any grains that have a branch, ie a pixel with more than 2 neighbours
    convolved_skelly = convolve_skelly(skeleton)

    if np.max(convolved_skelly) > 1:
        print(f"Grain {index} has a branch")
        continue

    paths_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "skeleton": skeleton,
        "p_to_nm": p_to_nm,
    }

plot_images(
    [paths_grain_dict[i]["image"] for i in paths_grain_dict],
    [paths_grain_dict[i]["mask"] for i in paths_grain_dict],
    [paths_grain_dict[i]["p_to_nm"] for i in paths_grain_dict],
    [paths_grain_dict[i]["skeleton"] for i in paths_grain_dict],
)

In [None]:
# Trace the skeleton


def plot_images(
    images: list, masks: list, px_to_nms: list, skeletons: list, traces: list, width=5, cmap=cmap, vmin=-8, vmax=8
):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 3, figsize=(30, 30))
    for i, (image, mask, skeleton) in enumerate(zip(images, masks, skeletons)):
        ax[i // width, i % width * 3].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3].axis("off")
        ax[i // width, i % width * 3 + 1].imshow(skeleton, cmap="binary")
        ax[i // width, i % width * 3].set_title(f"p_to_nm: {px_to_nms[i]}")
        ax[i // width, i % width * 3 + 2].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3 + 2].imshow(skeleton, cmap="viridis", alpha=0.2)
        ax[i // width, i % width * 3 + 2].plot(traces[i][:, 1], traces[i][:, 0], "r")
    fig.tight_layout()
    plt.show()


plotting = False
trace_grain_dict = {}

for index, grain_data in paths_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    skeleton = grain_data["skeleton"]
    p_to_nm = grain_data["p_to_nm"]

    # Trace the skeleton
    skeleton_points = np.argwhere(skeleton)
    # Find the start point
    start_point = skeleton_points[0]
    # Each point should not have more than 2 neighbours so we can trace by finding the next point
    # and removing the current point from the skeleton

    skeleton_history = skeleton.copy()
    trace = [start_point]
    current_point = start_point
    skeleton_history[current_point[0], current_point[1]] = 0

    # print(f"len of skeleton points: {len(skeleton_points)}")
    for iteration in range(len(skeleton_points) - 1):
        neighbourhood = skeleton_history[
            current_point[0] - 1 : current_point[0] + 2, current_point[1] - 1 : current_point[1] + 2
        ]
        if iteration > 0 and np.sum(neighbourhood) > 1:
            raise ValueError(f"More than 1 neighbour for iteration {iteration}")
        if np.sum(neighbourhood) == 0:
            raise ValueError(f"No neighbours for iteration {iteration}")
        next_point = np.argwhere(neighbourhood)[0]
        next_point_coords = current_point + next_point - 1
        trace.append(next_point_coords)
        current_point = next_point_coords
        skeleton_history[current_point[0], current_point[1]] = 0
    trace = np.array(trace)

    ordered_skeleton = np.zeros_like(skeleton).astype(int)
    for point_index, point in enumerate(trace):
        ordered_skeleton[point[0], point[1]] = point_index + 20

    if plotting:
        fig, ax = plt.subplots(1, 3, figsize=(20, 10))
        ax[0].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        ax[0].set_title("grain image")
        ax[1].imshow(ordered_skeleton, cmap="viridis")
        ax[1].set_title("ordered skeleton")
        ax[2].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        ax[2].imshow(ordered_skeleton, cmap="viridis", alpha=0.2)
        ax[2].plot(trace[:, 1], trace[:, 0], "r")
        plt.show()

    # Calculate pixel trace length in nm
    pixel_trace_length = np.sum(np.linalg.norm(np.diff(trace, axis=0), axis=1)) * p_to_nm

    trace_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "skeleton": ordered_skeleton,
        "trace": trace,
        "p_to_nm": p_to_nm,
        "pixel_trace_length": pixel_trace_length,
    }


plot_images(
    [trace_grain_dict[i]["image"] for i in trace_grain_dict],
    [trace_grain_dict[i]["mask"] for i in trace_grain_dict],
    [trace_grain_dict[i]["p_to_nm"] for i in trace_grain_dict],
    [trace_grain_dict[i]["skeleton"] for i in trace_grain_dict],
    [trace_grain_dict[i]["trace"] for i in trace_grain_dict],
)

# Plot kde of trace lengths
import seaborn as sns

trace_lengths = [trace_grain_dict[i]["pixel_trace_length"] for i in trace_grain_dict]
sns.kdeplot(trace_lengths)
plt.xlim(0, 60)
plt.xlabel("Trace length (nm)")
plt.title(f"Length of pixel traces in nm for {len(trace_lengths)} grains")
plt.show()

In [None]:
def is_in_polygon(polygon: np.ndarray, point: np.ndarray) -> bool:
    """Check if a point is in a polygon using the ray casting algorithm.

    Parameters
    ----------
    polygon: np.ndarray
        The polygon to check if the point is in.
    point: np.ndarray
        The point to check if it is in the polygon.

    Returns
    -------
    bool
        True if the point is in the polygon, False otherwise.
    """
    x, y = point
    n = len(polygon)
    inside = False
    p1x, p1y = polygon[0]
    for i in range(n + 1):
        p2x, p2y = polygon[i % n]
        if y > min(p1y, p2y):
            if y <= max(p1y, p2y):
                if x <= max(p1x, p2x):
                    if p1y != p2y:
                        xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x
                    if p1x == p2x or x <= xinters:
                        inside = not inside
        p1x, p1y = p2x, p2y
    return inside


def fill_polygon(array: np.ndarray, polygon: np.ndarray, fill_value: float):
    """Fills a polygon within an array with a fill value.

    Parameters
    ----------
    array: np.ndarray
        The array to fill the polygon within.
    polygon: np.ndarray
        The polygon to fill within the array.
    fill_value: float
        The value to fill the polygon with.

    Returns
    -------
    np.ndarray
        The array with the polygon filled.
    """
    minx, miny = np.min(polygon, axis=0)
    maxx, maxy = np.max(polygon, axis=0)
    for y in range(miny, maxy + 1):
        for x in range(minx, maxx + 1):
            if is_in_polygon(polygon, np.array([x, y])):
                array[y, x] = fill_value
    return array


# # Test the polygon filling
# polygon = np.array([[0, 10], [3, 25], [28, 15], [10, 0]])
# array = np.zeros((30, 30))
# filled_array = fill_polygon(array, polygon, 1)
# plt.imshow(filled_array)
# plt.plot(np.append(polygon[:, 0], polygon[0, 0]), np.append(polygon[:, 1], polygon[0, 1]), "r")
# plt.show()

In [None]:
# # Improve the tracing using pathfinding
# from scipy.ndimage import distance_transform_edt
# from skimage.graph import route_through_array

# # Options: height, distance_transform
# method = "distance_transform"

# plotting = True
# improved_path_grain_dict = {}

# for index, grain_data in trace_grain_dict.items():
#     grain_image = grain_data["image"]
#     grain_mask = grain_data["mask"]
#     skeleton = grain_data["skeleton"]
#     trace = grain_data["trace"]
#     p_to_nm = grain_data["p_to_nm"]

#     if method == "distance_transform":
#         # Get the distance transform of the grain mask
#         distance_transform = distance_transform_edt(grain_mask)
#         # Invert the distance transform
#         distance_transform_inverted_show = np.max(distance_transform) - distance_transform
#         distance_transform_inverted = np.max(distance_transform) - distance_transform
#         # Find the minimum value pixel in the inverted distance transform
#         min_point = np.unravel_index(np.argmin(distance_transform_inverted), distance_transform_inverted.shape)

#         # Closest trace point to the minimum value pixel
#         p0_index = np.argmin(np.linalg.norm(trace - min_point, axis=1))
#         p0 = trace[p0_index]

#         # Get the next point
#         if p0_index == len(trace) - 1:
#             p1 = trace[0]
#         else:
#             p1 = trace[p0_index + 1]

#         v1 = p1 - p0
#         v2 = np.array([v1[1], -v1[0]])

#         # Find the rectangle points to fill in
#         # Follow the vector v2 from p0 until the grain mask stops being 1

#         # Get r0
#         grain_mask_intersecting = True
#         r00 = p0
#         while grain_mask_intersecting:
#             r00 = r00 + v2
#             if grain_mask[int(r00[0]), int(r00[1])] == 0:
#                 grain_mask_intersecting = False
#         grain_mask_intersecting = True
#         r01 = p0
#         while grain_mask_intersecting:
#             r01 = r01 - v2
#             if grain_mask[int(r01[0]), int(r01[1])] == 0:
#                 grain_mask_intersecting = False
#         grain_mask_intersecting = True
#         r10 = p1
#         while grain_mask_intersecting:
#             r10 = r10 + v2
#             if grain_mask[int(r10[0]), int(r10[1])] == 0:
#                 grain_mask_intersecting = False
#         grain_mask_intersecting = True
#         r11 = p1
#         while grain_mask_intersecting:
#             r11 = r11 - v2
#             if grain_mask[int(r11[0]), int(r11[1])] == 0:
#                 grain_mask_intersecting = False

#         # Fill in the rectangle
#         polygon = np.array([r00, r01, r11, r10])
#         # Flip x and y in polygon
#         polygon = np.array([polygon[:, 1], polygon[:, 0]]).T
#         filled_array = fill_polygon(distance_transform_inverted, polygon, np.max(distance_transform_inverted))
#         filled_array_show = fill_polygon(
#             distance_transform_inverted_show, polygon, np.max(distance_transform_inverted_show)
#         )

#         # Find the points in the distance transform that should be set to really high
#         # to prevent pathfinding through the grain
#         # plt.imshow(filled_array)
#         # plt.colorbar()
#         # plt.show()

#         filled_array[filled_array == np.max(filled_array)] = 1000

#         # plt.imshow(filled_array)
#         # plt.colorbar()
#         # plt.show()

#         # Get the targets for pathfinding, one skeleton point ahead of p1 and one skeleton point behind p0
#         t0_index = p0_index - 1
#         t0 = trace[t0_index]
#         # Ensure the index is within the trace
#         if p0_index + 2 >= len(trace):
#             t1_index = p0_index + 2 - len(trace)
#         else:
#             t1_index = p0_index + 2
#         t1 = trace[t1_index]

#         # Now pathfind between t0 and t1 using the inverted distance transform
#         improved_path, cost = route_through_array(filled_array, t0, t1)

#         improved_path = np.array(improved_path)

#         # Create a visualisation of the improved path
#         improved_path_array = np.zeros_like(distance_transform_inverted)
#         for point in improved_path:
#             improved_path_array[point[0], point[1]] = 1

#         if plotting:
#             fig, ax = plt.subplots(1, 7, figsize=(30, 10))
#             ax[0].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
#             ax[0].set_title("grain image")
#             ax[1].imshow(grain_mask, cmap="gray")
#             ax[1].set_title("grain mask")
#             ax[2].imshow(distance_transform_inverted_show, cmap=cmap)
#             ax[2].scatter(min_point[1], min_point[0], c="r", s=10)
#             ax[2].set_title("distance transform")
#             ax[3].imshow(skeleton, cmap="gray")
#             ax[3].plot([p0[1], p1[1]], [p0[0], p1[0]], "r")
#             ax[3].plot([r00[1], r01[1], r11[1], r10[1], r00[1]], [r00[0], r01[0], r11[0], r10[0], r00[0]], "r")
#             ax[3].scatter(t0[1], t0[0], c="g", s=10)
#             ax[3].scatter(t1[1], t1[0], c="b", s=10)
#             ax[3].set_title("traced skeleton")
#             ax[4].imshow(grain_mask, cmap="gray")
#             ax[4].imshow(improved_path_array, cmap="viridis", alpha=0.2)
#             ax[4].plot(improved_path[:, 1], improved_path[:, 0], "r")
#             ax[4].set_title("improved path")
#             ax[5].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
#             ax[5].plot(improved_path[:, 1], improved_path[:, 0], "r")
#             ax[5].set_title("distance transform trace")
#             ax[6].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
#             ax[6].plot(trace[:, 1], trace[:, 0], "r")
#             ax[6].set_title("skeleton trace")
#             plt.show()

#     if index > 50:
#         break

In [None]:
# Get height traces from the skeletons


def plot_images(
    images: list, masks: list, px_to_nms: list, traces: list, height_traces: list, width=5, cmap=cmap, vmin=-8, vmax=8
):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 3, figsize=(30, 30))
    for i, (image, mask, height_trace) in enumerate(zip(images, masks, height_traces)):
        ax[i // width, i % width * 3].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3].axis("off")
        ax[i // width, i % width * 3].set_title(f"p_to_nm: {px_to_nms[i]:.2f}")
        ax[i // width, i % width * 3 + 1].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3 + 1].plot(traces[i][:, 1], traces[i][:, 0], "r")
        ax[i // width, i % width * 3 + 2].plot(height_trace)
    fig.tight_layout()
    plt.show()


plotting = False
height_trace_grain_dict = {}


for index, grain_data in trace_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]
    trace = grain_data["trace"]

    # Get the height trace
    height_trace = grain_image[trace[:, 0], trace[:, 1]]

    if plotting:
        fig, ax = plt.subplots(1, 2, figsize=(40, 10))
        ax[0].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        colours = np.linspace(0, 1, len(trace))
        for i in range(len(height_trace) - 1):
            ax[0].plot(trace[i : i + 2, 1], trace[i : i + 2, 0], color=plt.cm.viridis(colours[i]))
        ax[0].set_title("grain image")
        colours = np.linspace(0, 1, len(height_trace))
        xs = np.arange(len(height_trace))
        for i in range(len(height_trace) - 1):
            # Include +2 to get the next point since python slicing is exclusive
            ax[1].plot(xs[i : i + 2], height_trace[i : i + 2], color=plt.cm.viridis(colours[i]))
        ax[1].set_ylim(1, 3.5)
        plt.show()

    height_trace_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "trace": trace,
        "height_trace": height_trace,
        "p_to_nm": p_to_nm,
    }

plot_images(
    [height_trace_grain_dict[i]["image"] for i in height_trace_grain_dict],
    [height_trace_grain_dict[i]["mask"] for i in height_trace_grain_dict],
    [height_trace_grain_dict[i]["p_to_nm"] for i in height_trace_grain_dict],
    [height_trace_grain_dict[i]["trace"] for i in height_trace_grain_dict],
    [height_trace_grain_dict[i]["height_trace"] for i in height_trace_grain_dict],
)

In [None]:
# For each trace, pool sets of n pixels


def plot_images(
    images: list, original_traces: list, pooled_traces: list, px_to_nms: list, width=5, cmap=cmap, vmin=-8, vmax=8
):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width, figsize=(20, 50))
    for i, (image, original_trace, pooled_trace) in enumerate(zip(images, original_traces, pooled_traces)):
        ax[i // width, i % width].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width].plot(original_trace[:, 1], original_trace[:, 0], "r")
        ax[i // width, i % width].plot(pooled_trace[:, 1], pooled_trace[:, 0], "b")
        ax[i // width, i % width].axis("off")
    fig.tight_layout()
    plt.show()


plotting = False

pooled_curvature_grain_dict = {}

# Binning size
n = 6

for index, grain_data in height_trace_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]
    trace = grain_data["trace"]
    height_trace = grain_data["height_trace"]

    # Pool the trace points
    pooled_trace = []

    for i in range(len(trace)):
        binned_points = []
        for j in range(n):
            if i + j < len(trace):
                binned_points.append(trace[i + j])
            else:
                # If the index is out of range, sample from the start
                binned_points.append(trace[i + j - len(trace)])

        # Get the mean of the binned points
        pooled_trace.append(np.mean(binned_points, axis=0))

    pooled_trace = np.array(pooled_trace)

    if plotting:
        # Plot overlaid on image
        fig, ax = plt.subplots(1, 1, figsize=(20, 10))
        ax.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        ax.plot(trace[:, 1], trace[:, 0], "r")
        ax.plot(pooled_trace[:, 1], pooled_trace[:, 0], "b")
        ax.set_title("Trace and pooled trace")

    pooled_curvature_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "trace": trace,
        "pooled_trace": pooled_trace,
        "height_trace": height_trace,
        "p_to_nm": p_to_nm,
    }

plot_images(
    [pooled_curvature_grain_dict[i]["image"] for i in pooled_curvature_grain_dict],
    [pooled_curvature_grain_dict[i]["trace"] for i in pooled_curvature_grain_dict],
    [pooled_curvature_grain_dict[i]["pooled_trace"] for i in pooled_curvature_grain_dict],
    [pooled_curvature_grain_dict[i]["p_to_nm"] for i in pooled_curvature_grain_dict],
)

In [None]:
from skimage.morphology import binary_erosion
from scipy.ndimage import binary_fill_holes


def calculate_edges(grain_mask: np.ndarray):
    """Class method that takes a 2D boolean numpy array image of a grain and returns a python list of the
    coordinates of the edges of the grain.

    Parameters
    ----------
    grain_mask : np.ndarray
        A 2D numpy array image of a grain. Data in the array must be boolean.
    edge_detection_method : str
        Method used for detecting the edges of grain masks before calculating statistics on them.
        Do not change unless you know exactly what this is doing. Options: "binary_erosion", "canny".

    Returns
    -------
    edges : list
        List containing the coordinates of the edges of the grain.
    """
    # Fill any holes
    filled_grain_mask = binary_fill_holes(grain_mask)

    # Add padding (needed for erosion)
    padded = np.pad(filled_grain_mask, 1)
    # Erode by 1 pixel
    eroded = binary_erosion(padded)
    # Remove padding
    eroded = eroded[1:-1, 1:-1]

    # Edges is equal to the difference between the
    # original image and the eroded image.
    edges = filled_grain_mask.astype(int) - eroded.astype(int)

    nonzero_coordinates = edges.nonzero()
    # Get vector representation of the points
    # FIXME : Switched to list comprehension but should be unnecessary to create this as a list as we can use
    # np.stack() to combine the arrays and use that...
    # return np.stack(nonzero_coordinates, axis=1)
    # edges = []
    # for vector in np.transpose(nonzero_coordinates):
    #     edges.append(list(vector))
    # return edges
    return [list(vector) for vector in np.transpose(nonzero_coordinates)]


def is_clockwise(p_1: tuple, p_2: tuple, p_3: tuple) -> bool:
    """Function to determine if three points make a clockwise or counter-clockwise turn.

    Parameters
    ----------
    p_1: tuple
        First point to be used to calculate turn.
    p_2: tuple
        Second point to be used to calculate turn.
    p_3: tuple
        Third point to be used to calculate turn.

    Returns
    -------
    boolean
        Indicator of whether turn is clockwise.
    """
    # Determine if three points form a clockwise or counter-clockwise turn.
    # I use the method of calculating the determinant of the following rotation matrix here. If the determinant
    # is > 0 then the rotation is counter-clockwise.
    rotation_matrix = np.array(((p_1[0], p_1[1], 1), (p_2[0], p_2[1], 1), (p_3[0], p_3[1], 1)))
    return not np.linalg.det(rotation_matrix) > 0


def get_triangle_height(base_point_1: np.array, base_point_2: np.array, top_point: np.array) -> float:
    """Returns the height of a triangle defined by the input point vectors.
    Parameters
    ----------
    base_point_1: np.ndarray
        a base point of the triangle, eg: [5, 3].

    base_point_2: np.ndarray
        a base point of the triangle, eg: [8, 3].

    top_point: np.ndarray
        the top point of the triangle, defining the height from the line between the two base points, eg: [6,10].

    Returns
    -------
    Float
        The height of the triangle - ie the shortest distance between the top point and the line between the two
    base points.
    """

    # Height of triangle = A/b = ||AB X AC|| / ||AB||
    a_b = base_point_1 - base_point_2
    a_c = base_point_1 - top_point
    return np.linalg.norm(np.cross(a_b, a_c)) / np.linalg.norm(a_b)


def get_max_min_ferets(edge_points: list):
    """Returns the minimum and maximum feret diameters for a grain.
    These are defined as the smallest and greatest distances between
    a pair of callipers that are rotating around a 2d object, maintaining
    contact at all times.

    Parameters
    ----------
    edge_points: list
        a list of the vector positions of the pixels comprising the edge of the
        grain. Eg: [[0, 0], [1, 0], [2, 1]]
    Returns
    -------
    min_feret: float
        the minimum feret diameter of the grain
    max_feret: float
        the maximum feret diameter of the grain

    Notes
    -----
    The method starts out by calculating the upper and lower convex hulls using
    an algorithm based on the Graham Scan Algorithm [1]. Using these upper and
    lower hulls, the callipers are simulated as rotating clockwise around the grain.
    We determine the order in which vertices are encountered by comparing the
    gradients of the slopes between vertices. An array of pairs of points that
    are in contact with either calliper at a given time is created in order to
    be able to calculate the maximum feret diameter. The minimum diameter is a
    little tricky, since it won't simply be the shortest distance between two
    contact points, but it will occur somewhere during the rotation around a
    pair of contact points. It turns out that the point will always be such
    that two points are in contact with one calliper while the other calliper
    is in contact with another point. We can use this fact to be sure of finding
    the smallest feret diameter, simply by testing each triangle of 3 contact points
    as we iterate, finding the height of the triangle that is formed between the
    three aforementioned points, as this will be the perpendicular distance between
    the callipers.

    References
    ----------
    [1] Graham, R.L. (1972).
        "An Efficient Algorithm for Determining the Convex Hull of a Finite Planar Set".
        Information Processing Letters. 1 (4): 132-133.
        doi:10.1016/0020-0190(72)90045-2.
    """

    min_feret_triangle = None

    # Sort the vectors by their x coordinate and then by their y coordinate.
    # The conversion between list and numpy array can be removed, though it would be harder
    # to read.
    edge_points.sort()
    edge_points = np.array(edge_points)

    # Construct upper and lower hulls for the edge points. Sadly we can't just use the standard hull
    # that graham_scan() returns, since we need to separate the upper and lower hulls. I might streamline
    # these two into one method later.
    upper_hull = []
    lower_hull = []
    for point in edge_points:
        while len(lower_hull) > 1 and is_clockwise(lower_hull[-2], lower_hull[-1], point):
            lower_hull.pop()
        lower_hull.append(point)
        while len(upper_hull) > 1 and not is_clockwise(upper_hull[-2], upper_hull[-1], point):
            upper_hull.pop()
        upper_hull.append(point)

    upper_hull = np.array(upper_hull)
    lower_hull = np.array(lower_hull)

    # Create list of contact vertices for calipers on the antipodal hulls
    contact_points = []
    upper_index = 0
    lower_index = len(lower_hull) - 1
    min_feret = None
    while upper_index < len(upper_hull) - 1 or lower_index > 0:
        contact_points.append([lower_hull[lower_index, :], upper_hull[upper_index, :]])
        # If we have reached the end of the upper hull, continute iterating over the lower hull
        if upper_index == len(upper_hull) - 1:
            lower_index -= 1
            small_feret = get_triangle_height(
                np.array(lower_hull[lower_index + 1, :]),
                np.array(lower_hull[lower_index, :]),
                np.array(upper_hull[upper_index, :]),
            )
            if min_feret is None or small_feret < min_feret:
                min_feret = small_feret
                min_feret_triangle = [
                    lower_hull[lower_index + 1, :],
                    lower_hull[lower_index, :],
                    upper_hull[upper_index, :],
                ]
        # If we have reached the end of the lower hull, continue iterating over the upper hull
        elif lower_index == 0:
            upper_index += 1
            small_feret = get_triangle_height(
                np.array(upper_hull[upper_index - 1, :]),
                np.array(upper_hull[upper_index, :]),
                np.array(lower_hull[lower_index, :]),
            )
            if min_feret is None or small_feret < min_feret:
                min_feret = small_feret
                min_feret_triangle = [
                    lower_hull[lower_index + 1, :],
                    lower_hull[lower_index, :],
                    upper_hull[upper_index, :],
                ]
        # Check if the gradient of the last point and the proposed next point in the upper hull is greater than the gradient
        # of the two corresponding points in the lower hull, if so, this means that the next point in the upper hull
        # will be encountered before the next point in the lower hull and vice versa.
        # Note that the calculation here for gradients is the simple delta upper_y / delta upper_x > delta lower_y / delta lower_x
        # however I have multiplied through the denominators such that there are no instances of division by zero. The
        # inequality still holds and provides what is needed.
        elif (upper_hull[upper_index + 1, 1] - upper_hull[upper_index, 1]) * (
            lower_hull[lower_index, 0] - lower_hull[lower_index - 1, 0]
        ) > (lower_hull[lower_index, 1] - lower_hull[lower_index - 1, 1]) * (
            upper_hull[upper_index + 1, 0] - upper_hull[upper_index, 0]
        ):
            # If the upper hull is encoutnered first, increment the iteration index for the upper hull
            # Also consider the triangle that is made as the two upper hull vertices are colinear with the caliper
            upper_index += 1
            small_feret = get_triangle_height(
                np.array(upper_hull[upper_index - 1, :]),
                np.array(upper_hull[upper_index, :]),
                np.array(lower_hull[lower_index, :]),
            )
            if min_feret is None or small_feret < min_feret:
                min_feret = small_feret
        else:
            # The next point in the lower hull will be encountered first, so increment the lower hull iteration index.
            lower_index -= 1
            small_feret = get_triangle_height(
                np.array(lower_hull[lower_index + 1, :]),
                np.array(lower_hull[lower_index, :]),
                np.array(upper_hull[upper_index, :]),
            )

            if min_feret is None or small_feret < min_feret:
                min_feret = small_feret
                min_feret_triangle = [
                    lower_hull[lower_index + 1, :],
                    lower_hull[lower_index, :],
                    upper_hull[upper_index, :],
                ]

    contact_points = np.array(contact_points)

    # Find the minimum and maximum distance in the contact points
    max_feret = None
    for point_pair in contact_points:
        dist = np.sqrt((point_pair[0, 0] - point_pair[1, 0]) ** 2 + (point_pair[0, 1] - point_pair[1, 1]) ** 2)
        if max_feret is None or max_feret < dist:
            max_feret = dist

    return min_feret, max_feret, min_feret_triangle


def shoelace(points: np.ndarray):
    """Use shoelace method to calculate area of polygon"""

    # Add the first point to the end of the array
    points = np.vstack([points, points[0]])
    x = points[:, 0]
    y = points[:, 1]
    return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))

In [None]:
openness_grain_dict = {}

for index, grain_data in pooled_curvature_grain_dict.items():
    # print(f"grain index: {index}")
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]
    pooled_trace = grain_data["pooled_trace"]

    # Calculate the minimum and maximum feret diameters
    min_feret, max_feret, min_feret_triangle = get_max_min_ferets(edge_points=np.copy(pooled_trace).tolist())

    # An open molecule will have a feret ratio of 1, a squished molecule will have a lower feret ratio
    feret_ratio = min_feret / max_feret
    # feret_ratio = max_feret / min_feret

    if plotting:
        plt.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        plt.scatter(pooled_trace[:, 1], pooled_trace[:, 0], c="k", s=20)
        plt.scatter(min_feret_triangle[0][1], min_feret_triangle[0][0], c="red", s=60)
        plt.scatter(min_feret_triangle[1][1], min_feret_triangle[1][0], c="red", s=60)
        plt.scatter(min_feret_triangle[2][1], min_feret_triangle[2][0], c="red", s=60)
        plt.title(f"grain index: {index}, feret ratio: {feret_ratio:.2f}")
        plt.show()

        plt.imshow(grain_mask, cmap=cmap, vmin=-8, vmax=8)
        plt.scatter(pooled_trace[:, 1], pooled_trace[:, 0], c="k", s=20)
        plt.scatter(min_feret_triangle[0][1], min_feret_triangle[0][0], c="red", s=60)
        plt.scatter(min_feret_triangle[1][1], min_feret_triangle[1][0], c="red", s=60)
        plt.scatter(min_feret_triangle[2][1], min_feret_triangle[2][0], c="red", s=60)
        plt.title(f"grain index: {index}, feret ratio: {feret_ratio:.2f}")
        plt.show()

    # Calculate area divided by perimeter
    # Calculate the area of the tace
    area = shoelace(pooled_trace) * p_to_nm**2
    # Calculate the perimeter of the trace
    perimeter = np.sum(np.linalg.norm(np.diff(pooled_trace, axis=0), axis=1)) * p_to_nm

    perimeter_area_ratio = area / perimeter
    # larger

    # copy all the data to the openness_grain_dict and add the feret ratio and perimeter area ratio
    openness_grain_dict[index] = grain_data
    openness_grain_dict[index]["min_feret"] = min_feret
    openness_grain_dict[index]["max_feret"] = max_feret
    openness_grain_dict[index]["feret_ratio"] = feret_ratio
    openness_grain_dict[index]["perimeter_area_ratio"] = perimeter_area_ratio

# Plot openness as a kde
openness_values = [openness_grain_dict[i]["feret_ratio"] for i in openness_grain_dict]
sns.kdeplot(openness_values)
plt.title(f"Feret ratio (min / max) {SAMPLE_TYPE}")
plt.show()

# Plot perimeter area ratio
perimeter_area_ratio_values = [openness_grain_dict[i]["perimeter_area_ratio"] for i in openness_grain_dict]
sns.kdeplot(perimeter_area_ratio_values)
plt.title(f"Area / perimeter {SAMPLE_TYPE}")
plt.show()

In [None]:
# Using the pooled points, calculate an allegory for curvature, the change in angle per nm of trace length.


def angle_diff_signed(v1: np.ndarray, v2: np.ndarray):
    """Calculate the signed angle difference between two vectors.

    Parameters
    ----------
    v1: np.ndarray
        The first vector.
    v2: np.ndarray
        The second vector.

    Returns
    -------
    float
        The signed angle difference between the two vectors.
    """

    # Calculate if the new vector is clockwise or anticlockwise from the old vector

    # Calculate the angle between the vectors
    angle = np.arccos(np.clip(np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)), -1.0, 1.0))

    # Calculate the cross product
    cross = np.cross(v1, v2)

    # If the cross product is positive, the new vector is clockwise from the old vector
    if cross > 0:
        angle = -angle

    return angle

    # """Calculate the angle difference between two vectors.

    # Parameters
    # ----------
    # v1: np.ndarray
    #     The first vector.
    # v2: np.ndarray
    #     The second vector.

    # Returns
    # -------
    # float
    #     The angle difference between the two vectors.
    # """
    # v1_u = v1 / np.linalg.norm(v1)
    # v2_u = v2 / np.linalg.norm(v2)
    # return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))


# Test the angle diff function
v1 = np.array([1, 0])
v2 = np.array([0, 1])
print(np.degrees(angle_diff_signed(v1, v2)))

v1 = np.array([1, 0])
v2 = np.array([1, 1])
print(np.degrees(angle_diff_signed(v1, v2)))

v1 = np.array([1, 0])
v2 = np.array([1, -1])
print(np.degrees(angle_diff_signed(v1, v2)))

v1 = np.array([1, 0])
v2 = np.array([-1, 0])
print(np.degrees(angle_diff_signed(v1, v2)))

v1 = np.array([1, 0])
v2 = np.array([-1, 1])
print(np.degrees(angle_diff_signed(v1, v2)))


def angle_per_nm(trace: np.ndarray, p_to_nm: float, plot: bool = False) -> np.ndarray:
    """Calculate the angle per nm of a trace.

    Parameters
    ----------
    trace: np.ndarray
        The trace to calculate the angle per nm of.
    p_to_nm: float
        The pixel to nm scaling factor.

    Returns
    -------
    np.ndarray
        The angle change per nm for each point in the trace
    """

    # Check if the first point is the same as the last point
    if np.all(trace[0] == trace[-1]):
        raise ValueError("The first and last points are the same")

    if plot:
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))

    angles_per_nm = np.zeros(len(trace))
    angle_diffs = np.zeros(len(trace))

    for index, point in enumerate(trace):
        # print(f"index: {index}")
        if plot:
            ax.scatter(point[1], point[0], c="purple", s=20)

        # Get the vectors to the previous and next points
        if index == 0:
            v_prev = point - trace[-1]
            v_next = trace[index + 1] - point
        if index == len(trace) - 1:
            v_prev = point - trace[index - 1]
            v_next = trace[0] - point
        else:
            v_prev = point - trace[index - 1]
            v_next = trace[index + 1] - point

        # print(f"vprev: {v_prev} vnext: {v_next}")

        # Normalise the vectors to unit length
        norm_v_prev = v_prev / np.linalg.norm(v_prev) * 0.1
        norm_v_next = v_next / np.linalg.norm(v_next) * 0.1

        angle = angle_diff_signed(v_prev, v_next)

        if plot:
            # Plot the vectors
            ax.arrow(
                point[1], point[0], norm_v_prev[1], norm_v_prev[0], head_width=0.01, head_length=0.2, fc="r", ec="r"
            )
            ax.arrow(
                point[1], point[0], norm_v_next[1], norm_v_next[0], head_width=0.01, head_length=0.2, fc="b", ec="b"
            )
            # Write text for the angle
            ax.text(point[1], point[0], f"{np.degrees(angle):.2f}", fontsize=12, color="black")

        distance = np.linalg.norm(v_prev) * p_to_nm

        # print(f"distance: {distance:.4f} angle: {angle:.4f} angle per nm: {angle / distance:.4f}")

        angles_per_nm[index] = angle / distance
        angle_diffs[index] = angle

    if plot:
        plt.plot(trace[:, 1], trace[:, 0], "k")
        plt.show()

    return angles_per_nm, angle_diffs


# # Test the angle per nm function

# trace = np.array(
#     [
#         [0, 0],
#         [1, 0],
#         [1, 1],
#         [2, 1],
#         [4, 5],
#         [6, 2],
#     ]
# )

# px_to_nm = 1

# # Create an ellipse of points
# n = 100
# theta = np.linspace(0, 2 * np.pi, n)
# a = 3
# b = 2
# x = a * np.cos(theta)
# y = b * np.sin(theta)
# ellipse = np.array([x, y]).T
# # Remove the last point as it is the same as the first
# ellipse = ellipse[:-1]
# # Reverse the order of the points
# ellipse = ellipse[::-1]

# plt.plot(ellipse[:, 0], ellipse[:, 1], ".")
# plt.show()

# angles_per_nm, angle_diffs = angle_per_nm(ellipse, px_to_nm, plot=True)

# plt.plot(angles_per_nm, "-o")
# plt.show()


# # Generate a set of points on a circle
# n = 10
# theta = np.linspace(0, 2 * np.pi, n)
# # remove the last point as it is the same as the first
# theta = theta[:-1]
# r = 1
# x = r * np.cos(theta)
# y = r * np.sin(theta)
# circle = np.array([x, y]).T

# # Test the angle per nm function on a circle
# px_to_nm = 1
# angles_per_nm, angle_diffs = angle_per_nm(circle, px_to_nm, plot=True)

# plt.plot(angles_per_nm)
# plt.title(f"angle per unit length for a circle of {n} points")

# print(angles_per_nm)

In [None]:
def flip_if_anticlockwise(trace: np.ndarray):
    # Check if the trace is clockwise or anticlockwise by summing the cross products of the vectors
    # If the sum is positive, the trace is clockwise
    # If the sum is negative, the trace is anticlockwise
    # If the sum is 0, the trace is a straight line
    cross_sum = 0
    for i in range(len(trace) - 1):
        cross_sum += np.cross(trace[i], trace[i + 1])
    if cross_sum > 0:
        # print("clockwise")
        # Reverse the trace
        trace = np.flip(trace, axis=0)
    elif cross_sum < 0:
        # print("anticlockwise")
        pass

    return trace

In [None]:
def defect_stats(array: np.ndarray, threshold: float):
    regions = []
    in_region = False
    # Find the largest continuous region below the threshold
    for index, value in enumerate(array):
        if value > threshold:
            if not in_region:
                # Start region
                region_start = index
                area = 0
                highest_point = value
                highest_point_index = index
                in_region = True
            else:
                area += threshold - value
                if value > highest_point:
                    highest_point = value
                    highest_point_index = index
        elif in_region:
            regions.append(
                {
                    "start": region_start,
                    "end": index,
                    "area": area,
                    "highest_point_index": highest_point_index,
                    "highest_point": highest_point,
                    "defect_threshold": threshold,
                }
            )
            in_region = False
    if in_region:
        regions.append(
            {
                "start": region_start,
                "end": index,
                "area": area,
                "highest_point_index": highest_point_index,
                "highest_point": highest_point,
                "defect_threshold": threshold,
            }
        )

    # Check if there are defects at the start and end of the array
    if len(regions) > 0:
        # Check if the first region starts at the start of the array
        if regions[0]["start"] == 0:
            # And if the last region ends at the end of the array
            if regions[-1]["end"] == len(array) - 1:
                # Combine the first and last regions
                regions[0]["start"] = regions[-1]["start"]
                regions[0]["area"] += regions[-1]["area"]
                if regions[-1]["highest_point"] > regions[0]["highest_point"]:
                    regions[0]["highest_point"] = regions[-1]["highest_point"]
                    regions[0]["highest_point_index"] = regions[-1]["highest_point_index"]
                regions.pop(-1)

    # Number of defects
    number_of_defects = len(regions)

    # Find the largest region
    largest_region_below_threshold = None
    largest_area = 0
    for region in regions:
        if region["area"] > largest_area:
            largest_area = region["area"]
            largest_region_below_threshold = region

    # Find the midpoint of each region
    for region in regions:
        if region["start"] < region["end"]:
            region["midpoint"] = int(np.round((region["start"] + region["end"]) / 2))
        else:
            # Get the negative index to be able to take the average of the two indexes
            temp_startpoint = region["start"] - len(array)
            region["midpoint"] = int(np.round((temp_startpoint + region["end"]) / 2)) % len(array)

    return {
        "defect_number": number_of_defects,
        "defect_largest_region": largest_region_below_threshold,
        "defect_regions": regions,
    }


testarr1 = np.array(
    [
        0.4002,
        0.4014,
        0.4005,
        0.3933,
        0.3781,
        0.3564,
        0.3328,
        0.3114,
        0.2939,
        0.2791,
        0.2641,
        0.2461,
        0.2239,
        0.1980,
        0.1692,
        0.1383,
        0.1057,
        0.0717,
        0.0374,
        0.0052,
        -0.0221,
        -0.0417,
        -0.0518,
        -0.0514,
        -0.0408,
        -0.0202,
        0.0093,
        0.0460,
        0.0866,
        0.1267,
        0.1617,
        0.1886,
        0.2069,
        0.2191,
        0.2295,
        0.2429,
        0.2617,
        0.2853,
        0.3099,
        0.3295,
        0.3387,
        0.3341,
        0.3153,
        0.2847,
        0.2462,
        0.2042,
        0.1627,
        0.1248,
        0.0924,
        0.0666,
        0.0479,
        0.0362,
        0.0310,
        0.0310,
        0.0343,
        0.0384,
        0.0406,
        0.0390,
        0.0330,
        0.0239,
        0.0135,
        0.0034,
        -0.0058,
        -0.0142,
        -0.0220,
        -0.0283,
        -0.0320,
    ]
)

testarr2 = np.array(
    [
        0.2799,
        0.2818,
        0.2845,
        0.2852,
        0.2805,
        0.2674,
        0.2448,
        0.2136,
        0.1773,
        0.1394,
        0.1027,
        0.0681,
        0.0355,
        0.0055,
        -0.0203,
        -0.0392,
        -0.0485,
        -0.0475,
        -0.0378,
        -0.0221,
        -0.0037,
        0.0148,
        0.0320,
        0.0478,
        0.0636,
        0.0819,
        0.1056,
        0.1373,
        0.1780,
        0.2270,
        0.2812,
        0.3358,
        0.3848,
        0.4217,
        0.4408,
        0.4378,
        0.4110,
        0.3628,
        0.2999,
        0.2319,
        0.1692,
        0.1192,
        0.0848,
        0.0643,
        0.0529,
        0.0457,
        0.0390,
        0.0309,
        0.0212,
        0.0113,
        0.0028,
        -0.0025,
        -0.0031,
        0.0015,
        0.0115,
        0.0266,
        0.0459,
        0.0688,
        0.0949,
        0.1244,
        0.1571,
        0.1928,
        0.2301,
        0.2668,
        0.3003,
        0.3282,
        0.3482,
        0.3587,
    ]
)

threshold = 0.2

test_defect_stats_1 = defect_stats(testarr1, threshold)
print(test_defect_stats_1)

test_defect_stats_2 = defect_stats(testarr2, threshold)
print(test_defect_stats_2)


for testarr, test_defect_stats in zip([testarr1, testarr2], [test_defect_stats_1, test_defect_stats_2]):
    plt.plot(testarr)
    plt.ylim(-0.5, 0.5)
    plt.axhline(threshold, color="k", linestyle="--")
    plt.axhline(0, color="k", linestyle="-")
    for region in test_defect_stats["defect_regions"]:
        if region["start"] < region["end"]:
            plt.axvspan(region["start"], region["end"], color="red", alpha=0.3)
        else:
            plt.axvspan(region["start"], len(testarr1), color="red", alpha=0.3)
            plt.axvspan(0, region["end"], color="red", alpha=0.3)
        # Identify the deepest point
        plt.scatter(region["highest_point_index"], testarr[region["highest_point_index"]], c="r")
        # Plot the midpoint
        plt.axvline(region["midpoint"], color="b", linestyle="--")
    plt.show()

In [None]:
def calculate_real_distance_between_points_in_array(
    array: np.ndarray, indexes_to_calculate_distance_between: np.ndarray, p_to_nm: float
):
    # Calculate the distances between each defect along the trace
    to_find = np.copy(indexes_to_calculate_distance_between)
    original = to_find[0]
    current_index = original
    current_position = array[current_index]
    previous_position = current_position
    distances = []
    distance = 0
    while len(to_find) > 0:
        # Update old position
        previous_position = current_position
        # Increment the current position along the trace
        current_index += 1
        if current_index >= len(array):
            current_index -= len(array)
        current_position = array[current_index]

        # Increment the distance
        distance += np.linalg.norm(current_position - previous_position) * p_to_nm

        # Check if the current index is in the to_find list
        if current_index in to_find:
            # Get rid of the current index from the to_find list
            to_find = to_find[to_find != current_index]
            # Store the distance
            distances.append(distance)
            # Reset the distance
            distance = 0

    return distances


# # Test the distance calculation
# array = np.array(
#     [
#         [5, 5],
#         [5, 6],
#         [6, 7],
#         [7, 7],
#         [8, 6],
#         [8, 5],
#         [7, 4],
#         [6, 4],
#         [5, 5],
#     ]
# )

# indexes_to_calculate_distance_between = np.array([0, 3, 7])
# p_to_nm = 1

# distances = calculate_real_distance_between_points_in_array(array, indexes_to_calculate_distance_between, p_to_nm)

# print(distances)

# plt.plot(array[:, 0], array[:, 1], "o-")
# # Mark the start point with a star
# plt.scatter(
#     array[indexes_to_calculate_distance_between[0], 0],
#     array[indexes_to_calculate_distance_between[0], 1],
#     c="k",
#     s=400,
#     marker="*",
# )
# # Mark the points to calculate the distance between
# plt.scatter(
#     array[indexes_to_calculate_distance_between, 0], array[indexes_to_calculate_distance_between, 1], c="r", s=100
# )
# plt.show()

In [None]:
# SIMPLE calculate rate of change of vector angle per nm

from scipy.ndimage import gaussian_filter1d

angle_change_rate_grain_dict = {}
plotting = False

for index, grain_data in openness_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    pooled_trace = grain_data["pooled_trace"]
    p_to_nm = grain_data["p_to_nm"]

    # if index != 49:
    #     continue

    # Degrees per nm threshold
    degrees_per_nm_threshold = 12
    radians_per_nm_threshld = np.radians(degrees_per_nm_threshold)
    if plotting:
        print(f"radians per nm threshold: {radians_per_nm_threshld}")

    pooled_trace = flip_if_anticlockwise(pooled_trace)

    # Sample points once every n nm
    nm_sample_distance = 0.5
    # Sample pooled trace every 2 nm
    nm_sampled_trace = []
    pixel_sample_distance = nm_sample_distance / p_to_nm
    # Starting at zero, for each point in pooled trace, see if the distance from the last point is greater than the sample distance and if not skip until it is
    for i in range(len(pooled_trace)):
        if i == 0:
            nm_sampled_trace.append(pooled_trace[i])
        else:
            # If the distance between the current point and the last point in the sampled trace is greater than the sample distance, add the current point to the sampled trace
            if np.linalg.norm(pooled_trace[i] - nm_sampled_trace[-1]) > pixel_sample_distance:
                nm_sampled_trace.append(pooled_trace[i])

    nm_sampled_trace = np.array(nm_sampled_trace)

    angles_per_nm, angle_diffs = angle_per_nm(nm_sampled_trace, p_to_nm, plot=False)

    # Smooth the angles per nm
    angles_per_nm = gaussian_filter1d(angles_per_nm, 3)

    # Get the defect stats
    defect_stats_dict = defect_stats(angles_per_nm, radians_per_nm_threshld)

    # Print angles per nm with commas between so it can be copied into a spreadsheet
    # print(",".join([f"{angle:.4f}" for angle in angles_per_nm]))

    # Calculate the real distance between the defects
    defect_indexes = np.array([region["midpoint"] for region in defect_stats_dict["defect_regions"]])
    defect_distances = None
    if len(defect_indexes) > 1:
        defect_distances = calculate_real_distance_between_points_in_array(nm_sampled_trace, defect_indexes, p_to_nm)
        if plotting:
            print(f"defect distances: {defect_distances}")

    # Tag the molecules
    # If 3 defects, then dorito
    # If 2 defects, then churro or pasty
    if defect_stats_dict["defect_number"] == 3:
        tag = "dorito"
    elif defect_stats_dict["defect_number"] == 2:
        # If the two defect distances are similar then churro, else pasty
        # Threshold as % of the total length of the trace
        pasty_defect_difference_percentage = 0.1
        if np.abs(defect_distances[0] - defect_distances[1]) < pasty_defect_difference_percentage * np.sum(
            defect_distances
        ):
            tag = "churro"
        else:
            tag = "pasty"
    elif defect_stats_dict["defect_number"] == 1:
        tag = "teardrop"
    else:
        tag = "open"

    if plotting:
        plt.plot(angles_per_nm)
        plt.ylim(-0.5, 0.5)
        plt.axhline(radians_per_nm_threshld, color="k", linestyle="--")
        plt.axhline(0, color="k", linestyle="-")
        for region in defect_stats_dict["defect_regions"]:
            if region["start"] < region["end"]:
                plt.axvspan(region["start"], region["end"], color="red", alpha=0.3)
            else:
                plt.axvspan(region["start"], len(angles_per_nm), color="red", alpha=0.3)
                plt.axvspan(0, region["end"], color="red", alpha=0.3)
            # Identify the deepest point
            # plt.scatter(region["highest_point_index"], angles_per_nm[region["highest_point_index"]], c="lime")
            # Plot the midpoint
            plt.scatter(region["midpoint"], angles_per_nm[region["midpoint"]], c="lime")
        plt.show()

        assert len(angles_per_nm) == len(nm_sampled_trace)

        plt.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        plt.plot(pooled_trace[:, 1], pooled_trace[:, 0], "r")
        plt.plot(nm_sampled_trace[:, 1], nm_sampled_trace[:, 0], "b")
        # # For each defect, mark its highest point
        # for region in defect_stats_dict["defect_regions"]:
        #     plt.scatter(nm_sampled_trace[region["highest_point_index"]][1], nm_sampled_trace[region["highest_point_index"]][0], c="lime")
        # For each defect mark its midpoint
        for region in defect_stats_dict["defect_regions"]:
            plt.scatter(nm_sampled_trace[region["midpoint"]][1], nm_sampled_trace[region["midpoint"]][0], c="lime")
        plt.title(f"tag: {tag}")

    plt.show()

    angle_change_rate_grain_dict[index] = grain_data
    angle_change_rate_grain_dict[index]["angles_per_nm"] = angles_per_nm
    angle_change_rate_grain_dict[index]["simple_defect_stats"] = defect_stats_dict
    angle_change_rate_grain_dict[index]["simple_defect_distances"] = defect_distances
    angle_change_rate_grain_dict[index]["simple_nm_sampled_trace"] = nm_sampled_trace
    angle_change_rate_grain_dict[index]["simple_tag"] = tag

    # angle_change_rate_grain_dict[index] = {
    #     "image": grain_image,
    #     "mask": grain_mask,
    #     "pooled_trace": pooled_trace,
    #     "angles_per_nm": angles_per_nm,
    #     "p_to_nm": p_to_nm,
    #     "defect_stats": defect_stats_dict,
    #     "defect_distances": defect_distances,
    #     "nm_sampled_trace": nm_sampled_trace,
    #     "tag": tag,
    # }


def plot_images(
    images: list,
    tags: list,
    grain_indexes: list,
    nm_sampled_traces: np.ndarray,
    defect_stats: list,
    px_to_nms: list,
    width=5,
    cmap=cmap,
    vmin=-8,
    vmax=8,
    title: str = "",
):
    num_images = len(images)
    rows = np.ceil(num_images / width).astype(int)
    if rows == 1:
        rows = 2
    fig, ax = plt.subplots(rows, width, figsize=(30, 10 + 10 * rows))
    for i, (image, tag, grain_index, single_defect_stats, nm_sampled_trace) in enumerate(
        zip(images, tags, grain_indexes, defect_stats, nm_sampled_traces)
    ):
        ax[i // width, i % width].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        # Plot the trace
        ax[i // width, i % width].plot(nm_sampled_trace[:, 1], nm_sampled_trace[:, 0], "b")
        ax[i // width, i % width].axis("off")
        ax[i // width, i % width].set_title(f"grain index: {grain_index} tag: {tag}")

        # Plot midpoints
        for region in single_defect_stats["defect_regions"]:
            ax[i // width, i % width].scatter(
                nm_sampled_trace[region["midpoint"]][1], nm_sampled_trace[region["midpoint"]][0], c="lime"
            )

    plt.suptitle(title, fontsize=30)
    fig.tight_layout()
    plt.show()


plot_images(
    [angle_change_rate_grain_dict[i]["image"] for i in angle_change_rate_grain_dict],
    [angle_change_rate_grain_dict[i]["simple_tag"] for i in angle_change_rate_grain_dict],
    [i for i in angle_change_rate_grain_dict],
    [angle_change_rate_grain_dict[i]["simple_nm_sampled_trace"] for i in angle_change_rate_grain_dict],
    [angle_change_rate_grain_dict[i]["simple_defect_stats"] for i in angle_change_rate_grain_dict],
    [angle_change_rate_grain_dict[i]["p_to_nm"] for i in angle_change_rate_grain_dict],
)

# Plot just the churros
# tag_to_plot = "unclassified"
# indexes = [i for i in angle_change_rate_grain_dict if angle_change_rate_grain_dict[i]["tag"] == tag_to_plot]
# print(tag_to_plot, indexes)
# plot_images(
#     [angle_change_rate_grain_dict[i]["image"] for i in indexes],
#     [angle_change_rate_grain_dict[i]["tag"] for i in indexes],
#     [i for i in indexes],
#     [angle_change_rate_grain_dict[i]["nm_sampled_trace"] for i in indexes],
#     [angle_change_rate_grain_dict[i]["defect_stats"] for i in indexes],
#     [angle_change_rate_grain_dict[i]["p_to_nm"] for i in indexes],
#     title=tag_to_plot,
# )

# Bar chart of tags
tags = [angle_change_rate_grain_dict[i]["simple_tag"] for i in angle_change_rate_grain_dict]
unique_tags, counts = np.unique(tags, return_counts=True)
# sort by alphabetical order
unique_tags, counts = zip(*sorted(zip(unique_tags, counts)))
plt.bar(unique_tags, counts)
plt.title(f"(Angle per nm based) distribution for {SAMPLE_TYPE}")
plt.show()

In [None]:
def weighted_average_position(array: np.ndarray, clip_min: float) -> float:
    """Calculate the weighted average position in an array.

    Parameters
    ----------
    array: np.ndarray
        The array to calculate the weighted average position of.

    Returns
    -------
    float
        The weighted average position in the array.
    """
    positions = np.arange(len(array))
    weights = np.copy(array)
    weights[weights < clip_min] = clip_min
    weighted_average = np.average(positions, weights=array)
    return weighted_average


testarr = np.array([0, 1, 3, 5, 10, 10000000, 2, 12, 0])

print(weighted_average_position(testarr, clip_min=0))


def weighted_mean_defect_position(
    defect_angle_shifts: np.ndarray, defect_start_index: int, max_index: int, clip_min: float = 0.0
):
    weighted_mean_angle_shift_index = (
        weighted_average_position(defect_angle_shifts, clip_min=clip_min) + defect_start_index
    )
    weighted_mean_angle_shift_index_int = int(np.round(weighted_mean_angle_shift_index))
    if weighted_mean_angle_shift_index >= len(angles_per_nm):
        weighted_mean_angle_shift_index -= len(angles_per_nm)
    if weighted_mean_angle_shift_index_int >= len(angles_per_nm):
        weighted_mean_angle_shift_index_int -= len(angles_per_nm)

    return weighted_mean_angle_shift_index, weighted_mean_angle_shift_index_int


def find_distance_window_looped(distances: np.ndarray, window_distance: float, start_index: int):
    """Find a window that covers a certain distance in a set of points in a loop (but the start point is not repeated at the end)
    where the distances between the points are known including the distance between the end point and the start point.
    """

    distances = np.copy(distances)

    window_current_distance = 0
    window_start_index = start_index
    proposed_end_index = start_index
    found_end = False
    max_iterations = len(distances)
    iterations = 0
    while not found_end:
        # print(f"start index: {window_start_index}, proposed end index: {proposed_end_index}, distance so far: {window_current_distance}")
        # Check if distance is greater than the requested distance
        if window_current_distance > window_distance:
            # print(f"found end at index: {proposed_end_index}")
            found_end = True
            window_end_index = proposed_end_index
            return window_start_index, window_end_index, window_current_distance
        else:
            # Increment the window distance
            # print(f"adding distance: {distances[proposed_end_index]}")
            window_current_distance += distances[proposed_end_index]
            # Increment the proposed end index
            if proposed_end_index < len(distances) - 1:
                proposed_end_index += 1
            else:
                proposed_end_index = 0

        # Safety for infinite loop
        if iterations > max_iterations:
            raise ValueError("Max iterations reached")
        iterations += 1

In [None]:
def combine_two_defects(
    defect_0: dict,
    defect_1: dict,
    max_trace_index: int,
    angle_diffs: np.ndarray,
):
    """Combine two defects into one defect."""

    # Combine the defects
    combined_indexes = np.unique(np.append(defect_0["indexes"], defect_1["indexes"]))

    # Find the start and end points of the combined indexes
    # Check if the combined indexes span the end of the array
    if len(angles_per_nm) - 1 in combined_indexes and 0 in combined_indexes:
        # If so, the starting index will be the index without a number preceding it and the end index will be the index without a number following it
        # To find end index, count forward from 0 until there is a number missing
        for candidate_end_index in range(len(angles_per_nm)):
            if candidate_end_index + 1 not in combined_indexes:
                end_index = candidate_end_index
                break
        # To find start index, count backward from the end until there is a number missing
        for candidate_start_index in range(len(angles_per_nm) - 1, 0, -1):
            if candidate_start_index - 1 not in combined_indexes:
                start_index = candidate_start_index
                break
    else:
        start_index = np.min(combined_indexes)
        end_index = np.max(combined_indexes)

    # Calculate the total angle shift over the defect
    if end_index > start_index:
        total_angle_shift = np.sum(angle_diffs[start_index:end_index])
    else:
        total_angle_shift = np.sum(np.append(angle_diffs[start_index:], angle_diffs[:end_index]))

    # Calculate the maximum angle shift and the index of the point with maximum shift
    if end_index > start_index:
        local_angle_shifts = angle_diffs[start_index:end_index]
        local_maximum_angle_shift_index = np.argmax(local_angle_shifts)
        maximum_angle_shift = local_angle_shifts[local_maximum_angle_shift_index]
        maximum_angle_shift_index = start_index + local_maximum_angle_shift_index
    else:
        local_angle_shifts = np.append(angle_diffs[start_index:], angle_diffs[:end_index])
        local_maximum_angle_shift_index = np.argmax(local_angle_shifts)
        maximum_angle_shift = local_angle_shifts[local_maximum_angle_shift_index]
        maximum_angle_shift_index = start_index + local_maximum_angle_shift_index
        if maximum_angle_shift_index >= len(angles_per_nm):
            maximum_angle_shift_index -= len(angles_per_nm)

    # Calculate the weighted mean angle shift index
    # Get the angle diffs for the combined indexes
    if end_index > start_index:
        defect_angle_shifts = angle_diffs[start_index:end_index]
    else:
        defect_angle_shifts = np.append(angle_diffs[start_index:], angle_diffs[:end_index])

    (
        weighted_mean_angle_shift_index,
        weighted_mean_angle_shift_index_int,
    ) = weighted_mean_defect_position(
        defect_angle_shifts=defect_angle_shifts,
        defect_start_index=start_index,
        max_index=len(angles_per_nm),
    )

    combined_defect = {
        "start_index": start_index,
        "end_index": end_index,
        "maximum_total_angle_shift": total_angle_shift,
        "maximum_angle_shift": maximum_angle_shift,
        "maximum_angle_shift_index": maximum_angle_shift_index,
        "indexes": combined_indexes,
        "weighted_mean_angle_shift_index": weighted_mean_angle_shift_index,
        "weighted_mean_angle_shift_index_int": weighted_mean_angle_shift_index_int,
    }

    return combined_defect


def combine_overlapping_defects(defects: list):
    """Combine any overlapping defects in a list of defects where each defect is a dictionary of statistics for the defect.

    Defects start at at starting point and end at an ending point. If two defects share any points then they are overlapping
    and should be combined. Defects can span the start and end of the array, which makes combination difficult.
    """

    # Flag to check if any defects have been combined
    defects_were_combined = True
    while defects_were_combined:
        # Reset the flag
        defects_were_combined = False
        # For each defect, check if it overlaps with any other defect
        for i, defect_0 in enumerate(defects):
            for j, defect_1 in enumerate(defects):
                if i != j:
                    # Check if the defects overlap
                    if len(np.intersect1d(defect_0["indexes"], defect_1["indexes"])) > 0:
                        # Combine the defects
                        combined_defect = combine_two_defects(defect_0, defect_1, len(angles_per_nm), angle_diffs)
                        # Remove the old defects
                        defects.pop(i)
                        defects.pop(j - 1)
                        # Add the new defect
                        defects.append(combined_defect)
                        # Set the flag
                        defects_were_combined = True
                        break
            if defects_were_combined:
                break

    return defects

# Complex defect detection

In [None]:
plotting = True
plot_results = False

# Define what a defect is by any section that turns d degrees in n nm
defect_degrees_value = 80
defect_nm_value = 5.0
# pasty_distance_deviation_threshold_nm = 3.0
pasty_distance_deviation_threshold_percentage = 0.08

turn_in_distance_grain_dict = {}

for index, grain_data in angle_change_rate_grain_dict.items():
    if index != 74:
        continue

    # print(f"grain index: {index}")
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]
    trace = grain_data["trace"]
    pooled_trace = grain_data["pooled_trace"]

    # Check if the trace is clockwise or anticlockwise by summing the cross products of the vectors
    pooled_trace = flip_if_anticlockwise(pooled_trace)

    n = 1
    pooled_trace_every_nth_point = pooled_trace[::n]
    # Create a copy where the first point is appended to the end to calculate the distances between the last and first point
    if np.array_equal(pooled_trace_every_nth_point[0], pooled_trace_every_nth_point[-1]):
        pooled_trace_every_nth_point_extra_start = np.copy(pooled_trace_every_nth_point)
        pooled_trace_every_nth_point = pooled_trace_every_nth_point[:-1]
    else:
        pooled_trace_every_nth_point_extra_start = np.copy(pooled_trace_every_nth_point)
        pooled_trace_every_nth_point_extra_start = np.append(
            pooled_trace_every_nth_point_extra_start, [pooled_trace_every_nth_point[0]], axis=0
        )

    distances_between_points = np.linalg.norm(np.diff(pooled_trace_every_nth_point_extra_start, axis=0), axis=1)
    total_distance = np.sum(distances_between_points)

    distances_between_points_nm = distances_between_points * p_to_nm
    total_distance_nm = total_distance * p_to_nm

    # print(f"len of distances ends included: {len(distances_between_points)}, total distance: {total_distance_nm} nm")

    # print(f"len of points: {len(pooled_trace_every_nth_point)}")

    # plt.plot(distances_between_points)
    # plt.title(f"Distances between pooled points for grain {index}")
    # plt.show()

    angles_per_nm, angle_diffs = angle_per_nm(pooled_trace_every_nth_point, p_to_nm)

    print(f"total angle change: {np.degrees(np.sum(angle_diffs))}")

    # plt.plot(angle_diffs)
    # plt.show()

    # print(f"angle diffs: {angle_diffs}")

    # Detect if at any point more than curve_degrees_value is turned in curve_nm_value
    assert len(angles_per_nm) == len(distances_between_points)
    assert len(angles_per_nm) == len(pooled_trace_every_nth_point)

    in_defect = False
    defects = []
    maximum_total_angle_shift = 0
    maximum_angle_shift_index = 0
    for point_index, (point, angle_shift, distance) in enumerate(
        zip(pooled_trace_every_nth_point, angles_per_nm, distances_between_points)
    ):
        print(f"point index: {point_index}, angle shift: {angle_shift}, distance: {distance}")

        if plotting:
            plt.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
            plt.scatter(pooled_trace_every_nth_point[:, 1], pooled_trace_every_nth_point[:, 0], c="k", s=20)
            plt.scatter(
                pooled_trace_every_nth_point[point_index, 1],
                pooled_trace_every_nth_point[point_index, 0],
                c="white",
                s=40,
            )

        # Get window
        window_start_index, window_end_index, window_distance = find_distance_window_looped(
            distances_between_points_nm, defect_nm_value, point_index
        )
        # print(
        #     f"  window start index: {window_start_index}, window end index: {window_end_index}, window distance: {window_distance}"
        # )

        if plotting:
            plt.scatter(
                pooled_trace_every_nth_point[window_start_index, 1],
                pooled_trace_every_nth_point[window_start_index, 0],
                c="b",
                s=20,
            )
            plt.scatter(
                pooled_trace_every_nth_point[window_end_index, 1],
                pooled_trace_every_nth_point[window_end_index, 0],
                c="r",
                s=20,
            )

        # See if the angle shift is greater than the defect_degrees_value
        # Calculate total angle change by summing the angle shifts in the window
        total_angle_shift = 0
        # End index might be lower than start index if the window wraps around
        if window_end_index > window_start_index:
            window_angle_shifts = angle_diffs[window_start_index : window_end_index + 1]
        else:
            window_angle_shifts = np.append(angle_diffs[window_start_index:], angle_diffs[: window_end_index + 1])
            # print(f"window end index: {window_end_index}, window start index: {window_start_index}, angle_diffs: {angle_diffs}")

        total_angle_shift = np.sum(window_angle_shifts)
        # print(f"total angle shift: {total_angle_shift} ({np.degrees(total_angle_shift)} degrees)")

        if np.abs(total_angle_shift) > np.radians(defect_degrees_value):
            # print("window angle > defect_degrees_value")
            # print(f"currently in defect: {in_defect}")
            if not in_defect:
                # print(f"@@@ DEFECT START at index: {point_index} window end index: {window_end_index}")
                if plotting:
                    # Plot between start and end index following the pooled trace
                    if window_end_index > window_start_index:
                        plt.plot(
                            pooled_trace_every_nth_point[window_start_index : window_end_index + 1, 1],
                            pooled_trace_every_nth_point[window_start_index : window_end_index + 1, 0],
                            "g",
                        )
                    else:
                        plt.plot(
                            np.append(
                                pooled_trace_every_nth_point[window_start_index:, 1],
                                pooled_trace_every_nth_point[: window_end_index + 1, 1],
                            ),
                            np.append(
                                pooled_trace_every_nth_point[window_start_index:, 0],
                                pooled_trace_every_nth_point[: window_end_index + 1, 0],
                            ),
                            "g",
                        )
                    plt.title(
                        f"defect start index: {point_index}, end index: {window_end_index}, window angle shifts: {window_angle_shifts} total angle shift: {total_angle_shift} ({np.degrees(total_angle_shift)} degrees)"
                    )
                in_defect = True
                defect_start_index = point_index
                if window_end_index > window_start_index:
                    defect_indexes = np.arange(window_start_index, window_end_index + 1)
                else:
                    defect_indexes = np.append(
                        np.arange(window_start_index, len(angles_per_nm)), np.arange(0, window_end_index + 1)
                    )
                defect_shifts = window_angle_shifts
                maximum_total_angle_shift = total_angle_shift

                # print(f"defect indexes: {defect_indexes}")
                # print(f"defect shifts: {defect_shifts}")
                # print(
                #     f"initial total angle shift: {maximum_total_angle_shift} ({np.degrees(maximum_total_angle_shift)} degrees)"
                # )
                maximum_angle_shift = np.max(window_angle_shifts)
                maximum_angle_shift_index = window_start_index + np.argmax(window_angle_shifts)
                if maximum_angle_shift_index >= len(angles_per_nm):
                    maximum_angle_shift_index -= len(angles_per_nm)
                # print(f"starting defect, max shift: {maximum_angle_shift}, index: {maximum_angle_shift_index}")
            else:
                # Add the new point(s) to the defect indexes. Note there may be more than one point added due to a high density of points in the window since we are sampling every n nm rather than n points
                # Careful of the case where the window wraps around
                if window_end_index > window_start_index:
                    defect_indexes = np.append(defect_indexes, np.arange(window_start_index, window_end_index + 1))
                else:
                    defect_indexes = np.append(
                        defect_indexes,
                        np.append(
                            np.arange(window_start_index, len(angles_per_nm)), np.arange(0, window_end_index + 1)
                        ),
                    )
                # Ensure each index is unique
                defect_indexes = np.unique(defect_indexes)
                # Add the new angle shift to the defect shifts
                defect_shifts = np.append(defect_shifts, angle_diffs[window_end_index])
                # Check if the maximum angle shift is a new maximum
                window_maximum = np.max(window_angle_shifts)
                if window_maximum > maximum_angle_shift:
                    maximum_angle_shift = window_maximum
                    window_maximum_angle_shift_index = np.argmax(window_angle_shifts)
                    maximum_angle_shift_index = window_start_index + window_maximum_angle_shift_index
                    if maximum_angle_shift_index >= len(angles_per_nm):
                        maximum_angle_shift_index -= len(angles_per_nm)
                    # print(f"window maximum is new maximum: {maximum_angle_shift} at index: {maximum_angle_shift_index}")
                # Check if the total angle shift is greater than the current maximum
                if np.abs(total_angle_shift) > np.abs(maximum_total_angle_shift):
                    maximum_total_angle_shift = total_angle_shift
                    # Store the index of the maximum angle shift
                    if window_end_index > window_start_index:
                        maximum_total_angle_shift_index = window_start_index + np.argmax(window_angle_shifts)
                    else:
                        maximum_total_angle_shift_index = np.argmax(window_angle_shifts)
                # Plot between start and end index following the pooled trace
                if plotting:
                    if window_end_index > window_start_index:
                        plt.plot(
                            pooled_trace_every_nth_point[window_start_index : window_end_index + 1, 1],
                            pooled_trace_every_nth_point[window_start_index : window_end_index + 1, 0],
                            "g",
                        )
                    else:
                        plt.plot(
                            np.append(
                                pooled_trace_every_nth_point[window_start_index:, 1],
                                pooled_trace_every_nth_point[: window_end_index + 1, 1],
                            ),
                            np.append(
                                pooled_trace_every_nth_point[window_start_index:, 0],
                                pooled_trace_every_nth_point[: window_end_index + 1, 0],
                            ),
                            "g",
                        )
                    plt.title(
                        f"defect start index: {defect_start_index}, end index: {window_end_index}, window angle shifts: {window_angle_shifts} max shift: {maximum_angle_shift}, index: {maximum_angle_shift_index} total angle shift: {total_angle_shift} ({np.degrees(total_angle_shift)} degrees)"
                    )
            # print(f"defect indexes: {defect_indexes}")

        else:
            if plotting:
                plt.title(
                    f"no defect at index: {point_index} window end index: {window_end_index} window angle shifts: {window_angle_shifts} total angle shift: {total_angle_shift} ({np.degrees(total_angle_shift)} degrees)"
                )
            if in_defect:
                in_defect = False
                defect_end_index = window_end_index
                # print(
                #     f"@@@ DEFECT DONE: start index: {defect_start_index}, end index: {defect_end_index}, window angle shifts: {window_angle_shifts} max shift: {maximum_angle_shift}, index: {maximum_angle_shift_index}"
                # )
                # print(f"defect indexes: {defect_indexes}")
                # print(f"defect shifts: {defect_shifts}")

                weighted_mean_angle_shift_index, weighted_mean_angle_shift_index_int = weighted_mean_defect_position(
                    defect_angle_shifts=defect_shifts,
                    defect_start_index=defect_start_index,
                    max_index=len(angles_per_nm),
                )

                defects.append(
                    {
                        "start_index": defect_start_index,
                        "end_index": defect_end_index,
                        "maximum_total_angle_shift": maximum_total_angle_shift,
                        "maximum_angle_shift": maximum_angle_shift,
                        "maximum_angle_shift_index": maximum_angle_shift_index,
                        "indexes": defect_indexes,
                        "weighted_mean_angle_shift_index": weighted_mean_angle_shift_index,
                        "weighted_mean_angle_shift_index_int": weighted_mean_angle_shift_index_int,
                    }
                )

        plt.show()

    if in_defect:
        in_defect = False
        defect_end_index = window_end_index
        # print(f"@@@ DEFECT DONE: start index: {defect_start_index}, end index: {defect_end_index} max shift: {maximum_angle_shift}, index: {maximum_angle_shift_index}")

        weighted_mean_angle_shift_index, weighted_mean_angle_shift_index_int = weighted_mean_defect_position(
            defect_angle_shifts=defect_shifts,
            defect_start_index=defect_start_index,
            max_index=len(angles_per_nm),
        )
        defects.append(
            {
                "start_index": defect_start_index,
                "end_index": defect_end_index,
                "maximum_total_angle_shift": maximum_total_angle_shift,
                "maximum_angle_shift": maximum_angle_shift,
                "maximum_angle_shift_index": maximum_angle_shift_index,
                "indexes": defect_indexes,
                "weighted_mean_angle_shift_index": weighted_mean_angle_shift_index,
                "weighted_mean_angle_shift_index_int": weighted_mean_angle_shift_index_int,
            }
        )

    # Combine any overlapping regions
    # for i, defect in enumerate(defects):
    #     print(f"defect {i}: {defect}")

    combined_defects = combine_overlapping_defects(defects)

    # for i, defect in enumerate(combined_defects):
    #     print(f"combined defect {i}: {defect}")

    defects = combined_defects

    if len(defects) > 0:
        # For each defect's weighted mean angle shift index int, calculate the distance to the next defect's weighted mean angle shift index int
        # Get a list of each defect's weighted mean angle shift index int
        weighted_mean_angle_shift_indexes_int = [defect["weighted_mean_angle_shift_index_int"] for defect in defects]
        # Sort the list
        weighted_mean_angle_shift_indexes_int.sort()
        # print(f"sorted weighted mean angle shift indexes int: {weighted_mean_angle_shift_indexes_int}")
        # Calculate the distances between each defect along the trace
        to_find = np.copy(weighted_mean_angle_shift_indexes_int)
        found_original = False
        original = to_find[0]
        current_index = original
        current_position = pooled_trace_every_nth_point[current_index]
        previous_position = current_position
        distances = []
        distance = 0
        while len(to_find) > 0:
            # Update old position
            previous_position = current_position
            # Increment the current position along the trace
            current_index += 1
            if current_index >= len(pooled_trace_every_nth_point):
                current_index -= len(pooled_trace_every_nth_point)
            current_position = pooled_trace_every_nth_point[current_index]

            # Increment the distance
            distance += np.linalg.norm(current_position - previous_position) * p_to_nm

            # Check if the current index is in the to_find list
            if current_index in to_find:
                # Get rid of the current index from the to_find list
                to_find = to_find[to_find != current_index]
                # Store the distance
                distances.append(distance)
                # Reset the distance
                distance = 0
    else:
        distances = []

    # print(f"distances between defects: {distances}")

    if plot_results:
        fig, ax = plt.subplots(1, len(defects), figsize=(10 * len(defects), 10))
        for defect_index, defect in enumerate(defects):
            thisax = ax[defect_index]
            # Plot the defect
            thisax.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
            # Plot a horizontal line at the top left of the image starting at 10, 10 with a length equal to the nm distance threshold for defect
            thisax.plot([2, 2 + defect_nm_value / p_to_nm], [2, 2], "r")
            # Plot the points
            thisax.scatter(pooled_trace_every_nth_point[:, 1], pooled_trace_every_nth_point[:, 0], c="k", s=20)
            # Plot the start point of the entire trace
            thisax.scatter(pooled_trace_every_nth_point[0, 1], pooled_trace_every_nth_point[0, 0], c="pink", s=60)
            # Write the angle shift at each point
            for i, point in enumerate(pooled_trace_every_nth_point):
                thisax.text(
                    point[1], point[0], f"{int(np.round(np.degrees(angle_diffs[i])))}", fontsize=14, color="white"
                )
            # Plot the start and end index
            thisax.scatter(
                pooled_trace_every_nth_point[defect["start_index"], 1],
                pooled_trace_every_nth_point[defect["start_index"], 0],
                c="blue",
                s=60,
                alpha=1,
            )
            thisax.scatter(
                pooled_trace_every_nth_point[defect["end_index"], 1],
                pooled_trace_every_nth_point[defect["end_index"], 0],
                c="red",
                s=60,
                alpha=1,
            )
            # Plot the maximum angle shift index
            thisax.scatter(
                pooled_trace_every_nth_point[defect["maximum_angle_shift_index"], 1],
                pooled_trace_every_nth_point[defect["maximum_angle_shift_index"], 0],
                c="green",
                s=100,
                alpha=1,
            )
            # Plot the weighted mean angle shift index
            thisax.scatter(
                pooled_trace_every_nth_point[defect["weighted_mean_angle_shift_index_int"], 1],
                pooled_trace_every_nth_point[defect["weighted_mean_angle_shift_index_int"], 0],
                c="yellow",
                s=100,
                alpha=1,
            )

            # Plot a line between the start and end index following the pooled trace
            if defect["end_index"] > defect["start_index"]:
                thisax.plot(
                    pooled_trace_every_nth_point[defect["start_index"] : defect["end_index"] + 1, 1],
                    pooled_trace_every_nth_point[defect["start_index"] : defect["end_index"] + 1, 0],
                    "b",
                )
            else:
                thisax.plot(
                    np.append(
                        pooled_trace_every_nth_point[defect["start_index"] :, 1],
                        pooled_trace_every_nth_point[: defect["end_index"] + 1, 1],
                    ),
                    np.append(
                        pooled_trace_every_nth_point[defect["start_index"] :, 0],
                        pooled_trace_every_nth_point[: defect["end_index"] + 1, 0],
                    ),
                    "b",
                )

        plt.suptitle(
            f"Defects {len(defects)} for grain {index}, total distance: {total_distance_nm:.2f} nm, defect degrees: {defect_degrees_value}, defect nm: {defect_nm_value}, p_to_nm: {p_to_nm:.2f}"
        )
        fig.tight_layout()
        plt.show()

    # Classification
    # If there are 3 defects then it is a dorito
    # If there are 2 defects then it is either a churro or pasty
    # If the two defects have significantly different distances then it's a pasty, else it's a churro
    # If it has fewer than 2 or more than 3, then it is unclassified

    if len(defects) == 3:
        tag = "dorito"
    elif len(defects) == 2:
        abs_diff_distance = np.abs(distances[0] - distances[1])
        # print(f"index : {index} abs diff distance: {abs_diff_distance} total distance: {total_distance_nm} pasty threshold: {pasty_distance_deviation_threshold_percentage * total_distance_nm}")
        if abs_diff_distance > pasty_distance_deviation_threshold_percentage * total_distance_nm:
            tag = "pasty"
        else:
            tag = "churro"
    elif len(defects) == 1:
        tag = "teardrop"
    else:
        tag = "open"

    turn_in_distance_grain_dict[index] = grain_data
    turn_in_distance_grain_dict[index]["complex_defects"] = defects
    turn_in_distance_grain_dict[index]["complex_distances_between_defects"] = distances
    turn_in_distance_grain_dict[index]["complex_tag"] = tag
    turn_in_distance_grain_dict[index]["total_distance"] = total_distance_nm
    turn_in_distance_grain_dict[index]["complex_num_defects"] = len(defects)
    turn_in_distance_grain_dict[index]["distances_between_points"] = distances_between_points
    turn_in_distance_grain_dict[index]["complex_angles_per_nm"] = angles_per_nm


def plot_images(
    images: list,
    tags: list,
    traces: list,
    grain_indexes: list,
    defects: list,
    distances_between_defects: list,
    feret_ratios: list,
    area_perimeter_ratios: list,
    px_to_nms: list,
    width=5,
    cmap=cmap,
    vmin=-8,
    vmax=8,
):
    num_images = len(images)
    rows = np.ceil(num_images / width).astype(int)
    fig, ax = plt.subplots(rows, width, figsize=(30, 50))
    for i, (
        image,
        tag,
        grain_index,
        trace,
        defect_dict,
        defect_distances,
        feret_ratio,
        area_perimeter_ratio,
    ) in enumerate(
        zip(
            images, tags, grain_indexes, traces, defects, distances_between_defects, feret_ratios, area_perimeter_ratios
        )
    ):
        if rows == 1:
            thisax = ax[i]
        else:
            thisax = ax[i // width, i % width]
        thisax.imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        thisax.axis("off")
        distances_between_defects_string = ", ".join([f"{distance:.2f}" for distance in defect_distances])
        thisax.set_title(
            f"index: {grain_index} tag: {tag} p_to_nm: {px_to_nms[i]:.2f}\n defect distances: {distances_between_defects_string} total distance: {np.sum(defect_distances):.2f} nm \n feret ratio: {feret_ratio:.2f} area perimeter ratio: {area_perimeter_ratio:.2f}"
        )
        thisax.plot(trace[:, 1], trace[:, 0], "green")
        for defect_index, defect in enumerate(defect_dict):
            thisax.scatter(
                trace[defect["weighted_mean_angle_shift_index_int"]][1],
                trace[defect["weighted_mean_angle_shift_index_int"]][0],
                c="white",
            )
    fig.tight_layout()
    plt.show()


# plot_images(
#     [turn_in_distance_grain_dict[i]["image"] for i in turn_in_distance_grain_dict],
#     [turn_in_distance_grain_dict[i]["tag"] for i in turn_in_distance_grain_dict],
#     [i for i in turn_in_distance_grain_dict],
#     [turn_in_distance_grain_dict[i]["p_to_nm"] for i in turn_in_distance_grain_dict],
# )

for tag_to_plot in ["churro", "pasty", "dorito", "teardrop", "open"]:
    indexes = [i for i in turn_in_distance_grain_dict if turn_in_distance_grain_dict[i]["complex_tag"] == tag_to_plot]
    print(tag_to_plot, indexes)
    if len(indexes) > 0:
        plot_images(
            [turn_in_distance_grain_dict[i]["image"] for i in indexes],
            [turn_in_distance_grain_dict[i]["complex_tag"] for i in indexes],
            [turn_in_distance_grain_dict[i]["pooled_trace"] for i in indexes],
            [i for i in indexes],
            [turn_in_distance_grain_dict[i]["complex_defects"] for i in indexes],
            [turn_in_distance_grain_dict[i]["complex_distances_between_defects"] for i in indexes],
            [turn_in_distance_grain_dict[i]["feret_ratio"] for i in indexes],
            [turn_in_distance_grain_dict[i]["perimeter_area_ratio"] for i in indexes],
            [turn_in_distance_grain_dict[i]["p_to_nm"] for i in indexes],
        )

# Plot bar chart of tags
tags = [turn_in_distance_grain_dict[i]["complex_tag"] for i in turn_in_distance_grain_dict]
unique_tags = ["churro", "pasty", "dorito", "teardrop", "open"]
counts = [tags.count(tag) for tag in unique_tags]
plt.bar(unique_tags, counts)
plt.title(f"(Turn in distance based) distribution for {SAMPLE_TYPE}")
plt.show()

In [None]:
# from scipy.optimize import minimize

# def circularness(points: np.ndarray, initial_radius: float):
#     """Measure how circular a set of points are by fitting a circle to the points and measuring the residuals and return
#     the center and radius of the circle along with the circularness metric which is the mean distance of the points from
#     the circle divided by the radius of the circle.
#     """

#     x0 = [np.mean(points[:, 0]), np.mean(points[:, 1]), initial_radius]

#     def objective(x):
#         # Objective function to minimize, x[0] is the x coordinate of the center, x[1] is the y coordinate of the center, x[2] is the radius
#         # The sum of squared distances of the points from the circle
#         return np.sum((np.sqrt((points[:, 0] - x[0])**2 + (points[:, 1] - x[1])**2) - x[2])**2)

#     # Minimize the objective function
#     res = minimize(objective, x0, method="nelder-mead")

#     centre_x = res.x[0]
#     centre_y = res.x[1]
#     radius = res.x[2]

#     # squared residuals
#     squared_residuals = (np.sqrt((points[:, 0] - centre_x)**2 + (points[:, 1] - centre_y)**2) - radius)**2
#     sum_squared_residuals = np.sum(squared_residuals)
#     print(f"sum squared residuals: {sum_squared_residuals}")
#     # circularness metric
#     circularness_metric = np.mean(squared_residuals)

#     return centre_x, centre_y, radius, circularness_metric


# # Test the circularness function using a circle
# # Create a circle
# theta = np.linspace(0, 2 * np.pi, 100)
# radius = 100
# center_x = 200
# center_y = 200
# circle_x = center_x + radius * np.cos(theta)
# circle_y = center_y + radius * np.sin(theta)

# # Add noise to the circle
# circle_x += np.random.normal(0, 10, len(circle_x))
# circle_y += np.random.normal(0, 10, len(circle_y))

# # plot the circile
# plt.scatter(circle_x, circle_y)
# plt.show()

# # Fit a circle to the points
# center_x, center_y, radius, circularness_metric = circularness(np.array([circle_x, circle_y]).T, initial_radius=radius)
# print(f"center_x: {center_x:.2f}, center_y: {center_y:.2f}, radius: {radius:.2f}, circularness_metric: {circularness_metric}")

# # plot the circle
# plt.scatter(circle_x, circle_y)
# plt.scatter(center_x, center_y, c="red")
# plt.plot(center_x + radius * np.cos(theta), center_y + radius * np.sin(theta))
# plt.show()

In [None]:
# Save the results
import pickle

filename = f"{SAMPLE_TYPE}_turn_in_distance_grain_dict.pkl"

with open(SAVE_DIR / filename, "wb") as f:
    pickle.dump(turn_in_distance_grain_dict, f)

print(f"saved {filename} to {SAVE_DIR}")

In [None]:
raise SystemExit

In [None]:
# Curvature analysis
from scipy.interpolate import splprep, splev, UnivariateSpline
import matplotlib.gridspec as gridspec


def plot_colour_line_2d_based_on_value_on_axis(ax: plt.Axes, points: np.ndarray, values: np.ndarray, cmap="viridis_r"):
    # Get colours based on the value of the points
    normalised_values = (values - values.min()) / (values.max() - values.min())
    for i in range(len(points) - 1):
        ax.plot(points[i : i + 2, 0], points[i : i + 2, 1], color=plt.cm.get_cmap(cmap)(normalised_values[i]))


def plot_colour_line_1d_based_on_value_on_axis(
    ax: plt.Axes, values: np.ndarray, hline: Union[None, float] = None, cmap="viridis_r"
):
    # Get colours based on the value of the points
    normalised_values = (values - values.min()) / (values.max() - values.min())
    xs = np.arange(len(values))
    for i in range(len(values) - 1):
        ax.plot(xs[i : i + 2], values[i : i + 2], color=plt.cm.get_cmap(cmap)(normalised_values[i]))
    if hline is not None:
        ax.axhline(hline, color="k", linestyle="-")


def interpolate_points_spline(points: np.ndarray, num_points: Union[int, None] = None, smoothing: float = 0.0):
    """Interpolate a set of points using a spline.

    Parameters
    ----------
    points: np.ndarray
        Nx2 Numpy array of coordinates for the points.
    num_points: int
        The number of points to return following the calculated spline.

    Returns
    -------
    interpolated_points: np.ndarray
        An Ix2 Numpy array of coordinates of the interpolated points, where I is the number of points
        specified.
    """

    if num_points is None:
        num_points = points.shape[0]

    x, y = splprep(points.T, u=None, s=smoothing, per=1)
    x_spline = np.linspace(y.min(), y.max(), num_points)
    x_new, y_new = splev(x_spline, x, der=0)
    interpolated_points = np.array((x_new, y_new)).T
    return interpolated_points


def calculate_curvature_from_points(x_points, y_points, error=0.1, k=4):
    """Calculate the curvature for a set of points"""
    # Check that the number of points is the same for both x and y
    if x_points.shape[0] != y_points.shape[0]:
        raise ValueError(
            f"x_points and y_points must have the same number of points. x_points has {x_points.shape[0]} points and y_points has {y_points.shape[0]} points."
        )

    # Weight the values so less weight is given to points with higher error
    # K is the order of the spline to use. Increasing this increases the smoothness of the spline. A
    # value of 1 is linear interpolation, 2 is quadratic, 3 is cubic, etc. 4 is used to ensure the
    # spline is smooth enough to differentiate to the second derivative.
    # t is the independent variable that monotically increases with the data, similar to how
    # we use a dummy x variable in plotting calculations.

    t = np.arange(x_points.shape[0])
    weight_values = 1 / np.sqrt(error * np.ones_like(x_points))
    fx = UnivariateSpline(t, x_points, k=k, w=weight_values)
    fy = UnivariateSpline(t, y_points, k=k, w=weight_values)

    spline_x = fx(t)
    spline_y = fy(t)

    dx = fx.derivative(1)(t)
    dx2 = fx.derivative(2)(t)
    dy = fy.derivative(1)(t)
    dy2 = fy.derivative(2)(t)
    curvatures = (dx * dy2 - dy * dx2) / np.power(dx**2 + dy**2, 3 / 2)
    return curvatures, spline_x, spline_y


def calculate_curvature_periodic_boundary(x_points, y_points, error=0.1, periods=2, k=4):
    """Take a set of points that form a loop and calculate the curvature. Uses periodic boundary conditions, so
    the first and last points are connected. This reduces the error in the curvature calculation at the
    boundaries.

    Parameters
    ----------
    x_points: np.ndarray
        1D numpy array of x coordinates of the points.
    y_points: np.ndarray
        1D numpy array of y coordinates of the points.
    error: float
        Error in the points. Used to weight the points in the spline calculation.
    periods: int
        Number of times to repeat the points either side of the original points to reduce the error at the
        boundaries.

    Returns
    -------
    curvature: np.ndarray
        1D numpy array of the curvature for each point.
    """

    # Check that the number of points is the same for both x and y
    if x_points.shape[0] != y_points.shape[0]:
        raise ValueError(
            f"x_points and y_points must have the same number of points. x_points has"
            f" {x_points.shape[0]} points and y_points has {y_points.shape[0]} points."
        )

    # Repeat the points either side of the original points to reduce the error at the boundaries
    extended_points_x = np.copy(x_points)
    extended_points_y = np.copy(y_points)
    for i in range(periods * 2):
        extended_points_x = np.append(extended_points_x, x_points)
        extended_points_y = np.append(extended_points_y, y_points)

    # Calculate the curvature
    extended_curvature, spline_x, spline_y = calculate_curvature_from_points(
        extended_points_x, extended_points_y, error=error, k=k
    )

    # Return only the original points
    return (
        extended_curvature[x_points.shape[0] * int(periods / 2) : x_points.shape[0] * int((periods / 2) + 1)],
        spline_x[x_points.shape[0] * int(periods / 2) : x_points.shape[0] * int((periods / 2) + 1)],
        spline_y[x_points.shape[0] * int(periods / 2) : x_points.shape[0] * int((periods / 2) + 1)],
    )


def turn_path_into_pixel_map(array: np.ndarray):
    # Convert the spline to a pixelated trace 1 pixel thick

    # Create a map of pixels
    pixel_map = np.zeros((int(np.max(array) + 1), int(np.max(array) + 1)), dtype=int)
    pixelated_path = np.empty((0, 2), dtype=int)

    def check_is_touching(coordinate, original_coordinate):
        if np.abs(coordinate[0] - original_coordinate[0]) <= 1 and np.abs(coordinate[1] - original_coordinate[1]) <= 1:
            return True
        else:
            return False

    # Convert the array to integers and remove duplicates
    integer_array = np.array(array, dtype=int)
    removed_duplicates = []
    for index in range(len(integer_array)):
        coordinate = integer_array[index]
        if index > 0:
            if np.array_equal(coordinate, integer_array[index - 1]):
                # print(f"coordinate {coordinate} is a repeat of {integer_array[index - 1]}, skipping")
                continue

        removed_duplicates.append(coordinate)
    integer_array = np.array(removed_duplicates)

    last_coordinate = None
    for index in range(len(integer_array)):
        coordinate = integer_array[index]

        # If the coordinate is a repeat, skip it
        if index > 0:
            if np.array_equal(coordinate, integer_array[index - 1]):
                # print(
                #     f"coordinate {coordinate} is a repeat of {integer_array[index - 1]}, skipping"
                # )
                continue

        # print(f"coordinate: {coordinate}")
        if index == 0:
            pixel_map[coordinate[0], coordinate[1]] = 1
            last_coordinate = coordinate
        elif index == len(integer_array) - 1:
            pixel_map[coordinate[0], coordinate[1]] = 1
            last_coordinate = coordinate
            break
        else:
            # Check if the coordinate after this one is touching the coordinate before this one
            # and if so, skip this pixel
            if check_is_touching(integer_array[index + 1], last_coordinate):
                # print(f"coordinate {integer_array[index + 1]} is touching {last_coordinate}")
                continue
            else:
                # print(
                #     f"coordinate {integer_array[index+1]} is not touching {integer_array[index - 1]}. Adding to map"
                # )

                # Add the coordinate to the pixel map and the pixelated path
                pixel_map[int(coordinate[0]), int(coordinate[1])] = 1
                pixelated_path = np.vstack((pixelated_path, coordinate.reshape(1, 2)))
                last_coordinate = coordinate

    return pixel_map, pixelated_path


def defect_stats(height_trace: np.ndarray, threshold: float):
    regions = []
    in_region = False
    # Find the largest continuous region below the threshold
    for index, value in enumerate(height_trace):
        if value < threshold:
            # print(f"index {index} is below threshold: {height_trace[index]}")
            if not in_region:
                # Start region
                region_start = index
                area = 0
                deepest_point = value
                deepest_point_index = index
                in_region = True
            else:
                # print(f"value {height_trace[index]} is below threshold {threshold} by {threshold - height_trace[index]}")
                area += threshold - value
                if value < deepest_point:
                    deepest_point = value
                    deepest_point_index = index
        elif in_region:
            regions.append(
                {
                    "start": region_start,
                    "end": index,
                    "area": area,
                    "deepest_point_index": deepest_point_index,
                    "defect_threshold": threshold,
                }
            )
            in_region = False

    # Number of defects
    number_of_defects = len(regions)

    # Find the largest region
    largest_region_below_threshold = None
    largest_area = 0
    for region in regions:
        if region["area"] > largest_area:
            largest_area = region["area"]
            largest_region_below_threshold = region

    return {
        "defect_number": number_of_defects,
        "defect_largest_region": largest_region_below_threshold,
        "defect_regions": regions,
    }


plotting = True
stop_plotting_after = 70
curvature_grain_dict = {}

for index, grain_data in pooled_curvature_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]
    trace = grain_data["trace"]
    pooled_trace = grain_data["pooled_trace"]
    height_trace = grain_data["height_trace"]

    interpolated_points = interpolate_points_spline(pooled_trace, num_points=200, smoothing=0.0)

    # Error is the error in the points. Used to weight the points in the spline calculation.
    # Low error means the spline will pass through the points, high error means the spline will
    # be smooth and not pass through the points.

    # Increase smoothing (error) in images that are higher resolution ie lower pixel to nm ratio
    error = 0.01 / p_to_nm

    curvature, spline_x, spline_y = calculate_curvature_periodic_boundary(
        interpolated_points[:, 1], interpolated_points[:, 0], error=error, periods=2, k=5
    )

    curvature = -curvature

    spline_points = np.array((spline_x, spline_y)).T

    # Get pixel spline trace
    spline_pixel_trace_img, spline_pixelated_path = turn_path_into_pixel_map(spline_points)

    # Get the height trace from the spline pixelated path
    spline_pixelated_path_heights = grain_image[spline_pixelated_path[:, 1], spline_pixelated_path[:, 0]]

    defect_stats_dict = defect_stats(curvature, threshold=-0.1)

    if plotting:
        if index < stop_plotting_after:
            # Plot scatter with colour as curvature
            # fig, ax = plt.subplots(1, 3, figsize=(20, 10))

            fig = plt.figure(figsize=(12, 12))
            gs = gridspec.GridSpec(5, 2, figure=fig)

            ax0 = fig.add_subplot(gs[0, 0])
            ax1 = fig.add_subplot(gs[0, 1])
            ax2 = fig.add_subplot(gs[1, 0])
            ax3 = fig.add_subplot(gs[1, 1])
            ax4 = fig.add_subplot(gs[2, 0])
            ax5 = fig.add_subplot(gs[2, 1])
            ax6 = fig.add_subplot(gs[3, :])
            ax7 = fig.add_subplot(gs[4, :])

            ax0.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
            ax0.set_title("Grain Image")
            ax1.imshow(grain_mask, cmap="gray")
            ax1.set_title("Grain Mask")

            ax2.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
            ax2.plot(trace[:, 1], trace[:, 0], "r")
            ax2.plot(pooled_trace[:, 1], pooled_trace[:, 0], "b")
            ax2.legend(["Trace", "Pooled Trace"])
            ax2.set_title("Pixelated Path")
            ax3.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
            # Plot spline pixelated path including the first point to close the loop
            ax3.plot(
                np.append(spline_pixelated_path[:, 0], spline_pixelated_path[0, 0]),
                np.append(spline_pixelated_path[:, 1], spline_pixelated_path[0, 1]),
                "r",
            )
            ax3.set_title("Spline Pixel Path")
            ax4.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
            plot_colour_line_2d_based_on_value_on_axis(ax4, spline_points, curvature)
            ax4.set_title("Spline Path with Curvature")
            # Plot spline pixelated trace heights overlaid on the grain image
            ax5.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
            plot_colour_line_2d_based_on_value_on_axis(ax5, spline_pixelated_path, spline_pixelated_path_heights)
            ax5.set_title("Spline Pixellated Path with Heights")
            plot_colour_line_1d_based_on_value_on_axis(ax6, curvature, hline=0)
            ax6.set_title("Curvature")
            ax6.set_ylim(-0.5, 0.5)
            # Add red regions over defects
            for region in defect_stats_dict["defect_regions"]:
                ax6.axvspan(region["start"], region["end"], color="r", alpha=0.5)
            # Add a black line at the deepest point of the largest defect
            # But only if there are defects
            if defect_stats_dict["defect_number"] > 0:
                ax6.axvline(
                    defect_stats_dict["defect_largest_region"]["deepest_point_index"],
                    color="k",
                    linestyle="--",
                    alpha=0.5,
                )
            plot_colour_line_1d_based_on_value_on_axis(ax7, spline_pixelated_path_heights)
            ax7.set_ylim(0.0, 4.0)
            ax7.set_title("Spline Pixellated Path Heights")
            fig.tight_layout()
            plt.show()

    curvature_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "trace": trace,
        "height_trace": height_trace,
        "spline_points": spline_points,
        "curvature": curvature,
        "num_curvature_defects": defect_stats_dict["defect_number"],
        "p_to_nm": p_to_nm,
    }

# Plot kde of number of defects
num_defects = [curvature_grain_dict[i]["num_curvature_defects"] for i in curvature_grain_dict]
sns.kdeplot(num_defects)
plt.xlabel("Number of defects")
plt.title(f"Number of defects for {SAMPLE_TYPE} {len(num_defects)} grains")
plt.show()

# Plot bar chart of number of defects
from collections import Counter

num_defects_counter = Counter(num_defects)
plt.bar(num_defects_counter.keys(), num_defects_counter.values())
plt.xlabel("Number of defects")
plt.ylabel("Number of grains")
plt.title(f"Number of defects for {SAMPLE_TYPE} {len(num_defects)} grains")
plt.show()

In [None]:
# plot trace length in nm
trace_lengths = [
    len(curvature_grain_dict[i]["trace"]) * curvature_grain_dict[i]["p_to_nm"] for i in curvature_grain_dict
]
sns.kdeplot(trace_lengths)
plt.xlabel("Trace length (nm)")
plt.title(f"Trace length for {SAMPLE_TYPE} {len(trace_lengths)} grains n={len(trace_lengths)}")
plt.xlim(0, np.max(trace_lengths) * 1.1)
plt.show()

### Labelling grains for CNN classification

In [None]:
# Manually label each image with a tag. Either churro, dorito, or pasty
from IPython.display import clear_output
import pickle

# Load each image sequentially, wait for user input to tag the image
tagged_grain_dict = {}

TAGGED_GRAINS_FILENAME = f"{SAMPLE_TYPE}_tagged_grains.pkl"

ALREADY_LABELLED = True
LABELLED_DICTIONARY_PATH = Path(
    f"/Users/sylvi/topo_data/hariborings/dna_manual_tags/{SAMPLE_TYPE}/{TAGGED_GRAINS_FILENAME}"
)

if ALREADY_LABELLED:
    # Load the already labelled dictionary
    with open(LABELLED_DICTIONARY_PATH, "rb") as f:
        tagged_grain_dict = pickle.load(f)
else:
    # Manually label the grains
    for index, grain_data in curvature_grain_dict.items():
        grain_image = grain_data["image"]
        grain_mask = grain_data["mask"]
        trace = grain_data["trace"]
        height_trace = grain_data["height_trace"]
        p_to_nm = grain_data["p_to_nm"]
        curvature = grain_data["curvature"]
        num_curvature_defects = grain_data["num_curvature_defects"]

        fig, ax = plt.subplots(1, 2, figsize=(20, 10))
        ax[0].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        ax[0].set_title(f"Grain {index} - {SAMPLE_TYPE}")
        ax[1].plot(curvature)
        ax[1].set_title(f"Curvature - {num_curvature_defects} defects")
        plt.show()

        if not ALREADY_LABELLED:
            tag_chosen = False
            while not tag_chosen:
                tag = input(f"Tag grain {index} - {SAMPLE_TYPE} as churro, dorito, pasty or teardrop: ")
                if tag == "1":
                    tag = "churro"
                    tag_chosen = True
                elif tag == "2":
                    tag = "dorito"
                    tag_chosen = True
                elif tag == "3":
                    tag = "pasty"
                    tag_chosen = True
                elif tag == "4":
                    tag = "teardrop"
                    tag_chosen = True
                elif tag == "exit":
                    raise ValueError("Exiting")

            tagged_grain_dict[index] = {
                "image": grain_image,
                "mask": grain_mask,
                "trace": trace,
                "height_trace": height_trace,
                "curvature": curvature,
                "num_curvature_defects": num_curvature_defects,
                "p_to_nm": p_to_nm,
                "tag": tag,
            }

        clear_output()

In [None]:
# Display the tagged grains, plotting them in grids


def plot_images(images: list, masks: list, px_to_nms: list, grain_indexes: list, width=5, cmap=cmap, vmin=-8, vmax=8):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 2, figsize=(30, 30))
    for i, (image, mask, grain_index) in enumerate(zip(images, masks, grain_indexes)):
        ax[i // width, i % width * 2].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 2].axis("off")
        ax[i // width, i % width * 2 + 1].imshow(mask, cmap="binary")
        ax[i // width, i % width * 2].set_title(f"grain index: {grain_index} p_to_nm: {px_to_nms[i]}")
    fig.tight_layout()
    plt.show()


classes = np.unique([tagged_grain_dict[i]["tag"] for i in tagged_grain_dict])

churro_indexes = [i for i in tagged_grain_dict if tagged_grain_dict[i]["tag"] == "churro"]

# Plot the churros
plot_images(
    [tagged_grain_dict[i]["image"] for i in churro_indexes],
    [tagged_grain_dict[i]["mask"] for i in churro_indexes],
    [tagged_grain_dict[i]["p_to_nm"] for i in churro_indexes],
    churro_indexes,
)

dorito_indexes = [i for i in tagged_grain_dict if tagged_grain_dict[i]["tag"] == "dorito"]

# Plot the doritos
plot_images(
    [tagged_grain_dict[i]["image"] for i in dorito_indexes],
    [tagged_grain_dict[i]["mask"] for i in dorito_indexes],
    [tagged_grain_dict[i]["p_to_nm"] for i in dorito_indexes],
    dorito_indexes,
)

pasty_indexes = [i for i in tagged_grain_dict if tagged_grain_dict[i]["tag"] == "pasty"]

# Plot the pasties
plot_images(
    [tagged_grain_dict[i]["image"] for i in pasty_indexes],
    [tagged_grain_dict[i]["mask"] for i in pasty_indexes],
    [tagged_grain_dict[i]["p_to_nm"] for i in pasty_indexes],
    pasty_indexes,
)

In [None]:
# Convert a single tagged grain to a different tag
grain_index = 23
plt.imshow(tagged_grain_dict[grain_index]["image"], cmap=cmap, vmin=-8, vmax=8)
# tagged_grain_dict[grain_index]["tag"] = "dorito"

In [None]:
# Save the manually tagged grains for later use as a pickle
import pickle

TAGGED_GRAIN_SAVE_DIR = Path(f"/Users/sylvi/topo_data/hariborings/dna_manual_tags/{SAMPLE_TYPE}/")
TAGGED_GRAIN_SAVE_DIR.mkdir(parents=True, exist_ok=True)

with open(TAGGED_GRAIN_SAVE_DIR / f"{SAMPLE_TYPE}_tagged_grains.pkl", "wb") as f:
    pickle.dump(tagged_grain_dict, f)

In [None]:
# Save each image as image_{index}_{tag}.npy

TAGGED_GRAIN_SAVE_IMAGES_DIR = TAGGED_GRAIN_SAVE_DIR / "images"
TAGGED_GRAIN_SAVE_IMAGES_DIR.mkdir(parents=True, exist_ok=True)

for index, grain_data in tagged_grain_dict.items():
    tag = grain_data["tag"]
    image = grain_data["image"]
    np.save(TAGGED_GRAIN_SAVE_IMAGES_DIR / f"image_{index}_{tag}.npy", image)