In [None]:
import os
import pickle
import shutil
import sys
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

from topostats.io import LoadScans
from topostats.measure.curvature import \
    discrete_angle_difference_per_nm_circular

today = datetime.today().strftime("D-%Y-%m-%d-T-%H-%M")
print(today)

In [None]:
base_dir = Path("/Users/sylvi/topo_data/beaks")
assert base_dir.exists()
beak_topo_data_dir = base_dir / "output-beaks-topostats-unet-good"
assert beak_topo_data_dir.exists()

hummingbird_dir = beak_topo_data_dir / "hummingbird/processed/"
assert hummingbird_dir.exists()
magpie_dir = beak_topo_data_dir / "magpie/processed/"
assert magpie_dir.exists()

# grab files in both directories ending in .topostats
hummingbird_files = hummingbird_dir.glob("*.topostats")
magpie_files = magpie_dir.glob("*.topostats")
# merge lists
all_files = list(hummingbird_files) + list(magpie_files)
print(f"Found {len(all_files)} files")

In [None]:
def interpolate_between_two_points(point1, point2, distance):
    distance_between_points = np.linalg.norm(point2 - point1)
    if distance_between_points < distance:
        raise ValueError("distance between points is less than the desired interval")
    proportion = distance / distance_between_points
    new_point = point1 + proportion * (point2 - point1)
    return new_point

In [None]:
file = all_files[1]

# load the file
loadscans = LoadScans([file], channel="dummy")
loadscans.get_data()
img_dict = loadscans.img_dict
data = img_dict[file.stem]

print(data.keys())

image = data["image"]
splining_data = data["splining"]["above"]
plt.imshow(image)
plt.show()

for grain_index, grain_splining_data in splining_data.items():
    for molecule_index, molecule_splining_data in grain_splining_data.items():
        print(molecule_splining_data.keys())
        splined_coords = molecule_splining_data["spline_coords"]
        bbox = molecule_splining_data["bbox"]
        # plt.plot(ordered_coords[:, 1] + bbox[1], ordered_coords[:, 0] + bbox[0], "r-")

        grain_image = image[bbox[0] : bbox[2], bbox[1] : bbox[3]]
        plt.imshow(grain_image)

        curvatures = data["grain_curvature_stats"]["above"][grain_index][molecule_index]
        print(f"len curvatures: {len(curvatures)}")
        print(f"len ordered trace: {len(splined_coords)}")
        for splined_coord_index, splined_coord in enumerate(splined_coords):
            # plot line between current and previous point with colour based on curvature
            if splined_coord_index == 0:
                continue
            prev_coord = splined_coords[splined_coord_index - 1]
            curvature = curvatures[splined_coord_index - 1]
            normalised_curvature = (curvature - min(curvatures)) / (max(curvatures) - min(curvatures))
            plt.plot(
                [prev_coord[1], splined_coord[1]],
                [prev_coord[0], splined_coord[0]],
                c=(normalised_curvature, normalised_curvature, normalised_curvature),
            )
        plt.show()

        plt.plot(curvatures)
        plt.show()

        # plot zoom in of trace points
        fig, ax = plt.subplots(figsize=(10, 10))
        for splined_coord_index, splined_coord in enumerate(splined_coords):
            curvature = curvatures[splined_coord_index]
            normalised_curvature = (curvature - min(curvatures)) / (max(curvatures) - min(curvatures))
            plt.plot(
                splined_coord[1],
                splined_coord[0],
                ".",
                c=(normalised_curvature, 1 - normalised_curvature, normalised_curvature),
                markersize=1,
            )
        plt.gca().invert_yaxis()
        plt.gca().set_aspect("equal", adjustable="box")
        plt.show()

        # resample the spline to get points at fixed intervals
        interval = 5

        resampled_points = []
        resampled_points.append(splined_coords[0])
        current_spline_index = 1
        while True:
            current_point = resampled_points[-1]
            next_splined_point = splined_coords[current_spline_index]
            distance_to_next_splined_point = np.linalg.norm(next_splined_point - current_point)
            # if the distance to the next splined point is less than the interval, then skip to the next point
            if distance_to_next_splined_point < interval:
                current_spline_index += 1
                if current_spline_index >= len(splined_coords):
                    break
                continue
            new_interpolated_point = interpolate_between_two_points(current_point, next_splined_point, interval)
            resampled_points.append(new_interpolated_point)

        resampled_points = np.array(resampled_points)

        resampled_curvatures = discrete_angle_difference_per_nm_circular(resampled_points)

        fig, ax = plt.subplots(figsize=(10, 10))
        for resampled_point_index, resampled_point in enumerate(resampled_points):
            curvature = resampled_curvatures[resampled_point_index]
            normalised_curvature = (curvature - min(resampled_curvatures)) / (
                max(resampled_curvatures) - min(resampled_curvatures)
            )
            plt.plot(
                resampled_point[1],
                resampled_point[0],
                "o",
                c=(normalised_curvature, 1 - normalised_curvature, normalised_curvature),
                markersize=5,
            )
        plt.gca().invert_yaxis()
        plt.gca().set_aspect("equal", adjustable="box")
        plt.show()

        curvature_threshold = 0.35

        plt.plot(resampled_curvatures, label="discrete curvatures")
        plt.plot(np.abs(resampled_curvatures), label="discrete absolute cuvatures")
        plt.legend()
        plt.ylim(-0.1, 0.5)
        plt.show()