# Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image

import keras
from keras.models import Model, load_model
from keras.layers import Conv2D, MaxPooling2D, Input, Conv2DTranspose, Concatenate, BatchNormalization, UpSampling2D
from keras.layers import  Dropout, Activation
from keras.optimizers import Adam, SGD
from keras.layers import LeakyReLU
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from keras import backend as K
from keras.utils import plot_model
import tensorflow as tf

import glob
import random
import cv2
from random import shuffle

In [None]:
# Keras version: 3.8.0
# Tf version: 2.18.0

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

## Params

In [None]:
PROJECT_PATH = '/content/drive/MyDrive/Colab Notebooks/landmark_segmentation/' # Model is saved at the root of the project

IMG_PATH = PROJECT_PATH+ 'datasets/unet/512_50/images/'
TARGET_PATH = PROJECT_PATH + 'datasets/unet/512_50/targets/'

MODEL_WEIGTH_PATH = PROJECT_PATH + '/models/unet.weights.h5'
MODEL_BEST_WEIGHTS_PATH = PROJECT_PATH + "/models/unet_best.weights.h5"

IMG_PATH_TEST = PROJECT_PATH + 'datasets/test_set/pign_z14/image/'
MSK_PATH_TEST = PROJECT_PATH + 'datasets/test_set/pign_z14/mask/'
PRED_PATH = PROJECT_PATH + 'datasets/test_set/pign_z14/pred_unet/'

TILE_SIZE = (512, 512)

BATCH_SIZE = 32

TRAIN_TEST_SPLIT = 0.90

## Train / valid split

In [None]:
all_files = os.listdir(IMG_PATH)
shuffle(all_files)

split = int(TRAIN_TEST_SPLIT * len(all_files))

#split into training and validation
train_files = all_files[0:split]
valid_files  = all_files[split:]

In [None]:
print(f"Train files: {len(train_files)}")
print(f"Valid files: {len(valid_files)}")

## Generators

In [None]:
def image_generator(files, batch_size = 32, sz = (256, 256)):
  '''
  Generates batches of image and mask data for training or validation.

  Args:
      files (list): List of image filenames.
      batch_size (int): Number of samples per batch.
      sz (tuple): Target size for resizing images and masks.

  Yields:
      tuple: A batch of image and mask data.
            - batch_x: NumPy array of preprocessed images.
            - batch_y: NumPy array of preprocessed masks.
  '''

  while True:
    #extract a random batch
    batch = np.random.choice(files, size = batch_size)

    #variables for collecting batches of inputs and outputs
    batch_x = []
    batch_y = []

    for f in batch:

        #preprocess the raw images
        raw = Image.open(IMG_PATH + f)
        raw = raw.resize(sz)
        raw = np.array(raw)
        raw = raw[:,:,0:3] # remove transparency and keep only RGB
        batch_x.append(raw)

        #preprocess the masks
        mask = Image.open(TARGET_PATH + f'{f[:-4]}_mask.png')
        mask = mask.convert("L") # L = grayscale, was LA, remove transparence
        mask = np.array(mask.resize(sz))
        mask[mask == 0] = 0
        mask[mask == 255 ] = 1
        batch_y.append(mask)

    #preprocess a batch of images and masks
    batch_x = np.array(batch_x)/255.
    batch_y = np.array(batch_y)
    batch_y = np.expand_dims(batch_y,3) # add channel dimension at the end

    yield (batch_x, batch_y)


In [None]:
train_generator = image_generator(train_files, batch_size = BATCH_SIZE, sz = TILE_SIZE)
test_generator  = image_generator(valid_files, batch_size = BATCH_SIZE, sz = TILE_SIZE)

In [None]:
x, y= next(train_generator)

In [None]:
plt.axis('off')
img = x[0]
msk = y[0].squeeze()
msk = np.stack((msk,)*3, axis=-1) # match img dimension to allow concatenation

plt.imshow(np.concatenate([img, msk, img*msk], axis = 1))

## IoU

In [None]:
def mean_iou(y_true, y_pred):
    yt0 = tf.cast(y_true[:,:,:,0], 'float32') # Cast yt0 to float32
    yp0 = tf.cast(y_pred[:,:,:,0] > 0.5, 'float32')
    inter = tf.math.count_nonzero(tf.logical_and(tf.equal(yt0, 1), tf.equal(yp0, 1)))
    union = tf.math.count_nonzero(tf.add(yt0, yp0))
    iou = tf.where(tf.equal(union, 0), 1., tf.cast(inter/union, 'float32'))
    return iou

# Model

In [None]:
def unet(sz = (256, 256, 3)):
  x = Input(sz)
  inputs = x

  #down sampling
  f = 8
  layers = []

  for i in range(0, 6):
    x = Conv2D(f, 3, activation='relu', padding='same') (x)
    x = Conv2D(f, 3, activation='relu', padding='same') (x)
    layers.append(x)
    x = MaxPooling2D() (x)
    f = f*2
  ff2 = 64

  #bottleneck
  j = len(layers) - 1
  x = Conv2D(f, 3, activation='relu', padding='same') (x)
  x = Conv2D(f, 3, activation='relu', padding='same') (x)
  x = Conv2DTranspose(ff2, 2, strides=(2, 2), padding='same') (x)
  x = Concatenate(axis=3)([x, layers[j]])
  j = j -1

  #upsampling
  for i in range(0, 5):
    ff2 = ff2//2
    f = f // 2
    x = Conv2D(f, 3, activation='relu', padding='same') (x)
    x = Conv2D(f, 3, activation='relu', padding='same') (x)
    x = Conv2DTranspose(ff2, 2, strides=(2, 2), padding='same') (x)
    x = Concatenate(axis=3)([x, layers[j]])
    j = j -1


  #classification
  x = Conv2D(f, 3, activation='relu', padding='same') (x)
  x = Conv2D(f, 3, activation='relu', padding='same') (x)
  outputs = Conv2D(1, 1, activation='sigmoid') (x)

  #model creation
  model = Model(inputs=[inputs], outputs=[outputs])
  model.compile(optimizer = 'rmsprop', loss = 'binary_crossentropy', metrics = [mean_iou])

  return model

In [None]:
model = unet(sz=(TILE_SIZE[0], TILE_SIZE[1], 3))

In [None]:
model.summary() # Display the model's architecture

# Callbacks

Simple functions to save the model at each epoch and show some predictions

In [None]:
#def build_callbacks():
#        checkpointer = ModelCheckpoint(filepath=MODEL_WEIGTH_PATH, verbose=0, save_best_only=True, save_weights_only=True)
#        callbacks = [checkpointer, PlotLearning(), CheckPoint]
#       return callbacks

# inheritance for training process plot
class PlotLearning(keras.callbacks.Callback):

    def on_train_begin(self, logs={}):
        self.i = 0
        self.x = []
        self.losses = []
        self.val_losses = []
        self.acc = []
        self.val_acc = []
        #self.fig = plt.figure()
        self.logs = []

    def on_epoch_end(self, epoch, logs={}):
        self.logs.append(logs)
        self.x.append(self.i)
        self.losses.append(logs.get('loss'))
        self.val_losses.append(logs.get('val_loss'))
        self.acc.append(logs.get('mean_iou'))
        self.val_acc.append(logs.get('val_mean_iou'))
        self.i += 1

        #print info
        print('i=',self.i,'loss=',logs.get('loss'),'val_loss=',logs.get('val_loss'),'mean_iou=',logs.get('mean_iou'),'val_mean_iou=',logs.get('val_mean_iou'))

        #display sample predictions
        for i in range(0,3):
          #choose a random test image and preprocess
          path = np.random.choice(valid_files)
          raw = Image.open(IMG_PATH + path)
          raw = np.array(raw.resize(TILE_SIZE))/255.
          raw = raw[:,:,0:3]

          #predict the mask
          pred = model.predict(np.expand_dims(raw, 0))

          #mask post-processing
          msk  = pred.squeeze()
          msk = np.stack((msk,)*3, axis=-1)
          #msk[msk >= 0.5] = 1
          #msk[msk < 0.5] = 0

          #show the mask and the segmented image
          combined = np.concatenate([raw, msk, raw* msk], axis = 1)
          plt.axis('off')
          plt.imshow(combined)
          plt.show()



In [None]:
class CheckpointCallback(keras.callbacks.Callback):  # Callback for delayed checkpoint saving
    def __init__(self, filepath, start_saving_epoch=20, **kwargs):
        super().__init__(**kwargs)
        self.filepath = filepath
        self.start_saving_epoch = start_saving_epoch
        self.best_val_loss = float('inf')  # Initialize best validation loss

    def on_epoch_end(self, epoch, logs=None):
        current_val_loss = logs.get('val_loss')
        if epoch + 1 >= self.start_saving_epoch and current_val_loss < self.best_val_loss:
            self.best_val_loss = current_val_loss
            self.model.save_weights(self.filepath)
            print(f"Saving weights at epoch {epoch + 1} with val_loss: {current_val_loss}")

# Training

In [None]:
import time

In [None]:
train_steps = len(train_files) //BATCH_SIZE
test_steps = len(valid_files) //BATCH_SIZE

start_time = time.time()

history = model.fit(train_generator,
                    epochs = 150, steps_per_epoch = train_steps,validation_data = test_generator, validation_steps = test_steps,
                    callbacks = [PlotLearning(), CheckpointCallback(MODEL_BEST_WEIGHTS_PATH, start_saving_epoch=10)],
                    verbose = 0)

end_time = time.time()
training_time = end_time - start_time
print(f"Temps d'entraînement : {training_time / 60:.2f} minutes")

# Save model


In [None]:
model.save_weights(MODEL_WEIGTH_PATH)

## Curves

In [None]:
# print hitory keys
print(history.history.keys())

In [None]:
# plot loss and val loss from history
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('U-net loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
# start y axe at 0
plt.ylim(0, 0.6)
plt.show()

In [None]:
# plot mean iou and val_mean_iou
plt.plot(history.history['mean_iou'])
plt.plot(history.history['val_mean_iou'])
plt.title('U-net mean_iou')
plt.ylabel('mean_iou')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()

# Testing on benchmark

### Code tiles

In [None]:
# Cut tiles from evaluation images
def extract_tiles(image: Image.Image, tile_size=512, n_rows=3, n_cols=5):
    width, height = image.size
    stride_x = (width - tile_size) // (n_cols - 1)
    stride_y = (height - tile_size) // (n_rows - 1)

    tiles = []
    positions = []

    for row in range(n_rows):
        for col in range(n_cols):
            x = col * stride_x
            y = row * stride_y
            tile = image.crop((x, y, x + tile_size, y + tile_size))
            tiles.append(tile)
            positions.append((x, y))

    return tiles, positions

In [None]:
# Reconstruct image from predictions
def reconstruct_mask(tile_preds, positions, final_size=(890, 1920), tile_size=512):
    pred_sum = np.zeros(final_size, dtype=np.float32)
    pred_count = np.zeros(final_size, dtype=np.float32)

    for pred, (x, y) in zip(tile_preds, positions):
        h, w = pred.shape
        pred_sum[y:y+h, x:x+w] += pred
        pred_count[y:y+h, x:x+w] += 1.0

    # Moyenne par pixel
    averaged_pred = pred_sum / np.maximum(pred_count, 1e-6)  # évite la division par 0

    return averaged_pred  # image de même taille que l’originale

### Make predictions

In [None]:
# Load model
new_model = unet(sz=(TILE_SIZE[0], TILE_SIZE[1], 3))
new_model.load_weights(MODEL_BEST_WEIGHTS_PATH)

In [None]:
# for each image in the evaluation dataset, predict a mask and save it in the project benchmark folder
EVAL_PATH = PROJECT_PATH + 'eval/base_images/'
PRED_FOLDER = PROJECT_PATH + 'eval/pred_unet/'

for f in os.listdir(EVAL_PATH):
    # load image and create tiles
    raw = Image.open(EVAL_PATH + f)
    tiles, positions = extract_tiles(raw)

    # For each tile, make a prediction
    tile_preds = []
    for tile in tiles:
        # preprocess
        tile = np.array(tile)/255.
        tile = tile[:,:,0:3]
        # predict
        pred = new_model.predict(np.expand_dims(tile, 0), verbose=0) # add batch dim (model needs it)
        tile_preds.append(pred.squeeze()) # remove dimensions of size 1 to keep only (512, 512)

    # reconstruct final mask and save it in greyscale
    proba = reconstruct_mask(tile_preds, positions, tile_size=TILE_SIZE[0]) # 2d array of probabilities
    proba = (proba * 255).astype(np.uint8)
    #proba_im = Image.fromarray(proba.astype(np.uint8), mode="L")
    #proba_im.save(PRED_FOLDER + "proba/" + f[:-4] + '_unet_proba.png')
    #print(f"Saved {PRED_FOLDER + 'proba/' + f[:-4] + '_unet_proba.png'}")

    msk = proba
    msk[msk >= 128] = 255
    msk[msk < 128] = 0
    msk = Image.fromarray(msk.astype(np.uint8), mode="L")
    msk.save(PRED_FOLDER + "pred/" + f[:-4] + '_unet_mask.png')

    #msk = np.stack((msk,)*3, axis=-1)

In [None]:
# Combine

# save combination of logits and image and show it
combined = raw * logits
concat = np.concatenate([raw, logits, combined], axis = 1)
plt.axis('off')
plt.imshow(combined)
plt.show()

#save combination
im_combined = Image.fromarray((combined * 255).astype(np.uint8))
im_combined.save(PRED_COMBINED_PATH + f)


# References


1.   http://deeplearning.net/tutorial/unet.html
2.   https://github.com/ldenoue/keras-unet

