<a href="https://colab.research.google.com/github/Lucs1590/USeS-BPCA/blob/main/notebooks/u_net_bpca.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# U-net-like with Oxford-IIIT Pet Dataset

## Imports

In [None]:
import os
import glob


import numpy as np
import pandas as pd
import tensorflow as tf

from keras.models import load_model

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

import xplique
from xplique.plots import plot_attributions
from xplique.utils_functions.segmentation import get_connected_zone, get_in_out_border, get_common_border
from xplique.metrics import Deletion, MuFidelity, Insertion, AverageStability
from xplique.plots.metrics import barplot
from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad, SquareGrad,
                                  Occlusion, Rise, SobolAttributionMethod, HsicAttributionMethod)

from xplique.plots import plot_attributions

import tensorflow_datasets as tfds

In [None]:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

np.random.seed(77)
tf.random.set_seed(77)

## Constant Variables

In [None]:
HEIGHT, WIDTH = 256, 256
NUM_CLASSES = 3  # background, foreground, boundary
BATCH_SIZE = 64

## Dataset
Download and applying transformations to the dataset.


In [None]:
dataset, info = tfds.load(
    'oxford_iiit_pet:3.*.*',
    with_info=True,
    shuffle_files=True,
    data_dir='/home/hinton/brito/datasets/'
)

print(info)


In [None]:
classes_dict = {str(i): [label] for i, label in enumerate(info.features['label'].names)}
print(classes_dict)

In [None]:
def resize(input_image, input_mask):
    input_image = tf.image.resize(
        input_image,
        (HEIGHT, WIDTH),
        method="nearest"
    )
    input_mask = tf.image.resize(input_mask, (HEIGHT, WIDTH), method="nearest")

    return input_image, input_mask

In [None]:
def normalize(input_image, input_mask):
    input_image = tf.cast(input_image, tf.float32) / 255.0
    input_mask -= 1
    return input_image, input_mask

In [None]:
def load_image_test(datapoint):
    input_image = datapoint["image"]
    input_mask = datapoint["segmentation_mask"]
    input_image, input_mask = resize(input_image, input_mask)
    input_image, input_mask = normalize(input_image, input_mask)

    return input_image, input_mask

In [None]:
test_dataset = dataset["test"].take(1632).map(load_image_test, num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
validation_batches = test_dataset.take(963).batch(BATCH_SIZE)
test_batches = test_dataset.skip(963).take(669).batch(BATCH_SIZE)

In [None]:
def display(display_list, name=None):
    plt.figure(figsize=(15, 15))

    title = ["Imagem de Entrada", "Máscara", "Máscara Predita"]

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
        plt.axis("off")

    if name:
        plt.savefig(f"{name}.png", format="png", dpi=300,
                    bbox_inches='tight', pad_inches=0.0, transparent=True)

    plt.show()

## BPCA

In [None]:
class BPCAPooling(tf.keras.layers.Layer):
    def __init__(self, pool_size=2, stride=2, n_components=1, expected_shape=None, **kwargs):
        super(BPCAPooling, self).__init__(**kwargs)
        self.pool_size = pool_size
        self.stride = stride
        self.n_components = n_components
        self.expected_shape = expected_shape

        self.patch_size = [1, self.pool_size, self.pool_size, 1]
        self.strides = [1, self.stride, self.stride, 1]

    def build(self, input_shape):
        super(BPCAPooling, self).build(input_shape)

    @tf.function
    def bpca_pooling(self, feature_map):
        # Compute the region of interest
        h, w, c = self.expected_shape  # block_height, block_width, block_channels
        d = c // (self.pool_size * self.pool_size)  # block_depth

        # Create blocks (patches)
        data = tf.reshape(feature_map, [1, h, w, c])
        patches = tf.image.extract_patches(
            images=data,
            sizes=self.patch_size,
            strides=self.strides,
            rates=[1, 1, 1, 1],
            padding='VALID'
        )
        patches = tf.reshape(
            patches,
            [h*w*d, self.pool_size * self.pool_size]
        )

        # Normalize the data by subtracting the mean and dividing by the standard deviation
        mean = tf.reduce_mean(patches, axis=0)
        std = tf.math.reduce_std(patches, axis=0)
        patches = (patches - mean) / std
        patches = tf.where(tf.math.is_nan(patches), 0.0, patches)
        
        # Perform the Singular Value Decomposition (SVD) on the data
        _, _, v = tf.linalg.svd(patches)

        # Extract the first n principal components from the matrix v
        pca_components = v[:, :self.n_components]

        # Perform the PCA transformation on the data
        transformed_patches = tf.matmul(patches, pca_components)

        
        return tf.reshape(transformed_patches, [h // self.pool_size, w // self.pool_size, c])

    def call(self, inputs):
        pooled = tf.vectorized_map(self.bpca_pooling, inputs)
        return pooled

## Metrics

In [None]:
def mean_iou(y_true, y_pred):
    y_true = tf.cast(y_true > 0.5, tf.int32) 
    y_pred = tf.cast(y_pred > 0.5, tf.int32)
    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection
    iou = intersection / union
    return iou

In [None]:
def dice_coefficient(y_true, y_pred):
    smooth = 1.0  # to avoid division by zero
    y_true = tf.cast(y_true > 0.5, tf.float32)
    y_pred = tf.cast(y_pred > 0.5, tf.float32)

    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred)
    dice_coefficient = (2.0 * intersection + smooth) / (union + smooth)
    return dice_coefficient

In [None]:
def pixel_accuracy(y_true, y_pred):
    y_true = tf.cast(y_true, tf.dtypes.float64)
    y_pred = tf.cast(y_pred, tf.dtypes.float64)
    return tf.reduce_mean(tf.cast(
        tf.equal(y_true, y_pred),
        tf.float32
    ))

## Plots

In [None]:
def plot_metrics(model_history, output_dir):
    if 'loss' in model_history.columns:
        plt.plot(model_history['loss'])
        plt.plot(model_history['val_loss'])
        plt.legend(['train', 'test'])
        plt.title('loss')
        plt.legend(["Loss", "Loss de Validação"])
        plt.savefig(f"{output_dir}loss.png", dpi=300, format="png")

    if 'accuracy' in model_history.columns:
        plt.figure()
        plt.plot(model_history["accuracy"])
        plt.plot(model_history['val_accuracy'])
        plt.legend(['train', 'test'])
        plt.title('accuracy')
        plt.legend(["Acurácia", "Acurácia de Validação"])
        plt.savefig(f"{output_dir}accuracy.png", dpi=300, format="png")

    if 'mean_iou' in model_history.columns:
        plt.figure()
        plt.plot(model_history["mean_iou"])
        plt.plot(model_history['val_mean_iou'])
        plt.legend(['train', 'test'])
        plt.title('mean_iou')
        plt.legend(["MeanIoU", "MeanIoU de Validação"])
        plt.savefig(f"{output_dir}mean_iou.png", dpi=300, format="png")

    if 'dice_coefficient' in model_history.columns:
        plt.figure()
        plt.plot(model_history["dice_coefficient"])
        plt.plot(model_history['val_dice_coefficient'])
        plt.legend(['train', 'test'])
        plt.title('dice_coefficient')
        plt.legend(["DiceCoefficient", "DiceCoefficient de Validação"])
        plt.savefig(f"{output_dir}dice_coefficient.png", dpi=300, format="png")

In [None]:
def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

## XAI

In [None]:
class ModelWrapper(tf.keras.Model):
    # WARNING: `torch.nn.Module` specific to pytorch
    # `tf.keras.Model` instead for tensorflow models

    def __init__(self, model):
        super(ModelWrapper, self).__init__()
        self.model = model.eval()

    def __call__(self, torch_inputs):
        # this method should change depending on the model
        return self.model(torch_inputs)


In [None]:
def run_patch_segment(image, model, output_dir):
    categories = ['Pet', 'Fundo', 'Contorno']

    alpha = 0.6
    colormap = np.asarray(plt.get_cmap('tab20').colors)
    idx_to_class = {i: c for (i, c) in enumerate(categories)}

    # Use the model to predict the segmentation mask
    pred_mask = model.predict(image[tf.newaxis, ...])
    pred_seg = create_mask(pred_mask)

    # Initialize an empty 3D array (`color_seg`) with the same height and width as the predicted segmentation map, and 3 color channels.
    color_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3))

    # Loop over each color in the colormap. For each color, update the corresponding pixels in `color_seg` where the predicted label matches the current label. If there are any pixels with the current label in the predicted segmentation map, create a patch object with the current color and label, and append it to the `handles` list for the legend.
    handles = []
    # Limit the colormap to the number of categories
    for (i, color) in enumerate(colormap[:len(categories)]):
        # Update the pixels in `color_seg` where the predicted label matches the current label
        color_seg[pred_seg[:, :, 0] == i] = color

        # Create a patch object with the current color and label
        patch = mpatches.Patch(color=color, label=idx_to_class[i])
        handles.append(patch)

    # Overlay the color-coded segmentation map (`color_seg`) on the original image with a certain transparency level (`alpha`).
    # And Display the overlaid image using `plt.imshow()`, add a legend using `plt.legend()` with the handles created earlier, and add a grid using `plt.grid()`.
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    plt.imshow(color_seg, alpha=alpha)
    plt.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid()
    plt.savefig(f"{output_dir}overlayed_segmentation.png", format="png",
                dpi=300, bbox_inches='tight', pad_inches=0.0, transparent=True)

    return pred_seg

In [None]:
def plot_segmentation_zone(image, pred_seg, model, output_dir):
    alpha = 0.6

    pet_zone_targets = get_connected_zone(
        pred_seg,
        coordinates=(250, 250)
    )[tf.newaxis]

    # compute explanation on this zone via HSIC method
    explainer = HsicAttributionMethod(
        model,
        operator=xplique.Tasks.SEMANTIC_SEGMENTATION,
        nb_design=750,
        grid_size=12,
        batch_size=BATCH_SIZE
    )
    explanation = explainer.explain(
        image[tf.newaxis],
        pet_zone_targets
    )

    # add mask to image for visualization (optional)
    pet_mask = tf.cast(pet_zone_targets != 0, tf.float32)
    image_with_mask = (1 - alpha) * image + alpha * pet_mask

    # visualize explanation
    plot_attributions(
        explanation,
        image_with_mask,
        img_size=4.,
        cmap='jet',
        alpha=0.3,
        absolute_value=False,
        clip_percentile=0.5
    )
    # save image adjusted
    plt.savefig(f"{output_dir}overlayed_segmentation_zone.png", format="png",
                dpi=300, bbox_inches='tight', pad_inches=0.0, transparent=True)

In [None]:
def run_border_segmentations(image, predictions, model, output_dir):
    assert len(predictions.shape) == 3

    # specify the coordinate of a point in the zone to explain
    pet_coordinates = (200, 50)
    background_coordinates = (200, 50)

    # compute the `targets` parameter to explain the specified zones
    pet_zone_predictions = get_connected_zone(predictions, pet_coordinates)
    background_zone_predictions = get_connected_zone(
        predictions, background_coordinates)

    # compute the `targets` parameter to explain the border of the specified zones
    pet_borders_predictions = get_in_out_border(pet_zone_predictions)
    background_borders_predictions = get_in_out_border(
        background_zone_predictions)

    # compute the `targets` parameter to explain the border between two specified zones
    common_border_predictions = get_common_border(
        pet_zone_predictions,
        background_borders_predictions
    )

    # tile and stack tensors to call the method once by image
    images = tf.tile(image[tf.newaxis], (5, 1, 1, 1))
    inputs = tf.tile(image[tf.newaxis], (5, 1, 1, 1))

    targets = tf.stack([
        pet_zone_predictions,
        background_zone_predictions,
        pet_borders_predictions,
        background_zone_predictions,
        common_border_predictions
    ])

    # add the zone mask to the image to visualize
    mask_alpha = 0.5
    masks = tf.expand_dims(tf.cast(tf.reduce_any(
        targets != 0, axis=-1), tf.float32), -1)
    images_with_masks = (1 - mask_alpha) * images + mask_alpha * masks

    explainers = {
        Saliency: {},
        GradientInput: {},
        IntegratedGradients: {"steps": 20},
        SmoothGrad: {"nb_samples": 50, "noise": 0.75},
        VarGrad: {"nb_samples": 50, "noise": 0.75},
        SquareGrad: {"nb_samples": 100, "noise": 0.5},
        Occlusion: {"patch_size": 40, "patch_stride": 10, "occlusion_value": 0},
        Rise: {"nb_samples": 4000, "grid_size": 13},
        SobolAttributionMethod: {"nb_design": 32, "grid_size": 13},
        HsicAttributionMethod: {"nb_design": 1500, "grid_size": 13}
    }

    explanations = {}
    for explainer_class, params in explainers.items():
        tf.keras.backend.clear_session()
        plt.clf()
        print(explainer_class.__name__)

        # instanciate explainer
        explainer = explainer_class(
            model,
            operator=xplique.Tasks.SEMANTIC_SEGMENTATION,
            batch_size=BATCH_SIZE,
            **params
        )

        # compute explanations
        explanation = explainer(inputs, targets)

        # show explanations for a method
        plot_attributions(
            explanation,
            images_with_masks,
            img_size=4.,
            cols=images.shape[0],
            cmap='jet',
            alpha=0.3,
            absolute_value=False,
            clip_percentile=0.5
        )
        plt.show()
        plt.savefig(f"{output_dir}{explainer_class.__name__}.png", format="png",
                    dpi=300, bbox_inches='tight', pad_inches=0.0, transparent=True)

        # keep explanations in memory for metrics
        explanations[explainer_class.__name__] = explanation

    return explanations, inputs, targets

In [None]:
def plot_xai_metrics(explanations, inputs, targets, model, output_dir):
    metrics = {}
    explanations_metrics = {
        Deletion: {"baseline_mode": 0, "steps": 10, "max_percentage_perturbed": 0.5},
        MuFidelity: {"baseline_mode": 0, "nb_samples": 5, "subset_percent": 0.2, "grid_size": 13},
        Insertion: {"baseline_mode": 0, "steps": 10,
                    "max_percentage_perturbed": 0.5}
    }
    for metric_class, params in explanations_metrics.items():
        tf.keras.backend.clear_session()
        plt.clf()

        # instanciate the metric
        metric = metric_class(
            model,
            np.array(inputs[:3]),
            np.array(targets[:3]),
            operator=xplique.Tasks.SEMANTIC_SEGMENTATION,
            activation="softmax",
            batch_size=BATCH_SIZE,
            **params
        )

        # iterate on methods explanations
        metrics[metric_class.__name__] = {}
        for method, explanation in explanations.items():
            metrics[metric_class.__name__][method] = metric(explanation[:3])

    barplot(metrics, sort_metric="Deletion", ascending="True")
    plt.show()
    plt.savefig(f"{output_dir}barplot.png", format="png",
                dpi=300, bbox_inches='tight', pad_inches=0.0, transparent=True)

## Model Selection and Tests

In [None]:
# take a random image from the test dataset
MODELS_PATH = "/home/hinton/brito/models/"
OUTPUTS_PATH = "/home/hinton/brito/outputs/"

None if os.path.isdir(OUTPUTS_PATH) else os.mkdir(OUTPUTS_PATH)

images = []
masks = []
for image, mask in test_batches.take(3):
    images.append(image)
    masks.append(mask)

for model_path in glob.glob(f'{MODELS_PATH}*.h5'):
    model = load_model(model_path, custom_objects={'mean_iou': mean_iou, 'dice_coefficient': dice_coefficient, 'pixel_accuracy': pixel_accuracy, 'BPCAPooling': BPCAPooling})
    model_history = pd.read_csv(f'{MODELS_PATH}{model_path.split("/")[-1].replace(".h5", ".csv")}')
    
    print(f"Model: {model_path.split('/')[-1]}")
    output_dir = f'{OUTPUTS_PATH}{model_path.split("/")[-1].replace(".h5", "")}/'
    None if os.path.isdir(output_dir) else os.mkdir(output_dir)

    try:
        loss, accuracy, m_iou, dice = model.evaluate(validation_batches)
        print(f"Loss: {loss}, Accuracy: {accuracy}, Mean IoU: {m_iou}, Dice Coefficient: {dice}")
        del loss, accuracy, m_iou, dice
    except ValueError:
        loss, accuracy = model.evaluate(validation_batches)
        print(f"Loss: {loss}, Accuracy: {accuracy}")
        del loss, accuracy

    plot_metrics(model_history, output_dir)

    for i, (image, mask) in enumerate(zip(images, masks)):
        pred_mask = model.predict(image)
        display([image[0], mask[0], create_mask(pred_mask)], name=f"{output_dir}image_{i}")
        segmentation_predict = run_patch_segment(image[0], model, f"{output_dir}image_{i}_")
        plot_segmentation_zone(image[0], segmentation_predict, model, f"{output_dir}image_{i}_")
        explanations, inputs, targets = run_border_segmentations(image[0], pred_mask[0], model, f"{output_dir}image_{i}_")
        plot_xai_metrics(explanations, inputs, targets, model, f"{output_dir}image_{i}_")

        del segmentation_predict, explanations, inputs, targets, image, mask, pred_mask
        plt.clf()
        plt.close('all')

    del model, model_history, output_dir
    tf.keras.backend.clear_session()
    tf.compat.v1.reset_default_graph()

In this case:

- Deletion: lower is better
- Mufidelity: higher is better
- Insertion: higher is better
- AverageStability: lower is better

In [None]:
# zip OUTPUTS_PATH without lossing the folder structure and data
import zipfile

def zipdir(path, ziph):
    # ziph is zipfile handle
    for root, dirs, files in os.walk(path):
        for file in files:
            ziph.write(os.path.join(root, file))

zipf = zipfile.ZipFile(f'{OUTPUTS_PATH}outputs.zip', 'w', zipfile.ZIP_DEFLATED)
zipdir(OUTPUTS_PATH, zipf)
zipf.close()


In [None]:
import shutil
# zip OUTPUTS_PATH without lossing the folder structure and data

shutil.make_archive(f'{OUTPUTS_PATH}outputs2', 'zip', OUTPUTS_PATH)