In [1]:
from lab12_util import *

DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
DEST_DIRECTORY = 'dataset/cifar10'
DATA_DIRECTORY = DEST_DIRECTORY + '/cifar-10-batches-bin'
IMAGE_HEIGHT = 32
IMAGE_WIDTH = 32
IMAGE_DEPTH = 3
IMAGE_SIZE_CROPPED = 24
BATCH_SIZE = 128
NUM_CLASSES = 10 
LABEL_BYTES = 1
IMAGE_BYTES = 32 * 32 * 3
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000

# download it
maybe_download_and_extract(DEST_DIRECTORY, DATA_URL)

>> Done


In [17]:
from tensorflow.contrib.data import FixedLengthRecordDataset, Iterator

def cifar10_record_distort_parser(record):
    ''' Parse the record into label, cropped and distorted image
    -----
    Args:
        record: 
            a record containing label and image.
    Returns:
        label: 
            the label in the record.
        image: 
            the cropped and distorted image in the record.
    '''
    label_bytes = 1
    image_bytes = IMAGE_HEIGHT * IMAGE_WIDTH * IMAGE_DEPTH
    record_bytes = label_bytes + image_bytes

    # Bytes to Vector
    record_vector = tf.decode_raw(record, tf.uint8)

    label = tf.cast(record_vector[0], tf.int32)
    #label = tf.one_hot(label, NUM_CLASSES, tf.int32)

    depth_major = tf.reshape(record_vector[label_bytes:record_bytes], [IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH])

    reshaped_image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
    distorted_image = tf.random_crop(reshaped_image, [IMAGE_SIZE_CROPPED, IMAGE_SIZE_CROPPED, 3])
    distorted_image = tf.image.random_flip_left_right(distorted_image)
    distorted_image = tf.image.random_brightness(distorted_image, max_delta=63)
    
    image = tf.image.per_image_standardization(distorted_image)
    
    
    return label, image
    


def cifar10_record_crop_parser(record):
    ''' Parse the record into label, cropped image
    -----
    Args:
        record: 
            a record containing label and image.
    Returns:
        label: 
            the label in the record.
        image: 
            the cropped image in the record.
    '''
    label_bytes = 1
    image_bytes = IMAGE_HEIGHT * IMAGE_WIDTH * IMAGE_DEPTH
    record_bytes = label_bytes + image_bytes
    
    # Bytes to Vector
    record_vector = tf.decode_raw(record, tf.uint8)
    
    label = tf.cast(record_vector[0], tf.int32)
    #label = tf.one_hot(label, NUM_CLASSES, tf.int32)
    
    depth_major = tf.reshape(
      record_vector[label_bytes:record_bytes], [IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH])
    
    reshaped_image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
    distorted_image = tf.random_crop(reshaped_image, [IMAGE_SIZE_CROPPED, IMAGE_SIZE_CROPPED, 3])
    
    image = tf.image.per_image_standardization(distorted_image)
    
    
    return label, image


def cifar10_iterator(filenames, batch_size, cifar10_record_parser):
    ''' Create a dataset and return a tf.contrib.data.Iterator 
    which provides a way to extract elements from this dataset.
    -----
    Args:
        filenames: 
            a tensor of filenames.
        batch_size: 
            batch size.
    Returns:
        iterator: 
            an Iterator providing a way to extract elements from the created dataset.
        output_types: 
            the output types of the created dataset.
        output_shapes: 
            the output shapes of the created dataset.
    '''
    label_bytes = 1
    image_bytes = IMAGE_HEIGHT * IMAGE_WIDTH * IMAGE_DEPTH
    record_bytes = label_bytes + image_bytes
    dataset = FixedLengthRecordDataset(filenames, record_bytes)
    
    dataset = dataset.map(cifar10_record_parser)
    
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(10)
    iterator = dataset.make_initializable_iterator()
    
    output_types = dataset.output_types
    output_shapes = dataset.output_shapes


    return iterator, output_types, output_shapes

    

In [18]:
tf.reset_default_graph()

training_files = [os.path.join(DATA_DIRECTORY, 'data_batch_%d.bin' % i) for i in range(1, 6)]
testing_files = [os.path.join(DATA_DIRECTORY, 'test_batch.bin')]

filenames_train = tf.constant(training_files)
filenames_test = tf.constant(testing_files)

iterator_train, types, shapes = cifar10_iterator(filenames_train, BATCH_SIZE, cifar10_record_distort_parser)
iterator_test, _, _ = cifar10_iterator(filenames_test, BATCH_SIZE, cifar10_record_crop_parser)

next_batch = iterator_train.get_next()

# use to handle training and testing
handle = tf.placeholder(tf.string, shape=[])
iterator = Iterator.from_string_handle(handle, types, shapes)
labels_images_pairs = iterator.get_next()

# CNN model
model = CNN_Model(
    batch_size=BATCH_SIZE,
    num_classes=NUM_CLASSES,
    num_training_example=NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN,
    num_epoch_per_decay=350.0,
    init_lr=0.1,
    moving_average_decay=0.9999)

with tf.device('/cpu:0'):
    labels, images = labels_images_pairs
    labels = tf.reshape(labels, [BATCH_SIZE])
    images = tf.reshape(images, [BATCH_SIZE, IMAGE_SIZE_CROPPED, IMAGE_SIZE_CROPPED, IMAGE_DEPTH])
with tf.variable_scope('model'):
    logits = model.inference(images)

# train
global_step = tf.contrib.framework.get_or_create_global_step()
total_loss = model.loss(logits, labels)
train_op = model.train(total_loss, global_step)

# test
top_k_op = tf.nn.in_top_k(logits, labels, 1)

In [15]:
%%time
from datetime import datetime
from tqdm import tqdm

NUM_EPOCH = 10
NUM_BATCH_PER_EPOCH = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN // BATCH_SIZE
ckpt_dir = './model/'

config = tf.ConfigProto(allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True))

# train
saver = tf.train.Saver()
with tf.Session(config=config) as sess:
    ckpt = tf.train.get_checkpoint_state(ckpt_dir)
    
    if (ckpt and ckpt.model_checkpoint_path):
        saver.restore(sess, ckpt.model_checkpoint_path)
        # assume the name of checkpoint is like '.../model.ckpt-1000'
        gs = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
        sess.run(tf.assign(global_step, gs))
    else:
        # no checkpoint found
        init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        loss = []
        
        print("{}: Start training.".format(datetime.now()))
        
    for i in tqdm(range(NUM_EPOCH)):
        _loss = []
        sess.run(iterator_train.initializer)

        for _ in range(NUM_BATCH_PER_EPOCH):
            lbl, img = sess.run(next_batch)
            l, _ = sess.run([total_loss, train_op], feed_dict={images: img, labels: lbl})
            _loss.append(l)
        loss_this_epoch = np.sum(_loss)
        gs = global_step.eval()
        print('{}: Loss of epoch {}: {}'.format(datetime.now(), gs / NUM_BATCH_PER_EPOCH, loss_this_epoch))
        loss.append(loss_this_epoch)
        saver.save(sess, ckpt_dir + 'model.ckpt', global_step=gs)
    coord.request_stop()
    coord.join(threads)
  
print("{}: Done training.".format(datetime.now()))


  0%|          | 0/10 [00:00<?, ?it/s]

2017-11-12 15:31:28.381171: Start training.


 10%|█         | 1/10 [00:11<01:45, 11.75s/it]

2017-11-12 15:31:40.012861: Loss of epoch 1.0: 1504.51416015625


 20%|██        | 2/10 [00:24<01:36, 12.05s/it]

2017-11-12 15:31:52.367995: Loss of epoch 2.0: 1173.6761474609375


 30%|███       | 3/10 [00:35<01:22, 11.74s/it]

2017-11-12 15:32:03.479606: Loss of epoch 3.0: 955.1044921875


 40%|████      | 4/10 [00:46<01:09, 11.59s/it]

2017-11-12 15:32:14.608683: Loss of epoch 4.0: 796.3057861328125


 50%|█████     | 5/10 [00:57<00:57, 11.55s/it]

2017-11-12 15:32:25.993432: Loss of epoch 5.0: 682.933837890625


 60%|██████    | 6/10 [01:09<00:46, 11.53s/it]

2017-11-12 15:32:37.419351: Loss of epoch 6.0: 599.81005859375


 70%|███████   | 7/10 [01:20<00:34, 11.49s/it]

2017-11-12 15:32:48.691900: Loss of epoch 7.0: 540.1226806640625


 80%|████████  | 8/10 [01:32<00:23, 11.53s/it]

2017-11-12 15:33:00.474425: Loss of epoch 8.0: 492.43048095703125


 90%|█████████ | 9/10 [01:43<00:11, 11.55s/it]

2017-11-12 15:33:12.198794: Loss of epoch 9.0: 459.7913818359375


100%|██████████| 10/10 [01:55<00:00, 11.54s/it]

2017-11-12 15:33:23.625325: Loss of epoch 10.0: 432.41925048828125
2017-11-12 15:33:23.749019: Done training.
CPU times: user 3min 29s, sys: 17.5 s, total: 3min 47s
Wall time: 1min 55s





In [20]:
%%time
next_test = iterator_test.get_next()
variables_to_restore = model.ema.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)

with tf.Session() as sess:
    # Restore variables from disk.
    ckpt = tf.train.get_checkpoint_state(ckpt_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        num_iter = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL // BATCH_SIZE
        total_sample_count = num_iter * BATCH_SIZE
        true_count = 0
        sess.run(iterator_test.initializer)
        for _ in tqdm(range(num_iter)):
            lbl, img = sess.run(next_test)
            predictions = sess.run(top_k_op, feed_dict={images: img, labels: lbl})
            true_count += np.sum(predictions)
        print('{}: Accurarcy: {}/{} = {}'.format(datetime.now(), true_count, total_sample_count,
                                     true_count / total_sample_count))
        coord.request_stop()
        coord.join(threads)
    else:
        print("{}: No model existed.".format(datetime.now()))

INFO:tensorflow:Restoring parameters from ./model/model.ckpt-3900


100%|██████████| 78/78 [00:01<00:00, 50.84it/s]

2017-11-12 15:41:53.104205: Accurarcy: 7255/9984 = 0.7266626602564102
CPU times: user 3.19 s, sys: 244 ms, total: 3.43 s
Wall time: 1.72 s



