In [None]:
import os
import pickle
import shutil
import sys
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from scipy.interpolate import interp1d

from topostats.io import LoadScans
from topostats.unet_masking import make_bounding_box_square, pad_bounding_box
from topostats.measure.curvature import (
    discrete_angle_difference_per_nm_circular,
    discrete_angle_difference_per_nm_linear,
)

today = datetime.today().strftime("D-%Y-%m-%d-T-%H-%M")
print(today)

In [None]:
base_dir = Path("/Users/sylvi/topo_data/beaks")
assert base_dir.exists()
beak_topo_data_dir = base_dir / "output-beaks-topostats-unet-good"
assert beak_topo_data_dir.exists()

hummingbird_dir = beak_topo_data_dir / "hummingbird/processed/"
assert hummingbird_dir.exists()
magpie_dir = beak_topo_data_dir / "magpie/processed/"
assert magpie_dir.exists()

# grab files in both directories ending in .topostats
hummingbird_files = hummingbird_dir.glob("*.topostats")
magpie_files = magpie_dir.glob("*.topostats")
# merge lists
all_files = list(hummingbird_files) + list(magpie_files)
print(f"Found {len(all_files)} files")

In [None]:
def interpolate_between_two_points(point1, point2, distance):
    distance_between_points = np.linalg.norm(point2 - point1)
    if distance_between_points < distance:
        raise ValueError("distance between points is less than the desired interval")
    proportion = distance / distance_between_points
    new_point = point1 + proportion * (point2 - point1)
    return new_point

## construct grain dictionary

In [None]:
grains_dictionary = {}

loadscans = LoadScans(all_files, channel="dummy")
loadscans.get_data()
img_dict = loadscans.img_dict

bbox_padding = 10
grain_index = 0
for filename, file_data in img_dict.items():
    # print(f"getting data from {filename}")
    image = file_data["image"]
    ordered_trace_data = file_data["ordered_traces"]["above"]
    for current_grain_index, grain_ordered_trace_data in ordered_trace_data.items():
        # print(f"  grain {current_grain_index}")
        grains_dictionary[grain_index] = {}
        grains_dictionary[grain_index]["molecule_data"] = {}
        for current_molecule_index, molecule_ordered_trace_data in grain_ordered_trace_data.items():
            molecule_data = {}
            molecule_data["ordered_coords"] = molecule_ordered_trace_data["ordered_coords"]
            molecule_data["heights"] = molecule_ordered_trace_data["heights"]
            molecule_data["distances"] = molecule_ordered_trace_data["distances"]
            bbox = molecule_ordered_trace_data["bbox"]
            grains_dictionary[grain_index]["molecule_data"][current_molecule_index] = molecule_data

            splining_coords = file_data["splining"]["above"][current_grain_index][current_molecule_index][
                "spline_coords"
            ]
            molecule_data["spline_coords"] = splining_coords

            # print(molecule_ordered_trace_data.keys())
        bbox_square = make_bounding_box_square(bbox[0], bbox[1], bbox[2], bbox[3], image.shape)
        bbox_padded = pad_bounding_box(
            bbox_square[0], bbox_square[1], bbox_square[2], bbox_square[3], image.shape, padding=bbox_padding
        )
        added_left = bbox_padded[1] - bbox[1]
        added_top = bbox_padded[0] - bbox[0]

        image_crop = image[
            bbox_padded[0] : bbox_padded[2],
            bbox_padded[1] : bbox_padded[3],
        ]
        full_grain_mask = file_data["grain_masks"]["above"]
        grains_dictionary[grain_index]["image"] = image_crop
        grains_dictionary[grain_index]["full_image"] = image
        grains_dictionary[grain_index]["bbox"] = bbox_padded
        grains_dictionary[grain_index]["added_left"] = added_left
        grains_dictionary[grain_index]["added_top"] = added_top
        grains_dictionary[grain_index]["padding"] = bbox_padding
        mask_crop = full_grain_mask[
            bbox_padded[0] : bbox_padded[2],
            bbox_padded[1] : bbox_padded[3],
        ]
        grains_dictionary[grain_index]["mask"] = mask_crop
        grains_dictionary[grain_index]["filename"] = file_data["filename"]
        grains_dictionary[grain_index]["pixel_to_nm_scaling"] = file_data["pixel_to_nm_scaling"]
        grain_index += 1

# for grain_index, grain_data in grains_dictionary.items():
#     print(f"grain {grain_index}")
#     print(grain_data["filename"])
#     print(grain_data["pixel_to_nm_scaling"])
#     image = grain_data["image"]
#     plt.imshow(image)
#     for molecule_index, molecule_data in grain_data["molecule_data"].items():
#         ordered_coords = molecule_data["ordered_coords"]
#         plt.plot(ordered_coords[:, 1], ordered_coords[:, 0], "r")
#     plt.show()

#     mask = grain_data["mask"][:, :, 1]
#     plt.imshow(mask)
#     plt.show()

print(len(grains_dictionary))

In [None]:
grains_with_beaks = [7, 13, 16, 26, 28, 30, 34, 36, 37, 38, 43, 45, 54, 55, 58, 65, 66, 71, 72, 79, 82, 89, 91, 92, 93]

grains_with_beaks_dictionary = {}

for grain_index in grains_with_beaks:
    grains_with_beaks_dictionary[grain_index] = grains_dictionary[grain_index]

In [None]:
def resample_points_regular_interval(points: npt.NDArray, interval: float):
    """Resample a set of points to be at regular intervals"""

    resampled_points = []
    resampled_points.append(points[0])
    current_point_index = 1
    while True:
        current_point = resampled_points[-1]
        next_original_point = points[current_point_index]
        distance_to_next_splined_point = np.linalg.norm(next_original_point - current_point)
        # if the distance to the next splined point is less than the interval, then skip to the next point
        if distance_to_next_splined_point < interval:
            current_point_index += 1
            if current_point_index >= len(points):
                break
            continue
        new_interpolated_point = interpolate_between_two_points(current_point, next_original_point, interval)
        resampled_points.append(new_interpolated_point)

    # if the first and last points are less than 0.5 * the interval apart, then remove the last point
    if np.linalg.norm(resampled_points[0] - resampled_points[-1]) < 0.5 * interval:
        resampled_points = resampled_points[:-1]

    resampled_points = np.array(resampled_points)

    return resampled_points

In [None]:
for grain_index, grain_data in grains_with_beaks_dictionary.items():
    print(f"grain {grain_index}")

    full_image = grain_data["full_image"]
    image = grain_data["image"]
    mask = grain_data["mask"][:, :, 0]
    mask_thick_threshold = 2.0
    mask_thick = image > 2.0
    bbox = grain_data["bbox"]
    p2nm = grain_data["pixel_to_nm_scaling"]
    bbox_padding = grain_data["padding"]
    bbox_added_left = grain_data["added_left"]
    bbox_added_top = grain_data["added_top"]

    # plt.imshow(full_image)
    # plt.show()

    # ignore erroneously small splines
    min_splined_coords_num_points = 10

    for molecule_index, molecule_data in grain_data["molecule_data"].items():
        splined_coords_px = molecule_data["spline_coords"]
        splined_coords_px_accurate_positions = splined_coords_px.copy()
        splined_coords_px_accurate_positions[:, 0] -= bbox_added_top
        splined_coords_px_accurate_positions[:, 1] -= bbox_added_left
        splined_coords_nm = splined_coords_px_accurate_positions * p2nm

        # plot splined coords px
        # plt.imshow(image)
        # plt.plot(splined_coords_px_accurate_positions[:, 1], splined_coords_px_accurate_positions[:, 0], "r")
        # plt.title(f"grain {grain_index} molecule {molecule_index} bbox: {bbox} padding: {bbox_padding}")
        # plt.show()

        if len(splined_coords_nm) < min_splined_coords_num_points:
            print(
                f"Skipping molecule {current_molecule_index} in grain {grain_index} as it has too few points ({len(splined_coords_nm)} < {min_splined_coords_num_points})"
            )
            continue

        # resample the spline to get points at fixed intervals
        interval_nm = 5.0
        resampled_points_nm = resample_points_regular_interval(points=splined_coords_nm, interval=interval_nm)
        resampled_points_px = resampled_points_nm / p2nm

        # determine if circular
        circular_endpoint_distance_threshold = 2.0 * interval_nm
        if np.linalg.norm(resampled_points_nm[0] - resampled_points_nm[-1]) < circular_endpoint_distance_threshold:
            is_circular = True
        else:
            is_circular = False

        if is_circular:
            resampled_curvatures = discrete_angle_difference_per_nm_circular(resampled_points_nm)
        else:
            resampled_curvatures = discrete_angle_difference_per_nm_linear(resampled_points_nm)
        # threshold the curvature to find defects
        curvature_threshold = 0.25
        curvature_defects = np.where(np.abs(resampled_curvatures) > curvature_threshold)[0]

        # plot the molecule
        fig, ax = plt.subplots(figsize=(5, 5))
        plt.imshow(image)
        plt.imshow(mask_thick, alpha=0.5)
        plt.plot(resampled_points_px[:, 1], resampled_points_px[:, 0], "r")
        # change the coordinates of the plot to be in nanometers using the pixel to nm scaling ratio
        for resampled_point_index, resampled_point in enumerate(resampled_points_px):
            curvature = resampled_curvatures[resampled_point_index]
            # colour the curvature to be between 0.5 and -0.5
            colour = (curvature + 0.5) / 1
            plt.plot(
                resampled_point[1],
                resampled_point[0],
                "o",
                c=plt.get_cmap("bwr")(colour),
                markersize=5,
            )

        for defect_index in curvature_defects:
            resampled_point = resampled_points_px[defect_index]
            plt.plot(
                resampled_point[1],
                resampled_point[0],
                "o",
                c="r",
                markersize=10,
            )

        # find beaks at regions of moderate curvature
        possible_beak_curvature_threshold = 0.2
        possible_beak_points_interpolation_number = 20
        slight_curvature_defects = np.where(np.abs(resampled_curvatures) > possible_beak_curvature_threshold)[0]
        # at these regions, explore the mask to see if there is a beak
        for slight_defect_index in slight_curvature_defects:
            slight_defect_point = resampled_points_px[slight_defect_index]
            # get the mask value at this point
            mask_value = mask[int(slight_defect_point[0]), int(slight_defect_point[1])]
            # mark this point with a blue hollow circle
            plt.plot(slight_defect_point[1], slight_defect_point[0], "o", c="b", markersize=10, alpha=0.5)

            slight_defect_explore_range_points = 3
            # check if the explore range will be in bounds, also remember that the neighbouring points are checked hence the +1 -1
            if slight_defect_index - slight_defect_explore_range_points - 1 < 0:
                if not is_circular:
                    print(
                        f"slight defect index {slight_defect_index} - {slight_defect_explore_range_points + 1} < 0 in a noncircular trace, skipping"
                    )
                    continue
            if slight_defect_index + slight_defect_explore_range_points + 1 >= len(resampled_points_px):
                if not is_circular:
                    print(
                        f"slight defect index {slight_defect_index} + {slight_defect_explore_range_points + 1} >= {len(resampled_points_px)} in a noncircular trace, skipping"
                    )
                    continue

            print(f"slight defect index: {slight_defect_index}")

            # grab the points
            possible_beak_points = []
            for relative_index in range(-slight_defect_explore_range_points, slight_defect_explore_range_points + 1):
                current_point_index = slight_defect_index + relative_index
                # check if the indexes are in bounds and map them if not. if they aren't, they're guaranteed
                if current_point_index < 0:
                    current_point_index += len(resampled_points_px)
                if current_point_index >= len(resampled_points_px):
                    current_point_index -= len(resampled_points_px)
                possible_beak_points.append(resampled_points_px[current_point_index])
            possible_beak_points = np.array(possible_beak_points)

            # interpolate the possible beak points
            # Linear length along the line:
            distance = np.cumsum(np.sqrt(np.sum(np.diff(possible_beak_points, axis=0) ** 2, axis=1)))
            # Prepend the initial distance of 0
            distance = np.insert(distance, 0, 0) / distance[-1]
            # s parameter points
            alpha = np.linspace(0, 1, possible_beak_points_interpolation_number)
            interpolator = interp1d(distance, possible_beak_points, kind="cubic", axis=0)
            interpolated_possible_beak_points = interpolator(alpha)

            plt.plot(interpolated_possible_beak_points[:, 1], interpolated_possible_beak_points[:, 0], "cyan")

            for point_index, point in enumerate(interpolated_possible_beak_points):
                if point_index == 0:
                    continue
                elif point_index == len(interpolated_possible_beak_points) - 1:
                    continue
                else:
                    previous_point = interpolated_possible_beak_points[point_index - 1]
                    next_point = interpolated_possible_beak_points[point_index + 1]
                    vector_to_previous_point = previous_point - point
                    vector_to_previous_point /= np.linalg.norm(vector_to_previous_point)
                    vector_to_next_point = next_point - point
                    vector_to_next_point /= np.linalg.norm(vector_to_next_point)
                    average_vector = (vector_to_previous_point + vector_to_next_point) / 2
                    average_vector /= np.linalg.norm(average_vector)
                    flipped_average_vector = -average_vector
                    plt.plot(
                        [point[1], point[1] + flipped_average_vector[1] * 10],
                        [point[0], point[0] + flipped_average_vector[0] * 10],
                        "g",
                        alpha=0.5,
                    )

            # # start at a point 3 points before and iterate to 3 points after
            # for relative_index in range(-3, 4):
            #     current_point_index = slight_defect_index + relative_index
            #     # check if the indexes are in bounds and map them if not. if they aren't, they're guaranteed
            #     if current_point_index < 0:
            #         current_point_index += len(resampled_points_px)
            #     if current_point_index >= len(resampled_points_px):
            #         current_point_index -= len(resampled_points_px)
            #     previous_point_index = current_point_index - 1
            #     if previous_point_index < 0:
            #         previous_point_index += len(resampled_points_px)
            #     if previous_point_index >= len(resampled_points_px):
            #         previous_point_index -= len(resampled_points_px)
            #     next_point_index = current_point_index + 1
            #     if next_point_index < 0:
            #         next_point_index += len(resampled_points_px)
            #     if next_point_index >= len(resampled_points_px):
            #         next_point_index -= len(resampled_points_px)

            #     # grab points
            #     previous_point = resampled_points_px[previous_point_index]
            #     current_point = resampled_points_px[current_point_index]
            #     next_point = resampled_points_px[next_point_index]
            #     vector_to_previous_point = previous_point - current_point
            #     vector_to_previous_point /= np.linalg.norm(vector_to_previous_point)
            #     vector_to_next_point = next_point - current_point
            #     vector_to_next_point /= np.linalg.norm(vector_to_next_point)

            #     # calculate vector half way between the two
            #     average_vector = (vector_to_previous_point + vector_to_next_point) / 2
            #     average_vector /= np.linalg.norm(average_vector)

            #     # flip it the other way to face away from the turn
            #     flipped_average_vector = -average_vector

            #     # # travel along the orthogonal vector, seeing if we are still in the mask
            #     # orthogonal_vector_step_delta = 0.2 # pixels
            #     # current_orthogonal_point = current_point
            #     # # follow the orthogonal vector until the current pixel is not in the mask
            #     # while True:
            #     #     current_orthogonal_point += orthogonal_vector * orthogonal_vector_step_delta
            #     #     current_orthogonal_point_integer = current_orthogonal_point.astype(int)
            #     #     if mask[current_orthogonal_point_integer[0], current_orthogonal_point_integer[1]] == 0:
            #     #         break

            #     # plot where the vector ends

            #     # draw the orthogonal vector
            #     plt.plot(
            #         [current_point[1], current_point[1] + flipped_average_vector[1] * 10],
            #         [current_point[0], current_point[0] + flipped_average_vector[0] * 10],
            #         "g",
            #         alpha=0.5,
            #     )

        plt.gca().set_aspect("equal", adjustable="box")
        plt.title(
            f"grain {grain_index} molecule {molecule_index} bbox: {bbox} padding: {bbox_padding} circular: {is_circular}"
        )
        plt.show()

        # plot curvatures
        plt.plot(resampled_curvatures, label="discrete curvatures")
        plt.plot(np.abs(resampled_curvatures), label="discrete absolute cuvatures")
        plt.legend()
        plt.ylim(-0.1, 0.5)
        plt.title(f"grain {grain_index} molecule {molecule_index} curvatures")
        plt.show()

In [None]:
import numpy as np
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt

# Define some points:
points = np.array([[0, 1, 8, 2, 2], [1, 0, 6, 7, 2]]).T  # a (nbre_points x nbre_dim) array
print(points.shape)

# because interpolating in 2d, need to interpolate with respect to the distance parameter, s, not x or y.
# so need to calculate the distance along the line

# Linear length along the line:
distance = np.cumsum(np.sqrt(np.sum(np.diff(points, axis=0) ** 2, axis=1)))
# Prepend the initial distance of 0
distance = np.insert(distance, 0, 0) / distance[-1]

# Interpolation for different methods:
interpolations_methods = ["slinear", "quadratic", "cubic"]
alpha = np.linspace(0, 1, 10)

interpolated_points = {}
for method in interpolations_methods:
    interpolator = interp1d(distance, points, kind=method, axis=0)
    interpolated_points[method] = interpolator(alpha)

# Graph:
plt.figure(figsize=(7, 7))
for method_name, curve in interpolated_points.items():
    plt.plot(*curve.T, "-", label=method_name)

plt.plot(*points.T, "ok", label="original points")
plt.axis("equal")
plt.legend()
plt.xlabel("x")
plt.ylabel("y")