In [1]:
#!/usr/bin/env python
# coding=utf-8
'''
@Author: John
@Email: johnjim0816@gmail.com
@Date: 2020-05-21 23:36:58
@LastEditor: John
@LastEditTime: 2020-05-22 07:24:45
@Discription: 
@Environment: python 3.7.7
'''
import numpy as np
from struct import unpack
import gzip

def __read_image(path):
    with gzip.open(path, 'rb') as f:
        magic, num, rows, cols = unpack('>4I', f.read(16))
        img=np.frombuffer(f.read(), dtype=np.uint8).reshape(num, 28*28)
    return img

def __read_label(path):
    with gzip.open(path, 'rb') as f:
        magic, num = unpack('>2I', f.read(8))
        lab = np.frombuffer(f.read(), dtype=np.uint8)
        # print(lab[1])
    return lab
    
def __normalize_image(image):
    img = image.astype(np.float32) / 255.0
    return img

def __one_hot_label(label):
    lab = np.zeros((label.size, 10))
    for i, row in enumerate(lab):
        row[label[i]] = 1
    return lab

def load_mnist(x_train_path, y_train_path, x_test_path, y_test_path, normalize=True, one_hot=True):
    
    '''读入MNIST数据集
    Parameters
    ----------
    normalize : 将图像的像素值正规化为0.0~1.0
    one_hot_label : 
        one_hot为True的情况下，标签作为one-hot数组返回
        one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组
    Returns
    ----------
    (训练图像, 训练标签), (测试图像, 测试标签)
    '''
    image = {
        'train' : __read_image(x_train_path),
        'test'  : __read_image(x_test_path)
    }

    label = {
        'train' : __read_label(y_train_path),
        'test'  : __read_label(y_test_path)
    }
    
    if normalize:
        for key in ('train', 'test'):
            image[key] = __normalize_image(image[key])

    if one_hot:
        for key in ('train', 'test'):
            label[key] = __one_hot_label(label[key])

    return (image['train'], label['train']), (image['test'], label['test'])

x_train_path='./Mnist/train-images-idx3-ubyte.gz'
y_train_path='./Mnist/train-labels-idx1-ubyte.gz'
x_test_path='./Mnist/t10k-images-idx3-ubyte.gz'
y_test_path='./Mnist/t10k-labels-idx1-ubyte.gz'
(x_train,y_train),(x_test,y_test)=load_mnist(x_train_path, y_train_path, x_test_path, y_test_path)