In [8]:
import os
import logging
import pickle

from pathlib import Path

# set up logging
DEBUG_LEVEL = bool(int(os.getenv("DEBUG_LEVEL", 0)))

if DEBUG_LEVEL:
    logging.basicConfig(level=logging.DEBUG)
else:
    logging.basicConfig(level=logging.INFO)

_LOGGER = logging.getLogger(__name__)

IMPORT_DATASET = str(os.environ.get("TUTORIAL_TF_IMPORT_DATASET", "mnist"))

if IMPORT_DATASET == "mnist":
    _LOGGER.info("Selected MNIST Dataset: MNIST handwritten digits dataset.")
    from tensorflow.keras.datasets import mnist as tf_dataset

elif IMPORT_DATASET == "cifar10":
    _LOGGER.info(
        "Selected cifar10 Dataset: CIFAR10 small images classification dataset."
    )
    from tensorflow.keras.datasets import cifar10 as tf_dataset

elif IMPORT_DATASET == "cifar100":
    _LOGGER.info(
        "Selected cifar100 Dataset: CIFAR100 small images classification dataset."
    )
    from tensorflow.keras.datasets import cifar100 as tf_dataset

elif IMPORT_DATASET == "fashion_mnist":
    _LOGGER.info("Selected fashion_mnist Dataset: Fashion-MNIST dataset.")
    from tensorflow.keras.datasets import fashion_mnist as tf_dataset


def _is_file_downloaded(file_downloaded_path: Path) -> bool:
    """Check if file is already downloaded."""
    if os.path.exists(file_downloaded_path):
        _LOGGER.info("{} already exists, skipping ...".format(file_downloaded_path))
        return True

    return False

ModuleNotFoundError: No module named '__main__.elyra_aidevsecops_tutorial'; '__main__' is not a package

In [None]:
# Set path where to store
directory_path = Path.cwd().parents[1]
destination_path = directory_path.joinpath(
    str(os.environ.get("DESTINATION_DATASET", "data/raw"))
)

if not os.path.exists(destination_path):
    destination_path.mkdir(parents=True, exist_ok=True)

# Prepare MNIST data.
(x_train, y_train), (x_test, y_test) = mnist.load_data()

dataset = {
    "xdata.pkl": x_train,
    "ydata.pkl": y_train,
    "xtestdata.pkl": x_test,
    "ytestdata.pkl": y_test,
}

# Store MNIST data for next step.
for data_name, data_file in dataset.items():

    file_downloaded_path = destination_path.joinpath(data_name)

    if not _is_file_downloaded(file_downloaded_path):
        output = open(file_downloaded_path, "wb")
        pickle.dump(data_file, output)
        output.close()
        _LOGGER.info("Stored {}".format(file_downloaded_path))