导入所有需要的库。

In [None]:
import os
from keras.preprocessing.image import ImageDataGenerator
from keras.applications.xception import Xception
import numpy
from keras.models import Sequential
from keras.layers.core import Flatten, Dense, Dropout
from keras.callbacks import EarlyStopping
from matplotlib import pyplot
import csv

%matplotlib inline

设定需要的路径。当前目录 `./` 为该 jupyter notebook 文件、训练集和测试集数据所在目录。

In [None]:
TRAIN_DIR = './train/'
TRAIN_GEN_DIR = './train_gen/'

DOGS_DIR = TRAIN_GEN_DIR + 'dogs/'
CATS_DIR = TRAIN_GEN_DIR + 'cats/'

IMAGENET_FEATURES = './imagenet_features.npy'
MODEL = './model.h5'

TEST_DIR = './test1/'
SUBMISSION = './submission.csv'

为 `ImageDataGenerator` 创建文件夹。

In [None]:
train_list = [name for name in os.listdir(TRAIN_DIR)]

train_dogs = [name for name in train_list if 'dog' in name]
train_cats = [name for name in train_list if 'cat' in name]

os.makedirs(DOGS_DIR)
os.makedirs(CATS_DIR)

for name in train_dogs:
    os.symlink(TRAIN_DIR+name, DOGS_DIR+name)    
for name in train_cats:
    os.symlink(TRAIN_DIR+name, CATS_DIR+name)

In [None]:
idg = ImageDataGenerator(rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        rescale=1./255)
x = idg.flow_from_directory(directory=TRAIN_GEN_DIR, target_size=(299, 299))

In [None]:
base_model = Xception(include_top=False)

features = base_model.predict_generator(x, 150000)

numpy.save(open(IMAGENET_FEATURES, 'w'), features)

del base_model

In [None]:
features = numpy.load(open(IMAGENET_FEATURES))

model = Sequential()
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(2, activation='softmax'))

model.compile(optimizer='nadam', loss='binary_crossentropy', metrics=['accuracy'])

es = EarlyStopping(monitor='val_loss', patience=3)

history = model.fit(x=features, y=y, batch_size=128, epochs=50, callbacks=[es], validation_split=0.2)

In [None]:
loss = history.history['loss']
val_loss = history.history['val_loss']

acc = history.history['acc']
val_acc = history.history['val_acc']

pyplot.figure(figsize=(8,6))
pyplot.subplots_adjust(wspace=1, hspace=1)

pyplot.subplot(211)
pyplot.plot(loss, 'blue', label='Training Loss')
pyplot.plot(val_loss, 'green', label='Validation Loss')
pyplot.xlabel('Epochs')
pyplot.ylabel('Loss')
pyplot.title('Xception Loss Trend')
pyplot.legend()

pyplot.subplot(212)
pyplot.plot(acc, 'blue', label='Training Accuracy')
pyplot.plot(val_acc, 'green', label='Validation Accuracy')
pyplot.xlabel('Epochs')
pyplot.ylabel('Accuracy')
pyplot.title('Xception Accuracy Trend')
pyplot.legend()

pyplot.show()

In [None]:
model.save(MODEL)

del model

In [None]:
model = load_model(MODEL)

gen = ImageDataGenerator()
x_test = gen.flow_from_directory(directory=TEST_DIR, target_size=(299, 299), shuffle=False)

y_test = model.predict_generator(x_test)


file = open(SUBMISSION, 'wb')
writer = csv.writer(file)
writer.writerow(['id','label'])

for index, label in enumerate(y_test):
    writer.writerow([index, label])

file.close()