# Oxford Pets Image Segmentation

### Table of Contents
1. [Setup and Imports](#setup)
2. [Dataset Preparation](#dataset-prep)
3. [Model Definition](#model-def)
4. [Training](#training)
5. [Visualization](#visualization)


## 1. Setup and Imports <a id="setup"></a>
This section handles the imports and configuration of paths and parameters.


In [30]:
import os
import requests
import tarfile

# Function to download and extract dataset
def download_and_extract(url, download_path, extract_path):
    # Download
    response = requests.get(url, stream=True)
    with open(download_path, "wb") as f:
        for chunk in response.iter_content(chunk_size=8192):
            f.write(chunk)
    # Extract
    with tarfile.open(download_path, "r:gz") as tar:
        tar.extractall(path=extract_path)

# Paths
os.makedirs("dataset", exist_ok=True)
download_and_extract(
    "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz",
    "dataset/images.tar.gz",
    "dataset/" 
)
download_and_extract(
    "https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz",
    "dataset/annotations.tar.gz",
    "dataset/"
)


Downloading https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz...


ConnectionError: HTTPSConnectionPool(host='www.robots.ox.ac.uk', port=443): Max retries exceeded with url: /~vgg/data/pets/data/images.tar.gz (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x15d745d80>: Failed to resolve 'www.robots.ox.ac.uk' ([Errno 8] nodename nor servname provided, or not known)"))

## 2. Dataset Preparation <a id="dataset-prep"></a>
In this section, we load the datasets using the helper function from `data.py`.


In [25]:
import os

# Define paths
input_dir = "dataset/images/"               # Path to input images
target_dir = "dataset/annotations/trimaps/" # Path to segmentation masks

# Define parameters
img_size = (160, 160)
batch_size = 16  # Adjust this value
num_classes = 3


Deleting 0 hidden files...
Hidden files deleted.


NameError: name 'os' is not defined

## 3. What does one image look like <a id="image"></a>

In [None]:
from IPython.display import Image, display
from keras.utils import load_img
from PIL import ImageOps

# Display input image #7
display(Image(filename=input_img_paths[9]))

# Display auto-contrast version of corresponding target (per-pixel categories)
img = ImageOps.autocontrast(load_img(target_img_paths[9]))
display(img)

## 3. Model Definition <a id="model-def"></a>
This section defines a U-Net-like model for image segmentation.


In [27]:
import keras
import numpy as np
from tensorflow import data as tf_data
from tensorflow import image as tf_image
from tensorflow import io as tf_io

# Define paths
input_dir = "dataset/images/"
target_dir = "dataset/annotations/trimaps/"
img_size = (160, 160)
num_classes = 3
batch_size = 16


# Prepare datasets using image paths - Pass the image paths directly
train_dataset, valid_dataset = prepare_datasets(
    input_img_paths, target_img_paths, img_size, batch_size
)

# Define U-Net-like model
def get_unet_model(img_size, num_classes):
    inputs = tf.keras.Input(shape=img_size + (3,))
    
    # Downsampling
    x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    previous_block_activation = x
    
    for filters in [64, 128, 256]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)
        residual = layers.Conv2D(filters, 1, strides=2, padding="same")(previous_block_activation)
        x = layers.add([x, residual])
        previous_block_activation = x

    # Upsampling
    for filters in [256, 128, 64, 32]:
        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.UpSampling2D(2)(x)
        residual = layers.UpSampling2D(2)(previous_block_activation)
        residual = layers.Conv2D(filters, 1, padding="same")(residual)
        x = layers.add([x, residual])
        previous_block_activation = x

    outputs = layers.Conv2D(num_classes, 3, activation="softmax", padding="same")(x)
    return Model(inputs, outputs)

# Load the model
model = get_unet_model(img_size, num_classes)

# Compile and train the model
model.compile(optimizer=tf.keras.optimizers.Adam(1e-4), loss="sparse_categorical_crossentropy")

model.fit(train_dataset, epochs=2, validation_data=valid_dataset)


FileNotFoundError: [Errno 2] No such file or directory: 'dataset/images/'

## 5. Visualization <a id="visualization"></a>
Visualize some predictions from the validation dataset.


In [None]:
# Function to display prediction masks
def display_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = tf.expand_dims(pred_mask, axis=-1)
    return pred_mask[0]

# Display predictions
for images, masks in valid_dataset.take(1):
    predictions = model.predict(images)
    plt.figure(figsize=(10, 10))
    
    # Plot input image
    plt.subplot(1, 3, 1)
    plt.title("Input Image")
    plt.imshow(images[0])
    
    # Plot true mask
    plt.subplot(1, 3, 2)
    plt.title("True Mask")
    plt.imshow(tf.squeeze(masks[0]))
    
    # Plot predicted mask
    plt.subplot(1, 3, 3)
    plt.title("Predicted Mask")
    plt.imshow(tf.squeeze(display_mask(predictions)))
    plt.show()
