In [40]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import random
import os, shutil
from tqdm import tqdm
import pickle

from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
from keras.layers import GlobalAveragePooling2D, Dense, Dropout, Flatten
from keras.models import Model, Sequential
from keras import optimizers
from keras.callbacks import ModelCheckpoint, EarlyStopping

import numpy as np
import pandas as pd

# 预处理

In [2]:
data_dir = './{}/'
nb_class = len(os.listdir(data_dir.format('train')))
data = {}
for i in ['train', 'valid']:
    data[i] = {x: os.listdir(data_dir.format(i)+x) for x in os.listdir(data_dir.format(i))}
nb_train_samples = sum([len(data['train'][x]) for x in data['train'].keys()])
nb_valid_samples = sum([len(data['valid'][x]) for x in data['train'].keys()])

## 图像变换

暂时参考以下博文为蓝本：https://zhuanlan.zhihu.com/p/26693647

In [3]:
datagen = {'train': image.ImageDataGenerator(
    preprocessing_function=preprocess_input,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    vertical_flip=True
),
           'valid':image.ImageDataGenerator(
    preprocessing_function=preprocess_input
)
          }

用了`preprocess_input()`就不需要`rescale`参数了。

https://stackoverflow.com/questions/47555829/preprocess-input-method-in-keras

In [77]:
im_width, im_height = 224, 224
batch_size = 64

generator = {x: datagen[x].flow_from_directory(
    data_dir.format(x),
    target_size=(im_width, im_height),
    batch_size=batch_size,
    seed = 0,
    class_mode = 'binary'
) for x in datagen.keys()}

Found 16662 images belonging to 2 classes.
Found 8208 images belonging to 2 classes.


## 载入模型

载入模型并排除顶部的全连接层。

In [54]:
model_base = ResNet50(weights='imagenet', include_top=False, input_shape = (im_width, im_height, 3))



In [55]:
model_base.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            (None, 224, 224, 3)  0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 230, 230, 3)  0           input_4[0][0]                    
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 112, 112, 64) 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, 112, 112, 64) 256         conv1[0][0]                      
__________________________________________________________________________________________________
activation

activation_157 (Activation)     (None, 56, 56, 256)  0           add_51[0][0]                     
__________________________________________________________________________________________________
res3a_branch2a (Conv2D)         (None, 28, 28, 128)  32896       activation_157[0][0]             
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_158 (Activation)     (None, 28, 28, 128)  0           bn3a_branch2a[0][0]              
__________________________________________________________________________________________________
res3a_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_158[0][0]             
__________________________________________________________________________________________________
bn3a_branc

__________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3d_branch2c[0][0]             
__________________________________________________________________________________________________
add_55 (Add)                    (None, 28, 28, 512)  0           bn3d_branch2c[0][0]              
                                                                 activation_166[0][0]             
__________________________________________________________________________________________________
activation_169 (Activation)     (None, 28, 28, 512)  0           add_55[0][0]                     
__________________________________________________________________________________________________
res4a_branch2a (Conv2D)         (None, 14, 14, 256)  131328      activation_169[0][0]             
__________________________________________________________________________________________________
bn4a_branc

__________________________________________________________________________________________________
bn4d_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_180 (Activation)     (None, 14, 14, 256)  0           bn4d_branch2b[0][0]              
__________________________________________________________________________________________________
res4d_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_180[0][0]             
__________________________________________________________________________________________________
bn4d_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4d_branch2c[0][0]             
__________________________________________________________________________________________________
add_59 (Add)                    (None, 14, 14, 1024) 0           bn4d_branch2c[0][0]              
          

__________________________________________________________________________________________________
bn5b_branch2a (BatchNormalizati (None, 7, 7, 512)    2048        res5b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_191 (Activation)     (None, 7, 7, 512)    0           bn5b_branch2a[0][0]              
__________________________________________________________________________________________________
res5b_branch2b (Conv2D)         (None, 7, 7, 512)    2359808     activation_191[0][0]             
__________________________________________________________________________________________________
bn5b_branch2b (BatchNormalizati (None, 7, 7, 512)    2048        res5b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_192 (Activation)     (None, 7, 7, 512)    0           bn5b_branch2b[0][0]              
__________

冻结不需要训练的层：

In [56]:
for layer in model_base.layers:
    layer.trainable = False

**TBD** ：但是为什么是前5层呢？花卉项目似乎冻结了所有层。

添加自己的层：

In [62]:
model = Sequential()
model.add(model_base)
model.add(Flatten())
model.add(Dense(1024, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(50, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

# model.fit(X_data, Y_labels) # here your labels have shape (data_size,).

In [63]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
resnet50 (Model)             (None, 7, 7, 2048)        23587712  
_________________________________________________________________
flatten_3 (Flatten)          (None, 100352)            0         
_________________________________________________________________
dense_28 (Dense)             (None, 1024)              102761472 
_________________________________________________________________
dropout_10 (Dropout)         (None, 1024)              0         
_________________________________________________________________
dense_29 (Dense)             (None, 50)                51250     
_________________________________________________________________
dense_30 (Dense)             (None, 1)                 51        
Total params: 126,400,485
Trainable params: 102,812,773
Non-trainable params: 23,587,712
_____________________________________________________

In [66]:
print('Number of trainable weights befor freezing the model_base:', len(model.trainable_weights))
model_base.trainable = False
print('Number of trainable weights after freezing the model_base:', len(model.trainable_weights))

Number of trainable weights befor freezing the model_base: 6
Number of trainable weights after freezing the model_base: 6


编译模型：

In [73]:
lr = 0.0001
model.compile(loss = "binary_crossentropy", optimizer = optimizers.SGD(lr=lr, momentum=0.9), metrics=["accuracy"])

训练模型：

In [None]:
history = model.fit_generator(generator['train'],
                              steps_per_epoch=nb_train_samples // batch_size,
                              epochs=20,
                              validation_data=generator['valid'],
                              validation_steps=nb_valid_samples // batch_size)

Epoch 1/20


保存模型checkpoint：

In [51]:
model.save_weights('model_binary_wieghts.h5')
model.save('model_binary.h5')

In [68]:
# epochs = 60
# model_final.fit_generator(
#     generator['train'],
#     samples_per_epoch = nb_train_samples,
#     epochs = epochs,
#     validation_data = generator['valid'],
#     nb_val_samples = nb_valid_samples,
#     callbacks = [checkpoint, early]
# )

自己设计的检查点数据：

In [None]:
checkpoint = {'model': model
             ,'history': history}
with open('checkpoint.pth', 'wb') as pth:
    pickle.dump(ex_files, ef)
with open('checkpint.pth', 'rb') as pth:
    checkpoint = pickle.load(ef)
model = checkpoint['model']
history = checkpoint['history']

可视化：

In [None]:
#get the details form the history object
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(1, len(acc) + 1)

#Train and validation accuracy
plt.plot(epochs, acc, 'b', label='Training accurarcy')
plt.plot(epochs, val_acc, 'r', label='Validation accurarcy')
plt.title('Training and Validation accurarcy')
plt.legend()

plt.figure()
#Train and validation loss
plt.plot(epochs, loss, 'b', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and Validation loss')
plt.legend()

plt.show()

++++++++++++++++++++++++++前方不通行++++++++++++++++++++++++++++++++++

In [None]:
def plot_training(history):
    acc = history.history['acc']
    val_acc = history.history['val_acc']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs = range(len(acc))

    plt.plot(epochs, acc, 'r.')
    plt.plot(epochs, val_acc, 'r')
    plt.title('Training and validation accuracy')

    plt.figure()
    plt.plot(epochs, loss, 'r.')
    plt.plot(epochs, val_loss, 'r-')
    plt.title('Training and validation loss')

    plt.show()

# 参考资料

+ https://zhuanlan.zhihu.com/p/26693647
+ https://medium.com/@14prakash/transfer-learning-using-keras-d804b2e04ef8