In [29]:
import tensorflow.examples.tutorials.mnist.input_data as input_data
import tensorflow as tf
import numpy as np
import os
import gzip

In [2]:
# 可以用FLAGS，这样就可以从终端输入，下次可以试试
TRAIN_DIR = "../../../data/Tensorflow_formatted/mnist/train"
TEST_DIR = "../../../data/Tensorflow_formatted/mnist/test"

In [38]:
# download the file if not exists
# 成功下载后会在目标路径下增加一个名为exists的空文件, 已标志文件是否下载
# 因为后面会把文件直接进行解压处理并删除原文件，所以不能依据原文件进行判别
# => 主要学习os的基本用法
def may_download():
    if not os.path.exists(TRAIN_DIR+"/exists"):
        print("file not exist, try to download the file...")
        mnist = input_data.read_data_sets(TRAIN_DIR, one_hot = True)
        if not os.path.exists(TEST_DIR):
            os.mkdir(TEST_DIR)
        os.rename(TRAIN_DIR+"/t10k-images-idx3-ubyte.gz", TEST_DIR+"/t10k-images-idx3-ubyte.gz")
        os.rename(TRAIN_DIR+"/t10k-labels-idx1-ubyte.gz", TEST_DIR+"/t10k-labels-idx1-ubyte.gz")
        with open(TRAIN_DIR+"/exists", 'w') as f:
            print("create exists")
    else:
        print("file exists")
    train_images_path = TRAIN_DIR + "/train-images-idx3-ubyte.gz"
    train_labels_path = TRAIN_DIR + "/train-labels-idx1-ubyte.gz"
    test_images_path = TEST_DIR + "/t10k-images-idx3-ubyte.gz"
    test_labels_path = TEST_DIR + "/t10k-labels-idx1-ubyte.gz"
    return (train_images_path, train_labels_path, test_images_path, test_labels_path)

In [60]:
may_download()

file not exist, try to download the file...
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting ../../../data/Tensorflow_formatted/mnist/train/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting ../../../data/Tensorflow_formatted/mnist/train/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting ../../../data/Tensorflow_formatted/mnist/train/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting ../../../data/Tensorflow_formatted/mnist/train/t10k-labels-idx1-ubyte.gz
create exists


('../../../data/Tensorflow_formatted/mnist/train/train-images-idx3-ubyte.gz',
 '../../../data/Tensorflow_formatted/mnist/train/train-labels-idx1-ubyte.gz',
 '../../../data/Tensorflow_formatted/mnist/test/t10k-images-idx3-ubyte.gz',
 '../../../data/Tensorflow_formatted/mnist/test/t10k-labels-idx1-ubyte.gz')

In [35]:
# unzip the file directly
# 这个动作没有意义，原文件是一堆的文件，直接解压，得不到什么好东西。
# => 了解gzip的用法
def unzip(file_path):
    file_name,_ = os.path.splitext(file_path)
    with gzip.open(file_path, 'rb') as zip_file:
        with open(file_name, 'wb') as t_file:
            for line in zip_file:
                t_file.write(line)
    os.remove(file_path)
    return file_name

In [48]:
# 名字前有一个下划线，惯例表示内部方法
# =>numpy的数据读取，和dtype类型，改为big endian
def _read32(bytestream):
    # 设置dtype，更改endian的类型
    dt = np.dtype(np.uint32).newbyteorder('>')
    # 读取前4个字节 并返回. frombuffer返回数组，取首值
    return np.frombuffer(bytestream.read(4), dtype = dt)[0]

In [79]:
# 提取文件，应该区分image和label
# image文件，前4个字符是验证码，之后是图片的个数，图片的rows，图片的cols，再之后就是图片的数据了。
# 所以提取步骤是先获得图片的数据，然后返回对应的图片数据。
# 各种数据类型，弄得头都晕了，应该注意区分
def extract_images(file_path):
    file_name, _ = os.path.splitext(file_path)
    if os.path.exists(file_path):
        with gzip.open(file_path, 'rb') as bytestream:
            magic = _read32(bytestream)
            if magic != 2051:
                raise ValueError("Invalid magic number %d in MNIST image file: %s" %(magic, file_path))
            num_image = _read32(bytestream)
            rows = _read32(bytestream)
            cols = _read32(bytestream)
            data = bytestream.read(rows * cols * num_image)
            with open(file_name, 'wb') as f:
                f.write(data)
        #os.remove(file_path)
    else:
        if os.path.exists(file_name):
            print("extracted data already exists.")
        else:
            raise ValueError("target file not exists")
    return file_name

In [80]:
def extract_labels(file_path):
    file_name, _ = os.path.splitext(file_path)
    if os.path.exists(file_path):
        with gzip.open(file_path, 'rb') as bytestream:
            magic = _read32(bytestream)
            if magic != 2049:
                raise ValueError("Invalid magic number %d in MNIST labels file: %s" %(magic, file_path))
            num_items = _read32(bytestream)
            buf = bytestream.read(num_items)
            #labels = np.frombuffer(buf, dtype = np.uint8)
            with open(file_name, 'wb') as f:
                f.write(buf)
        #os.remove(file_path)
    else:
        if os.path.exists(file_name):
            print("extracted data already exists.")
        else:
            raise ValueError("target file not exists")
    return file_name

In [72]:
extract_images("../../../data/Tensorflow_formatted/mnist/train/train-images-idx3-ubyte.gz")

'../../../data/Tensorflow_formatted/mnist/train/train-images-idx3-ubyte'

In [73]:
extract_labels("../../../data/Tensorflow_formatted/mnist/train/train-labels-idx1-ubyte.gz")

'../../../data/Tensorflow_formatted/mnist/train/train-labels-idx1-ubyte'

In [82]:
def prepare_data():
    train_images_path, train_labels_path, test_images_path, test_labels_path = may_download()
    train_images_path = extract_images(train_images_path)
    train_labels_path = extract_labels(train_labels_path)
    test_images_path = extract_images(test_images_path)
    test_labels_path = extract_labels(test_labels_path)
    return (train_images_path, train_labels_path, test_images_path, test_labels_path)

In [83]:
prepare_data()

file exists


('../../../data/Tensorflow_formatted/mnist/train/train-images-idx3-ubyte',
 '../../../data/Tensorflow_formatted/mnist/train/train-labels-idx1-ubyte',
 '../../../data/Tensorflow_formatted/mnist/test/t10k-images-idx3-ubyte',
 '../../../data/Tensorflow_formatted/mnist/test/t10k-labels-idx1-ubyte')