In [73]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import random
import os, shutil
from tqdm import tqdm
import pickle

from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
from keras.layers import GlobalAveragePooling2D, Dense, Dropout, Flatten
from keras.models import Model

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

# 预处理

In [48]:
data_dir = './{}/'
nb_class = len(os.listdir(data_dir.format('train')))
data = {}
for i in ['train', 'valid']:
    data[i] = {x: os.listdir(data_dir.format(i)+x) for x in os.listdir(data_dir.format(i))}
nb_train_samples = sum([len(data['train'][x]) for x in data['train'].keys()])
nb_valid_samples = sum([len(data['valid'][x]) for x in data['train'].keys()])

## 图像变换

暂时参考以下博文为蓝本：https://zhuanlan.zhihu.com/p/26693647

In [3]:
datagen = {'train': image.ImageDataGenerator(
    preprocessing_function=preprocess_input,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    vertical_flip=True
),
           'valid':image.ImageDataGenerator(
    preprocessing_function=preprocess_input
)
          }

In [49]:
im_width, im_height = 224, 224
batch_size = 64

generator = {x: datagen[x].flow_from_directory(
    data_dir.format(x),
    target_size=(im_width, im_height),
    batch_size=batch_size,
) for x in datagen.keys()}

Found 16662 images belonging to 2 classes.
Found 8208 images belonging to 2 classes.


## 载入模型

载入模型并排除顶部的全连接层。

In [51]:
model = ResNet50(weights='imagenet', include_top=False, input_shape = (im_width, im_height, 3))



<keras.engine.training.Model at 0x21403afcd30>

In [71]:
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            (None, 224, 224, 3)  0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 230, 230, 3)  0           input_4[0][0]                    
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 112, 112, 64) 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, 112, 112, 64) 256         conv1[0][0]                      
__________________________________________________________________________________________________
activation

activation_157 (Activation)     (None, 56, 56, 256)  0           add_51[0][0]                     
__________________________________________________________________________________________________
res3a_branch2a (Conv2D)         (None, 28, 28, 128)  32896       activation_157[0][0]             
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_158 (Activation)     (None, 28, 28, 128)  0           bn3a_branch2a[0][0]              
__________________________________________________________________________________________________
res3a_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_158[0][0]             
__________________________________________________________________________________________________
bn3a_branc

__________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3d_branch2c[0][0]             
__________________________________________________________________________________________________
add_55 (Add)                    (None, 28, 28, 512)  0           bn3d_branch2c[0][0]              
                                                                 activation_166[0][0]             
__________________________________________________________________________________________________
activation_169 (Activation)     (None, 28, 28, 512)  0           add_55[0][0]                     
__________________________________________________________________________________________________
res4a_branch2a (Conv2D)         (None, 14, 14, 256)  131328      activation_169[0][0]             
__________________________________________________________________________________________________
bn4a_branc

__________________________________________________________________________________________________
bn4d_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_180 (Activation)     (None, 14, 14, 256)  0           bn4d_branch2b[0][0]              
__________________________________________________________________________________________________
res4d_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_180[0][0]             
__________________________________________________________________________________________________
bn4d_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4d_branch2c[0][0]             
__________________________________________________________________________________________________
add_59 (Add)                    (None, 14, 14, 1024) 0           bn4d_branch2c[0][0]              
          

__________________________________________________________________________________________________
bn5b_branch2a (BatchNormalizati (None, 7, 7, 512)    2048        res5b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_191 (Activation)     (None, 7, 7, 512)    0           bn5b_branch2a[0][0]              
__________________________________________________________________________________________________
res5b_branch2b (Conv2D)         (None, 7, 7, 512)    2359808     activation_191[0][0]             
__________________________________________________________________________________________________
bn5b_branch2b (BatchNormalizati (None, 7, 7, 512)    2048        res5b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_192 (Activation)     (None, 7, 7, 512)    0           bn5b_branch2b[0][0]              
__________

冻结不需要训练的层：

In [72]:
for layer in model.layers[:5]:
    layer.trainable = False

**TBD** ：但是为什么是前5层呢？花卉项目似乎冻结了所有层。

添加自己的层：

In [74]:
x = model.output
x = Flatten()(x)
x = Dense(1024, activation="relu")(x)
x = Dropout(0.5)(x)
x = Dense(1024, activation="relu")(x)
predictions = Dense(1, activation="sigmoid")(x)

形成最终模型：

In [75]:
model_final = Model(input = model.input, output = predictions)

  """Entry point for launching an IPython kernel.


+++++++++++++前方不通行+++++++++++

In [23]:
def add_new_last_layer(model, FC_SIZE, dropout, nb_classes, last_activation):
    """Add last layer to the convnet
    Args:
    base_model: keras model excluding top
    nb_classes: # of classes
    Returns:
    new keras model with last layer
    """
    x = model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(FC_SIZE, activation='relu')(x)
    x = Dropout(dropout)(x)
    predictions = Dense(nb_classes, activation=last_activation)(x) 
    model = Model(input=model.input, output=predictions)
    return model

In [24]:
new_model = add_new_last_layer(model, 1024, 0.3, 2, 'sigmoid')

  


In [25]:
def transfer_learn(new_model, model):
    """Freeze all layers and compile the model"""
    for layer in model.layers:
        layer.trainable = False
    new_model.compile(optimizer='rmsprop',
                      loss='categorical_crossentropy', 
                      metrics=['accuracy'])

In [26]:
transfer_learn(new_model, model)

In [27]:
def finetune(new_model, lr):
    """Freeze the bottom NB_IV3_LAYERS and retrain the remaining top 
      layers.
    note: NB_IV3_LAYERS corresponds to the top 2 inception blocks in 
         the inceptionv3 architecture
    Args:
     model: keras model
    """
    for layer in new_model.layers[:NB_IV3_LAYERS_TO_FREEZE]:
        layer.trainable = False
    for layer in new_model.layers[NB_IV3_LAYERS_TO_FREEZE:]:
        layer.trainable = True
    new_model.compile(optimizer=SGD(lr=lr, momentum=0.9),
                      loss='categorical_crossentropy')

In [28]:
finetune(new_model, 0.0001)

NameError: name 'NB_IV3_LAYERS_TO_FREEZE' is not defined

In [9]:
history = model.fit_generator(
  generator['train'],
  samples_per_epoch=nb_train_samples,
  nb_epoch=nb_epoch,
  validation_data=generator['valid'],
  nb_val_samples=nb_valid_samples,
  class_weight='auto')

model.save(args.output_model_file)

NameError: name 'nb_train_samples' is not defined

In [None]:
def plot_training(history):
    acc = history.history['acc']
    val_acc = history.history['val_acc']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs = range(len(acc))

    plt.plot(epochs, acc, 'r.')
    plt.plot(epochs, val_acc, 'r')
    plt.title('Training and validation accuracy')

    plt.figure()
    plt.plot(epochs, loss, 'r.')
    plt.plot(epochs, val_loss, 'r-')
    plt.title('Training and validation loss')

    plt.show()

# 参考资料

+ https://zhuanlan.zhihu.com/p/26693647
+ https://medium.com/@14prakash/transfer-learning-using-keras-d804b2e04ef8