# Image Segmentation Model Implementation

This notebook implements an image segmentation model using Keras that:
- Works on AMD GPU, NVIDIA GPU, and CPU
- Uses Python 3.13 and FiftyOne for dataset management
- Includes 3+ segmentation classes
- Properly splits data into training and testing sets
- Calculates Dice, Micro-F1, and Macro-F1 metrics
- Visualizes training progress and results
- Presents a confusion matrix and additional metrics

## 1. Setup and Environment Configuration

In [33]:
# !conda create -n tensorflow python=3.10 -c conda-forge -y

# activate & install the rest

!pip install keras matplotlib datasets



In [34]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers, models
from tensorflow import keras
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model
from tensorflow_datasets.core.download import DownloadConfig
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import (
    confusion_matrix,
    precision_score,
    recall_score,
    accuracy_score,
    f1_score,
)

### Dataset Splitting & Export
Create train/val/test splits with image-mask pairs:

In [35]:
IMG_SIZE    = (128, 128)
BATCH_SIZE  = 16
AUTOTUNE    = tf.data.AUTOTUNE
TARGET_IDS  = [2, 9, 15]
NUM_CLASSES = len(TARGET_IDS) + 1

CLASS_NAMES = {
    0: "background",
    1: "person",
    2: "bicycle",
    3: "train",
}

In [36]:
from google.colab import drive
drive.mount('/content/drive')
drive_root = "/content/drive/MyDrive/tfds"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [53]:
TARGET_IDS  = [2, 9, 15]             # VOC class IDs for bicycle, chair, person
IMG_SIZE    = (128, 128)
BATCH_SIZE  = 16
AUTOTUNE    = tf.data.AUTOTUNE
NUM_CLASSES = len(TARGET_IDS) + 1    # 0=bg,1..3=our classes

# 1) stream & limit
raw_train = load_dataset(
    "merve/pascal-voc",
    split="train",
    streaming=True,
).take(5000)
raw_val = load_dataset(
    "merve/pascal-voc",
    split="validation",
    streaming=True,
).take(1000)

# 2) filter on the "mask" column (not "segmentation_mask")
def has_target(ex):
    mask = np.array(ex["mask"])
    return any((mask == cls).any() for cls in TARGET_IDS)

train_filt = filter(has_target, raw_train)
val_filt   = filter(has_target, raw_val)

# 3) preprocess uses ex["mask"]
def gen(ds_iter):
    for ex in ds_iter:
        img  = np.array(ex["image"])      # PIL → np
        mask = np.array(ex["mask"])       # PIL → np
        # resize & normalize
        img  = tf.image.resize(img, IMG_SIZE) / 255.0
        mask = tf.image.resize(
            tf.expand_dims(mask, -1), IMG_SIZE, method="nearest"
        )[..., 0]
        # remap to 0..NUM_CLASSES-1
        sem = np.zeros(IMG_SIZE, np.int32)
        for i, cls in enumerate(TARGET_IDS, start=1):
            sem[mask == cls] = i
        yield img, sem

# 4) wrap in tf.data and batch
output_sig = (
    tf.TensorSpec((*IMG_SIZE, 3), tf.float32),
    tf.TensorSpec(IMG_SIZE,      tf.int32),
)
train_ds = tf.data.Dataset.from_generator(lambda: gen(train_filt),
                                          output_signature=output_sig)
val_ds   = tf.data.Dataset.from_generator(lambda: gen(val_filt),
                                          output_signature=output_sig)

train_ds = train_ds.cache().shuffle(500).batch(BATCH_SIZE).prefetch(AUTOTUNE)
val_ds   = val_ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)

README.md:   0%|          | 0.00/2.79k [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/37 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/34 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/37 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/34 [00:00<?, ?it/s]

In [55]:
# 5) Build a UNet with MobileNetV2 backbone
def upsample_block(x, skip, filters):
    x = layers.Conv2DTranspose(filters, 3, strides=2, padding="same")(x)
    x = layers.Concatenate()([x, skip])
    x = layers.Conv2D(filters, 3, padding="same", activation="relu")(x)
    x = layers.Conv2D(filters, 3, padding="same", activation="relu")(x)
    return x

def build_unet_mobilenet(input_shape, num_classes):
    base = MobileNetV2(input_shape=input_shape,
                       include_top=False,
                       weights="imagenet")
    layer_names = [
        "block_1_expand_relu",   # 64x64
        "block_3_expand_relu",   # 32x32
        "block_6_expand_relu",   # 16x16
        "block_13_expand_relu",  #  8x8
        "block_16_project",      #  4x4
    ]
    skips = [base.get_layer(name).output for name in layer_names]
    down = Model(inputs=base.input, outputs=skips, name="downstack")
    down.trainable = False

    inputs = layers.Input(shape=input_shape)
    s1, s2, s3, s4, x = down(inputs)

    x = upsample_block(x, s4, 512)
    x = upsample_block(x, s3, 256)
    x = upsample_block(x, s2, 128)
    x = upsample_block(x, s1,  64)

    x = layers.Conv2DTranspose(64, 3, strides=2, padding="same")(x)
    x = layers.Conv2D(64, 3, padding="same", activation="relu")(x)
    outputs = layers.Conv2D(num_classes, 1, activation="softmax")(x)

    return Model(inputs, outputs, name="UNet_MobileNetV2")

model = build_unet_mobilenet((*IMG_SIZE, 3), NUM_CLASSES)
model.compile(
    optimizer="adam",
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"],
)

In [56]:
EPOCHS = 30
STEPS_PER_EPOCH   = 100
VALIDATION_STEPS  = 20

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    steps_per_epoch=STEPS_PER_EPOCH,
    validation_steps=VALIDATION_STEPS,
)

Epoch 1/30


UnknownError: Graph execution error:

Detected at node PyFunc defined at (most recent call last):
<stack traces unavailable>
KeyError: 'mask'
Traceback (most recent call last):

  File "/usr/local/lib/python3.11/dist-packages/tensorflow/python/ops/script_ops.py", line 269, in __call__
    ret = func(*args)
          ^^^^^^^^^^^

  File "/usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/impl/api.py", line 643, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^

  File "/usr/local/lib/python3.11/dist-packages/tensorflow/python/data/ops/from_generator_op.py", line 198, in generator_py_func
    values = next(generator_state.get_iterator(iterator_id))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  File "<ipython-input-53-dcd4a944125d>", line 29, in gen
    for ex in ds_iter:

  File "<ipython-input-53-dcd4a944125d>", line 21, in has_target
    mask = np.array(ex["mask"])
                    ~~^^^^^^^^

KeyError: 'mask'


	 [[{{node PyFunc}}]]
	 [[IteratorGetNext]] [Op:__inference_multi_step_on_iterator_60054]

In [None]:
def display_samples_from_unbatch(model, ds, num=3):
    """
    Pulls `num` raw (image, mask) pairs from `ds` by unbatching and taking,
    then runs them through `model` and plots them side by side.
    """
    # 1) Grab `num` individual examples
    examples = list(ds.unbatch().take(num))  # always restarts from the beginning
    images, true_masks = zip(*examples)      # tuples of length num

    images = np.stack(images, axis=0)        # [num, H, W, 3]
    true_masks = np.stack(true_masks, axis=0)

    # 2) Predict
    logits = model.predict(images)           # [num, H, W, num_classes]
    pred_masks = np.argmax(logits, axis=-1)  # [num, H, W]

    # 3) Plot
    fig, axes = plt.subplots(num, 3, figsize=(12, 4*num))
    for i in range(num):
        ax_img, ax_true, ax_pred = axes[i]

        ax_img.imshow(images[i])
        ax_img.set_title("Image")
        ax_img.axis("off")

        ax_true.imshow(true_masks[i], cmap="tab10", vmin=0, vmax=NUM_CLASSES-1)
        ax_true.set_title("Ground Truth")
        ax_true.axis("off")

        ax_pred.imshow(pred_masks[i], cmap="tab10", vmin=0, vmax=NUM_CLASSES-1)
        ax_pred.set_title("Prediction")
        ax_pred.axis("off")

    # 4) Legend
    import matplotlib.patches as mpatches
    patches = [
        mpatches.Patch(color=plt.cm.tab10(c/10), label=name)
        for c, name in CLASS_NAMES.items()
    ]
    fig.legend(handles=patches, loc="lower center", ncol=len(CLASS_NAMES))
    plt.tight_layout()
    plt.show()

# Usage:
display_samples_from_unbatch(model, val_ds, num=5)

In [None]:
# 3) Preprocess: resize image & remap VOC mask → {0,1,2,3}
def preprocess(ex):
    img  = tf.image.resize(ex["image"], IMG_SIZE) / 255.0
    mask = tf.image.resize(
        tf.expand_dims(ex["segmentation_mask"], -1),
        IMG_SIZE,
        method="nearest"
    )[..., 0]  # shape (H,W)

    sem_mask = tf.zeros(IMG_SIZE, tf.int32)
    for new_idx, cls_id in enumerate(TARGET_IDS, start=1):
        sem_mask = tf.where(tf.equal(mask, cls_id),
                            tf.cast(new_idx, tf.int32),
                            sem_mask)
    return img, sem_mask

train_ds = ds_train.map(preprocess, num_parallel_calls=AUTOTUNE)
val_ds   = ds_val  .map(preprocess, num_parallel_calls=AUTOTUNE)

# 4) Batch & prefetch
train_ds = train_ds.cache().shuffle(500).batch(BATCH_SIZE).prefetch(AUTOTUNE)
val_ds   = val_ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)

### Data Pipeline
Create TensorFlow data pipeline for image-mask pairs: