From 5ee04f7c82c79a6c88d9162083d6d732fdef9bb4 Mon Sep 17 00:00:00 2001 From: aidiary Date: Thu, 16 Feb 2017 11:10:39 +0900 Subject: [PATCH] Add a finetuning result with vgg16 preprocess --- vgg16/17flowers/finetuning_with_preprocess.py | 121 ++++++++++++++++++ vgg16/17flowers/predict_with_preprocess.py | 68 ++++++++++ .../history_finetuning_with_preprocess.txt | 51 ++++++++ 3 files changed, 240 insertions(+) create mode 100644 vgg16/17flowers/finetuning_with_preprocess.py create mode 100644 vgg16/17flowers/predict_with_preprocess.py create mode 100644 vgg16/17flowers/results/history_finetuning_with_preprocess.txt diff --git a/vgg16/17flowers/finetuning_with_preprocess.py b/vgg16/17flowers/finetuning_with_preprocess.py new file mode 100644 index 0000000..ce90161 --- /dev/null +++ b/vgg16/17flowers/finetuning_with_preprocess.py @@ -0,0 +1,121 @@ +import os +from keras.applications.vgg16 import VGG16 +from keras.preprocessing.image import ImageDataGenerator +from keras.models import Sequential, Model +from keras.layers import Input, Activation, Dropout, Flatten, Dense +from keras.preprocessing.image import ImageDataGenerator +from keras import optimizers +import numpy as np +from smallcnn import save_history + + +classes = ['Tulip', 'Snowdrop', 'LilyValley', 'Bluebell', 'Crocus', + 'Iris', 'Tigerlily', 'Daffodil', 'Fritillary', 'Sunflower', + 'Daisy', 'ColtsFoot', 'Dandelion', 'Cowslip', 'Buttercup', + 'Windflower', 'Pansy'] + +batch_size = 32 +nb_classes = len(classes) + +img_rows, img_cols = 150, 150 +channels = 3 + +train_data_dir = 'train_images' +validation_data_dir = 'test_images' + +nb_train_samples = 1190 +nb_val_samples = 170 +nb_epoch = 50 + +result_dir = 'results' +if not os.path.exists(result_dir): + os.mkdir(result_dir) + + +# keras.applications.imagenet_utilsのxは4Dテンソルなので +# 3Dテンソル版を作成 +def preprocess_input(x): + """Preprocesses a tensor encoding a batch of images. + # Arguments + x: input Numpy tensor, 3D. + # Returns + Preprocessed tensor. + """ + # 'RGB'->'BGR' + x = x[:, :, ::-1] + # Zero-center by mean pixel + x[:, :, 0] -= 103.939 + x[:, :, 1] -= 116.779 + x[:, :, 2] -= 123.68 + return x + + +if __name__ == '__main__': + # VGG16モデルと学習済み重みをロード + # Fully-connected層(FC)はいらないのでinclude_top=False) + input_tensor = Input(shape=(img_rows, img_cols, 3)) + vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor) + # vgg16.summary() + + # FC層を構築 + # Flattenへの入力指定はバッチ数を除く + top_model = Sequential() + top_model.add(Flatten(input_shape=vgg16.output_shape[1:])) + top_model.add(Dense(256, activation='relu')) + top_model.add(Dropout(0.5)) + top_model.add(Dense(nb_classes, activation='softmax')) + + # 学習済みのFC層の重みをロード + # top_model.load_weights(os.path.join(result_dir, 'bottleneck_fc_model.h5')) + + # VGG16とFCを接続 + model = Model(input=vgg16.input, output=top_model(vgg16.output)) + + # 最後のconv層の直前までの層をfreeze + for layer in model.layers[:15]: + layer.trainable = False + + # Fine-tuningのときはSGDの方がよい? + model.compile(loss='categorical_crossentropy', + optimizer=optimizers.SGD(lr=1e-4, momentum=0.9), + metrics=['accuracy']) + + # preprocessing_functionはKeras1.2.2から + # VGG16を使うための画像前処理を適用する + train_datagen = ImageDataGenerator( + preprocessing_function=preprocess_input, + shear_range=0.2, + zoom_range=0.2, + horizontal_flip=True) + + test_datagen = ImageDataGenerator( + preprocessing_function=preprocess_input) + + train_generator = train_datagen.flow_from_directory( + train_data_dir, + target_size=(img_rows, img_cols), + color_mode='rgb', + classes=classes, + class_mode='categorical', + batch_size=batch_size, + shuffle=True) + + validation_generator = test_datagen.flow_from_directory( + validation_data_dir, + target_size=(img_rows, img_cols), + color_mode='rgb', + classes=classes, + class_mode='categorical', + batch_size=batch_size, + shuffle=True) + + # Fine-tuning + history = model.fit_generator( + train_generator, + samples_per_epoch=nb_train_samples, + nb_epoch=nb_epoch, + validation_data=validation_generator, + nb_val_samples=nb_val_samples) + + model.save_weights(os.path.join(result_dir, 'finetuning.h5')) + save_history(history, os.path.join(result_dir, 'history_finetuning.txt')) diff --git a/vgg16/17flowers/predict_with_preprocess.py b/vgg16/17flowers/predict_with_preprocess.py new file mode 100644 index 0000000..1fb572d --- /dev/null +++ b/vgg16/17flowers/predict_with_preprocess.py @@ -0,0 +1,68 @@ +import os +import sys +from keras.applications.vgg16 import VGG16, preprocess_input +from keras.models import Sequential, Model +from keras.layers import Input, Activation, Dropout, Flatten, Dense +from keras.preprocessing import image +import numpy as np + +if len(sys.argv) != 2: + print("usage: python predict.py [filename]") + sys.exit(1) + +filename = sys.argv[1] +print('input:', filename) + +result_dir = 'results' + +classes = ['Tulip', 'Snowdrop', 'LilyValley', 'Bluebell', 'Crocus', + 'Iris', 'Tigerlily', 'Daffodil', 'Fritillary', 'Sunflower', + 'Daisy', 'ColtsFoot', 'Dandelion', 'Cowslip', 'Buttercup', + 'Windflower', 'Pansy'] +nb_classes = len(classes) + +img_height, img_width = 150, 150 +channels = 3 + +# VGG16 +input_tensor = Input(shape=(img_height, img_width, channels)) +vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor) + +# FC +fc = Sequential() +fc.add(Flatten(input_shape=vgg16.output_shape[1:])) +fc.add(Dense(256, activation='relu')) +fc.add(Dropout(0.5)) +fc.add(Dense(nb_classes, activation='softmax')) + +# VGG16とFCを接続 +model = Model(input=vgg16.input, output=fc(vgg16.output)) + +# 学習済みの重みをロード +model.load_weights(os.path.join(result_dir, 'finetuning.h5')) + +model.compile(loss='categorical_crossentropy', + optimizer='adam', + metrics=['accuracy']) +# model.summary() + +# 画像を読み込んで4次元テンソルへ変換 +img = image.load_img(filename, target_size=(img_height, img_width)) +x = image.img_to_array(img) +x = np.expand_dims(x, axis=0) + +x = preprocess_input(x) + +# print(x) +# print(x.shape) + +# クラスを予測 +# 入力は1枚の画像なので[0]のみ +pred = model.predict(x)[0] + +# 予測確率が高いトップ5を出力 +top = 5 +top_indices = pred.argsort()[-top:][::-1] +result = [(classes[i], pred[i]) for i in top_indices] +for x in result: + print(x) diff --git a/vgg16/17flowers/results/history_finetuning_with_preprocess.txt b/vgg16/17flowers/results/history_finetuning_with_preprocess.txt new file mode 100644 index 0000000..aa42d52 --- /dev/null +++ b/vgg16/17flowers/results/history_finetuning_with_preprocess.txt @@ -0,0 +1,51 @@ +epoch loss acc val_loss val_acc +0 7.799596 0.075630 2.736852 0.111765 +1 2.715135 0.121008 2.639145 0.147059 +2 2.672502 0.126050 2.605250 0.135294 +3 2.636434 0.144538 2.720314 0.100000 +4 2.620954 0.144538 2.598778 0.152941 +5 2.633619 0.158824 2.664991 0.170588 +6 2.565990 0.167227 2.603898 0.170588 +7 2.532105 0.164706 2.537722 0.211765 +8 2.465628 0.213445 2.537234 0.217647 +9 2.390184 0.227731 2.445597 0.241176 +10 2.352276 0.266387 2.288540 0.323529 +11 2.266398 0.280672 2.207796 0.305882 +12 2.216889 0.305042 2.345270 0.270588 +13 2.154504 0.321008 2.094557 0.388235 +14 2.131545 0.352101 2.194597 0.352941 +15 2.048808 0.363025 1.965545 0.370588 +16 1.977085 0.395798 2.100380 0.388235 +17 2.005733 0.380672 1.857513 0.405882 +18 1.978694 0.390756 2.166514 0.388235 +19 1.879812 0.411765 1.966580 0.394118 +20 1.800871 0.429412 1.782561 0.452941 +21 1.916352 0.403361 2.020662 0.423529 +22 1.874563 0.426891 2.082303 0.376471 +23 1.749871 0.453782 1.941729 0.411765 +24 1.697299 0.453782 1.885577 0.441176 +25 1.742076 0.454622 1.989393 0.382353 +26 1.680414 0.457983 1.846914 0.441176 +27 1.668524 0.486555 2.165333 0.400000 +28 1.657829 0.478992 1.858192 0.429412 +29 1.635422 0.487395 1.790263 0.429412 +30 1.626568 0.473109 1.866771 0.441176 +31 1.695085 0.447899 1.817949 0.470588 +32 1.555944 0.509244 1.605009 0.482353 +33 1.532633 0.519328 1.579341 0.511765 +34 1.448442 0.541176 1.486748 0.529412 +35 1.321206 0.573950 1.686954 0.505882 +36 1.224674 0.598319 1.280517 0.623529 +37 1.128079 0.626891 1.160887 0.629412 +38 0.980099 0.675630 0.969792 0.700000 +39 0.928305 0.692437 1.127778 0.652941 +40 0.821076 0.730252 0.972618 0.682353 +41 0.772046 0.725210 0.902807 0.711765 +42 0.681999 0.773950 0.764624 0.711765 +43 0.637164 0.779832 0.815621 0.723529 +44 0.593890 0.798319 0.655362 0.794118 +45 0.518867 0.820168 0.602647 0.764706 +46 0.573130 0.804202 0.756010 0.752941 +47 0.476230 0.839496 0.739997 0.758824 +48 0.413541 0.859664 0.826336 0.758824 +49 0.402393 0.864706 0.481998 0.829412