In [None]:
from pathlib import Path
from datetime import datetime
from IPython.display import clear_output
import pickle

import numpy as np
import matplotlib.pyplot as plt
import h5py
from skimage.morphology import label
from skimage.measure import regionprops
import seaborn as sns
import tensorflow as tf

from topostats.plottingfuncs import Colormap
from topostats.grain_finding_haribo_unet import predict_unet, load_model, iou

colormap = Colormap()
cmap = colormap.get_cmap()

In [None]:
MAX_PX_TO_NM = 10.0
BBOX_PAD_NM = 4
DUMB_THRESHOLD_ABSOLUTE_NM = 1.2
DUMB_UPPER_SIZE_THRESHOLD_NM2 = 350
DUMB_LOWER_SIZE_THRESHOLD_NM2 = 100
NORM_LOWER_BOUND = -1
NORM_UPPER_BOUND = 5

In [None]:
def plot_images(
    images: list,
    masks: list,
    px_to_nms: list,
    grain_indexes: list,
    width=5,
    cmap=cmap,
    vmin=-8,
    vmax=8,
    extra_title_datas=None,
):
    if extra_title_datas is None:
        extra_title_datas = ["" for _ in range(len(images))]
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 3, figsize=(30, 30))
    for i, (image, mask, grain_index, extra_title_data) in enumerate(
        zip(images, masks, grain_indexes, extra_title_datas)
    ):
        ax[i // width, i % width * 3].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3].axis("off")
        ax[i // width, i % width * 3 + 1].imshow(mask, cmap="binary")
        ax[i // width, i % width * 3].set_title(f"grain: {grain_index} p_to_nm: {px_to_nms[i]}\n{extra_title_data}")
        ax[i // width, i % width * 3 + 2].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3 + 2].imshow(mask, cmap="viridis", alpha=0.2)

    fig.tight_layout()
    plt.show()


today = datetime.now().strftime("%Y-%m-%d")

NEW_NAME_SAMPLE_TYPE = "ON_REL"
# Get the new nomenclature where OT1 = OT2 and OT2 = OT1
if NEW_NAME_SAMPLE_TYPE == "OT1_REL":
    OLD_NAME_SAMPLE_TYPE = "OT2_REL"
elif NEW_NAME_SAMPLE_TYPE == "OT2_REL":
    OLD_NAME_SAMPLE_TYPE = "OT1_REL"
elif NEW_NAME_SAMPLE_TYPE == "OT1_SC":
    OLD_NAME_SAMPLE_TYPE = "OT2_SC"
elif NEW_NAME_SAMPLE_TYPE == "OT2_SC":
    OLD_NAME_SAMPLE_TYPE = "OT1_SC"
# Keep the ON type the same
elif NEW_NAME_SAMPLE_TYPE == "ON_REL":
    OLD_NAME_SAMPLE_TYPE = "ON_REL"
elif NEW_NAME_SAMPLE_TYPE == "ON_SC":
    OLD_NAME_SAMPLE_TYPE = "ON_SC"
else:
    raise ValueError(f"Unknown NEW_NAME_SAMPLE_TYPE: {NEW_NAME_SAMPLE_TYPE}")

DATA_DIR = Path(
    f"/Users/sylvi/topo_data/hariborings/testing_all_unbound_data/output_{OLD_NAME_SAMPLE_TYPE}/processed/"
)
SAVE_DIR = Path(
    f"/Users/sylvi/topo_data/hariborings/extracted-grains-new-names-20250814/unbound_{NEW_NAME_SAMPLE_TYPE}/date_{today}/"
)
SAVE_DIR.mkdir(exist_ok=True, parents=True)
assert SAVE_DIR.exists()
assert DATA_DIR.exists()
# Grab all .topostats files
files = list(DATA_DIR.glob("*.topostats"))

model_name = "haribonet_dna_only_single_class_extra_doritos_2024-02-26_23-51-29_image-size-256x256_epochs-45_batch-size-25_learning-rate-0.001.h5"
MODEL_PATH = Path(f"/Users/sylvi/topo_data/hariborings/saved_models/dna_only_extra_doritos/{model_name}")
model = load_model(model_path=MODEL_PATH, custom_objects={"iou": iou})
assert MODEL_PATH.exists()

# Explore the data
file = files[1]
with h5py.File(file, "r") as f:
    image = f["image"][:]
    p_to_nm = f["pixel_to_nm_scaling"][()]
    print(f"Image shape: {image.shape}")
    print(f"p_to_nm: {p_to_nm}")

fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(image, cmap=cmap, vmin=-8, vmax=8)
ax.axis("off")
plt.show()

In [None]:
grains_processed = 0
stop_at_grain = 100000
plotting = False
plot_results = False

grain_dict = {}

for file in files:
    print(file)
    # Load file
    with h5py.File(file, "r") as f:
        print(f.keys())
        image = f["image"][:]
        # grain_masks = f["grain_masks"]["above"][:]
        p_to_nm = f["pixel_to_nm_scaling"][()]

    if p_to_nm > MAX_PX_TO_NM:
        continue

    # Get dumb grain masks
    mask = image > DUMB_THRESHOLD_ABSOLUTE_NM

    grain_masks = label(mask)
    grain_regionprops = regionprops(grain_masks)
    # Remove any grains from the mask that are too big
    if plotting:
        fig, ax = plt.subplots(1, 2, figsize=(20, 10))
        ax[0].imshow(image, cmap=cmap, vmin=-8, vmax=8)
        ax[0].set_title("image")
        ax[1].imshow(grain_masks, cmap="gray")
        ax[1].set_title("grain_masks")
        plt.suptitle(f"pixel to nm scaling: {p_to_nm}")
    for grain in grain_regionprops:
        if (
            grain.area * p_to_nm**2 > DUMB_UPPER_SIZE_THRESHOLD_NM2
            or grain.area * p_to_nm**2 < DUMB_LOWER_SIZE_THRESHOLD_NM2
        ):
            grain_masks[grain.label == grain_masks] = 0
            if plotting:
                # Add red rectangle around the grain
                minr, minc, maxr, maxc = grain.bbox
                rect = plt.Rectangle((minc, minr), maxc - minc, maxr - minr, fill=False, edgecolor="red", linewidth=2)
                ax[1].add_patch(rect)

    # Plot image and mask side by side
    if plotting:
        fig, ax = plt.subplots(1, 2, figsize=(20, 10))
        ax[0].imshow(image, cmap=cmap, vmin=-8, vmax=8)
        ax[0].set_title("image")
        ax[1].imshow(grain_masks, cmap="gray")
        ax[1].set_title("grain_masks")
        plt.suptitle(f"pixel to nm scaling: {p_to_nm}")
        fig.tight_layout()
        plt.show()

    # Process the grains
    grain_regionprops = regionprops(grain_masks)
    for grain in grain_regionprops:
        if grains_processed == stop_at_grain:
            break
        # Get the bounding box of the grain
        minr, minc, maxr, maxc = grain.bbox

        # Calculate the pixel padding from the nm padding
        BBOX_PAD_PX = int(BBOX_PAD_NM / p_to_nm)

        # Apply bounding box
        minr = minr - BBOX_PAD_PX
        minc = minc - BBOX_PAD_PX
        maxr = maxr + BBOX_PAD_PX
        maxc = maxc + BBOX_PAD_PX

        # Check if the bounding box breaks the bounds of the image
        if minr < 0 or minc < 0 or maxr > image.shape[0] or maxc > image.shape[1]:
            continue

        # Make the bounding box square if it's not too big for the image
        if maxr - minr > maxc - minc:
            diff = maxr - minr - (maxc - minc)
            proposed_minc = minc - diff // 2
            proposed_maxc = maxc + diff // 2
            # Check that the proposed crop is inside the image
            if proposed_minc >= 0 and proposed_maxc <= image.shape[1]:
                minc = proposed_minc
                maxc = proposed_maxc
            # If not, only expand the crop in one direction, away from the image border
            else:
                diff = maxr - minr - (maxc - minc)
                minc = max(0, minc - diff)
                maxc = min(image.shape[1], maxc + diff)
        elif maxc - minc > maxr - minr:
            diff = maxc - minc - (maxr - minr)
            proposed_minr = minr - diff // 2
            proposed_maxr = maxr + diff // 2
            # Check that the proposed crop is inside the image
            if proposed_minr >= 0 and proposed_maxr <= image.shape[0]:
                minr = proposed_minr
                maxr = proposed_maxr
            # If not, only expand the crop in one direction, away from the image border
            else:
                diff = maxc - minc - (maxr - minr)
                minr = max(0, minr - diff)
                maxr = min(image.shape[0], maxr + diff)

        # Get the crop of grain image
        grain_image = image[minr:maxr, minc:maxc]
        grain_mask_full = grain_masks == grain.label
        dumb_grain_mask = grain_mask_full[minr:maxr, minc:maxc]

        if plotting:
            fig, ax = plt.subplots(1, 3, figsize=(20, 10))
            ax[0].imshow(dumb_grain_mask, cmap="gray")
            ax[0].set_title("grain mask")
            ax[1].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
            ax[1].set_title("grain image")
            ax[2].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
            ax[2].imshow(dumb_grain_mask, cmap="gray", alpha=0.2)
            plt.show()

        # Predict the mask using the Unet
        grain_mask_pred = predict_unet(
            image=grain_image,
            model=model,
            confidence=0.5,
            model_image_size=256,
            image_output_dir=None,
            filename=file.stem,
            normalisation_set_range=(NORM_LOWER_BOUND, NORM_UPPER_BOUND),
            quiet=True,
        )

        if plotting:
            fig, ax = plt.subplots(1, 3, figsize=(20, 10))
            ax[0].imshow(dumb_grain_mask, cmap="gray")
            ax[0].set_title("grain mask")
            ax[1].imshow(grain_mask_pred, cmap="gray")
            ax[1].set_title("grain mask pred")
            ax[2].imshow(grain_image, cmap=cmap, vmin=-8, vmax=8)
            ax[2].imshow(dumb_grain_mask, cmap="viridis", alpha=0.2)
            plt.show()

        grain_dict[grains_processed] = {
            "image": grain_image,
            "dumb_grain_mask": dumb_grain_mask,
            "grain_mask": grain_mask_pred,
            "p_to_nm": p_to_nm,
        }

        grains_processed += 1

    if grains_processed == stop_at_grain:
        break

# Clear cell output
if not plotting:
    clear_output()

# Plot the grains
images = [grain_dict[i]["image"] for i in range(grains_processed)]
masks = [grain_dict[i]["grain_mask"] for i in range(grains_processed)]
px_to_nms = [grain_dict[i]["p_to_nm"] for i in range(grains_processed)]
grain_indexes = list(range(grains_processed))

print(f"num grains: {len(images)}")

if plot_results:
    plot_images(
        images,
        masks,
        px_to_nms,
        grain_indexes,
    )

In [None]:
# Plot a small subset of the grains

sample_indexes = np.random.choice(range(grains_processed), 10, replace=False)
images = [grain_dict[i]["image"] for i in sample_indexes]
masks = [grain_dict[i]["grain_mask"] for i in sample_indexes]
px_to_nms = [grain_dict[i]["p_to_nm"] for i in sample_indexes]
grain_indexes = sample_indexes

plot_images(
    images,
    masks,
    px_to_nms,
    grain_indexes,
)

In [None]:
# Clean up the masks
from skimage.morphology import binary_dilation, binary_erosion

DILATION_PASS = 2
ERODE_PASS = 2

LOWER_AREA_BOUND = 70
UPPER_AREA_BOUND = 10000
plot_results = False

dilated_grain_dict = {}

for index, grain_data in grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["grain_mask"]
    p_to_nm = grain_data["p_to_nm"]

    # Dilation
    for _ in range(DILATION_PASS):
        grain_mask = binary_dilation(grain_mask)
    # Erosion
    for _ in range(ERODE_PASS):
        grain_mask = binary_erosion(grain_mask)

    dilated_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "p_to_nm": p_to_nm,
    }

if plot_results:
    plot_images(
        [dilated_grain_dict[i]["image"] for i in range(grains_processed)],
        [dilated_grain_dict[i]["mask"] for i in range(grains_processed)],
        [dilated_grain_dict[i]["p_to_nm"] for i in range(grains_processed)],
        [i for i in range(grains_processed)],
    )

removed_anomaly_grain_dict = {}
bad_grains = {}
for index, grain_data in dilated_grain_dict.items():
    grain_image = grain_data["image"]
    grain_mask = grain_data["mask"]
    p_to_nm = grain_data["p_to_nm"]

    # Label the grains
    labelled_background = label(grain_mask == 0)
    background_props = regionprops(labelled_background)

    if len(background_props) < 2:
        print(f"Grain {index} has too few background regions")
        bad_grains[index] = {
            "image": grain_image,
            "mask": grain_mask,
            "p_to_nm": p_to_nm,
            "reason": "too few background regions",
        }
        continue
        # plt.imshow(labelled_background)
        # print(len(background_props))
        # plt.show()
    elif len(background_props) >= 3:
        print(f"Grain {index} has too many background regions")
        bad_grains[index] = {
            "image": grain_image,
            "mask": grain_mask,
            "p_to_nm": p_to_nm,
            "reason": "too many background regions",
        }
        continue
        # plt.imshow(labelled_background)
        # print(len(background_props))
        # plt.show()
    else:
        # Check the size of the foreground
        foreground_area = grain_mask.sum()
        if foreground_area < LOWER_AREA_BOUND:
            print(f"Grain {index} has too small foreground area")
            bad_grains[index] = {
                "image": grain_image,
                "mask": grain_mask,
                "p_to_nm": p_to_nm,
                "reason": "too small foreground area",
            }
            continue
        elif foreground_area > UPPER_AREA_BOUND:
            print(f"Grain {index} has too large foreground area")
            bad_grains[index] = {
                "image": grain_image,
                "mask": grain_mask,
                "p_to_nm": p_to_nm,
                "reason": "too large foreground area",
            }
            continue

    # Remove all but the largest foreground region
    labelled_grain = label(grain_mask)
    grain_props = regionprops(labelled_grain)
    grain_areas = [prop.area for prop in grain_props]
    max_area_index = np.argmax(grain_areas)
    grain_mask = labelled_grain == grain_props[max_area_index].label

    removed_anomaly_grain_dict[index] = {
        "image": grain_image,
        "mask": grain_mask,
        "p_to_nm": p_to_nm,
    }

if plot_results:
    plot_images(
        [removed_anomaly_grain_dict[i]["image"] for i in removed_anomaly_grain_dict],
        [removed_anomaly_grain_dict[i]["mask"] for i in removed_anomaly_grain_dict],
        [removed_anomaly_grain_dict[i]["p_to_nm"] for i in removed_anomaly_grain_dict],
        [i for i in removed_anomaly_grain_dict],
    )

# plot a sample of bad grains from the indexes of the bad grains
bad_grain_indexes = list(bad_grains.keys())
sample_indexes = np.random.choice(bad_grain_indexes, 10, replace=False)
bad_images = [bad_grains[i]["image"] for i in sample_indexes]
bad_masks = [bad_grains[i]["mask"] for i in sample_indexes]
bad_px_to_nms = [bad_grains[i]["p_to_nm"] for i in sample_indexes]
bad_grain_indexes = sample_indexes
bad_reasons = [bad_grains[i]["reason"] for i in sample_indexes]
extra_title_datas = [f"reason: {reason}" for reason in bad_reasons]

plot_images(
    bad_images,
    bad_masks,
    bad_px_to_nms,
    bad_grain_indexes,
    extra_title_datas=extra_title_datas,
)

In [None]:
print(f"Number of grains before cleaning: {len(grain_dict)}")
print(f"Number of grains after cleaning: {len(removed_anomaly_grain_dict)}")

In [None]:
# Plot a sample of the grains

sample_size = 40

sample_grain_indexes = np.random.choice(list(removed_anomaly_grain_dict.keys()), sample_size, replace=False)

plot_images(
    [removed_anomaly_grain_dict[i]["image"] for i in sample_grain_indexes],
    [removed_anomaly_grain_dict[i]["mask"] for i in sample_grain_indexes],
    [removed_anomaly_grain_dict[i]["p_to_nm"] for i in sample_grain_indexes],
    [i for i in sample_grain_indexes],
)

In [None]:
# Save the dictionary

file_path = SAVE_DIR / "grain_dict.pkl"
print(f"saving to {file_path}")
with open(file_path, "wb") as f:
    pickle.dump(removed_anomaly_grain_dict, f)
    print(f"saved grain_dict.pkl to {SAVE_DIR}")