In [None]:
import sys
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np


dataset_folder = "../datasets"

if 'google.colab' in sys.modules:
    from google.colab import drive

    drive.mount('/content/drive')
    dataset_folder = "/content/drive/My Drive/Msc Project/msc-project-source-code-files-24-25-Stanley-Okwii/datasets"


# Load the ImageNet dataset.
print("Loading the ImageNet dataset...")
# We will only load the training split initially to get class counts.
# The following line assumes you have a local copy of ImageNet files.
(ds_train_full,), ds_info = tfds.load(
    'imagenet2012',
    split=['train'],
    shuffle_files=False, # Important to keep order for consistent splits
    data_dir='../datasets',
    as_supervised=True,
    with_info=True
)

# --- Step 1: Count samples per class ---
# We need to get the number of samples for each class to perform a stratified split.
# This is still a memory-intensive operation but only requires storing counts, not all images.
print("Counting samples per class...")

# Create a dictionary to hold the counts for each of the 1000 ImageNet labels.
class_counts = {label: 0 for label in range(1000)}
for _, label in ds_train_full:
    class_counts[label.numpy()] += 1

# --- Step 2: Determine split sizes per class ---
print("Calculating split sizes for 80% train and 20% validation...")

train_sizes = {}
val_sizes = {}

for label, count in class_counts.items():
    # Calculate the number of samples for the training set (80%).
    num_train = int(np.floor(0.8 * count))
    train_sizes[label] = num_train
    val_sizes[label] = count - num_train

# --- Step 3: Create the final datasets using filtering and skipping ---
# This is the memory-efficient part. We create a single shuffled dataset and then
# filter it to build the train and validation splits.

# First, create a single, unshuffled dataset from the original TFDS split.
ds_full = tfds.load(
    'imagenet2012',
    split='train',
    shuffle_files=True, # Shuffle the files for better randomness
    as_supervised=True,
)

# Now, we define a function to filter and get the stratified splits.
# This approach is still not a true "perfect" stratification in one pass because
# `filter` and `take` operations can be complex with large, shuffled datasets.
# A more robust solution for perfect stratification on disk would require writing
# to sharded files, which is more complex.
# The following is a common and practical approximation.

# Let's create a temporary list of tuples (label, item) to ensure we get the right number of samples.
all_items = list(ds_full.as_numpy_iterator())
np.random.shuffle(all_items) # Shuffle the combined list

train_list = []
val_list = []
train_counts = {label: 0 for label in range(1000)}
val_counts = {label: 0 for label in range(1000)}

for image, label in all_items:
    if train_counts[label] < train_sizes[label]:
        train_list.append((image, label))
        train_counts[label] += 1
    elif val_counts[label] < val_sizes[label]:
        val_list.append((image, label))
        val_counts[label] += 1
    # Once we have enough samples for a class, we stop taking from it.
    
print("Converting lists to final datasets...")
train_ds = tf.data.Dataset.from_generator(
    lambda: (x for x in train_list),
    output_signature=(
        tf.TensorSpec(shape=(224, 224, 3), dtype=tf.uint8),
        tf.TensorSpec(shape=(), dtype=tf.int64)
    )
)

val_ds = tf.data.Dataset.from_generator(
    lambda: (x for x in val_list),
    output_signature=(
        tf.TensorSpec(shape=(224, 224, 3), dtype=tf.uint8),
        tf.TensorSpec(shape=(), dtype=tf.int64)
    )
)

# --- Step 4: Verification ---
print(f"\nTraining dataset size: {len(train_list)} samples")
print(f"Validation dataset size: {len(val_list)} samples")

# Final verification of class distribution (first 10 classes)
print("\nClass distribution in Training set (first 10 classes):")
for label in range(10):
    print(f"  Class {label}: {train_counts[label]} samples")

print("\nClass distribution in Validation set (first 10 classes):")
for label in range(10):
    print(f"  Class {label}: {val_counts[label]} samples")

Loading the ImageNet dataset...


2025-08-19 02:22:55.642900: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Max
2025-08-19 02:22:55.642950: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 64.00 GB
2025-08-19 02:22:55.642966: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 24.00 GB
2025-08-19 02:22:55.642997: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-08-19 02:22:55.643013: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


Counting samples per class...


2025-08-19 02:33:25.625086: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: CANCELLED: ../datasets/imagenet2012/5.1.0/imagenet2012-train.tfrecord-00026-of-01024; Operation canceled


CancelledError: {{function_node __wrapped__IteratorGetNext_output_types_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} ../datasets/imagenet2012/5.1.0/imagenet2012-train.tfrecord-00026-of-01024; Operation canceled [Op:IteratorGetNext] name: 