# Chapter 8: Telling Things Apart - Image Segmentation

## 1️⃣ Chapter Overview

In the previous chapters, we focused on **Classification** (assigning a single label to an entire image). In this chapter, we tackle a much more granular task: **Semantic Segmentation**.

Semantic Segmentation involves classifying **every single pixel** in an image. Instead of saying "This image contains a cat", we say "Pixels (x,y) to (x+n, y+m) belong to the cat, and the rest belong to the background."

We will implement the legendary **U-Net** architecture, a model originally designed for biomedical image segmentation that has become the gold standard for general segmentation tasks due to its efficiency and precision.

### Key Concepts:
* **Semantic Segmentation:** Pixel-level classification.
* **Upsampling:** How to increase the spatial resolution of feature maps (Transposed Convolutions).
* **U-Net Architecture:** A symmetric Encoder-Decoder network with skip connections.
* **IOU (Intersection over Union):** The standard metric for evaluating segmentation accuracy.

---

## 2️⃣ Theoretical Explanation

### 2.1 Classification vs. Segmentation
* **Classification:** Output is a vector of probabilities (e.g., `[0.1, 0.9, 0.0]`). Spatial information is lost.
* **Segmentation:** Output is a **mask** of the same size as the input image (e.g., $128 \times 128 \times K$ classes).

### 2.2 The Challenge: Resolution
Standard CNNs (like VGG or ResNet) use Max Pooling to reduce image size and increase the receptive field. 
* Input: $256 \times 256$
* Bottleneck: $8 \times 8$

For segmentation, we need an output that is $256 \times 256$. How do we get back up from $8 \times 8$? 
We need **Upsampling**.

### 2.3 Upsampling Techniques
1.  **Nearest Neighbor / Bilinear Interpolation:** Simple mathematical resizing. No learning involved.
2.  **Transposed Convolution (Conv2DTranspose):** Often called "Deconvolution". It uses learnable filters to expand the image, learning *how* to fill in the details.

### 2.4 The U-Net Architecture
U-Net consists of two paths:
1.  **Contracting Path (Encoder):** A standard CNN that captures context but reduces resolution.
2.  **Expanding Path (Decoder):** Uses Transposed Convolutions to restore resolution.

**The Secret Sauce: Skip Connections.** 
Upsampling from a low-resolution bottleneck is hard; fine details (like the whiskers of a cat) are lost. U-Net concatenates the high-resolution feature maps from the Encoder directly to the corresponding layers in the Decoder. This gives the Decoder a "template" of fine details to paint over.

## 3️⃣ Setup and Data Loading

We will use the **Oxford-IIIT Pet Dataset**, a standard benchmark for segmentation. It contains images of pets and their pixel-wise masks (1: Pet, 2: Background, 3: Border).

We will map these to 3 classes:
* Class 0: Pet
* Class 1: Background
* Class 2: Border (Outline)

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, Concatenate, Dropout
from tensorflow.keras.models import Model

# 1. Download Dataset
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

# 2. Image Processing Functions
def normalize(input_image, input_mask):
    # Normalize image to [0, 1]
    input_image = tf.cast(input_image, tf.float32) / 255.0
    # Masks are 1, 2, 3 in dataset. We shift to 0, 1, 2 for training
    input_mask -= 1 
    return input_image, input_mask

def load_image(datapoint):
    input_image = tf.image.resize(datapoint['image'], (128, 128))
    input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

    input_image, input_mask = normalize(input_image, input_mask)
    return input_image, input_mask

# 3. Build Pipeline
TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

train_images = dataset['train'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
test_images = dataset['test'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)

train_batches = train_images.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat().prefetch(buffer_size=tf.data.AUTOTUNE)
test_batches = test_images.batch(BATCH_SIZE)

print("Data loaded. Input Image Shape: (128, 128, 3). Mask Shape: (128, 128, 1)")

### 3.1 Visualizing the Data
Let's look at an image and its corresponding ground truth mask.

In [None]:
def display(display_list):
    plt.figure(figsize=(15, 5))
    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        # Use standard coloring for image, grayscale for mask
        plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()

for image, mask in train_batches.take(1):
    sample_image, sample_mask = image[0], mask[0]
    display([sample_image, sample_mask])

## 4️⃣ Building the U-Net Architecture

We will build the U-Net manually to understand the flow. 
* **Encoder:** `Conv2D` -> `MaxPool`.
* **Decoder:** `Conv2DTranspose` -> `Concatenate` (Skip Connection) -> `Conv2D`.

In [None]:
def unet_model(output_channels: int):
    inputs = Input(shape=[128, 128, 3])
    
    # --- Encoder (Downsampling) ---
    # We save the output of each block to use in the skip connections later
    
    # Block 1
    x = Conv2D(64, 3, activation='relu', padding='same')(inputs)
    x = Conv2D(64, 3, activation='relu', padding='same')(x)
    skip1 = x 
    x = MaxPooling2D()(x)

    # Block 2
    x = Conv2D(128, 3, activation='relu', padding='same')(x)
    x = Conv2D(128, 3, activation='relu', padding='same')(x)
    skip2 = x
    x = MaxPooling2D()(x)

    # Block 3
    x = Conv2D(256, 3, activation='relu', padding='same')(x)
    x = Conv2D(256, 3, activation='relu', padding='same')(x)
    skip3 = x
    x = MaxPooling2D()(x)

    # --- Bottleneck ---
    x = Conv2D(512, 3, activation='relu', padding='same')(x)
    x = Conv2D(512, 3, activation='relu', padding='same')(x)
    x = Dropout(0.5)(x)

    # --- Decoder (Upsampling) ---
    
    # Block 3 (Up)
    x = Conv2DTranspose(256, 3, strides=2, padding='same')(x)
    x = Concatenate()([x, skip3]) # SKIP CONNECTION
    x = Conv2D(256, 3, activation='relu', padding='same')(x)
    x = Conv2D(256, 3, activation='relu', padding='same')(x)

    # Block 2 (Up)
    x = Conv2DTranspose(128, 3, strides=2, padding='same')(x)
    x = Concatenate()([x, skip2]) # SKIP CONNECTION
    x = Conv2D(128, 3, activation='relu', padding='same')(x)
    x = Conv2D(128, 3, activation='relu', padding='same')(x)

    # Block 1 (Up)
    x = Conv2DTranspose(64, 3, strides=2, padding='same')(x)
    x = Concatenate()([x, skip1]) # SKIP CONNECTION
    x = Conv2D(64, 3, activation='relu', padding='same')(x)
    x = Conv2D(64, 3, activation='relu', padding='same')(x)

    # --- Output Layer ---
    # Output has 'output_channels' filters (3 classes here)
    outputs = Conv2D(output_channels, 1, activation='softmax')(x)

    return Model(inputs=inputs, outputs=outputs)

model = unet_model(output_channels=3)
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Note: The summary is huge, so we just print the parameter count
print(f"Model created with {model.count_params():,} parameters.")

## 5️⃣ Training and Prediction

We train the model. Since segmentation is computationally expensive, we will run for a few epochs to demonstrate convergence.

In [None]:
# Callback to show predictions during training
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        # Clear output to keep notebook clean (optional, omitted here for safety)
        # display([sample_image, sample_mask, create_mask(model.predict(sample_image[tf.newaxis, ...]))])
        pass

def create_mask(pred_mask):
    # pred_mask shape: (1, 128, 128, 3)
    # We take argmax across the last axis (channels) to get the class index (0, 1, 2)
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

EPOCHS = 5
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples // BATCH_SIZE // VAL_SUBSPLITS

model_history = model.fit(train_batches, 
                          epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=test_batches,
                          callbacks=[DisplayCallback()])

## 6️⃣ Evaluation: Making Predictions

Let's visualize the results. We will take images from the test set, run them through the U-Net, and compare the predicted mask with the ground truth.

In [None]:
def show_predictions(dataset=None, num=1):
    if dataset:
        for image, mask in dataset.take(num):
            pred_mask = model.predict(image)
            display([image[0], mask[0], create_mask(pred_mask)])
    else:
        # Show prediction for the sample image loaded earlier
        pred_mask = model.predict(sample_image[tf.newaxis, ...])
        display([sample_image, sample_mask, create_mask(pred_mask)])

# Show predictions on Test Data
show_predictions(test_batches, num=3)

## 7️⃣ Metric: Intersection over Union (IoU)

Accuracy can be misleading in segmentation (if 90% of the image is background, predicting "all background" gives 90% accuracy but is useless). 

**IoU** measures the overlap between the predicted mask and the true mask.
$$IoU = \frac{\text{Area of Overlap}}{\text{Area of Union}}$$

In [None]:
class MeanIoU(tf.keras.metrics.MeanIoU):
    def __init__(self, num_classes, name=None, dtype=None):
        super(MeanIoU, self).__init__(num_classes=num_classes, name=name, dtype=dtype)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.argmax(y_pred, axis=-1)
        return super().update_state(y_true, y_pred, sample_weight)

# Recompile to include IoU
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy', MeanIoU(num_classes=3)])

# Evaluate on a small batch
results = model.evaluate(test_batches, steps=5)
print(f"Test Accuracy: {results[1]:.4f}")
print(f"Test Mean IoU: {results[2]:.4f}")

## 8️⃣ Chapter Summary

In this chapter, we performed **Semantic Segmentation** on the Oxford-IIIT Pet Dataset.

* **Architecture:** We built **U-Net**, the industry standard for segmentation. We learned how it uses an Encoder to capture context and a Decoder with **Skip Connections** to recover precise spatial details.
* **Upsampling:** We used `Conv2DTranspose` to learn how to resize feature maps intelligently, rather than just stretching them.
* **Metrics:** We learned that pixel accuracy is insufficient and implemented **Mean IoU** for a robust evaluation of segmentation quality.

This concludes Part 2 of the book (Computer Vision). In the next chapter, we enter the world of **Natural Language Processing (NLP)**, starting with Sentiment Analysis.