# Augment and Split into Training/Validation and Testing

Only the training dataset will be augmented.

The testing dataset is a holdout dataset, hence, will **not** be used to train or validate models, so that it is never exposed to the models during training.

It is purely for use at the end when comparing the various models' responses to new unseen data.

## This program will do the following steps in order:

1. Load the entire `brain_tumor_dataset`

2. Split the dataset into `training_and_validation_dataset` and `testing_dataset`

3. Augment only the `training_and_validation_dataset` to make extra copies with random transformations applied

4. Save the `training_and_validation_dataset` and `testing_dataset` into its respective directories

In [11]:
import tensorflow as tf
import os
from keras.utils import image_dataset_from_directory, split_dataset
from keras.models import Sequential
from keras.layers import RandomFlip, RandomRotation, RandomZoom, RandomContrast
from tensorflow.image import encode_png

In [22]:
"""
Constants and Parameters
"""
IMAGE_SIZE = (150, 150)
BATCH_SIZE = 1

AUGMENTATION_COPIES = 3 # How many augmented copies per image
MAX_ROTATION = 0.0277   # 0.0277 radians = 5 degrees
MAX_ZOOM = 0.05         # Small enough to prevent stretches in one direction but not the other
MAX_CONTRAST = 0.2

INPUT_DIRECTORY = "brain_tumor_dataset"
TRAINING_DIRECTORY = "training_and_validation_dataset" # (Augmented)
TESTING_DIRECTORY = "testing_dataset" # (Not augmented) NEVER used in training - only used in evaluation

TRAINING_SPLIT = 0.8 # 80% train & validation, 20% test

In [35]:
"""
Load images
"""
dataset = image_dataset_from_directory(
    INPUT_DIRECTORY,
    image_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
)

Found 223 files belonging to 2 classes.


In [36]:
"""
Split the dataset into the training and testing portions
"""
training_dataset, testing_dataset = split_dataset(dataset, left_size=TRAINING_SPLIT, shuffle=True)

In [37]:
"""
Count the number of images in each dataset
"""
def count_images(dataset):
    total_images = 0
    for batch, _ in dataset:
        for _ in batch:
            total_images += 1
    return total_images

print(f"Total dataset count before split:    {count_images(dataset)}")
print(f"Training & Validation dataset count: {count_images(training_dataset)}")
print(f"Testing dataset count:               {count_images(testing_dataset)}")

Total dataset count before split:    223
Training & Validation dataset count: 178
Testing dataset count:               45


In [38]:
"""
Define augmentation function
"""
data_augmentation = Sequential([
    RandomFlip("horizontal"),
    RandomRotation(MAX_ROTATION, fill_mode="nearest"),
    # Only change the width to prevent dramatic stretches
    RandomZoom(height_factor=(0, 0), width_factor=(-1 * MAX_ZOOM, MAX_ZOOM), fill_mode="nearest"),
    RandomContrast(MAX_CONTRAST),
])

def augment_image(image, label):
    return data_augmentation(image, training=True), label

def augment_dataset(dataset):
    # Add extra augmented images to original training dataset
    augmented_datasets = [dataset.map(augment_image) for _ in range(AUGMENTATION_COPIES)]

    # Start with the original dataset
    total_dataset = dataset

    # Concatenate augmented datasets sequentially
    for dataset in augmented_datasets:
        total_dataset = total_dataset.concatenate(dataset)

    return total_dataset

In [39]:
"""
Augment images
"""
augmented_training_dataset = augment_dataset(training_dataset)

In [41]:
"""
Count images after augmentation
"""
print(f"Augmentation copies per image:                           {AUGMENTATION_COPIES}")
print(f"Training & Validation dataset count before augmentation: {count_images(training_dataset)}")
print(f"Training & Validation dataset count after augmentation:  {count_images(augmented_training_dataset)}")

Augmentation copies per image:                           3
Training & Validation dataset count before augmentation: 178
Training & Validation dataset count after augmentation:  712


In [43]:
"""
Save dataset
"""
def save_dataset(dataset, directory):
    os.makedirs(directory, exist_ok=True)

    total_images = 0
    class_count = {
        "yes": 0,
        "no": 0,
    }
    for batch, labels in dataset:
        for image, label in zip(batch, labels):
            image = tf.cast(image, tf.uint8)  # Convert for saving

            output_filename = f"{total_images:05d}.png"

            class_name = "yes" if int(label.numpy()) == 1 else "no"
            output_path = os.path.join(directory, class_name, output_filename)

            tf.io.write_file(output_path, encode_png(image))

            total_images += 1
            class_count[class_name] += 1

    print(f"Saved {total_images} images to {directory} ({class_count})")

save_dataset(augmented_training_dataset, TRAINING_DIRECTORY)
save_dataset(testing_dataset, TESTING_DIRECTORY)

Saved 712 images to training_and_validation_dataset ({'yes': 484, 'no': 228})
Saved 45 images to testing_dataset ({'yes': 29, 'no': 16})
