In [None]:
from pathlib import Path

from topostats.io import LoadScans, read_yaml
from topostats.plottingfuncs import Colormap
from topostats.filters import Filters
from topostats.grains import Grains

import numpy as np
import matplotlib.pyplot as plt


colormap = Colormap()
cmap = colormap.get_cmap()

In [None]:
FILE_PATH = Path("/Users/sylvi/Downloads/20230823_TAF_IR_minicircle_1.5ng_CaptureFile.0_00016.spm")
assert FILE_PATH.exists()

loadscans = LoadScans([FILE_PATH], channel="Height Sensor")
loadscans.get_data()

data = loadscans.img_dict[str(FILE_PATH.stem)]

print(data.keys())

image = data["image_original"]

fig, ax = plt.subplots()
im = ax.imshow(image, cmap=cmap)
fig.colorbar(im)
plt.show()

config = read_yaml("../topostats/default_config.yaml")
filter_config = config["filter"]
filter_config.pop("run")
print(filter_config)

filters = Filters(
    image=image, filename=FILE_PATH.stem, pixel_to_nm_scaling=data["pixel_to_nm_scaling"], **filter_config
)
filters.filter_image()

flattened_image = filters.images["gaussian_filtered"]

fig, ax = plt.subplots()
im = ax.imshow(flattened_image, cmap=cmap)
fig.colorbar(im)
plt.show()

In [None]:
# Get the grain

threshold = 1.0
binary_mask = flattened_image > threshold

plt.imshow(binary_mask)

# Get the largest two regions
from skimage.measure import label, regionprops

labelled_mask = label(binary_mask)
mask_regions = regionprops(labelled_mask)

region_sizes = [region.area for region in mask_regions]
region_sizes.sort(reverse=True)
largest_regions = [region for region in mask_regions if region.area in region_sizes[:2]]

fig, ax = plt.subplots()
ax.imshow(binary_mask, cmap=cmap)
for region in largest_regions:
    minr, minc, maxr, maxc = region.bbox
    rect = plt.Rectangle((minc, minr), maxc - minc, maxr - minr, fill=False, edgecolor="red", linewidth=2)
    ax.add_patch(rect)
plt.show()

region_1_mask = labelled_mask == largest_regions[0].label
region_2_mask = labelled_mask == largest_regions[1].label

fig, ax = plt.subplots(1, 2)
ax[0].imshow(region_1_mask, cmap=cmap)
ax[1].imshow(region_2_mask, cmap=cmap)
plt.show()

# get a crop of the mask and image for the grains
padding = 10
minr_0, minc_0, maxr_0, maxc_0 = largest_regions[0].bbox
minr_1, minc_1, maxr_1, maxc_1 = largest_regions[1].bbox
image_crop_0 = image[minr_0 - padding : maxr_0 + padding, minc_0 - padding : maxc_0 + padding]
mask_crop_0 = region_1_mask[minr_0 - padding : maxr_0 + padding, minc_0 - padding : maxc_0 + padding]
image_crop_1 = image[minr_1 - padding : maxr_1 + padding, minc_1 - padding : maxc_1 + padding]
mask_crop_1 = region_2_mask[minr_1 - padding : maxr_1 + padding, minc_1 - padding : maxc_1 + padding]

fig, ax = plt.subplots(2, 2)
ax[0, 0].imshow(image_crop_0, cmap=cmap)
ax[0, 1].imshow(mask_crop_0, cmap=cmap)
ax[1, 0].imshow(image_crop_1, cmap=cmap)
ax[1, 1].imshow(mask_crop_1, cmap=cmap)
plt.show()

In [None]:
# process the grains
from skimage.morphology import skeletonize
from scipy.ndimage import convolve


def convolve_skelly(skeleton) -> np.ndarray:
    """Convolves the skeleton with a 3x3 ones kernel to produce an array
    of the skeleton as 1, endpoints as 2, and nodes as 3.

    Parameters
    ----------
    skeleton: np.ndarray
        Single pixel thick binary trace(s) within an array.

    Returns
    -------
    np.ndarray
        The skeleton (=1) with endpoints (=2), and crossings (=3) highlighted.
    """
    conv = convolve(skeleton.astype(np.int32), np.ones((3, 3)))
    conv[skeleton == 0] = 0  # remove non-skeleton points
    conv[conv == 3] = 1  # skelly = 1
    conv[conv > 3] = 3  # nodes = 3
    return conv


skelly_0 = skeletonize(mask_crop_0.astype(bool), method="lee")
skelly_1 = skeletonize(mask_crop_1.astype(bool), method="lee")

skelly_0 = convolve_skelly(skelly_0).astype(bool)
skelly_1 = convolve_skelly(skelly_1).astype(bool)

print(np.max(skelly_0), np.max(skelly_1))

grain_0_dict = {
    "grain_image": image_crop_0,
    "skeleton": skelly_0,
    "p_to_nm": data["pixel_to_nm_scaling"],
    "mask": mask_crop_0,
}

grain_1_dict = {
    "grain_image": image_crop_1,
    "skeleton": skelly_1,
    "p_to_nm": data["pixel_to_nm_scaling"],
    "mask": mask_crop_1,
}

grain_data_dict = {
    0: grain_0_dict,
    1: grain_1_dict,
}

skelly_grain_dict = {}
plotting = False

for index, grain_data in grain_data_dict.items():
    grain_image = grain_data["grain_image"]
    grain_mask = grain_data["mask"]
    skeleton = grain_data["skeleton"]
    pixel_to_nm_scaling = grain_data["p_to_nm"]

    # Trace the skeleton
    skeleton_points = np.argwhere(skeleton.astype(bool))
    # Find the start point
    start_point = skeleton_points[0]
    # Each point should not have more than 2 neighbours so we can trace by finding the next point
    # and removing the current point from the skeleton

    skeleton_history = skeleton.copy()
    trace = [start_point]
    current_point = start_point
    skeleton_history[current_point[0], current_point[1]] = 0

    # print(f"len of skeleton points: {len(skeleton_points)}")
    for iteration in range(len(skeleton_points) - 1):
        neighbourhood = skeleton_history[
            current_point[0] - 1 : current_point[0] + 2, current_point[1] - 1 : current_point[1] + 2
        ]
        if iteration > 0 and np.sum(neighbourhood) > 1:
            raise ValueError(f"More than 1 neighbour for iteration {iteration}")
        if np.sum(neighbourhood) == 0:
            raise ValueError(f"No neighbours for iteration {iteration}")
        next_point = np.argwhere(neighbourhood)[0]
        next_point_coords = current_point + next_point - 1
        trace.append(next_point_coords)
        current_point = next_point_coords
        skeleton_history[current_point[0], current_point[1]] = 0
    trace = np.array(trace)

    ordered_skeleton = np.zeros_like(skeleton).astype(int)
    for point_index, point in enumerate(trace):
        ordered_skeleton[point[0], point[1]] = point_index + 20

    if plotting:
        fig, ax = plt.subplots(1, 3, figsize=(20, 10))
        ax[0].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        ax[0].set_title("grain image")
        ax[1].imshow(ordered_skeleton, cmap="viridis")
        ax[1].set_title("ordered skeleton")
        ax[2].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        ax[2].imshow(ordered_skeleton, cmap="viridis", alpha=0.2)
        ax[2].plot(trace[:, 1], trace[:, 0], "r")
        plt.show()

    # Calculate pixel trace length in nm
    pixel_trace_length = np.sum(np.linalg.norm(np.diff(trace, axis=0), axis=1)) * data["pixel_to_nm_scaling"]

    skelly_grain_dict[index] = {
        "grain_image": grain_image,
        "grain_mask": grain_mask,
        "skeleton": ordered_skeleton,
        "trace": trace,
        "pixel_trace_length": pixel_trace_length,
        "p_to_nm": pixel_to_nm_scaling,
    }

for index, grain in skelly_grain_dict.items():
    fig, ax = plt.subplots(1, 2, figsize=(20, 10))
    ax[0].imshow(grain["grain_image"], cmap=cmap)
    ax[0].set_title("grain image")
    ax[1].imshow(grain["skeleton"], cmap="viridis")
    ax[0].plot(grain["trace"][:, 1], grain["trace"][:, 0], "lime")
    ax[1].set_title(f"grain {index} skeleton")
    plt.show()

In [None]:
plotting = False
height_trace_grain_dict = {}

for index, grain_data in skelly_grain_dict.items():
    grain_image = grain_data["grain_image"]
    grain_mask = grain_data["grain_mask"]
    p_to_nm = grain_data["p_to_nm"]
    trace = grain_data["trace"]

    # Get the height trace
    height_trace = grain_image[trace[:, 0], trace[:, 1]]

    if plotting:
        fig, ax = plt.subplots(1, 2, figsize=(40, 10))
        ax[0].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        colours = np.linspace(0, 1, len(trace))
        for i in range(len(height_trace) - 1):
            ax[0].plot(trace[i : i + 2, 1], trace[i : i + 2, 0], color=plt.cm.viridis(colours[i]))
        ax[0].set_title("grain image")
        colours = np.linspace(0, 1, len(height_trace))
        xs = np.arange(len(height_trace))
        for i in range(len(height_trace) - 1):
            # Include +2 to get the next point since python slicing is exclusive
            ax[1].plot(xs[i : i + 2], height_trace[i : i + 2], color=plt.cm.viridis(colours[i]))
        ax[1].set_ylim(1, 3.5)
        plt.show()

    height_trace_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "trace": trace,
        "height_trace": height_trace,
        "p_to_nm": p_to_nm,
    }

# Plot the height traces:
for index, grain in height_trace_grain_dict.items():
    fig, ax = plt.subplots(1, 2, figsize=(20, 10))
    ax[0].imshow(grain["image"], cmap=cmap)
    ax[0].plot(grain["trace"][:, 1], grain["trace"][:, 0], "lime")
    ax[0].set_title(f"grain {index} image")
    ax[1].plot(grain["height_trace"])
    ax[1].set_title(f"grain {index} height trace")
    plt.show()

In [None]:
def plot_images(
    images: list, original_traces: list, pooled_traces: list, px_to_nms: list, width=5, cmap=cmap, vmin=-8, vmax=8
):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int) + 1, width, figsize=(20, 50))
    for i, (image, original_trace, pooled_trace) in enumerate(zip(images, original_traces, pooled_traces)):
        ax[i // width, i % width].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width].plot(original_trace[:, 1], original_trace[:, 0], "r")
        ax[i // width, i % width].plot(pooled_trace[:, 1], pooled_trace[:, 0], "b")
        ax[i // width, i % width].axis("off")
    fig.tight_layout()
    plt.show()


plotting = False

pooled_curvature_grain_dict = {}

# Binning size
# Bin based on the pixel to nm scaling so it bins every n nm
n_nm = 4
# bin every n nm
n = int(n_nm / data["pixel_to_nm_scaling"])

for index, grain_data in height_trace_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]
    trace = grain_data["trace"]
    height_trace = grain_data["height_trace"]

    # Pool the trace points
    pooled_trace = []

    for i in range(len(trace)):
        binned_points = []
        for j in range(n):
            if i + j < len(trace):
                binned_points.append(trace[i + j])
            else:
                # If the index is out of range, sample from the start
                binned_points.append(trace[i + j - len(trace)])

        # Get the mean of the binned points
        pooled_trace.append(np.mean(binned_points, axis=0))

    pooled_trace = np.array(pooled_trace)

    if plotting:
        # Plot overlaid on image
        fig, ax = plt.subplots(1, 1, figsize=(20, 10))
        ax.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        ax.plot(trace[:, 1], trace[:, 0], "r")
        ax.plot(pooled_trace[:, 1], pooled_trace[:, 0], "b")
        ax.set_title("Trace and pooled trace")

    pooled_curvature_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "trace": trace,
        "pooled_trace": pooled_trace,
        "height_trace": height_trace,
        "p_to_nm": p_to_nm,
    }

for index, grain in pooled_curvature_grain_dict.items():
    fig, ax = plt.subplots(1, 2, figsize=(20, 10))
    ax[0].imshow(grain["image"], cmap=cmap)
    ax[0].plot(grain["trace"][:, 1], grain["trace"][:, 0], "lime")
    ax[0].set_title("standard skeleton trace")
    ax[1].imshow(grain["image"], cmap=cmap)
    ax[1].plot(grain["pooled_trace"][:, 1], grain["pooled_trace"][:, 0], "lime")
    ax[1].set_title(f"pooled trace ({n_nm} nm bins)")
    plt.show()

In [None]:
def angle_diff_signed(v1: np.ndarray, v2: np.ndarray):
    """Calculate the signed angle difference between two vectors.

    Parameters
    ----------
    v1: np.ndarray
        The first vector.
    v2: np.ndarray
        The second vector.

    Returns
    -------
    float
        The signed angle difference between the two vectors.
    """

    # Calculate if the new vector is clockwise or anticlockwise from the old vector

    # Calculate the angle between the vectors
    angle = np.arccos(np.clip(np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)), -1.0, 1.0))

    # Calculate the cross product
    cross = np.cross(v1, v2)

    # If the cross product is positive, the new vector is clockwise from the old vector
    if cross > 0:
        angle = -angle

    return angle


def angle_per_nm(trace: np.ndarray, p_to_nm: float, plot: bool = False) -> np.ndarray:
    """Calculate the angle per nm of a trace.

    Parameters
    ----------
    trace: np.ndarray
        The trace to calculate the angle per nm of.
    p_to_nm: float
        The pixel to nm scaling factor.

    Returns
    -------
    np.ndarray
        The angle change per nm for each point in the trace
    """

    # Check if the first point is the same as the last point
    if np.all(trace[0] == trace[-1]):
        raise ValueError("The first and last points are the same")

    if plot:
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))

    angles_per_nm = np.zeros(len(trace))
    angle_diffs = np.zeros(len(trace))

    for index, point in enumerate(trace):
        # print(f"index: {index}")
        if plot:
            ax.scatter(point[1], point[0], c="purple", s=20)

        # Get the vectors to the previous and next points
        if index == 0:
            v_prev = point - trace[-1]
            v_next = trace[index + 1] - point
        if index == len(trace) - 1:
            v_prev = point - trace[index - 1]
            v_next = trace[0] - point
        else:
            v_prev = point - trace[index - 1]
            v_next = trace[index + 1] - point

        # print(f"vprev: {v_prev} vnext: {v_next}")

        # Normalise the vectors to unit length
        norm_v_prev = v_prev / np.linalg.norm(v_prev) * 0.1
        norm_v_next = v_next / np.linalg.norm(v_next) * 0.1

        angle = angle_diff_signed(v_prev, v_next)

        if plot:
            # Plot the vectors
            ax.arrow(
                point[1], point[0], norm_v_prev[1], norm_v_prev[0], head_width=0.01, head_length=0.2, fc="r", ec="r"
            )
            ax.arrow(
                point[1], point[0], norm_v_next[1], norm_v_next[0], head_width=0.01, head_length=0.2, fc="b", ec="b"
            )
            # Write text for the angle
            ax.text(point[1], point[0], f"{np.degrees(angle):.2f}", fontsize=12, color="black")

        distance = np.linalg.norm(v_prev) * p_to_nm

        # print(f"distance: {distance:.4f} angle: {angle:.4f} angle per nm: {angle / distance:.4f}")

        angles_per_nm[index] = angle / distance
        angle_diffs[index] = angle

    if plot:
        plt.plot(trace[:, 1], trace[:, 0], "k")
        plt.show()

    return angles_per_nm, angle_diffs


def flip_if_anticlockwise(trace: np.ndarray):
    # Check if the trace is clockwise or anticlockwise by summing the cross products of the vectors
    # If the sum is positive, the trace is clockwise
    # If the sum is negative, the trace is anticlockwise
    # If the sum is 0, the trace is a straight line
    cross_sum = 0
    for i in range(len(trace) - 1):
        cross_sum += np.cross(trace[i], trace[i + 1])
    if cross_sum > 0:
        # print("clockwise")
        # Reverse the trace
        trace = np.flip(trace, axis=0)
    elif cross_sum < 0:
        # print("anticlockwise")
        pass

    return trace


def defect_stats(array: np.ndarray, threshold: float):
    regions = []
    in_region = False
    # Find the largest continuous region below the threshold
    for index, value in enumerate(array):
        if value > threshold:
            if not in_region:
                # Start region
                region_start = index
                area = 0
                highest_point = value
                highest_point_index = index
                in_region = True
            else:
                area += threshold - value
                if value > highest_point:
                    highest_point = value
                    highest_point_index = index
        elif in_region:
            regions.append(
                {
                    "start": region_start,
                    "end": index,
                    "area": area,
                    "highest_point_index": highest_point_index,
                    "highest_point": highest_point,
                    "defect_threshold": threshold,
                }
            )
            in_region = False
    if in_region:
        regions.append(
            {
                "start": region_start,
                "end": index,
                "area": area,
                "highest_point_index": highest_point_index,
                "highest_point": highest_point,
                "defect_threshold": threshold,
            }
        )

    # Check if there are defects at the start and end of the array
    if len(regions) > 0:
        # Check if the first region starts at the start of the array
        if regions[0]["start"] == 0:
            # And if the last region ends at the end of the array
            if regions[-1]["end"] == len(array) - 1:
                # Combine the first and last regions
                regions[0]["start"] = regions[-1]["start"]
                regions[0]["area"] += regions[-1]["area"]
                if regions[-1]["highest_point"] > regions[0]["highest_point"]:
                    regions[0]["highest_point"] = regions[-1]["highest_point"]
                    regions[0]["highest_point_index"] = regions[-1]["highest_point_index"]
                regions.pop(-1)

    # Number of defects
    number_of_defects = len(regions)

    # Find the largest region
    largest_region_below_threshold = None
    largest_area = 0
    for region in regions:
        if region["area"] > largest_area:
            largest_area = region["area"]
            largest_region_below_threshold = region

    # Find the midpoint of each region
    for region in regions:
        if region["start"] < region["end"]:
            region["midpoint"] = int(np.round((region["start"] + region["end"]) / 2))
        else:
            # Get the negative index to be able to take the average of the two indexes
            temp_startpoint = region["start"] - len(array)
            region["midpoint"] = int(np.round((temp_startpoint + region["end"]) / 2)) % len(array)

    return {
        "defect_number": number_of_defects,
        "defect_largest_region": largest_region_below_threshold,
        "defect_regions": regions,
    }


def calculate_real_distance_between_points_in_array(
    array: np.ndarray, indexes_to_calculate_distance_between: np.ndarray, p_to_nm: float
):
    # Calculate the distances between each defect along the trace
    to_find = np.copy(indexes_to_calculate_distance_between)
    original = to_find[0]
    current_index = original
    current_position = array[current_index]
    previous_position = current_position
    distances = []
    distance = 0
    while len(to_find) > 0:
        # Update old position
        previous_position = current_position
        # Increment the current position along the trace
        current_index += 1
        if current_index >= len(array):
            current_index -= len(array)
        current_position = array[current_index]

        # Increment the distance
        distance += np.linalg.norm(current_position - previous_position) * p_to_nm

        # Check if the current index is in the to_find list
        if current_index in to_find:
            # Get rid of the current index from the to_find list
            to_find = to_find[to_find != current_index]
            # Store the distance
            distances.append(distance)
            # Reset the distance
            distance = 0

    return distances

In [None]:
# SIMPLE calculate rate of change of vector angle per nm

from scipy.ndimage import gaussian_filter1d

angle_change_rate_grain_dict = {}
plotting = True

for index, grain_data in pooled_curvature_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    pooled_trace = grain_data["pooled_trace"]
    p_to_nm = grain_data["p_to_nm"]

    # Degrees per nm threshold
    degrees_per_nm_threshold = 12
    radians_per_nm_threshld = np.radians(degrees_per_nm_threshold)
    if plotting:
        print(f"radians per nm threshold: {radians_per_nm_threshld}")

    pooled_trace = flip_if_anticlockwise(pooled_trace)

    # Sample points once every n nm
    nm_sample_distance = 0.5
    # Sample pooled trace every 2 nm
    nm_sampled_trace = []
    pixel_sample_distance = nm_sample_distance / p_to_nm
    # Starting at zero, for each point in pooled trace, see if the distance from the last point is greater than the sample distance and if not skip until it is
    for i in range(len(pooled_trace)):
        if i == 0:
            nm_sampled_trace.append(pooled_trace[i])
        else:
            # If the distance between the current point and the last point in the sampled trace is greater than the sample distance, add the current point to the sampled trace
            if np.linalg.norm(pooled_trace[i] - nm_sampled_trace[-1]) > pixel_sample_distance:
                nm_sampled_trace.append(pooled_trace[i])

    nm_sampled_trace = np.array(nm_sampled_trace)

    angles_per_nm, angle_diffs = angle_per_nm(nm_sampled_trace, p_to_nm, plot=False)

    # Smooth the angles per nm
    angles_per_nm = gaussian_filter1d(angles_per_nm, 3)

    # Get the defect stats
    defect_stats_dict = defect_stats(angles_per_nm, radians_per_nm_threshld)

    # Print angles per nm with commas between so it can be copied into a spreadsheet
    # print(",".join([f"{angle:.4f}" for angle in angles_per_nm]))

    # Calculate the real distance between the defects
    defect_indexes = np.array([region["midpoint"] for region in defect_stats_dict["defect_regions"]])
    defect_distances = None
    if len(defect_indexes) > 1:
        defect_distances = calculate_real_distance_between_points_in_array(nm_sampled_trace, defect_indexes, p_to_nm)
        if plotting:
            print(f"defect distances: {defect_distances}")

    if plotting:
        plt.plot(angles_per_nm)
        plt.ylim(-0.5, 0.5)
        plt.axhline(radians_per_nm_threshld, color="k", linestyle="--")
        plt.axhline(0, color="k", linestyle="-")
        for region in defect_stats_dict["defect_regions"]:
            if region["start"] < region["end"]:
                plt.axvspan(region["start"], region["end"], color="red", alpha=0.3)
            else:
                plt.axvspan(region["start"], len(angles_per_nm), color="red", alpha=0.3)
                plt.axvspan(0, region["end"], color="red", alpha=0.3)
            # Identify the deepest point
            # plt.scatter(region["highest_point_index"], angles_per_nm[region["highest_point_index"]], c="lime")
            # Plot the midpoint
            plt.scatter(region["midpoint"], angles_per_nm[region["midpoint"]], c="lime")
        plt.show()

        assert len(angles_per_nm) == len(nm_sampled_trace)

        plt.imshow(grain_image, cmap=cmap)
        plt.plot(pooled_trace[:, 1], pooled_trace[:, 0], "r")
        plt.plot(nm_sampled_trace[:, 1], nm_sampled_trace[:, 0], "b")
        # # For each defect, mark its highest point
        # for region in defect_stats_dict["defect_regions"]:
        #     plt.scatter(nm_sampled_trace[region["highest_point_index"]][1], nm_sampled_trace[region["highest_point_index"]][0], c="lime")
        # For each defect mark its midpoint
        for region in defect_stats_dict["defect_regions"]:
            plt.scatter(nm_sampled_trace[region["midpoint"]][1], nm_sampled_trace[region["midpoint"]][0], c="lime")

    plt.show()

    angle_change_rate_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "pooled_trace": pooled_trace,
        "angles_per_nm": angles_per_nm,
        "p_to_nm": p_to_nm,
        "defect_stats": defect_stats_dict,
        "defect_distances": defect_distances,
        "nm_sampled_trace": nm_sampled_trace,
    }

In [None]:
def weighted_average_position(array: np.ndarray, clip_min: float) -> float:
    """Calculate the weighted average position in an array.

    Parameters
    ----------
    array: np.ndarray
        The array to calculate the weighted average position of.

    Returns
    -------
    float
        The weighted average position in the array.
    """
    positions = np.arange(len(array))
    weights = np.copy(array)
    weights[weights < clip_min] = clip_min
    weighted_average = np.average(positions, weights=array)
    return weighted_average

In [None]:
def find_distance_window_looped(distances: np.ndarray, window_distance: float, start_index: int):
    """Find a window that covers a certain distance in a set of points in a loop (but the start point is not repeated at the end)
    where the distances between the points are known including the distance between the end point and the start point.
    """

    distances = np.copy(distances)

    window_current_distance = 0
    window_start_index = start_index
    proposed_end_index = start_index
    found_end = False
    max_iterations = len(distances)
    iterations = 0
    while not found_end:
        # print(f"start index: {window_start_index}, proposed end index: {proposed_end_index}, distance so far: {window_current_distance}")
        # Check if distance is greater than the requested distance
        if window_current_distance > window_distance:
            # print(f"found end at index: {proposed_end_index}")
            found_end = True
            window_end_index = proposed_end_index
            return window_start_index, window_end_index, window_current_distance
        else:
            # Increment the window distance
            # print(f"adding distance: {distances[proposed_end_index]}")
            window_current_distance += distances[proposed_end_index]
            # Increment the proposed end index
            if proposed_end_index < len(distances) - 1:
                proposed_end_index += 1
            else:
                proposed_end_index = 0

        # Safety for infinite loop
        if iterations > max_iterations:
            raise ValueError("Max iterations reached")
        iterations += 1


def weighted_mean_defect_position(
    defect_angle_shifts: np.ndarray, defect_start_index: int, max_index: int, clip_min: float = 0.0
):
    weighted_mean_angle_shift_index = (
        weighted_average_position(defect_angle_shifts, clip_min=clip_min) + defect_start_index
    )
    weighted_mean_angle_shift_index_int = int(np.round(weighted_mean_angle_shift_index))
    if weighted_mean_angle_shift_index >= len(angles_per_nm):
        weighted_mean_angle_shift_index -= len(angles_per_nm)
    if weighted_mean_angle_shift_index_int >= len(angles_per_nm):
        weighted_mean_angle_shift_index_int -= len(angles_per_nm)

    return weighted_mean_angle_shift_index, weighted_mean_angle_shift_index_int


# Calculate angle per nm for each trace

plotting = False
plot_results = True

# Define what a defect is by any section that turns d degrees in n nm
defect_degrees_value = 75
defect_nm_value = 5.5
pasty_distance_deviation_threshold_nm = 3.0

turn_in_distance_grain_dict = {}

for index, grain_data in pooled_curvature_grain_dict.items():
    # print(f"grain index: {index}")
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]
    trace = grain_data["trace"]
    pooled_trace = grain_data["pooled_trace"]

    # Check if the trace is clockwise or anticlockwise by summing the cross products of the vectors
    # If the sum is positive, the trace is clockwise
    # If the sum is negative, the trace is anticlockwise
    # If the sum is 0, the trace is a straight line
    cross_sum = 0
    for i in range(len(pooled_trace) - 1):
        cross_sum += np.cross(pooled_trace[i], pooled_trace[i + 1])
    if cross_sum > 0:
        # print("clockwise")
        # Reverse the trace
        pooled_trace = np.flip(pooled_trace, axis=0)
    elif cross_sum < 0:
        # print("anticlockwise")
        pass

    n = 2
    pooled_trace_every_nth_point = pooled_trace[::n]
    # Create a copy where the first point is appended to the end to calculate the distances between the last and first point
    if np.array_equal(pooled_trace_every_nth_point[0], pooled_trace_every_nth_point[-1]):
        pooled_trace_every_nth_point_extra_start = np.copy(pooled_trace_every_nth_point)
        pooled_trace_every_nth_point = pooled_trace_every_nth_point[:-1]
    else:
        pooled_trace_every_nth_point_extra_start = np.copy(pooled_trace_every_nth_point)
        pooled_trace_every_nth_point_extra_start = np.append(
            pooled_trace_every_nth_point_extra_start, [pooled_trace_every_nth_point[0]], axis=0
        )

    distances_between_points = np.linalg.norm(np.diff(pooled_trace_every_nth_point_extra_start, axis=0), axis=1)
    total_distance = np.sum(distances_between_points)

    distances_between_points_nm = distances_between_points * p_to_nm
    total_distance_nm = total_distance * p_to_nm

    # print(f"len of distances ends included: {len(distances_between_points)}, total distance: {total_distance_nm} nm")

    # print(f"len of points: {len(pooled_trace_every_nth_point)}")

    # plt.plot(distances_between_points)
    # plt.title(f"Distances between pooled points for grain {index}")
    # plt.show()

    angles_per_nm, angle_diffs = angle_per_nm(pooled_trace_every_nth_point, p_to_nm)

    # plt.plot(angle_diffs)
    # plt.show()

    # print(f"angle diffs: {angle_diffs}")

    # Detect if at any point more than curve_degrees_value is turned in curve_nm_value
    assert len(angles_per_nm) == len(distances_between_points)
    assert len(angles_per_nm) == len(pooled_trace_every_nth_point)

    in_defect = False
    defects = []
    maximum_total_angle_shift = 0
    maximum_angle_shift_index = 0
    for point_index, (point, angle_shift, distance) in enumerate(
        zip(pooled_trace_every_nth_point, angles_per_nm, distances_between_points)
    ):
        if plotting:
            plt.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
            plt.scatter(pooled_trace_every_nth_point[:, 1], pooled_trace_every_nth_point[:, 0], c="k", s=20)
            plt.scatter(
                pooled_trace_every_nth_point[point_index, 1],
                pooled_trace_every_nth_point[point_index, 0],
                c="white",
                s=40,
            )

        # print(f"index starting: {point_index}")
        # Get window
        window_start_index, window_end_index, window_distance = find_distance_window_looped(
            distances_between_points_nm, defect_nm_value, point_index
        )
        # print(
        #     f"  window start index: {window_start_index}, window end index: {window_end_index}, window distance: {window_distance}"
        # )

        if plotting:
            plt.scatter(
                pooled_trace_every_nth_point[window_start_index, 1],
                pooled_trace_every_nth_point[window_start_index, 0],
                c="b",
                s=20,
            )
            plt.scatter(
                pooled_trace_every_nth_point[window_end_index, 1],
                pooled_trace_every_nth_point[window_end_index, 0],
                c="r",
                s=20,
            )

        # See if the angle shift is greater than the defect_degrees_value
        # Calculate total angle change by summing the angle shifts in the window
        total_angle_shift = 0
        # End index might be lower than start index if the window wraps around
        if window_end_index > window_start_index:
            window_angle_shifts = angle_diffs[window_start_index:window_end_index]
        else:
            window_angle_shifts = np.append(angle_diffs[window_start_index:], angle_diffs[:window_end_index])
            # print(f"window end index: {window_end_index}, window start index: {window_start_index}, angle_diffs: {angle_diffs}")

        total_angle_shift = np.sum(window_angle_shifts)
        # print(f"total angle shift: {total_angle_shift} ({np.degrees(total_angle_shift)} degrees)")

        if np.abs(total_angle_shift) > np.radians(defect_degrees_value):
            # print("window angle > defect_degrees_value")
            # print(f"currently in defect: {in_defect}")
            if not in_defect:
                # print(f"@@@ DEFECT START at index: {point_index} window end index: {window_end_index}")
                if plotting:
                    # Plot between start and end index following the pooled trace
                    if window_end_index > window_start_index:
                        plt.plot(
                            pooled_trace_every_nth_point[window_start_index : window_end_index + 1, 1],
                            pooled_trace_every_nth_point[window_start_index : window_end_index + 1, 0],
                            "g",
                        )
                    else:
                        plt.plot(
                            np.append(
                                pooled_trace_every_nth_point[window_start_index:, 1],
                                pooled_trace_every_nth_point[: window_end_index + 1, 1],
                            ),
                            np.append(
                                pooled_trace_every_nth_point[window_start_index:, 0],
                                pooled_trace_every_nth_point[: window_end_index + 1, 0],
                            ),
                            "g",
                        )
                    plt.title(
                        f"defect start index: {point_index}, end index: {window_end_index}, window angle shifts: {window_angle_shifts} total angle shift: {total_angle_shift} ({np.degrees(total_angle_shift)} degrees)"
                    )
                in_defect = True
                defect_start_index = point_index
                if window_end_index > window_start_index:
                    defect_indexes = np.arange(window_start_index, window_end_index)
                else:
                    defect_indexes = np.append(
                        np.arange(window_start_index, len(angles_per_nm)), np.arange(0, window_end_index)
                    )
                defect_shifts = window_angle_shifts
                maximum_total_angle_shift = total_angle_shift

                # print(f"defect indexes: {defect_indexes}")
                # print(f"defect shifts: {defect_shifts}")
                # print(
                #     f"initial total angle shift: {maximum_total_angle_shift} ({np.degrees(maximum_total_angle_shift)} degrees)"
                # )
                maximum_angle_shift = np.max(window_angle_shifts)
                maximum_angle_shift_index = window_start_index + np.argmax(window_angle_shifts)
                if maximum_angle_shift_index >= len(angles_per_nm):
                    maximum_angle_shift_index -= len(angles_per_nm)
                # print(f"starting defect, max shift: {maximum_angle_shift}, index: {maximum_angle_shift_index}")
            else:
                # Add the new point to the defect indexes
                defect_indexes = np.append(defect_indexes, window_end_index)
                # Add the new angle shift to the defect shifts
                defect_shifts = np.append(defect_shifts, angle_diffs[window_end_index])
                # Check if the maximum angle shift is a new maximum
                window_maximum = np.max(window_angle_shifts)
                if window_maximum > maximum_angle_shift:
                    maximum_angle_shift = window_maximum
                    window_maximum_angle_shift_index = np.argmax(window_angle_shifts)
                    maximum_angle_shift_index = window_start_index + window_maximum_angle_shift_index
                    if maximum_angle_shift_index >= len(angles_per_nm):
                        maximum_angle_shift_index -= len(angles_per_nm)
                    # print(f"window maximum is new maximum: {maximum_angle_shift} at index: {maximum_angle_shift_index}")
                # Check if the total angle shift is greater than the current maximum
                if np.abs(total_angle_shift) > np.abs(maximum_total_angle_shift):
                    maximum_total_angle_shift = total_angle_shift
                    # Store the index of the maximum angle shift
                    if window_end_index > window_start_index:
                        maximum_total_angle_shift_index = window_start_index + np.argmax(window_angle_shifts)
                    else:
                        maximum_total_angle_shift_index = np.argmax(window_angle_shifts)
                # Plot between start and end index following the pooled trace
                if plotting:
                    if window_end_index > window_start_index:
                        plt.plot(
                            pooled_trace_every_nth_point[window_start_index : window_end_index + 1, 1],
                            pooled_trace_every_nth_point[window_start_index : window_end_index + 1, 0],
                            "g",
                        )
                    else:
                        plt.plot(
                            np.append(
                                pooled_trace_every_nth_point[window_start_index:, 1],
                                pooled_trace_every_nth_point[: window_end_index + 1, 1],
                            ),
                            np.append(
                                pooled_trace_every_nth_point[window_start_index:, 0],
                                pooled_trace_every_nth_point[: window_end_index + 1, 0],
                            ),
                            "g",
                        )
                    plt.title(
                        f"defect start index: {defect_start_index}, end index: {window_end_index}, window angle shifts: {window_angle_shifts} max shift: {maximum_angle_shift}, index: {maximum_angle_shift_index} total angle shift: {total_angle_shift} ({np.degrees(total_angle_shift)} degrees)"
                    )

        else:
            if plotting:
                plt.title(
                    f"no defect at index: {point_index} window end index: {window_end_index} window angle shifts: {window_angle_shifts} total angle shift: {total_angle_shift} ({np.degrees(total_angle_shift)} degrees)"
                )
            if in_defect:
                in_defect = False
                defect_end_index = window_end_index
                # print(
                #     f"@@@ DEFECT DONE: start index: {defect_start_index}, end index: {defect_end_index}, window angle shifts: {window_angle_shifts} max shift: {maximum_angle_shift}, index: {maximum_angle_shift_index}"
                # )
                # print(f"defect indexes: {defect_indexes}")
                # print(f"defect shifts: {defect_shifts}")

                weighted_mean_angle_shift_index, weighted_mean_angle_shift_index_int = weighted_mean_defect_position(
                    defect_angle_shifts=defect_shifts,
                    defect_start_index=defect_start_index,
                    max_index=len(angles_per_nm),
                )

                defects.append(
                    {
                        "start_index": defect_start_index,
                        "end_index": defect_end_index,
                        "maximum_total_angle_shift": maximum_total_angle_shift,
                        "maximum_angle_shift": maximum_angle_shift,
                        "maximum_angle_shift_index": maximum_angle_shift_index,
                        "indexes": defect_indexes,
                        "weighted_mean_angle_shift_index": weighted_mean_angle_shift_index,
                        "weighted_mean_angle_shift_index_int": weighted_mean_angle_shift_index_int,
                    }
                )

        plt.show()

    if in_defect:
        in_defect = False
        defect_end_index = window_end_index
        # print(f"@@@ DEFECT DONE: start index: {defect_start_index}, end index: {defect_end_index} max shift: {maximum_angle_shift}, index: {maximum_angle_shift_index}")

        weighted_mean_angle_shift_index, weighted_mean_angle_shift_index_int = weighted_mean_defect_position(
            defect_angle_shifts=defect_shifts,
            defect_start_index=defect_start_index,
            max_index=len(angles_per_nm),
        )
        defects.append(
            {
                "start_index": defect_start_index,
                "end_index": defect_end_index,
                "maximum_total_angle_shift": maximum_total_angle_shift,
                "maximum_angle_shift": maximum_angle_shift,
                "maximum_angle_shift_index": maximum_angle_shift_index,
                "indexes": defect_indexes,
                "weighted_mean_angle_shift_index": weighted_mean_angle_shift_index,
                "weighted_mean_angle_shift_index_int": weighted_mean_angle_shift_index_int,
            }
        )

    # Combine any overlapping regions
    for i, defect in enumerate(defects):
        for j, other_defect in enumerate(defects):
            if i != j:
                # Check if they share any common indexes
                defect_index_list = defect["indexes"]
                other_defect_index_list = other_defect["indexes"]
                overlap = False
                if len(np.intersect1d(defect_index_list, other_defect_index_list)) > 0:
                    overlap = True

                if overlap:
                    # print(f"defect {i} and {j} overlap")
                    # Combine the defects

                    combined_indexes = np.unique(np.append(defect_index_list, other_defect_index_list))
                    # print(f"combined indexes: {combined_indexes}")
                    # Check if the combined indexes span the end of the array
                    if len(angles_per_nm) - 1 in combined_indexes and 0 in combined_indexes:
                        # If so, the starting index will be the index without a number preceding it and the end index will be the index without a number following it
                        # To find end index, count forward from 0 until there is a number missing
                        for candidate_end_index in range(len(angles_per_nm)):
                            if candidate_end_index + 1 not in combined_indexes:
                                end_index = candidate_end_index
                                break
                        # To find end index, count backward from the end until there is a number missing
                        for candidate_start_index in range(len(angles_per_nm) - 1, 0, -1):
                            if candidate_start_index - 1 not in combined_indexes:
                                start_index = candidate_start_index
                                break

                    else:
                        start_index = np.min(combined_indexes)
                        end_index = np.max(combined_indexes)

                    # Calculate the total angle shift
                    if end_index > start_index:
                        total_angle_shift = np.sum(angle_diffs[start_index:end_index])
                    else:
                        total_angle_shift = np.sum(np.append(angle_diffs[start_index:], angle_diffs[:end_index]))

                    # Calculate the maximum angle shift
                    if end_index > start_index:
                        local_angle_shifts = angle_diffs[start_index:end_index]
                        local_maximum_angle_shift_index = np.argmax(local_angle_shifts)
                        maximum_angle_shift = local_angle_shifts[local_maximum_angle_shift_index]
                        maximum_angle_shift_index = start_index + local_maximum_angle_shift_index
                    else:
                        local_angle_shifts = np.append(angle_diffs[start_index:], angle_diffs[:end_index])
                        local_maximum_angle_shift_index = np.argmax(local_angle_shifts)
                        maximum_angle_shift = local_angle_shifts[local_maximum_angle_shift_index]
                        maximum_angle_shift_index = start_index + local_maximum_angle_shift_index
                        if maximum_angle_shift_index >= len(angles_per_nm):
                            maximum_angle_shift_index -= len(angles_per_nm)

                    # Calculate the weighted mean angle shift index
                    local_range = np.arange(len(local_angle_shifts))
                    # Get the angle diffs for the combined indexes
                    if end_index > start_index:
                        defect_angle_shifts = angle_diffs[start_index:end_index]
                    else:
                        defect_angle_shifts = np.append(angle_diffs[start_index:], angle_diffs[:end_index])

                    (
                        weighted_mean_angle_shift_index,
                        weighted_mean_angle_shift_index_int,
                    ) = weighted_mean_defect_position(
                        defect_angle_shifts=defect_angle_shifts,
                        defect_start_index=start_index,
                        max_index=len(angles_per_nm),
                    )

                    # Update the defect
                    defects[i] = {
                        "start_index": start_index,
                        "end_index": end_index,
                        "maximum_total_angle_shift": total_angle_shift,
                        "maximum_angle_shift": maximum_angle_shift,
                        "maximum_angle_shift_index": maximum_angle_shift_index,
                        "indexes": combined_indexes,
                        "weighted_mean_angle_shift_index": weighted_mean_angle_shift_index,
                        "weighted_mean_angle_shift_index_int": weighted_mean_angle_shift_index_int,
                    }
                    # Remove the other defect
                    defects.pop(j)

    # For each defect's weighted mean angle shift index int, calculate the distance to the next defect's weighted mean angle shift index int
    # Get a list of each defect's weighted mean angle shift index int
    weighted_mean_angle_shift_indexes_int = [defect["weighted_mean_angle_shift_index_int"] for defect in defects]
    # Sort the list
    weighted_mean_angle_shift_indexes_int.sort()
    # print(f"sorted weighted mean angle shift indexes int: {weighted_mean_angle_shift_indexes_int}")
    # Calculate the distances between each defect along the trace
    to_find = np.copy(weighted_mean_angle_shift_indexes_int)
    found_original = False
    original = to_find[0]
    current_index = original
    current_position = pooled_trace_every_nth_point[current_index]
    previous_position = current_position
    distances = []
    distance = 0
    while len(to_find) > 0:
        # Update old position
        previous_position = current_position
        # Increment the current position along the trace
        current_index += 1
        if current_index >= len(pooled_trace_every_nth_point):
            current_index -= len(pooled_trace_every_nth_point)
        current_position = pooled_trace_every_nth_point[current_index]

        # Increment the distance
        distance += np.linalg.norm(current_position - previous_position) * p_to_nm

        # Check if the current index is in the to_find list
        if current_index in to_find:
            # Get rid of the current index from the to_find list
            to_find = to_find[to_find != current_index]
            # Store the distance
            distances.append(distance)
            # Reset the distance
            distance = 0

    # print(f"distances between defects: {distances}")

    if plot_results:
        fig, ax = plt.subplots(1, len(defects), figsize=(10 * len(defects), 10))
        for defect_index, defect in enumerate(defects):
            if len(defects) == 1:
                thisax = ax
            else:
                thisax = ax[defect_index]
            # Plot the defect
            thisax.imshow(grain_image, cmap=cmap, vmin=None, vmax=None)
            # Plot a horizontal line at the top left of the image starting at 10, 10 with a length equal to the nm distance threshold for defect
            thisax.plot(
                [2, 2 + defect_nm_value / p_to_nm], [2, 2], "r", label="curvature defect window length (to scale)"
            )
            # Plot the points
            thisax.scatter(
                pooled_trace_every_nth_point[:, 1],
                pooled_trace_every_nth_point[:, 0],
                c="k",
                s=20,
                label="pooled trace",
            )
            # Plot the start point of the entire trace
            thisax.scatter(
                pooled_trace_every_nth_point[0, 1], pooled_trace_every_nth_point[0, 0], c="pink", s=60, label="start"
            )
            # Write the angle shift at each point
            for i, point in enumerate(pooled_trace_every_nth_point):
                thisax.text(
                    point[1],
                    point[0],
                    f"{int(np.round(np.degrees(angle_diffs[i])))}",
                    fontsize=14,
                    color="white",
                    label="angle shifts",
                )
            # Plot the start and end index
            thisax.scatter(
                pooled_trace_every_nth_point[defect["start_index"], 1],
                pooled_trace_every_nth_point[defect["start_index"], 0],
                c="blue",
                s=60,
                alpha=1,
                label="start of window for defect",
            )
            thisax.scatter(
                pooled_trace_every_nth_point[defect["end_index"], 1],
                pooled_trace_every_nth_point[defect["end_index"], 0],
                c="red",
                s=60,
                alpha=1,
                label="end of window for defect",
            )
            # Plot the maximum angle shift index
            thisax.scatter(
                pooled_trace_every_nth_point[defect["maximum_angle_shift_index"], 1],
                pooled_trace_every_nth_point[defect["maximum_angle_shift_index"], 0],
                c="green",
                s=100,
                alpha=1,
                label="maximum angle shift position within window",
            )
            # Plot the weighted mean angle shift index
            thisax.scatter(
                pooled_trace_every_nth_point[defect["weighted_mean_angle_shift_index_int"], 1],
                pooled_trace_every_nth_point[defect["weighted_mean_angle_shift_index_int"], 0],
                c="yellow",
                s=100,
                alpha=1,
                label="weighted mean angle shift position within window",
            )

            # Plot a line between the start and end index following the pooled trace
            if defect["end_index"] > defect["start_index"]:
                thisax.plot(
                    pooled_trace_every_nth_point[defect["start_index"] : defect["end_index"] + 1, 1],
                    pooled_trace_every_nth_point[defect["start_index"] : defect["end_index"] + 1, 0],
                    "b",
                )
            else:
                thisax.plot(
                    np.append(
                        pooled_trace_every_nth_point[defect["start_index"] :, 1],
                        pooled_trace_every_nth_point[: defect["end_index"] + 1, 1],
                    ),
                    np.append(
                        pooled_trace_every_nth_point[defect["start_index"] :, 0],
                        pooled_trace_every_nth_point[: defect["end_index"] + 1, 0],
                    ),
                    "b",
                )
            thisax.set_title(f"defect {defect_index}")

        plt.suptitle(
            f"Defects {len(defects)} for grain {index}, total distance: {total_distance_nm:.2f} nm, defect degrees: {defect_degrees_value}, defect nm: {defect_nm_value}, p_to_nm: {p_to_nm:.2f}"
        )
        fig.legend()
        # fig.tight_layout()
        plt.show()

    # Classification
    # If there are 3 defects then it is a dorito
    # If there are 2 defects then it is either a churro or pasty
    # If the two defects have significantly different distances then it's a pasty, else it's a churro
    # If it has fewer than 2 or more than 3, then it is unclassified

    if len(defects) == 3:
        tag = "dorito"
    elif len(defects) == 2:
        if np.abs(distances[0] - distances[1]) > pasty_distance_deviation_threshold_nm:
            tag = "pasty"
        else:
            tag = "churro"

    turn_in_distance_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "trace": trace,
        "pooled_trace": pooled_trace,
        "angles_per_nm": angles_per_nm,
        "angle_diffs": angle_diffs,
        "distances_between_points": distances_between_points,
        "total_distance": total_distance,
        "defects": defects,
        "num_defects": len(defects),
        "distances_between_defects": distances,
        "p_to_nm": p_to_nm,
    }


def plot_images(images: list, grain_indexes: list, px_to_nms: list, width=5, cmap=cmap, vmin=None, vmax=None):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width, figsize=(30, 30))
    for i, (image, grain_index) in enumerate(zip(images, grain_indexes)):
        ax[i // width, i % width].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width].axis("off")
    fig.tight_layout()
    plt.show()