In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Conv2DTranspose
from tensorflow.keras.callbacks import EarlyStopping

import matplotlib.pyplot as plt



tf.keras.utils.set_random_seed(13)

In [None]:
data, info = tfds.load('oxford_iiit_pet', with_info=True, as_supervised=False)

In [None]:
train_dataset = data['train']
test_dataset = data['test']

In [None]:
print(info)

In [None]:
IMG_HEIGHT = 128
IMG_WIDTH = 128
BATCH_SIZE = 16

def normalize_input(example):
    image = tf.image.resize(example['image'], (IMG_HEIGHT, IMG_WIDTH))
    image = tf.cast(image, tf.float32) / 255.0

    mask = tf.image.resize(example['segmentation_mask'], (IMG_HEIGHT, IMG_WIDTH), method='nearest')
    mask = tf.cast(mask, tf.uint8) -1


    return image, mask

In [None]:
train_ds = train_dataset.map(normalize_input, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.cache().shuffle(100).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

test_ds = test_dataset.map(normalize_input, num_parallel_calls=tf.data.AUTOTUNE)
test_ds = test_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

In [None]:
for image, mask in train_ds.take(1):
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(image[0])
    plt.title("Image")

    plt.subplot(1, 2, 2)
    plt.imshow(mask[0, :, :, 0])
    plt.title("Mask")
    plt.show()


In [None]:
def unet_model(input_size=(128, 128, 3), num_classes=3):
    inputs = tf.keras.Input(shape=input_size)

    # Encoder (downsampling)
    c1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D((2, 2))(c1)

    c2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D((2, 2))(c2)

    c3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
    c3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c3)
    p3 = layers.MaxPooling2D((2, 2))(c3)

    c4 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(p3)
    c4 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c4)
    p4 = layers.MaxPooling2D((2, 2))(c4)

    # Bottleneck
    c5 = layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(p4)
    c5 = layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(c5)

    # Decoder (upsampling)
    u6 = layers.Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = layers.concatenate([u6, c4])
    c6 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(u6)
    c6 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c6)

    u7 = layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = layers.concatenate([u7, c3])
    c7 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(u7)
    c7 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c7)

    u8 = layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = layers.concatenate([u8, c2])
    c8 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(u8)
    c8 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c8)

    u9 = layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = layers.concatenate([u9, c1])
    c9 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(u9)
    c9 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c9)

    # Output layer
    outputs = layers.Conv2D(num_classes, (1, 1), activation='softmax')(c9)

    model = Model(inputs=inputs, outputs=outputs)
    return model

In [None]:
callback = EarlyStopping(monitor='val_accuracy', patience=3, restore_best_weights=True)
model = unet_model(input_size=(128, 128, 3), num_classes=3)
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=['accuracy'])

model.fit(train_ds, validation_data=test_ds, epochs=50, callbacks=[callback])

In [None]:
for image_batch, mask_batch in test_ds.take(1):
    image = image_batch[0]
    mask = mask_batch[0]
    break

def display_image_and_mask(image, mask):
    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.title("Image")
    plt.imshow(image)
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.title("Mask")
    plt.imshow(mask, cmap='jet')
    plt.axis('off')

    plt.show()

display_image_and_mask(image, mask)