In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import h5py

from topostats.plottingfuncs import Colormap

colormap = Colormap()
cmap = colormap.get_cmap()

In [None]:
def plot(
    image: np.ndarray, title: str = None, vmin: float = -8, vmax: float = 8, cmap=cmap, figsize=(10, 10), cbar=False
):
    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
    if title is not None:
        ax.set_title(title)
    if cbar:
        fig.colorbar(im, ax=ax)
    plt.show()

In [None]:
MAX_PX_TO_NM = 0.59
BBOX_PAD = 5

In [None]:
def plot_images(images: list, masks: list, px_to_nms: list, width=5, cmap=cmap, vmin=-8, vmax=8):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 3, figsize=(30, 30))
    for i, (image, mask) in enumerate(zip(images, masks)):
        ax[i // width, i % width * 3].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3].axis("off")
        ax[i // width, i % width * 3 + 1].imshow(mask, cmap="binary")
        ax[i // width, i % width * 3].set_title(f"p_to_nm: {px_to_nms[i]}")
        ax[i // width, i % width * 3 + 2].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3 + 2].imshow(mask, cmap="viridis", alpha=0.2)
    fig.tight_layout()
    plt.show()


on_rel = Path("/Users/sylvi/topo_data/hariborings/testing_all_unbound_data/output_ON_SC/processed/")
assert on_rel.exists()
# Grab all .topostats files
on_files = list(on_rel.glob("*.topostats"))

# file = on_files[1]

grains_processed = 0
stop_at_grain = 200
plotting = False

grain_dict = {}

for file in on_files:
    print(file)
    # Load file
    with h5py.File(file, "r") as f:
        print(f.keys())
        image = f["image"][:]
        grain_masks = f["grain_masks"]["above"][:]
        p_to_nm = f["pixel_to_nm_scaling"][()]

    if p_to_nm > MAX_PX_TO_NM:
        continue

    # Plot image and mask side by side
    if plotting:
        fig, ax = plt.subplots(1, 2, figsize=(20, 10))
        ax[0].imshow(image, cmap=cmap, vmin=-8, vmax=8)
        ax[0].set_title("image")
        ax[1].imshow(grain_masks, cmap="gray")
        ax[1].set_title("grain_masks")
        plt.suptitle(f"pixel to nm scaling: {p_to_nm}")
        fig.tight_layout()
        plt.show()

    # Process the grains
    for grain in range(1, grain_masks.max() + 1):
        if grains_processed == stop_at_grain:
            break
        # Get the bounding box of the grain
        grain_mask_fullsize = grain_masks == grain
        grain_bbox = np.argwhere(grain_mask_fullsize)
        minr, minc = grain_bbox.min(axis=0)
        maxr, maxc = grain_bbox.max(axis=0)
        # Add padding to the bounding box
        minr = max(0, minr - BBOX_PAD)
        minc = max(0, minc - BBOX_PAD)
        maxr = min(grain_mask_fullsize.shape[0], maxr + BBOX_PAD)
        maxc = min(grain_mask_fullsize.shape[1], maxc + BBOX_PAD)

        # Get the crop of grain image
        grain_image = image[minr:maxr, minc:maxc]
        grain_mask = grain_mask_fullsize[minr:maxr, minc:maxc]

        if plotting:
            fig, ax = plt.subplots(1, 3, figsize=(20, 10))
            ax[0].imshow(grain_mask, cmap="gray")
            ax[0].set_title("grain mask")
            ax[1].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
            ax[1].set_title("grain image")
            ax[2].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
            ax[2].imshow(grain_mask, cmap="gray", alpha=0.2)
            plt.show()

        grain_dict[grains_processed] = {
            "image": grain_image,
            "mask": grain_mask,
            "p_to_nm": p_to_nm,
        }

        grains_processed += 1

    if grains_processed == stop_at_grain:
        break

# Plot the grains
images = [grain_dict[i]["image"] for i in range(grains_processed)]
masks = [grain_dict[i]["mask"] for i in range(grains_processed)]
px_to_nms = [grain_dict[i]["p_to_nm"] for i in range(grains_processed)]
plot_images(images, masks, px_to_nms)

In [None]:
# Clean up the masks
from skimage.morphology import binary_dilation, binary_erosion

DILATION_PASS = 2
ERODE_PASS = 2

dilated_grain_dict = {}

for index, grain_data in grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]

    # Dilation
    for _ in range(DILATION_PASS):
        grain_mask = binary_dilation(grain_mask)
    # Erosion
    for _ in range(ERODE_PASS):
        grain_mask = binary_erosion(grain_mask)

    dilated_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "p_to_nm": p_to_nm,
    }

plot_images(
    [dilated_grain_dict[i]["image"] for i in range(grains_processed)],
    [dilated_grain_dict[i]["mask"] for i in range(grains_processed)],
    [dilated_grain_dict[i]["p_to_nm"] for i in range(grains_processed)],
)


from skimage.measure import label, regionprops

LOWER_AREA_BOUND = 100
UPPER_AREA_BOUND = 10000

removed_anomaly_grain_dict = {}
for index, grain_data in dilated_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]

    # Label the grains
    labelled_background = label(grain_mask == 0)
    background_props = regionprops(labelled_background)

    if len(background_props) < 2:
        print(f"Grain {index} has too few background regions")
        # plt.imshow(labelled_background)
        # print(len(background_props))
        # plt.show()
    elif len(background_props) >= 3:
        print(f"Grain {index} has too many background regions")
        # plt.imshow(labelled_background)
        # print(len(background_props))
        # plt.show()
    else:
        # Check the size of the foreground
        foreground_area = grain_mask.sum()
        if foreground_area < LOWER_AREA_BOUND:
            print(f"Grain {index} has too small foreground area")
        elif foreground_area > UPPER_AREA_BOUND:
            print(f"Grain {index} has too large foreground area")
        else:
            removed_anomaly_grain_dict[index] = grain_data

plot_images(
    [removed_anomaly_grain_dict[i]["image"] for i in removed_anomaly_grain_dict],
    [removed_anomaly_grain_dict[i]["mask"] for i in removed_anomaly_grain_dict],
    [removed_anomaly_grain_dict[i]["p_to_nm"] for i in removed_anomaly_grain_dict],
)

In [None]:
# Skeletonise using standard skeletonise


def plot_images(images: list, masks: list, px_to_nms: list, skeletons: list, width=5, cmap=cmap, vmin=-8, vmax=8):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 3, figsize=(30, 30))
    for i, (image, mask, skeleton) in enumerate(zip(images, masks, skeletons)):
        ax[i // width, i % width * 3].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3].axis("off")
        ax[i // width, i % width * 3 + 1].imshow(skeleton, cmap="binary")
        ax[i // width, i % width * 3].set_title(f"p_to_nm: {px_to_nms[i]}")
        ax[i // width, i % width * 3 + 2].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3 + 2].imshow(skeleton, cmap="viridis", alpha=0.2)
    fig.tight_layout()
    plt.show()


from scipy.ndimage import convolve


def convolve_skelly(skeleton) -> np.ndarray:
    """Convolves the skeleton with a 3x3 ones kernel to produce an array
    of the skeleton as 1, endpoints as 2, and nodes as 3.

    Parameters
    ----------
    skeleton: np.ndarray
        Single pixel thick binary trace(s) within an array.

    Returns
    -------
    np.ndarray
        The skeleton (=1) with endpoints (=2), and crossings (=3) highlighted.
    """
    conv = convolve(skeleton.astype(np.int32), np.ones((3, 3)))
    conv[skeleton == 0] = 0  # remove non-skeleton points
    conv[conv == 3] = 1  # skelly = 1
    conv[conv > 3] = 3  # nodes = 3
    return conv


plotting = False
paths_grain_dict = {}

for index, grain_data in removed_anomaly_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]

    # Skeletonise
    from skimage.morphology import skeletonize

    skeleton = skeletonize(grain_mask)

    if plotting:
        fig, ax = plt.subplots(1, 4, figsize=(20, 10))
        ax[0].imshow(grain_mask, cmap="gray")
        ax[0].set_title("grain mask")
        ax[1].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        ax[1].set_title("grain image")
        ax[2].imshow(skeleton, cmap="gray")
        ax[2].set_title("skeleton")
        ax[3].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
        ax[3].imshow(skeleton, cmap="viridis", alpha=0.2)
        plt.show()

    # Ignore any grains that have a branch, ie a pixel with more than 2 neighbours
    convolved_skelly = convolve_skelly(skeleton)

    if np.max(convolved_skelly) > 1:
        print(f"Grain {index} has a branch")
        continue

    paths_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "skeleton": skeleton,
        "p_to_nm": p_to_nm,
    }

plot_images(
    [paths_grain_dict[i]["image"] for i in paths_grain_dict],
    [paths_grain_dict[i]["mask"] for i in paths_grain_dict],
    [paths_grain_dict[i]["p_to_nm"] for i in paths_grain_dict],
    [paths_grain_dict[i]["skeleton"] for i in paths_grain_dict],
)

In [None]:
# # Skeletonise using topostats
# from topostats.tracing.dnatracing import dnaTrace

# for index, grain_data in removed_anomaly_grain_dict.items():
#     grain_image = grain_data["image"]
#     grain_mask = grain_data["mask"]
#     p_to_nm = grain_data["p_to_nm"]

#     plt.imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
#     plt.show()
#     plt.imshow(grain_mask, cmap="gray")
#     plt.show()

#     dnatrace = dnaTrace(
#         image=grain_image,
#         grain=grain_mask,
#         filename="test",
#         pixel_to_nm_scaling=p_to_nm,
#     )
#     dnatrace.trace_dna()

#     plt.imshow(dnatrace.smoothed_grain_copy, cmap="gray")
#     plt.show()

#     disordered_trace = dnatrace.disordered_trace
#     disordered_trace_image = np.zeros_like(grain_image)
#     disordered_trace_image[disordered_trace[:, 0], disordered_trace[:, 1]] = 1
#     plt.imshow(disordered_trace_image, cmap=cmap, vmin=-8, vmax=8)
#     plt.show()