In [11]:
from __future__ import print_function
try: 
    from urllib.request import urlretrieve 
except ImportError: 
    from urllib import urlretrieve
import sys
import gzip
import shutil
import os
import struct
import numpy as np
import pickle

In [12]:
def loadData(src, cimg):
    print ('Downloading ' + src)
    gzfname, h = urlretrieve(src, './delete.me')
    print ('Done.')
    try:
        with gzip.open(gzfname) as gz:
            n = struct.unpack('I', gz.read(4))
            # Read magic number.
            if n[0] != 0x3080000:
                raise Exception('Invalid file: unexpected magic number.')
            # Read number of entries.
            n = struct.unpack('>I', gz.read(4))[0]
            if n != cimg:
                raise Exception('Invalid file: expected {0} entries.'.format(cimg))
            crow = struct.unpack('>I', gz.read(4))[0]
            ccol = struct.unpack('>I', gz.read(4))[0]
            if crow != 28 or ccol != 28:
                raise Exception('Invalid file: expected 28 rows/cols per image.')
            # Read data.
            res = np.fromstring(gz.read(cimg * crow * ccol), dtype = np.uint8)
    finally:
        os.remove(gzfname)
    return res.reshape((cimg, crow * ccol))

def loadLabels(src, cimg):
    print ('Downloading ' + src)
    gzfname, h = urlretrieve(src, './delete.me')
    print ('Done.')
    try:
        with gzip.open(gzfname) as gz:
            n = struct.unpack('I', gz.read(4))
            # Read magic number.
            if n[0] != 0x1080000:
                raise Exception('Invalid file: unexpected magic number.')
            # Read number of entries.
            n = struct.unpack('>I', gz.read(4))
            if n[0] != cimg:
                raise Exception('Invalid file: expected {0} rows.'.format(cimg))
            # Read labels.
            res = np.fromstring(gz.read(cimg), dtype = np.uint8)
    finally:
        os.remove(gzfname)
    return res.reshape((cimg, 1))

def load(dataSrc, labelsSrc, cimg):
    data = loadData(dataSrc, cimg)
    labels = loadLabels(labelsSrc, cimg)
    return np.hstack((data, labels))




def savetxt(filename, ndarray):
    with open(filename, 'w') as f:
        labels = list(map(' '.join, np.eye(10, dtype=np.uint).astype(str)))
        for row in ndarray:
            row_str = row.astype(str)
            label_str = labels[row[-1]]
            feature_str = ' '.join(row_str[:-1])
            f.write('|labels {} |features {}\n'.format(label_str, feature_str))

In [14]:

if __name__ == "__main__":
    os.chdir(os.path.abspath(os.path.dirname('__file__')))
    train = load('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', 60000)
    print ('Writing train text file...')
    print(train[:2])
    print(train.shape)
    with open('../Data/mnist_train.pkl', 'wb') as f:
         pickle.dump(train, f)
    
    print ('Done.')
    test = load('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', 10000)
    print ('Writing test text file...')
    print(test[:2])
    print(test.shape)
    with open('../Data/mnist_test.pkl', 'wb') as f:
         pickle.dump(test, f)
    print ('Done.')


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Done.




Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Done.




Writing train text file...
[[0 0 0 ... 0 0 5]
 [0 0 0 ... 0 0 0]]
(60000, 785)
Done.
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Done.
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Done.
Writing test text file...
[[0 0 0 ... 0 0 7]
 [0 0 0 ... 0 0 2]]
(10000, 785)
Done.
