In [None]:
from pathlib import Path

import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt
import h5py



In [None]:
imagefile = Path("/Users/sylvi/Downloads/20250328_TAF_picoz_5_RA_15mM_mg_3mM_ni.0_00035.topostats")
with h5py.File(imagefile, "r") as f:
    print(f.keys())
    tensor = f["grain_tensors"]["above"][:]
    image = f["image"][:]
    spline_data = f["splining"]["above"]["grain_0"]["mol_0"]
    print(spline_data.keys())
    bbox = spline_data["bbox"][:]
    print("bbox:", bbox)
    spline = spline_data["spline_coords"][:]
    spline += [bbox[0], bbox[1]]


print(tensor.shape)
mask = tensor[:, :, 1]
plt.imshow(mask, cmap="gray")
plt.show()

plt.figure(figsize=(20, 20))
plt.imshow(image, cmap="afmhot")
# plot spline
plt.plot(spline[:, 1], spline[:, 0], color="blue", linewidth=2, alpha=0.5)
plt.show()


# grab point at the start of the spline
start_point = spline[0]
end_point = spline[-1]

# create a distance transform of the binary mask
from scipy.ndimage import distance_transform_edt
dist_transform = distance_transform_edt(mask == 1)
# plot it
plt.figure(figsize=(20, 20))
plt.imshow(dist_transform, cmap="gray")
plt.plot(start_point[1], start_point[0], "ro", markersize=10, label="Start Point")
plt.plot(end_point[1], end_point[0], "go", markersize=10, label="End Point")
plt.legend()
plt.show()


# Pathfind the best path from start to end

from skimage.graph import route_through_array

print(np.min(dist_transform), np.max(dist_transform))

# Create a cost array where lower values are better
# invert the distance transform so that lower distances are preferred
dist_transform = np.max(dist_transform) - dist_transform
cost_array = dist_transform
# Use the route_through_array function to find the best path
start = (int(start_point[0]), int(start_point[1]))
end = (int(end_point[0]), int(end_point[1]))
path, cost = route_through_array(cost_array, start, end, fully_connected=True)
# Convert path to a numpy array
path = np.array(path)
# Plot the path on the original image
plt.figure(figsize=(20, 20))
plt.imshow(image, cmap="afmhot")
plt.plot(spline[:, 1], spline[:, 0], color="blue", linewidth=2, alpha=0.5, label="Original Spline")
plt.plot(path[:, 1], path[:, 0], color="red", linewidth=2, alpha=0.5, label="Path through Distance Transform")
plt.scatter(start_point[1], start_point[0], color="red", s=100, label="Start Point")
plt.scatter(end_point[1], end_point[0], color="green", s=100, label="End Point")
plt.legend()
plt.show()

# Plot mask with the paths compared
plt.figure(figsize=(20, 20))
plt.imshow(mask, cmap="gray")
plt.plot(spline[:, 1], spline[:, 0], color="blue", linewidth=2, alpha=0.5, label="Original Spline")
plt.plot(path[:, 1], path[:, 0], color="red", linewidth=2, alpha=0.5, label="Path through Distance Transform")
plt.scatter(start_point[1], start_point[0], color="red", s=100, label="Start Point")
plt.scatter(end_point[1], end_point[0], color="green", s=100, label="End Point")
plt.legend()
plt.show()
