In [None]:
import os
import pickle
import shutil
import sys
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

from topostats.io import LoadScans
from topostats.unet_masking import make_bounding_box_square, pad_bounding_box
from topostats.measure.curvature import (
    discrete_angle_difference_per_nm_circular,
    discrete_angle_difference_per_nm_linear,
)

today = datetime.today().strftime("D-%Y-%m-%d-T-%H-%M")
print(today)

In [None]:
base_dir = Path("/Users/sylvi/topo_data/beaks")
assert base_dir.exists()
beak_topo_data_dir = base_dir / "output-beaks-topostats-unet-good"
assert beak_topo_data_dir.exists()

hummingbird_dir = beak_topo_data_dir / "hummingbird/processed/"
assert hummingbird_dir.exists()
magpie_dir = beak_topo_data_dir / "magpie/processed/"
assert magpie_dir.exists()

# grab files in both directories ending in .topostats
hummingbird_files = hummingbird_dir.glob("*.topostats")
magpie_files = magpie_dir.glob("*.topostats")
# merge lists
all_files = list(hummingbird_files) + list(magpie_files)
print(f"Found {len(all_files)} files")

In [None]:
def interpolate_between_two_points(point1, point2, distance):
    distance_between_points = np.linalg.norm(point2 - point1)
    if distance_between_points < distance:
        raise ValueError("distance between points is less than the desired interval")
    proportion = distance / distance_between_points
    new_point = point1 + proportion * (point2 - point1)
    return new_point

## construct grain dictionary

In [None]:
grains_dictionary = {}

loadscans = LoadScans(all_files, channel="dummy")
loadscans.get_data()
img_dict = loadscans.img_dict

bbox_padding = 10
grain_index = 0
for filename, file_data in img_dict.items():
    print(f"getting data from {filename}")
    image = file_data["image"]
    ordered_trace_data = file_data["ordered_traces"]["above"]
    for current_grain_index, grain_ordered_trace_data in ordered_trace_data.items():
        print(f"  grain {current_grain_index}")
        grains_dictionary[grain_index] = {}
        grains_dictionary[grain_index]["molecule_data"] = {}
        for current_molecule_index, molecule_ordered_trace_data in grain_ordered_trace_data.items():
            molecule_data = {}
            molecule_data["ordered_coords"] = molecule_ordered_trace_data["ordered_coords"]
            molecule_data["heights"] = molecule_ordered_trace_data["heights"]
            molecule_data["distances"] = molecule_ordered_trace_data["distances"]
            bbox = molecule_ordered_trace_data["bbox"]
            grains_dictionary[grain_index]["molecule_data"][current_molecule_index] = molecule_data

            splining_coords = file_data["splining"]["above"][current_grain_index][current_molecule_index][
                "spline_coords"
            ]
            molecule_data["spline_coords"] = splining_coords

            # print(molecule_ordered_trace_data.keys())
        bbox_square = make_bounding_box_square(bbox[0], bbox[1], bbox[2], bbox[3], image.shape)
        bbox_padded = pad_bounding_box(
            bbox_square[0], bbox_square[1], bbox_square[2], bbox_square[3], image.shape, padding=bbox_padding
        )
        added_left = bbox_padded[1] - bbox[1]
        added_top = bbox_padded[0] - bbox[0]

        image_crop = image[
            bbox_padded[0] : bbox_padded[2],
            bbox_padded[1] : bbox_padded[3],
        ]
        full_grain_mask = file_data["grain_masks"]["above"]
        grains_dictionary[grain_index]["image"] = image_crop
        grains_dictionary[grain_index]["full_image"] = image
        grains_dictionary[grain_index]["bbox"] = bbox_padded
        grains_dictionary[grain_index]["added_left"] = added_left
        grains_dictionary[grain_index]["added_top"] = added_top
        grains_dictionary[grain_index]["padding"] = bbox_padding
        mask_crop = full_grain_mask[
            bbox_padded[0] : bbox_padded[2],
            bbox_padded[1] : bbox_padded[3],
        ]
        grains_dictionary[grain_index]["mask"] = mask_crop
        grains_dictionary[grain_index]["filename"] = file_data["filename"]
        grains_dictionary[grain_index]["pixel_to_nm_scaling"] = file_data["pixel_to_nm_scaling"]
        grain_index += 1

# for grain_index, grain_data in grains_dictionary.items():
#     print(f"grain {grain_index}")
#     print(grain_data["filename"])
#     print(grain_data["pixel_to_nm_scaling"])
#     image = grain_data["image"]
#     plt.imshow(image)
#     for molecule_index, molecule_data in grain_data["molecule_data"].items():
#         ordered_coords = molecule_data["ordered_coords"]
#         plt.plot(ordered_coords[:, 1], ordered_coords[:, 0], "r")
#     plt.show()

#     mask = grain_data["mask"][:, :, 1]
#     plt.imshow(mask)
#     plt.show()

print(len(grains_dictionary))

In [None]:
for grain_index, grain_data in grains_dictionary.items():
    print(f"grain {grain_index}")

    full_image = grain_data["full_image"]
    image = grain_data["image"]
    mask = grain_data["mask"]
    bbox = grain_data["bbox"]
    p2nm = grain_data["pixel_to_nm_scaling"]
    bbox_padding = grain_data["padding"]
    bbox_added_left = grain_data["added_left"]
    bbox_added_top = grain_data["added_top"]

    plt.imshow(full_image)
    plt.show()

    # ignore erroneously small splines
    min_splined_coords_num_points = 10

    for molecule_index, molecule_data in grain_data["molecule_data"].items():
        splined_coords_px = molecule_data["spline_coords"]
        splined_coords_px[:, 0] -= bbox_added_top
        splined_coords_px[:, 1] -= bbox_added_left
        splined_coords_nm = splined_coords_px * p2nm

        if len(splined_coords_nm) < min_splined_coords_num_points:
            print(
                f"Skipping molecule {current_molecule_index} in grain {grain_index} as it has too few points ({len(splined_coords_nm)} < {min_splined_coords_num_points})"
            )
            continue

        # resample the spline to get points at fixed intervals
        interval_nm = 5

        resampled_points_nm = []
        resampled_points_nm.append(splined_coords_nm[0])
        current_spline_index = 1
        while True:
            current_point = resampled_points_nm[-1]
            next_splined_point = splined_coords_nm[current_spline_index]
            distance_to_next_splined_point = np.linalg.norm(next_splined_point - current_point)
            # if the distance to the next splined point is less than the interval, then skip to the next point
            if distance_to_next_splined_point < interval_nm:
                current_spline_index += 1
                if current_spline_index >= len(splined_coords_nm):
                    break
                continue
            new_interpolated_point = interpolate_between_two_points(current_point, next_splined_point, interval_nm)
            resampled_points_nm.append(new_interpolated_point)

        resampled_points_nm = np.array(resampled_points_nm)

        resampled_curvatures = discrete_angle_difference_per_nm_linear(resampled_points_nm)

        curvature_threshold = 0.30
        curvature_defects = np.where(np.abs(resampled_curvatures) > curvature_threshold)[0]

        # plot the molecule
        fig, ax = plt.subplots(figsize=(10, 10))
        plt.imshow(image)
        # change the coordinates of the plot to be in nanometers using the pixel to nm scaling ratio
        for resampled_point_index, resampled_point in enumerate(resampled_points_nm):
            resampled_point /= p2nm
            curvature = resampled_curvatures[resampled_point_index]
            # colour the curvature to be between 0.5 and -0.5
            colour = (curvature + 0.5) / 1
            plt.plot(
                resampled_point[1],
                resampled_point[0],
                "o",
                c=plt.get_cmap("bwr")(colour),
                markersize=5,
            )

        for defect_index in curvature_defects:
            resampled_point = resampled_points_nm[defect_index] / p2nm
            plt.plot(
                resampled_point[1],
                resampled_point[0],
                "o",
                c="r",
                markersize=10,
            )

        plt.gca().set_aspect("equal", adjustable="box")
        plt.title(f"grain {grain_index} molecule {molecule_index} bbox: {bbox} padding: {bbox_padding}")
        plt.show()

        # plot curvatures
        plt.plot(resampled_curvatures, label="discrete curvatures")
        plt.plot(np.abs(resampled_curvatures), label="discrete absolute cuvatures")
        plt.legend()
        plt.ylim(-0.1, 0.5)
        plt.title(f"grain {grain_index} molecule {molecule_index} curvatures")
        plt.show()

    if grain_index > 20:
        break