In [None]:
from pathlib import Path
from typing import Union

import matplotlib.pyplot as plt
import numpy as np
import h5py

from topostats.plottingfuncs import Colormap

colormap = Colormap()
cmap = colormap.get_cmap()

In [None]:
def plot(
    image: np.ndarray, title: str = None, vmin: float = -8, vmax: float = 8, cmap=cmap, figsize=(10, 10), cbar=False
):
    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
    if title is not None:
        ax.set_title(title)
    if cbar:
        fig.colorbar(im, ax=ax)
    plt.show()

In [None]:
MAX_PX_TO_NM = 0.59
BBOX_PAD = 5

In [None]:
def plot_images(images: list, masks: list, px_to_nms: list, width=5, cmap=cmap, vmin=-8, vmax=8):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 3, figsize=(30, 30))
    for i, (image, mask) in enumerate(zip(images, masks)):
        ax[i // width, i % width * 3].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3].axis("off")
        ax[i // width, i % width * 3 + 1].imshow(mask, cmap="binary")
        ax[i // width, i % width * 3].set_title(f"p_to_nm: {px_to_nms[i]}")
        ax[i // width, i % width * 3 + 2].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3 + 2].imshow(mask, cmap="viridis", alpha=0.2)
    fig.tight_layout()
    plt.show()


SAMPLE_TYPE = "OT2_REL"
on_rel = Path(f"/Users/sylvi/topo_data/hariborings/testing_all_unbound_data/output_{SAMPLE_TYPE}/processed/")
assert on_rel.exists()
# Grab all .topostats files
on_files = list(on_rel.glob("*.topostats"))

# file = on_files[1]

grains_processed = 0
stop_at_grain = 200
plotting = False

grain_dict = {}

for file in on_files:
    print(file)
    # Load file
    with h5py.File(file, "r") as f:
        print(f.keys())
        image = f["image"][:]
        grain_masks = f["grain_masks"]["above"][:]
        p_to_nm = f["pixel_to_nm_scaling"][()]

    if p_to_nm > MAX_PX_TO_NM:
        continue

    # Plot image and mask side by side
    if plotting:
        fig, ax = plt.subplots(1, 2, figsize=(20, 10))
        ax[0].imshow(image, cmap=cmap, vmin=-8, vmax=8)
        ax[0].set_title("image")
        ax[1].imshow(grain_masks, cmap="gray")
        ax[1].set_title("grain_masks")
        plt.suptitle(f"pixel to nm scaling: {p_to_nm}")
        fig.tight_layout()
        plt.show()

    # Process the grains
    for grain in range(1, grain_masks.max() + 1):
        if grains_processed == stop_at_grain:
            break
        # Get the bounding box of the grain
        grain_mask_fullsize = grain_masks == grain
        grain_bbox = np.argwhere(grain_mask_fullsize)
        minr, minc = grain_bbox.min(axis=0)
        maxr, maxc = grain_bbox.max(axis=0)
        # Add padding to the bounding box
        minr = max(0, minr - BBOX_PAD)
        minc = max(0, minc - BBOX_PAD)
        maxr = min(grain_mask_fullsize.shape[0], maxr + BBOX_PAD)
        maxc = min(grain_mask_fullsize.shape[1], maxc + BBOX_PAD)

        # Get the crop of grain image
        grain_image = image[minr:maxr, minc:maxc]
        grain_mask = grain_mask_fullsize[minr:maxr, minc:maxc]

        if plotting:
            fig, ax = plt.subplots(1, 3, figsize=(20, 10))
            ax[0].imshow(grain_mask, cmap="gray")
            ax[0].set_title("grain mask")
            ax[1].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
            ax[1].set_title("grain image")
            ax[2].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
            ax[2].imshow(grain_mask, cmap="gray", alpha=0.2)
            plt.show()

        grain_dict[grains_processed] = {
            "image": grain_image,
            "mask": grain_mask,
            "p_to_nm": p_to_nm,
        }

        grains_processed += 1

    if grains_processed == stop_at_grain:
        break

# Plot the grains
images = [grain_dict[i]["image"] for i in range(grains_processed)]
masks = [grain_dict[i]["mask"] for i in range(grains_processed)]
px_to_nms = [grain_dict[i]["p_to_nm"] for i in range(grains_processed)]
plot_images(images, masks, px_to_nms)

In [None]:
# Clean up the masks
from skimage.morphology import binary_dilation, binary_erosion

DILATION_PASS = 2
ERODE_PASS = 2

dilated_grain_dict = {}

for index, grain_data in grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]

    # Dilation
    for _ in range(DILATION_PASS):
        grain_mask = binary_dilation(grain_mask)
    # Erosion
    for _ in range(ERODE_PASS):
        grain_mask = binary_erosion(grain_mask)

    dilated_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "p_to_nm": p_to_nm,
    }

plot_images(
    [dilated_grain_dict[i]["image"] for i in range(grains_processed)],
    [dilated_grain_dict[i]["mask"] for i in range(grains_processed)],
    [dilated_grain_dict[i]["p_to_nm"] for i in range(grains_processed)],
)


from skimage.measure import label, regionprops

LOWER_AREA_BOUND = 100
UPPER_AREA_BOUND = 10000

removed_anomaly_grain_dict = {}
for index, grain_data in dilated_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]

    # Label the grains
    labelled_background = label(grain_mask == 0)
    background_props = regionprops(labelled_background)

    if len(background_props) < 2:
        print(f"Grain {index} has too few background regions")
        # plt.imshow(labelled_background)
        # print(len(background_props))
        # plt.show()
    elif len(background_props) >= 3:
        print(f"Grain {index} has too many background regions")
        # plt.imshow(labelled_background)
        # print(len(background_props))
        # plt.show()
    else:
        # Check the size of the foreground
        foreground_area = grain_mask.sum()
        if foreground_area < LOWER_AREA_BOUND:
            print(f"Grain {index} has too small foreground area")
        elif foreground_area > UPPER_AREA_BOUND:
            print(f"Grain {index} has too large foreground area")
        else:
            removed_anomaly_grain_dict[index] = grain_data

plot_images(
    [removed_anomaly_grain_dict[i]["image"] for i in removed_anomaly_grain_dict],
    [removed_anomaly_grain_dict[i]["mask"] for i in removed_anomaly_grain_dict],
    [removed_anomaly_grain_dict[i]["p_to_nm"] for i in removed_anomaly_grain_dict],
)

In [None]:
# Skeletonise using standard skeletonise


def plot_images(images: list, masks: list, px_to_nms: list, skeletons: list, width=5, cmap=cmap, vmin=-8, vmax=8):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 3, figsize=(30, 30))
    for i, (image, mask, skeleton) in enumerate(zip(images, masks, skeletons)):
        ax[i // width, i % width * 3].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3].axis("off")
        ax[i // width, i % width * 3 + 1].imshow(skeleton, cmap="binary")
        ax[i // width, i % width * 3].set_title(f"p_to_nm: {px_to_nms[i]}")
        ax[i // width, i % width * 3 + 2].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3 + 2].imshow(skeleton, cmap="viridis", alpha=0.2)
    fig.tight_layout()
    plt.show()


from scipy.ndimage import convolve


def convolve_skelly(skeleton) -> np.ndarray:
    """Convolves the skeleton with a 3x3 ones kernel to produce an array
    of the skeleton as 1, endpoints as 2, and nodes as 3.

    Parameters
    ----------
    skeleton: np.ndarray
        Single pixel thick binary trace(s) within an array.

    Returns
    -------
    np.ndarray
        The skeleton (=1) with endpoints (=2), and crossings (=3) highlighted.
    """
    conv = convolve(skeleton.astype(np.int32), np.ones((3, 3)))
    conv[skeleton == 0] = 0  # remove non-skeleton points
    conv[conv == 3] = 1  # skelly = 1
    conv[conv > 3] = 3  # nodes = 3
    return conv


plotting = False
paths_grain_dict = {}

for index, grain_data in removed_anomaly_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]

    # Skeletonise
    from skimage.morphology import skeletonize

    skeleton = skeletonize(grain_mask)

    if plotting:
        fig, ax = plt.subplots(1, 4, figsize=(20, 10))
        ax[0].imshow(grain_mask, cmap="gray")
        ax[0].set_title("grain mask")
        ax[1].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        ax[1].set_title("grain image")
        ax[2].imshow(skeleton, cmap="gray")
        ax[2].set_title("skeleton")
        ax[3].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        ax[3].imshow(skeleton, cmap="viridis", alpha=0.2)
        plt.show()

    # Ignore any grains that have a branch, ie a pixel with more than 2 neighbours
    convolved_skelly = convolve_skelly(skeleton)

    if np.max(convolved_skelly) > 1:
        print(f"Grain {index} has a branch")
        continue

    paths_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "skeleton": skeleton,
        "p_to_nm": p_to_nm,
    }

plot_images(
    [paths_grain_dict[i]["image"] for i in paths_grain_dict],
    [paths_grain_dict[i]["mask"] for i in paths_grain_dict],
    [paths_grain_dict[i]["p_to_nm"] for i in paths_grain_dict],
    [paths_grain_dict[i]["skeleton"] for i in paths_grain_dict],
)

In [None]:
# Trace the skeleton


def plot_images(
    images: list, masks: list, px_to_nms: list, skeletons: list, traces: list, width=5, cmap=cmap, vmin=-8, vmax=8
):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 3, figsize=(30, 30))
    for i, (image, mask, skeleton) in enumerate(zip(images, masks, skeletons)):
        ax[i // width, i % width * 3].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3].axis("off")
        ax[i // width, i % width * 3 + 1].imshow(skeleton, cmap="binary")
        ax[i // width, i % width * 3].set_title(f"p_to_nm: {px_to_nms[i]}")
        ax[i // width, i % width * 3 + 2].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3 + 2].imshow(skeleton, cmap="viridis", alpha=0.2)
        ax[i // width, i % width * 3 + 2].plot(traces[i][:, 1], traces[i][:, 0], "r")
    fig.tight_layout()
    plt.show()


plotting = False
trace_grain_dict = {}

for index, grain_data in paths_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    skeleton = grain_data["skeleton"]
    p_to_nm = grain_data["p_to_nm"]

    # Trace the skeleton
    skeleton_points = np.argwhere(skeleton)
    # Find the start point
    start_point = skeleton_points[0]
    # Each point should not have more than 2 neighbours so we can trace by finding the next point
    # and removing the current point from the skeleton

    skeleton_history = skeleton.copy()
    trace = [start_point]
    current_point = start_point
    skeleton_history[current_point[0], current_point[1]] = 0

    # print(f"len of skeleton points: {len(skeleton_points)}")
    for iteration in range(len(skeleton_points) - 1):
        neighbourhood = skeleton_history[
            current_point[0] - 1 : current_point[0] + 2, current_point[1] - 1 : current_point[1] + 2
        ]
        if iteration > 0 and np.sum(neighbourhood) > 1:
            raise ValueError(f"More than 1 neighbour for iteration {iteration}")
        if np.sum(neighbourhood) == 0:
            raise ValueError(f"No neighbours for iteration {iteration}")
        next_point = np.argwhere(neighbourhood)[0]
        next_point_coords = current_point + next_point - 1
        trace.append(next_point_coords)
        current_point = next_point_coords
        skeleton_history[current_point[0], current_point[1]] = 0
    trace = np.array(trace)

    ordered_skeleton = np.zeros_like(skeleton).astype(int)
    for point_index, point in enumerate(trace):
        ordered_skeleton[point[0], point[1]] = point_index + 20

    if plotting:
        fig, ax = plt.subplots(1, 3, figsize=(20, 10))
        ax[0].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        ax[0].set_title("grain image")
        ax[1].imshow(ordered_skeleton, cmap="viridis")
        ax[1].set_title("ordered skeleton")
        ax[2].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        ax[2].imshow(ordered_skeleton, cmap="viridis", alpha=0.2)
        ax[2].plot(trace[:, 1], trace[:, 0], "r")
        plt.show()

    # Calculate pixel trace length in nm
    pixel_trace_length = np.sum(np.linalg.norm(np.diff(trace, axis=0), axis=1)) * p_to_nm

    trace_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "skeleton": ordered_skeleton,
        "trace": trace,
        "p_to_nm": p_to_nm,
        "pixel_trace_length": pixel_trace_length,
    }


plot_images(
    [trace_grain_dict[i]["image"] for i in trace_grain_dict],
    [trace_grain_dict[i]["mask"] for i in trace_grain_dict],
    [trace_grain_dict[i]["p_to_nm"] for i in trace_grain_dict],
    [trace_grain_dict[i]["skeleton"] for i in trace_grain_dict],
    [trace_grain_dict[i]["trace"] for i in trace_grain_dict],
)

# Plot kde of trace lengths
import seaborn as sns

trace_lengths = [trace_grain_dict[i]["pixel_trace_length"] for i in trace_grain_dict]
sns.kdeplot(trace_lengths)
plt.xlim(0, 60)
plt.xlabel("Trace length (nm)")
plt.title(f"Length of pixel traces in nm for {len(trace_lengths)} grains")
plt.show()

In [None]:
# Get height traces from the skeletons


def plot_images(
    images: list, masks: list, px_to_nms: list, traces: list, height_traces: list, width=5, cmap=cmap, vmin=-8, vmax=8
):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 3, figsize=(30, 30))
    for i, (image, mask, height_trace) in enumerate(zip(images, masks, height_traces)):
        ax[i // width, i % width * 3].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3].axis("off")
        ax[i // width, i % width * 3].set_title(f"p_to_nm: {px_to_nms[i]:.2f}")
        ax[i // width, i % width * 3 + 1].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3 + 1].plot(traces[i][:, 1], traces[i][:, 0], "r")
        ax[i // width, i % width * 3 + 2].plot(height_trace)
    fig.tight_layout()
    plt.show()


plotting = False
height_trace_grain_dict = {}


for index, grain_data in trace_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]
    trace = grain_data["trace"]

    # Get the height trace
    height_trace = grain_image[trace[:, 0], trace[:, 1]]

    if plotting:
        fig, ax = plt.subplots(1, 2, figsize=(40, 10))
        ax[0].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        colours = np.linspace(0, 1, len(trace))
        for i in range(len(height_trace) - 1):
            ax[0].plot(trace[i : i + 2, 1], trace[i : i + 2, 0], color=plt.cm.viridis(colours[i]))
        ax[0].set_title("grain image")
        colours = np.linspace(0, 1, len(height_trace))
        xs = np.arange(len(height_trace))
        for i in range(len(height_trace) - 1):
            # Include +2 to get the next point since python slicing is exclusive
            ax[1].plot(xs[i : i + 2], height_trace[i : i + 2], color=plt.cm.viridis(colours[i]))
        ax[1].set_ylim(1, 3.5)
        plt.show()

    height_trace_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "trace": trace,
        "height_trace": height_trace,
        "p_to_nm": p_to_nm,
    }

plot_images(
    [height_trace_grain_dict[i]["image"] for i in height_trace_grain_dict],
    [height_trace_grain_dict[i]["mask"] for i in height_trace_grain_dict],
    [height_trace_grain_dict[i]["p_to_nm"] for i in height_trace_grain_dict],
    [height_trace_grain_dict[i]["trace"] for i in height_trace_grain_dict],
    [height_trace_grain_dict[i]["height_trace"] for i in height_trace_grain_dict],
)

In [None]:
# Curvature analysis
from scipy.interpolate import splprep, splev, UnivariateSpline
import matplotlib.gridspec as gridspec


def plot_colour_line_2d_based_on_value_on_axis(ax: plt.Axes, points: np.ndarray, values: np.ndarray, cmap="viridis_r"):
    # Get colours based on the value of the points
    normalised_values = (values - values.min()) / (values.max() - values.min())
    for i in range(len(points) - 1):
        ax.plot(points[i : i + 2, 0], points[i : i + 2, 1], color=plt.cm.get_cmap(cmap)(normalised_values[i]))


def plot_colour_line_1d_based_on_value_on_axis(
    ax: plt.Axes, values: np.ndarray, hline: Union[None, float] = None, cmap="viridis_r"
):
    # Get colours based on the value of the points
    normalised_values = (values - values.min()) / (values.max() - values.min())
    xs = np.arange(len(values))
    for i in range(len(values) - 1):
        ax.plot(xs[i : i + 2], values[i : i + 2], color=plt.cm.get_cmap(cmap)(normalised_values[i]))
    if hline is not None:
        ax.axhline(hline, color="k", linestyle="-")


def interpolate_points_spline(points: np.ndarray, num_points: Union[int, None] = None, smoothing: float = 0.0):
    """Interpolate a set of points using a spline.

    Parameters
    ----------
    points: np.ndarray
        Nx2 Numpy array of coordinates for the points.
    num_points: int
        The number of points to return following the calculated spline.

    Returns
    -------
    interpolated_points: np.ndarray
        An Ix2 Numpy array of coordinates of the interpolated points, where I is the number of points
        specified.
    """

    if num_points is None:
        num_points = points.shape[0]

    x, y = splprep(points.T, u=None, s=smoothing, per=1)
    x_spline = np.linspace(y.min(), y.max(), num_points)
    x_new, y_new = splev(x_spline, x, der=0)
    interpolated_points = np.array((x_new, y_new)).T
    return interpolated_points


def calculate_curvature_from_points(x_points, y_points, error=0.1, k=4):
    """Calculate the curvature for a set of points"""
    # Check that the number of points is the same for both x and y
    if x_points.shape[0] != y_points.shape[0]:
        raise ValueError(
            f"x_points and y_points must have the same number of points. x_points has {x_points.shape[0]} points and y_points has {y_points.shape[0]} points."
        )

    # Weight the values so less weight is given to points with higher error
    # K is the order of the spline to use. Increasing this increases the smoothness of the spline. A
    # value of 1 is linear interpolation, 2 is quadratic, 3 is cubic, etc. 4 is used to ensure the
    # spline is smooth enough to differentiate to the second derivative.
    # t is the independent variable that monotically increases with the data, similar to how
    # we use a dummy x variable in plotting calculations.

    t = np.arange(x_points.shape[0])
    weight_values = 1 / np.sqrt(error * np.ones_like(x_points))
    fx = UnivariateSpline(t, x_points, k=k, w=weight_values)
    fy = UnivariateSpline(t, y_points, k=k, w=weight_values)

    spline_x = fx(t)
    spline_y = fy(t)

    dx = fx.derivative(1)(t)
    dx2 = fx.derivative(2)(t)
    dy = fy.derivative(1)(t)
    dy2 = fy.derivative(2)(t)
    curvatures = (dx * dy2 - dy * dx2) / np.power(dx**2 + dy**2, 3 / 2)
    return curvatures, spline_x, spline_y


def calculate_curvature_periodic_boundary(x_points, y_points, error=0.1, periods=2, k=4):
    """Take a set of points that form a loop and calculate the curvature. Uses periodic boundary conditions, so
    the first and last points are connected. This reduces the error in the curvature calculation at the
    boundaries.

    Parameters
    ----------
    x_points: np.ndarray
        1D numpy array of x coordinates of the points.
    y_points: np.ndarray
        1D numpy array of y coordinates of the points.
    error: float
        Error in the points. Used to weight the points in the spline calculation.
    periods: int
        Number of times to repeat the points either side of the original points to reduce the error at the
        boundaries.

    Returns
    -------
    curvature: np.ndarray
        1D numpy array of the curvature for each point.
    """

    # Check that the number of points is the same for both x and y
    if x_points.shape[0] != y_points.shape[0]:
        raise ValueError(
            f"x_points and y_points must have the same number of points. x_points has"
            f" {x_points.shape[0]} points and y_points has {y_points.shape[0]} points."
        )

    # Repeat the points either side of the original points to reduce the error at the boundaries
    extended_points_x = np.copy(x_points)
    extended_points_y = np.copy(y_points)
    for i in range(periods * 2):
        extended_points_x = np.append(extended_points_x, x_points)
        extended_points_y = np.append(extended_points_y, y_points)

    # Calculate the curvature
    extended_curvature, spline_x, spline_y = calculate_curvature_from_points(
        extended_points_x, extended_points_y, error=error, k=k
    )

    # Return only the original points
    return (
        extended_curvature[x_points.shape[0] * int(periods / 2) : x_points.shape[0] * int((periods / 2) + 1)],
        spline_x[x_points.shape[0] * int(periods / 2) : x_points.shape[0] * int((periods / 2) + 1)],
        spline_y[x_points.shape[0] * int(periods / 2) : x_points.shape[0] * int((periods / 2) + 1)],
    )


def turn_path_into_pixel_map(array: np.ndarray):
    # Convert the spline to a pixelated trace 1 pixel thick

    # Create a map of pixels
    pixel_map = np.zeros((int(np.max(array) + 1), int(np.max(array) + 1)), dtype=int)
    pixelated_path = np.empty((0, 2), dtype=int)

    def check_is_touching(coordinate, original_coordinate):
        if np.abs(coordinate[0] - original_coordinate[0]) <= 1 and np.abs(coordinate[1] - original_coordinate[1]) <= 1:
            return True
        else:
            return False

    # Convert the array to integers and remove duplicates
    integer_array = np.array(array, dtype=int)
    removed_duplicates = []
    for index in range(len(integer_array)):
        coordinate = integer_array[index]
        if index > 0:
            if np.array_equal(coordinate, integer_array[index - 1]):
                # print(f"coordinate {coordinate} is a repeat of {integer_array[index - 1]}, skipping")
                continue

        removed_duplicates.append(coordinate)
    integer_array = np.array(removed_duplicates)

    last_coordinate = None
    for index in range(len(integer_array)):
        coordinate = integer_array[index]

        # If the coordinate is a repeat, skip it
        if index > 0:
            if np.array_equal(coordinate, integer_array[index - 1]):
                # print(
                #     f"coordinate {coordinate} is a repeat of {integer_array[index - 1]}, skipping"
                # )
                continue

        # print(f"coordinate: {coordinate}")
        if index == 0:
            pixel_map[coordinate[0], coordinate[1]] = 1
            last_coordinate = coordinate
        elif index == len(integer_array) - 1:
            pixel_map[coordinate[0], coordinate[1]] = 1
            last_coordinate = coordinate
            break
        else:
            # Check if the coordinate after this one is touching the coordinate before this one
            # and if so, skip this pixel
            if check_is_touching(integer_array[index + 1], last_coordinate):
                # print(f"coordinate {integer_array[index + 1]} is touching {last_coordinate}")
                continue
            else:
                # print(
                #     f"coordinate {integer_array[index+1]} is not touching {integer_array[index - 1]}. Adding to map"
                # )

                # Add the coordinate to the pixel map and the pixelated path
                pixel_map[int(coordinate[0]), int(coordinate[1])] = 1
                pixelated_path = np.vstack((pixelated_path, coordinate.reshape(1, 2)))
                last_coordinate = coordinate

    return pixel_map, pixelated_path


def defect_stats(height_trace: np.ndarray, threshold: float):
    regions = []
    in_region = False
    # Find the largest continuous region below the threshold
    for index, value in enumerate(height_trace):
        if value < threshold:
            # print(f"index {index} is below threshold: {height_trace[index]}")
            if not in_region:
                # Start region
                region_start = index
                area = 0
                deepest_point = value
                deepest_point_index = index
                in_region = True
            else:
                # print(f"value {height_trace[index]} is below threshold {threshold} by {threshold - height_trace[index]}")
                area += threshold - value
                if value < deepest_point:
                    deepest_point = value
                    deepest_point_index = index
        elif in_region:
            regions.append(
                {
                    "start": region_start,
                    "end": index,
                    "area": area,
                    "deepest_point_index": deepest_point_index,
                    "defect_threshold": threshold,
                }
            )
            in_region = False

    # Number of defects
    number_of_defects = len(regions)

    # Find the largest region
    largest_region_below_threshold = None
    largest_area = 0
    for region in regions:
        if region["area"] > largest_area:
            largest_area = region["area"]
            largest_region_below_threshold = region

    return {
        "defect_number": number_of_defects,
        "defect_largest_region": largest_region_below_threshold,
        "defect_regions": regions,
    }


plotting = False
curvature_grain_dict = {}

for index, grain_data in height_trace_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]
    trace = grain_data["trace"]
    height_trace = grain_data["height_trace"]

    interpolated_points = interpolate_points_spline(trace, num_points=200, smoothing=0.0)

    curvature, spline_x, spline_y = calculate_curvature_periodic_boundary(
        interpolated_points[:, 1], interpolated_points[:, 0], error=0.1, periods=2, k=5
    )

    # Flip curvatures so that positive is convex up
    curvature = -curvature

    spline_points = np.array((spline_x, spline_y)).T

    # Get pixel spline trace
    spline_pixel_trace_img, spline_pixelated_path = turn_path_into_pixel_map(spline_points)

    # Get the height trace from the spline pixelated path
    spline_pixelated_path_heights = grain_image[spline_pixelated_path[:, 1], spline_pixelated_path[:, 0]]

    defect_stats_dict = defect_stats(curvature, threshold=-0.15)

    if plotting:
        # Plot scatter with colour as curvature
        # fig, ax = plt.subplots(1, 3, figsize=(20, 10))

        fig = plt.figure(figsize=(12, 12))
        gs = gridspec.GridSpec(4, 2, figure=fig)

        ax0 = fig.add_subplot(gs[0, 0])
        ax1 = fig.add_subplot(gs[0, 1])
        ax2 = fig.add_subplot(gs[1, 0])
        ax3 = fig.add_subplot(gs[1, 1])
        ax4 = fig.add_subplot(gs[2, :])
        ax5 = fig.add_subplot(gs[3, :])

        ax0.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        ax0.plot(trace[:, 1], trace[:, 0], "r")
        ax0.set_title("Pixelated Path")
        ax1.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        # Plot spline pixelated path including the first point to close the loop
        ax1.plot(
            np.append(spline_pixelated_path[:, 0], spline_pixelated_path[0, 0]),
            np.append(spline_pixelated_path[:, 1], spline_pixelated_path[0, 1]),
            "r",
        )
        ax1.set_title("Spline Pixel Path")
        ax2.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        plot_colour_line_2d_based_on_value_on_axis(ax2, spline_points, curvature)
        ax2.set_title("Spline Path with Curvature")
        # Plot spline pixelated trace heights overlaid on the grain image
        ax3.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        plot_colour_line_2d_based_on_value_on_axis(ax3, spline_pixelated_path, spline_pixelated_path_heights)
        ax3.set_title("Spline Pixellated Path with Heights")
        plot_colour_line_1d_based_on_value_on_axis(ax4, curvature, hline=0)
        ax4.set_title("Curvature")
        ax4.set_ylim(-0.5, 0.5)
        # Add red regions over defects
        for region in defect_stats_dict["defect_regions"]:
            ax4.axvspan(region["start"], region["end"], color="r", alpha=0.5)
        # Add a black line at the deepest point of the largest defect
        ax4.axvline(
            defect_stats_dict["defect_largest_region"]["deepest_point_index"], color="k", linestyle="--", alpha=0.5
        )
        plot_colour_line_1d_based_on_value_on_axis(ax5, spline_pixelated_path_heights)
        ax5.set_ylim(0.0, 4.0)
        ax5.set_title("Spline Pixellated Path Heights")
        fig.tight_layout()
        plt.show()

    curvature_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "trace": trace,
        "height_trace": height_trace,
        "spline_points": spline_points,
        "curvature": curvature,
        "num_curvature_defects": defect_stats_dict["defect_number"],
        "p_to_nm": p_to_nm,
    }

# Plot kde of number of defects
num_defects = [curvature_grain_dict[i]["num_curvature_defects"] for i in curvature_grain_dict]
sns.kdeplot(num_defects)
plt.xlabel("Number of defects")
plt.title(f"Number of defects for {SAMPLE_TYPE} {len(num_defects)} grains")
plt.show()