# Imports

In [1]:
import json
import random
import numpy as np

from shutil import rmtree
from pathlib import Path
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.utils import save_img

# Config

In [2]:
DATA_FOLDER = Path.cwd() / "data"

# Labels

## Output path

In [3]:
class_names_file = DATA_FOLDER / "class_names.json"

## Data creation

In [4]:
class_names = {
    0: "T-shirt/top",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle boot"
}

In [5]:
with open(class_names_file, "w") as file:
    json.dump(class_names, file)

# Data

## Output paths

In [6]:
Path(sys.executable)

PosixPath('/Users/sofiene.alouini/miniconda3/envs/vertex-ai-demo/bin/python')

In [7]:
if DATA_FOLDER.exists() and "vertex-ai-demo" in DATA_FOLDER.parents:
    rmtree(DATA_FOLDER, ignore_errors=True)

train_data_folder_partial = DATA_FOLDER / "train_partial"
train_data_folder_all = DATA_FOLDER / "train_all"
test_data_folder = DATA_FOLDER / "test"

for class_name in class_names.keys():
    (train_data_folder_partial / str(class_name)).mkdir(parents=True, exist_ok=True)
    (train_data_folder_all / str(class_name)).mkdir(parents=True, exist_ok=True)
    (test_data_folder / str(class_name)).mkdir(parents=True, exist_ok=True)

## Data creation

In [8]:
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()

print("TRAIN:")
print("Images:", X_train.shape)
print("Labels:", y_train.shape)

print("\nTEST:")
print("Images:", X_test.shape)
print("Labels:", y_test.shape)

TRAIN:
Images: (60000, 28, 28)
Labels: (60000,)

TEST:
Images: (10000, 28, 28)
Labels: (10000,)


In [9]:
def create_split(X: np.ndarray, y: np.ndarray, target_folder: Path, samples_per_class: dict = None, seed: int = 42):
    X_ = np.expand_dims(X, axis=-1)
    if samples_per_class is None:
        indices = range(len(X_))
    else:
        indices = []
        for class_num in range(10):
            class_indices = (y == class_num).nonzero()[0].tolist()
            
            if class_num in samples_per_class.keys():
                class_samples = samples_per_class[class_num]
                if class_samples > len(indices):
                    raise ValueError(f"There is a total of {str(len(indices))} examples for class {str(class_num)}")
                random.seed(seed)
                class_indices = random.sample(class_indices, class_samples)
            
            indices.extend(class_indices)
                
    for idx in indices:
        img_array, img_class = X_[idx], y[idx]
        img_file_path = target_folder / str(img_class) / f"{str(idx).zfill(5)}.jpg"
        save_img(img_file_path, img_array)

In [10]:
create_split(X_test, y_test, test_data_folder)
create_split(X_train, y_train, train_data_folder_all)
create_split(X_train, y_train, train_data_folder_partial, samples_per_class={7: 500, 9: 500})