In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import tensorflow_io as tfio
import numpy as np
import os
import tifffile
import math
import tensorflow_io as tfio
import matplotlib.pyplot as plt
import cv2

In [2]:
def double_conv_block(x, n_filters):
   # Conv2D then ReLU activation
   x = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)
   # Conv2D then ReLU activation
   x = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)
   return x

def downsample_block(x, n_filters):
   f = double_conv_block(x, n_filters)
   p = layers.MaxPool2D(2)(f)
   p = layers.Dropout(0.3)(p)
   return f, p

def attention_gate(g, s, num_filters):
    Wg = layers.Conv2D(num_filters, 3, padding="same")(g)
    Wg = layers.BatchNormalization()(Wg)
 
    Ws = layers.Conv2D(num_filters, 3, padding="same")(s)
    Ws = layers.BatchNormalization()(Ws)
 
    out = layers.Activation("relu")(Wg + Ws)
    out = layers.Conv2D(num_filters, 3, padding="same")(out)
    out = layers.Activation("sigmoid")(out)
 
    return out * s
    
def upsample_block(x, conv_features, n_filters):
   # upsample
   x = layers.Conv2DTranspose(n_filters, 3, 2, padding="same")(x)
   s = attention_gate(x, conv_features, n_filters)
   # concatenate
   x = layers.concatenate([x, s])
   # dropout
   x = layers.Dropout(0.3)(x)
   # Conv2D twice with ReLU activation
   x = double_conv_block(x, n_filters)
   return x

def build_unet_model():
   inputs = layers.Input(shape=(256,256,3))
   # encoder: contracting path - downsample
   # 1 - downsample
   f1, p1 = downsample_block(inputs, 64)
   # 2 - downsample
   f2, p2 = downsample_block(p1, 128)

   # 3 - downsample
   f3, p3 = downsample_block(p2, 256)
   # 4 - downsample
   f4, p4 = downsample_block(p3, 512)
   # 5 - bottleneck
   bottleneck = double_conv_block(p4, 1024)
   # decoder: expanding path - upsample
   # 6 - upsample
   u6 = upsample_block(bottleneck, f4, 512)
   # 7 - upsample
   u7 = upsample_block(u6, f3, 256)
   # 8 - upsample
   u8 = upsample_block(u7, f2, 128)
   # 9 - upsample
   u9 = upsample_block(u8, f1, 64)
   # outputs
   outputs = layers.Conv2D(1, (1,1), padding="same", activation = "linear")(u9)
   # unet model with Keras Functional API
   unet_model = tf.keras.Model(inputs, outputs, name="U-Net")
   return unet_model   

In [3]:
unet_model = build_unet_model()
unet_model.load_weights('best_model.h5')

I0000 00:00:1741625872.315879  324515 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1741625872.341280  324515 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1741625872.341408  324515 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1741625872.344965  324515 cuda_executor.cc:1015] successful NUMA node read from SysFS ha

In [4]:
unet_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
                  loss='mse',
                  metrics=['accuracy', 'precision', 'recall', 'mse'])

In [16]:
def exportPatches(img_data, patch_size, overlap):
    # The size of sliding window
    ksizes = [1, patch_size, patch_size, 1] 
    # How far the centers of 2 consecutive patches are in the image
    strides = [1, overlap, overlap, 1]
    # The document is unclear. However, an intuitive example posted on StackOverflow illustrate its behaviour clearly. 
    # http://stackoverflow.com/questions/40731433/understanding-tf-extract-image-patches-for-extracting-patches-from-an-image
    rates = [1, 1, 1, 1] # sample pixel consecutively
    # padding algorithm to used
    padding='SAME' # or 'SAME'
    
    image_patches = tf.image.extract_patches(images=img_data[None, :, :, :], sizes=ksizes, strides=strides, rates=rates, padding=padding)
    results = []
    for i in range(image_patches.shape[1]):
        for j in range(image_patches.shape[2]):
            results.append(tf.reshape(image_patches[0, i, j, :], [patch_size, patch_size, 3]))
    return tf.stack(results)

In [None]:
def read_image(img_path):
    img_data = tf.io.read_file(img_path)
    img = tfio.experimental.image.decode_tiff(img_data)
    img = img[:,:,0:3]
    return img

def prepare_image(img):   
    img = tf.image.convert_image_dtype(img, "float32") # This also scales to [O,1)
    return img


input_folder = '/home/khietdang/Documents/khiet/treeRing/input'
output_folder = '/home/khietdang/Documents/khiet/treeRing/predictions'
patch_size = 256
overlap = patch_size - 196
batch_size = 8



for im_name in os.listdir(input_folder):
    full_path = os.path.join(input_folder, im_name)
    im_data = tifffile.imread(full_path)
    
    shape = im_data.shape
    tiles = exportPatches(im_data, patch_size=patch_size, overlap=overlap)

    predictions = np.squeeze(unet_model.predict(tiles, batch_size=batch_size))

    tiles_manager = ImageTiler2D(patch_size, overlap, shape[:2])
    probabilities = tiles_manager.tiles_to_image(predictions)

    tifffile.imwrite(os.path.join(output_folder, im_name), probabilities)


I0000 00:00:1741626100.506932  324645 service.cc:146] XLA service 0x77491c002870 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1741626100.506959  324645 service.cc:154]   StreamExecutor device (0): NVIDIA RTX 4000 Ada Generation Laptop GPU, Compute Capability 8.9
2025-03-10 18:01:40.517424: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-03-10 18:01:40.598584: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8900
2025-03-10 18:01:41.804969: W external/local_tsl/tsl/framework/bfc_allocator.cc:291] Allocator (GPU_0_bfc) ran out of memory trying to allocate 17.17GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2025-03-10 18:01:42.748866: W external/local_tsl/tsl/framework/bfc_a

KeyboardInterrupt: 

In [None]:




image_file = '/home/khietdang/Documents/khiet/treeRing/input/4 E 1 m_8µm_x50.tif'
output_file = '/home/khietdang/Documents/khiet/treeRing/masks/4 E 1 m_8µm_x50.tif'
image_path_dataset = tf.data.Dataset.from_tensor_slices([image_file])
image_dataset = image_path_dataset.map(read_image, num_parallel_calls=tf.data.AUTOTUNE).map(prepare_image, num_parallel_calls=tf.data.AUTOTUNE)

ksize_rows = 256
ksize_cols = 256
strides_rows = 196
strides_cols = 196

# image = ds # tfio.experimental.image.decode_tiff(image_data)

# print(image)
# The size of sliding window
ksizes = [1, ksize_rows, ksize_cols, 1] 

# How far the centers of 2 consecutive patches are in the image
strides = [1, strides_rows, strides_cols, 1]

# The document is unclear. However, an intuitive example posted on StackOverflow illustrate its behaviour clearly. 
# http://stackoverflow.com/questions/40731433/understanding-tf-extract-image-patches-for-extracting-patches-from-an-image
rates = [1, 1, 1, 1] # sample pixel consecutively

# padding algorithm to used
padding='SAME' # or 'SAME'

# image = tf.expand_dims(image, 0)
image_patches = tf.image.extract_patches(images=list(image_dataset.take(1)), sizes=ksizes, strides=strides, rates=rates, padding=padding)

columns = image_patches.shape[1]
rows = image_patches.shape[2]

print(columns, rows)
# retrieve the 1st patches
fig = plt.figure(figsize=(columns, rows)) 
fig.tight_layout()
i = 1
for col in range(columns):
    for row in range(rows):
        patch = image_patches[0,col,row,]
        patch = tf.reshape(patch, [ksize_rows, ksize_cols, 3])        
        fig.add_subplot(columns, rows, i) 
        plt.axis('off') 
        plt.imshow(patch)
        i = i + 1
# visualize image

plt.show()
