Permalink
Switch branches/tags
Nothing to show
Find file Copy path
aabe380 Jan 6, 2017
1 contributor

Users who have contributed to this file

84 lines (67 sloc) 2.54 KB
import os
from keras.models import Sequential
from keras.layers import Convolution2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense
from keras.preprocessing.image import ImageDataGenerator
nb_epoch = 50
result_dir = 'results'
if not os.path.exists(result_dir):
os.mkdir(result_dir)
def save_history(history, result_file):
loss = history.history['loss']
acc = history.history['acc']
val_loss = history.history['val_loss']
val_acc = history.history['val_acc']
nb_epoch = len(acc)
with open(result_file, "w") as fp:
fp.write("epoch\tloss\tacc\tval_loss\tval_acc\n")
for i in range(nb_epoch):
fp.write("%d\t%f\t%f\t%f\t%f\n" % (i, loss[i], acc[i], val_loss[i], val_acc[i]))
if __name__ == '__main__':
# モデルを構築
model = Sequential()
model.add(Convolution2D(32, 3, 3, input_shape=(150, 150, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(32, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(64, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'])
# 訓練データとバリデーションデータを生成するジェネレータを作成
train_datagen = ImageDataGenerator(
rescale=1.0 / 255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1.0 / 255)
train_generator = train_datagen.flow_from_directory(
'data/train',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
'data/validation',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
# 訓練
history = model.fit_generator(
train_generator,
samples_per_epoch=2000,
nb_epoch=nb_epoch,
validation_data=validation_generator,
nb_val_samples=800)
# 結果を保存
model.save_weights(os.path.join(result_dir, 'smallcnn.h5'))
save_history(history, os.path.join(result_dir, 'history_smallcnn.txt'))