<a href="https://colab.research.google.com/github/Lucs1590/USeS-BPCA/blob/main/notebooks/u_net_bpca.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# U-net-like with Oxford-IIIT Pet Dataset

## Imports

In [None]:
import os
import time
import math
from datetime import datetime

import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.utils import plot_model
from keras.optimizers import Adam, SGD
from keras.metrics import MeanIoU, Accuracy, MeanSquaredError, IoU
from keras.models import load_model
from keras.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    CSVLogger,
    TensorBoard,
    ReduceLROnPlateau
)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds


In [None]:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

if 'content' in os.getcwd():
    COLAB = True
    from google.colab import drive
    drive.mount('/content/gdrive')
    RESOURCES_DIR = f'{os.path.join(os.getcwd(), os.pardir)}/resources/'
else:
    COLAB = False
    RESOURCES_DIR = f'{(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))}/resources/'

np.random.seed(77)
tf.random.set_seed(77)


## Constant Variables

In [None]:
RESOURCES_DIR = RESOURCES_DIR.replace('..', '').replace('//', '/')
None if os.path.isdir(RESOURCES_DIR) else os.mkdir(RESOURCES_DIR)
LOGS_DIR = f'{RESOURCES_DIR}logs/{datetime.now().strftime("%Y%m%d-%H%M%S")}'
None if os.path.isdir(LOGS_DIR) else os.makedirs(LOGS_DIR, exist_ok=True)

MODEL_NAME = 'bpca_unetlike'

BATCH_SIZE = 64
BUFFER_SIZE = 1000
HEIGHT, WIDTH = 256, 256
NUM_CLASSES = 3  # background, foreground, boundary
NUM_EPOCHS = 500
VAL_SUBSPLITS = 5


In [None]:
RESOURCES_DIR


## Dataset
Download and applying transformations to the dataset.


In [None]:
dataset, info = tfds.load(
    'oxford_iiit_pet:3.*.*',
    with_info=True,
    shuffle_files=True,
    data_dir='/content/gdrive/MyDrive/Projetos/' if COLAB else '/home/brito/tensorflow_datasets/',
)

print(info)


In [None]:
TRAIN_LENGTH = info.splits["train"].num_examples
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

TEST_LENTH = info.splits["test"].num_examples - 669
VALIDATION_STEPS = TEST_LENTH // BATCH_SIZE // VAL_SUBSPLITS
LEARNING_RATE = 1e-4


In [None]:
def resize(input_image, input_mask):
    input_image = tf.image.resize(
        input_image,
        (HEIGHT, WIDTH),
        method="nearest"
    )
    input_mask = tf.image.resize(input_mask, (HEIGHT, WIDTH), method="nearest")

    return input_image, input_mask


In [None]:
def augment(input_image, input_mask):
    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        input_mask = tf.image.flip_left_right(input_mask)

    return input_image, input_mask


In [None]:
def normalize(input_image, input_mask):
    input_image = tf.cast(input_image, tf.float32) / 255.0
    input_mask -= 1
    return input_image, input_mask


In [None]:
def load_image_train(datapoint):
    input_image = datapoint["image"]
    input_mask = datapoint["segmentation_mask"]
    input_image, input_mask = resize(input_image, input_mask)
    input_image, input_mask = augment(input_image, input_mask)
    input_image, input_mask = normalize(input_image, input_mask)

    return input_image, input_mask


In [None]:
def load_image_test(datapoint):
    input_image = datapoint["image"]
    input_mask = datapoint["segmentation_mask"]
    input_image, input_mask = resize(input_image, input_mask)
    input_image, input_mask = normalize(input_image, input_mask)

    return input_image, input_mask


In [None]:
train_dataset = dataset["train"].map(
    load_image_train, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = dataset["test"].map(
    load_image_test, num_parallel_calls=tf.data.AUTOTUNE)

print(train_dataset)


In [None]:
train_batches = train_dataset.cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_batches = train_batches.prefetch(
    buffer_size=tf.data.experimental.AUTOTUNE)
validation_batches = test_dataset.take(3000).batch(BATCH_SIZE)
test_batches = test_dataset.skip(3000).take(669).batch(BATCH_SIZE)


In [None]:
def display(display_list, name=None):
    plt.figure(figsize=(15, 15))

    title = ["Input Image", "True Mask", "Predicted Mask"]

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
        plt.axis("off")
    plt.show()
    if name:
        plt.savefig(f'{name}.png', dpi=300, format='png')


In [None]:
sample_batch = next(iter(train_batches))
random_index = np.random.choice(sample_batch[0].shape[0])
sample_image, sample_mask = sample_batch[0][random_index], sample_batch[1][random_index]
display([sample_image, sample_mask])


## BPCA

In [None]:
class BPCAPooling(tf.keras.layers.Layer):
    def __init__(self, pool_size=2, stride=2, n_components=1, expected_shape=None, **kwargs):
        super(BPCAPooling, self).__init__(**kwargs)
        self.pool_size = pool_size
        self.stride = stride
        self.n_components = n_components
        self.expected_shape = expected_shape

        self.patch_size = [1, self.pool_size, self.pool_size, 1]
        self.strides = [1, self.stride, self.stride, 1]

    def build(self, input_shape):
        super(BPCAPooling, self).build(input_shape)

    @tf.function
    def bpca_pooling(self, feature_map):
        # Compute the region of interest
        h, w, c = self.expected_shape  # block_height, block_width, block_channels
        d = c // (self.pool_size * self.pool_size)  # block_depth

        # Create blocks (patches)
        data = tf.reshape(feature_map, [1, h, w, c])
        patches = tf.image.extract_patches(
            images=data,
            sizes=self.patch_size,
            strides=self.strides,
            rates=[1, 1, 1, 1],
            padding='VALID'
        )
        patches = tf.reshape(
            patches,
            [h*w*d, self.pool_size * self.pool_size]
        )

        # Normalize the data by subtracting the mean and dividing by the standard deviation
        mean = tf.reduce_mean(patches, axis=0)
        std = tf.math.reduce_std(patches, axis=0)
        patches = (patches - mean) / std
        patches = tf.where(tf.math.is_nan(patches), 0.0, patches)

        # Perform the Singular Value Decomposition (SVD) on the data
        _, _, v = tf.linalg.svd(patches)

        # Extract the first n principal components from the matrix v
        pca_components = v[:, :self.n_components]

        # Perform the PCA transformation on the data
        transformed_patches = tf.matmul(patches, pca_components)
        return tf.reshape(transformed_patches, [h // self.pool_size, w // self.pool_size, c])

    def call(self, inputs):
        pooled = tf.vectorized_map(self.bpca_pooling, inputs)
        return pooled


In [None]:
class GlobalBPCAPooling2D(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(GlobalBPCAPooling2D, self).__init__(**kwargs)

    def build(self, input_shape):
        super(GlobalBPCAPooling2D, self).build(input_shape)

    @tf.function
    def bpca_pooling(self, feature_map):
        # Compute the region of interest
        h, w, c = feature_map.shape  # block_height, block_width, block_channels
        pool_size = h
        patch_size = [1, pool_size, pool_size, 1]
        strides = [1, pool_size, pool_size, 1]

        # Create blocks (patches)
        data = tf.reshape(feature_map, [1, h, w, c])
        patches = tf.image.extract_patches(
            images=data,
            sizes=patch_size,
            strides=strides,
            rates=[1, 1, 1, 1],
            padding='VALID'
        )

        patches = tf.reshape(patches, [-1, pool_size * pool_size])

        # Normalize the data by subtracting the mean and dividing by the standard deviation
        mean = tf.reduce_mean(patches, axis=0)
        std = tf.math.reduce_std(patches, axis=0)
        patches = (patches - mean) / std
        patches = tf.where(tf.math.is_nan(patches), 0.0, patches)

        # Perform the Singular Value Decomposition (SVD) on the data
        _, _, v = tf.linalg.svd(patches)

        # Extract the first principal component from the matrix v
        pca_components = v[:, :1]

        # Perform the PCA transformation on the data
        transformed_patches = tf.matmul(patches, pca_components)

        return tf.reshape(transformed_patches, [h // pool_size, w // pool_size, c])

    def call(self, inputs):
        pooled = tf.vectorized_map(self.bpca_pooling, inputs)
        pooled = tf.reshape(pooled, [-1, pooled.shape[-1]])
        return pooled


In [None]:
def calculate_power_inverted_value(x):
    y = 64 / math.sqrt(x)
    return int(y)

def define_crop_values(filters):
    if filters == 64:
        crop = (88, 88)
    elif filters == 128:
        crop = (40, 40)
    elif filters == 256:
        crop = (16, 16)
    elif filters == 512:
        crop = (4, 4)
    else:
        crop = (0, 0)
    return crop

def get_layers_number_unetlike(number):
    if number == 64:
        return (128, 128, 64)
    if number == 128:
        return (64, 64, 128)
    if number == 256:
        return (32, 32, 256)
    if number == 512:
        return (16, 16, 512)
    else:
        return (0, 0, 0)

def get_layers_number_unet(number):
    if number == 64:
        return (256, 256, 64)
    if number == 128:
        return (128, 128, 128)
    if number == 256:
        return (64, 64, 256)
    if number == 512:
        return (32, 32, 512)
    else:
        return (0, 0, 0)

## U-net-like architecture


In [None]:
def get_unetlike_model(img_size, num_classes):
    inputs = keras.Input(shape=img_size + (3,))

    ### [First half of the network: downsampling inputs] ###

    # Entry block
    x = keras.layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    # Blocks 1, 2, 3 are identical apart from the feature depth.
    for filters in [64, 128, 256]:
        x = keras.layers.Activation("relu")(x)
        x = keras.layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = keras.layers.BatchNormalization()(x)

        x = keras.layers.Activation("relu")(x)
        x = keras.layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = keras.layers.BatchNormalization()(x)

        # x = keras.layers.MaxPooling2D(2, strides=2, padding="same")(x)
        # print(filters)
        x = BPCAPooling(pool_size=2, stride=2, expected_shape=get_layers_number_unetlike(filters))(x)
        # print(filters, x.shape)

        # Project residual
        residual = keras.layers.Conv2D(filters, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = keras.layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    ### [Second half of the network: upsampling inputs] ###

    for filters in [256, 128, 64, 32]:
        x = keras.layers.Activation("relu")(x)
        x = keras.layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = keras.layers.BatchNormalization()(x)

        x = keras.layers.Activation("relu")(x)
        x = keras.layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = keras.layers.BatchNormalization()(x)

        x = keras.layers.UpSampling2D(2)(x)

        # Project residual
        residual = keras.layers.UpSampling2D(2)(previous_block_activation)
        residual = keras.layers.Conv2D(filters, 1, padding="same")(residual)
        x = keras.layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    # Add a per-pixel classification layer
    outputs = keras.layers.Conv2D(num_classes, 3, activation="softmax", padding="same")(
        x
    )

    # Define the model
    model = keras.Model(inputs, outputs)
    return model


## U-Net architecture

In [None]:
def get_unet_model(img_size, num_classes):
    def convolution_block(input, num_filters):
        x = layers.Conv2D(num_filters, 3, padding='same')(input)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)

        x = layers.Conv2D(num_filters, 3, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)

        return x

    def encoder_block(input, num_filters):
        conv_layer = convolution_block(input, num_filters)
        # pooling = layers.MaxPooling2D((2, 2))(conv_layer)
        pooling = BPCAPooling(pool_size=2, stride=2, expected_shape=get_layers_number_unet(num_filters))(conv_layer)
        return conv_layer, pooling

    def decoder_block(input, skip_features, num_filters):
        x = layers.Conv2DTranspose(
            num_filters, (2, 2), strides=2, padding='same')(input)
        x = layers.Concatenate()([x, skip_features])
        x = convolution_block(x, num_filters)
        return x

    inputs = layers.Input(shape=img_size + (3,))
    filters = [64, 128, 256, 512]
    saved_layers = []

    pooling = inputs
    for filter in filters:
        conv_layer, pooling = encoder_block(pooling, filter)
        saved_layers.append(conv_layer)

    conv_block = convolution_block(pooling, 1024)

    deconv_layer = conv_block
    for filter in reversed(filters):
        conv_layer = saved_layers.pop()
        deconv_layer = decoder_block(deconv_layer, conv_layer, filter)

    outputs = layers.Conv2D(num_classes, 3, padding='same',
                            activation='softmax')(deconv_layer)

    model = keras.Model(inputs, outputs, name='UNet')
    return model


## Model Selection

In [None]:
# model = get_unetlike_model(img_size=(HEIGHT, WIDTH), num_classes=NUM_CLASSES)
model = get_unet_model(img_size=(HEIGHT, WIDTH), num_classes=NUM_CLASSES)
# model = load_model("/home/hinton/brito/models/max_unetlike.h5")
# model_history = pd.read_csv('/home/hinton/brito/models/max_unetlike.csv')
# model_history.head()


In [None]:
model.summary()


In [None]:
for i, layer in enumerate(model.layers):
    print(i, layer.name, layer.trainable)


In [None]:
# plot_model(
#     model,
#     to_file=f'{RESOURCES_DIR}model.png',
#     show_shapes=True,
#     show_layer_names=True,
#     rankdir='TB'
# )


## Training & Testing

Metrics

In [None]:
def mean_iou(y_true, y_pred):
    intersection = tf.reduce_sum(y_pred * y_true, axis=(1, 2))
    union = tf.reduce_sum(y_pred + y_true, axis=(1, 2)) - intersection
    return tf.reduce_mean(intersection / union)


In [None]:
def dice_coefficient(y_true, y_pred, smooth=1):
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3])
    dice = tf.reduce_mean((2. * intersection + smooth) / (union + smooth),axis=0)
    return dice


In [None]:
def pixel_accuracy(y_true, y_pred):
    return tf.reduce_mean(tf.cast(
        tf.equal(y_true, y_pred),
        tf.float32
    ))


In [None]:
optimizer = Adam(learning_rate=LEARNING_RATE)
# optimizer = SGD(learning_rate=LEARNING_RATE)

model.compile(
    optimizer=optimizer,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy", dice_coefficient, mean_iou]
)


Callbacks

In [None]:
model_checkpointer = ModelCheckpoint(
    f'{RESOURCES_DIR}{MODEL_NAME}.h5',
    monitor='val_mean_iou',
    verbose=1,
    save_best_only=True,
    save_weights_only=False,
    mode='max'
)
store_history = CSVLogger(f'{RESOURCES_DIR}{MODEL_NAME}.csv', append=True)
tensorboard_callback = TensorBoard(
    log_dir=LOGS_DIR,
    histogram_freq=1,
    write_graph=True,
    write_images=True,
    update_freq='epoch'
)
early_stopping = EarlyStopping(
    monitor='val_mean_iou',
    min_delta=0,
    mode='auto',
    verbose=1,
    patience=100
)
learning_rate_reducer = ReduceLROnPlateau(
    monitor='val_mean_iou',
    mode='auto',
    verbose=1,
    min_lr=1e-5,
    patience=25
)


TensorBoard

In [None]:
%tensorboard --logdir = {LOGS_DIR}  # typo:ignore
# kill $(lsof -i:6006)


Training

In [None]:
start_time = time.perf_counter()
with tf.device('/gpu:0'):
    model_history = model.fit(
        train_batches,
        epochs=NUM_EPOCHS,
        steps_per_epoch=STEPS_PER_EPOCH,
        validation_steps=VALIDATION_STEPS,
        validation_data=validation_batches,
        verbose=1,

        callbacks=[
            model_checkpointer,
            store_history,
            tensorboard_callback,
            learning_rate_reducer,
            early_stopping
        ]
    )
end_time = time.perf_counter()


In [None]:
total_time = ((end_time - start_time) / 60)
print(f'Time to train: {str("{0:.2f}".format(total_time))}')


In [None]:
model.save(f'{RESOURCES_DIR}{MODEL_NAME}_last_epoch.h5')
model_json = model.to_json()
with open(f'{RESOURCES_DIR}{MODEL_NAME}.json', "w") as json_file:
    json_file.write(model_json)


Testing

In [None]:
loss, accuracy, m_iou, dice = model.evaluate(test_batches, verbose=1)
print("Loss:", loss)
print("Accuracy: %.2f%%" % (accuracy * 100))
print("Accuracy: %.2f%%" % (m_iou * 100))
print("Accuracy: %.2f%%" % (dice * 100))


In [None]:
if isinstance(model_history, pd.DataFrame):
    if 'loss' in model_history.columns:
        plt.plot(model_history['loss'])
        plt.plot(model_history['val_loss'])
        plt.legend(['train', 'test'])
        plt.title('loss')
        plt.legend(["Loss", "Validation Loss"])
        plt.savefig("loss.png", dpi=300, format="png")

    if 'accuracy' in model_history.columns:
        plt.figure()
        plt.plot(model_history["accuracy"])
        plt.plot(model_history['val_accuracy'])
        plt.legend(['train', 'test'])
        plt.title('accuracy')
        plt.legend(["Accuracy", "Validation Accuracy"])
        plt.savefig("accuracy.png", dpi=300, format="png")

    if 'mean_iou' in model_history.columns:
        plt.figure()
        plt.plot(model_history["mean_iou"])
        plt.plot(model_history['val_mean_iou'])
        plt.legend(['train', 'test'])
        plt.title('mean_iou')
        plt.legend(["MeanIoU", "Validation MeanIoU"])
        plt.savefig("mean_iou.png", dpi=300, format="png")

    if 'dice_coefficient' in model_history.columns:
        plt.figure()
        plt.plot(model_history["dice_coefficient"])
        plt.plot(model_history['val_dice_coefficient'])
        plt.legend(['train', 'test'])
        plt.title('dice_coefficient')
        plt.legend(["DiceCoefficient", "Validation DiceCoefficient"])
        plt.savefig("dice_coefficient.png", dpi=300, format="png")

else:
    if 'loss' in model_history.history:
        plt.plot(model_history.history['loss'])
        plt.plot(model_history.history['val_loss'])
        plt.legend(['train', 'test'])
        plt.title('loss')
        plt.legend(["Loss", "Validation Loss"])
        plt.savefig("loss.png", dpi=300, format="png")

    if 'accuracy' in model_history.history:
        plt.figure()
        plt.plot(model_history.history["accuracy"])
        plt.plot(model_history.history['val_accuracy'])
        plt.legend(['train', 'test'])
        plt.title('accuracy')
        plt.legend(["Accuracy", "Validation Accuracy"])
        plt.savefig("accuracy.png", dpi=300, format="png")

    if 'mean_iou' in model_history.history:
        plt.figure()
        plt.plot(model_history.history["mean_iou"])
        plt.plot(model_history.history['val_mean_iou'])
        plt.legend(['train', 'test'])
        plt.title('mean_iou')
        plt.legend(["MeanIoU", "Validation MeanIoU"])
        plt.savefig("mean_iou.png", dpi=300, format="png")

    if 'dice_coefficient' in model_history.history:
        plt.figure()
        plt.plot(model_history.history["dice_coefficient"])
        plt.plot(model_history.history['val_dice_coefficient'])
        plt.legend(['train', 'test'])
        plt.title('dice_coefficient')
        plt.legend(["DiceCoefficient", "Validation DiceCoefficient"])
        plt.savefig("dice_coefficient.png", dpi=300, format="png")


## Prediction

Image, Ground Truth and Feature Maps

In [None]:
def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]


In [None]:
def show_predictions(dataset=None, num=1):
    if dataset:
        for image, mask in dataset.take(num):
            pred_mask = model.predict(image)
            display([image[0], mask[0], create_mask(pred_mask)], name=f'prediction{num}')
    else:
        display([
            sample_image,
            sample_mask,
            create_mask(model.predict(sample_image[tf.newaxis, ...]))
        ])


In [None]:
show_predictions(test_batches.skip(5), 3)


In [None]:
params = {
    'optimizer': optimizer,
    'val_accuracy': accuracy,
    'val_miou': m_iou,
    'val_dice': dice,
    'val_loss': loss,
    'epochs': NUM_EPOCHS,
    'steps_by_epochs': STEPS_PER_EPOCH,
    'validation_steps': VALIDATION_STEPS,
    'batch_size': BATCH_SIZE,
    'learning_rate': LEARNING_RATE,
    'time': total_time,
}
print(params)


LIME