# Imports

In [2]:
# misc
import sys
import os

# load/save files
import zipfile
import json

# plot
import matplotlib.pyplot as plt
from PIL import Image

# datascience libs
import numpy as np
import pandas as pd

try: # python
    path_ = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
except NameError: # jupyter notebook
    path_ = os.path.dirname(os.getcwd())

dataset_dir = os.path.join(path_, "datasets")
model_dir = os.path.join(path_, "models")


# Helpers

In [3]:
from time import time
from sys import argv

def _time(f):
    def wrapper(*args):
        start = time()
        r = f(*args)
        end = time()
        print("%s timed %f" % (f.__name__, end-start))
        return r
    return wrapper

# Load exemple

In [61]:
try:
    from types import SimpleNamespace as Namespace
except ImportError:
    from argparse import Namespace

# https://towardsdatascience.com/efficiently-splitting-an-image-into-tiles-in-python-using-numpy-d1bf0dd7b6f7
def array_to_tiled_array(img:np.ndarray, kernel_size:tuple):
    if len(img.shape) == 2:
        img = np.expand_dims(img, axis=-1)
    img_height, img_width, channels = img.shape
    tile_height, tile_width = kernel_size
    tiles = img.reshape(img_height // tile_height,
                        tile_height,
                        img_width // tile_width,
                        tile_width,
                        channels)
    return tiles.swapaxes(1,2).reshape(-1, tile_height,tile_width, 1)

def emnistload_data_X(path:str):
    dir_path = os.path.dirname(path)
    with open(path, 'r', encoding='utf-8') as f:
        obj = json.loads(f.read(), object_hook = lambda d: Namespace(**d))
    X = np.zeros((0, 28,28,1))
    for s in obj.files:
        img_path = os.path.join(dir_path, s)
        im = Image.open(img_path).convert('L')
        data = array_to_tiled_array(np.array(im,dtype="uint8"), (28,28))
        X = np.append(X, data, axis=0)
    return X

def emnistload_data_y(path:str):
    dir_path = os.path.dirname(path)
    with open(path, 'r', encoding='utf-8') as f:
        obj = json.loads(f.read(), object_hook = lambda d: Namespace(**d))
    return [np.array(obj.id, dtype="uint8"), np.array(obj.bbox, dtype="uint8")]

@_time
def emnist_load_data(dir_path:str):
    test_path = os.path.join(dir_path, "test.json")
    train_path = os.path.join(dir_path, "train.json")
    X_test = emnistload_data_X(test_path)
    X_train = emnistload_data_X(train_path)
    y_test = emnistload_data_y(test_path)
    y_train = emnistload_data_y(train_path)
    return (X_train, y_train), (X_test, y_test)
    
#(X_train, y_train), (X_test, y_test) = datasets.mnist.load_data(path="mnist.npz")
(X_train, y_train), (X_test, y_test) = emnist_load_data(os.path.join(dataset_dir, "origin-emnist-mnist"))
print("")
print("X_train:", X_train.shape)
print("y_train_id:", y_train[0].shape)
print("y_train_bbox:", y_train[1].shape)
print("")
print("X_test:", X_test.shape)
print("y_test_id:", y_test[0].shape)
print("y_test_bbox:", y_test[1].shape)


emnist_load_data timed 0.024504

X_train: (201, 28, 28, 1)
y_train_id: (201,)
y_train_bbox: (201, 4)

X_test: (201, 28, 28, 1)
y_test_id: (201,)
y_test_bbox: (201, 4)
