In [13]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf
import random

In [17]:
tf.app.flags.DEFINE_float("learning_rate", 1e-4, "Learning rate.")
tf.app.flags.DEFINE_integer("batch_size", 100, "Batch size.")
tf.app.flags.DEFINE_string("train_dir", "./", "train directory")
tf.app.flags.DEFINE_integer("steps_per_checkpoint", 100, "steps per checkpoint")

ArgumentError: argument --learning_rate: conflicting option string: --learning_rate

In [4]:
FLAGS = tf.app.flags.FLAGS

In [12]:
def read_data(images_path, labels_path):
    # read data from the given path and save it 
    # 返回data_set，是一个列表，分别对应存储着image和label的信息
    data_set = [[],[]]
    images_data = _extract_images(images_path)
    labels_data = _extract_labels(labels_path)
    for index, (image, label) in enumerate(zip(images_data, labels_data)):
        data_set[index][0].append(image)
        data_set[index][1].append(label)
    return data_set

In [21]:
def read_data_with_validation(images_path, labels_path, validation_size):
    data_set = read_data(images_path, labels_path)
    if not validation_size > len(data_set[0]) * 0.3:
        validation_images = data_set[0][: validation_size]
        validation_labels = data_set[1][: validation_size]
        train_images = data_set[0][validation_size:]
        train_labels = data_set[1][validation_size:]
    validation_data_set = [[], []]
    train_data_set = [[], []]
    validation_data_set[0].extend(validation_images)
    validation_data_set[1].extend(validation_labels)
    train_data_set[0].extend(train_images)
    train_data_set[1].extend(train_labels)
    return (validation_data_set, train_data_set)

In [6]:
# 名字前有一个下划线，惯例表示内部方法
# =>numpy的数据读取，和dtype类型，改为big endian
def _read32(bytestream):
    # 设置dtype，更改endian的类型
    dt = np.dtype(np.uint32).newbyteorder('>')
    # 读取前4个字节 并返回. frombuffer返回数组，取首值
    return np.frombuffer(bytestream.read(4), dtype = dt)[0]

In [9]:
# 提取文件，应该区分image和label
# image文件，前4个字符是验证码，之后是图片的个数，图片的rows，图片的cols，再之后就是图片的数据了。
# 所以提取步骤是先获得图片的数据，然后返回对应的图片数据。
# 各种数据类型，弄得头都晕了，应该注意区分
def _extract_images(file_path):
    if os.path.exists(file_path):
        with gzip.open(file_path, 'rb') as bytestream:
            magic = _read32(bytestream)
            if magic != 2051:
                raise ValueError("Invalid magic number %d in MNIST image file: %s" %(magic, file_path))
            num_images = _read32(bytestream)
            rows = _read32(bytestream)
            cols = _read32(bytestream)
            buf = bytestream.read(rows * cols * num_images)
            data = np.frombuffer(buf, np.uint8)
            data = data.reshape(num_images, rows, cols, 1)
            return data
    else:
        raise ValueError("target file does not exists")

In [10]:
def _extract_labels(file_path):
    file_name, _ = os.path.splitext(file_path)
    if os.path.exists(file_path):
        with gzip.open(file_path, 'rb') as bytestream:
            magic = _read32(bytestream)
            if magic != 2049:
                raise ValueError("Invalid magic number %d in MNIST labels file: %s" %(magic, file_path))
            num_items = _read32(bytestream)
            buf = bytestream.read(num_items)
            labels = np.frombuffer(buf, dtype = np.uint8)
            return labels
    else:
        raise ValueError("target file does not exists")

In [16]:
def get_batch(data_set, batch_size, shuffle = False):
    # 获取对应的batch，标准库里把这部分放在model类里面，表示不是很理解
    if len(data_set[0]) < batch_size:
        # 剩余的数据量小于batch的大小=> 直接输出data_set
        return data_set
    else:
        # 其实只把生成sample的语句进行分支，或者直接用if语句
        # sample = random.sample(range(len(data_set[0])), batch_size) if shuffle else range(len(data_set[0]))
        batch = [[],[]] 
        if not shuffle:
            # 取data_set中的后batch_size个数据
            for _ in range(batch_size):
                batch[0].append(data_set[0].pop())
                batch[1].append(data_set[1].pop())
        else:
            samples = random.sample(range(len(data_set[0])), batch_size)
            for index in samples:
                batch[0].append(data_set[0].pop(index))
                batch[1].append(data_set[1].pop(index))
        return batch_set

In [18]:
def create_model(session):
    model = CNNModel(FLAGS.batch_size, FLAGS.learning_rate)
    # 看是否存在checkpoint，存在的话，读取就进行参数读取
    ckpt = tf.train.get_chekcpoint_state(FLAGS.train_dir)
    if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
        print("Reading model parameters from %s" %(ckpt.model_checkpoint_path))
        model.saver.restore(session, ckpt.model_checkpoint_path)
    else:
        print("Initialize the parameters")
        session.run(tf.global_variables_initializer())
    return model

In [None]:
def train():
    # 获取数据
    train_images_path, train_labels_path, test_images_path, test_labels_path = prepare_data()
    
    with tf.Session() as sess:
        # Create Model
        model = create_model(sess)
        # read data
        train_data_set = read_data(train_images_path, train_labels_path)
        test_data_set = read_data(test_images_path, test_labels_path)
        # start train
        step_time, loss = 0.0, 0.0
        current_step = 0
        while True:
            batch_set = get_batch(train_data_set, 100, True)
            start_time = time.time()
            loss = model.step(sess, batch_set[0], batch_set[1])
            step_time = (time.time() - start_time)/ FLAGS.steps_per_checkpoint
            loss += loss/ FLAGS.steps_per_chekcpoint
            current_step += 1
            
            # save the checkpoint and zero timer and loss
            if current_step % FLAGS.steps_per_checkpoint == 0:
                perplexity = math.exp(float(loss)) if loss < 300 else float("inf")
                print("global step %d learning raet %.4f step_time %.2f perplexity"
                     "%.2f" %(model.global_step.eval(), model.learning_rate.eval(), step_time, perplexity))
                checkpoint_path = os.path.join(FLAGS.train_dir, "runner.ckpt")
                model.saver.save(sess, checkpoint_path, global_step = model.global_step)
                step_time, loss = 0.0, 0.0