In [None]:
from pathlib import Path
import pickle
import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

from topostats.measure.curvature import (
    discrete_angle_difference_per_nm_circular,
    discrete_angle_difference_per_nm_linear,
)
from beaks import resample_points_regular_interval

# load grain dictionary

In [None]:
base_dir = Path("/Users/sylvi/topo_data/beaks")
assert base_dir.exists()
data_dir = base_dir / "grain-dictionaries"
assert data_dir.exists()
filename = "grains_with_beaks_D-2025-03-04-T-13-35.pkl"
assert (data_dir / filename).exists()

# load the dictionary
with open(data_dir / filename, "rb") as f:
    grains_dictionary = pickle.load(f)

print(f"Loaded {len(grains_dictionary)} grains")

In [None]:
stats = {}

for grain_index, grain_data in grains_dictionary.items():
    print(f"grain {grain_index}")

    full_image = grain_data["full_image"]
    image = grain_data["image"]
    mask = grain_data["mask"][:, :, 0]
    mask_thick_threshold = 2.0
    mask_thick = image > 2.0
    bbox = grain_data["bbox"]
    p2nm = grain_data["pixel_to_nm_scaling"]
    bbox_padding = grain_data["padding"]
    bbox_added_left = grain_data["added_left"]
    bbox_added_top = grain_data["added_top"]

    grain_beak_data = {}

    print(f"p2nm: {p2nm}")

    # plt.imshow(full_image)
    # plt.show()

    # ignore erroneously small splines
    min_splined_coords_num_points = 10

    for molecule_index, molecule_data in grain_data["molecule_data"].items():
        splined_coords_px = molecule_data["spline_coords"]
        splined_coords_px_accurate_positions = splined_coords_px.copy()
        splined_coords_px_accurate_positions[:, 0] -= bbox_added_top
        splined_coords_px_accurate_positions[:, 1] -= bbox_added_left
        splined_coords_nm = splined_coords_px_accurate_positions * p2nm

        molecule_beak_data = {}
        molecule_beak_data["num_beaks_sharp"] = 0
        molecule_beak_data["num_beaks_explored"] = 0

        # plot splined coords px
        # plt.imshow(image)
        # plt.plot(splined_coords_px_accurate_positions[:, 1], splined_coords_px_accurate_positions[:, 0], "r")
        # plt.title(f"grain {grain_index} molecule {molecule_index} bbox: {bbox} padding: {bbox_padding}")
        # plt.show()

        if len(splined_coords_nm) < min_splined_coords_num_points:
            print(
                f"Skipping molecule {molecule_index} in grain {grain_index} as it has too few points ({len(splined_coords_nm)} < {min_splined_coords_num_points})"
            )
            continue

        # resample the spline to get points at fixed intervals
        interval_nm = 5.0
        resampled_points_nm = resample_points_regular_interval(points=splined_coords_nm, interval=interval_nm)
        resampled_points_px = resampled_points_nm / p2nm

        # determine if circular
        circular_endpoint_distance_threshold = 2.0 * interval_nm
        if np.linalg.norm(resampled_points_nm[0] - resampled_points_nm[-1]) < circular_endpoint_distance_threshold:
            is_circular = True
        else:
            is_circular = False

        if is_circular:
            resampled_curvatures = discrete_angle_difference_per_nm_circular(resampled_points_nm)
        else:
            resampled_curvatures = discrete_angle_difference_per_nm_linear(resampled_points_nm)
        # threshold the curvature to find defects
        curvature_threshold = 0.3
        curvature_defects = np.where(np.abs(resampled_curvatures) > curvature_threshold)[0]

        # plot the molecule
        grain_fig, grain_ax = plt.subplots(figsize=(10, 10))
        grain_ax.imshow(image)
        grain_ax.imshow(mask_thick, alpha=0.5)
        grain_ax.plot(resampled_points_px[:, 1], resampled_points_px[:, 0], "r")
        # change the coordinates of the plot to be in nanometers using the pixel to nm scaling ratio
        for resampled_point_index, resampled_point in enumerate(resampled_points_px):
            curvature = resampled_curvatures[resampled_point_index]
            # colour the curvature to be between 0.5 and -0.5
            colour = (curvature + 0.5) / 1
            grain_ax.plot(
                resampled_point[1],
                resampled_point[0],
                "o",
                c=plt.get_cmap("bwr")(colour),
                markersize=5,
            )

        molecule_beak_data["num_beaks_sharp"] = len(curvature_defects)
        for defect_index in curvature_defects:
            resampled_point = resampled_points_px[defect_index]
            grain_ax.plot(
                resampled_point[1],
                resampled_point[0],
                "o",
                c="r",
                markersize=10,
            )

        # find beaks at regions of moderate curvature
        possible_beak_curvature_threshold = 0.2
        possible_beak_points_interpolation_number = 20
        exploration_distance_defect_threshold_nm = 7
        slight_curvature_defects = np.where(np.abs(resampled_curvatures) > possible_beak_curvature_threshold)[0]
        # at these regions, explore the mask to see if there is a beak
        for slight_defect_index in slight_curvature_defects:
            slight_defect_point = resampled_points_px[slight_defect_index]
            # get the mask value at this point
            mask_value = mask[int(slight_defect_point[0]), int(slight_defect_point[1])]
            # mark this point with a blue hollow circle
            grain_ax.plot(slight_defect_point[1], slight_defect_point[0], "o", c="b", markersize=8)

            slight_defect_explore_range_points = 2
            # check if the explore range will be in bounds, also remember that the neighbouring points are checked hence the +1 -1
            if slight_defect_index - slight_defect_explore_range_points - 1 < 0:
                if not is_circular:
                    print(
                        f"slight defect index {slight_defect_index} - {slight_defect_explore_range_points + 1} < 0 in a noncircular trace, skipping"
                    )
                    continue
            if slight_defect_index + slight_defect_explore_range_points + 1 >= len(resampled_points_px):
                if not is_circular:
                    print(
                        f"slight defect index {slight_defect_index} + {slight_defect_explore_range_points + 1} >= {len(resampled_points_px)} in a noncircular trace, skipping"
                    )
                    continue

            print(f"slight defect index: {slight_defect_index}")

            # grab the points
            possible_beak_points = []
            for relative_index in range(-slight_defect_explore_range_points, slight_defect_explore_range_points + 1):
                current_point_index = slight_defect_index + relative_index
                # check if the indexes are in bounds and map them if not. if they aren't, they're guaranteed
                if current_point_index < 0:
                    current_point_index += len(resampled_points_px)
                if current_point_index >= len(resampled_points_px):
                    current_point_index -= len(resampled_points_px)
                possible_beak_points.append(resampled_points_px[current_point_index])
            possible_beak_points = np.array(possible_beak_points)

            # interpolate the possible beak points
            # Linear length along the line:
            distance = np.cumsum(np.sqrt(np.sum(np.diff(possible_beak_points, axis=0) ** 2, axis=1)))
            # Prepend the initial distance of 0
            distance = np.insert(distance, 0, 0) / distance[-1]
            # s parameter points
            alpha = np.linspace(0, 1, possible_beak_points_interpolation_number)
            interpolator = interp1d(distance, possible_beak_points, kind="cubic", axis=0)
            interpolated_possible_beak_points = interpolator(alpha)

            grain_ax.plot(interpolated_possible_beak_points[:, 1], interpolated_possible_beak_points[:, 0], "cyan")

            exploration_profile = []

            for point_index, point in enumerate(interpolated_possible_beak_points):
                # update exploration min max values
                if point_index == 0:
                    continue
                elif point_index == len(interpolated_possible_beak_points) - 1:
                    continue
                else:
                    previous_point = interpolated_possible_beak_points[point_index - 1]
                    next_point = interpolated_possible_beak_points[point_index + 1]
                    vector_to_previous_point = previous_point - point
                    vector_to_previous_point /= np.linalg.norm(vector_to_previous_point)
                    vector_to_next_point = next_point - point
                    vector_to_next_point /= np.linalg.norm(vector_to_next_point)
                    average_vector = (vector_to_previous_point + vector_to_next_point) / 2
                    average_vector /= np.linalg.norm(average_vector)
                    flipped_average_vector = -average_vector
                    grain_ax.plot(
                        [point[1], point[1] + flipped_average_vector[1] * 10],
                        [point[0], point[0] + flipped_average_vector[0] * 10],
                        "g",
                        alpha=0.5,
                    )

                    # explore the mask in the direction of the vector
                    possible_beak_vector_delta_px = 0.5
                    exploration_current_position = point.astype(float)
                    grain_ax.plot(
                        exploration_current_position[1],
                        exploration_current_position[0],
                        "o",
                        c="yellow",
                        markersize=2,
                    )
                    # while still in the mask, continue forward
                    while True:
                        exploration_current_position += possible_beak_vector_delta_px * flipped_average_vector.astype(
                            float
                        )
                        # turn to integer coords
                        exploration_current_position_int = exploration_current_position.astype(int)
                        # check if still in mask
                        if (
                            exploration_current_position_int[0] < 0
                            or exploration_current_position_int[0] >= mask.shape[0]
                        ):
                            print(f"[exploration] reached edge of image at {exploration_current_position}")
                            break
                        if (
                            exploration_current_position_int[1] < 0
                            or exploration_current_position_int[1] >= mask.shape[1]
                        ):
                            print(f"[exploration] reached edge of image at {exploration_current_position_int}")
                            break
                        # check if in mask
                        if not mask_thick[exploration_current_position_int[0], exploration_current_position_int[1]]:
                            # reached edge of mask, find distance to start exploration point
                            distance_to_explore_start = np.linalg.norm(exploration_current_position_int - point)
                            exploration_profile.append(distance_to_explore_start)
                            # mark the point
                            grain_ax.plot(
                                exploration_current_position_int[1],
                                exploration_current_position_int[0],
                                "o",
                                c="orange",
                                markersize=2,
                            )
                            break

            exploration_profile = np.array(exploration_profile)
            exploration_profile_nm = exploration_profile * p2nm

            # get index of largest exploration profile point and the value
            largest_exploration_index = np.argmax(exploration_profile)
            largest_exploration_distance = exploration_profile_nm[largest_exploration_index]
            if largest_exploration_distance > exploration_distance_defect_threshold_nm:
                # defect found
                molecule_beak_data["num_beaks_explored"] += 1
                print(
                    f"possible defect confirmed, with distance {largest_exploration_distance} > {exploration_distance_defect_threshold_nm} nm"
                )
                grain_ax.plot(
                    interpolated_possible_beak_points[largest_exploration_index][1],
                    interpolated_possible_beak_points[largest_exploration_index][0],
                    "o",
                    c="green",
                    markersize=6,
                )

            explore_profile_fig, explore_profile_ax = plt.subplots()
            explore_profile_ax.plot(exploration_profile_nm)
            explore_profile_ax.set_title(
                f"grain {grain_index} molecule {molecule_index} slight defect index {slight_defect_index} exploration profile"
            )

        plt.gca().set_aspect("equal", adjustable="box")
        grain_ax.set_title(
            f"grain {grain_index} molecule {molecule_index} bbox: {bbox} padding: {bbox_padding} circular: {is_circular}"
        )
        plt.show()

        # plot curvatures
        curvature_fig, curvature_ax = plt.subplots()
        curvature_ax.plot(resampled_curvatures, label="discrete curvatures")
        curvature_ax.plot(np.abs(resampled_curvatures), label="discrete absolute cuvatures")
        curvature_ax.legend()
        curvature_ax.set_ylim(-0.1, 0.5)
        curvature_ax.set_title(f"grain {grain_index} molecule {molecule_index} curvatures")
        plt.show()

        grain_beak_data[molecule_index] = molecule_beak_data

    stats[grain_index] = grain_beak_data

In [None]:
# plot number of beaks per molecule

# calculate molecules with a sharp beak
num_grains_sharp_beak = 0
# calculate molecules with an explored beak
num_grains_explored_beak = 0
# calculate molecules without any beak
num_grains_no_beak = 0
num_grains_with_any_beak = 0
num_grains_total = 0

grains_num_sharp_beaks = []
grains_num_explored_beaks = []

for grain_index, grain_beak_stats in stats.items():
    num_grains_total += 1
    mol_num_sharp_beaks = 0
    mol_num_explored_beaks = 0
    mol_has_sharp_beak = False
    mol_has_explored_beak = False

    for molecule_index, molecule_beak_stats in grain_beak_stats.items():
        if molecule_beak_stats["num_beaks_sharp"] > 0:
            mol_has_sharp_beak = True
        if molecule_beak_stats["num_beaks_explored"] > 0:
            mol_has_explored_beak = True

        mol_num_sharp_beaks += molecule_beak_stats["num_beaks_sharp"]
        mol_num_explored_beaks += molecule_beak_stats["num_beaks_explored"]

    grains_num_sharp_beaks.append(mol_num_sharp_beaks)
    grains_num_explored_beaks.append(mol_num_explored_beaks)

    if mol_has_sharp_beak:
        num_grains_sharp_beak += 1
    if mol_has_explored_beak:
        num_grains_explored_beak += 1
    if not mol_has_sharp_beak and not mol_has_explored_beak:
        num_grains_no_beak += 1
    if mol_has_sharp_beak or mol_has_explored_beak:
        num_grains_with_any_beak += 1

In [None]:
# plot num_grains_no_beak, num_grains_any_beak on a single bar
bar_x = ["beaks vs no beaks"]
bar_y_1 = [num_grains_with_any_beak]
bar_y_2 = [num_grains_no_beak]

print(num_grains_total)

fig, ax = plt.subplots()
ax.bar(bar_x, bar_y_1, label="beaks")
ax.bar(bar_x, bar_y_2, bottom=bar_y_1, label="no beaks")
ax.legend()
plt.show()

grains_num_sharp_beaks = np.array(grains_num_sharp_beaks)
grains_num_explored_beaks = np.array(grains_num_explored_beaks)

bar_xs_sharp = np.unique(grains_num_sharp_beaks)
bar_ys_sharp = np.array([np.sum(grains_num_sharp_beaks == x) for x in bar_xs_sharp])
bar_xs_explored = np.unique(grains_num_explored_beaks)
bar_ys_explored = np.array([np.sum(grains_num_explored_beaks == x) for x in bar_xs_explored])

fig, ax = plt.subplots()
ax.bar(bar_xs_sharp, bar_ys_sharp)
plt.title("Number of sharp beaks per grain")
plt.show()

fig, ax = plt.subplots()
ax.bar(bar_xs_explored, bar_ys_explored)
plt.title("Number of explored beaks per grain")
plt.show()
