In [None]:
from pathlib import Path
from typing import Union
from datetime import datetime
import pickle

import matplotlib.pyplot as plt
import numpy as np
import h5py

import seaborn as sns


from topostats.plottingfuncs import Colormap

colormap = Colormap()
cmap = colormap.get_cmap()

VMIN = 0
VMAX = 2.5

In [None]:
# Load images from pickle file

SAMPLE_TYPE = "OT2_SC"
DATA_LOADED_DAY = "2024-03-11"
TODAY = datetime.now().strftime("%Y-%m-%d")
print(f"Today: {TODAY}")
DATA_DIR = Path(f"/Users/sylvi/topo_data/hariborings/extracted_grains/unbound_{SAMPLE_TYPE}/date_{DATA_LOADED_DAY}/")
SAVE_DIR = Path(f"/Users/sylvi/topo_data/hariborings/processed_grains/unbound_{SAMPLE_TYPE}/date_{TODAY}/")
SAVE_DIR.mkdir(exist_ok=True, parents=True)

MAXIMUM_GRAIN_NUMBER = 210

# Load the data
FILE_PATH = DATA_DIR / "grain_dict.pkl"
with open(FILE_PATH, "rb") as f:
    grain_dict = pickle.load(f)

print(f"number of images for sample type [{SAMPLE_TYPE}] : {len(grain_dict.keys())}")

# Cut off the number of grains
grain_dict_sample = {}
for i, (key, grain_dictionary) in enumerate(grain_dict.items()):
    if i > MAXIMUM_GRAIN_NUMBER:
        break
    grain_dict_sample[key] = grain_dictionary

print(f"number of images in sample for {SAMPLE_TYPE} : {len(grain_dict_sample.keys())}")

grain_dict = grain_dict_sample

In [None]:
# Skeletonise using standard skeletonise
from skimage.morphology import skeletonize
from scipy.ndimage import convolve

plot_results = False


def plot_images(
    images: list,
    masks: list,
    grain_indexes: list,
    px_to_nms: list,
    skeletons: list,
    width=5,
    cmap=cmap,
    vmin=-8,
    vmax=8,
):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 3, figsize=(30, 30))
    for i, (image, mask, skeleton, grain_index) in enumerate(zip(images, masks, skeletons, grain_indexes)):
        ax[i // width, i % width * 3].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3].axis("off")
        ax[i // width, i % width * 3 + 1].imshow(skeleton, cmap="binary")
        ax[i // width, i % width * 3].set_title(f"index : {grain_index} p_to_nm: {px_to_nms[i]}")
        ax[i // width, i % width * 3 + 2].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3 + 2].imshow(skeleton, cmap="viridis", alpha=0.2)
    fig.tight_layout()
    plt.show()


def convolve_skelly(skeleton) -> np.ndarray:
    """Convolves the skeleton with a 3x3 ones kernel to produce an array
    of the skeleton as 1, endpoints as 2, and nodes as 3.

    Parameters
    ----------
    skeleton: np.ndarray
        Single pixel thick binary trace(s) within an array.

    Returns
    -------
    np.ndarray
        The skeleton (=1) with endpoints (=2), and crossings (=3) highlighted.
    """
    conv = convolve(skeleton.astype(np.int32), np.ones((3, 3)))
    conv[skeleton == 0] = 0  # remove non-skeleton points
    conv[conv == 3] = 1  # skelly = 1
    conv[conv > 3] = 3  # nodes = 3
    return conv


plotting = False
paths_grain_dict = {}
# Method for finding the optimal path in the molecule mask
# Options: skeletonize, distance transform, height
path_method = "distance transform"

for index, grain_data in grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]

    # Skeletonise
    skeleton = skeletonize(grain_mask)

    if plotting:
        fig, ax = plt.subplots(1, 4, figsize=(20, 10))
        ax[0].imshow(grain_mask, cmap="gray")
        ax[0].set_title("grain mask")
        ax[1].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        ax[1].set_title("grain image")
        ax[2].imshow(skeleton, cmap="gray")
        ax[2].set_title("skeleton")
        ax[3].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        ax[3].imshow(skeleton, cmap="viridis", alpha=0.2)
        plt.show()

    # Ignore any grains that have a branch, ie a pixel with more than 2 neighbours
    convolved_skelly = convolve_skelly(skeleton)

    if np.max(convolved_skelly) > 1:
        print(f"Grain {index} has a branch")
        continue

    paths_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "skeleton": skeleton.astype(bool),
        "p_to_nm": p_to_nm,
    }
if plot_results:
    plot_images(
        [paths_grain_dict[i]["image"] for i in paths_grain_dict],
        [paths_grain_dict[i]["mask"] for i in paths_grain_dict],
        [i for i in paths_grain_dict],
        [paths_grain_dict[i]["p_to_nm"] for i in paths_grain_dict],
        [paths_grain_dict[i]["skeleton"] for i in paths_grain_dict],
    )

In [None]:
# Trace the skeleton

plot_results = True


def plot_images(
    images: list, masks: list, px_to_nms: list, skeletons: list, traces: list, width=5, cmap=cmap, vmin=-8, vmax=8
):
    num_images = len(images)
    rows = np.ceil(num_images / width).astype(int)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 3, figsize=(30, rows * 10))
    for i, (image, mask, skeleton) in enumerate(zip(images, masks, skeletons)):
        ax[i // width, i % width * 3].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3].axis("off")
        ax[i // width, i % width * 3 + 1].imshow(skeleton, cmap="binary")
        ax[i // width, i % width * 3].set_title(f"p_to_nm: {px_to_nms[i]}")
        ax[i // width, i % width * 3 + 2].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3 + 2].imshow(skeleton, cmap="viridis", alpha=0.2)
        ax[i // width, i % width * 3 + 2].plot(traces[i][:, 1], traces[i][:, 0], "r")
    fig.tight_layout()
    plt.show()


plotting = False
trace_grain_dict = {}

for index, grain_data in paths_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    skeleton = grain_data["skeleton"]
    p_to_nm = grain_data["p_to_nm"]

    # Trace the skeleton
    skeleton_points = np.argwhere(skeleton)
    # Find the start point
    start_point = skeleton_points[0]
    # Each point should not have more than 2 neighbours so we can trace by finding the next point
    # and removing the current point from the skeleton

    skeleton_history = skeleton.copy()
    trace = [start_point]
    current_point = start_point
    skeleton_history[current_point[0], current_point[1]] = 0

    # print(f"len of skeleton points: {len(skeleton_points)}")
    # print(f"index: {index}")
    for iteration in range(len(skeleton_points) - 1):
        neighbourhood = skeleton_history[
            current_point[0] - 1 : current_point[0] + 2, current_point[1] - 1 : current_point[1] + 2
        ]
        if iteration > 0 and np.sum(neighbourhood) > 1:
            raise ValueError(f"More than 1 neighbour for iteration {iteration}")
        if np.sum(neighbourhood) == 0:
            raise ValueError(f"No neighbours for iteration {iteration}")
        next_point = np.argwhere(neighbourhood)[0]
        next_point_coords = current_point + next_point - 1
        trace.append(next_point_coords)
        current_point = next_point_coords
        skeleton_history[current_point[0], current_point[1]] = 0
    trace = np.array(trace)

    ordered_skeleton = np.zeros_like(skeleton).astype(int)
    for point_index, point in enumerate(trace):
        ordered_skeleton[point[0], point[1]] = point_index + 20

    # if plotting:
    #     fig, ax = plt.subplots(1, 3, figsize=(20, 10))
    #     ax[0].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
    #     ax[0].set_title("grain image")
    #     ax[1].imshow(ordered_skeleton, cmap="viridis")
    #     ax[1].set_title("ordered skeleton")
    #     ax[2].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
    #     ax[2].imshow(ordered_skeleton, cmap="viridis", alpha=0.2)
    #     ax[2].plot(trace[:, 1], trace[:, 0], "r")
    #     plt.show()

    # Calculate pixel trace length in nm
    pixel_trace_length = np.sum(np.linalg.norm(np.diff(trace, axis=0), axis=1)) * p_to_nm

    trace_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "skeleton": ordered_skeleton,
        "trace": trace,
        "p_to_nm": p_to_nm,
        "pixel_trace_length": pixel_trace_length,
    }

if plot_results:
    plot_images(
        [trace_grain_dict[i]["image"] for i in trace_grain_dict],
        [trace_grain_dict[i]["mask"] for i in trace_grain_dict],
        [trace_grain_dict[i]["p_to_nm"] for i in trace_grain_dict],
        [trace_grain_dict[i]["skeleton"] for i in trace_grain_dict],
        [trace_grain_dict[i]["trace"] for i in trace_grain_dict],
    )

# Plot kde of trace lengths
trace_lengths = [trace_grain_dict[i]["pixel_trace_length"] for i in trace_grain_dict]
sns.kdeplot(trace_lengths)
plt.xlim(0, 60)
plt.xlabel("Trace length (nm)")
plt.title(f"Length of pixel traces in nm for {len(trace_lengths)} grains")
plt.show()

In [None]:
# plot grain 206

grain_index = 206
grain_data = trace_grain_dict[grain_index]
grain_image = grain_data["image"]
grain_mask = grain_data["mask"]
skeleton = grain_data["skeleton"]
trace = grain_data["trace"]
p_to_nm = grain_data["p_to_nm"]
pixel_trace_length = grain_data["pixel_trace_length"]

plt.imshow(grain_image, cmap=cmap, vmin=VMIN, vmax=VMAX)
plt.plot(trace[:, 1], trace[:, 0], "r")
plt.show()

fig, ax = plt.subplots(1, 1, figsize=(20, 20))
ax.imshow(grain_image, cmap=cmap, vmin=VMIN, vmax=VMAX)
ax.plot(trace[:, 1], trace[:, 0], "r")
plt.show()

# plot derivative of trace
derivative = np.diff(trace, axis=0)

# plot x and y derivatives
fig, ax = plt.subplots(2, 1, figsize=(20, 10))
ax[0].plot(derivative[:, 1])
ax[0].set_title("x derivative")
ax[1].plot(derivative[:, 0])
ax[1].set_title("y derivative")
plt.show()

# plot x and y 2nd derivatives
second_derivative = np.diff(derivative, axis=0)

fig, ax = plt.subplots(2, 1, figsize=(20, 10))
ax[0].plot(second_derivative[:, 1])
ax[0].set_title("x 2nd derivative", fontsize=18)
ax[1].plot(second_derivative[:, 0])
ax[1].set_title("y 2nd derivative", fontsize=18)
# remove x and y ticks
for a in ax:
    a.set_xticks([])
    a.set_yticks([])
plt.show()

In [None]:
def simple_curvature(xs: np.ndarray, ys: np.ndarray):
    """Calculate the curvature of a set of points. Uses the standard curvature definition of the derivative of the
    tangent vector.

    Parameters:
    ----------
    xs: np.ndarray
        One dimensional numpy array of x-coordinates of the points
    ys: np.ndarray
        One dimensional numpy array of y-coordinates of the points
    Returns:
    -------
    np.ndarray
        One-dimensional numpy array of curvatures for the spline.
    """
    extension_length = xs.shape[0]
    xs_extended = np.append(xs, xs)
    xs_extended = np.append(xs_extended, xs)

    # # Plot the extended points
    # plt.plot(xs_extended)
    # plt.title("Extended x")
    # plt.show()

    ys_extended = np.append(ys, ys)
    ys_extended = np.append(ys_extended, ys)
    dx = np.gradient(xs_extended)

    # # Plot the extended dx
    # plt.plot(dx)
    # plt.title("Extended dx")
    # plt.show()

    dy = np.gradient(ys_extended)
    d2x = np.gradient(dx)

    # # Plot the extended d2x
    # plt.plot(d2x)
    # plt.title("Extended d2x")
    # plt.show()

    d2y = np.gradient(dy)
    curv = (dx * d2y - d2x * dy) / (dx * dx + dy * dy) ** 1.5

    # # Plot the extended curvature
    # plt.plot(curv)
    # plt.title("Extended curvature")
    # plt.show()

    curv = curv[extension_length : (len(curv) - extension_length)]

    return curv

In [None]:
curvatures = simple_curvature(trace[:, 1], trace[:, 0])

curvatures_xs = np.linspace(0, pixel_trace_length, len(curvatures))
plt.plot(curvatures_xs, curvatures)
plt.xlabel("Position along trace (nm)")
plt.ylabel("Curvature")
plt.title("Curvature of pixel trace")

In [None]:
def is_in_polygon(polygon: np.ndarray, point: np.ndarray) -> bool:
    """Check if a point is in a polygon using the ray casting algorithm.

    Parameters
    ----------
    polygon: np.ndarray
        The polygon to check if the point is in.
    point: np.ndarray
        The point to check if it is in the polygon.

    Returns
    -------
    bool
        True if the point is in the polygon, False otherwise.
    """
    x, y = point
    n = len(polygon)
    inside = False
    p1x, p1y = polygon[0]
    for i in range(n + 1):
        p2x, p2y = polygon[i % n]
        if y > min(p1y, p2y):
            if y <= max(p1y, p2y):
                if x <= max(p1x, p2x):
                    if p1y != p2y:
                        xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x
                    if p1x == p2x or x <= xinters:
                        inside = not inside
        p1x, p1y = p2x, p2y
    return inside


def fill_polygon(array: np.ndarray, polygon: np.ndarray, fill_value: float):
    """Fills a polygon within an array with a fill value.

    Parameters
    ----------
    array: np.ndarray
        The array to fill the polygon within.
    polygon: np.ndarray
        The polygon to fill within the array.
    fill_value: float
        The value to fill the polygon with.

    Returns
    -------
    np.ndarray
        The array with the polygon filled.
    """
    minx, miny = np.min(polygon, axis=0)
    maxx, maxy = np.max(polygon, axis=0)
    for y in range(miny, maxy + 1):
        for x in range(minx, maxx + 1):
            if is_in_polygon(polygon, np.array([x, y])):
                array[y, x] = fill_value
    return array


# # Test the polygon filling
# polygon = np.array([[0, 10], [3, 25], [28, 15], [10, 0]])
# array = np.zeros((30, 30))
# filled_array = fill_polygon(array, polygon, 1)
# plt.imshow(filled_array)
# plt.plot(np.append(polygon[:, 0], polygon[0, 0]), np.append(polygon[:, 1], polygon[0, 1]), "r")
# plt.show()

In [None]:
# Get height traces from the skeletons

plot_results = False


def plot_images(
    images: list, masks: list, px_to_nms: list, traces: list, height_traces: list, width=5, cmap=cmap, vmin=-8, vmax=8
):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 3, figsize=(30, 30))
    for i, (image, mask, height_trace) in enumerate(zip(images, masks, height_traces)):
        ax[i // width, i % width * 3].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3].axis("off")
        ax[i // width, i % width * 3].set_title(f"p_to_nm: {px_to_nms[i]:.2f}")
        ax[i // width, i % width * 3 + 1].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3 + 1].plot(traces[i][:, 1], traces[i][:, 0], "r")
        ax[i // width, i % width * 3 + 2].plot(height_trace)
    fig.tight_layout()
    plt.show()


plotting = False
height_trace_grain_dict = {}


for index, grain_data in trace_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]
    trace = grain_data["trace"]

    # Get the height trace
    height_trace = grain_image[trace[:, 0], trace[:, 1]]

    if plotting:
        fig, ax = plt.subplots(1, 2, figsize=(40, 10))
        ax[0].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        colours = np.linspace(0, 1, len(trace))
        for i in range(len(height_trace) - 1):
            ax[0].plot(trace[i : i + 2, 1], trace[i : i + 2, 0], color=plt.cm.viridis(colours[i]))
        ax[0].set_title("grain image")
        colours = np.linspace(0, 1, len(height_trace))
        xs = np.arange(len(height_trace))
        for i in range(len(height_trace) - 1):
            # Include +2 to get the next point since python slicing is exclusive
            ax[1].plot(xs[i : i + 2], height_trace[i : i + 2], color=plt.cm.viridis(colours[i]))
        ax[1].set_ylim(1, 3.5)
        plt.show()

    height_trace_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "trace": trace,
        "height_trace": height_trace,
        "p_to_nm": p_to_nm,
    }

if plot_results:
    plot_images(
        [height_trace_grain_dict[i]["image"] for i in height_trace_grain_dict],
        [height_trace_grain_dict[i]["mask"] for i in height_trace_grain_dict],
        [height_trace_grain_dict[i]["p_to_nm"] for i in height_trace_grain_dict],
        [height_trace_grain_dict[i]["trace"] for i in height_trace_grain_dict],
        [height_trace_grain_dict[i]["height_trace"] for i in height_trace_grain_dict],
    )

In [None]:
# For each trace, pool sets of n pixels

plot_results = False


def plot_images(
    images: list, original_traces: list, pooled_traces: list, px_to_nms: list, width=5, cmap=cmap, vmin=-8, vmax=8
):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width, figsize=(20, 50))
    for i, (image, original_trace, pooled_trace) in enumerate(zip(images, original_traces, pooled_traces)):
        ax[i // width, i % width].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width].plot(original_trace[:, 1], original_trace[:, 0], "r")
        ax[i // width, i % width].plot(pooled_trace[:, 1], pooled_trace[:, 0], "b")
        ax[i // width, i % width].axis("off")
    fig.tight_layout()
    plt.show()


plotting = False

pooled_curvature_grain_dict = {}

# Binning size
n = 6

for index, grain_data in height_trace_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]
    trace = grain_data["trace"]
    height_trace = grain_data["height_trace"]

    # Pool the trace points
    pooled_trace = []

    for i in range(len(trace)):
        binned_points = []
        for j in range(n):
            if i + j < len(trace):
                binned_points.append(trace[i + j])
            else:
                # If the index is out of range, sample from the start
                binned_points.append(trace[i + j - len(trace)])

        # Get the mean of the binned points
        pooled_trace.append(np.mean(binned_points, axis=0))

    pooled_trace = np.array(pooled_trace)

    if plotting:
        # Plot overlaid on image
        fig, ax = plt.subplots(1, 1, figsize=(20, 10))
        ax.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        ax.plot(trace[:, 1], trace[:, 0], "r")
        ax.plot(pooled_trace[:, 1], pooled_trace[:, 0], "b")
        ax.set_title("Trace and pooled trace")

    pooled_curvature_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "trace": trace,
        "pooled_trace": pooled_trace,
        "height_trace": height_trace,
        "pixel_trace_length": pixel_trace_length,
        "p_to_nm": p_to_nm,
    }

if plot_results:
    plot_images(
        [pooled_curvature_grain_dict[i]["image"] for i in pooled_curvature_grain_dict],
        [pooled_curvature_grain_dict[i]["trace"] for i in pooled_curvature_grain_dict],
        [pooled_curvature_grain_dict[i]["pooled_trace"] for i in pooled_curvature_grain_dict],
        [pooled_curvature_grain_dict[i]["p_to_nm"] for i in pooled_curvature_grain_dict],
    )

In [None]:
# Plot grain 206 pooled trace and pixel trace

grain_index = 206
grain_data = pooled_curvature_grain_dict[grain_index]
grain_image = grain_data["image"]
grain_mask = grain_data["mask"]
trace = grain_data["trace"]
pooled_trace = grain_data["pooled_trace"]
p_to_nm = grain_data["p_to_nm"]
height_trace = grain_data["height_trace"]
pixel_trace_length = grain_data["pixel_trace_length"]

plt.imshow(grain_image, cmap=cmap, vmin=VMIN, vmax=VMAX)
plt.plot(trace[:, 1], trace[:, 0], "r")
plt.plot(pooled_trace[:, 1], pooled_trace[:, 0], "b")
plt.show()

# Calculate the curvature of the pooled trace
pooled_curvature = simple_curvature(pooled_trace[:, 1], pooled_trace[:, 0])

pooled_curvature_xs = np.linspace(0, pixel_trace_length, len(pooled_curvature))

plt.plot(pooled_curvature_xs, pooled_curvature)
plt.xlabel("Position along trace (nm)")
plt.ylabel("Curvature")
plt.title("Curvature of pooled pixel trace")
plt.show()

fig, ax = plt.subplots(1, 1, figsize=(20, 20))
# turn axes off
ax.axis("off")
ax.imshow(grain_image, cmap=cmap, vmin=VMIN, vmax=VMAX)
# ax.plot(trace[:, 1], trace[:, 0], "r", marker=".")
# ax.plot(pooled_trace[:, 1], pooled_trace[:, 0], "b", marker=".")
ax.plot(trace[114:120, 1], trace[114:120, 0], "r", marker=".")
ax.plot(pooled_trace[114, 1], pooled_trace[114, 0], "b", marker=".")
plt.show()

# Apply gaussian smoothing to the curvautures
from scipy.ndimage import gaussian_filter1d

sigma = 3
smoothed_curvature = gaussian_filter1d(pooled_curvature, sigma=sigma)

plt.plot(pooled_curvature_xs, smoothed_curvature)
plt.xlabel("Position along trace (nm)")
plt.ylabel("Curvature")
plt.title(f"Smoothed curvature of pooled pixel trace (sigma={sigma})")
plt.show()

In [None]:
# Calculate the spatial derivative of the trace

grain_index = 206
grain_data = pooled_curvature_grain_dict[grain_index]
grain_image = grain_data["image"]
grain_mask = grain_data["mask"]
trace = grain_data["trace"]
pooled_trace = grain_data["pooled_trace"]
p_to_nm = grain_data["p_to_nm"]
height_trace = grain_data["height_trace"]
pixel_trace_length = grain_data["pixel_trace_length"]

# Calculate the spatial derivative of the pooled trace
x_derivative = np.gradient(pooled_trace[:, 1])
y_derivative = np.gradient(pooled_trace[:, 0])

plt.plot(x_derivative, label="x derivative")
plt.plot(y_derivative, label="y derivative")
plt.legend()
plt.title("Spatial derivatives of the pooled trace")
plt.show()

# Calculate spatial derivatives of the standard trace
x_derivative = np.gradient(trace[:, 1])
y_derivative = np.gradient(trace[:, 0])

plt.plot(x_derivative, label="x derivative")
plt.plot(y_derivative, label="y derivative")
plt.legend()
plt.title("Spatial derivatives of the standard trace")

In [None]:
from skimage.morphology import binary_erosion
from scipy.ndimage import binary_fill_holes


def calculate_edges(grain_mask: np.ndarray):
    """Class method that takes a 2D boolean numpy array image of a grain and returns a python list of the
    coordinates of the edges of the grain.

    Parameters
    ----------
    grain_mask : np.ndarray
        A 2D numpy array image of a grain. Data in the array must be boolean.
    edge_detection_method : str
        Method used for detecting the edges of grain masks before calculating statistics on them.
        Do not change unless you know exactly what this is doing. Options: "binary_erosion", "canny".

    Returns
    -------
    edges : list
        List containing the coordinates of the edges of the grain.
    """
    # Fill any holes
    filled_grain_mask = binary_fill_holes(grain_mask)

    # Add padding (needed for erosion)
    padded = np.pad(filled_grain_mask, 1)
    # Erode by 1 pixel
    eroded = binary_erosion(padded)
    # Remove padding
    eroded = eroded[1:-1, 1:-1]

    # Edges is equal to the difference between the
    # original image and the eroded image.
    edges = filled_grain_mask.astype(int) - eroded.astype(int)

    nonzero_coordinates = edges.nonzero()
    # Get vector representation of the points
    # FIXME : Switched to list comprehension but should be unnecessary to create this as a list as we can use
    # np.stack() to combine the arrays and use that...
    # return np.stack(nonzero_coordinates, axis=1)
    # edges = []
    # for vector in np.transpose(nonzero_coordinates):
    #     edges.append(list(vector))
    # return edges
    return [list(vector) for vector in np.transpose(nonzero_coordinates)]


def is_clockwise(p_1: tuple, p_2: tuple, p_3: tuple) -> bool:
    """Function to determine if three points make a clockwise or counter-clockwise turn.

    Parameters
    ----------
    p_1: tuple
        First point to be used to calculate turn.
    p_2: tuple
        Second point to be used to calculate turn.
    p_3: tuple
        Third point to be used to calculate turn.

    Returns
    -------
    boolean
        Indicator of whether turn is clockwise.
    """
    # Determine if three points form a clockwise or counter-clockwise turn.
    # I use the method of calculating the determinant of the following rotation matrix here. If the determinant
    # is > 0 then the rotation is counter-clockwise.
    rotation_matrix = np.array(((p_1[0], p_1[1], 1), (p_2[0], p_2[1], 1), (p_3[0], p_3[1], 1)))
    return not np.linalg.det(rotation_matrix) > 0


def get_triangle_height(base_point_1: np.array, base_point_2: np.array, top_point: np.array) -> float:
    """Returns the height of a triangle defined by the input point vectors.
    Parameters
    ----------
    base_point_1: np.ndarray
        a base point of the triangle, eg: [5, 3].

    base_point_2: np.ndarray
        a base point of the triangle, eg: [8, 3].

    top_point: np.ndarray
        the top point of the triangle, defining the height from the line between the two base points, eg: [6,10].

    Returns
    -------
    Float
        The height of the triangle - ie the shortest distance between the top point and the line between the two
    base points.
    """

    # Height of triangle = A/b = ||AB X AC|| / ||AB||
    a_b = base_point_1 - base_point_2
    a_c = base_point_1 - top_point
    return np.linalg.norm(np.cross(a_b, a_c)) / np.linalg.norm(a_b)


def get_max_min_ferets(edge_points: list):
    """Returns the minimum and maximum feret diameters for a grain.
    These are defined as the smallest and greatest distances between
    a pair of callipers that are rotating around a 2d object, maintaining
    contact at all times.

    Parameters
    ----------
    edge_points: list
        a list of the vector positions of the pixels comprising the edge of the
        grain. Eg: [[0, 0], [1, 0], [2, 1]]
    Returns
    -------
    min_feret: float
        the minimum feret diameter of the grain
    max_feret: float
        the maximum feret diameter of the grain

    Notes
    -----
    The method starts out by calculating the upper and lower convex hulls using
    an algorithm based on the Graham Scan Algorithm [1]. Using these upper and
    lower hulls, the callipers are simulated as rotating clockwise around the grain.
    We determine the order in which vertices are encountered by comparing the
    gradients of the slopes between vertices. An array of pairs of points that
    are in contact with either calliper at a given time is created in order to
    be able to calculate the maximum feret diameter. The minimum diameter is a
    little tricky, since it won't simply be the shortest distance between two
    contact points, but it will occur somewhere during the rotation around a
    pair of contact points. It turns out that the point will always be such
    that two points are in contact with one calliper while the other calliper
    is in contact with another point. We can use this fact to be sure of finding
    the smallest feret diameter, simply by testing each triangle of 3 contact points
    as we iterate, finding the height of the triangle that is formed between the
    three aforementioned points, as this will be the perpendicular distance between
    the callipers.

    References
    ----------
    [1] Graham, R.L. (1972).
        "An Efficient Algorithm for Determining the Convex Hull of a Finite Planar Set".
        Information Processing Letters. 1 (4): 132-133.
        doi:10.1016/0020-0190(72)90045-2.
    """

    min_feret_triangle = None

    # Sort the vectors by their x coordinate and then by their y coordinate.
    # The conversion between list and numpy array can be removed, though it would be harder
    # to read.
    edge_points.sort()
    edge_points = np.array(edge_points)

    # Construct upper and lower hulls for the edge points. Sadly we can't just use the standard hull
    # that graham_scan() returns, since we need to separate the upper and lower hulls. I might streamline
    # these two into one method later.
    upper_hull = []
    lower_hull = []
    for point in edge_points:
        while len(lower_hull) > 1 and is_clockwise(lower_hull[-2], lower_hull[-1], point):
            lower_hull.pop()
        lower_hull.append(point)
        while len(upper_hull) > 1 and not is_clockwise(upper_hull[-2], upper_hull[-1], point):
            upper_hull.pop()
        upper_hull.append(point)

    upper_hull = np.array(upper_hull)
    lower_hull = np.array(lower_hull)

    # Create list of contact vertices for calipers on the antipodal hulls
    contact_points = []
    upper_index = 0
    lower_index = len(lower_hull) - 1
    min_feret = None
    while upper_index < len(upper_hull) - 1 or lower_index > 0:
        contact_points.append([lower_hull[lower_index, :], upper_hull[upper_index, :]])
        # If we have reached the end of the upper hull, continute iterating over the lower hull
        if upper_index == len(upper_hull) - 1:
            lower_index -= 1
            small_feret = get_triangle_height(
                np.array(lower_hull[lower_index + 1, :]),
                np.array(lower_hull[lower_index, :]),
                np.array(upper_hull[upper_index, :]),
            )
            if min_feret is None or small_feret < min_feret:
                min_feret = small_feret
                min_feret_triangle = [
                    lower_hull[lower_index + 1, :],
                    lower_hull[lower_index, :],
                    upper_hull[upper_index, :],
                ]
        # If we have reached the end of the lower hull, continue iterating over the upper hull
        elif lower_index == 0:
            upper_index += 1
            small_feret = get_triangle_height(
                np.array(upper_hull[upper_index - 1, :]),
                np.array(upper_hull[upper_index, :]),
                np.array(lower_hull[lower_index, :]),
            )
            if min_feret is None or small_feret < min_feret:
                min_feret = small_feret
                min_feret_triangle = [
                    lower_hull[lower_index + 1, :],
                    lower_hull[lower_index, :],
                    upper_hull[upper_index, :],
                ]
        # Check if the gradient of the last point and the proposed next point in the upper hull is greater than the gradient
        # of the two corresponding points in the lower hull, if so, this means that the next point in the upper hull
        # will be encountered before the next point in the lower hull and vice versa.
        # Note that the calculation here for gradients is the simple delta upper_y / delta upper_x > delta lower_y / delta lower_x
        # however I have multiplied through the denominators such that there are no instances of division by zero. The
        # inequality still holds and provides what is needed.
        elif (upper_hull[upper_index + 1, 1] - upper_hull[upper_index, 1]) * (
            lower_hull[lower_index, 0] - lower_hull[lower_index - 1, 0]
        ) > (lower_hull[lower_index, 1] - lower_hull[lower_index - 1, 1]) * (
            upper_hull[upper_index + 1, 0] - upper_hull[upper_index, 0]
        ):
            # If the upper hull is encoutnered first, increment the iteration index for the upper hull
            # Also consider the triangle that is made as the two upper hull vertices are colinear with the caliper
            upper_index += 1
            small_feret = get_triangle_height(
                np.array(upper_hull[upper_index - 1, :]),
                np.array(upper_hull[upper_index, :]),
                np.array(lower_hull[lower_index, :]),
            )
            if min_feret is None or small_feret < min_feret:
                min_feret = small_feret
        else:
            # The next point in the lower hull will be encountered first, so increment the lower hull iteration index.
            lower_index -= 1
            small_feret = get_triangle_height(
                np.array(lower_hull[lower_index + 1, :]),
                np.array(lower_hull[lower_index, :]),
                np.array(upper_hull[upper_index, :]),
            )

            if min_feret is None or small_feret < min_feret:
                min_feret = small_feret
                min_feret_triangle = [
                    lower_hull[lower_index + 1, :],
                    lower_hull[lower_index, :],
                    upper_hull[upper_index, :],
                ]

    contact_points = np.array(contact_points)

    # Find the minimum and maximum distance in the contact points
    max_feret = None
    for point_pair in contact_points:
        dist = np.sqrt((point_pair[0, 0] - point_pair[1, 0]) ** 2 + (point_pair[0, 1] - point_pair[1, 1]) ** 2)
        if max_feret is None or max_feret < dist:
            max_feret = dist

    return min_feret, max_feret, min_feret_triangle


def shoelace(points: np.ndarray):
    """Use shoelace method to calculate area of polygon"""

    # Add the first point to the end of the array
    points = np.vstack([points, points[0]])
    x = points[:, 0]
    y = points[:, 1]
    return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))

In [None]:
openness_grain_dict = {}

for index, grain_data in pooled_curvature_grain_dict.items():
    # print(f"grain index: {index}")
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]
    pooled_trace = grain_data["pooled_trace"]

    # Calculate the minimum and maximum feret diameters
    min_feret, max_feret, min_feret_triangle = get_max_min_ferets(edge_points=np.copy(pooled_trace).tolist())

    # An open molecule will have a feret ratio of 1, a squished molecule will have a lower feret ratio
    feret_ratio = min_feret / max_feret
    # feret_ratio = max_feret / min_feret

    if plotting:
        plt.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        plt.scatter(pooled_trace[:, 1], pooled_trace[:, 0], c="k", s=20)
        plt.scatter(min_feret_triangle[0][1], min_feret_triangle[0][0], c="red", s=60)
        plt.scatter(min_feret_triangle[1][1], min_feret_triangle[1][0], c="red", s=60)
        plt.scatter(min_feret_triangle[2][1], min_feret_triangle[2][0], c="red", s=60)
        plt.title(f"grain index: {index}, feret ratio: {feret_ratio:.2f}")
        plt.show()

        plt.imshow(grain_mask, cmap=cmap, vmin=-8, vmax=8)
        plt.scatter(pooled_trace[:, 1], pooled_trace[:, 0], c="k", s=20)
        plt.scatter(min_feret_triangle[0][1], min_feret_triangle[0][0], c="red", s=60)
        plt.scatter(min_feret_triangle[1][1], min_feret_triangle[1][0], c="red", s=60)
        plt.scatter(min_feret_triangle[2][1], min_feret_triangle[2][0], c="red", s=60)
        plt.title(f"grain index: {index}, feret ratio: {feret_ratio:.2f}")
        plt.show()

    # Calculate area divided by perimeter
    # Calculate the area of the tace
    area = shoelace(pooled_trace) * p_to_nm**2
    # Calculate the perimeter of the trace
    perimeter = np.sum(np.linalg.norm(np.diff(pooled_trace, axis=0), axis=1)) * p_to_nm

    perimeter_area_ratio = area / perimeter
    # larger

    # copy all the data to the openness_grain_dict and add the feret ratio and perimeter area ratio
    openness_grain_dict[index] = grain_data
    openness_grain_dict[index]["min_feret"] = min_feret
    openness_grain_dict[index]["max_feret"] = max_feret
    openness_grain_dict[index]["feret_ratio"] = feret_ratio
    openness_grain_dict[index]["perimeter_area_ratio"] = perimeter_area_ratio

# Plot openness as a kde
openness_values = [openness_grain_dict[i]["feret_ratio"] for i in openness_grain_dict]
sns.kdeplot(openness_values)
plt.title(f"Feret ratio (min / max) {SAMPLE_TYPE}")
plt.show()

# Plot perimeter area ratio
perimeter_area_ratio_values = [openness_grain_dict[i]["perimeter_area_ratio"] for i in openness_grain_dict]
sns.kdeplot(perimeter_area_ratio_values)
plt.title(f"Area / perimeter {SAMPLE_TYPE}")
plt.show()

In [None]:
# Using the pooled points, calculate an allegory for curvature, the change in angle per nm of trace length.


def angle_diff_signed(v1: np.ndarray, v2: np.ndarray):
    """Calculate the signed angle difference between two vectors.

    Parameters
    ----------
    v1: np.ndarray
        The first vector.
    v2: np.ndarray
        The second vector.

    Returns
    -------
    float
        The signed angle difference between the two vectors.
    """

    # Calculate if the new vector is clockwise or anticlockwise from the old vector

    # Calculate the angle between the vectors
    angle = np.arccos(np.clip(np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)), -1.0, 1.0))

    # Calculate the cross product
    cross = np.cross(v1, v2)

    # If the cross product is positive, the new vector is clockwise from the old vector
    if cross > 0:
        angle = -angle

    return angle

    # """Calculate the angle difference between two vectors.

    # Parameters
    # ----------
    # v1: np.ndarray
    #     The first vector.
    # v2: np.ndarray
    #     The second vector.

    # Returns
    # -------
    # float
    #     The angle difference between the two vectors.
    # """
    # v1_u = v1 / np.linalg.norm(v1)
    # v2_u = v2 / np.linalg.norm(v2)
    # return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))


# Test the angle diff function
v1 = np.array([1, 0])
v2 = np.array([0, 1])
print(np.degrees(angle_diff_signed(v1, v2)))

v1 = np.array([1, 0])
v2 = np.array([1, 1])
print(np.degrees(angle_diff_signed(v1, v2)))

v1 = np.array([1, 0])
v2 = np.array([1, -1])
print(np.degrees(angle_diff_signed(v1, v2)))

v1 = np.array([1, 0])
v2 = np.array([-1, 0])
print(np.degrees(angle_diff_signed(v1, v2)))

v1 = np.array([1, 0])
v2 = np.array([-1, 1])
print(np.degrees(angle_diff_signed(v1, v2)))


def angle_per_nm(trace: np.ndarray, p_to_nm: float, plot: bool = False) -> np.ndarray:
    """Calculate the angle per nm of a trace.

    Parameters
    ----------
    trace: np.ndarray
        The trace to calculate the angle per nm of.
    p_to_nm: float
        The pixel to nm scaling factor.

    Returns
    -------
    np.ndarray
        The angle change per nm for each point in the trace
    """

    # Check if the first point is the same as the last point
    if np.all(trace[0] == trace[-1]):
        raise ValueError("The first and last points are the same")

    if plot:
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))

    angles_per_nm = np.zeros(len(trace))
    angle_diffs = np.zeros(len(trace))

    for index, point in enumerate(trace):
        # print(f"index: {index}")
        if plot:
            ax.scatter(point[1], point[0], c="purple", s=20)

        # Get the vectors to the previous and next points
        if index == 0:
            v_prev = point - trace[-1]
            v_next = trace[index + 1] - point
        if index == len(trace) - 1:
            v_prev = point - trace[index - 1]
            v_next = trace[0] - point
        else:
            v_prev = point - trace[index - 1]
            v_next = trace[index + 1] - point

        # print(f"vprev: {v_prev} vnext: {v_next}")

        # Normalise the vectors to unit length
        norm_v_prev = v_prev / np.linalg.norm(v_prev) * 0.1
        norm_v_next = v_next / np.linalg.norm(v_next) * 0.1

        angle = angle_diff_signed(v_prev, v_next)

        if plot:
            # Plot the vectors
            ax.arrow(
                point[1], point[0], norm_v_prev[1], norm_v_prev[0], head_width=0.01, head_length=0.2, fc="r", ec="r"
            )
            ax.arrow(
                point[1], point[0], norm_v_next[1], norm_v_next[0], head_width=0.01, head_length=0.2, fc="b", ec="b"
            )
            # Write text for the angle
            ax.text(point[1], point[0], f"{np.degrees(angle):.2f}", fontsize=12, color="black")

        distance = np.linalg.norm(v_prev) * p_to_nm

        # print(f"distance: {distance:.4f} angle: {angle:.4f} angle per nm: {angle / distance:.4f}")

        angles_per_nm[index] = angle / distance
        angle_diffs[index] = angle

    if plot:
        plt.plot(trace[:, 1], trace[:, 0], "k")
        plt.show()

    return angles_per_nm, angle_diffs


# Test the angle per nm function

trace = np.array(
    [
        [0, 0],
        [1, 0],
        [1, 1],
        [2, 1],
        [4, 5],
        [6, 2],
    ]
)

px_to_nm = 1

# Create an ellipse of points
n = 100
theta = np.linspace(0, 2 * np.pi, n)
a = 3
b = 2
x = a * np.cos(theta)
y = b * np.sin(theta)
ellipse = np.array([x, y]).T
# Remove the last point as it is the same as the first
ellipse = ellipse[:-1]
# Reverse the order of the points
ellipse = ellipse[::-1]

plt.plot(ellipse[:, 0], ellipse[:, 1], ".")
plt.show()

angles_per_nm, angle_diffs = angle_per_nm(ellipse, px_to_nm, plot=True)

plt.plot(angles_per_nm, "-o")
plt.show()


# Generate a set of points on a circle
n = 10
theta = np.linspace(0, 2 * np.pi, n)
# remove the last point as it is the same as the first
theta = theta[:-1]
r = 1
x = r * np.cos(theta)
y = r * np.sin(theta)
circle = np.array([x, y]).T

# Test the angle per nm function on a circle
px_to_nm = 1
angles_per_nm, angle_diffs = angle_per_nm(circle, px_to_nm, plot=True)

plt.plot(angles_per_nm)
plt.title(f"angle per unit length for a circle of {n} points")

print(angles_per_nm)

trace = np.array(
    [
        [2, 2],
        [10, 4],
        [8, 9],
        [6, 7],
        [5, 5],
    ]
)


px_to_nm = 1

angles_per_nm, angle_diffs = angle_per_nm(trace, px_to_nm, plot=True)

angle_diffs_deg = np.degrees(angle_diffs)

plt.plot(trace[:, 1], trace[:, 0], "k")
plt.show()
plt.plot(angle_diffs_deg, "-o")
plt.show()

In [None]:
def flip_if_anticlockwise(trace: np.ndarray):
    # Check if the trace is clockwise or anticlockwise by summing the cross products of the vectors
    # If the sum is positive, the trace is clockwise
    # If the sum is negative, the trace is anticlockwise
    # If the sum is 0, the trace is a straight line
    cross_sum = 0
    for i in range(len(trace) - 1):
        cross_sum += np.cross(trace[i], trace[i + 1])
    if cross_sum > 0:
        # print("clockwise")
        # Reverse the trace
        trace = np.flip(trace, axis=0)
    elif cross_sum < 0:
        # print("anticlockwise")
        pass

    return trace

In [None]:
def defect_stats(array: np.ndarray, threshold: float):
    regions = []
    in_region = False
    # Find the largest continuous region below the threshold
    for index, value in enumerate(array):
        if value > threshold:
            if not in_region:
                # Start region
                region_start = index
                area = 0
                highest_point = value
                highest_point_index = index
                in_region = True
            else:
                area += threshold - value
                if value > highest_point:
                    highest_point = value
                    highest_point_index = index
        elif in_region:
            regions.append(
                {
                    "start": region_start,
                    "end": index,
                    "area": area,
                    "highest_point_index": highest_point_index,
                    "highest_point": highest_point,
                    "defect_threshold": threshold,
                }
            )
            in_region = False
    if in_region:
        regions.append(
            {
                "start": region_start,
                "end": index,
                "area": area,
                "highest_point_index": highest_point_index,
                "highest_point": highest_point,
                "defect_threshold": threshold,
            }
        )

    # Check if there are defects at the start and end of the array
    if len(regions) > 0:
        # Check if the first region starts at the start of the array
        if regions[0]["start"] == 0:
            # And if the last region ends at the end of the array
            if regions[-1]["end"] == len(array) - 1:
                # Combine the first and last regions
                regions[0]["start"] = regions[-1]["start"]
                regions[0]["area"] += regions[-1]["area"]
                if regions[-1]["highest_point"] > regions[0]["highest_point"]:
                    regions[0]["highest_point"] = regions[-1]["highest_point"]
                    regions[0]["highest_point_index"] = regions[-1]["highest_point_index"]
                regions.pop(-1)

    # Number of defects
    number_of_defects = len(regions)

    # Find the largest region
    largest_region_below_threshold = None
    largest_area = 0
    for region in regions:
        if region["area"] > largest_area:
            largest_area = region["area"]
            largest_region_below_threshold = region

    # Find the midpoint of each region
    for region in regions:
        if region["start"] < region["end"]:
            region["midpoint"] = int(np.round((region["start"] + region["end"]) / 2))
        else:
            # Get the negative index to be able to take the average of the two indexes
            temp_startpoint = region["start"] - len(array)
            region["midpoint"] = int(np.round((temp_startpoint + region["end"]) / 2)) % len(array)

    return {
        "defect_number": number_of_defects,
        "defect_largest_region": largest_region_below_threshold,
        "defect_regions": regions,
    }


testarr1 = np.array(
    [
        0.4002,
        0.4014,
        0.4005,
        0.3933,
        0.3781,
        0.3564,
        0.3328,
        0.3114,
        0.2939,
        0.2791,
        0.2641,
        0.2461,
        0.2239,
        0.1980,
        0.1692,
        0.1383,
        0.1057,
        0.0717,
        0.0374,
        0.0052,
        -0.0221,
        -0.0417,
        -0.0518,
        -0.0514,
        -0.0408,
        -0.0202,
        0.0093,
        0.0460,
        0.0866,
        0.1267,
        0.1617,
        0.1886,
        0.2069,
        0.2191,
        0.2295,
        0.2429,
        0.2617,
        0.2853,
        0.3099,
        0.3295,
        0.3387,
        0.3341,
        0.3153,
        0.2847,
        0.2462,
        0.2042,
        0.1627,
        0.1248,
        0.0924,
        0.0666,
        0.0479,
        0.0362,
        0.0310,
        0.0310,
        0.0343,
        0.0384,
        0.0406,
        0.0390,
        0.0330,
        0.0239,
        0.0135,
        0.0034,
        -0.0058,
        -0.0142,
        -0.0220,
        -0.0283,
        -0.0320,
    ]
)

testarr2 = np.array(
    [
        0.2799,
        0.2818,
        0.2845,
        0.2852,
        0.2805,
        0.2674,
        0.2448,
        0.2136,
        0.1773,
        0.1394,
        0.1027,
        0.0681,
        0.0355,
        0.0055,
        -0.0203,
        -0.0392,
        -0.0485,
        -0.0475,
        -0.0378,
        -0.0221,
        -0.0037,
        0.0148,
        0.0320,
        0.0478,
        0.0636,
        0.0819,
        0.1056,
        0.1373,
        0.1780,
        0.2270,
        0.2812,
        0.3358,
        0.3848,
        0.4217,
        0.4408,
        0.4378,
        0.4110,
        0.3628,
        0.2999,
        0.2319,
        0.1692,
        0.1192,
        0.0848,
        0.0643,
        0.0529,
        0.0457,
        0.0390,
        0.0309,
        0.0212,
        0.0113,
        0.0028,
        -0.0025,
        -0.0031,
        0.0015,
        0.0115,
        0.0266,
        0.0459,
        0.0688,
        0.0949,
        0.1244,
        0.1571,
        0.1928,
        0.2301,
        0.2668,
        0.3003,
        0.3282,
        0.3482,
        0.3587,
    ]
)

threshold = 0.2

test_defect_stats_1 = defect_stats(testarr1, threshold)
print(test_defect_stats_1)

test_defect_stats_2 = defect_stats(testarr2, threshold)
print(test_defect_stats_2)


for testarr, test_defect_stats in zip([testarr1, testarr2], [test_defect_stats_1, test_defect_stats_2]):
    plt.plot(testarr)
    plt.ylim(-0.5, 0.5)
    plt.axhline(threshold, color="k", linestyle="--")
    plt.axhline(0, color="k", linestyle="-")
    for region in test_defect_stats["defect_regions"]:
        if region["start"] < region["end"]:
            plt.axvspan(region["start"], region["end"], color="red", alpha=0.3)
        else:
            plt.axvspan(region["start"], len(testarr1), color="red", alpha=0.3)
            plt.axvspan(0, region["end"], color="red", alpha=0.3)
        # Identify the deepest point
        plt.scatter(region["highest_point_index"], testarr[region["highest_point_index"]], c="r")
        # Plot the midpoint
        plt.axvline(region["midpoint"], color="b", linestyle="--")
    plt.show()

In [None]:
def calculate_real_distance_between_points_in_array(
    array: np.ndarray, indexes_to_calculate_distance_between: np.ndarray, p_to_nm: float
):
    # Calculate the distances between each defect along the trace
    to_find = np.copy(indexes_to_calculate_distance_between)
    original = to_find[0]
    current_index = original
    current_position = array[current_index]
    previous_position = current_position
    distances = []
    distance = 0
    while len(to_find) > 0:
        # Update old position
        previous_position = current_position
        # Increment the current position along the trace
        current_index += 1
        if current_index >= len(array):
            current_index -= len(array)
        current_position = array[current_index]

        # Increment the distance
        distance += np.linalg.norm(current_position - previous_position) * p_to_nm

        # Check if the current index is in the to_find list
        if current_index in to_find:
            # Get rid of the current index from the to_find list
            to_find = to_find[to_find != current_index]
            # Store the distance
            distances.append(distance)
            # Reset the distance
            distance = 0

    return distances


# Test the distance calculation
array = np.array(
    [
        [5, 5],
        [5, 6],
        [6, 7],
        [7, 7],
        [8, 6],
        [8, 5],
        [7, 4],
        [6, 4],
        [5, 5],
    ]
)

indexes_to_calculate_distance_between = np.array([0, 3, 7])
p_to_nm = 1

distances = calculate_real_distance_between_points_in_array(array, indexes_to_calculate_distance_between, p_to_nm)

print(distances)

plt.plot(array[:, 0], array[:, 1], "o-")
# Mark the start point with a star
plt.scatter(
    array[indexes_to_calculate_distance_between[0], 0],
    array[indexes_to_calculate_distance_between[0], 1],
    c="k",
    s=400,
    marker="*",
)
# Mark the points to calculate the distance between
plt.scatter(
    array[indexes_to_calculate_distance_between, 0], array[indexes_to_calculate_distance_between, 1], c="r", s=100
)
plt.show()

In [None]:
# SIMPLE calculate rate of change of vector angle per nm

from scipy.ndimage import gaussian_filter1d

angle_change_rate_grain_dict = {}
plotting = True
plot_results = True

for index, grain_data in openness_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    pooled_trace = grain_data["pooled_trace"]
    p_to_nm = grain_data["p_to_nm"]

    if index != 206:
        continue

    # Degrees per nm threshold
    degrees_per_nm_threshold = 12
    radians_per_nm_threshld = np.radians(degrees_per_nm_threshold)
    if plotting:
        print(f"radians per nm threshold: {radians_per_nm_threshld}")

    pooled_trace = flip_if_anticlockwise(pooled_trace)

    # Sample points once every n nm
    nm_sample_distance = 0.5
    # Sample pooled trace every 2 nm
    nm_sampled_trace = []
    pixel_sample_distance = nm_sample_distance / p_to_nm
    # Starting at zero, for each point in pooled trace, see if the distance from the last point is greater than the sample distance and if not skip until it is
    for i in range(len(pooled_trace)):
        if i == 0:
            nm_sampled_trace.append(pooled_trace[i])
        else:
            # If the distance between the current point and the last point in the sampled trace is greater than the sample distance, add the current point to the sampled trace
            if np.linalg.norm(pooled_trace[i] - nm_sampled_trace[-1]) > pixel_sample_distance:
                nm_sampled_trace.append(pooled_trace[i])

    nm_sampled_trace = np.array(nm_sampled_trace)

    angles_per_nm, angle_diffs = angle_per_nm(nm_sampled_trace, p_to_nm, plot=False)

    # Smooth the angles per nm
    angles_per_nm = gaussian_filter1d(angles_per_nm, 3)

    # Get the defect stats
    defect_stats_dict = defect_stats(angles_per_nm, radians_per_nm_threshld)

    # Print angles per nm with commas between so it can be copied into a spreadsheet
    # print(",".join([f"{angle:.4f}" for angle in angles_per_nm]))

    # Calculate the real distance between the defects
    defect_indexes = np.array([region["midpoint"] for region in defect_stats_dict["defect_regions"]])
    defect_distances = None
    if len(defect_indexes) > 1:
        defect_distances = calculate_real_distance_between_points_in_array(nm_sampled_trace, defect_indexes, p_to_nm)
        if plotting:
            print(f"defect distances: {defect_distances}")

    # Tag the molecules
    # If 3 defects, then dorito
    # If 2 defects, then churro or pasty
    if defect_stats_dict["defect_number"] == 3:
        tag = "dorito"
    elif defect_stats_dict["defect_number"] == 2:
        # If the two defect distances are similar then churro, else pasty
        # Threshold as % of the total length of the trace
        pasty_defect_difference_percentage = 0.1
        if np.abs(defect_distances[0] - defect_distances[1]) < pasty_defect_difference_percentage * np.sum(
            defect_distances
        ):
            tag = "churro"
        else:
            tag = "pasty"
    elif defect_stats_dict["defect_number"] == 1:
        tag = "teardrop"
    else:
        tag = "open"

    if plotting:
        plt.plot(angles_per_nm)
        plt.ylim(-0.5, 0.5)
        plt.axhline(radians_per_nm_threshld, color="k", linestyle="--")
        plt.axhline(0, color="k", linestyle="-")
        for region in defect_stats_dict["defect_regions"]:
            if region["start"] < region["end"]:
                plt.axvspan(region["start"], region["end"], color="red", alpha=0.3)
            else:
                plt.axvspan(region["start"], len(angles_per_nm), color="red", alpha=0.3)
                plt.axvspan(0, region["end"], color="red", alpha=0.3)
            # Identify the deepest point
            # plt.scatter(region["highest_point_index"], angles_per_nm[region["highest_point_index"]], c="lime")
            # Plot the midpoint
            plt.scatter(region["midpoint"], angles_per_nm[region["midpoint"]], c="lime")
        plt.show()

        assert len(angles_per_nm) == len(nm_sampled_trace)

        plt.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        plt.plot(pooled_trace[:, 1], pooled_trace[:, 0], "r")
        plt.plot(nm_sampled_trace[:, 1], nm_sampled_trace[:, 0], "b")
        # # For each defect, mark its highest point
        # for region in defect_stats_dict["defect_regions"]:
        #     plt.scatter(nm_sampled_trace[region["highest_point_index"]][1], nm_sampled_trace[region["highest_point_index"]][0], c="lime")
        # For each defect mark its midpoint
        for region in defect_stats_dict["defect_regions"]:
            plt.scatter(nm_sampled_trace[region["midpoint"]][1], nm_sampled_trace[region["midpoint"]][0], c="lime")
        plt.title(f"tag: {tag}")

    plt.show()

    # plot the points and angle diffs
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    ax.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
    ax.scatter(pooled_trace[:, 1], pooled_trace[:, 0], c="k", s=10)
    # write the angle diff at each point
    for i, angle_diff in enumerate(angle_diffs):
        ax.text(pooled_trace[i][1], pooled_trace[i][0], f"{np.degrees(angle_diff):.2f}", fontsize=12, color="black")

    angle_change_rate_grain_dict[index] = grain_data
    angle_change_rate_grain_dict[index]["angles_per_nm"] = angles_per_nm
    angle_change_rate_grain_dict[index]["simple_defect_stats"] = defect_stats_dict
    angle_change_rate_grain_dict[index]["simple_defect_distances"] = defect_distances
    angle_change_rate_grain_dict[index]["simple_nm_sampled_trace"] = nm_sampled_trace
    angle_change_rate_grain_dict[index]["simple_tag"] = tag
    angle_change_rate_grain_dict[index]["rad_per_nm_threshold"] = radians_per_nm_threshld
    angle_change_rate_grain_dict[index]["angle_diffs"] = angle_diffs

    # angle_change_rate_grain_dict[index] = {
    #     "image": grain_image,
    #     "mask": grain_mask,
    #     "pooled_trace": pooled_trace,
    #     "angles_per_nm": angles_per_nm,
    #     "p_to_nm": p_to_nm,
    #     "defect_stats": defect_stats_dict,
    #     "defect_distances": defect_distances,
    #     "nm_sampled_trace": nm_sampled_trace,
    #     "tag": tag,
    # }


def plot_images(
    images: list,
    tags: list,
    grain_indexes: list,
    nm_sampled_traces: np.ndarray,
    defect_stats: list,
    curvatures_list: list,
    radians_per_nm_threshld_list: list,
    px_to_nms: list,
    width=5,
    cmap=cmap,
    vmin=VMIN,
    vmax=VMAX,
    title: str = "",
):
    num_images = len(images)
    rows = np.ceil(num_images / width).astype(int)
    if rows == 1:
        rows = 2
    fig, ax = plt.subplots(rows, width * 2, figsize=(30, 10 + 10 * rows))
    for i, (
        image,
        tag,
        grain_index,
        single_defect_stats,
        nm_sampled_trace,
        curvatures,
        radians_per_nm_threshld,
    ) in enumerate(
        zip(images, tags, grain_indexes, defect_stats, nm_sampled_traces, curvatures_list, radians_per_nm_threshld_list)
    ):
        ax[i // width, i % width * 2].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        # Plot the trace
        ax[i // width, i % width * 2].plot(nm_sampled_trace[:, 1], nm_sampled_trace[:, 0], "b")
        ax[i // width, i % width * 2].axis("off")
        ax[i // width, i % width * 2].set_title(f"grain index: {grain_index} tag: {tag}")
        ax[i // width, i % width * 2 + 1].plot(curvatures)
        ax[i // width, i % width * 2 + 1].set_title(f"curvatures")
        ax[i // width, i % width * 2 + 1].set_ylim(-0.5, 0.5)
        ax[i // width, i % width * 2 + 1].axhline(0, color="k", linestyle="-")
        ax[i // width, i % width * 2 + 1].axhline(radians_per_nm_threshld, color="k", linestyle="--")
        # Plot midpoints
        for region in single_defect_stats["defect_regions"]:
            ax[i // width, i % width * 2].scatter(
                nm_sampled_trace[region["midpoint"]][1], nm_sampled_trace[region["midpoint"]][0], c="cyan", s=200
            )

    plt.suptitle(title, fontsize=30)
    fig.tight_layout()
    plt.show()


if plot_results:
    plot_images(
        [angle_change_rate_grain_dict[i]["image"] for i in angle_change_rate_grain_dict],
        [angle_change_rate_grain_dict[i]["simple_tag"] for i in angle_change_rate_grain_dict],
        [i for i in angle_change_rate_grain_dict],
        [angle_change_rate_grain_dict[i]["simple_nm_sampled_trace"] for i in angle_change_rate_grain_dict],
        [angle_change_rate_grain_dict[i]["simple_defect_stats"] for i in angle_change_rate_grain_dict],
        [angle_change_rate_grain_dict[i]["angles_per_nm"] for i in angle_change_rate_grain_dict],
        [angle_change_rate_grain_dict[i]["rad_per_nm_threshold"] for i in angle_change_rate_grain_dict],
        [angle_change_rate_grain_dict[i]["p_to_nm"] for i in angle_change_rate_grain_dict],
    )

# Plot just the churros
# tag_to_plot = "unclassified"
# indexes = [i for i in angle_change_rate_grain_dict if angle_change_rate_grain_dict[i]["tag"] == tag_to_plot]
# print(tag_to_plot, indexes)
# plot_images(
#     [angle_change_rate_grain_dict[i]["image"] for i in indexes],
#     [angle_change_rate_grain_dict[i]["tag"] for i in indexes],
#     [i for i in indexes],
#     [angle_change_rate_grain_dict[i]["nm_sampled_trace"] for i in indexes],
#     [angle_change_rate_grain_dict[i]["defect_stats"] for i in indexes],
#     [angle_change_rate_grain_dict[i]["p_to_nm"] for i in indexes],
#     title=tag_to_plot,
# )

# Bar chart of tags
tags = [angle_change_rate_grain_dict[i]["simple_tag"] for i in angle_change_rate_grain_dict]
unique_tags, counts = np.unique(tags, return_counts=True)
# sort by alphabetical order
unique_tags, counts = zip(*sorted(zip(unique_tags, counts)))
plt.bar(unique_tags, counts)
plt.title(f"(Angle per nm based) distribution for {SAMPLE_TYPE}")
plt.show()

In [None]:
# plot all the points

grain_index = 206
grain_data = angle_change_rate_grain_dict[grain_index]
grain_image = grain_data["image"]
grain_mask = grain_data["mask"]
pooled_trace = grain_data["pooled_trace"]
p_to_nm = grain_data["p_to_nm"]
grain_angles_per_nm = grain_data["angles_per_nm"]
defect_stats_dict = grain_data["simple_defect_stats"]
radians_per_nm_threshld = grain_data["rad_per_nm_threshold"]
nm_sampled_trace = grain_data["simple_nm_sampled_trace"]
tag = grain_data["simple_tag"]
grain_angle_diffs = grain_data["angle_diffs"]


fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(grain_image, cmap=cmap, vmin=VMIN, vmax=VMAX)
# scatter the nm_sampled_trace where each point is coloured by the angle diffs

sc = ax.scatter(nm_sampled_trace[:, 1], nm_sampled_trace[:, 0], c=grain_angle_diffs, s=90, cmap="cool_r")
plt.colorbar(sc)
plt.title(f"grain index: {grain_index}, tag: {tag}")
plt.show()

# do the same but for angles per nm
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(grain_image, cmap=cmap, vmin=VMIN, vmax=VMAX)
# scatter the nm_sampled_trace where each point is coloured by the angles per nm
sc = ax.scatter(nm_sampled_trace[:, 1], nm_sampled_trace[:, 0], c=grain_angles_per_nm, s=90, cmap="cool_r")
plt.colorbar(sc)
plt.title(f"grain index: {grain_index}, tag: {tag}")
plt.show()

# fig, ax = plt.subplots(1, 1, figsize=(10, 10))
# ax.imshow(grain_image, cmap=cmap, vmin=VMIN, vmax=VMAX)
# ax.scatter(pooled_trace[:, 1], pooled_trace[:, 0], c="white", s=10)
# write the angle per nm at each point for nm sampled_trace
# for i, angle_diff in enumerate(grain_angle_diffs):
#     # Get angle from last point to this point
#     vector = nm_sampled_trace[i] - nm_sampled_trace[i - 1]
#     angle = np.degrees(np.arctan2(vector[1], vector[0]))
#     # write text rotated to the angle showing the angle diff at each point
#     ax.text(nm_sampled_trace[i][1], nm_sampled_trace[i][0], f"{np.degrees(angle_diff):.1f}", fontsize=12, color="black", rotation=angle)

In [None]:
grain_index = 206
grain_data = angle_change_rate_grain_dict[grain_index]
grain_image = grain_data["image"]
grain_mask = grain_data["mask"]
trace = grain_data["trace"]
pooled_trace = grain_data["pooled_trace"]
p_to_nm = grain_data["p_to_nm"]
height_trace = grain_data["height_trace"]
pixel_trace_length = grain_data["pixel_trace_length"]
angles_per_nm = grain_data["angles_per_nm"]


fig, ax = plt.subplots(1, 1, figsize=(10, 10))
# turn axes off
ax.axis("off")
ax.imshow(grain_image, cmap=cmap, vmin=VMIN, vmax=VMAX)
# ax.plot(trace[:, 1], trace[:, 0], "r")
ax.plot(pooled_trace[:, 1], pooled_trace[:, 0], "b", linewidth=3)

plt.show()


angles_per_nm_x = np.linspace(0, pixel_trace_length, len(angles_per_nm))
plt.plot(angles_per_nm_x, angles_per_nm)
plt.axhline(radians_per_nm_threshld, color="k", linestyle="--")
plt.xlabel("Distance along trace (nm)")
plt.ylabel("Angle change per nm (radians)")
plt.title(f"Angle change per nm")

In [None]:
def weighted_average_position(array: np.ndarray, clip_min: float) -> float:
    """Calculate the weighted average position in an array.

    Parameters
    ----------
    array: np.ndarray
        The array to calculate the weighted average position of.

    Returns
    -------
    float
        The weighted average position in the array.
    """
    positions = np.arange(len(array))
    weights = np.copy(array)
    weights[weights < clip_min] = clip_min
    weighted_average = np.average(positions, weights=array)
    return weighted_average


testarr = np.array([0, 1, 3, 5, 10, 10000000, 2, 12, 0])

print(weighted_average_position(testarr, clip_min=0))


def weighted_mean_defect_position(
    defect_angle_shifts: np.ndarray, defect_start_index: int, max_index: int, clip_min: float = 0.0
):
    weighted_mean_angle_shift_index = (
        weighted_average_position(defect_angle_shifts, clip_min=clip_min) + defect_start_index
    )
    weighted_mean_angle_shift_index_int = int(np.round(weighted_mean_angle_shift_index))
    if weighted_mean_angle_shift_index >= len(angles_per_nm):
        weighted_mean_angle_shift_index -= len(angles_per_nm)
    if weighted_mean_angle_shift_index_int >= len(angles_per_nm):
        weighted_mean_angle_shift_index_int -= len(angles_per_nm)

    return weighted_mean_angle_shift_index, weighted_mean_angle_shift_index_int


def find_distance_window_looped(distances: np.ndarray, window_distance: float, start_index: int):
    """Find a window that covers a certain distance in a set of points in a loop (but the start point is not repeated at the end)
    where the distances between the points are known including the distance between the end point and the start point.
    """

    distances = np.copy(distances)

    window_current_distance = 0
    window_start_index = start_index
    proposed_end_index = start_index
    found_end = False
    max_iterations = len(distances)
    iterations = 0
    while not found_end:
        # print(f"start index: {window_start_index}, proposed end index: {proposed_end_index}, distance so far: {window_current_distance}")
        # Check if distance is greater than the requested distance
        if window_current_distance > window_distance:
            # print(f"found end at index: {proposed_end_index}")
            found_end = True
            window_end_index = proposed_end_index
            return window_start_index, window_end_index, window_current_distance
        else:
            # Increment the window distance
            # print(f"adding distance: {distances[proposed_end_index]}")
            window_current_distance += distances[proposed_end_index]
            # Increment the proposed end index
            if proposed_end_index < len(distances) - 1:
                proposed_end_index += 1
            else:
                proposed_end_index = 0

        # Safety for infinite loop
        if iterations > max_iterations:
            raise ValueError("Max iterations reached")
        iterations += 1

In [None]:
def combine_two_defects(
    defect_0: dict,
    defect_1: dict,
    max_trace_index: int,
):
    """Combine two defects into one defect."""

    # Combine the defects
    combined_indexes = np.unique(np.append(defect_0["indexes"], defect_1["indexes"]))

    # Find the start and end points of the combined indexes
    # Check if the combined indexes span the end of the array
    if max_trace_index in combined_indexes and 0 in combined_indexes:
        # If so, the starting index will be the index without a number preceding it and the end index will be the index without a number following it
        # To find end index, count forward from 0 until there is a number missing
        for candidate_end_index in range(max_trace_index + 1):
            if candidate_end_index + 1 not in combined_indexes:
                end_index = candidate_end_index
                break
        # To find start index, count backward from the end until there is a number missing
        for candidate_start_index in range(max_trace_index, 0, -1):
            if candidate_start_index - 1 not in combined_indexes:
                start_index = candidate_start_index
                break
    else:
        start_index = np.min(combined_indexes)
        end_index = np.max(combined_indexes)

    # Calculate the midpoint of the defect
    if end_index > start_index:
        midpoint = int(np.round((start_index + end_index) / 2))
    else:
        midpoint = int(np.round((start_index + end_index + max_trace_index + 1) / 2))
        if midpoint >= max_trace_index + 1:
            midpoint -= max_trace_index + 1

    combined_defect = {
        "start_index": start_index,
        "end_index": end_index,
        "indexes": combined_indexes,
        "midpoint": midpoint,
    }

    return combined_defect


def combine_overlapping_defects(defects: list, max_trace_index: int):
    """Combine any overlapping defects in a list of defects where each defect is a dictionary of statistics for the defect.

    Defects start at at starting point and end at an ending point. If two defects share any points then they are overlapping
    and should be combined. Defects can span the start and end of the array, which makes combination difficult.
    """

    # Flag to check if any defects have been combined
    defects_were_combined = True
    while defects_were_combined:
        # Reset the flag
        defects_were_combined = False
        # For each defect, check if it overlaps with any other defect
        for i, defect_0 in enumerate(defects):
            for j, defect_1 in enumerate(defects):
                if i != j:
                    # Check if the defects overlap
                    if len(np.intersect1d(defect_0["indexes"], defect_1["indexes"])) > 0:
                        # Combine the defects
                        combined_defect = combine_two_defects(defect_0, defect_1, max_trace_index)
                        # Remove the old defects
                        defects.pop(i)
                        defects.pop(j - 1)
                        # Add the new defect
                        defects.append(combined_defect)
                        # Set the flag
                        defects_were_combined = True
                        break
            if defects_were_combined:
                break

    return defects

In [None]:
plotting = False
plot_results = False
plot_classifications = False
vmin = 0
vmax = 2.5


# Define what a defect is by any section that turns d degrees in n nm
defect_degrees_value = 80
defect_nm_value = 5.0
# pasty_distance_deviation_threshold_nm = 3.0
pasty_distance_deviation_threshold_percentage = 0.08

turn_in_distance_grain_dict = {}

for index, grain_data in angle_change_rate_grain_dict.items():
    # print(f"grain index: {index}")
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]
    trace = grain_data["trace"]
    pooled_trace = grain_data["pooled_trace"]

    # n = 1
    # pooled_trace_every_nth_point = pooled_trace[::n]
    # # Create a copy where the first point is appended to the end to calculate the distances between the last and first point
    # if np.array_equal(pooled_trace_every_nth_point[0], pooled_trace_every_nth_point[-1]):
    #     pooled_trace_every_nth_point_extra_start = np.copy(pooled_trace_every_nth_point)
    #     pooled_trace_every_nth_point = pooled_trace_every_nth_point[:-1]
    # else:
    #     pooled_trace_every_nth_point_extra_start = np.copy(pooled_trace_every_nth_point)
    #     pooled_trace_every_nth_point_extra_start = np.append(
    #         pooled_trace_every_nth_point_extra_start, [pooled_trace_every_nth_point[0]], axis=0
    #     )

    # distances_between_points = np.linalg.norm(np.diff(pooled_trace_every_nth_point_extra_start, axis=0), axis=1)
    # total_distance = np.sum(distances_between_points)

    # distances_between_points_nm = distances_between_points * p_to_nm
    # total_distance_nm = total_distance * p_to_nm

    # angles_per_nm, angle_diffs = angle_per_nm(pooled_trace_every_nth_point, p_to_nm)

    # Calculate distances between points
    pooled_trace_with_start_point = np.append(pooled_trace, [pooled_trace[0]], axis=0)
    distances_between_points = np.linalg.norm(np.diff(pooled_trace_with_start_point, axis=0), axis=1)
    distances_between_points_nm = distances_between_points * p_to_nm
    total_distance_nm = np.sum(distances_between_points_nm)

    # Detect defects
    in_defect = False
    defects = []
    maximum_total_angle_shift = 0
    maximum_angle_shift_index = 0
    for point_index, (point) in enumerate(zip(pooled_trace)):
        # print(f"point index: {point_index}, point: {point}")

        if plotting:
            plt.imshow(grain_image, cmap=cmap, vmin=vmin, vmax=vmax)
            plt.scatter(pooled_trace[:, 1], pooled_trace[:, 0], c="k", s=5)
            plt.scatter(
                pooled_trace[point_index, 1],
                pooled_trace[point_index, 0],
                c="white",
                s=40,
            )

        # Get window
        window_start_index, window_end_index, window_distance = find_distance_window_looped(
            distances_between_points_nm, defect_nm_value, point_index
        )
        # print(
        #     f"  window start index: {window_start_index}, window end index: {window_end_index}, window distance: {window_distance}"
        # )

        if plotting:
            plt.scatter(
                pooled_trace[window_start_index, 1],
                pooled_trace[window_start_index, 0],
                c="b",
                s=20,
            )
            plt.scatter(
                pooled_trace[window_end_index, 1],
                pooled_trace[window_end_index, 0],
                c="r",
                s=20,
            )

            # Plot start and end vectors on the start and end points

            # start vector
            if window_start_index == len(pooled_trace) - 1:
                start_vector = pooled_trace[0] - pooled_trace[window_start_index]
            else:
                start_vector = pooled_trace[window_start_index + 1] - pooled_trace[window_start_index]
            start_vector = start_vector / np.linalg.norm(start_vector)

            # end vector
            if window_end_index == len(pooled_trace) - 1:
                end_vector = pooled_trace[0] - pooled_trace[window_end_index]
            else:
                end_vector = pooled_trace[window_end_index + 1] - pooled_trace[window_end_index]
            end_vector = end_vector / np.linalg.norm(end_vector)

            plt.arrow(
                pooled_trace[window_start_index, 1],
                pooled_trace[window_start_index, 0],
                start_vector[1] * 10,
                start_vector[0] * 10,
                head_width=2,
                head_length=2,
                fc="b",
                ec="b",
            )

            plt.arrow(
                pooled_trace[window_end_index, 1],
                pooled_trace[window_end_index, 0],
                end_vector[1] * 10,
                end_vector[0] * 10,
                head_width=2,
                head_length=2,
                fc="r",
                ec="r",
            )

        # Calculate the starting vector as the mean vector between the previous point and the starting point of the window, and the starting point and next point
        if window_start_index == 0:
            v0 = pooled_trace[window_start_index] - pooled_trace[-1]
            v1 = pooled_trace[window_start_index + 1] - pooled_trace[window_start_index]
        elif window_start_index == len(pooled_trace) - 1:
            v0 = pooled_trace[window_start_index] - pooled_trace[window_start_index - 1]
            v1 = pooled_trace[0] - pooled_trace[window_start_index]
        else:
            v0 = pooled_trace[window_start_index] - pooled_trace[window_start_index - 1]
            v1 = pooled_trace[window_start_index + 1] - pooled_trace[window_start_index]

        mean_start_vector = (v0 + v1) / 2

        # Calculate the ending vector as the mean vector between the ending point of the window and the next point, and the ending point and previous point
        if window_end_index == 0:
            v0 = pooled_trace[window_end_index] - pooled_trace[-1]
            v1 = pooled_trace[window_end_index + 1] - pooled_trace[window_end_index]
        elif window_end_index == len(pooled_trace) - 1:
            v0 = pooled_trace[window_end_index] - pooled_trace[window_end_index - 1]
            v1 = pooled_trace[0] - pooled_trace[window_end_index]
        else:
            v0 = pooled_trace[window_end_index] - pooled_trace[window_end_index - 1]
            v1 = pooled_trace[window_end_index + 1] - pooled_trace[window_end_index]

        mean_end_vector = (v0 + v1) / 2

        angle_between_start_and_end = angle_diff_signed(mean_start_vector, mean_end_vector)
        # print(
        #     f"  angle between start and end: {angle_between_start_and_end} ({np.degrees(angle_between_start_and_end)} degrees)"
        # )

        # See if the angle shift is greater than the defect_degrees_value
        # Calculate total angle change by summing the angle shifts in the window
        # total_angle_shift = 0
        # # End index might be lower than start index if the window wraps around
        # if window_end_index > window_start_index:
        #     window_angle_shifts = angle_diffs[window_start_index : window_end_index + 1]
        # else:
        #     window_angle_shifts = np.append(angle_diffs[window_start_index:], angle_diffs[: window_end_index + 1])
        # print(f"window end index: {window_end_index}, window start index: {window_start_index}, angle_diffs: {angle_diffs}")

        # print(f"total angle shift: {total_angle_shift} ({np.degrees(total_angle_shift)} degrees)")

        if np.abs(angle_between_start_and_end) > np.radians(defect_degrees_value):
            # print("window angle > defect_degrees_value")
            # print(f"currently in defect: {in_defect}")
            if not in_defect:
                # print(f"@@@ DEFECT START at index: {point_index} window end index: {window_end_index}")
                if plotting:
                    # Plot between start and end index following the pooled trace
                    if window_end_index > window_start_index:
                        plt.plot(
                            pooled_trace[window_start_index : window_end_index + 1, 1],
                            pooled_trace[window_start_index : window_end_index + 1, 0],
                            "g",
                        )
                    else:
                        plt.plot(
                            np.append(
                                pooled_trace[window_start_index:, 1],
                                pooled_trace[: window_end_index + 1, 1],
                            ),
                            np.append(
                                pooled_trace[window_start_index:, 0],
                                pooled_trace[: window_end_index + 1, 0],
                            ),
                            "g",
                        )
                    plt.title(
                        f"defect start index: {point_index}, end index: {window_end_index},({np.degrees(angle_between_start_and_end)} degrees)"
                    )
                in_defect = True
                defect_start_index = point_index
                if window_end_index > window_start_index:
                    defect_indexes = np.arange(window_start_index, window_end_index + 1)
                else:
                    defect_indexes = np.append(
                        np.arange(window_start_index, len(pooled_trace)), np.arange(0, window_end_index + 1)
                    )
                maximum_total_angle_shift = angle_between_start_and_end

                # print(f"defect indexes: {defect_indexes}")
                # print(f"defect shifts: {defect_shifts}")
                # print(
                #     f"initial total angle shift: {maximum_total_angle_shift} ({np.degrees(maximum_total_angle_shift)} degrees)"
                # )

                # print(f"starting defect, max shift: {maximum_angle_shift}, index: {maximum_angle_shift_index}")
            else:
                # Add the new point(s) to the defect indexes. Note there may be more than one point added due to a high density of points in the window since we are sampling every n nm rather than n points
                # Careful of the case where the window wraps around
                if window_end_index > window_start_index:
                    defect_indexes = np.append(defect_indexes, np.arange(window_start_index, window_end_index + 1))
                else:
                    defect_indexes = np.append(
                        defect_indexes,
                        np.append(np.arange(window_start_index, len(pooled_trace)), np.arange(0, window_end_index + 1)),
                    )
                # Ensure each index is unique
                defect_indexes = np.unique(defect_indexes)

                # Plot between start and end index following the pooled trace
                if plotting:
                    if window_end_index > window_start_index:
                        plt.plot(
                            pooled_trace[window_start_index : window_end_index + 1, 1],
                            pooled_trace[window_start_index : window_end_index + 1, 0],
                            "lime",
                            linewidth=2,
                        )
                    else:
                        plt.plot(
                            np.append(
                                pooled_trace[window_start_index:, 1],
                                pooled_trace[: window_end_index + 1, 1],
                            ),
                            np.append(
                                pooled_trace[window_start_index:, 0],
                                pooled_trace[: window_end_index + 1, 0],
                            ),
                            "lime",
                        )
                    plt.title(
                        f"defect start index: {defect_start_index}, end index: {window_end_index}, total angle shift: {angle_between_start_and_end} ({np.degrees(angle_between_start_and_end)} degrees)"
                    )
            # print(f"defect indexes: {defect_indexes}")

        else:
            if plotting:
                plt.title(
                    f"no defect at index: {point_index} window end index: {window_end_index} total angle shift: {angle_between_start_and_end} ({np.degrees(angle_between_start_and_end)} degrees)"
                )
            if in_defect:
                in_defect = False
                defect_end_index = window_end_index
                # print(
                #     f"@@@ DEFECT DONE: start index: {defect_start_index}, end index: {defect_end_index}, window angle shifts: {window_angle_shifts} max shift: {maximum_angle_shift}, index: {maximum_angle_shift_index}"
                # )
                # print(f"defect indexes: {defect_indexes}")
                # print(f"defect shifts: {defect_shifts}")

                # Calculate the midpoint of the indexes
                if defect_end_index > defect_start_index:
                    defect_midpoint = int(np.mean([defect_start_index, defect_end_index]))
                else:
                    defect_midpoint = int(np.mean([defect_start_index, defect_end_index + len(pooled_trace)]))
                    if defect_midpoint >= len(pooled_trace):
                        defect_midpoint -= len(pooled_trace)

                defects.append(
                    {
                        "start_index": defect_start_index,
                        "end_index": defect_end_index,
                        "maximum_total_angle_shift": maximum_total_angle_shift,
                        "indexes": defect_indexes,
                        "midpoint": defect_midpoint,
                    }
                )

        plt.show()

    if in_defect:
        in_defect = False
        defect_end_index = window_end_index
        # print(f"@@@ DEFECT DONE: start index: {defect_start_index}, end index: {defect_end_index} max shift: {maximum_angle_shift}, index: {maximum_angle_shift_index}")

        # Calculate the midpoint of the indexes
        if defect_end_index > defect_start_index:
            defect_midpoint = int(np.mean([defect_start_index, defect_end_index]))
        else:
            defect_midpoint = int(np.mean([defect_start_index, defect_end_index + len(pooled_trace)]))
            if defect_midpoint >= len(pooled_trace):
                defect_midpoint -= len(pooled_trace)

        defects.append(
            {
                "start_index": defect_start_index,
                "end_index": defect_end_index,
                "maximum_total_angle_shift": maximum_total_angle_shift,
                "indexes": defect_indexes,
                "midpoint": defect_midpoint,
            }
        )

    # Combine any overlapping regions
    combined_defects = combine_overlapping_defects(defects, max_trace_index=len(pooled_trace) - 1)
    defects = combined_defects

    if len(defects) > 0:
        # Sort the list
        defects = sorted(defects, key=lambda x: x["midpoint"])
        # Calculate the distances between each defect along the trace
        to_find = [defect["midpoint"] for defect in defects]
        found_original = False
        current_index = to_find[0]
        current_position = pooled_trace[current_index]
        previous_position = current_position
        distances = []
        distance = 0
        while len(to_find) > 0:
            # Update old position
            previous_position = current_position
            # Increment the current position along the trace
            current_index += 1
            if current_index >= len(pooled_trace):
                current_index -= len(pooled_trace)
            current_position = pooled_trace[current_index]

            # Increment the distance
            distance += np.linalg.norm(current_position - previous_position) * p_to_nm

            # Check if the current index is in the to_find list
            if current_index in to_find:
                # Get rid of the current index from the to_find list
                to_find.remove(current_index)
                # Store the distance
                distances.append(distance)
                # Reset the distance
                distance = 0
    else:
        distances = []

    # print(f"distances between defects: {distances}")

    if plot_results:
        fig, ax = plt.subplots(1, figsize=(8, 8))
        for defect_index, defect in enumerate(defects):
            thisax = ax
            # Plot the defect
            thisax.imshow(grain_image, cmap=cmap, vmin=vmin, vmax=vmax)
            # Plot a horizontal line at the top left of the image starting at 10, 10 with a length equal to the nm distance threshold for defect
            # thisax.plot([2, 2 + defect_nm_value / p_to_nm], [2, 2], "r")
            # Plot the points
            thisax.scatter(pooled_trace[:, 1], pooled_trace[:, 0], c="k", s=40)
            # Plot the start point of the entire trace
            # thisax.scatter(pooled_trace[0, 1], pooled_trace[0, 0], c="pink", s=60)
            # Write the angle shift at each point
            # for i, point in enumerate(pooled_trace):
            #     thisax.text(
            #         point[1], point[0], f"{int(np.round(np.degrees(angle_diffs[i])))}", fontsize=14, color="white"
            #     )

            # # Plot the maximum angle shift index
            # thisax.scatter(
            #     pooled_trace[defect["maximum_angle_shift_index"], 1],
            #     pooled_trace[defect["maximum_angle_shift_index"], 0],
            #     c="green",
            #     s=100,
            #     alpha=1,
            # )
            # Plot the weighted mean angle shift index

            # Plot a line between the start and end index following the pooled trace
            if defect["end_index"] > defect["start_index"]:
                thisax.plot(
                    pooled_trace[defect["start_index"] : defect["end_index"] + 1, 1],
                    pooled_trace[defect["start_index"] : defect["end_index"] + 1, 0],
                    "lime",
                    linewidth=4,
                )
            else:
                thisax.plot(
                    np.append(
                        pooled_trace[defect["start_index"] :, 1],
                        pooled_trace[: defect["end_index"] + 1, 1],
                    ),
                    np.append(
                        pooled_trace[defect["start_index"] :, 0],
                        pooled_trace[: defect["end_index"] + 1, 0],
                    ),
                    "lime",
                    linewidth=3,
                )

            # Plot the start and end index
            thisax.scatter(
                pooled_trace[defect["start_index"], 1],
                pooled_trace[defect["start_index"], 0],
                c="blue",
                s=150,
                alpha=1,
            )
            thisax.scatter(
                pooled_trace[defect["end_index"], 1],
                pooled_trace[defect["end_index"], 0],
                c="red",
                s=150,
                alpha=1,
            )

            # Plot the midpoint
            thisax.scatter(
                pooled_trace[defect["midpoint"], 1],
                pooled_trace[defect["midpoint"], 0],
                c="cyan",
                s=600,
                alpha=1,
            )

        plt.suptitle(
            f"Defects {len(defects)} for grain {index}, total distance: {total_distance_nm:.2f} nm, defect degrees: {defect_degrees_value}, defect nm: {defect_nm_value}, p_to_nm: {p_to_nm:.2f}"
        )
        fig.tight_layout()
        plt.show()

    # Classification
    # If there are 3 defects then it is a dorito
    # If there are 2 defects then it is either a churro or pasty
    # If the two defects have significantly different distances then it's a pasty, else it's a churro
    # If it has fewer than 2 or more than 3, then it is unclassified

    if len(defects) == 3:
        tag = "dorito"
    elif len(defects) == 2:
        abs_diff_distance = np.abs(distances[0] - distances[1])
        # print(f"index : {index} abs diff distance: {abs_diff_distance} total distance: {total_distance_nm} pasty threshold: {pasty_distance_deviation_threshold_percentage * total_distance_nm}")
        if abs_diff_distance > pasty_distance_deviation_threshold_percentage * total_distance_nm:
            tag = "pasty"
        else:
            tag = "churro"
    elif len(defects) == 1:
        tag = "teardrop"
    else:
        tag = "open"

    turn_in_distance_grain_dict[index] = grain_data
    turn_in_distance_grain_dict[index]["pooled_trace"] = pooled_trace
    turn_in_distance_grain_dict[index]["complex_defects"] = defects
    turn_in_distance_grain_dict[index]["complex_distances_between_defects"] = distances
    turn_in_distance_grain_dict[index]["complex_tag"] = tag
    turn_in_distance_grain_dict[index]["total_distance"] = total_distance_nm
    turn_in_distance_grain_dict[index]["complex_num_defects"] = len(defects)
    turn_in_distance_grain_dict[index]["distances_between_points"] = distances_between_points
    turn_in_distance_grain_dict[index]["complex_angles_per_nm"] = angles_per_nm


def plot_images(
    images: list,
    tags: list,
    traces: list,
    grain_indexes: list,
    defects: list,
    distances_between_defects: list,
    feret_ratios: list,
    area_perimeter_ratios: list,
    px_to_nms: list,
    width=5,
    cmap=cmap,
    vmin=VMIN,
    vmax=VMAX,
):
    num_images = len(images)
    rows = np.ceil(num_images / width).astype(int)
    fig, ax = plt.subplots(rows, width, figsize=(30, rows * 5))
    for i, (
        image,
        tag,
        grain_index,
        trace,
        defect_dict,
        defect_distances,
        feret_ratio,
        area_perimeter_ratio,
    ) in enumerate(
        zip(
            images, tags, grain_indexes, traces, defects, distances_between_defects, feret_ratios, area_perimeter_ratios
        )
    ):
        if rows == 1:
            thisax = ax[i]
        else:
            thisax = ax[i // width, i % width]
        thisax.imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        thisax.axis("off")
        distances_between_defects_string = ", ".join([f"{distance:.2f}" for distance in defect_distances])
        thisax.set_title(
            f"index: {grain_index} tag: {tag} p_to_nm: {px_to_nms[i]:.2f}\n defect distances: {distances_between_defects_string} total distance: {np.sum(defect_distances):.2f} nm \n feret ratio: {feret_ratio:.2f} area perimeter ratio: {area_perimeter_ratio:.2f}"
        )
        thisax.plot(trace[:, 1], trace[:, 0], "black", ".")
        for defect_index, defect in enumerate(defect_dict):
            thisax.scatter(
                trace[defect["midpoint"]][1],
                trace[defect["midpoint"]][0],
                c="cyan",
                s=400,
            )
        # Plot a horizontal line at the bottom right of the image with the length of 20nm
        line_length_nm = 10
        line_length_pixels = line_length_nm / px_to_nms[i]
        offset_from_right = image.shape[1] * 0.05
        offset_from_bottom = image.shape[1] * 0.05
        line_right_point_x = image.shape[1] - offset_from_right
        line_left_point_x = line_right_point_x - line_length_pixels
        line_bottom_point_y = image.shape[0] - offset_from_bottom
        line_top_point_y = line_bottom_point_y
        # thisax.plot(
        #     [line_left_point_x, line_right_point_x],
        #     [line_bottom_point_y, line_top_point_y],
        #     "white",
        #     linewidth=5,
        # )
        # Add text above the line
        # thisax.text(
        #     line_right_point_x - 10,
        #     line_bottom_point_y - 2,
        #     f"{line_length_nm} nm",
        #     fontsize=24,
        #     fontweight="bold",
        #     color="white",
        # )
    fig.tight_layout()
    plt.show()


# plot_images(
#     [turn_in_distance_grain_dict[i]["image"] for i in turn_in_distance_grain_dict],
#     [turn_in_distance_grain_dict[i]["tag"] for i in turn_in_distance_grain_dict],
#     [i for i in turn_in_distance_grain_dict],
#     [turn_in_distance_grain_dict[i]["p_to_nm"] for i in turn_in_distance_grain_dict],
# )

if plot_classifications:
    for tag_to_plot in ["churro", "pasty", "dorito", "teardrop", "open"]:
        indexes = [
            i for i in turn_in_distance_grain_dict if turn_in_distance_grain_dict[i]["complex_tag"] == tag_to_plot
        ]
        print(tag_to_plot, indexes)
        if len(indexes) > 0:
            plot_images(
                [turn_in_distance_grain_dict[i]["image"] for i in indexes],
                [turn_in_distance_grain_dict[i]["complex_tag"] for i in indexes],
                [turn_in_distance_grain_dict[i]["pooled_trace"] for i in indexes],
                [i for i in indexes],
                [turn_in_distance_grain_dict[i]["complex_defects"] for i in indexes],
                [turn_in_distance_grain_dict[i]["complex_distances_between_defects"] for i in indexes],
                [turn_in_distance_grain_dict[i]["feret_ratio"] for i in indexes],
                [turn_in_distance_grain_dict[i]["perimeter_area_ratio"] for i in indexes],
                [turn_in_distance_grain_dict[i]["p_to_nm"] for i in indexes],
            )

# Plot bar chart of tags
tags = [turn_in_distance_grain_dict[i]["complex_tag"] for i in turn_in_distance_grain_dict]
unique_tags = ["churro", "pasty", "dorito", "teardrop", "open"]
counts = [tags.count(tag) for tag in unique_tags]
plt.bar(unique_tags, counts)
plt.title(f"(Turn in distance based) distribution for {SAMPLE_TYPE}")
plt.show()

### SPLINING

In [None]:
from scipy.interpolate import splprep, splev
from mpl_toolkits.axes_grid1 import make_axes_locatable


def _rim_curvature(xs: np.ndarray, ys: np.ndarray):
    """Calculate the curvature of a set of points. Uses the standard curvature definition of the derivative of the
    tangent vector.

    Parameters:
    ----------
    xs: np.ndarray
        One dimensional numpy array of x-coordinates of the points
    ys: np.ndarray
        One dimensional numpy array of y-coordinates of the points
    Returns:
    -------
    np.ndarray
        One-dimensional numpy array of curvatures for the spline.
    """
    extension_length = xs.shape[0]
    xs_extended = np.append(xs, xs)
    xs_extended = np.append(xs_extended, xs)

    # # Plot the extended points
    # plt.plot(xs_extended)
    # plt.title("Extended x")
    # plt.show()

    ys_extended = np.append(ys, ys)
    ys_extended = np.append(ys_extended, ys)
    dx = np.gradient(xs_extended)

    # # Plot the extended dx
    # plt.plot(dx)
    # plt.title("Extended dx")
    # plt.show()

    dy = np.gradient(ys_extended)
    d2x = np.gradient(dx)

    # # Plot the extended d2x
    # plt.plot(d2x)
    # plt.title("Extended d2x")
    # plt.show()

    d2y = np.gradient(dy)
    curv = np.abs(dx * d2y - d2x * dy) / (dx * dx + dy * dy) ** 1.5

    # # Plot the extended curvature
    # plt.plot(curv)
    # plt.title("Extended curvature")
    # plt.show()

    curv = curv[extension_length : (len(curv) - extension_length)]

    return curv


def _interpolate_points_spline(points: np.ndarray, num_points: int, smoothing: float = 0.0):
    """Interpolate a set of points using a spline.

    Parameters
    ----------
    points: np.ndarray
        Nx2 Numpy array of coordinates for the points.
    num_points: int
        The number of points to return following the calculated spline.

    Returns
    -------
    interpolated_points: np.ndarray
        An Ix2 Numpy array of coordinates of the interpolated points, where I is the number of points
        specified.
    """

    x, y = splprep(points.T, u=None, s=smoothing, per=1)
    x_spline = np.linspace(y.min(), y.max(), num_points)
    x_new, y_new = splev(x_spline, x, der=0)
    interpolated_points = np.array((x_new, y_new)).T
    return interpolated_points


def interpolate_spline_and_get_curvature(points: np.ndarray, interpolation_number: int, smoothing: float = 0.0):
    """Calculate the curvature for a set of points in a closed loop. Interpolates the points using a spline
    to reduce anomalies.

    Parameters
    ----------
    points: np.ndarray
        2xN Numpy array of coordinates for the points.
    interpolation_number: int
        Number of interpolation points per point. Eg: for a set of 10 points and 2 interpolation points,
        there will be 20 points in the spline.

    Returns
    -------
    interpolated_curvatures: np.ndarray
        1xN Numpy array of curvatures corresponding to the interpolated points.
    interpolated_points: np.ndarray
        2xN Numpy array of interpolated points generated from the spline of the
        original points.
    """

    # Interpolate the data using cubic splines
    num_points = interpolation_number * points.shape[0]
    interpolated_points = _interpolate_points_spline(points=points, num_points=num_points, smoothing=smoothing)
    x = interpolated_points[:, 0]
    y = interpolated_points[:, 1]

    # Calculate the curvature
    interpolated_curvatures = _rim_curvature(x, y)

    return interpolated_curvatures, interpolated_points


def visualise_curvature_pixel_image(
    curvatures: np.ndarray,
    points: np.ndarray,
    image_size: int = 100,
    title: str = "",
    figsize=(12, 12),
):
    """Visualise the curvature of a set of points using a pixel heightmap image.

    Parameters
    ----------
    curvatures: np.ndarray
        Numpy Nx1 array of curvatures for the points.
    points: np.ndarray
        Numpy Nx2 array of coordinates for the points.

    Returns
    -------
    None
    """

    # Construct a visualisation
    curv_img = np.zeros((image_size, image_size))
    scaling_factor = (curv_img.shape[0] * 1.4) / np.max(points) / 2
    centroid = np.array([np.mean(points[:, 0]), np.mean(points[:, 1])])
    for point, curvature in zip(points, curvatures):
        scaled_point = ((np.array(curv_img.shape) / 2) + (point * scaling_factor) - centroid * scaling_factor).astype(
            int
        )
        curv_img[scaled_point[0], scaled_point[1]] = curvature

    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(np.flipud(curv_img.T), cmap="rainbow")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax, orientation="vertical")
    # ax.colorbar()
    ax.set_title(title)
    plt.show()


def visualise_curvature_scatter(curvatures: np.ndarray, points: np.ndarray, title: str = ""):
    """Visualise the curvature of a set of points using a scatter plot with colours of the markers
    representing the curvatures of the points."""

    # Plot the points
    scatter_plot = plt.scatter(points[:, 0], points[:, 1], c=curvatures, cmap="rainbow", s=0.5)
    plt.title(title)
    plt.colorbar(scatter_plot)
    # aspect ratio of 1
    # flip y axis
    plt.gca().invert_yaxis()
    plt.show()


def simple_curvature(points):
    x, y = points.T
    dx_dt = np.gradient(x)
    dy_dt = np.gradient(y)
    d2x_dt2 = np.gradient(dx_dt)
    d2y_dt2 = np.gradient(dy_dt)
    curvature = (d2x_dt2 * dy_dt - dx_dt * d2y_dt2) / (dx_dt**2 + dy_dt**2) ** 1.5
    return curvature

In [None]:
print(f"indexes available: {list(turn_in_distance_grain_dict.keys())}")

for grain_index in [219]:
    print(f"grain index: {grain_index}")

    grain_image = turn_in_distance_grain_dict[grain_index]["image"]
    pooled_trace = turn_in_distance_grain_dict[grain_index]["pooled_trace"]
    trace = turn_in_distance_grain_dict[grain_index]["trace"]
    p_to_nm = turn_in_distance_grain_dict[grain_index]["p_to_nm"]
    angles_per_nm = turn_in_distance_grain_dict[grain_index]["angles_per_nm"]

    for smoothing in [30.0]:
        # spline and get curvature
        interpolated_curvatures, interpolated_points = interpolate_spline_and_get_curvature(
            trace, interpolation_number=10, smoothing=smoothing
        )

        plt.imshow(grain_image, cmap=cmap, vmin=vmin, vmax=vmax)
        plt.plot(trace[:, 1], trace[:, 0], "red", linewidth=2)
        plt.show()

        # visualise
        # visualise_curvature_pixel_image(interpolated_curvatures, interpolated_points, title="Spline curvature")
        # visualise_curvature_scatter(interpolated_curvatures, interpolated_points, title="Spline curvature")

        # Plot the original trace with the colour of the line representing the curvature

        fig, ax = plt.subplots(1, figsize=(8, 8))
        ax.imshow(grain_image, cmap=cmap, vmin=vmin, vmax=vmax)

        # for each line section, colour it according to the curvature
        # colours = plt.cm.rainbow(interpolated_curvatures / np.max(interpolated_curvatures))
        # for i in range(len(interpolated_points) - 1):
        #     ax.plot(
        #         [interpolated_points[i, 1], interpolated_points[i + 1, 1]],
        #         [interpolated_points[i, 0], interpolated_points[i + 1, 0]],
        #         color=colours[i],
        #         linewidth=3,
        #     )
        # # plot original trace in black
        # # ax.plot(trace[:, 1], trace[:, 0], "red", linewidth=3)
        # # Plot pooled trace in blue
        # # ax.plot(pooled_trace[:, 1], pooled_trace[:, 0], "blue", linewidth=3)
        # # add colorbar sized to match the plot so the top of the colorbar is at the top of the plot
        # divider = make_axes_locatable(ax)
        # cax = divider.append_axes("right", size="5%", pad=0.05)
        # sm = plt.cm.ScalarMappable(cmap="rainbow", norm=plt.Normalize(vmin=0, vmax=1))
        # sm._A = []
        # cbar = plt.colorbar(sm, cax=cax)
        # cbar.set_label("Curvature")

        # Plot spline
        ax.plot(interpolated_points[:, 1], interpolated_points[:, 0], "black", linewidth=3)

        fig.tight_layout()
        ax.set_title(f"Spline curvature (smoothing: {smoothing})")
        plt.show()

        plt.imshow(grain_image, cmap=cmap, vmin=vmin, vmax=vmax)
        # Plot pooled trace where the colour of the line represents the angles per nm
        # colours = plt.cm.rainbow(angles_per_nm / np.max(angles_per_nm))
        # for i in range(len(pooled_trace) - 1):
        #     plt.plot(
        #         [pooled_trace[i, 1], pooled_trace[i + 1, 1]],
        #         [pooled_trace[i, 0], pooled_trace[i + 1, 0]],
        #         color=colours[i],
        #         linewidth=3,
        #     )
        plt.plot(pooled_trace[:, 1], pooled_trace[:, 0], "blue", linewidth=2)
        plt.title(f"pooled trace")
        plt.show()

        # Calculate a spline for the original trace using splev and splprep
        tck, u = splprep(trace.T, u=None, s=smoothing, per=1)

        x_spline = np.linspace(u.min(), u.max(), 1000)
        x_new, y_new = splev(x_spline, tck, der=0)
        interpolated_points = np.array((x_new, y_new)).T
        # plot the interpolated points

        fig, ax = plt.subplots(1, figsize=(10, 10))
        ax.imshow(grain_image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax.plot(interpolated_points[:, 1], interpolated_points[:, 0], "black", linewidth=3)
        ax.plot(trace[:, 1], trace[:, 0], "red", linewidth=3)
        # ax.plot(pooled_trace[:, 1], pooled_trace[:, 0], "blue", linewidth=3)
        ax.set_title("Original trace (red) and spline (black)")
        plt.show()

In [None]:
# plot a circle and calculate the curvature

# circle
radius = 10
theta = np.linspace(0, 2 * np.pi, 30)
x = radius * np.cos(theta)
y = radius * np.sin(theta)
points = np.array((x, y)).T

# elipse
a = 10
b = 8
theta = np.linspace(0, 2 * np.pi, 30)
x = a * np.cos(theta)
y = b * np.sin(theta)
points = np.array((x, y)).T

# remove the last point since it is the same as the first
points = points[:-1]

# copy points to either side
points = np.concatenate([points, points, points], axis=0)


# calculate the curvature
curvatures = simple_curvature(points)

# only get the middle 30 curvature values
curvatures = curvatures[29:59]

# plot the circle
fig, ax = plt.subplots(1, figsize=(8, 8))
ax.plot(x, y, linewidth=3)
ax.set_aspect("equal")
ax.set_title("Ellipse", fontsize=32)
ax.set_xticks([-10, 0, 10])
ax.set_yticks([-10, 0, 10])
# set tick font size
ax.tick_params(axis="both", which="major", labelsize=20)
plt.show()

# plot curvature
fig, ax = plt.subplots(1, figsize=(8, 2))
ax.plot(curvatures, linewidth=3)
ax.set_ylim(-0.2, 0)
# ax.set_yticks([-0.2, 0.2])

ax.set_xticks([0, 15, 30])
# set tick font size
ax.tick_params(axis="both", which="major", labelsize=14)
ax.set_xlabel("Position", fontsize=18)
ax.set_ylabel("Curvature", fontsize=18)
ax.set_title("Curvature along the ellipse", fontsize=32)