In [1]:
import os, sys
import numpy as np
from keras.applications.vgg16 import VGG16
from keras.models import Sequential, Model
from keras.layers import Input, Activation, Dropout, Flatten, Dense
from keras.preprocessing import image
from keras import optimizers

In [2]:
# 分類するクラス
classes = ['1', '2', '3', '4', '5', '6', '7', 
          '8', '9', '10', '11']

# それぞれ　egg_type_list の['Field', 'Undiscovered', 'Bug', 'Amorphous', 'Dragon', 'Fairy', 'Mineral', 
#                     'Flying', 'Grass', 'Human-Like', 'Monster', 'Water']　に対応


nb_classes = len(classes)

img_width, img_height = 150, 150

In [3]:
result_dir = 'results'

test_data_dir = 'AllDataSet/test'

In [6]:
def model_load():
    # VGG16, FC層は不要なので include_top=False
    input_tensor = Input(shape=(img_width, img_height, 3))
    vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor)

    # FC層の作成
    top_model = Sequential()
    top_model.add(Flatten(input_shape=vgg16.output_shape[1:]))
    top_model.add(Dense(256, activation='relu'))
    top_model.add(Dropout(0.5))
    top_model.add(Dense(nb_classes, activation='softmax'))

    # VGG16とFC層を結合してモデルを作成
    model = Model(inputs=vgg16.input, outputs=top_model(vgg16.output))

    # 学習済みの重みをロード
    model.load_weights(os.path.join(result_dir, 'finetuning.h5'))

    # 多クラス分類を指定
    model.compile(loss='categorical_crossentropy',
              optimizer=optimizers.SGD(lr=1e-3, momentum=0.9),
              metrics=['accuracy'])

    return model

In [7]:
# モデルのロード
model = model_load()

# テスト用画像取得
test_imagelist = os.listdir(test_data_dir)

for test_image in test_imagelist:
    filename = os.path.join(test_data_dir, test_image)
    img = image.load_img(filename, target_size=(img_width, img_height))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    # 学習時に正規化してるので、ここでも正規化
    x = x / 255
    pred = model.predict(x)[0]

    # 予測確率が高いトップを出力
    # 今回は最も似ているクラスのみ出力したいので1にしているが、上位n個を表示させることも可能。
    top = 1
    top_indices = pred.argsort()[-top:][::-1]
    result = [(classes[i], pred[i]) for i in top_indices]
    print('file name is', test_image)
    print(result)
    print('=======================================')

file name is 837.png
[('6', 0.8204572)]
file name is 823.png
[('7', 0.5491222)]
file name is 822.png
[('7', 0.64107335)]
file name is 836.png
[('1', 0.7355311)]
file name is 820.png
[('6', 0.35117975)]
file name is 834.png
[('8', 0.24975431)]
file name is 835.png
[('8', 0.30046633)]
file name is 821.png
[('7', 0.8814871)]
file name is 819.png
[('1', 0.21848005)]
file name is 825.png
[('2', 0.41873986)]
file name is 831.png
[('7', 0.41336685)]
file name is 830.png
[('8', 0.4293204)]
file name is 824.png
[('2', 0.26073867)]
file name is 818.png
[('1', 0.46625343)]
file name is 832.png
[('7', 0.5722554)]
file name is 826.png
[('2', 0.37002388)]
file name is 827.png
[('2', 0.27858317)]
file name is 833.png
[('8', 0.27782497)]
file name is 883.png
[('1', 0.65585744)]
file name is 854.png
[('3', 0.8867558)]
file name is 840.png
[('2', 0.16744718)]
file name is 868.png
[('3', 0.56317496)]
file name is 869.png
[('3', 0.52969795)]
file name is 841.png
[('1', 0.23133639)]
file name is 855.png
[(