In [None]:
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import h5py

from topostats.grain_finding_haribo_unet import predict_unet, predict_unet_multiclass, load_model, mean_iou, iou
from topostats.plottingfuncs import Colormap

colormap = Colormap()
CMAP = colormap.get_cmap()

# Cas9 Bound

In [None]:
SAMPLE_TYPE = "ON_SC"
CAS9_DATA_DIR = Path(
    f"/Volumes/shared/pyne_group/Shared/AFM_Data/Cas9_Minicircles/Analysis_all/DNA_Cas9/output_justboundcas9/{SAMPLE_TYPE}/processed/"
)
assert CAS9_DATA_DIR.exists()

MODEL_PATH = Path(
    f"/Volumes/shared/pyne_group/Shared/AFM_Data/Cas9_Minicircles/deep_learning/saved_models/haribonet_multiclass_improved_norm_big_95_bridging_v1_2024-01-17_10-58-46.h5"
)
assert MODEL_PATH.exists()
# Load the model
cas9_model = load_model(model_path=MODEL_PATH, custom_objects={"mean_iou": mean_iou})

files = list(CAS9_DATA_DIR.glob("*.topostats"))
print(f"Num images: {len(files)}")

In [None]:
IMAGE_INDEX = 3
file = files[IMAGE_INDEX]

# Load the image via hdf5
with h5py.File(file, "r") as f:
    image = f["image"][:]

plt.imshow(image, cmap=CMAP)
plt.show()

In [None]:
# Crop the image
crop_size = 150
crop_x = 10
crop_y = 260

image_cropped = image[crop_y : crop_y + crop_size, crop_x : crop_x + crop_size]
print(f"max: {np.max(image_cropped)}, min: {np.min(image_cropped)}")
plt.imshow(image_cropped, cmap=CMAP, vmin=-1, vmax=5)
plt.show()

In [None]:
# Try to segment the image with a height threshold
threshold_start = 0.0
threshold_end = 5.0
threshold_step = 0.2

thresholds = np.arange(threshold_start, threshold_end, threshold_step)
thresholded_images = []
for threshold in thresholds:
    thresholded_image = image_cropped > threshold
    thresholded_images.append(thresholded_image)

# Plot the thresholded images
cols = 5
rows = len(thresholded_images) // cols
fig, axes = plt.subplots(rows, cols, figsize=(cols * 5, rows * 5))
for i, ax in enumerate(axes.flat):
    ax.imshow(thresholded_images[i], cmap="gray")
    ax.set_title(f"Threshold: {thresholds[i]:.2f}")
plt.show()

In [None]:
# Predict the mask
mask = predict_unet_multiclass(
    image=image_cropped,
    model=cas9_model,
    confidence=0.5,
    model_image_size=256,
    image_output_dir=None,
    filename=None,
    IMAGE_SAVE_DIR=None,
    image_index=0,
)

plt.imshow(mask, cmap="inferno")
plt.show()

# Unbound DNA

In [None]:
SAMPLE_TYPE = "OT1_SC"
UNBOUND_DATA_DIR = Path(
    f"/Volumes/shared/pyne_group/Shared/AFM_Data/Cas9_Minicircles/Analysis_all/DNA_only/output_noCas9_reprocessed/output_noCas9/{SAMPLE_TYPE}/processed/processed/"
)
assert UNBOUND_DATA_DIR.exists()

UNBOUND_MODEL_PATH = Path(
    f"/Volumes/shared/pyne_group/Shared/AFM_Data/Cas9_Minicircles/deep_learning/saved_models/haribonet_dna_only_single_class_extra_doritos_2024-02-26_23-51-29_image-size-256x256_epochs-45_batch-size-25_learning-rate-0.001.h5"
)
assert UNBOUND_MODEL_PATH.exists()
# Load the model
unbound_model = load_model(model_path=UNBOUND_MODEL_PATH, custom_objects={"iou": iou})

unbound_files = list(UNBOUND_DATA_DIR.glob("*.topostats"))
print(f"Num images: {len(unbound_files)}")

In [None]:
# for UNBOUND_IMAGE_INDEX in range(35, 54):
#     unbound_file = unbound_files[UNBOUND_IMAGE_INDEX]

#     # Load the image via hdf5
#     with h5py.File(unbound_file, "r") as f:
#         unbound_image = f["image"][:]

#     plt.imshow(unbound_image, cmap=CMAP, vmin=-3, vmax=4)
#     plt.title(f"Image index: {UNBOUND_IMAGE_INDEX}")
#     plt.show()

UNBOUND_IMAGE_INDEX = 48
unbound_file = unbound_files[UNBOUND_IMAGE_INDEX]

# Load the image via hdf5
with h5py.File(unbound_file, "r") as f:
    unbound_image = f["image"][:]

plt.imshow(unbound_image, cmap=CMAP, vmin=-3, vmax=4)
plt.show()

# Images of interest:
# ON_SC: 68: good zoom in with gap
# OT1_SC: 7: good close edges of dna
# OT1_SC: 48: good close images with helix

In [None]:
# Crop the image
unbound_crop_size = 130
unbound_crop_x = 190
unbound_crop_y = 160

unbound_image_cropped = unbound_image[
    unbound_crop_y : unbound_crop_y + unbound_crop_size, unbound_crop_x : unbound_crop_x + unbound_crop_size
]
plt.imshow(unbound_image_cropped, cmap=CMAP, vmin=None, vmax=None)
plt.show()

In [None]:
# Try to segment the image with a height threshold
threshold_start = 0.0
threshold_end = 3.0
threshold_step = 0.2

thresholds = np.arange(threshold_start, threshold_end, threshold_step)
thresholded_images = []
for threshold in thresholds:
    thresholded_image = unbound_image_cropped > threshold
    thresholded_images.append(thresholded_image)

# Plot the thresholded images
cols = 5
rows = len(thresholded_images) // cols
fig, axes = plt.subplots(rows, cols, figsize=(cols * 5, rows * 5))
for i, ax in enumerate(axes.flat):
    ax.imshow(thresholded_images[i], cmap="gray")
    ax.set_title(f"Threshold: {thresholds[i]:.2f}")
plt.show()

In [None]:
# Predict the mask
unbound_mask = predict_unet(
    image=unbound_image_cropped,
    model=unbound_model,
    model_image_size=256,
    confidence=0.5,
    image_output_dir=None,
    filename=None,
    normalisation_set_range=[None, None],
)

plt.imshow(unbound_mask, cmap="inferno")
plt.show()