In [None]:
import os
import pickle
import shutil
import sys
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

from topostats.io import LoadScans
from topostats.measure.curvature import (
    discrete_angle_difference_per_nm_circular,
    discrete_angle_difference_per_nm_linear,
)

today = datetime.today().strftime("D-%Y-%m-%d-T-%H-%M")
print(today)

In [None]:
base_dir = Path("/Users/sylvi/topo_data/beaks")
assert base_dir.exists()
beak_topo_data_dir = base_dir / "output-beaks-topostats-unet-good"
assert beak_topo_data_dir.exists()

hummingbird_dir = beak_topo_data_dir / "hummingbird/processed/"
assert hummingbird_dir.exists()
magpie_dir = beak_topo_data_dir / "magpie/processed/"
assert magpie_dir.exists()

# grab files in both directories ending in .topostats
hummingbird_files = hummingbird_dir.glob("*.topostats")
magpie_files = magpie_dir.glob("*.topostats")
# merge lists
all_files = list(hummingbird_files) + list(magpie_files)
print(f"Found {len(all_files)} files")

In [None]:
def interpolate_between_two_points(point1, point2, distance):
    distance_between_points = np.linalg.norm(point2 - point1)
    if distance_between_points < distance:
        raise ValueError("distance between points is less than the desired interval")
    proportion = distance / distance_between_points
    new_point = point1 + proportion * (point2 - point1)
    return new_point

In [None]:
grains_dictionary = {}

loadscans = LoadScans(all_files, channel="dummy")
loadscans.get_data()
img_dict = loadscans.img_dict

grain_index = 0
for filename, file_data in img_dict.items():
    print(f"getting data from {filename}")
    image = file_data["image"]
    ordered_trace_data = file_data["ordered_traces"]["above"]
    for current_grain_index, grain_ordered_trace_data in ordered_trace_data.items():
        print(f"  grain {current_grain_index}")
        grains_dictionary[grain_index] = {}
        grains_dictionary[grain_index]["molecule_data"] = {}
        for molecule_index, molecule_ordered_trace_data in grain_ordered_trace_data.items():
            molecule_data = {}
            molecule_data["ordered_coords"] = molecule_ordered_trace_data["ordered_coords"]
            molecule_data["heights"] = molecule_ordered_trace_data["heights"]
            molecule_data["distances"] = molecule_ordered_trace_data["distances"]
            bbox = molecule_ordered_trace_data["bbox"]
            grains_dictionary[grain_index]["molecule_data"][molecule_index] = molecule_data
            # print(molecule_ordered_trace_data.keys())
        image_crop = image[
            bbox[0] : bbox[2],
            bbox[1] : bbox[3],
        ]
        full_grain_mask = file_data["grain_masks"]["above"]
        grains_dictionary[grain_index]["image"] = image_crop
        grains_dictionary[grain_index]["bbox"] = bbox
        mask_crop = full_grain_mask[
            bbox[0] : bbox[2],
            bbox[1] : bbox[3],
        ]
        grains_dictionary[grain_index]["mask"] = mask_crop
        grains_dictionary[grain_index]["filename"] = file_data["filename"]
        grains_dictionary[grain_index]["pixel_to_nm_scaling"] = file_data["pixel_to_nm_scaling"]
        grain_index += 1

# for grain_index, grain_data in grains_dictionary.items():
#     print(f"grain {grain_index}")
#     print(grain_data["filename"])
#     print(grain_data["pixel_to_nm_scaling"])
#     image = grain_data["image"]
#     plt.imshow(image)
#     for molecule_index, molecule_data in grain_data["molecule_data"].items():
#         ordered_coords = molecule_data["ordered_coords"]
#         plt.plot(ordered_coords[:, 1], ordered_coords[:, 0], "r")
#     plt.show()

#     mask = grain_data["mask"][:, :, 1]
#     plt.imshow(mask)
#     plt.show()

print(len(grains_dictionary))

In [None]:
file = all_files[0]

for file in all_files:

    individual_plotting = False

    # load the file
    loadscans = LoadScans([file], channel="dummy")
    loadscans.get_data()
    img_dict = loadscans.img_dict
    data = img_dict[file.stem]
    print(data.keys())

    image = data["image"]
    splining_data = data["splining"]["above"]

    plt.imshow(image)

    # ignore erroneously small splines
    min_splined_coords_num_points = 10

    # gallery image
    n_cols = 3

    for grain_index, grain_splining_data in splining_data.items():
        for molecule_index, molecule_splining_data in grain_splining_data.items():
            splined_coords = molecule_splining_data["spline_coords"]

            if len(splined_coords) < min_splined_coords_num_points:
                print(
                    f"Skipping molecule {molecule_index} in grain {grain_index} as it has too few points ({len(splined_coords)} < {min_splined_coords_num_points})"
                )
                continue

            bbox = molecule_splining_data["bbox"]
            # plt.plot(ordered_coords[:, 1] + bbox[1], ordered_coords[:, 0] + bbox[0], "r-")

            grain_image = image[bbox[0] : bbox[2], bbox[1] : bbox[3]]

            if individual_plotting:
                fig, ax = plt.subplots(figsize=(10, 10))
                plt.imshow(grain_image)

            # resample the spline to get points at fixed intervals
            interval = 5

            resampled_points = []
            resampled_points.append(splined_coords[0])
            current_spline_index = 1
            while True:
                current_point = resampled_points[-1]
                next_splined_point = splined_coords[current_spline_index]
                distance_to_next_splined_point = np.linalg.norm(next_splined_point - current_point)
                # if the distance to the next splined point is less than the interval, then skip to the next point
                if distance_to_next_splined_point < interval:
                    current_spline_index += 1
                    if current_spline_index >= len(splined_coords):
                        break
                    continue
                new_interpolated_point = interpolate_between_two_points(current_point, next_splined_point, interval)
                resampled_points.append(new_interpolated_point)

            resampled_points = np.array(resampled_points)

            resampled_curvatures = discrete_angle_difference_per_nm_linear(resampled_points)

            curvature_threshold = 0.30
            curvature_defects = np.where(np.abs(resampled_curvatures) > curvature_threshold)[0]

            if individual_plotting:
                for resampled_point_index, resampled_point in enumerate(resampled_points):
                    curvature = resampled_curvatures[resampled_point_index]
                    # colour the curvature to be between 0.5 and -0.5
                    colour = (curvature + 0.5) / 1
                    plt.plot(
                        resampled_point[1],
                        resampled_point[0],
                        "o",
                        c=plt.get_cmap("bwr")(colour),
                        markersize=5,
                    )

                for defect_index in curvature_defects:
                    resampled_point = resampled_points[defect_index]
                    plt.plot(
                        resampled_point[1],
                        resampled_point[0],
                        "o",
                        c="r",
                        markersize=10,
                    )

            if individual_plotting:
                plt.gca().invert_yaxis()
                plt.gca().set_aspect("equal", adjustable="box")
                plt.show()

                plt.plot(resampled_curvatures, label="discrete curvatures")
                plt.plot(np.abs(resampled_curvatures), label="discrete absolute cuvatures")

                plt.legend()
                plt.ylim(-0.1, 0.5)
                plt.show()

            # plt.plot(resampled_points[:, 1] + bbox[1], resampled_points[:, 0] + bbox[0], "w-")
            # plot curvature defects
            for defect_index in curvature_defects:
                resampled_point = resampled_points[defect_index]
                plt.plot(
                    resampled_point[1] + bbox[1], resampled_point[0] + bbox[0], "o", c="red", markersize=2, alpha=1
                )

            if len(curvature_defects > 0):
                # draw a red box around the bbox
                plt.plot(
                    [bbox[1], bbox[3], bbox[3], bbox[1], bbox[1]],
                    [bbox[0], bbox[0], bbox[2], bbox[2], bbox[0]],
                    "r-",
                    alpha=0.5,
                )
    plt.show()