In [None]:
from pathlib import Path
from typing import List, Tuple, Union, Dict
import re
from datetime import datetime
import pickle

import numpy as np
import h5py
import matplotlib.pyplot as plt
import tensorflow as tf
from skimage.morphology import binary_dilation, skeletonize
from skimage.measure import label, regionprops
from skimage.color import label2rgb
from skimage.graph import route_through_array
from sklearn.cluster import KMeans
from scipy.ndimage import distance_transform_edt
import seaborn as sns
from skimage.morphology import binary_erosion
from scipy.ndimage import binary_fill_holes
from skimage.feature import canny
from scipy.interpolate import splprep, splev

from topostats.grain_finding_haribo_unet import (
    predict_unet,
    load_model,
    predict_unet_multiclass_and_get_angle,
    mean_iou,
    iou,
    predict_unet_multiclass,
)

from IPython.display import clear_output

from topostats.plottingfuncs import Colormap

colormap = Colormap()
CMAP = colormap.get_cmap()

VMIN = 0
VMAX = 4

# Get grain crops

In [None]:
SAMPLE = "OT2_SC"
MAX_P_TO_NM = 10.0
PLOT_RESULTS = True
SIMPLE_HEIGHT_THRESHOLD = 1.0
SIMPLE_CROP_PADDING = 5
SIMPLE_AREA_THRESHOLDS = (500, 100000000000)
FLATTENED_IMAGE_DIR = Path(
    "/Volumes/shared/pyne_group/Shared/AFM_Data/Cas9_Minicircles/Analysis_all/DNA_Cas9/output_justboundcas9"
)
assert FLATTENED_IMAGE_DIR.exists()

In [None]:
# Get data files
sample_data_dir = FLATTENED_IMAGE_DIR / SAMPLE / "processed"
data_files = list(sample_data_dir.glob("*.topostats"))
print(f"num files: {len(data_files)}")

# Load the image from the hdf5 files
flattened_image_dict = {}

for image_index, data_file in enumerate(data_files):
    with h5py.File(data_file, "r") as f:
        # print(f"keys: {list(f.keys())}")
        flattened_image = f["image"][:]
        p_to_nm = f["pixel_to_nm_scaling"][()]
        if p_to_nm > MAX_P_TO_NM:
            # print(f"Skipping image {data_file} due to large p_to_nm: {p_to_nm}")
            continue
        flattened_image_dict[image_index] = {
            "image": flattened_image,
            "p_to_nm": p_to_nm,
            "file_path": data_file,
        }
    print(f"Loaded image {data_file}")

In [None]:
# For each image, threshold generously and crop the regions

grain_crop_dictionaries = {}
masked_image_dictionaries = {}
for image_index, image_dict in flattened_image_dict.items():
    image = image_dict["image"]
    p_to_nm = image_dict["p_to_nm"]
    file_path = image_dict["file_path"]

    # Threshold the image
    simple_mask = image > SIMPLE_HEIGHT_THRESHOLD

    # plt.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
    # plt.show()
    # plt.imshow(simple_mask)
    # plt.show()

    # Label
    labelled_simple_mask = label(simple_mask)

    # Vet the regions based on size
    region_props = regionprops(labelled_simple_mask)

    vetted_bounding_boxes = []

    for region in region_props:
        area_nm = region.area * p_to_nm**2
        if region.area < SIMPLE_AREA_THRESHOLDS[0] or region.area > SIMPLE_AREA_THRESHOLDS[1]:
            # print(f"Skipping region with area {area_nm} nm^2")
            # remove the region from the labeled mask
            labelled_simple_mask[labelled_simple_mask == region.label] = 0
            continue

        # Get the bounding box
        minr, minc, maxr, maxc = region.bbox

        # Add padding to the crop
        minr = minr - SIMPLE_CROP_PADDING
        minc = minc - SIMPLE_CROP_PADDING
        maxr = maxr + SIMPLE_CROP_PADDING
        maxc = maxc + SIMPLE_CROP_PADDING

        maxr = max(maxr, maxc - minc + minr)
        maxc = max(maxc, maxr - minr + minc)

        # Check if the crop is out of bounds
        if minr < 0 or minc < 0 or maxr > image.shape[0] or maxc > image.shape[1]:
            # print(f"Skipping region with area {area_nm} nm^2 due to out of bounds crop")
            # remove the region from the labeled mask
            labelled_simple_mask[labelled_simple_mask == region.label] = 0
            continue

        # Crop the image
        cropped_image = image[minr:maxr, minc:maxc]
        cropped_mask = simple_mask[minr:maxr, minc:maxc]

        # Save the cropped image
        grain_crop_dictionaries[(image_index)] = {
            "cropped_image": cropped_image,
            "cropped_mask": cropped_mask,
            "region": region,
            "p_to_nm": p_to_nm,
            "file_path": file_path,
            "minr": minr,
            "minc": minc,
            "maxr": maxr,
            "maxc": maxc,
        }

        vetted_bounding_boxes.append((minr, minc, maxr, maxc))

    # save the mask with vetted regions
    masked_image_dictionaries[image_index] = {
        "image": image,
        "p_to_nm": p_to_nm,
        "file_path": file_path,
        "simple_mask": simple_mask,
        "vetted_bounding_boxes": vetted_bounding_boxes,
    }

    # plt.imshow(labelled_simple_mask)

print(f"found {len(grain_crop_dictionaries)} cropped grains")

In [None]:
# Plot a random selection of 10 images

max_index = len(flattened_image_dict)
indices = np.random.choice(max_index, 10, replace=False)

fig, axes = plt.subplots(10, 2, figsize=(10, 5 * len(indices)))
# Plot images, simple masks and bounding boxes
for i, index in enumerate(indices):
    image = masked_image_dictionaries[index]["image"]
    simple_mask = masked_image_dictionaries[index]["simple_mask"]
    vetted_bounding_boxes = masked_image_dictionaries[index]["vetted_bounding_boxes"]

    axes[i, 0].imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
    axes[i, 0].set_title(f"Image {index}")

    axes[i, 1].imshow(simple_mask)
    axes[i, 1].set_title(f"Simple mask {index}")

    for minr, minc, maxr, maxc in vetted_bounding_boxes:
        rect = plt.Rectangle((minc, minr), maxc - minc, maxr - minr, fill=False, edgecolor="red")
        axes[i, 0].add_patch(rect)

plt.tight_layout()
plt.show()

In [None]:
# Plot a gallery of all the cropped grains


def plot_images(
    images: list,
    masks: list,
    grain_indexes: list,
    px_to_nms: list,
    file_paths: list,
    width=3,
    VMIN=VMIN,
    VMAX=VMAX,
    cmap=CMAP,
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 20, num_rows * 8))
    for i, (image, mask, grain_index, p_to_nm, file_path) in enumerate(
        zip(images, masks, grain_indexes, px_to_nms, file_paths)
    ):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        im_ax.set_title(f"Grain {grain_index} {p_to_nm:.2f} p/nm \n {file_path.stem}", fontsize=20)
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int))
        mask_ax.axis("off")

    fig.tight_layout()
    plt.show()


# Plot the gallery
plot_images(
    images=[grain_crop_dictionaries[i]["cropped_image"] for i in grain_crop_dictionaries],
    masks=[grain_crop_dictionaries[i]["cropped_mask"] for i in grain_crop_dictionaries],
    grain_indexes=[i for i in grain_crop_dictionaries],
    px_to_nms=[grain_crop_dictionaries[i]["p_to_nm"] for i in grain_crop_dictionaries],
    file_paths=[grain_crop_dictionaries[i]["file_path"] for i in grain_crop_dictionaries],
)

# Load model

In [None]:
# Sadly needs to be local to the script because of loading times
MODEL_PATH = Path("./haribonet_multiclass_improved_norm_big_95_bridging_v1_2024-01-17_10-58-46.h5")
model = load_model(model_path=MODEL_PATH, custom_objects={"iou": iou, "mean_iou": mean_iou})
MODEL_CONFIDENCE = 0.5


today = datetime.today().strftime("%Y-%m-%d")

IMAGE_SAVE_DIR = Path(f"/Users/sylvi/topo_data/hariborings/extracted_grains/cas9_{SAMPLE}/{today}/")
IMAGE_SAVE_DIR.mkdir(exist_ok=True, parents=True)

# Use automatic crops for segmentation

In [None]:
grain_dicts = {}

for grain_index, grain_crop_dict in grain_crop_dictionaries.items():
    image = grain_crop_dict["cropped_image"]

    # plt.imshow(image)
    # plt.show()

    p_to_nm = grain_crop_dict["p_to_nm"]

    predicted_mask = predict_unet_multiclass(
        image=image,
        model=model,
        confidence=MODEL_CONFIDENCE,
        model_image_size=256,
        image_output_dir=IMAGE_SAVE_DIR,
        filename="test",
        image_index=grain_index,
        quiet=True,
        IMAGE_SAVE_DIR=IMAGE_SAVE_DIR,
        normalisation_set_range=(-1, 8),
    )

    grain_dicts[grain_index] = {
        "image": image,
        "predicted_mask": predicted_mask,
        "p_to_nm": p_to_nm,
    }

print(f"Number of images: {len(grain_dicts)}")

In [None]:
# # Plot the crops
# for grain_index, grain_dict in grain_dicts.items():
#     image = grain_dict["image"]
#     predicted_mask = grain_dict["predicted_mask"]
#     p_to_nm = grain_dict["p_to_nm"]

#     plt.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
#     plt.show()
#     plt.imshow(predicted_mask)
#     plt.show()

# Get existing crops from file

In [None]:
# CROPPED_IMAGE_DIR = Path(f"/Users/sylvi/topo_data/hariborings/cas9_crops_p2nm/{SAMPLE}_p2nm")
# assert CROPPED_IMAGE_DIR.exists()
# image_files = list(CROPPED_IMAGE_DIR.glob("*.npy"))
# image_files = sorted(image_files, key=lambda x: float(re.findall(r"\d+\.\d+", x.name)[0]))
# print(f"Found {len(image_files)} images")

# grain_dicts = {}

# for index, image_file in enumerate(image_files):
#     image = np.load(image_file)
#     p_to_nm = float(re.findall(r"\d+\.\d+", image_file.name)[0])

#     if p_to_nm > MAX_P_TO_NM:
#         continue

#     predicted_mask = predict_unet_multiclass(
#         image=image,
#         model=model,
#         confidence=MODEL_CONFIDENCE,
#         model_image_size=256,
#         image_output_dir=IMAGE_SAVE_DIR,
#         filename="test",
#         image_index=index,
#         quiet=True,
#         IMAGE_SAVE_DIR=IMAGE_SAVE_DIR,
#         normalisation_set_range=(-1, 8),
#     )

#     grain_dicts[index] = {
#         "image": image,
#         "predicted_mask": predicted_mask,
#         "p_to_nm": p_to_nm,
#     }

# clear_output()
# print(f"Number of images: {len(grain_dicts)}")

In [None]:
def plot_images(
    images: list, masks: list, grain_indexes: list, px_to_nms: list, width=5, VMIN=VMIN, VMAX=VMAX, cmap=CMAP
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 4, num_rows * 4))
    for i, (image, mask, grain_index, p_to_nm) in enumerate(zip(images, masks, grain_indexes, px_to_nms)):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        im_ax.set_title(f"Grain {grain_index} {p_to_nm} p/nm")
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int))
        mask_ax.axis("off")

    fig.tight_layout()
    plt.show()


if PLOT_RESULTS:
    images = [grain_dicts[i]["image"] for i in grain_dicts]
    masks = [grain_dicts[i]["predicted_mask"] for i in grain_dicts]
    grain_indexes = [i for i in grain_dicts]
    px_to_nms = [grain_dicts[i]["p_to_nm"] for i in grain_dicts]
    plot_images(images, masks, grain_indexes, px_to_nms)

### Vet based on numbers of regions

In [None]:
def check_ring_and_mask_exists(combined_predicted_mask: np.ndarray):
    # Check if there is a ring and gem larger than n pixels in the predicted mask

    min_ring_pixels = 40
    min_gem_pixels = 40

    ring_mask = combined_predicted_mask == 1
    gem_mask = combined_predicted_mask == 2
    if np.sum(ring_mask) < min_ring_pixels or np.sum(gem_mask) < min_gem_pixels:
        return False
    return True


def turn_small_gem_regions_into_ring(combined_predicted_mask: np.ndarray):
    gem_mask = combined_predicted_mask == 2

    # Find largest gem region
    gem_labels = label(gem_mask)
    gem_regions = regionprops(gem_labels)
    gem_areas = [region.area for region in gem_regions]
    largest_gem_region = gem_regions[np.argmax(gem_areas)]

    # For all other regions, check if they touch a ring region
    for region in gem_regions:
        if region.label == largest_gem_region.label:
            continue
        # Get only the pixels in the region
        region_mask = gem_labels == region.label
        # Dilate the region
        small_gem_dilation_strength = 5
        dilated_region_mask = region_mask
        for i in range(small_gem_dilation_strength):
            dilated_region_mask = binary_dilation(dilated_region_mask)
        # Get the intersection with the ring mask
        intersection = dilated_region_mask & (combined_predicted_mask == 1)
        # If there is any intersection, turn the region into a ring
        if np.any(intersection):
            combined_predicted_mask[dilated_region_mask] = 1

    return combined_predicted_mask


def remove_all_but_largest_ring_region(combined_predicted_mask: np.ndarray):
    ring_mask = combined_predicted_mask == 1
    # Find largest region
    ring_labels = label(ring_mask)
    ring_regions = regionprops(ring_labels)
    ring_areas = [region.area for region in ring_regions]
    largest_ring_region = ring_regions[np.argmax(ring_areas)]
    # For all others, turn to background
    for region in ring_regions:
        if region.label == largest_ring_region.label:
            continue
        combined_predicted_mask[ring_labels == region.label] = 0

    return combined_predicted_mask


def get_number_of_connection_points(combined_predicted_mask: np.ndarray):
    ring_mask = combined_predicted_mask == 1
    gem_mask = combined_predicted_mask == 2
    # Dilate the gem mask
    gem_dilation_strength = 1
    dilated_gem_mask = gem_mask
    for i in range(gem_dilation_strength):
        dilated_gem_mask = binary_dilation(dilated_gem_mask)
    # Get the intersection with the ring mask
    intersection = dilated_gem_mask & ring_mask

    # Get number of separate intersection regions
    intersection_labels = label(intersection)
    intersection_regions = regionprops(intersection_labels)
    num_connection_regions = len(intersection_regions)

    return num_connection_regions, intersection_labels


vetted_grain_dict = {}
failed_indexes = []
for index, grain_dict in grain_dicts.items():
    image = grain_dict["image"]
    predicted_mask = grain_dict["predicted_mask"]
    p_to_nm = grain_dict["p_to_nm"]

    if not check_ring_and_mask_exists(predicted_mask):
        failed_indexes.append(index)
        continue

    predicted_mask = turn_small_gem_regions_into_ring(predicted_mask)

    predicted_mask = remove_all_but_largest_ring_region(predicted_mask)

    num_connection_regions, intersection_labels = get_number_of_connection_points(predicted_mask)

    if num_connection_regions != 2:
        failed_indexes.append(index)
        continue

    vetted_grain_dict[index] = {
        "image": image,
        "predicted_mask": predicted_mask,
        "p_to_nm": p_to_nm,
        "intersection_labels": intersection_labels,
    }

print(f"Number of vetted grains: {len(vetted_grain_dict)}")

if PLOT_RESULTS:
    print(f"Failed indexes: {failed_indexes}")
    failed_images = [grain_dicts[i]["image"] for i in failed_indexes]
    failed_masks = [grain_dicts[i]["predicted_mask"] for i in failed_indexes]
    failed_grain_indexes = [i for i in failed_indexes]
    failed_px_to_nms = [grain_dicts[i]["p_to_nm"] for i in failed_indexes]
    plot_images(failed_images, failed_masks, failed_grain_indexes, failed_px_to_nms)

if PLOT_RESULTS:
    images = [vetted_grain_dict[i]["image"] for i in vetted_grain_dict]
    masks = [vetted_grain_dict[i]["predicted_mask"] for i in vetted_grain_dict]
    grain_indexes = [i for i in vetted_grain_dict]
    px_to_nms = [vetted_grain_dict[i]["p_to_nm"] for i in vetted_grain_dict]
    plot_images(images, masks, grain_indexes, px_to_nms)

# Save the crops

In [None]:
# with open(IMAGE_SAVE_DIR / "grain_dict.pkl", "wb") as f:
#     pickle.dump(vetted_grain_dict, f)
#     print(f"saved grain_dict.pkl to {IMAGE_SAVE_DIR}")