In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import keras

from topostats.unet_masking import predict_unet, mean_iou, iou_loss
from topostats.io import LoadScans

In [None]:
model_path = Path("/Users/sylvi/topo_data/eva-rna/test-running-model/eva_rna_model_40_epoch_simple_20250316.keras")
unet_model = keras.models.load_model(
    model_path, custom_objects={"mean_iou": mean_iou, "iou_loss": iou_loss}, compile=False
)

data_dir = Path("/Users/sylvi/topo_data/eva-rna/test-running-model/output/processed")
files = data_dir.glob("*.topostats")

loadscans = LoadScans(
    img_paths=files,
    channel="dummy",
)
loadscans.get_data()

for filename, file_data in loadscans.img_dict.items():
    print(f"File: {filename}")
    print(file_data.keys())

    image = file_data["image"]
    plt.imshow(image, cmap="gray")

    crop_size = 400
    crop_x = 80
    crop_y = 80
    crop = image[crop_x : crop_x + crop_size, crop_y : crop_y + crop_size]
    plt.imshow(crop, cmap="gray")
    plt.show()

    # Predict the mask using the U-Net model
    predicted_mask = predict_unet(
        image=crop,
        model=unet_model,
        confidence=0.5,
        model_input_shape=unet_model.input_shape,
        upper_norm_bound=7.0,
        lower_norm_bound=-1.0,
    )

    predicted_mask_foreground = predicted_mask[:, :, 1]

    plt.imshow(predicted_mask_foreground, cmap="gray")
    plt.show()