In [14]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import random
import os, shutil
from tqdm import tqdm
import pickle

from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
from keras.layers import GlobalAveragePooling2D, Dense, Dropout, Flatten
from keras.models import Model, Sequential
from keras import optimizers
from keras.callbacks import ModelCheckpoint, EarlyStopping

import numpy as np
import pandas as pd

# 预处理

In [15]:
data_dir = './{}/'
nb_class = len(os.listdir(data_dir.format('train')))
data = {}
for i in ['train', 'valid']:
    data[i] = {x: os.listdir(data_dir.format(i)+x) for x in os.listdir(data_dir.format(i))}
nb_train_samples = sum([len(data['train'][x]) for x in data['train'].keys()])
nb_valid_samples = sum([len(data['valid'][x]) for x in data['train'].keys()])

## 图像变换

暂时参考以下博文为蓝本：https://zhuanlan.zhihu.com/p/26693647

In [16]:
datagen = {'train': image.ImageDataGenerator(
    preprocessing_function=preprocess_input,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    vertical_flip=True
),
           'valid':image.ImageDataGenerator(
    preprocessing_function=preprocess_input
),
           'test':image.ImageDataGenerator(
    preprocessing_function=preprocess_input
)
          }

用了`preprocess_input()`就不需要`rescale`参数了。

https://stackoverflow.com/questions/47555829/preprocess-input-method-in-keras

In [17]:
im_width, im_height = 224, 224
batch_size = 64

generator = {x: datagen[x].flow_from_directory(
    data_dir.format(x),
    target_size=(im_width, im_height),
    batch_size=batch_size,
    seed = 0,
    class_mode = 'binary'
) for x in datagen.keys()}

Found 16662 images belonging to 2 classes.
Found 8208 images belonging to 2 classes.
Found 0 images belonging to 0 classes.


## 载入模型

载入模型并排除顶部的全连接层。

In [18]:
model_base = ResNet50(weights='imagenet', include_top=False, input_shape = (im_width, im_height, 3))



In [19]:
model_base.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            (None, 224, 224, 3)  0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 230, 230, 3)  0           input_2[0][0]                    
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 112, 112, 64) 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, 112, 112, 64) 256         conv1[0][0]                      
__________________________________________________________________________________________________
activation

添加自己的层：

In [20]:
model = Sequential()
model.add(model_base)
model.add(Flatten())
model.add(Dense(1024, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(50, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

In [21]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
resnet50 (Model)             (None, 7, 7, 2048)        23587712  
_________________________________________________________________
flatten_2 (Flatten)          (None, 100352)            0         
_________________________________________________________________
dense_4 (Dense)              (None, 1024)              102761472 
_________________________________________________________________
dropout_2 (Dropout)          (None, 1024)              0         
_________________________________________________________________
dense_5 (Dense)              (None, 50)                51250     
_________________________________________________________________
dense_6 (Dense)              (None, 1)                 51        
Total params: 126,400,485
Trainable params: 126,347,365
Non-trainable params: 53,120
_________________________________________________________

查看冻结层前后的可训练层数：

In [22]:
print('Number of trainable weights befor freezing the model_base:', len(model.trainable_weights))
model_base.trainable = False
print('Number of trainable weights after freezing the model_base:', len(model.trainable_weights))

Number of trainable weights befor freezing the model_base: 218
Number of trainable weights after freezing the model_base: 6


编译模型：

In [23]:
lr = 0.0001
model.compile(loss = "binary_crossentropy", optimizer = optimizers.SGD(lr=lr, momentum=0.9), metrics=["accuracy"])

训练模型：

In [None]:
# history = model.fit_generator(generator['train'],
#                               steps_per_epoch=nb_train_samples // batch_size,
#                               epochs=20,
#                               validation_data=generator['valid'],
#                               validation_steps=nb_valid_samples // batch_size)

保存模型checkpoint：

In [51]:
model.save_weights('model_binary_wieghts.h5')
model.save('model_binary.h5')

In [68]:
# epochs = 60
# model_final.fit_generator(
#     generator['train'],
#     samples_per_epoch = nb_train_samples,
#     epochs = epochs,
#     validation_data = generator['valid'],
#     nb_val_samples = nb_valid_samples,
#     callbacks = [checkpoint, early]
# )

自己设计的检查点数据：

In [None]:
checkpoint = {'model': model
             ,'history': history}
with open('checkpoint.pth', 'wb') as pth:
    pickle.dump(ex_files, ef)
with open('checkpint.pth', 'rb') as pth:
    checkpoint = pickle.load(ef)
model = checkpoint['model']
history = checkpoint['history']

可视化：

In [None]:
#get the details form the history object
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(1, len(acc) + 1)

#Train and validation accuracy
plt.plot(epochs, acc, 'b', label='Training accurarcy')
plt.plot(epochs, val_acc, 'r', label='Validation accurarcy')
plt.title('Training and Validation accurarcy')
plt.legend()

plt.figure()
#Train and validation loss
plt.plot(epochs, loss, 'b', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and Validation loss')
plt.legend()

plt.show()

# 预测

In [None]:
img_path = data_dir.format('test')[0]
img = image.load_img(img_path, target_size=(im_width, im_height))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

preds = model.predict(x)
prob = decode_predictions(preds, top=3)[0]
# decode the results into a list of tuples (class, description, probability)
# (one such list for each sample in the batch)
print('Predicted:', prob)

++++++++++++++++++++++++++前方不通行++++++++++++++++++++++++++++++++++

In [None]:
def plot_training(history):
    acc = history.history['acc']
    val_acc = history.history['val_acc']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs = range(len(acc))

    plt.plot(epochs, acc, 'r.')
    plt.plot(epochs, val_acc, 'r')
    plt.title('Training and validation accuracy')

    plt.figure()
    plt.plot(epochs, loss, 'r.')
    plt.plot(epochs, val_loss, 'r-')
    plt.title('Training and validation loss')

    plt.show()

# 参考资料

+ https://zhuanlan.zhihu.com/p/26693647
+ https://medium.com/@14prakash/transfer-learning-using-keras-d804b2e04ef8