In [20]:
import tensorflow as tf
import tensorflow_datasets as tfds

# Data aquisition&preparation

In [21]:
# Oxford-IIIT pet dataset: https://www.tensorflow.org/datasets/catalog/oxford_iiit_pet
dataset_name = "oxford_iiit_pet" # dataset name
split = [
    "train[:60%]+test[:60%]",
    "train[60%:80%]+test[60%:80%]",
    "train[80%:]+test[80%:]"
] # train-validation-test split

In [22]:
# Oxford Flowers 102 dataset: https://www.tensorflow.org/datasets/catalog/oxford_flowers102
dataset_name = "oxford_flowers102" # dataset name
split = [
    "train[:60%]+validation[:60%]+test[:60%]",
    "train[60%:80%]+validation[60%:80%]+test[60%:80%]",
    "train[80%:]+validation[80%:]+test[80%:]"
] # train-validation-test split

In [23]:
# Hyperparameters
img_size = 64 # resized image size
batch_size = 64 # batch size

In [24]:
def preprocess_image(data): # preprocess an image
    height = tf.shape(data["image"])[0]
    width = tf.shape(data["image"])[1]
    crop_size = tf.minimum(height, width) # find the smallest dimension
    img = tf.image.crop_to_bounding_box(
        data["image"],
        (height - crop_size) // 2,
        (width - crop_size) // 2,
        crop_size,
        crop_size,
    ) # crop the image to a square
    img = tf.cast(img, dtype=tf.float32) # cast the image to float32
    img = tf.image.resize(img, size=(img_size, img_size), antialias=True) # resize the image to img_size x img_size
    return tf.clip_by_value(img / 255.0, 0.0, 1.0) # normalize the image to [0, 1]

In [25]:
def preprocess_dataset(dataset): # preprocess a dataset
    return (dataset.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE) # preprocess images of the dataset
            .batch(batch_size, drop_remainder=True) # organize the dataset into batches
            .shuffle(2 * batch_size) # shuffle the dataset
            .prefetch(buffer_size=tf.data.AUTOTUNE)) # prefetch data for better performance

def load_data(dataset_name): # load dataset from tensorflow datasets with the given name
    train_ds, val_ds, test_ds = tfds.load(dataset_name, split=split, shuffle_files=True) # load the dataset
    train_ds = preprocess_dataset(train_ds) # preprocess the training dataset
    val_ds = preprocess_dataset(val_ds) # preprocess the validation dataset
    test_ds = preprocess_dataset(test_ds) # preprocess the test dataset
    return train_ds, val_ds, test_ds

In [26]:
train_ds, val_ds, test_ds = load_data(dataset_name)