In [None]:
from pathlib import Path

from typing import Any
from topostats.io import LoadScans
from topostats.plottingfuncs import Colormap
from topostats.utils import convolve_skeleton

import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt

colormap = Colormap()
cmap = colormap.get_cmap()
vmin = -3.0
vmax = 4.0


def clear_output():
    from IPython.display import clear_output as co

    co()


def load_data(dir: Path) -> dict[str, Any]:
    files = list(dir.glob("*.topostats"))
    loader = LoadScans(files, channel="dummy")
    loader.get_data()
    clear_output()
    return loader.img_dict

In [None]:
dir_raw_data = Path("/Users/sylvi/topo_data/connect-loose-ends/data")
loaded_files = load_data(dir_raw_data)
print(f"Loaded {len(loaded_files)} files from {dir_raw_data}")

In [None]:
# Get the skeletons

for filename, file_data in loaded_files.items():
    print(f"file: {filename}")
    print(file_data.keys())
    image = file_data["image"]
    mask = file_data["grain_tensors"]["above"][:, :, 1]
    p2nm = file_data["pixel_to_nm_scaling"]
    all_disordered_traces = file_data["disordered_traces"]["above"]
    plt.imshow(image)
    plt.show()
    for grain_index, grain_disordered_traces in all_disordered_traces.items():
        print(f" grain: {grain_index}")
        grain_bbox = grain_disordered_traces["bbox"]
        skeleton = grain_disordered_traces["skeleton"]
        print(f"skeleton unique values: {np.unique(skeleton)}")
        grain_image = image[grain_bbox[0] : grain_bbox[2], grain_bbox[1] : grain_bbox[3]]
        grain_mask = mask[grain_bbox[0] : grain_bbox[2], grain_bbox[1] : grain_bbox[3]]
        print(grain_disordered_traces.keys())

        # Find the endpoints by convolving the skeleton
        convolved_skeleton = convolve_skeleton(skeleton)

        # Get a list of endpoints
        endpoints = np.argwhere(convolved_skeleton == 2)

        # For each endpoint pair, calculate the distance
        endpoint_pair_distances: list[tuple[npt.NDArray[np.int_], npt.NDArray[np.int_], float]] = []
        for endpoint_1_index in range(len(endpoints)):
            for endpoint_2_index in range(endpoint_1_index + 1, len(endpoints)):
                endpoint_1 = endpoints[endpoint_1_index]
                endpoint_2 = endpoints[endpoint_2_index]
                distance = np.linalg.norm((endpoint_1 - endpoint_2) * p2nm)
                endpoint_pair_distances.append((endpoint_1, endpoint_2, distance))
        # Sort by distance
        endpoint_pair_distances.sort(key=lambda data: data[2]) # sort by distance
        print("Endpoint pair distances (nm):")
        for endpoint_1, endpoint_2, distance in endpoint_pair_distances:
            print(f"  Distance: {distance:.2f} nm between {endpoint_1} and {endpoint_2}")

        fig, ax = plt.subplots(figsize=(15, 15))
        plt.imshow(grain_image, cmap=cmap, vmin=vmin, vmax=vmax)
        grain_mask_mask = np.ma.masked_where(grain_mask == 0, grain_mask)
        plt.imshow(grain_mask_mask, cmap="Blues_r", alpha=0.3)
        skeleton_mask = np.ma.masked_where(convolved_skeleton == 0, convolved_skeleton)
        plt.imshow(skeleton_mask, cmap="viridis", alpha=1)
        plt.title(f"File: {filename} | Grain: {grain_index}")
        plt.show()