In [None]:
from pathlib import Path

import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt
import tensorflow as tf

from plotting import plot_image_mask_prediction
from train import image_data_generator
from evaluate import predict_mask
from preprocess import preprocess_mask, preprocess_image

In [None]:
dir_data = Path("../data/all_data/")
assert dir_data.exists()
path_model = Path("../models/unet_model.keras")
assert path_model.exists()
model = tf.keras.models.load_model(path_model)

image = np.load(dir_data / "image_10.npy")
mask = np.load(dir_data / "mask_10.npy")

predicted_mask, predicted_mask_binary, predicted_mask_flat_discrete, predicted_mask_flat = predict_mask(
    model=model,
    image=image,
    norm_lower_bound=-1.0,
    norm_upper_bound=7.0,
    filter_channels=["original"],
    hessian_component="minima",
    hessian_sigma=4.0,
)

plt.imshow(predicted_mask_flat_discrete)
plt.show()

image_preprocessed = preprocess_image(
    image=image,
    model_image_size=(256, 256),
    norm_lower_bound=-1.0,
    norm_upper_bound=7.0,
    filter_channels=["original"],
    hessian_component="minima",
    hessian_sigma=4.0,
)

mask = preprocess_mask(
    mask=mask,
    model_image_size=(256, 256),
    output_channels=3
)

plot_image_mask_prediction(
    image=image_preprocessed,
    mask=mask,
    mask_predicted=predicted_mask,
    mask_predicted_binary=predicted_mask_binary,
    mask_predicted_flat=predicted_mask_flat,
    mask_predicted_flat_discrete=predicted_mask_flat_discrete
)


In [None]:
print(predicted_mask.shape)
from skimage.morphology import binary_erosion, binary_dilation

prediction_image = tf.argmax(predicted_mask, axis=-1)
print(np.unique(prediction_image))
plt.imshow(prediction_image)
plt.show()

# erode channel 1 and then dilate again
num_erode = 2
num_dilate = 3
channel_1 = prediction_image == 1
plt.imshow(channel_1)
plt.show()
for _ in range(num_erode):
    channel_1 = binary_erosion(channel_1)
for _ in range(num_dilate):
    channel_1 = binary_dilation(channel_1)
plt.imshow(channel_1)
plt.show()

# erode channel 2 and then dilate again
channel_2 = prediction_image == 2
plt.imshow(channel_2)
plt.show()
for _ in range(num_erode):
    channel_2 = binary_erosion(channel_2)
for _ in range(num_dilate):
    channel_2 = binary_dilation(channel_2)
plt.imshow(channel_2)
plt.show()

# combine channels back
final_prediction = np.zeros_like(prediction_image)
final_prediction[channel_1] = 1
final_prediction[channel_2] = 2
plt.imshow(final_prediction)
plt.show()



In [None]:
testarr = np.array(
    [
        [0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 1, 1, 1, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0],

    ]
)
eroded = binary_erosion(testarr)
print("Original:")
print(testarr)
print("Eroded:")
print(eroded.astype(int))