In [1]:
import os
os.sys.path.append(os.path.dirname(os.path.abspath('.')))

TensorFlow中自带了实现数据增强的API，主要分为四类：
- Resizing
- Cropping
- Flipping and Transposing
- Image Adjustments

In [2]:
import tensorflow as tf
import matplotlib.pyplot as plt
from dataset.dataset import load_cifar10
import numpy as np
from CNN.mini_CNN import mini_CNN

config = tf.ConfigProto()
config.gpu_options.allow_growth = True    # 按需使用显存

  from ._conv import register_converters as _register_converters


# 数据准备
这里使用CIFAR-10中的三张图片，模拟训练过程中的一个batch。

In [3]:
# train_data, test_data = load_cifar10(batch_size=64)
# for batch_data, _ in train_data.next_batch():
#     img_batch = batch_data
#     break
# del train_data, test_data

# batch_size = 5
# img_batch = img_batch[:batch_size].reshape(
#     (-1, 3, 32, 32)).transpose((0, 2, 3, 1))

In [4]:
# plt.clf()
# fig, axs = plt.subplots(1, batch_size, figsize=(10, 2))
# for i in range(batch_size):
#     axs[i].imshow(img_batch[i])
# plt.show()

## Resizing
```tf.image.resize_images```同时支持batch输入与单张图片输入，有四种插值方法：BILINEAR、NEAREST_NEIGHBOR、BICUBIC和AREA，除了NEAREST_NEIGHBOR，另外三种方法都只能接受```float```格式的输入，所以为了最大兼容性，在缩放图片之前，需要将图片转换成浮点格式。同样的，标准的RBG三通道图片格式为```uint8```，所以在可视化或保存图片时还要再做转换。

不过由于在```CifarData```这个类中我们使用了```MinMaxScaler```这一缩放模式，所以无需担心格式问题。

In [5]:
# with tf.name_scope('img_resize'):
#     img = tf.image.resize_images(img_batch, (48, 48), method=0)

# with tf.Session(config=config) as sess:
#     res = sess.run(img)

# plt.clf()
# fig, axs = plt.subplots(1, batch_size, figsize=(10, 2))
# for i in range(batch_size):
#     axs[i].imshow(res[i])
# plt.show()

## Cropping
- ```tf.image.pad_to_bounding_box(image, offset_height, offset_width, target_height, target_width)```：边缘填充
- ```tf.image.crop_to_bounding_box(image, offset_height, offset_width, target_height, target_width)```：裁剪
- ```tf.image.random_crop(image, size, seed=None, name=None)```：随机裁剪，需指定所有维度

In [6]:
# with tf.name_scope('img_crop'):
#     img = tf.image.random_crop(img_batch, size=(img_batch.shape[0], 25, 25, 3))

# with tf.Session(config=config) as sess:
#     res = sess.run(img)

# plt.clf()
# fig, axs = plt.subplots(1, batch_size, figsize=(10, 2))
# for i in range(batch_size):
#     axs[i].imshow(res[i])
# plt.show()

In [7]:
# with tf.name_scope('img_crop'):
#     img = tf.image.pad_to_bounding_box(img_batch, offset_height=2, offset_width=4,
#                                        target_height=36, target_width=40)

# with tf.Session(config=config) as sess:
#     res = sess.run(img)

# plt.clf()
# fig, axs = plt.subplots(1, batch_size, figsize=(10, 2))
# for i in range(batch_size):
#     axs[i].imshow(res[i])
# plt.show()

## Flipping and Transposing
- ```tf.image.random_flip_up_down(image)```
- ```tf.image.random_flip_left_right(image)```
- ```tf.image.transpose_image(image)```

In [8]:
# with tf.name_scope('img_crop'):
#     img = tf.image.random_flip_up_down(img_batch)

# with tf.Session(config=config) as sess:
#     res = sess.run(img)

# plt.clf()
# fig, axs = plt.subplots(1, batch_size, figsize=(10, 2))
# for i in range(batch_size):
#     axs[i].imshow(res[i])
# plt.show()

## Adjustments
- ```tf.image.adjust_brightness(image, delta, min_value=None, max_value=None)```
- ```tf.image.random_brightness(image, max_delta, seed=None)```
- ```tf.image.adjust_contrast(images, contrast_factor, min_value=None, max_value=None)```
- ```tf.image.random_contrast(image, lower, upper, seed=None)```

In [9]:
# with tf.name_scope('img_crop'):
#     img = tf.image.random_brightness(img_batch, max_delta=0.4)
#     img = tf.clip_by_value(img, 0, 1)

# with tf.Session(config=config) as sess:
#     res = sess.run(img)

# plt.clf()
# fig, axs = plt.subplots(1, batch_size, figsize=(10, 2))
# for i in range(batch_size):
#     axs[i].imshow(res[i])
# plt.show()

In [10]:
# with tf.name_scope('img_crop'):
#     img = tf.image.random_contrast(img_batch, lower=0.2,upper=1.8)
#     img = tf.clip_by_value(img, 0, 1)

# with tf.Session(config=config) as sess:
#     res = sess.run(img)

# plt.clf()
# fig, axs = plt.subplots(1, batch_size, figsize=(10, 2))
# for i in range(batch_size):
#     axs[i].imshow(res[i])
# plt.show()

# 实时增强
做数据增强有两种方法：第一种是在训练之前对所有数据做增强，相当于一次性增大了数据集的总量；另一种方法是在训练阶段进行实时增强，这里演示第二种方法。同时考虑到数据增强会使得神经网络需要更大的迭代次数去学习，这里只使用翻转。

注意，数据增强应该只应用与训练集，即训练过程才需要数据增强。这需要使用TensorFlow的control flow去实现。

In [11]:
batch_size = 32
train_data, test_data = load_cifar10(batch_size=batch_size)

unit_I = train_data.n_features
X = tf.placeholder(tf.float32, [None, unit_I])
Y = tf.placeholder(tf.int64, [None])
X_img = tf.transpose(tf.reshape(X, [-1, 3, 32, 32]),
                     perm=[0, 2, 3, 1])


def data_aug(X_img):
    '''
    数据增强
    X_img: 原图片张量
    '''
    X_img_trans = tf.image.random_flip_left_right(X_img)    # 翻转
    return X_img_trans


# tf.cond的pred参数不能是bool值，所以这里使用int型
is_training = tf.placeholder(tf.int16)
X_img_trans = tf.cond(is_training > 0, lambda: data_aug(X_img), lambda: X_img)

with tf.name_scope('CNN'):
    logits = mini_CNN(X_img_trans, activation=tf.nn.leaky_relu)

with tf.name_scope('Eval'):
    loss = tf.losses.sparse_softmax_cross_entropy(labels=Y, logits=logits)
    predict = tf.argmax(logits, 1)
    accuracy = tf.reduce_mean(tf.cast(tf.equal(predict, Y), tf.float32))

with tf.name_scope('train_op'):
    lr = 1e-3
    train_op = tf.train.AdamOptimizer(lr).minimize(loss)

init = tf.global_variables_initializer()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True



(50000, 3072) (50000,)




(10000, 3072) (10000,)
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use keras.layers.conv2d instead.
Instructions for updating:
Use keras.layers.max_pooling2d instead.
Instructions for updating:
Use keras.layers.flatten instead.
Instructions for updating:
Use keras.layers.dense instead.
Instructions for updating:
Use tf.cast instead.


# 训练网络

In [12]:
with tf.Session(config=config) as sess:
    sess.run(init)
    epochs = 20

    batch_cnt = 0
    for epoch in range(epochs):
        for batch_data, batch_labels in train_data.next_batch():
            batch_cnt += 1
            loss_val, acc_val, _ = sess.run(
                [loss, accuracy, train_op],
                feed_dict={
                    X: batch_data,
                    Y: batch_labels,
                    is_training:1})

            # 每1000batch输出一次信息
            if (batch_cnt+1) % 1000 == 0:
                print('epoch: {}, batch_loss: {}, batch_acc: {}'.format(
                    epoch+1, loss_val, acc_val))

            # 每5000batch做一次验证
            if (batch_cnt+1) % 5000 == 0:
                all_test_acc_val = list()
                for test_batch_data, test_batch_labels in test_data.next_batch():
                    test_acc_val = sess.run(
                        [accuracy],
                        feed_dict={
                            X: test_batch_data,
                            Y: test_batch_labels,
                            is_training:0
                        })
                    all_test_acc_val.append(test_acc_val)
                test_acc = np.mean(all_test_acc_val)
                print('epoch: {}, test_acc: {}'.format(epoch+1, test_acc))

epoch: 1, batch_loss: 1.5107368230819702, batch_acc: 0.375
epoch: 2, batch_loss: 1.273892879486084, batch_acc: 0.46875
epoch: 2, batch_loss: 1.3984922170639038, batch_acc: 0.59375
epoch: 3, batch_loss: 0.9320375919342041, batch_acc: 0.75
epoch: 4, batch_loss: 1.0769767761230469, batch_acc: 0.5625
epoch: 4, test_acc: 0.651442289352417
epoch: 4, batch_loss: 1.3465605974197388, batch_acc: 0.53125
epoch: 5, batch_loss: 0.6441978216171265, batch_acc: 0.75
epoch: 6, batch_loss: 1.0151033401489258, batch_acc: 0.5625
epoch: 6, batch_loss: 0.7728836536407471, batch_acc: 0.75
epoch: 7, batch_loss: 0.9682464003562927, batch_acc: 0.71875
epoch: 7, test_acc: 0.6898036599159241
epoch: 8, batch_loss: 0.6210789084434509, batch_acc: 0.75
epoch: 8, batch_loss: 0.4844314455986023, batch_acc: 0.8125
epoch: 9, batch_loss: 0.6945263147354126, batch_acc: 0.71875
epoch: 9, batch_loss: 0.7143003940582275, batch_acc: 0.71875
epoch: 10, batch_loss: 0.9262591600418091, batch_acc: 0.71875
epoch: 10, test_acc: 0.71

这里仅仅只应用了左右翻转，ACC就已经有所上升。