In [63]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import random
import os, shutil
from tqdm import tqdm
import pickle
import math

from keras.applications.resnet50 import ResNet50
# from keras_applications.resnet import ResNet152
from keras.applications.xception import Xception

from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
from keras.layers import GlobalAveragePooling2D, Dense, Dropout, Flatten
from keras.models import Model, Sequential, load_model
from keras import optimizers
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras import backend, layers, models,utils

import numpy as np
import pandas as pd

# 预处理

In [20]:
data_dir = './{}/'
nb_class = len(os.listdir(data_dir.format('train')))
data = {}
for i in ['train', 'valid']:
    data[i] = {x: os.listdir(data_dir.format(i)+x) for x in os.listdir(data_dir.format(i))}
nb_train_samples = sum([len(data['train'][x]) for x in data['train'].keys()])
nb_valid_samples = sum([len(data['valid'][x]) for x in data['train'].keys()])
nb_test_samples = len(os.listdir(data_dir.format('test')))

In [21]:
print(any([data['valid']['cat'][x] in data['train']['cat'] for x in range(len(data['valid']['cat']))]))
print(any([data['valid']['dog'][x] in data['train']['dog'] for x in range(len(data['valid']['dog']))]))

False
False


## 图像变换

暂时参考以下博文为蓝本：https://zhuanlan.zhihu.com/p/26693647

In [23]:
datagen = {'train': image.ImageDataGenerator(
    preprocessing_function=preprocess_input,
#     rotation_range=30,
#     width_shift_range=0.2,
#     height_shift_range=0.2,
#     shear_range=0.2,
#     zoom_range=0.2,
#     horizontal_flip=True,
#     vertical_flip=True
),
           'valid':image.ImageDataGenerator(
    preprocessing_function=preprocess_input
),
           'test': image.ImageDataGenerator(
    preprocessing_function=preprocess_input
)
          }

用了`preprocess_input()`就不需要`rescale`参数了。

https://stackoverflow.com/questions/47555829/preprocess-input-method-in-keras

In [36]:
im_width, im_height = 224, 224
# im_width, im_height = 299, 299
batch_size = 64
seed = 123

# generator = {x: datagen[x].flow_from_directory(
#     data_dir.format(x),
#     target_size=(im_width, im_height),
#     batch_size=batch_size,
#     seed = 123,
#     class_mode = 'binary',
#     shuffle = True
# ) for x in list(datagen.keys())[:2]}

generator = {}

generator['train'] = datagen['train'].flow_from_directory(data_dir.format('train'),
                                                          target_size=(im_width, im_height),
                                                          batch_size=batch_size,
                                                          seed = seed,
                                                          class_mode = 'binary',
                                                          shuffle = True)
generator['valid'] = datagen['valid'].flow_from_directory(data_dir.format('valid'),
                                                          target_size=(im_width, im_height),
                                                          batch_size=batch_size,
                                                          seed = seed,
                                                          class_mode = 'binary',
                                                          shuffle = True)
generator['test'] = datagen['test'].flow_from_directory(data_dir.format('test'),
                                                          target_size=(im_width, im_height),
                                                          batch_size=batch_size,
                                                          seed = seed,
                                                          class_mode = 'binary',
                                                          shuffle = False)

Found 19896 images belonging to 2 classes.
Found 4974 images belonging to 2 classes.
Found 0 images belonging to 0 classes.


## 载入模型

载入模型并排除顶部的全连接层。

In [66]:
model_base = ResNet50(weights='imagenet', include_top=False, input_shape = (im_width, im_height, 3))
# model_base = ResNet152(include_top = False, weights = 'imagenet', backend = backend, layers = layers, models = models, utils = utils,
#                        input_shape = (im_width, im_height, 3))
# model_base = Xception(weights='imagenet', include_top=False, input_shape = (im_width, im_height, 3))

In [67]:
# model_base.summary()

添加自己的层：

In [68]:
model = Sequential()
model.add(model_base)
# model.add(Flatten())
model.add(GlobalAveragePooling2D())
# model.add(Dense(1024, activation='relu'))
# model.add(Dropout(0.75))
# model.add(Dense(500, activation='relu'))
model.add(Dropout(0.75))
model.add(Dense(1, activation='sigmoid'))

In [69]:
# model.summary()

查看冻结层前后的可训练层数：

In [70]:
print('Number of trainable weights befor freezing the model_base:', len(model.trainable_weights))
model_base.trainable = False
print('Number of trainable weights after freezing the model_base:', len(model.trainable_weights))

Number of trainable weights befor freezing the model_base: 214
Number of trainable weights after freezing the model_base: 2


编译模型：

In [71]:
# lr = 0.0005
# from keras.utils import multi_gpu_model
# model = multi_gpu_model(model_base, gpus=8)
model.compile(loss = "binary_crossentropy", optimizer = optimizers.Adam(), metrics=["accuracy"])

训练模型：

In [None]:
epochs = 10
history = model.fit_generator(generator['train'],
                              steps_per_epoch=math.ceil(nb_train_samples / batch_size),
                              epochs=epochs,
                              validation_data=generator['valid'],
                              validation_steps=math.ceil(nb_valid_samples / batch_size))

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10

保存模型checkpoint：

In [None]:
model.save_weights('model_binary_wieghts.h5')
model.save('model_binary.h5')

# 重新载入模型

In [None]:
model = load_model('model_binary.h5')

可视化：

In [None]:
#get the details form the history object
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(1, len(acc) + 1)

#Train and validation accuracy
plt.figure(figsize = [15, 7])
plt.subplot(1,2,1)
plt.plot(epochs, acc, 'b', label='Training accurarcy')
plt.plot(epochs, val_acc, 'r', label='Validation accurarcy')
plt.title('Training and Validation accurarcy')
plt.legend()

plt.subplot(1,2,2)
#Train and validation loss
plt.plot(epochs, loss, 'b', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and Validation loss')
plt.legend()
plt.savefig('loss and acc.jpg', bbox_inches = 'tight')
plt.show()

# 预测

In [None]:
# test_imgs = os.listdir(data_dir.format('test'))
# img_path = data_dir.format('test') + test_imgs[3]
# # img_path = './test04.jpg'
# img = image.load_img(img_path, target_size=(im_width, im_height))
# x = image.img_to_array(img)
# x = np.expand_dims(x, axis=0)
# x = preprocess_input(x)

# preds = model.predict(x)[0][0]
# print(preds)
# img_show = mpimg.imread(img_path)
# plt.imshow(img_show)
# plt.title('This is a {}'.format('dog' if preds > 0.5 else 'cat'))
# plt.suptitle('probability: {} percent'.format(round(preds, 3)*100) if preds > 0.5 else round(1-preds, 3)*100);

全部预测，制作csv表

In [None]:
test_imgs = os.listdir(data_dir.format('test'))
ids = []
label = []
for i in tqdm(test_imgs):
    img = image.load_img(data_dir.format('test') + i, target_size=(im_width, im_height))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    label.append(model.predict(x)[0][0])
    ids.append(int(i.split('.')[0]))

In [None]:
sub = pd.DataFrame({'id': ids, 'label': label}).sort_values('id',axis = 0, ascending = True)
sub.to_csv('submission.csv', index = False)

In [None]:
pred = model.predict_generator(generator['test'],
                        steps=math.ceil(nb_test_samples / batch_size),
                       verbose = 1)

# 参考资料

+ https://zhuanlan.zhihu.com/p/26693647
+ https://medium.com/@14prakash/transfer-learning-using-keras-d804b2e04ef8
+ https://www.kaggle.com/risingdeveloper/transfer-learning-in-keras-on-dogs-vs-cats