# U-Net Segmentation 🧠

In this notebook, we’ll implement **U-Net**, a powerful convolutional neural network architecture for **semantic segmentation** — particularly useful in **biomedical image segmentation**.

U-Net performs **pixel-wise classification** by combining local (spatial) and contextual (semantic) information using a **U-shaped encoder-decoder structure**.

We’ll cover:
- U-Net architecture overview
- Building U-Net using TensorFlow/Keras
- Training on a sample dataset (e.g., Oxford-IIIT Pet dataset)
- Visualizing segmentation results


## 1. Imports and Setup

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
import numpy as np
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator

## 2. Load and Prepare Dataset

We’ll use the **Oxford-IIIT Pet Dataset** available through TensorFlow Datasets for demonstration.

In [None]:
import tensorflow_datasets as tfds
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

train = dataset['train']

def normalize(input_image, input_mask):
    input_image = tf.cast(input_image, tf.float32) / 255.0
    input_mask -= 1  # classes: 1, 2, 3 → 0, 1, 2
    return input_image, input_mask

train = train.map(lambda x: normalize(x['image'], x['segmentation_mask']))
train_dataset = train.cache().shuffle(1000).batch(16).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

## 3. Visualize a Sample

In [None]:
for image, mask in train.take(1):
    plt.figure(figsize=(10,5))
    plt.subplot(1,2,1)
    plt.imshow(image)
    plt.title('Input Image')
    plt.axis('off')
    plt.subplot(1,2,2)
    plt.imshow(tf.squeeze(mask))
    plt.title('Segmentation Mask')
    plt.axis('off')
    plt.show()

## 4. Define U-Net Architecture

The U-Net consists of:
- **Encoder (Contracting path)** — series of Conv + MaxPool layers that extract features.
- **Bottleneck** — deepest layer capturing high-level features.
- **Decoder (Expanding path)** — upsampling + concatenation layers to recover spatial details.

In [None]:
def unet_model(input_size=(128,128,3)):
    inputs = layers.Input(input_size)

    # Encoder
    c1 = layers.Conv2D(64, 3, activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(64, 3, activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D((2,2))(c1)

    c2 = layers.Conv2D(128, 3, activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(128, 3, activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D((2,2))(c2)

    c3 = layers.Conv2D(256, 3, activation='relu', padding='same')(p2)
    c3 = layers.Conv2D(256, 3, activation='relu', padding='same')(c3)
    p3 = layers.MaxPooling2D((2,2))(c3)

    # Bottleneck
    b = layers.Conv2D(512, 3, activation='relu', padding='same')(p3)
    b = layers.Conv2D(512, 3, activation='relu', padding='same')(b)

    # Decoder
    u1 = layers.Conv2DTranspose(256, 2, strides=(2,2), padding='same')(b)
    u1 = layers.concatenate([u1, c3])
    c4 = layers.Conv2D(256, 3, activation='relu', padding='same')(u1)
    c4 = layers.Conv2D(256, 3, activation='relu', padding='same')(c4)

    u2 = layers.Conv2DTranspose(128, 2, strides=(2,2), padding='same')(c4)
    u2 = layers.concatenate([u2, c2])
    c5 = layers.Conv2D(128, 3, activation='relu', padding='same')(u2)
    c5 = layers.Conv2D(128, 3, activation='relu', padding='same')(c5)

    u3 = layers.Conv2DTranspose(64, 2, strides=(2,2), padding='same')(c5)
    u3 = layers.concatenate([u3, c1])
    c6 = layers.Conv2D(64, 3, activation='relu', padding='same')(u3)
    c6 = layers.Conv2D(64, 3, activation='relu', padding='same')(c6)

    outputs = layers.Conv2D(3, (1,1), activation='softmax')(c6)

    model = models.Model(inputs=[inputs], outputs=[outputs])
    return model

model = unet_model()
model.summary()

## 5. Compile and Train the Model

In [None]:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

EPOCHS = 3  # small for demo
steps_per_epoch = info.splits['train'].num_examples // 16

history = model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=steps_per_epoch)

## 6. Visualize Predictions

In [None]:
def display_sample(display_list):
    plt.figure(figsize=(10,10))
    titles = ['Input Image', 'True Mask', 'Predicted Mask']
    for i in range(len(display_list)):
        plt.subplot(1,3,i+1)
        plt.title(titles[i])
        plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()

for image, mask in train.take(1):
    pred_mask = model.predict(tf.expand_dims(image, axis=0))
    pred_mask = tf.argmax(pred_mask, axis=-1)[0]
    display_sample([image, mask, pred_mask])

## ✅ Summary

In this notebook, we learned how to:
- Build the **U-Net** architecture from scratch.
- Train it for semantic segmentation tasks.
- Visualize segmentation masks.

**Next Steps:**
- Use pretrained backbones (ResNet, EfficientNet) for better results.
- Apply **Dice loss / IoU metrics** for better segmentation accuracy.
- Try **data augmentation** to improve generalization.