In [4]:
import sys
sys.path.append('../')

In [5]:
import numpy as np
import tensorflow_datasets as tfds
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from pathlib import Path
from conf import BASE_DIR

NUM_CLASSES_10 = 10
NUM_CLASSES_100 = 100
NUM_CLASSES_43 = 43
DATASETS = {
    'cifar10': (NUM_CLASSES_10, BASE_DIR / 'train_test_data' / 'Cifar'),
    'mnist': (NUM_CLASSES_10, BASE_DIR / 'train_test_data' / 'Mnist'),
    'cifar100': (NUM_CLASSES_100, BASE_DIR / 'train_test_data' / 'Cifar100'),
    'fashion_mnist': (NUM_CLASSES_10, BASE_DIR / 'train_test_data' / 'FashionMnist'), 
    'svhn_cropped': (NUM_CLASSES_10, BASE_DIR / 'train_test_data' / 'SVHN'),
}

def process_and_save_dataset(dataset_name, num_classes, data_dir, fraction=0.5):
    data_dir = Path(data_dir)
    data_dir.mkdir(parents=True, exist_ok=True)

    dataset = tfds.load(dataset_name, data_dir='../.data')
    
    x_train, y_train = list(zip(*((sample['image'], sample['label']) for sample in dataset['train'])))
    x_train = np.stack(x_train)
    y_train = np.stack(y_train)

    if fraction < 1:
        train_size = int(len(x_train) * fraction)
        x_train, y_train = x_train[:train_size], y_train[:train_size]

    # Process test data
    x_test, y_test = list(zip(*((sample['image'], sample['label']) for sample in dataset['test'])))
    x_test = np.stack(x_test)
    y_test = np.stack(y_test)


    # Reduce the size of the test data

    if fraction < 1:
        test_size = int(len(x_test) * fraction)
        x_test, y_test = x_test[:test_size], y_test[:test_size]

    # Split the test data into test and validation data
    x_test, x_val, y_test, y_val = train_test_split(x_test, y_test, test_size=0.5, random_state=42)

    
    # Convert labels to categorical
    y_train = to_categorical(y_train, num_classes=num_classes)
    y_test = to_categorical(y_test, num_classes=num_classes)
    y_val = to_categorical(y_val, num_classes=num_classes)

    for arr, filename in [
            (x_train, 'x_train.npy'),
            (y_train, 'y_train.npy'),
            (x_test, 'x_test.npy'),
            (y_test, 'y_test.npy'),
            (x_val, 'x_valid.npy'),
            (y_val, 'y_valid.npy')]:
        np.save(data_dir / filename, arr)

if __name__ == "__main__":
    for dataset_name, (num_classes, data_dir) in DATASETS.items():
        process_and_save_dataset(dataset_name, num_classes, data_dir)
        print(f"Processed and saved {dataset_name} dataset")


Processed and saved cifar10 dataset
Processed and saved mnist dataset
Processed and saved cifar100 dataset
Processed and saved fashion_mnist dataset
Processed and saved svhn_cropped dataset
