Import libraries

In [None]:
import tensorflow as tf
import tensorflow_io as tfio
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np
import os

Set the input paths and parameters

In [None]:
train_input_path = "/media/baecker/6b38a953-6650-4da5-94d9-57bd718df733/2025/in/2007_tree_rings/input_images/train/image"
train_mask_path = "/media/baecker/6b38a953-6650-4da5-94d9-57bd718df733/2025/in/2007_tree_rings/input_images/train/mask"
test_input_path = "/media/baecker/6b38a953-6650-4da5-94d9-57bd718df733/2025/in/2007_tree_rings/input_images/test/image"
test_mask_path = "/media/baecker/6b38a953-6650-4da5-94d9-57bd718df733/2025/in/2007_tree_rings/input_images/test/mask"
model_path = "/media/baecker/6b38a953-6650-4da5-94d9-57bd718df733/2025/in/2007_tree_rings/models"

Get the paths of the images and masks

In [None]:
train_input_paths = [os.path.join(train_input_path, path) for path in os.listdir(train_input_path) if path.endswith(".tif")]
train_mask_paths = [os.path.join(train_mask_path, path) for path in os.listdir(train_mask_path) if path.endswith(".tif")]
print("Input images: " + str(len(train_input_paths)))
print("Input masks: " + str(len(train_mask_paths)))
print("---")
test_input_paths = [os.path.join(test_input_path, path) for path in os.listdir(test_input_path) if path.endswith(".tif")]
test_mask_paths = [os.path.join(test_mask_path, path) for path in os.listdir(test_mask_path) if path.endswith(".tif")]
print("Test images: " + str(len(test_input_paths)))
print("Test masks: " + str(len(test_mask_paths)))
train_path_dataset = tf.data.Dataset.from_tensor_slices((train_input_paths, train_mask_paths))
test_path_dataset = tf.data.Dataset.from_tensor_slices((test_input_paths, test_mask_paths))

In [None]:
for pair in train_path_dataset.take(1):
    print(pair)

We define a function to read image/mask pairs.

In [None]:
def read_images(img_path, segmentation_mask_path):
    img_data = tf.io.read_file(img_path)
    img = tfio.experimental.image.decode_tiff(img_data)
    img = img[:,:,0:3]
    segm_data = tf.io.read_file(segmentation_mask_path)
    segm_mask = tfio.experimental.image.decode_tiff(segm_data)   
    segm_mask = segm_mask[:,:,0:1]
    return img, segm_mask

Normalize images and masks.

In [None]:
def prepare_images(img, semg_mask):
    # img = tfio.experimental.color.rgba_to_rgb(img)
    img = tf.image.convert_image_dtype(img, tf.float32)
    semg_mask = tf.image.convert_image_dtype(semg_mask, tf.float32)
    semg_mask = semg_mask / 255.0
    return img, semg_mask

We create a dataset containing pairs of images/masks.

In [None]:
train_dataset = train_path_dataset.map(read_images, num_parallel_calls=tf.data.AUTOTUNE).map(prepare_images, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = test_path_dataset.map(read_images, num_parallel_calls=tf.data.AUTOTUNE).map(prepare_images, num_parallel_calls=tf.data.AUTOTUNE)

Build train and validation batches.

In [None]:
BATCH_SIZE = 16
BUFFER_SIZE = 1000
VALIDATION_SIZE = int(round((len(train_dataset) * 20) / 100))
print("validation data size: " + str(VALIDATION_SIZE))
print("train data size: " + str(len(train_dataset) - VALIDATION_SIZE))
validation_batches = train_dataset.take(VALIDATION_SIZE).batch(BATCH_SIZE)
train_batches = train_dataset.skip(VALIDATION_SIZE)
train_batches = train_dataset.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_batches = train_batches.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

In [None]:
train_dataset.take(1)

In [None]:
data = train_dataset.take(1)
mask = data.get_single_element()[1]
mask

Display some random examples of pairs of input tiles and mask tiles.

In [None]:
import matplotlib.pyplot as plt
N = 3
for image, mask in train_dataset.shuffle(len(train_dataset)).take(N):
    print(image.shape)
    print(mask.shape)
    fig, (ax1, ax2) = plt.subplots(1, 2)
    ax1.imshow(image)
    ax2.imshow(mask)
    plt.show()

Building blocks for the UNet.

In [None]:
def double_conv_block(x, n_filters):
   # Conv2D then ReLU activation
   x = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)
   # Conv2D then ReLU activation
   x = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)
   return x

def downsample_block(x, n_filters):
   f = double_conv_block(x, n_filters)
   p = layers.MaxPool2D(2)(f)
   p = layers.Dropout(0.3)(p)
   return f, p

def upsample_block(x, conv_features, n_filters):
   # upsample
   x = layers.Conv2DTranspose(n_filters, 3, 2, padding="same")(x)
   # concatenate
   x = layers.concatenate([x, conv_features])
   # dropout
   x = layers.Dropout(0.3)(x)
   # Conv2D twice with ReLU activation
   x = double_conv_block(x, n_filters)
   return x

Function that builds the UNet

In [None]:
def build_unet_model():
   inputs = layers.Input(shape=(256,256,3))
   # encoder: contracting path - downsample
   # 1 - downsample
   f1, p1 = downsample_block(inputs, 64)
   # 2 - downsample
   f2, p2 = downsample_block(p1, 128)
   # 3 - downsample
   f3, p3 = downsample_block(p2, 256)
   # 4 - downsample
   f4, p4 = downsample_block(p3, 512)
   # 5 - bottleneck
   bottleneck = double_conv_block(p4, 1024)
   # decoder: expanding path - upsample
   # 6 - upsample
   u6 = upsample_block(bottleneck, f4, 512)
   # 7 - upsample
   u7 = upsample_block(u6, f3, 256)
   # 8 - upsample
   u8 = upsample_block(u7, f2, 128)
   # 9 - upsample
   u9 = upsample_block(u8, f1, 64)
   # outputs
   outputs = layers.Conv2D(1, (1,1), padding="same", activation = "sigmoid")(u9)
   # unet model with Keras Functional API
   unet_model = tf.keras.Model(inputs, outputs, name="U-Net")
   return unet_model    

Build the UNet.

In [None]:
unet_model = build_unet_model()

In [None]:
from keras_unet_collection import models
unet_model = models.unet_2d((None, None, 3), [64, 128, 256, 512, 1024], n_labels=2,
                      stack_num_down=2, stack_num_up=1,
                      activation='GELU', output_activation='Softmax', 
                      batch_norm=True, pool='max', unpool='nearest', name='unet')

In [None]:
#@tf.keras.utils.register_keras_serializable
def dice_loss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred)
    return 1 - (2. * intersection + 1) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + 1)

#@tf.keras.utils.register_keras_serializable
def bce_dice_loss(bce_coef=0.5):
    def bcl(y_true, y_pred):
        bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        dice = dice_loss(y_true, y_pred)
        return bce_coef * bce + (1.0 - bce_coef) * dice
    return bcl

In [None]:
keras.utils.plot_model(unet_model, show_shapes=True)
"model.png written"

Compile the model.

In [None]:
unet_model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=bce_dice_loss(bce_coef=0.3),
                  metrics=[tf.keras.metrics.Precision(),
                           tf.keras.metrics.Recall(),
                           tf.keras.metrics.Accuracy()])

In [None]:
NUM_EPOCHS = 100
keras.config.disable_traceback_filtering()
STEPS_PER_EPOCH = len(train_dataset) // BATCH_SIZE
VAL_SUBSPLITS = 5
VAL_LENGTH = VALIDATION_SIZE
VALIDATION_STEPS = VAL_LENGTH // BATCH_SIZE // VAL_SUBSPLITS
model_history = unet_model.fit(train_batches,
                              epochs=NUM_EPOCHS,
                              steps_per_epoch=STEPS_PER_EPOCH,
                              validation_steps=VALIDATION_STEPS,
                              validation_data=validation_batches,
                              verbose=2
                              )

Save a model.

In [None]:
import datetime
date = datetime.datetime.now()
unet_model.save(os.path.join(model_path, "./unet - " + str(date) + ".keras"))

Save the weights only.

In [None]:
import datetime
date = datetime.datetime.now()
unet_model.save_weights(os.path.join(model_path, "./unet - " + str(date) + ".weights.h5"))

In [None]:
print(model_history.history.keys())

In [None]:
# summarize history for accuracy
plt.plot(model_history.history['accuracy'])
plt.plot(model_history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(model_history.history['loss'])
plt.plot(model_history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()

Load a saved model

In [None]:
model = "unet - 2025-01-23 11:09:24.145873.keras"
path = os.path.join(model_path, model)
unet_model = keras.models.load_model(path)

Load the weights only.

In [None]:
model = 'unet - 2025-01-29 11:27:03.366954.weights.h5'
path = os.path.join(model_path, model)
unet_model.load_weights(path)

Evaluate model on test data.

In [None]:
test_batches = train_dataset.take(len(test_dataset)).batch(BATCH_SIZE)
score = unet_model.evaluate(test_batches, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

# Apply the model to an image.

Functions to read and prepare single images.

In [None]:
def read_image(img_path):
    img_data = tf.io.read_file(img_path)
    img = tfio.experimental.image.decode_tiff(img_data)
    img = img[:,:,0:3]
    return img

def prepare_image(img):   
    img = tf.image.convert_image_dtype(img, "float32") # This also scales to [O,1)
    return img


Create a dataset with patches of one image.

In [None]:
import tensorflow_io as tfio
import matplotlib.pyplot as plt
import cv2

image_file = '/media/baecker/6b38a953-6650-4da5-94d9-57bd718df733/2025/in/2007_tree_rings/input_images/4 E 1 m_8µm_x50.tif'
output_file = '/media/baecker/6b38a953-6650-4da5-94d9-57bd718df733/2025/in/2007_tree_rings/input_images/out/4 E 1 m_8µm_x50.tif'
image_path_dataset = tf.data.Dataset.from_tensor_slices([image_file])
image_dataset = image_path_dataset.map(read_image, num_parallel_calls=tf.data.AUTOTUNE).map(prepare_image, num_parallel_calls=tf.data.AUTOTUNE)



ksize_rows = 256
ksize_cols = 256
strides_rows = 196
strides_cols = 196


image = ds # tfio.experimental.image.decode_tiff(image_data)

print(image)
# The size of sliding window
ksizes = [1, ksize_rows, ksize_cols, 1] 

# How far the centers of 2 consecutive patches are in the image
strides = [1, strides_rows, strides_cols, 1]

# The document is unclear. However, an intuitive example posted on StackOverflow illustrate its behaviour clearly. 
# http://stackoverflow.com/questions/40731433/understanding-tf-extract-image-patches-for-extracting-patches-from-an-image
rates = [1, 1, 1, 1] # sample pixel consecutively

# padding algorithm to used
padding='SAME' # or 'SAME'

# image = tf.expand_dims(image, 0)
image_patches = tf.image.extract_patches(images=list(image_dataset.take(1)), sizes=ksizes, strides=strides, rates=rates, padding=padding)


Display the input patches.

In [None]:
columns = image_patches.shape[1]
rows = image_patches.shape[2]

print(columns, rows)
# retrieve the 1st patches
fig = plt.figure(figsize=(columns, rows)) 
fig.tight_layout()
i = 1
for col in range(columns):
    for row in range(rows):
        patch = image_patches[0,col,row,]
        patch = tf.reshape(patch, [ksize_rows, ksize_cols, 3])        
        fig.add_subplot(columns, rows, i) 
        plt.axis('off') 
        plt.imshow(patch)
        i = i + 1
# visualize image

plt.show()

Predict the rings on the patches

In [None]:
columns = image_patches.shape[1]
rows = image_patches.shape[2]

print(columns, rows)
# retrieve the 1st patches
fig = plt.figure(figsize=(columns, rows)) 
fig.tight_layout()
i = 1
results = []
for col in range(columns):
    for row in range(rows):
        patch = image_patches[0,col,row,]
        patch = tf.reshape(patch, [1, ksize_rows, ksize_cols, 3])        
        res = unet_model.predict([patch], verbose=0)
        res = np.squeeze(res)
        results.append(res)
        fig.add_subplot(columns, rows, i) 
        plt.axis('off') 
        strides_rows
        i = i + 1

In [None]:
plt.figure(figsize=(columns, rows)) 
output = np.array(results).reshape(columns, rows, ksize_cols, ksize_rows)
img = output[11][16]
img = (img > 0.000005).astype('uint8')
plt.imshow(img)
print(img)

Batch apply classifier

Set the input and output folders and the parameters.

In [None]:
INPUT_FOLDER = "/media/baecker/6b38a953-6650-4da5-94d9-57bd718df733/2025/in/2007_tree_rings/input_images/input/"
OUTPUT_FOLDER = "/media/baecker/6b38a953-6650-4da5-94d9-57bd718df733/2025/in/2007_tree_rings/input_images/out/"
PATCH_SIZE = 256
STRIDE_WIDTH = 256
CHANNELS = 3
PADDING = 'VALID'
THRESHOLD = 0.000005

In [None]:
import cv2
output = np.array(results).reshape(columns, rows, ksize_cols, ksize_rows)
height = list(image_dataset.take(1))[0].shape[0]
width = list(image_dataset.take(1))[0].shape[1] 
reconstructed = np.zeros(height*width).reshape(height, width)
print("image shape", reconstructed.shape)
y = 0
i = 0
for col in range(columns):
    x = 0
    for row in range(rows):   
        print("row: ", row, " col: ", col)
        yEnd = min(y+ksize_rows, height)
        xEnd = min(x+ksize_cols, width)
        deltaY = yEnd - y 
        deltaX = xEnd - x 
        reconstructed[y:yEnd, x:xEnd] = output[col, row, 0:deltaY, 0:deltaX]
        x = x + strides_cols 
    y = y + strides_rows
reconstructed = ((reconstructed > 0.000005).astype('uint8'))*255
reconstructed = np.roll(reconstructed, -(width % strides_cols), axis=0)
reconstructed = np.roll(reconstructed, -(height % strides_rows), axis=1)
cv2.imwrite(output_file, reconstructed)

Function to create patches, predict result patches from input patches and reconstruct a result image.

In [None]:
def createPatches(image_file, patch_size=255, stride_width=196, padding='SAME'):
    image_path_dataset = tf.data.Dataset.from_tensor_slices([image_file])
    image_dataset = image_path_dataset.map(read_image, num_parallel_calls=tf.data.AUTOTUNE).map(prepare_image, num_parallel_calls=tf.data.AUTOTUNE)
    height = list(image_dataset.take(1))[0].shape[0]
    width = list(image_dataset.take(1))[0].shape[1] 
    ksizes = [1, patch_size, patch_size, 1] 
    strides = [1, stride_width, stride_width, 1]
    rates = [1, 1, 1, 1]
    patches = tf.image.extract_patches(images=list(image_dataset.take(1)), sizes=ksizes, strides=strides, rates=rates, padding=padding)
    return patches, height, width

def predictPatches(model, image_patches, patch_size=256, channels=3):
    columns = image_patches.shape[1]
    rows = image_patches.shape[2]
    results = []
    for col in range(columns):
        for row in range(rows):
            patch = image_patches[0,col,row,]
            patch = tf.reshape(patch, [1, patch_size, patch_size, channels])        
            res = model.predict([patch], verbose=0)
            res = np.squeeze(res)
            results.append(res)
    output = np.array(results).reshape(columns, rows, patch_size, patch_size)            
    return output

def reconstructFromPatches(patches, original_image_height, original_image_width, patch_size=256, stride_width=196, threshold=0.000005):
    height, width = original_image_height, original_image_width
    rows = patches.shape[1]
    columns = patches.shape[0]
    reconstructed = np.zeros(height*width).reshape(height, width)    
    y = 0
    for col in range(columns):
        x = 0
        for row in range(rows):   
            yEnd = min(y+patch_size, height)
            xEnd = min(x+patch_size, width)
            deltaY = yEnd - y 
            deltaX = xEnd - x
            reconstructed[y:yEnd, x:xEnd] = patches[col, row, 0:deltaY, 0:deltaX]
            x = x + stride_width 
        y = y + stride_width
    reconstructed = ((reconstructed > threshold).astype('uint8')) * 255
    return reconstructed


def displayPatches(patches, patch_size=256, channels=3):
    if tf.is_tensor(patches):
        columns = patches.shape[1]
        rows = patches.shape[2]
    else:
        columns = patches.shape[0]
        rows = patches.shape[1]
    fig = plt.figure(figsize=(columns, rows)) 
    fig.tight_layout()
    i = 1
    for col in range(columns):
        for row in range(rows):
            if tf.is_tensor(patches):
                patch = patches[0, col, row,]
                patch = tf.reshape(patch, [patch_size, patch_size, channels])        
            else:
                patch = patches[col, row]
            fig.add_subplot(columns, rows, i) 
            plt.axis('off') 
            plt.imshow(patch)
            i = i + 1
    plt.show()

In [None]:
import cv2

def predictImage(image_file, model, patch_size=256, stride_width=196, padding="SAME", channels=3, threshold= 0.000005):
    imagePatches, height, width = createPatches(image_file, patch_size, stride_width, padding)
    maskPatches = predictPatches(model, imagePatches, patch_size, channels)
    mask = reconstructFromPatches(maskPatches, height, width, patch_size, stride_width, threshold)
    return mask

def batchPredict(input_folder, output_folder, patch_size=256, stride_width=196, padding="SAME", channels=3, threshold= 0.000005):
    predict_input_paths = [os.path.join(input_folder, path) for path in os.listdir(input_folder) if path.endswith(".tif")]
    predict_output_paths = [os.path.join(output_folder, path) for path in os.listdir(input_folder) if path.endswith(".tif")]
    paths = zip(predict_input_paths, predict_output_paths)
    counter = 1
    for input_file, output_file in paths:
        print("Processing image " + str(counter) + " of " + str(len(predict_input_paths)))
        print("in: ", input_file)
        print("out: ", output_file)
        mask = predictImage(input_file, unet_model, patch_size, stride_width, padding, channels, threshold)
        cv2.imwrite(output_file, mask)
        counter = counter + 1

Test create patches

In [None]:
IMAGE_PATH = "/media/baecker/6b38a953-6650-4da5-94d9-57bd718df733/2025/in/2007_tree_rings/input_images/input/T 5 b_8µm_x50.tif"
imagePatches, height, width = createPatches(IMAGE_PATH, PATCH_SIZE, STRIDE_WIDTH, PADDING)
print(height, width)
displayPatches(imagePatches, PATCH_SIZE, CHANNELS)

Test predict patches.

In [None]:
output = predictPatches(unet_model, imagePatches, PATCH_SIZE, CHANNELS)

In [None]:
print(output.shape)

In [None]:
displayPatches(output, PATCH_SIZE, CHANNELS)

Test reconstruct from patches.

In [None]:
reconstructed = reconstructFromPatches(output, height, width, PATCH_SIZE, STRIDE_WIDTH, THRESHOLD)
plt.figure(figsize=(8, 8)) 
plt.axis('off') 
plt.imshow(reconstructed)

Test predict image

In [None]:
mask = predictImage('/media/baecker/6b38a953-6650-4da5-94d9-57bd718df733/2025/in/2007_tree_rings/input_images/input/4 E 1 b_8µm_x50.tif', unet_model, 256, 256, "VALID", 3, 0.000005)
plt.figure(figsize=(8, 8)) 
plt.axis('off') 
plt.imshow(mask)

In [None]:
print(type(output), output.shape, tf.is_tensor(output))
print(type(imagePatches), imagePatches.shape, tf.is_tensor(imagePatches))

In [None]:
batchPredict('/media/baecker/6b38a953-6650-4da5-94d9-57bd718df733/2025/in/2007_tree_rings/unused', '/media/baecker/6b38a953-6650-4da5-94d9-57bd718df733/2025/in/2007_tree_rings/unused/out', PATCH_SIZE, STRIDE_WIDTH, PADDING, CHANNELS, THRESHOLD)