# U-Net for Image Segmentation using Oxford Pets Dataset
This is a simplified version of the U-Net segmentation model using TensorFlow and Keras, adapted from Andrew Ng's lab.

In [None]:
!pip install tensorflow tensorflow-datasets matplotlib -q
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np


In [None]:
# Load the dataset
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)
train_dataset = dataset['train']
test_dataset = dataset['test']

def normalize_img(datapoint):
    image = tf.cast(datapoint['image'], tf.float32) / 255.0
    mask = tf.cast(datapoint['segmentation_mask'], tf.uint8)
    mask = mask - 1  # Convert to {0, 1} from {1, 2}
    mask = tf.where(mask == 255, 1, mask)  # Fix borders
    return image, mask

train_dataset = train_dataset.map(normalize_img).cache().shuffle(1000).batch(32).prefetch(tf.data.AUTOTUNE)
test_dataset = test_dataset.map(normalize_img).batch(32).prefetch(tf.data.AUTOTUNE)


In [None]:
from tensorflow.keras import layers, models

def unet_model(input_size=(128, 128, 3)):
    inputs = tf.keras.Input(input_size)

    # Encoder
    c1 = layers.Conv2D(64, 3, activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(64, 3, activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D()(c1)

    c2 = layers.Conv2D(128, 3, activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(128, 3, activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D()(c2)

    c3 = layers.Conv2D(256, 3, activation='relu', padding='same')(p2)
    c3 = layers.Conv2D(256, 3, activation='relu', padding='same')(c3)
    p3 = layers.MaxPooling2D()(c3)

    # Bottleneck
    c4 = layers.Conv2D(512, 3, activation='relu', padding='same')(p3)
    c4 = layers.Conv2D(512, 3, activation='relu', padding='same')(c4)

    # Decoder
    u5 = layers.UpSampling2D()(c4)
    u5 = layers.Concatenate()([u5, c3])
    c5 = layers.Conv2D(256, 3, activation='relu', padding='same')(u5)
    c5 = layers.Conv2D(256, 3, activation='relu', padding='same')(c5)

    u6 = layers.UpSampling2D()(c5)
    u6 = layers.Concatenate()([u6, c2])
    c6 = layers.Conv2D(128, 3, activation='relu', padding='same')(u6)
    c6 = layers.Conv2D(128, 3, activation='relu', padding='same')(c6)

    u7 = layers.UpSampling2D()(c6)
    u7 = layers.Concatenate()([u7, c1])
    c7 = layers.Conv2D(64, 3, activation='relu', padding='same')(u7)
    c7 = layers.Conv2D(64, 3, activation='relu', padding='same')(c7)

    outputs = layers.Conv2D(1, 1, activation='sigmoid')(c7)

    model = models.Model(inputs, outputs)
    return model

model = unet_model()
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.summary()


In [None]:
IMG_SIZE = 128

def resize(input_image, input_mask):
    input_image = tf.image.resize(input_image, (IMG_SIZE, IMG_SIZE))
    input_mask = tf.image.resize(input_mask[..., tf.newaxis], (IMG_SIZE, IMG_SIZE))
    return input_image, input_mask

train_dataset = train_dataset.map(resize)
test_dataset = test_dataset.map(resize)


In [None]:
EPOCHS = 5
model.fit(train_dataset, validation_data=test_dataset, epochs=EPOCHS)


In [None]:
def display(display_list):
    plt.figure(figsize=(15, 5))
    title = ['Input Image', 'True Mask', 'Predicted Mask']
    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()

for image, mask in test_dataset.take(1):
    pred_mask = model.predict(image)
    pred_mask = tf.round(pred_mask)
    for i in range(3):
        display([image[i], mask[i], pred_mask[i]])
