In [1]:
#=============================================================================
# Modules
#=============================================================================

# Import modules
import os
import gzip

import numpy as np

from one_parameter_plotting import Rips_Filtration
from multiparameter_landscape_plotting import Rips_Filter_Bifiltration

In [2]:
#=============================================================================
# Functions
#=============================================================================

def loadImages(file_path:str):
    """
    function that returns the loaded images from a .gz file

    Args:
        file_path (str): file path for the .gz image file to be loaded 

    Returns:
        data (np array): loaded image data 
    """
    with gzip.open(file_path, 'rb') as f:
        # Skip the magic number and dimensions (first 16 bytes)
        data = np.frombuffer(f.read(), np.uint8, offset=16)
    return data.reshape(-1, 28, 28)  # reshape to image size

def loadLabels(file_path:str):
    """
    function that returns the loaded labels from a .gz file

    Args:
        file_path (str): file path for the .gz label file to be loaded

    Returns:
        data (np array): loaded label data 
    """
    with gzip.open(file_path, 'rb') as f:
        # Skip the magic number (first 8 bytes)
        data = np.frombuffer(f.read(), np.uint8, offset=8)
    return data

In [5]:
#=============================================================================
# Variables
#=============================================================================

dataPath   = "../../data/raw/"
outputPath = "../../outputs/initial-test/"

# Size of image dimensons for Fashion MNIST dataset
imageDimensions = 28

# Pixel normalisation value
pixels = 255.

In [6]:
 ## Paths to the downloaded data files
trainImagesPath = os.path.join(dataPath, "train-images-idx3-ubyte.gz")
trainLabelsPath = os.path.join(dataPath, "train-labels-idx1-ubyte.gz")
testImagesPath  = os.path.join(dataPath, "t10k-images-idx3-ubyte.gz")
testLabelsPath  = os.path.join(dataPath, "t10k-labels-idx1-ubyte.gz")

## Load the datasets
trainImages = loadImages(trainImagesPath)
trainLabels = loadLabels(trainLabelsPath)
testImages  = loadImages(testImagesPath)
testLabels  = loadLabels(testLabelsPath)


#==========================================================================
# Data pre-processing
#==========================================================================

# pre-process images
trainImages = trainImages.reshape(trainImages.shape[0], -1) / pixels
testImages = testImages.reshape(testImages.shape[0], -1) / pixels

In [9]:
x = trainImages[0]
x.reshape(28,28)
x.shape

(784,)