In [2]:
import numpy as np
np.set_printoptions(edgeitems=3)
np.core.arrayprint._line_width = 999

from mlxtend.data import loadlocal_mnist
import pathlib
from PIL import Image
import os
import readmat
import h5py

In [None]:
# TODO: set accordingly
data_path_str = ''
data_path = pathlib.Path(data_path_str)

In [3]:
#included here for convenience, can cause errors if imported
import pickle

class DigitImage:
    def __init__(self, image, label):
        self.image = image
        self.label = label

class DigitDataset:
    def __init__(self, train_set, test_set):
        self.train_set = train_set
        self.test_set = test_set


def persist_dataset_to_pickle(dataset: DigitDataset, path):
    with open(path, 'wb') as f:
        pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL)

def load_dataset_from_pickle(path) -> DigitDataset:
    with open(path, 'rb') as f:
        return pickle.load(f)


In [29]:
def extract_mnist_imgs(mnist_set):
    return [np.reshape(x,(28,28)) for x in list(mnist_set)]

def extract_usps_imgs(usps_set):
    return [np.reshape(x,(16,16)) for x in list((usps_set*255).astype(np.uint8))]

def process_mnistm():
    mnistm_path = pathlib.Path(data_path/'mnistm')
    train_list = []
    test_list = []
    for set in os.listdir(mnistm_path):
        for label in os.listdir(mnistm_path/set):
            for image in os.listdir(mnistm_path/set/label):
                img = np.asarray(Image.open(mnistm_path/set/label/image))
                if set == 'train':
                    train_list.append(DigitImage(img, int(label)))
                elif set == 'test':
                    test_list.append(DigitImage(img, int(label)))
                else:
                    raise Exception(f'Set not defined: {set}')
    return DigitDataset(train_list, test_list)

def process_mnist():
    mnist_path = pathlib.Path(data_path/'mnist')
    X_train, y_train = loadlocal_mnist(
                images_path=mnist_path/'train-images.idx3-ubyte',
                labels_path=mnist_path/'train-labels.idx1-ubyte')
    X_test, y_test = loadlocal_mnist(
                images_path=mnist_path/'t10k-images.idx3-ubyte',
                labels_path=mnist_path/'t10k-labels.idx1-ubyte')
    X_train_imgs = extract_mnist_imgs(X_train)
    x_test_imgs = extract_mnist_imgs(X_test)

    train_list = [DigitImage(image , label) for (image, label) in zip(X_train_imgs, y_train)]
    test_list = [DigitImage(image, label) for (image, label) in zip(x_test_imgs, y_test)]
    return DigitDataset(train_list,test_list)



def process_synnum():
    synnum_path = pathlib.Path(data_path/'synnum')
    synnum_train = readmat.loadmat(synnum_path/'synth_train_32x32.mat')
    synnum_test = readmat.loadmat(synnum_path/'synth_test_32x32.mat')
    assert synnum_train['y'].shape[-1] == synnum_train['X'].shape[-1] #make sure data and label dimension adds up

    train_list = []
    test_list = []

    train_imgs = np.moveaxis(synnum_train['X'],3,0)
    test_imgs = np.moveaxis(synnum_test['X'],3,0)

    assert train_imgs.shape[0] == synnum_train['y'].shape[0]
    assert test_imgs.shape[0] == synnum_test['y'].shape[0]

    for idx in range(train_imgs.shape[0]):
        train_list.append(DigitImage(train_imgs[idx], synnum_train['y'][idx]))
    for idx in range(test_imgs.shape[0]):
        test_list.append(DigitImage(test_imgs[idx], synnum_test['y'][idx]))

    return DigitDataset(train_list, test_list)


def process_emnist():
    emnist_path = pathlib.Path(data_path/'emnist')
    emnist = readmat.loadmat(emnist_path/'emnist-digits.mat')

    train_list = []
    test_list = []

    train_imgs = [np.asarray(x).reshape(28,28) for x in emnist['dataset']['train']['images']]
    test_imgs = [np.asarray(x).reshape(28,28) for x in emnist['dataset']['test']['images']]

    assert len(train_imgs) == len(emnist['dataset']['train']['labels'])
    assert len(test_imgs) == len(emnist['dataset']['test']['labels'])

    for img, label in zip(train_imgs,emnist['dataset']['train']['labels']):
        train_list.append(DigitImage(img, label))
    for img, label in zip(test_imgs,emnist['dataset']['test']['labels']):
        test_list.append(DigitImage(img, label))

    return DigitDataset(train_list, test_list)

def process_svhn():
    svhn_path = pathlib.Path(data_path/'SVHN')
    svhn_train = readmat.loadmat(svhn_path/'train_32x32.mat')
    svhn_test = readmat.loadmat(svhn_path/'test_32x32.mat')

    train_imgs = np.moveaxis(svhn_train['X'],3,0)
    test_imgs = np.moveaxis(svhn_test['X'],3,0)

    for idx in range(train_imgs.shape[0]):
        if svhn_train['y'][idx] == 10:
            svhn_train['y'][idx] = 0

    for idx in range(test_imgs.shape[0]):
        if svhn_test['y'][idx] == 10:
            svhn_test['y'][idx] = 0

    train_list = []
    test_list = []

    assert train_imgs.shape[0] == svhn_train['y'].shape[0]
    assert test_imgs.shape[0] == svhn_test['y'].shape[0]

    for idx in range(train_imgs.shape[0]):
        #labels consider 0 to be 10
        train_list.append(DigitImage(train_imgs[idx], svhn_train['y'][idx]))
    for idx in range(test_imgs.shape[0]):
        #labels consider 0 to be 10
        test_list.append(DigitImage(test_imgs[idx], svhn_test['y'][idx]))

    return  DigitDataset(train_list, test_list)

def process_usps():
    usps_path = pathlib.Path(data_path/'usps')
    usps_file = usps_path/'usps.h5'

    with h5py.File(usps_file, 'r') as hf:
            train = hf.get('train')
            X_tr = train.get('data')[:]
            y_tr = train.get('target')[:]
            test = hf.get('test')
            X_te = test.get('data')[:]
            y_te = test.get('target')[:]

    train_list = []
    test_list = []

    train_imgs = extract_usps_imgs(X_tr)
    test_imgs = extract_usps_imgs(X_te)

    for img, label in zip(train_imgs,y_tr):
        train_list.append(DigitImage(img, label))
    for img, label in zip(test_imgs,y_te):
        test_list.append(DigitImage(img, label))

    return DigitDataset(train_list, test_list)

In [5]:
mnist = process_mnist()
print('Done')

Done


In [66]:
mnist_m = process_mnistm()
print('Done')

Done


In [78]:
synnum = process_synnum()
print('Done')

Done


In [79]:
emnist = process_emnist()
print('Done')

Done


In [30]:
svhn = process_svhn()
print('Done')

Done


In [71]:
usps = process_usps()
print('Done')

Done


In [84]:
save_location_path = pathlib.Path('../../../data')
persist_dataset_to_pickle(mnist, save_location_path/'mnist.pkl')
persist_dataset_to_pickle(mnist_m, save_location_path/'mnist_m.pkl')
persist_dataset_to_pickle(synnum, save_location_path/'synnum.pkl')
persist_dataset_to_pickle(emnist, save_location_path/'emnist.pkl')
persist_dataset_to_pickle(svhn, save_location_path/'svhn.pkl')
persist_dataset_to_pickle(usps, save_location_path/'usps.pkl')
