In [None]:
from topostats.plottingfuncs import Colormap

import numpy as np
import matplotlib.pyplot as plt
import h5py

from skimage.morphology import skeletonize
from skimage.measure import label, regionprops

colormap = Colormap()
cmap = colormap.get_cmap()


# load flattened minicircles

with h5py.File("../tests/resources/file.topostats", "r") as f:
    print(f.keys())

    image = f["image"][:]
    mask = f["grain_masks"]["above"][:]

    # plt.imshow(image, cmap=cmap)
    # plt.show()

    # plt.imshow(mask, cmap=cmap)
    # plt.show()

    # Choose a grain
    grain_num = 7
    # get grain mask
    grain_mask = mask == grain_num

    # plt.imshow(grain_mask, cmap=cmap)
    # plt.show()

    labelled_grain_mask = label(grain_mask)
    grain_regionprops = regionprops(labelled_grain_mask)

    # crop the grain
    minr, minc, maxr, maxc = grain_regionprops[0].bbox
    padding = 5
    minr = minr - padding
    minc = minc - padding
    maxr = maxr + padding
    maxc = maxc + padding

    # make it square
    maxr = max(maxr, maxc - minc + minr)
    maxc = max(maxc, maxr - minr + minc)

    cropped_grain = grain_mask[minr:maxr, minc:maxc]
    cropped_image = image[minr:maxr, minc:maxc]

    # plt.imshow(cropped_grain, cmap=cmap)
    # plt.show()
    # plt.imshow(cropped_image, cmap=cmap)
    # plt.show()

    # downsample the data
    downsample_factor = 4
    cropped_image = cropped_image[::downsample_factor, ::downsample_factor]
    cropped_grain = cropped_grain[::downsample_factor, ::downsample_factor]

    # make contiguous
    cropped_grain = cropped_grain.astype(np.uint8)

    # skeletonize the grain mask
    skeleton = skeletonize(cropped_grain, method="lee")
    plt.imshow(skeleton.astype(bool), cmap=cmap)
    plt.show()

    plt.imshow(cropped_image, cmap=cmap)

    # spline the skeleton
    from scipy.interpolate import splprep, splev

    # get the coordinates of the skeleton
    skeleton_coords = np.argwhere(skeleton)

    # Trace the skeleton
    skeleton_points = np.argwhere(skeleton)
    # Find the start point
    start_point = skeleton_points[0]
    # Each point should not have more than 2 neighbours so we can trace by finding the next point
    # and removing the current point from the skeleton

    skeleton_history = skeleton.copy()
    trace = [start_point]
    current_point = start_point
    skeleton_history[current_point[0], current_point[1]] = 0

    # print(f"len of skeleton points: {len(skeleton_points)}")
    # print(f"index: {index}")
    for iteration in range(len(skeleton_points) - 1):
        neighbourhood = skeleton_history[
            current_point[0] - 1 : current_point[0] + 2, current_point[1] - 1 : current_point[1] + 2
        ]
        if iteration > 0 and np.sum(neighbourhood) > 1:
            raise ValueError(f"More than 1 neighbour for iteration {iteration}")
        if np.sum(neighbourhood) == 0:
            raise ValueError(f"No neighbours for iteration {iteration}")
        next_point = np.argwhere(neighbourhood)[0]
        next_point_coords = current_point + next_point - 1
        trace.append(next_point_coords)
        current_point = next_point_coords
        skeleton_history[current_point[0], current_point[1]] = 0
    trace = np.array(trace)

    ordered_skeleton = np.zeros_like(skeleton).astype(int)
    for point_index, point in enumerate(trace):
        ordered_skeleton[point[0], point[1]] = point_index + 20

    if True:
        fig, ax = plt.subplots(1, 3, figsize=(20, 10))
        ax[0].imshow(cropped_grain, cmap=cmap, vmin=-8, vmax=8)
        ax[0].set_title("grain image")
        ax[1].imshow(ordered_skeleton, cmap="viridis")
        ax[1].set_title("ordered skeleton")
        ax[2].imshow(cropped_grain, cmap=cmap, vmin=-8, vmax=8)
        ax[2].imshow(ordered_skeleton, cmap="viridis", alpha=0.2)
        ax[2].plot(trace[:, 1], trace[:, 0], "r")
        plt.show()

    for smoothing in [0.0, 3.0, 10.0]:
        # get the spline from the trace
        tck, u = splprep([trace[:, 1], trace[:, 0]], s=smoothing, per=1)
        u_new = np.linspace(u.min(), u.max(), 1000)
        x_new, y_new = splev(u_new, tck)

        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
        ax.imshow(cropped_image, cmap=cmap, vmin=-3, vmax=4)
        ax.scatter(skeleton_coords[:, 1], skeleton_coords[:, 0], c="k", s=100)
        ax.plot(x_new, y_new, c="teal", linewidth=5)
        # turn off axes completely
        ax.axis("off")
        # remove white border
        plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
        plt.show()