In [1]:
pip install tensorflow



In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
import numpy as np

# Load dataset
dataset, info = tfds.load('oxford_iiit_pet', with_info=True)

# Preprocessing function
def preprocess_data(data):
    image = data['image']
    mask  =  data['segmentation_mask']
    # Resize to 128x128
    image = tf.image.resize(image, (128, 128))
    mask = tf.image.resize(mask, (128, 128), method='nearest')
    # Normalize image and adjust mask
    image = tf.cast(image, tf.float32) / 255.0
    # The original mask has values 1, 2, 3. We remap them to 0, 1, 2.
    mask = tf.cast(mask, tf.int32) - 1
    mask = tf.squeeze(mask, axis=-1) # Shape: (128, 128)
    return image, mask

# Prepare subsets and batch
BATCH_SIZE = 16
# Using a subset of the data for faster training in a demo
train_dataset = dataset['train'].map(preprocess_data).shuffle(1000).take(1000).batch(BATCH_SIZE)
val_dataset = dataset['train'].map(preprocess_data).shuffle(1000).take(200).batch(BATCH_SIZE)
test_dataset = dataset['test'].map(preprocess_data).take(200).batch(BATCH_SIZE)

# Encoder block
def encoder_block(inputs, filters):
    x = layers.Conv2D(filters, 3, padding='same', activation='relu')(inputs)
    x = layers.Conv2D(filters, 3, padding='same', activation='relu')(x)
    p = layers.MaxPooling2D((2, 2))(x)
    return x, p  # x for skip connection, p for next layer

# Bridge
def bridge(inputs, filters):
    x = layers.Conv2D(filters, 3, padding='same', activation='relu')(inputs)
    x = layers.Conv2D(filters, 3, padding='same', activation='relu')(x)
    return x

# Decoder block
def decoder_block(inputs, skip, filters):
    x = layers.Conv2DTranspose(filters, 2, strides=2, padding='same')(inputs)
    x  =  layers.Concatenate()([x,  skip])
    x = layers.Conv2D(filters, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(filters, 3, padding='same', activation='relu')(x)
    return x

# Build model
inputs = layers.Input(shape=(128, 128, 3))

# Encoder
s1, p1 = encoder_block(inputs, 64)  # 128x128x64 -> 64x64x64
s2, p2 = encoder_block(p1, 128)     # 64x64x128 -> 32x32x128
s3, p3 = encoder_block(p2, 256)     # 32x32x256 -> 16x16x256
s4, p4 = encoder_block(p3, 512)     # 16x16x512 -> 8x8x512

# Bridge
b = bridge(p4, 1024)                # 8x8x1024

# Decoder
d1 = decoder_block(b, s4, 512)      # 16x16x512
d2 = decoder_block(d1, s3, 256)     # 32x32x256
d3 = decoder_block(d2, s2, 128)     # 64x64x128
d4  =  decoder_block(d3,  s1,  64)  # 128x128x64

# Output layer
outputs = layers.Conv2D(3, 1, activation='softmax')(d4) # 128x128x3
model  =  models.Model(inputs,  outputs)

# Compile model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train the model
history = model.fit(train_dataset, epochs=10, validation_data=val_dataset)

# Plot training history
plt.figure(figsize=(8, 5))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

plt.figure(figsize=(8, 5))
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

# IoU function
def compute_iou(true_masks, pred_masks, num_classes):
    ious = []
    for c in range(num_classes):
        true_c = (true_masks == c)
        pred_c = (pred_masks == c)
        intersection = np.logical_and(true_c, pred_c).sum()
        union = np.logical_or(true_c, pred_c).sum()
        iou = intersection / union if union > 0 else 1.0
        ious.append(iou)
    return np.mean(ious)

# Test predictions and visualization
for images, masks in test_dataset.take(1):
    predictions = model.predict(images)
    pred_masks = tf.argmax(predictions, axis=-1).numpy()
    true_masks = masks.numpy()
    # Compute IoU for this batch
    iou = compute_iou(true_masks, pred_masks, 3)
    print(f'Mean IoU: {iou:.4f}')
    # Visualize 5 samples
    for i in range(5):
        plt.figure(figsize=(15, 5))
        plt.subplot(1, 3, 1)
        plt.imshow(images[i])
        plt.title('Image')
        plt.axis('off')
        plt.subplot(1, 3, 2)
        plt.imshow(true_masks[i], cmap='jet')
        plt.title('True  Mask')
        plt.axis('off')
        plt.subplot(1, 3, 3)
        plt.imshow(pred_masks[i], cmap='jet')
        plt.title('Predicted Mask')
        plt.axis('off')
        plt.show()