In [None]:
from pathlib import Path
import pickle

from pydantic import BaseModel, Field

from typing import Any
from scipy import ndimage
import networkx as nx
from topostats.io import LoadScans
from topostats.plottingfuncs import Colormap
from topostats.utils import convolve_skeleton
from topostats.mask_manipulation import smooth_mask
from topostats.tracing.skeletonize import getSkeleton
from topostats.measure.geometry import calculate_mask_width_with_skeleton, calculate_pixel_path_distance

import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt

colormap = Colormap()
cmap = colormap.get_cmap()
vmin = -3.0
vmax = 4.0


def clear_output():
    from IPython.display import clear_output as co

    co()


def load_data(dir: Path) -> dict[str, Any]:
    files = list(dir.glob("*.topostats"))
    loader = LoadScans(files, channel="dummy")
    loader.get_data()
    clear_output()
    return loader.img_dict

In [None]:
dir_mid_process_files = Path("/Users/sylvi/topo_data/connect-loose-ends/mid-topostats-processing-data-files")
files = list(dir_mid_process_files.glob("*.pkl"))
loaded_files: dict[str, dict] = {}
for file in files:
    with open(file, "rb") as f:
        loaded_file = pickle.load(f)
        filename = loaded_file["filename"]
        loaded_files[filename] = loaded_file

In [None]:
class Endpoint(BaseModel):
    id: int
    position: tuple[int, int]


class ConnectionGroup(BaseModel):
    id: int
    endpoints: dict[int, Endpoint] = Field(default_factory=dict)
    hard_connected_endpoints: list[tuple[int, int]] = Field(default_factory=list)
    close_endpoint_pairs: list[tuple[int, int]] = Field(default_factory=list)

In [None]:
def keep_only_nonrepeated_endpoints(
    potential_pairs: list[list[list[int, int], list[int, int], float]],
) -> list[list[list[int, int], list[int, int], float]]:

    used_endpoints: list[tuple[int, int]] = []
    for potential_pair in potential_pairs:
        endpoint_1, endpoint_2, distance_nm = potential_pair
        used_endpoints.append((endpoint_1[0], endpoint_1[1]))
        used_endpoints.append((endpoint_2[0], endpoint_2[1]))

    repeated_endpoints = set([ep for ep in used_endpoints if used_endpoints.count(ep) > 1])

    pairs_no_repeated_ends: list[list[list[int, int], list[int, int], float]] = []
    for potential_pair in potential_pairs:
        endpoint_1, endpoint_2, distance_nm = potential_pair
        if (endpoint_1[0], endpoint_1[1]) not in repeated_endpoints and (
            endpoint_2[0],
            endpoint_2[1],
        ) not in repeated_endpoints:
            pairs_no_repeated_ends.append(potential_pair)
        else:
            print(f"excluding pair {endpoint_1}, {endpoint_2} due to repeated endpoints")

    return pairs_no_repeated_ends


def connect_endpoints_with_best_path(
    image: npt.NDArray[np.float32],
    mask: npt.NDArray[np.bool_],
    p2nm: float,
    endpoint_1: tuple[int, int],
    endpoint_2: tuple[int, int],
    endpoint_connection_cost_map_height_maximum: float,
) -> tuple[npt.NDArray[np.uint8], float, float]:
    # create a weight cost map from the image, where 0 is the maximum cost, and the lowest cost is configurable.
    # first create a crop around the two endpoints to speed up pathfinding
    cost_map_bbox_padding_px = 10
    min_y = max(0, min(endpoint_1[0], endpoint_2[0]) - cost_map_bbox_padding_px)
    max_y = min(image.shape[0], max(endpoint_1[0], endpoint_2[0]) + cost_map_bbox_padding_px)
    min_x = max(0, min(endpoint_1[1], endpoint_2[1]) - cost_map_bbox_padding_px)
    max_x = min(image.shape[1], max(endpoint_1[1], endpoint_2[1]) + cost_map_bbox_padding_px)
    cost_map = image[min_y:max_y, min_x:max_x]
    mask_crop = mask[min_y:max_y, min_x:max_x]
    image_crop = image[min_y:max_y, min_x:max_x]
    local_endpoint_1 = (endpoint_1[0] - min_y, endpoint_1[1] - min_x)
    local_endpoint_2 = (endpoint_2[0] - min_y, endpoint_2[1] - min_x)
    # clip it to the height bounds
    cost_map = np.clip(
        cost_map,
        a_min=0,
        a_max=endpoint_connection_cost_map_height_maximum,
    )
    # invert it
    cost_map = endpoint_connection_cost_map_height_maximum - cost_map
    # normalise to 0-1
    cost_map = cost_map / endpoint_connection_cost_map_height_maximum

    # find the lowest cost path between the two endpoints
    from skimage.graph import route_through_array

    path, cost = route_through_array(
        cost_map,
        start=local_endpoint_1,
        end=local_endpoint_2,
        fully_connected=True,  # allow diagonal moves
    )

    # Convert the path back to the original image coordinates
    path = [(y + min_y, x + min_x) for y, x in path]
    # Convert to numpy array for easier indexing
    path = np.array(path)

    # Calculate the distance of the path in nm, taking into account diagnonal distances
    path_distance_nm = calculate_pixel_path_distance(path) * p2nm

    return path, cost, path_distance_nm


def group_endpoints(
    endpoints: dict[int, Endpoint], close_pairs: list[int, int, float], draw_graph: bool = False
) -> dict[int, ConnectionGroup]:
    """
    Group endpoints into connection groups based on interconnections.
    """
    # Split the graph into connected groups

    # Create a graph
    G = nx.Graph()
    # Add the nodes (endpoints) as endpoint IDs
    for endpoint_index, endpoint in endpoints.items():
        G.add_node(endpoint_index)
    # Add edges for each close pair (as endpoint IDs)
    for endpoint_1_index, endpoint_2_index, distance_nm in close_pairs:
        G.add_edge(endpoint_1_index, endpoint_2_index)

    # draw it
    if draw_graph:
        nx.draw(G, with_labels=True)
        plt.show()

    # Get networkx to find connected components
    connected_components = list(nx.connected_components(G))
    # Create ConnectionGroup objects for each connected component
    connection_groups: dict[int, ConnectionGroup] = {}
    for group_id, component in enumerate(connected_components):
        group_endpoints = {endpoint_index: endpoints[endpoint_index] for endpoint_index in component}
        # Get the close pairs that are within this component
        group_close_pairs = [
            (endpoint_1_index, endpoint_2_index)
            for endpoint_1_index, endpoint_2_index, _distance_nm in close_pairs
            if endpoint_1_index in component and endpoint_2_index in component
        ]
        connection_group = ConnectionGroup(
            id=group_id,
            endpoints=group_endpoints,
            close_endpoint_pairs=group_close_pairs,
        )
        connection_groups[group_id] = connection_group
    return connection_groups

In [None]:
skeletonisation_holearea_min_max = (0, None)
skeletonisation_mask_smoothing_dilation_iterations = 2
skeletonisation_mask_smoothing_gaussian_sigma = 2
skeletonisation_method = "topostats"
skeletonisation_height_bias = 0.6
endpoint_connection_distance_nm = 10
endpoint_connection_cost_map_height_maximum = 3.0
endpoint_hard_connection_distance_nm = 20.0


for filename, file_data in loaded_files.items():

    # Let's focus on this one file for now.
    # if filename != "20251031_nicked_picoz_8ng_nicl.0_00062":
    #     continue

    if filename != "20251031_nicked_picoz_8ng_nicl.0_00083":
        continue

    print(f"processing file: {filename}")
    p2nm = file_data["pixel_to_nm_scaling"]
    tensor = file_data["full_mask_tensor"]
    image = file_data["image"]

    channel_to_connect_ends = 1  # use the DNA channel for connecting loose ends.

    mask = tensor[:, :, channel_to_connect_ends].astype(bool)
    # plt.imshow(mask, cmap="gray")
    # plt.title("Original mask")
    # plt.show()

    smoothed_mask = smooth_mask(
        filename=filename,
        pixel_to_nm_scaling=p2nm,
        grain=mask,
        gaussian_sigma=skeletonisation_mask_smoothing_gaussian_sigma,
        holearea_min_max=skeletonisation_holearea_min_max,
        dilation_iterations=skeletonisation_mask_smoothing_dilation_iterations,
    )
    # plt.imshow(smoothed_mask, cmap="gray")
    # plt.title("Smoothed mask")
    # plt.show()

    # Maybe need to check it doesn't touch the edge of the image like we do in disordered_tracing? unsure.

    # Next step, skeletonize
    skeleton = getSkeleton(
        image=image,
        mask=smoothed_mask,
        method=skeletonisation_method,
        height_bias=0.6,
    ).get_skeleton()

    # Calculate the mask width along the skeleton for later
    mean_mask_width_nm = calculate_mask_width_with_skeleton(
        mask=smoothed_mask,
        skeleton=skeleton,
        pixel_to_nm_scaling=p2nm,
    )
    mean_mask_width_px = mean_mask_width_nm / p2nm

    # fig, ax = plt.subplots(figsize=(10, 10))
    # plt.imshow(skeleton, cmap="gray")
    # plt.title("skeleton")
    # plt.show()

    # Now to find the skeleton endpoints and connect close ones.
    convolved_skeleton = convolve_skeleton(skeleton=skeleton)
    # Get the endpoints, value = 2
    endpoint_coords = np.argwhere(convolved_skeleton == 2)
    print("endpoints:")
    print(endpoint_coords)

    # construct list of Endpoint objects
    endpoints: dict[int, Endpoint] = {}
    for endpoint_index, coord in enumerate(endpoint_coords):
        endpoint = Endpoint(id=endpoint_index, position=(coord[0], coord[1]))
        endpoints[endpoint_index] = endpoint

    fig, ax = plt.subplots(figsize=(20, 20))
    plt.imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
    grain_mask_mask = np.ma.masked_where(~mask, mask)
    plt.imshow(grain_mask_mask, cmap="Blues_r", alpha=0.3)
    skeleton_mask = np.ma.masked_where(~convolved_skeleton.astype(bool), convolved_skeleton)
    plt.imshow(skeleton_mask, cmap="viridis", alpha=0.7)
    plt.show()

    # For each endpoint, determine if any others are close enough to connect.
    nearby_endpoint_pairs: list[tuple[int, int, float]] = []
    for i, endpoint_1 in endpoints.items():
        for j, endpoint_2 in endpoints.items():
            if i >= j:
                continue  # avoid double counting
            distance_nm = np.linalg.norm((np.array(endpoint_1.position) - np.array(endpoint_2.position)) * p2nm)
            if distance_nm <= endpoint_connection_distance_nm:
                nearby_endpoint_pairs.append((endpoint_1.id, endpoint_2.id, distance_nm))

    # Group nearby endpoint pairs into connection groups
    connection_groups = group_endpoints(endpoints=endpoints, close_pairs=nearby_endpoint_pairs)
    print(f"connection groups:")
    print(connection_groups)

    # # Now consider each group and decide how to connect them.
    # # Iterate over each possible pair
    # for group_index, group_of_pairs in enumerate(groups_of_potentially_connected_endpoint_pairs):
    #     print(f"processing group {group_index} of {len(groups_of_potentially_connected_endpoint_pairs)}")

    #     # If there is only one pair in the group, just connect it.
    #     if len(group_of_pairs) == 1:
    #         pair = group_of_pairs[0]
    #         endpoint_1, endpoint_2, distance_nm = pair

    #         path, cost, distance_nm = connect_endpoints_with_best_path(
    #             image=image,
    #             mask=mask,
    #             p2nm=p2nm,
    #             endpoint_1=(endpoint_1[0], endpoint_1[1]),
    #             endpoint_2=(endpoint_2[0], endpoint_2[1]),
    #             endpoint_connection_cost_map_height_maximum=endpoint_connection_cost_map_height_maximum,
    #         )

    #         path_mask = np.zeros_like(mask, dtype=bool)
    #         # Set the path to True
    #         for y, x in path:
    #             path_mask[y, x] = True
    #         # Calculate the dilation iterations needed to reach the mean mask width
    #         dilation_radius = int(np.ceil(mean_mask_width_px / 2))
    #         dilated_path_array = ndimage.binary_dilation(
    #             path_mask,
    #             iterations=dilation_radius,
    #         )

    #         # Add the dilated path to the whole mask
    #         mask = mask | dilated_path_array

    #         # Update the skeleton to include the path
    #         for y, x in path:
    #             skeleton[y, x] = True

    #         continue  # move to next group

    #     # For groups with multiple pairs, we need to find a way to connect all endpoints together.

    #     # find the hard-connected endpoints in this group, defined as the endpoints that are connected via the
    #     # skeleton already, within a configurable hard connection distance.
    #     endpoint_hard_connection_distance_px = int(endpoint_hard_connection_distance_nm / p2nm)

    # #     path_mask = np.zeros_like(mask, dtype=bool)
    # #     # Set the path to True
    # #     for y, x in path:
    # #         path_mask[y, x] = True
    # #     # Calculate the dilation iterations needed to reach the mean mask width
    # #     dilation_radius = int(np.ceil(mean_mask_width_px / 2))
    # #     dilated_path_array = ndimage.binary_dilation(
    # #         path_mask,
    # #         iterations=dilation_radius,
    # #     )

    # #     # Add the dilated path to the whole mask
    # #     mask = mask | dilated_path_array

    # #     # Update the skeleton to include the path
    # #     for y, x in path:
    # #         skeleton[y, x] = True

    # # fig, ax = plt.subplots(figsize=(20, 20))
    # # plt.imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
    # # grain_mask_mask = np.ma.masked_where(~mask, mask)
    # # plt.imshow(grain_mask_mask, cmap="Blues_r", alpha=0.3)
    # # skeleton_mask = np.ma.masked_where(~skeleton.astype(bool), skeleton)
    # # plt.imshow(skeleton_mask, cmap="viridis", alpha=0.7)
    # # plt.title("skeleton after connecting loose ends")
    # # plt.show()