In [19]:
!pip install tensorflow

import tensorflow as tf
import os

# Path to your dataset
dataset_dir = r"C:\Users\USER\Documents\Thesis Dataset"

# Parameters
img_size = (224, 224)
batch_size = 32
seed = 123



-------------------------------
Step 1: Load the dataset (80% train+val, 20% test)

-------------------------------

In [20]:
train_val_ds = tf.keras.utils.image_dataset_from_directory(
    dataset_dir,
    validation_split=0.2,
    subset="training",
    seed=seed,
    image_size=img_size,
    batch_size=batch_size
)

test_ds = tf.keras.utils.image_dataset_from_directory(
    dataset_dir,
    validation_split=0.2,
    subset="validation",
    seed=seed,
    image_size=img_size,
    batch_size=batch_size
)

Found 15766 files belonging to 3 classes.
Using 12613 files for training.
Found 15766 files belonging to 3 classes.
Using 3153 files for validation.


-------------------------------
Step 2: Split train+val further (60% train, 20% val)

-------------------------------

In [21]:
# Total size of train_val_ds
train_val_size = tf.data.experimental.cardinality(train_val_ds).numpy()

# Calculate split sizes
train_size = int(train_val_size * 0.75)  # 75% of 80% = 60%
val_size = train_val_size - train_size   # remaining 25% of 80% = 20%

train_ds = train_val_ds.take(train_size)
val_ds = train_val_ds.skip(train_size)

-------------------------------
Step 3: Normalize pixel values [0,1]

-------------------------------

In [22]:
normalization_layer = tf.keras.layers.Rescaling(1./255)

train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))
test_ds = test_ds.map(lambda x, y: (normalization_layer(x), y))

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)

In [23]:
for images, labels in train_ds.take(1):
    print("Image batch shape:", images.shape)
    print("Label batch shape:", labels.shape)

Image batch shape: (32, 224, 224, 3)
Label batch shape: (32,)
