In [None]:
from keras.utils import normalize
import os
import cv2
from PIL import Image
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
import seaborn as sns
from unet import unet_model
import random
from pathlib import Path
import logging
import itertools

LOGGER = logging.getLogger()
from datetime import datetime

from skimage.morphology import binary_erosion
from skimage.morphology import skeletonize
from skimage.filters import hessian
from skimage.feature import hessian_matrix, hessian_matrix_eigvals
from skimage.morphology import label
from skimage.measure import regionprops
from skimage.color import label2rgb

# TopoStats needs to be >= version 2.1.0
from topostats import io

import tensorflow as tf
from sklearn.model_selection import train_test_split

In [None]:
# Ensure that your GPU is working
tf.test.gpu_device_name()

In [None]:
# Set the random seeds
SEED = 0
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [None]:
# IMAGE_DIR = Path('/home/sylvia/Desktop/dl_data/images/')
# MASK_DIR = Path('/home/sylvia/Desktop/dl_data/ground_truth/')
# MODEL_SAVE_DIR = Path('/home/sylvia/Desktop/dl_data/saved_models/')
IMAGE_DIR = Path("/Users/sylvi/topo_data/cats/training_data/images_edge_detection_lower_training_sigma_1/")
# MASK_DIR = Path("/Users/sylvi/topo_data/cats/training_data/ground_truth_edges/")
MASK_DIR = Path("/Users/sylvi/topo_data/cats/training_data/images_edge_detection_lower_labels_sigma_1/")
MODEL_SAVE_DIR = Path("./catsnet/saved_models/")
SIZE = 512

In [None]:
def detect_ridges(gray, sigma=1.0):
    H_elems = hessian_matrix(gray, sigma=sigma, order="rc")
    maxima_ridges, minima_ridges = hessian_matrix_eigvals(H_elems)
    return maxima_ridges

In [None]:
image_dataset = []
mask_dataset = []

print(" -- images --")

images = os.listdir(IMAGE_DIR)
images = sorted(images)
print(images)
for index, image_name in enumerate(images):
    print(image_name)
    if image_name.split(".")[-1] == "png":
        image = cv2.imread(str(IMAGE_DIR / image_name), 0)
        image = Image.fromarray(image)
        image = image.resize((SIZE, SIZE))
        image = np.array(image)
        # Detect ridges
        # image = detect_ridges(image)
        # fig, ax = plt.subplots(1, 2, figsize=(3, 3))
        # sns.kdeplot(image.flatten(), ax=ax[0])
        image = image - np.min(image)
        image = image / np.max(image)
        image_dataset.append(image)
        # sns.kdeplot(image.flatten(), ax=ax[1])
        # plt.show()
        # print(f"image min: {np.min(image)} image max: {np.max(image)}")
        # image_flip_y = np.flip(image, axis=0)
        # image_flip_x = np.flip(image, axis=1)
        # image_flip_xy = np.flip(image_flip_y, 1)
        # for im in [image, image_flip_y, image_flip_x, image_flip_xy]:
        #     image_dataset.append(im)
        #     image_dataset.append(cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE))
        #     image_dataset.append(cv2.rotate(im, cv2.ROTATE_180))
        #     image_dataset.append(cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE))

print("-- masks --")

masks = os.listdir(MASK_DIR)
masks = sorted(masks)
print(masks)
for index, image_name in enumerate(masks):
    print(image_name)
    if image_name.split(".")[1] == "png":
        image = cv2.imread(str(MASK_DIR / image_name), 0)
        image = Image.fromarray(image)
        image = image.resize((SIZE, SIZE))
        image = np.array(image)
        mask_dataset.append(image.astype(bool))
        # print(f"mask unique: {np.unique(image)}")
        # image_flip_y = np.flip(image, axis=0)
        # image_flip_x = np.flip(image, axis=1)
        # image_flip_xy = np.flip(image_flip_y, axis=1)
        # for im in [image, image_flip_y, image_flip_x, image_flip_xy]:
        #     mask_dataset.append(im.astype(bool))
        #     mask_dataset.append(cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE).astype(bool))
        #     mask_dataset.append(cv2.rotate(im, cv2.ROTATE_180).astype(bool))
        #     mask_dataset.append(cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE).astype(bool))

# CLEAN UP
# del (images, masks, image, im, image_name, index, image_flip_y, image_flip_x, image_flip_xy)

In [None]:
# Check the data has been loaded correctly

print(f"image dataset size: {len(image_dataset)}")
print(f"mask dataset size: {len(mask_dataset)}")

index = np.random.randint(0, len(image_dataset) - 1)
print(f"index: {index}")

plt.imshow(image_dataset[index])
plt.show()
print(f"img dataset | min: {np.min(image_dataset)} max: {np.max(image_dataset)}")
print(np.shape(image_dataset[index]))

plt.imshow(mask_dataset[index])
plt.show()
print(f"unique: {np.unique(mask_dataset[index])}")
print(f"shape: {np.shape(mask_dataset[index])}")

In [None]:
# DO NOT RUN MORE THAN ONCE

# Expand dims because the model doesn't work unless we add an extra dimension, don't know why.
image_dataset = np.expand_dims(np.array(image_dataset), 3)
mask_dataset = np.expand_dims(np.array(mask_dataset), 3)

print(image_dataset.shape)
print(mask_dataset.shape)
print(f"image dataset min, max: {np.min(image_dataset), np.max(image_dataset)}")
print(f"mask unique values: {np.unique(mask_dataset)}")

In [None]:
X_train, X_test, y_train, y_test = train_test_split(image_dataset, mask_dataset, test_size=0.1, random_state=SEED)
print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape)

In [None]:
def rotate_image(image: np.ndarray, angle_90_multiple: int):
    return np.rot90(image, k=angle_90_multiple)


def flip_image(image: np.ndarray):
    return np.flipud(image)


def augment_image(image: np.ndarray):
    images = []

    return images


def augment_image_set(training_set: np.ndarray):
    augmented_set = np.zeros(
        (training_set.shape[0] * 8, training_set.shape[1], training_set.shape[2], training_set.shape[3])
    )

    for index in range(training_set.shape[0]):
        image = training_set[index, :, :, 0]
        image_flipped = flip_image(image)
        augmented_set[index * 8 + 0, :, :, 0] = image
        augmented_set[index * 8 + 1, :, :, 0] = image_flipped.copy()
        augmented_set[index * 8 + 2, :, :, 0] = rotate_image(image_flipped.copy(), 1)
        augmented_set[index * 8 + 3, :, :, 0] = rotate_image(image_flipped.copy(), 2)
        augmented_set[index * 8 + 4, :, :, 0] = rotate_image(image_flipped.copy(), 3)

        augmented_set[index * 8 + 5, :, :, 0] = rotate_image(image.copy(), 1)
        augmented_set[index * 8 + 6, :, :, 0] = rotate_image(image.copy(), 2)
        augmented_set[index * 8 + 7, :, :, 0] = rotate_image(image.copy(), 3)

    return augmented_set

In [None]:
print(X_train.shape)
print(y_train.shape)

X_train = augment_image_set(X_train)
y_train = augment_image_set(y_train)

print(X_train.shape)
print(y_train.shape)

In [None]:
print(np.max(X_train))
print(np.min(X_train))
print(np.unique(y_train))

In [None]:
# Check things are working correctly
image_number = random.randint(0, len(X_train) - 1)
print(f"image number: {image_number} / {len(X_train)}")
plt.figure(figsize=(12, 6))
plt.subplot(121)
plt.imshow(np.reshape(X_train[image_number], (SIZE, SIZE)), cmap="gray")
plt.subplot(122)
plt.imshow(np.reshape(y_train[image_number], (SIZE, SIZE)), cmap="gray")
plt.show()

In [None]:
IMG_HEIGHT = image_dataset.shape[1]
IMG_WIDTH = image_dataset.shape[2]
IMG_CHANNELS = image_dataset.shape[3]


def get_model():
    return unet_model(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)


model = get_model()

In [None]:
# Train the model
BATCH_SIZE = 4
EPOCHS = 50
history = model.fit(
    X_train, y_train, batch_size=4, verbose=1, epochs=50, validation_data=(X_test, y_test), shuffle=False
)

# SAVE THE MODEL WITH DATE AND PARAMS IN THE NAME
now = datetime.now()
dt_string = str(now.strftime("%Y%m%d_%H-%M-%S"))
filename = str(MODEL_SAVE_DIR / f"{dt_string}_cats_{SIZE}_b{BATCH_SIZE}_e{EPOCHS}_hessian_lower_1_0.hdf5")
print(f"saving file: {filename}")
model.save(filename)

## Load model

In [None]:
# # LOAD MODEL
# model = tf.keras.models.load_model(MODEL_SAVE_DIR / "20230811_14-26-19_cats.hdf5")
model = tf.keras.models.load_model(MODEL_SAVE_DIR / "20230815_17-50-23_cats_512_b4_e50_hessian_lower_4_0.hdf5")
model.summary()

In [None]:
# Check model accuracy
_, acc = model.evaluate(X_test, y_test)
print("Accuracy = ", (acc * 100.0), "%")

In [None]:
loss = history.history["loss"]
val_loss = history.history["val_loss"]
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, "y", label="Training loss")
plt.plot(epochs, val_loss, "r", label="Valdation loss")
plt.title("Training and validation loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

acc = history.history["accuracy"]
val_acc = history.history["val_accuracy"]

plt.plot(epochs, acc, "y", label="Training acc")
plt.plot(epochs, val_acc, "r", label="Validation acc")
plt.title("Training and validation accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

In [None]:
# Note that for semantic segmentation, accuracy is not the correct metric.

# Calculate IOU
y_pred = model.predict(X_test)
y_pred_thresholded = y_pred > 0.5  # this value is a probability cutoff

intersection = np.logical_and(y_test, y_pred_thresholded)
union = np.logical_or(y_test, y_pred_thresholded)
iou_score = np.sum(intersection) / np.sum(union)
print(f"IoU score: {iou_score}")

In [None]:
# See how it predicts our testing dataset

threshold = 0.15

test_img_number = random.randint(0, len(X_test) - 1)
print(f"test image number: {test_img_number} / {len(X_test)}")
test_img = X_test[test_img_number]
ground_truth = y_test[test_img_number]
ground_truth = ground_truth.reshape(512, 512)
print(f"ground truth shape: {ground_truth.shape}")
test_img_norm = test_img[:, :, 0][:, :, None]
plt.hist(test_img_norm[:, :, 0])
test_img_input = np.expand_dims(test_img_norm, 0)
prediction = (model.predict(test_img_input)[0, :, :, 0] > threshold).astype(np.uint8)

plt.figure(figsize=(16, 8))
plt.subplot(131)
plt.title("Test image")
plt.imshow(test_img[:, :, 0], cmap="gray")
plt.subplot(132)
plt.imshow(ground_truth, cmap="gray")
plt.title("Testing label")
plt.subplot(133)
plt.imshow(prediction, cmap="gray")
plt.title("Prediction")
plt.show()

fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(test_img)
ax.imshow(np.ma.masked_where(prediction == 0, prediction))
plt.show()

prediction = prediction == 0
print(f"prediction shape: {prediction.shape}")
difference = ground_truth.astype(int) - prediction.astype(int)
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(difference)
ax.set_title("prediction difference - yellow is prediction, teal is ground truth")
plt.show()

labelled = label(prediction)
coloured = label2rgb(labelled)
plt.imshow(coloured)

In [None]:
# Try a completely new image

# FROM PNG
# another_image_file = './dl_data/to_be_labelled/20230207_MeO_perov_10um.0_00001.png'
# another_image_file = str(
#     Path(
#         "/Users/sylvi/topo_data/textured-silicon/flattened_images/tapp_ref_perovskite_si_fraunhofer_5um.0_00000_sylvia_freq_split_1DFFT_filter.png"
#     )
# )
another_image_file = str(
    Path("/Users/sylvi/topo_data/cats/training_data/images_edge_detection_lower_all_sigma_1/training_image_11.png")
)
image = cv2.imread(another_image_file, 0)
print(f"original shape: {np.sqrt(image.size)}")

threshold = 0.01

# FROM SPM
# another_image_file = Path('/home/sylvia/Desktop/unseen_data/spm_images/20230512-Me41-100NPBSpin-512-5um.0_00000.spm')
# another_image_file = Path('/Users/sylvi/topo_data/perovskite/perovskite_images/Me/20230512-Me41-100NPBSpin-512-5um.0_00000.spm')
# another_image_file = Path('/Users/sylvi/topo_data/perovskite/perovskite_images/AR/AR123_926F_FACsPbI3_evap_5um.0_00002.spm')
# another_image_file = Path('/Users/sylvi/topo_data/perovskite/perovskite_images/AR/AR115_25_BA_MAPbI3_10um.ibw')

# loadscan = io.LoadScans(img_paths = [another_image_file], channel='HeightRetrace')
# loadscan = io.LoadScans(img_paths = [another_image_file], channel='Height')
# loadscan.get_data()
# image, pixel_to_nm = loadscan.load_ibw()
# image, pixel_to_nm = loadscan.load_spm()

# image = perov_flatten.flatten_image(image, order=3, plot_steps=True)

# ====== Loading PNG images =========
# another_image_file = Path("/Users/sylvi/topo_data/textured-silicon/flattened_images/tapp_ref_perovskite_si_fraunhofer_5um.0_00000_sylvia_freq_split_1DFFT_filter.png")
# another_image_file = cv2.imread(str(image), cv2.IMREAD_GRAYSCALE)
# ===================================

# image = detect_ridges(image, sigma=1.0)

# Normalize image
sns.kdeplot(image.flatten())
plt.show()
print(f"min: {np.min(image)} max: {np.max(image)}")
image = image - np.min(image)
image = image / np.max(image)
# image = image - np.min(image)
# print(f'min: {np.min(image)} max: {np.max(image)}')
# image = normalize(image)
# image = image / np.max(image)
print(f"min: {np.min(image)} max: {np.max(image)}")
sns.kdeplot(image.flatten())
plt.show()

# Resize image to the correct size
print(f"image shape: {image.shape}")
print(f"min, max values: {np.min(image), np.max(image)}")
print(np.unique(image))
image = Image.fromarray(image)
image = image.resize((SIZE, SIZE))
image = np.array(image)
print(f"image shape: {image.shape}")

# Get the input image into the right form (since training was in A x 512 x 512 x 1 shape)
# This was just trial and error, I don't really have a good understanding of this bit
to_predict = [image]
to_predict = np.array(to_predict)
print(f"to predict shape: {to_predict.shape}")
to_predict = np.expand_dims(to_predict, 3)
print(f"to predict shape: {to_predict.shape}")
# Fetch image from the strange array
test_img = to_predict[0]
test_img = test_img[:, :, 0][:, :, None]
test_img = np.expand_dims(test_img, 0)
print(f"to predict shape: {test_img.shape}")

# Get prediction
prediction = (model.predict(test_img)[0, :, :, 0] > threshold).astype(np.uint8)
print(f"prediction unique vals: {np.unique(prediction)}")
# prediction = skeletonize(prediction)
plt.imshow(image)
plt.show()
plt.imshow(prediction)
plt.show()

# # Pot prediction
# fig, ax = plt.subplots(figsize=(20, 20))
# ax.imshow(image)
# masked = np.ma.masked_where(prediction.astype(int) == 0, prediction)
# ax.imshow(masked)
# ax.set_title("prediction")
# plt.show()

# labelled = label(prediction == 0, connectivity=1)
# coloured = label2rgb(labelled)
# fig, ax = plt.subplots()
# ax.imshow(coloured)
# plt.show()

# fig, ax = plt.subplots(figsize=(14, 14))
# test_mask = np.zeros(image.shape)
# for j in range(coloured.shape[0]):
#     for i in range(coloured.shape[1]):
#         if np.array_equal(coloured[j, i], np.array([0, 0, 0])):
#             test_mask[j, i] = 1


# overlay = np.zeros(image.shape)

# test_masked = np.ma.masked_where(test_mask == 1, overlay)
# ax.imshow(image)
# ax.imshow(test_masked, alpha=0.1, cmap="binary")
# ax.set_title("overlay")
# plt.show()

# # Plot skeletonised segmentation
# skeleton_prediction = skeletonize(prediction)
# print(f"unique skeleton prediction: {np.unique(skeleton_prediction)}")
# plt.imshow(skeleton_prediction)
# plt.show()

# fig, ax = plt.subplots(figsize=(20, 20))
# ax.imshow(image)
# ax.set_title("flattened image")
# plt.show()

# fig, ax = plt.subplots(figsize=(20, 20))
# ax.imshow(image)
# masked = np.ma.masked_where(skeleton_prediction == 0, skeleton_prediction.astype(int))
# ax.imshow(masked)
# plt.title("skeletonized prediction - if cannot see skeleton, increase image size")
# plt.show()

# fig, ax = plt.subplots(4, 2, figsize=(10, 22))
# section_size = 40
# vmin = np.min(image)
# vmax = np.max(image)
# for index in range(ax.shape[0]):
#     y = np.random.randint(0, image.shape[0] - section_size)
#     x = np.random.randint(0, image.shape[1] - section_size)
#     img_section = image[y : y + section_size + 1, x : x + section_size + 1]
#     mask_section = masked[y : y + section_size + 1, x : x + section_size + 1]
#     ax[index, 0].imshow(img_section, vmin=vmin, vmax=vmax)
#     ax[index, 1].imshow(img_section, vmin=vmin, vmax=vmax)
#     ax[index, 1].imshow(mask_section)
#     ax[index, 0].set_title(f"coords: x: {x}, y: {y}")
#     ax[index, 1].set_title(f"coords: x: {x}, y: {y}")
# fig.suptitle(f"file: {another_image_file} | section size:{section_size}")
# fig.tight_layout()