In [1]:
import tensorflow as tf
import os

# Load image dataset from directory (organized by class subfolders)

In [2]:
dataset_url = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"
path_to_zip = tf.keras.utils.get_file('cats_and_dogs_filtered.zip', origin=dataset_url, extract=True)
data_dir = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered', 'train')  # Use train set

Downloading data from https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip
[1m68606236/68606236[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 0us/step


# Function to create split datasets

In [3]:
def create_split_datasets(data_dir, img_size=(160, 160), batch_size=32, val_split=0.2, test_split=0.1):
    # First, split into training and validation
    full_dataset = tf.keras.utils.image_dataset_from_directory(
        data_dir,
        image_size=img_size,
        batch_size=batch_size,
        validation_split=val_split + test_split,             # Reserve both val + test initially
        subset="training",
        seed=123
    )
    
    valtest_dataset = tf.keras.utils.image_dataset_from_directory(
        data_dir,
        image_size=img_size,
        batch_size=batch_size,
        validation_split=val_split + test_split,
        subset="validation",
        seed=123
    )
    
    val_batches = int(val_split / (val_split + test_split) * len(valtest_dataset))  # Split into val and test
    
    val_dataset = valtest_dataset.take(val_batches)            # First part = validation set
    test_dataset = valtest_dataset.skip(val_batches)           # Remaining part = test set
    
    return full_dataset, val_dataset, test_dataset

# Generate datasets

In [6]:
train_ds, val_ds, test_ds = create_split_datasets(r"C:\Users\Dell\.keras\datasets\cats_and_dogs_filtered_extracted\cats_and_dogs_filtered\train")

Found 2000 files belonging to 2 classes.
Using 1400 files for training.
Found 2000 files belonging to 2 classes.
Using 600 files for validation.


# Print dataset sizes

In [7]:
print(f"Train batches: {len(train_ds)}")
print(f"Validation batches: {len(val_ds)}")
print(f"Test batches: {len(test_ds)}")

Train batches: 44
Validation batches: 12
Test batches: 7
