# Artificial Neural Networks and Deep Learning - Homework 2


## ⚙️ Import Libraries

In [None]:
!pip install focal-loss # default Keras focal loss does not support class weights 

In [None]:
import os
from datetime import datetime

import numpy as np
import pandas as pd

import tensorflow as tf
from tensorflow import keras as tfk
from tensorflow.keras import layers as tfkl

import matplotlib.pyplot as plt

import focal_loss

%matplotlib inline

from sklearn.model_selection import train_test_split

seed = 14
np.random.seed(seed)
tf.random.set_seed(seed)

print(f"TensorFlow version: {tf.__version__}")
print(f"Keras version: {tfk.__version__}")
print(f"GPU devices: {len(tf.config.list_physical_devices('GPU'))}")

## ⏳ Load and prepare the Data

In [None]:
data = np.load("/kaggle/input/homework2/mars_for_students.npz")

training_set = data["training_set"]
X_train = training_set[:, 0]
y_train = training_set[:, 1]

X_test = data["test_set"]

print(f"Training X shape: {X_train.shape}")
print(f"Training y shape: {y_train.shape}")
print(f"Test X shape: {X_test.shape}")

In [None]:
# Add color channel and rescale pixels between 0 and 1
X_train = X_train[..., np.newaxis] / 255.0
X_test = X_test[..., np.newaxis] / 255.0

input_shape = X_train.shape[1:]
num_classes = len(np.unique(y_train))

print(f"Input shape: {input_shape}")
print(f"Number of classes: {num_classes}")

In [None]:
# Remove outliers (containing aliens)
outliers = [
    62,
    79,
    125,
    139,
    142,
    147,
    152,
    156,
    170,
    210,
    217,
    266,
    289,
    299,
    313,
    339,
    348,
    365,
    412,
    417,
    426,
    450,
    461,
    536,
    552,
    669,
    675,
    741,
    744,
    747,
    799,
    802,
    808,
    820,
    821,
    849,
    863,
    890,
    909,
    942,
    971,
    1005,
    1057,
    1079,
    1082,
    1092,
    1095,
    1106,
    1119,
    1125,
    1177,
    1194,
    1224,
    1247,
    1248,
    1258,
    1261,
    1262,
    1306,
    1324,
    1365,
    1370,
    1443,
    1449,
    1508,
    1509,
    1519,
    1551,
    1584,
    1588,
    1628,
    1637,
    1693,
    1736,
    1767,
    1768,
    1782,
    1813,
    1816,
    1834,
    1889,
    1925,
    1942,
    1975,
    1979,
    2000,
    2002,
    2086,
    2096,
    2110,
    2111,
    2151,
    2161,
    2222,
    2235,
    2239,
    2242,
    2301,
    2307,
    2350,
    2361,
    2365,
    2372,
    2414,
    2453,
    2522,
    2535,
    2561,
    2609,
    2614,
]
X_train = np.delete(X_train, outliers, axis=0)
y_train = np.delete(y_train, outliers, axis=0)

# Print the new shape
print(f"Training X shape after outlier removal: {X_train.shape}")
print(f"Training y shape after outlier removal: {y_train.shape}")
print(f"Test X shape: {X_test.shape}")

In [None]:
# Count the number of pixels with a given label
label_list, counts = np.unique(y_train, return_counts=True)
for label, count in zip(label_list, counts):
    print(f"Number of pixels with label {int(label)}: {count}")

In [None]:
class_weights = [sum(counts) / counts[i] for i in range(num_classes)]
class_weights[0] = 0
print(class_weights)

In [None]:
# Count the number of images containing a pixel with a given label
counts = [0 for _ in range(num_classes)]
for labels in y_train:
    label_list = np.unique(labels)
    for value in label_list:
        counts[int(value)] += 1
for label, count in enumerate(counts):
    print(f"Number of images containing label {int(label)}: {count}")

In [None]:
# Since there are few images with label 4, stratify using the presence of label 4
stratify = []
for labels in y_train:
    label_list = np.unique(labels)
    if 4.0 in label_list:
        stratify.append(1)
    else:
        stratify.append(0)
stratify = np.array(stratify)

In [None]:
# Split into training and validation (80-20)
val_size = 0.1
X_train, X_val, y_train, y_val = train_test_split(
    X_train, y_train, random_state=seed, test_size=val_size, stratify=stratify
)

In [None]:
# Check imbalance between training and validation
train_counts = [0 for _ in range(num_classes)]
for labels in y_train:
    label_list = np.unique(labels)
    for value in label_list:
        train_counts[int(value)] += 1

val_counts = [0 for _ in range(num_classes)]
for labels in y_val:
    label_list = np.unique(labels)
    for value in label_list:
        val_counts[int(value)] += 1

ratios = [val_counts[i] / train_counts[i] / val_size for i in range(num_classes)]

for i in range(num_classes):
    print(
        f"Ratio of images containing a label in validation vs train compared to ideal split for label {i}: {ratios[i]}"
    )

In [None]:
category_map = {
    0: 0,  # background
    1: 1,  # soil
    2: 2,  # bedrock
    3: 3,  # sand
    4: 4,  # big rock
}


def apply_category_mapping(label):
    """
    Apply category mapping to labels.
    """
    keys_tensor = tf.constant(list(category_map.keys()), dtype=tf.int32)
    vals_tensor = tf.constant(list(category_map.values()), dtype=tf.int32)
    table = tf.lookup.StaticHashTable(
        tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), default_value=0
    )
    return table.lookup(label)

In [None]:
@tf.function
def random_flip_h(image, label):
    """Consistent random horizontal flip."""
    flip_prob = tf.random.uniform([])
    image = tf.cond(
        flip_prob > 0.5, lambda: tf.image.flip_left_right(image), lambda: image
    )
    label = tf.cond(
        flip_prob > 0.5, lambda: tf.image.flip_left_right(label), lambda: label
    )
    return image, label


@tf.function
def random_flip_v(image, label):
    """Consistent random vertical flip."""
    flip_prob = tf.random.uniform([])
    image = tf.cond(
        flip_prob > 0.5, lambda: tf.image.flip_up_down(image), lambda: image
    )
    label = tf.cond(
        flip_prob > 0.5, lambda: tf.image.flip_up_down(label), lambda: label
    )
    return image, label


@tf.function
def random_brightness(image):
    delta = tf.random.uniform([])
    delta_thresh = 0.1
    image = tf.cond(
        delta < 3 * delta_thresh,
        lambda: tf.image.adjust_brightness(image, delta=delta),
        lambda: image,
    )
    return image


@tf.function
def random_contrast(image):
    factor = tf.random.uniform([])
    factor_thresh = 2
    image = tf.cond(
        factor < factor_thresh / 6,
        lambda: tf.image.adjust_contrast(image, contrast_factor=factor),
        lambda: image,
    )
    return image


@tf.function
def augmentation(image, label, seed=None):
    image, label = random_flip_h(image, label)
    # image, label = random_flip_v(image, label)
    # image = random_brightness(image)
    # image = random_contrast(image)
    return image, label

In [None]:
def make_dataset(images, labels, batch_size, shuffle=True, augment=False, seed=None):
    """
    Create a memory-efficient TensorFlow dataset.
    """
    # Add an axis to labels
    new_labels = labels[..., np.newaxis]

    # Create dataset from file paths
    dataset = tf.data.Dataset.from_tensor_slices((images, new_labels))

    if shuffle:
        dataset = dataset.shuffle(buffer_size=batch_size * 2, seed=seed)

    # Apply category mapping
    dataset = dataset.map(
        lambda x, y: (x, apply_category_mapping(y)), num_parallel_calls=tf.data.AUTOTUNE
    )

    if augment:
        dataset = dataset.map(
            lambda x, y: augmentation(x, y, seed=seed),
            num_parallel_calls=tf.data.AUTOTUNE,
        )

    # Batch the data
    dataset = dataset.batch(batch_size, drop_remainder=False)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    return dataset

In [None]:
batch_size = 16

# Create the datasets
print("Creating datasets...")
train_dataset = make_dataset(
    X_train,
    y_train.astype("int32"),
    batch_size=batch_size,
    shuffle=False,
    augment=False,
    seed=seed,
)

val_dataset = make_dataset(
    X_val, y_val.astype("int32"), batch_size=batch_size, shuffle=False
)

print("Datasets created!")

# Check the shape of the data
for images, labels in train_dataset.take(1):
    input_shape = images.shape[1:]
    print(f"\nInput shape: {input_shape}")
    print("Images shape:", images.shape)
    print("Labels shape:", labels.shape)
    print("Labels dtype:", labels.dtype)
    break

## Analyze the data

In [None]:
def create_segmentation_colormap(num_classes):
    """
    Create a linear colormap using a predefined palette.
    Uses 'viridis' as default because it is perceptually uniform
    and works well for colorblindness.
    """
    return plt.cm.viridis(np.linspace(0, 1, num_classes))


def apply_colormap(label, colormap=None):
    """
    Apply the colormap to a label.
    """
    # Ensure label is 2D
    label = np.squeeze(label)

    if colormap is None:
        num_classes = len(np.unique(label))
        colormap = create_segmentation_colormap(num_classes)

    # Apply the colormap
    colored = colormap[label.astype(int)]

    return colored


def plot_sample_batch(images, labels, num_samples=3):
    """
    Display some image and label pairs from the dataset.
    """
    plt.figure(figsize=(15, 4 * num_samples))

    colormap = create_segmentation_colormap(num_classes)

    for j in range(min(num_samples, len(images))):
        # Plot original image
        plt.subplot(num_samples, 2, j * 2 + 1)
        plt.imshow(images[j], cmap="gray")
        plt.title(f"Image {j+1}")
        plt.axis("off")

        # Plot colored label
        plt.subplot(num_samples, 2, j * 2 + 2)
        colored_label = apply_colormap(labels[j], colormap)
        plt.imshow(colored_label)
        plt.title(f"Label {j+1}")
        plt.axis("off")

    plt.tight_layout()
    plt.show()
    plt.close()


# Visualize examples from the training set
print("Visualizing examples from the training set:")
plot_sample_batch(X_train, y_train, num_samples=10)

In [None]:
# Visualize examples from the test set.
num_samples = 4
plt.figure(figsize=(15, 2 * num_samples))

colormap = create_segmentation_colormap(num_classes)

for j in range(min(num_samples, len(X_test))):
    plt.subplot(num_samples, 1, j + 1)
    plt.imshow(X_test[j], cmap="gray")
    plt.title(f"Image {j}")
    plt.axis("off")

## 🛠️ Define the model

In [None]:
def unet_block(
    input_tensor,
    filters,
    kernel_size=3,
    activation="relu",
    stack=1,
    name="",
    groups=8,
    dilation_rate=1,
):
    # Initialise the input tensor
    x = input_tensor

    # Apply a sequence of Conv2D, Batch Normalisation, and Activation layers for the specified number of stacks
    for i in range(stack):
        x = tfkl.Conv2D(
            filters,
            kernel_size=kernel_size,
            padding="same",
            dilation_rate=dilation_rate,
            kernel_regularizer=tfk.regularizers.L2(1e-3),
            name=name + "conv" + str(i + 1),
        )(x)
        x = tfkl.GroupNormalization(name=name + "bn" + str(i + 1), groups=groups)(x)
        x = tfkl.Activation(activation, name=name + "activation" + str(i + 1))(x)

    # Return the transformed tensor
    return x

In [None]:
# Define a Residual block with configurable parameters (currently unused, but
# used in older models). Note that a modified version of this block is used in
# the final model, by combining it with the inception block.
def residual_unet_block(
    x,
    filters,
    kernel_size=3,
    padding="same",
    downsample=False,
    activation="relu",
    stack=2,
    name="residual",
    groups=8,
    dilation_rate=1,
):
    for s in range(stack):
        # Save input for skip connection.
        skip = x

        # First convolutional block with Batch Normalisation and activation.
        x = tfkl.Conv2D(
            filters,
            kernel_size,
            padding=padding,
            name=f"{name}_conv1_{s}",
            dilation_rate=dilation_rate,
        )(x)
        x = tfkl.GroupNormalization(name=f"{name}_bn1_{s}", groups=groups)(x)
        x = tfkl.Activation(activation, name=f"{name}_act1_{s}")(x)

        # Second convolutional block.
        x = tfkl.Conv2D(
            filters,
            kernel_size,
            padding=padding,
            name=f"{name}_conv2_{s}",
            dilation_rate=dilation_rate,
        )(x)
        x = tfkl.GroupNormalization(name=f"{name}_bn2_{s}", groups=groups)(x)

        # Adjust skip connection dimension if needed.
        if skip.shape[-1] != filters:
            skip = tfkl.Conv2D(filters, 1, padding=padding, name=f"{name}_proj_{s}")(
                skip
            )
            skip = tfkl.GroupNormalization(name=f"{name}_proj_bn_{s}", groups=groups)(
                skip
            )

        # Add skip connection and apply activation.
        x = tfkl.Add(name=f"{name}_add_{s}")([x, skip])
        x = tfkl.Activation(activation, name=f"{name}_act2_{s}")(x)

    # Optional downsampling.
    if downsample:
        x = tfkl.MaxPooling2D(2, name=f"{name}_pool")(x)

    return x

In [None]:
# Define the Inception block with batch normalization (BN) and with multiple
# convolution paths and optional downsampling.
def inception_block_bn(
    x,
    filters,
    padding="same",
    downsample=False,
    activation="relu",
    stack=1,
    name="inception",
    groups=8,
):
    # This inception block consists of a 1x1 convolution path, a 3x3 convolution
    # path, a 5x5 convolution path, and a pooling path. The paths are then
    # concatenated to form the final block output.
    # The reason for using multiple paths is to allow the model to learn
    # different features at different scales, and to increase the model's
    # capacity without increasing the number of parameters too much.
    # The downsampling (pooling) is not mandatory.

    # Loop through specified stack layers for multiple inception paths.
    for s in range(stack):
        # 1x1 convolution path with batch normalization and activation.
        conv1 = tfkl.Conv2D(filters // 4, 1, padding=padding, name=f"{name}_conv1_{s}")(
            x
        )
        conv1 = tfkl.GroupNormalization(name=f"{name}_bn1_{s}", groups=groups)(conv1)
        conv1 = tfkl.Activation(activation, name=f"{name}_act1_{s}")(conv1)

        # 3x3 convolution path with initial reduction layer.
        conv3_reduce = tfkl.Conv2D(
            filters // 8, 1, padding=padding, name=f"{name}_conv3_reduce_{s}"
        )(x)
        conv3_reduce = tfkl.GroupNormalization(
            name=f"{name}_bn3_reduce_{s}", groups=groups
        )(conv3_reduce)
        conv3_reduce = tfkl.Activation(activation, name=f"{name}_act3_reduce_{s}")(
            conv3_reduce
        )
        conv3 = tfkl.Conv2D(filters // 4, 3, padding=padding, name=f"{name}_conv3_{s}")(
            conv3_reduce
        )
        conv3 = tfkl.GroupNormalization(name=f"{name}_bn3_{s}", groups=groups)(conv3)
        conv3 = tfkl.Activation(activation, name=f"{name}_act3_{s}")(conv3)

        # 5x5 convolution path with initial reduction layer.
        conv5_reduce = tfkl.Conv2D(
            filters // 12, 1, padding=padding, name=f"{name}_conv5_reduce_{s}"
        )(x)
        conv5_reduce = tfkl.GroupNormalization(
            name=f"{name}_bn5_reduce_{s}", groups=groups
        )(conv5_reduce)
        conv5_reduce = tfkl.Activation(activation, name=f"{name}_act5_reduce_{s}")(
            conv5_reduce
        )
        conv5 = tfkl.Conv2D(filters // 4, 5, padding=padding, name=f"{name}_conv5_{s}")(
            conv5_reduce
        )
        conv5 = tfkl.GroupNormalization(name=f"{name}_bn5_{s}", groups=groups)(conv5)
        conv5 = tfkl.Activation(activation, name=f"{name}_act5_{s}")(conv5)

        # Pooling path with projection for spatial dimensionality reduction.
        pool = tfkl.MaxPooling2D(
            3, strides=1, padding=padding, name=f"{name}_pooling_{s}"
        )(x)
        pool_proj = tfkl.Conv2D(
            filters // 4, 1, padding=padding, name=f"{name}_pool_proj_{s}"
        )(pool)
        pool_proj = tfkl.GroupNormalization(
            name=f"{name}_bn_pool_proj_{s}", groups=groups
        )(pool_proj)
        pool_proj = tfkl.Activation(activation, name=f"{name}_act_pool_proj_{s}")(
            pool_proj
        )

        # Concatenate all paths to form the final block output.
        x = tfkl.Concatenate(name=f"{name}_concat_{s}")(
            [conv1, conv3, conv5, pool_proj]
        )

    # Apply downsampling if specified.
    if downsample:
        x = tfkl.MaxPooling2D(2, name=f"{name}_pool")(x)
    return x

In [None]:
# Define the a residual block using an inception block instead of the
# convolutional path. This mixed block is used in the final model, and it
# combines the benefits of both residual and inception blocks. We believe
# that the residual and inception blocks are complementary, as their guesses
# don't seem to perfectly overlap. See the report for more details.
def inception_residual_unet(
    x,
    filters,
    padding="same",
    downsample=False,
    activation="relu",
    stack=4,
    inception_stack=1,
    name="residual",
    groups=8,
):
    for s in range(stack):
        # Save input for skip connection.
        skip = x

        # Create the inception block.
        x = inception_block_bn(
            x,
            filters,
            padding,
            downsample=False,
            activation=activation,
            stack=inception_stack,
            name=f"{name}_inception_{s}",
            groups=groups,
        )

        # Adjust skip connection dimension if needed.
        if skip.shape[-1] != filters:
            skip = tfkl.Conv2D(filters, 1, padding=padding, name=f"{name}_proj_{s}")(
                skip
            )
            skip = tfkl.GroupNormalization(name=f"{name}_proj_bn_{s}", groups=groups)(
                skip
            )

        # Add skip connection and apply activation.
        x = tfkl.Add(name=f"{name}_add_{s}")([x, skip])
        x = tfkl.Activation(activation, name=f"{name}_act2_{s}")(x)

    # Optional downsampling.
    if downsample:
        x = tfkl.MaxPooling2D(2, name=f"{name}_pool")(x)

    return x

In [None]:
def attention_gate(input_tensor, gating_tensor, inter_channels):
    # 1x1 convolution on the input (skip connection)
    theta_x = tfkl.Conv2D(inter_channels, kernel_size=1, strides=2, padding="same")(
        input_tensor
    )
    # 1x1 convolution on the gating (decoder output)
    phi_g = tfkl.Conv2D(inter_channels, kernel_size=1, strides=1, padding="same")(
        gating_tensor
    )
    # Add and apply ReLU
    add = tfkl.Add()([theta_x, phi_g])
    relu = tfkl.Activation("relu")(add)
    # Generate attention weights
    psi = tfkl.Conv2D(
        1, kernel_size=1, strides=1, padding="same", activation="sigmoid"
    )(relu)
    # Upsample attention weights to match input_tensor spatial dimensions
    upsampled_psi = tfkl.UpSampling2D()(psi)
    # Multiply input tensor by attention weights
    output = tfkl.Multiply()([input_tensor, upsampled_psi])
    return output

In [None]:
def get_unet_model(input_shape=(64, 128, 1), num_classes=num_classes, seed=seed):
    tf.random.set_seed(seed)
    input_layer = tfkl.Input(shape=input_shape, name="input_layer")
    base_filters = 32

    # Downsampling path
    down_block_1 = unet_block(input_layer, base_filters, name="down_block1_", groups=1)
    d1 = tfkl.SpatialDropout2D(0.05)(down_block_1)
    d1 = tfkl.MaxPooling2D()(d1)

    down_block_2 = unet_block(d1, base_filters * 2, name="down_block2_", groups=2)
    d2 = tfkl.SpatialDropout2D(0.1)(down_block_2)
    d2 = tfkl.MaxPooling2D()(d2)

    down_block_3 = unet_block(d2, base_filters * 4, name="down_block3_", groups=4)
    d3 = tfkl.SpatialDropout2D(0.15)(down_block_3)
    d3 = tfkl.MaxPooling2D()(d3)

    down_block_4 = unet_block(d3, base_filters * 8, name="down_block4_", groups=8)
    d4 = tfkl.SpatialDropout2D(0.2)(down_block_4)
    d4 = tfkl.MaxPooling2D()(d4)

    down_block_5 = unet_block(d4, base_filters * 16, name="down_block5_", groups=8)
    d5 = tfkl.SpatialDropout2D(0.25)(down_block_5)
    d5 = tfkl.MaxPooling2D()(d5)

    # Bottleneck
    bottleneck = unet_block(
        d5, base_filters * 32, name="bottleneck", groups=8, stack=2, dilation_rate=2
    )
    bottleneck = tfkl.SpatialDropout2D(0.3)(bottleneck)

    # Upsampling path
    u1 = tfkl.Conv2DTranspose(
        base_filters * 16, kernel_size=2, strides=2, padding="same"
    )(bottleneck)
    u1 = tfkl.Concatenate()([u1, down_block_5])
    u1 = unet_block(u1, base_filters * 16, name="up_block1_", groups=8)

    u2 = tfkl.Conv2DTranspose(
        base_filters * 8, kernel_size=2, strides=2, padding="same"
    )(u1)
    u2 = tfkl.Concatenate()([u2, down_block_4])
    u2 = unet_block(u2, base_filters * 8, name="up_block2_", groups=8)

    u3 = tfkl.Conv2DTranspose(
        base_filters * 4, kernel_size=2, strides=2, padding="same"
    )(u2)
    u3 = tfkl.Concatenate()([u3, down_block_3])
    u3 = unet_block(u3, base_filters * 4, name="up_block3_", groups=4)

    u4 = tfkl.Conv2DTranspose(
        base_filters * 2, kernel_size=2, strides=2, padding="same"
    )(u3)
    u4 = tfkl.Concatenate()([u4, down_block_2])
    u4 = unet_block(u4, base_filters * 2, name="up_block4_", groups=2)

    u5 = tfkl.Conv2DTranspose(base_filters, kernel_size=2, strides=2, padding="same")(
        u4
    )
    u5 = tfkl.Concatenate()([u5, down_block_1])
    u5 = unet_block(u5, base_filters, name="up_block5_", groups=1)

    # Output Layer
    output_layer = tfkl.Conv2D(
        num_classes,
        kernel_size=1,
        padding="same",
        activation="softmax",
        name="output_layer",
    )(u5)

    model = tf.keras.Model(inputs=input_layer, outputs=output_layer, name="UNet")
    return model

In [None]:
# Define parameters
epochs = 1000
patience = 45
learning_rate = 1e-4

In [None]:
model = get_unet_model()

# Print a detailed summary of the model with expanded nested layers and trainable parameters.
model.summary(expand_nested=True, show_trainable=True)

# Generate and display a graphical representation of the model architecture.
try:
    tf.keras.utils.plot_model(model, show_trainable=True, expand_nested=True, dpi=70)
except:
    print("Model too complex to plot!")

## Train the model

In [None]:
# Define custom Mean Intersection Over Union metric
@tfk.utils.register_keras_serializable()
class MeanIntersectionOverUnion(tf.keras.metrics.MeanIoU):
    def __init__(
        self, num_classes, labels_to_exclude=None, name="mean_iou", dtype=None
    ):
        super(MeanIntersectionOverUnion, self).__init__(
            num_classes=num_classes, name=name, dtype=dtype
        )
        if labels_to_exclude is None:
            labels_to_exclude = [0]  # Default to excluding label 0
        self.labels_to_exclude = labels_to_exclude

    def update_state(self, y_true, y_pred, sample_weight=None):
        # Convert predictions to class labels
        y_pred = tf.math.argmax(y_pred, axis=-1)

        # Flatten the tensors
        y_true = tf.reshape(y_true, [-1])
        y_pred = tf.reshape(y_pred, [-1])

        # Apply mask to exclude specified labels
        for label in self.labels_to_exclude:
            mask = tf.not_equal(y_true, label)
            y_true = tf.boolean_mask(y_true, mask)
            y_pred = tf.boolean_mask(y_pred, mask)

        # Update the state
        return super().update_state(y_true, y_pred, sample_weight)


# Visualization callback
class VizCallback(tf.keras.callbacks.Callback):
    def __init__(self, image, label, frequency=5):
        super().__init__()
        self.image = image
        self.label = label
        self.frequency = frequency

    def on_epoch_end(self, epoch, logs=None):
        if epoch % self.frequency == 0:  # Visualize only every "frequency" epochs
            image, label = self.image, self.label
            label = apply_category_mapping(label)
            pred = self.model.predict(image, verbose=0)
            y_pred = tf.math.argmax(pred, axis=-1)
            y_pred = y_pred.numpy()

            # Create colormap
            colormap = create_segmentation_colormap(num_classes)

            plt.figure(figsize=(16, 4))

            # Input image
            plt.subplot(1, 3, 1)
            plt.imshow(image[0], cmap="gray")
            plt.title("Input Image")
            plt.axis("off")

            # Ground truth
            plt.subplot(1, 3, 2)
            colored_label = apply_colormap(label.numpy(), colormap)
            plt.imshow(colored_label)
            plt.title("Ground Truth Mask")
            plt.axis("off")

            # Prediction
            plt.subplot(1, 3, 3)
            colored_pred = apply_colormap(y_pred[0], colormap)
            plt.imshow(colored_pred)
            plt.title("Predicted Mask")
            plt.axis("off")

            plt.tight_layout()
            plt.show()
            plt.close()

In [None]:
from tensorflow.keras import backend as K


def iou_loss(y_true, y_pred, num_classes=num_classes, smooth=1e-6):
    """
    Compute the Intersection over Union (IoU) loss.
    :param y_true: Ground truth tensor (not one-hot encoded).
    :param y_pred: Predicted tensor (probabilities or logits).
    :param smooth: Smoothing factor to avoid division by zero.
    :return: IoU loss value.
    """
    y_true_one_hot = tf.one_hot(tf.cast(y_true, tf.int32), num_classes)

    # Initialize a list to store IoU values for each class (ignoring class 0)
    iou_values = []

    # Loop over all classes (excluding class 0)
    for i in range(1, num_classes):  # Start from 1 to exclude background (class 0)
        # Get the probabilities for the current class (class i)
        y_pred_class = y_pred[..., i]
        y_true_class = y_true_one_hot[..., i]

        # Flatten the predicted probabilities and true labels for class i
        y_true_f_class = K.flatten(y_true_class)
        y_pred_f_class = K.flatten(y_pred_class)

        # Calculate intersection and union for class i
        intersection = K.sum(y_true_f_class * y_pred_f_class)
        union = K.sum(y_true_f_class) + K.sum(y_pred_f_class) - intersection

        # Compute IoU for this class
        iou_class = (intersection + smooth) / (union + smooth)
        iou_values.append(iou_class)

    # Compute the mean IoU over all classes (excluding class 0)
    mean_iou = K.mean(K.stack(iou_values))

    # Return mean IoU loss (1 - mean IoU)
    return 1 - mean_iou

In [None]:
# Define a custom loss
class CustomLoss(tfk.losses.Loss):
    def __init__(self, alpha, num_classes=num_classes, name="custom_loss", **kwargs):
        super(CustomLoss, self).__init__(name=name, **kwargs)
        self.alpha = alpha
        self.num_classes = num_classes

    def call(self, y_true, y_pred):
        # Calculate the loss
        return (1 - self.alpha) * focal_loss.sparse_categorical_focal_loss(
            y_true, y_pred, class_weight=class_weights, gamma=3.0
        ) + self.alpha * iou_loss(y_true, y_pred)

In [None]:
# Compile the model
print("Compiling model...")
model.compile(
    loss=CustomLoss(0.8),
    optimizer=tfk.optimizers.AdamW(learning_rate),
    metrics=[
        "accuracy",
        MeanIntersectionOverUnion(num_classes=num_classes, labels_to_exclude=[0]),
    ],
)
print("Model compiled!")

In [None]:
# Setup callbacks
early_stopping = tfk.callbacks.EarlyStopping(
    monitor="val_mean_iou", mode="max", patience=patience, restore_best_weights=True
)

image, label = val_dataset.take(1).get_single_element()
viz_callback = VizCallback(image[1:2, ...], label[1:2, ...])

reduce_lr_callback = tfk.callbacks.ReduceLROnPlateau(
    monitor="val_loss", patience=patience / 3, factor=0.1, min_lr=learning_rate * 1e-4
)

callbacks = [early_stopping, viz_callback, reduce_lr_callback]

In [None]:
history = model.fit(
    train_dataset,
    epochs=epochs,
    validation_data=val_dataset,
    callbacks=callbacks,
    verbose=1,
).history

# Calculate and print the final validation accuracy
final_val_meanIoU = round(max(history["val_mean_iou"]) * 100, 2)
print(f"Final validation Mean Intersection Over Union: {final_val_meanIoU}%")

In [None]:
timestep_str = datetime.now().strftime("%y%m%d_%H%M%S")
model_filename = f"model_{timestep_str}.keras"
model.save(model_filename)
del model

print(f"Model saved to {model_filename}")

## 📊 Prepare Your Submission

In our Kaggle competition, submissions are made as `csv` files. To create a proper `csv` file, you need to flatten your predictions and include an `id` column as the first column of your dataframe. To maintain consistency between your results and our solution, please avoid shuffling the test set. The code below demonstrates how to prepare the `csv` file from your model predictions.




In [None]:
model = tfk.models.load_model(model_filename, compile=False)
print(f"Model loaded from {model_filename}")

In [None]:
preds = model.predict(X_test)
preds = np.argmax(preds, axis=-1)
print(f"Predictions shape: {preds.shape}")

In [None]:
def y_to_df(y) -> pd.DataFrame:
    """Converts segmentation predictions into a DataFrame format for Kaggle."""
    n_samples = len(y)
    y_flat = y.reshape(n_samples, -1)
    df = pd.DataFrame(y_flat)
    df["id"] = np.arange(n_samples)
    cols = ["id"] + [col for col in df.columns if col != "id"]
    return df[cols]

In [None]:
# Create and download the csv submission file
timestep_str = model_filename.replace("model_", "").replace(".keras", "")
submission_filename = f"submission_{timestep_str}.csv"
submission_df = y_to_df(preds)
submission_df.to_csv(submission_filename, index=False)

# from google.colab import files
# files.download(submission_filename)

In [None]:
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    confusion_matrix,
)
import seaborn as sns

# Predict class probabilities and get predicted classes on the validation set.
test_predictions = np.argmax(model.predict(X_val), axis=-1).flatten()

# Extract ground truth classes.
test_gt = y_val.flatten()

# Calculate and display test set accuracy.
test_accuracy = accuracy_score(test_gt, test_predictions)
print(f"Accuracy score over the test set: {round(test_accuracy, 4)}")

# Calculate and display test set precision.
test_precision = precision_score(test_gt, test_predictions, average="weighted")
print(f"Precision score over the test set: {round(test_precision, 4)}")

# Calculate and display test set recall.
test_recall = recall_score(test_gt, test_predictions, average="weighted")
print(f"Recall score over the test set: {round(test_recall, 4)}")

# Calculate and display test set F1 score.
test_f1 = f1_score(test_gt, test_predictions, average="weighted")
print(f"F1 score over the test set: {round(test_f1, 4)}")

# Compute the confusion matrix.
cm = confusion_matrix(test_gt, test_predictions)

# Create labels combining confusion matrix values.
labels = np.array([f"{num}" for num in cm.flatten()]).reshape(cm.shape)

# Plot the confusion matrix with class labels.
plt.figure(figsize=(8, 6))
sns.heatmap(
    cm,
    annot=labels,
    fmt="",
    xticklabels=range(num_classes),
    yticklabels=range(num_classes),
    cmap="Blues",
)
plt.xlabel("Predicted labels")
plt.ylabel("True labels")
plt.show()

#  
<img src="https://airlab.deib.polimi.it/wp-content/uploads/2019/07/airlab-logo-new_cropped.png" width="350">

<img src="https://upload.wikimedia.org/wikipedia/commons/thumb/9/95/Instagram_logo_2022.svg/800px-Instagram_logo_2022.svg.png" width="15"> **Instagram:** https://www.instagram.com/airlab_polimi/

<img src="https://upload.wikimedia.org/wikipedia/commons/thumb/8/81/LinkedIn_icon.svg/2048px-LinkedIn_icon.svg.png" width="15"> **LinkedIn:** https://www.linkedin.com/company/airlab-polimi/
___
Credits: Alberto Archetti 📧 alberto.archetti@polito.it





```
   Copyright 2024 Alberto Archetti

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
```