In [None]:
from pathlib import Path
import pickle

import numpy as np
import matplotlib.pyplot as plt
import h5py
from skimage.morphology import skeletonize, label
from skimage.measure import regionprops
from scipy.ndimage import distance_transform_edt, binary_fill_holes, binary_dilation
from skimage.graph import route_through_array

import tensorflow as tf

from topostats.plottingfuncs import Colormap
from topostats.measure.feret import get_feret_from_mask
from topostats.grain_finding_haribo_unet import predict_unet, load_model, iou, predict_unet_multiclass, mean_iou

colormap = Colormap()
cmap = colormap.get_cmap()

In [None]:
UNBOUND_SC_DATA_DIR = Path(f"/Users/sylvi/topo_data/hariborings/testing_all_unbound_data/output_ON_SC/processed/")
UNBOUND_REL_DATA_DIR = Path(f"/Users/sylvi/topo_data/hariborings/testing_all_unbound_data/output_ON_REL/processed/")
BOUND_DATA_DIR = Path(
    "/Volumes/shared-2/pyne_group/Shared/AFM_Data/Cas9_Minicircles/Analysis_all/DNA_Cas9/output_justboundcas9/ON_SC/processed/"
)
SAVE_DIR = Path(f"/Volumes/shared-2/pyne_group/Shared/Papers/cas9_minicircles/si_figure")
assert UNBOUND_SC_DATA_DIR.exists()
assert UNBOUND_REL_DATA_DIR.exists()
assert Path("/Volumes/shared-2/pyne_group/").exists()
assert BOUND_DATA_DIR.exists()

assert SAVE_DIR.exists()
unbound_model_name = "haribonet_dna_only_single_class_extra_doritos_2024-02-26_23-51-29_image-size-256x256_epochs-45_batch-size-25_learning-rate-0.001.h5"
UNBOUND_MODEL_PATH = Path(
    f"/Users/sylvi/topo_data/hariborings/saved_models/dna_only_extra_doritos/{unbound_model_name}"
)
BOUND_MODEL_PATH = MODEL_PATH = Path(
    "/Users/sylvi/topo_data/hariborings/saved_models/haribonet_multiclass_improved_norm_big_95_bridging_v1_2024-01-17_10-58-46.h5"
)
assert UNBOUND_MODEL_PATH.exists()
assert BOUND_MODEL_PATH.exists()
unbound_model = load_model(UNBOUND_MODEL_PATH, custom_objects={"iou": iou})
bound_model = load_model(BOUND_MODEL_PATH, custom_objects={"mean_iou": mean_iou, "iou": iou})

VMIN_UNBOUND = -4
VMAX_UNBOUND = 4

In [None]:
image_dictionary = {}

# Unbound SC

In [None]:
# grab all .topostats files from unbound
unbound_sc_files = list(UNBOUND_SC_DATA_DIR.glob("*.topostats"))
print(f"Found {len(unbound_sc_files)} files")
file_index = 40
unbound_sc_file = unbound_sc_files[file_index]
with h5py.File(unbound_sc_file, "r") as f:
    image = f["image"][:]
    p_to_nm = f["pixel_to_nm_scaling"][()]
    print(f"file: {file_index} : {unbound_sc_file}")
    print(f"p_to_nm: {p_to_nm}")

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(image, cmap=cmap, vmin=VMIN_UNBOUND, vmax=VMAX_UNBOUND)
plt.show()

# Crop a grain
crop_size_nm = 28
crop_size = int(crop_size_nm / p_to_nm)
crop_x = 210
crop_y = 290
unbound_sc_crop = image[crop_y : crop_y + crop_size, crop_x : crop_x + crop_size]
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(unbound_sc_crop, cmap=cmap, vmin=VMIN_UNBOUND, vmax=VMAX_UNBOUND)
plt.show()

unbound_sc_prediction = predict_unet(
    image=unbound_sc_crop,
    model=unbound_model,
    confidence=0.5,
    model_image_size=256,
    image_output_dir=None,
    filename=None,
    normalisation_set_range=(-1, 5),
)
plt.imshow(unbound_sc_crop, cmap=cmap, vmin=VMIN_UNBOUND, vmax=VMAX_UNBOUND)
plt.imshow(np.ma.masked_where(unbound_sc_prediction == 0, unbound_sc_prediction), cmap="gray", alpha=0.5)
plt.show()

# Unbound REL

In [None]:
# grab all .topostats files from unbound rel
unbound_rel_files = list(UNBOUND_REL_DATA_DIR.glob("*.topostats"))
print(f"Found {len(unbound_rel_files)} files")
file_index = 2
unbound_rel_file = unbound_rel_files[file_index]
with h5py.File(unbound_rel_file, "r") as f:
    image = f["image"][:]
    p_to_nm = f["pixel_to_nm_scaling"][()]
    print(f"file: {file_index} : {unbound_rel_file}")
    print(f"p_to_nm: {p_to_nm}")

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(image, cmap=cmap, vmin=VMIN_UNBOUND, vmax=VMAX_UNBOUND)
plt.show()

# Crop a grain
crop_size_nm = 28
crop_size = int(crop_size_nm / p_to_nm)
crop_x = 160
crop_y = 50
unbound_rel_crop = image[crop_y : crop_y + crop_size, crop_x : crop_x + crop_size]
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(unbound_rel_crop, cmap=cmap, vmin=VMIN_UNBOUND, vmax=VMAX_UNBOUND)
plt.show()

unbound_rel_prediction = predict_unet(
    image=unbound_rel_crop,
    model=unbound_model,
    confidence=0.5,
    model_image_size=256,
    image_output_dir=None,
    filename=None,
    normalisation_set_range=(-1, 5),
)
plt.imshow(unbound_rel_crop, cmap=cmap, vmin=VMIN_UNBOUND, vmax=VMAX_UNBOUND)
plt.imshow(np.ma.masked_where(unbound_rel_prediction == 0, unbound_rel_prediction), cmap="gray", alpha=0.5)
plt.show()

# Bound

In [None]:
# grab all .topostats files from unbound rel
bound_files = list(BOUND_DATA_DIR.glob("*.topostats"))
print(f"Found {len(bound_files)} files")
file_index = 3
bound_file = bound_files[file_index]
with h5py.File(bound_file, "r") as f:
    image = f["image"][:]
    p_to_nm = f["pixel_to_nm_scaling"][()]
    print(f"file: {file_index} : {bound_file}")
    print(f"p_to_nm: {p_to_nm}")

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(image, cmap=cmap, vmin=VMIN_UNBOUND, vmax=VMAX_UNBOUND)
plt.show()

# Crop a grain
crop_size_nm = 28
crop_size = int(crop_size_nm / p_to_nm)
crop_x = 345
crop_y = 60
bound_crop = image[crop_y : crop_y + crop_size, crop_x : crop_x + crop_size]
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(bound_crop, cmap=cmap, vmin=VMIN_UNBOUND, vmax=VMAX_UNBOUND)
plt.show()

# Predict the mask
bound_prediction = predict_unet_multiclass(
    image=bound_crop,
    model=bound_model,
    confidence=0.5,
    model_image_size=256,
    image_output_dir=None,
    filename="",
    image_index=0,
    quiet=True,
    IMAGE_SAVE_DIR=None,
    normalisation_set_range=(-1, 8),
)
plt.imshow(bound_crop, cmap=cmap, vmin=VMIN_UNBOUND, vmax=VMAX_UNBOUND)
plt.imshow(np.ma.masked_where(bound_prediction == 0, bound_prediction), cmap="viridis", alpha=0.5, interpolation="none")
plt.show()

# Trace the grain

In [None]:
unbound_sc_unordered_trace = skeletonize(unbound_sc_prediction > 0)

# Plot image with trace overlay
fig, ax = plt.subplots(1, 1, figsize=(7, 7))
ax.imshow(unbound_sc_crop, cmap=cmap, vmin=VMIN_UNBOUND, vmax=VMAX_UNBOUND)
ax.imshow(np.ma.masked_where(unbound_sc_unordered_trace == 0, unbound_sc_unordered_trace), cmap="gray", alpha=0.5)
plt.show()

image_dictionary["unbound_sc"] = {
    "image": unbound_sc_crop,
    "mask": unbound_sc_prediction,
    "trace": unbound_sc_unordered_trace,
    "p_to_nm": p_to_nm,
}

unbound_rel_unordered_trace = skeletonize(unbound_rel_prediction > 0)

# Plot image with trace overlay
fig, ax = plt.subplots(1, 1, figsize=(7, 7))
ax.imshow(unbound_rel_crop, cmap=cmap, vmin=VMIN_UNBOUND, vmax=VMAX_UNBOUND)
ax.imshow(np.ma.masked_where(unbound_rel_unordered_trace == 0, unbound_rel_unordered_trace), cmap="gray", alpha=0.5)
plt.show()

image_dictionary["unbound_rel"] = {
    "image": unbound_rel_crop,
    "mask": unbound_rel_prediction,
    "trace": unbound_rel_unordered_trace,
    "p_to_nm": p_to_nm,
}

# bound
distance_transform = distance_transform_edt(bound_prediction > 0)
distance_transform[bound_prediction == 2] = 0
ring_mask = bound_prediction == 1

# remove all but largest ring region
ring_labels = label(ring_mask)
ring_props = regionprops(ring_labels)
ring_areas = [region.area for region in ring_props]
largest_ring_index = np.argmax(ring_areas)
largest_ring_label = ring_props[largest_ring_index].label
largest_ring_mask = ring_labels == largest_ring_label
ring_mask = largest_ring_mask

gem_mask = bound_prediction == 2
dilated_gem_mask = binary_dilation(gem_mask)
intersection = dilated_gem_mask & ring_mask
intersection_labels = label(intersection)
intersection_props = regionprops(intersection_labels)
num_connecting_regions = len(intersection_props)
fig, ax = plt.subplots(1, 1, figsize=(7, 7))
ax.imshow(bound_crop, cmap=cmap, vmin=VMIN_UNBOUND, vmax=VMAX_UNBOUND)
ax.imshow(np.ma.masked_where(intersection == 0, intersection), cmap="gray", alpha=0.5)
print(f"Number of connecting regions: {num_connecting_regions}")
assert num_connecting_regions == 2

region_0 = intersection_props[0]
region_1 = intersection_props[1]
region_0_distance_transform_values = []
region_1_distance_transform_values = []
for pixel in region_0.coords:
    region_0_distance_transform_values.append(distance_transform_edt(pixel[0], pixel[1]))
for pixel in region_1.coords:
    region_1_distance_transform_values.append(distance_transform_edt(pixel[0], pixel[1]))
region_0_distance_transform_values = np.array(region_0_distance_transform_values)
region_1_distance_transform_values = np.array(region_1_distance_transform_values)
region_0_max_distance_transform_value = np.max(region_0_distance_transform_values)
region_1_max_distance_transform_value = np.max(region_1_distance_transform_values)
region_0_max_distance_transform_value_index = np.argmax(region_0_distance_transform_values)
region_1_max_distance_transform_value_index = np.argmax(region_1_distance_transform_values)
region_0_max_distance_transform_value_coords = region_0.coords[region_0_max_distance_transform_value_index]
region_1_max_distance_transform_value_coords = region_1.coords[region_1_max_distance_transform_value_index]

# For the start and end points, use the points on the intersection labels that are closest to the centroids
region_0_centroid = region_0.centroid
region_1_centroid = region_1.centroid
region_0_min_distance_to_centroid = np.inf
region_1_min_distance_to_centroid = np.inf
for pixel in region_0.coords:
    distance = np.linalg.norm(np.array(pixel) - np.array(region_0_centroid))
    if distance < region_0_min_distance_to_centroid:
        region_0_min_distance_to_centroid = distance
        region_0_closest_to_centroid = pixel
for pixel in region_1.coords:
    distance = np.linalg.norm(np.array(pixel) - np.array(region_1_centroid))
    if distance < region_1_min_distance_to_centroid:
        region_1_min_distance_to_centroid = distance
        region_1_closest_to_centroid = pixel

start_point = (region_0_closest_to_centroid[0], region_0_closest_to_centroid[1])
end_point = (region_1_closest_to_centroid[0], region_1_closest_to_centroid[1])

# start_point = (region_0_max_distance_transform_value_coords[0], region_0_max_distance_transform_value_coords[1])
# end_point = (region_1_max_distance_transform_value_coords[0], region_1_max_distance_transform_value_coords[1])

inverted_distance_transform = np.max(distance_transform) - distance_transform
# plt.imshow(inverted_distance_transform)
# plt.show()
inverted_distance_transform[inverted_distance_transform == np.max(distance_transform)] = 1000

route, weight = route_through_array(inverted_distance_transform, start_point, end_point)
route = np.array(route)

# Make the route into a binary mask
route_mask = np.zeros_like(bound_prediction)
for point in route:
    route_mask[point[0], point[1]] = 1

fig, ax = plt.subplots(1, 1, figsize=(7, 7))
ax.imshow(bound_crop, cmap=cmap, vmin=VMIN_UNBOUND, vmax=VMAX_UNBOUND)
ax.imshow(np.ma.masked_where(route_mask == 0, route_mask), cmap="gray", alpha=0.5)
plt.show()

image_dictionary["bound"] = {
    "image": bound_crop,
    "mask": bound_prediction,
    "feret_mask": ring_mask,
    "trace": route_mask,
    "p_to_nm": p_to_nm,
}

# Min feret

In [None]:
# Get the min feret from the trace mask
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
plt.margins(0, 0)
# fig, ax = plt.subplots(3, 4, figsize=(20, 20))
for index, (key, grain_dict) in enumerate(image_dictionary.items()):
    print(f"index: {index}")

    image = grain_dict["image"]
    mask = grain_dict["mask"]
    trace_mask = grain_dict["trace"]
    p_to_nm = grain_dict["p_to_nm"]

    # Plot the image by itself
    FIGSIZE = 7
    fig, image_ax = plt.subplots(1, 1, figsize=(FIGSIZE, FIGSIZE))
    image_ax.imshow(image, cmap=cmap, vmin=VMIN_UNBOUND, vmax=VMAX_UNBOUND)
    # turn off axis
    image_ax.axis("off")
    # save
    fig.savefig(SAVE_DIR / f"{key}_image.png", dpi=600, bbox_inches="tight", pad_inches=0)
    plt.show()

    # Plot the mask overlay
    fig, mask_ax = plt.subplots(1, 1, figsize=(FIGSIZE, FIGSIZE))
    mask_ax.imshow(image, cmap=cmap, vmin=VMIN_UNBOUND, vmax=VMAX_UNBOUND)
    mask_ax.imshow(np.ma.masked_where(mask == 0, mask), cmap="viridis", alpha=0.5)
    # turn off axis
    mask_ax.axis("off")
    # save
    fig.savefig(SAVE_DIR / f"{key}_mask.png", dpi=600, bbox_inches="tight", pad_inches=0)
    plt.show()

    # Plot the trace overlay
    fig, trace_ax = plt.subplots(1, 1, figsize=(FIGSIZE, FIGSIZE))
    trace_ax.imshow(image, cmap=cmap, vmin=VMIN_UNBOUND, vmax=VMAX_UNBOUND)
    trace_ax.imshow(np.ma.masked_where(trace_mask == 0, trace_mask), cmap="gray", alpha=0.5)
    # turn off axis
    trace_ax.axis("off")
    # save
    fig.savefig(SAVE_DIR / f"{key}_trace.png", dpi=600, bbox_inches="tight", pad_inches=0)
    plt.show()

    # Use only the largest connected region
    label_image = label(trace_mask)
    props = regionprops(label_image)
    areas = [prop.area for prop in props]
    max_area = max(areas)
    max_area_index = areas.index(max_area)
    trace_mask[label_image != max_area_index + 1] = 0

    # plt.imshow(trace_mask, cmap="gray")
    # plt.show()

    results = get_feret_from_mask(trace_mask)

    min_feret = results["min_feret"] * p_to_nm
    max_feret = results["max_feret"] * p_to_nm

    min_feret_coords = results["min_feret_coords"]
    max_feret_coords = results["max_feret_coords"]

    print(f"min feret: {min_feret}")
    print(f"max feret: {max_feret}")

    print(min_feret_coords)

    min_feret_colour = "#00FFFF"
    max_feret_colour = "#FF69B4"

    # feret_ax = ax[index, 3]
    fig, feret_ax = plt.subplots(1, 1, figsize=(FIGSIZE, FIGSIZE))
    feret_ax.imshow(image, cmap=cmap, vmin=VMIN_UNBOUND, vmax=VMAX_UNBOUND)
    feret_ax.imshow(np.ma.masked_where(trace_mask == 0, trace_mask), cmap="gray", alpha=0.5)
    feret_ax.scatter(min_feret_coords[0, 1], min_feret_coords[0, 0], color=min_feret_colour, s=200)
    feret_ax.scatter(min_feret_coords[1, 1], min_feret_coords[1, 0], color=min_feret_colour, s=200)
    feret_ax.scatter(max_feret_coords[0, 1], max_feret_coords[0, 0], color=max_feret_colour, s=200)
    feret_ax.scatter(max_feret_coords[1, 1], max_feret_coords[1, 0], color=max_feret_colour, s=200)
    feret_ax.plot(
        [min_feret_coords[0, 1], min_feret_coords[1, 1]],
        [min_feret_coords[0, 0], min_feret_coords[1, 0]],
        color=min_feret_colour,
        linewidth=5,
        linestyle="--",
        label="min feret (min width)",
    )
    feret_ax.plot(
        [max_feret_coords[0, 1], max_feret_coords[1, 1]],
        [max_feret_coords[0, 0], max_feret_coords[1, 0]],
        color=max_feret_colour,
        linewidth=5,
        linestyle="--",
        label="max feret (max width)",
    )
    feret_ax.legend(fontsize=16)
    # turn off axis
    feret_ax.axis("off")
    plt.show()
    # save the figure
    fig.savefig(SAVE_DIR / f"{key}_feret.png", dpi=600, bbox_inches="tight", pad_inches=0)
# plt.show()