In [10]:
class_names_to_ids = {'cardboard': 0, 'glass': 1, 'metal': 2, 'paper':3, 'plastic':4, 'trash':5}



In [11]:
import os
data_dir = 'dataset/'
output_path = 'list.txt'
fd = open(output_path, 'w')
for class_name in class_names_to_ids.keys():
    images_list = os.listdir(data_dir + class_name)
    for image_name in images_list:
        fd.write('{}/{} {}\n'.format(class_name, image_name, class_names_to_ids[class_name]))
fd.close()

In [12]:
# 随机选取样本做训练集和测试集
import random
_NUM_VALIDATION = 505
_RANDOM_SEED = 0
list_path = 'list.txt'
train_list_path = 'list_train.txt'
val_list_path = 'list_val.txt'
fd = open(list_path)
lines = fd.readlines()
fd.close()
random.seed(_RANDOM_SEED)
random.shuffle(lines)
fd = open(train_list_path, 'w')
for line in lines[_NUM_VALIDATION:]:
    fd.write(line)
fd.close()
fd = open(val_list_path, 'w')
for line in lines[:_NUM_VALIDATION]:
    fd.write(line)
fd.close()

In [13]:
def get_train_test_data(list_file):
    list_train = open(list_file)
    x_train = []
    y_train = []
    for line in list_train.readlines():
        x_train.append(line.strip()[:-2])
        y_train.append(int(line.strip()[-1]))
        #print(line.strip())
    return x_train, y_train
x_train, y_train = get_train_test_data('list_train.txt')
x_test, y_test = get_train_test_data('list_val.txt')


In [14]:
from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input

def process_train_test_data(x_path):
    images = []
    for image_path in x_path:
        img_load = image.load_img('dataset/'+image_path)
        img = image.img_to_array(img_load)
        img = preprocess_input(img)
        images.append(img)
    return images
train_images = process_train_test_data(x_train)
test_images = process_train_test_data(x_test)

In [15]:
from keras.layers import Activation, Dense
from tensorflow.keras import Model

In [16]:
from keras.applications.inception_resnet_v2 import InceptionResNetV2
base_model = InceptionResNetV2(include_top=False, pooling='avg')
outputs = Dense(6, activation='softmax')(base_model.output)
model = Model(base_model.inputs, outputs)

In [17]:
from keras.callbacks import ModelCheckpoint
import numpy as np
from tensorflow.keras.utils import to_categorical
from keras.metrics import top_k_categorical_accuracy

In [19]:
# 设置ModelCheckpoint，按照验证集的准确率进行保存
save_dir='train_model'
filepath="model_{epoch:02d}-{val_accuracy:.2f}.hdf5"
checkpoint = ModelCheckpoint(os.path.join(save_dir, filepath), monitor='val_accuracy',verbose=1, 
                            save_best_only=True)
# 模型设置
def acc_top3(y_true, y_pred):
    return top_k_categorical_accuracy(y_true, y_pred, k=3)
  
def acc_top5(y_true, y_pred):
    return top_k_categorical_accuracy(y_true, y_pred, k=5)

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy', acc_top3, acc_top5])
# 模型训练
model.fit(np.array(train_images), to_categorical(y_train),
          batch_size=8,
          epochs=5,
          shuffle=True,
          validation_data=(np.array(test_images), to_categorical(y_test)),
          callbacks=[checkpoint])

Epoch 1/5
Epoch 1: val_accuracy improved from -inf to 0.17030, saving model to train_model\model_01-0.17.hdf5
Epoch 2/5
Epoch 2: val_accuracy improved from 0.17030 to 0.17624, saving model to train_model\model_02-0.18.hdf5
Epoch 3/5
Epoch 3: val_accuracy improved from 0.17624 to 0.22970, saving model to train_model\model_03-0.23.hdf5
Epoch 4/5
Epoch 4: val_accuracy did not improve from 0.22970
Epoch 5/5
Epoch 5: val_accuracy did not improve from 0.22970


<keras.callbacks.History at 0x299134a6550>

In [20]:
print(test_images)

[array([[[151.061   , 138.22101 , 131.32    ],
        [151.061   , 138.22101 , 131.32    ],
        [151.061   , 138.22101 , 131.32    ],
        ...,
        [ 80.061   ,  82.221   ,  84.32    ],
        [ 80.061   ,  82.221   ,  84.32    ],
        [ 80.061   ,  82.221   ,  84.32    ]],

       [[151.061   , 138.22101 , 131.32    ],
        [151.061   , 138.22101 , 131.32    ],
        [151.061   , 138.22101 , 131.32    ],
        ...,
        [ 80.061   ,  82.221   ,  84.32    ],
        [ 80.061   ,  82.221   ,  84.32    ],
        [ 80.061   ,  82.221   ,  84.32    ]],

       [[151.061   , 138.22101 , 131.32    ],
        [151.061   , 138.22101 , 131.32    ],
        [151.061   , 138.22101 , 131.32    ],
        ...,
        [ 80.061   ,  82.221   ,  84.32    ],
        [ 80.061   ,  82.221   ,  84.32    ],
        [ 80.061   ,  82.221   ,  84.32    ]],

       ...,

       [[108.061   , 115.221   , 119.32    ],
        [108.061   , 115.221   , 119.32    ],
        [108.061   , 

In [21]:
print(np.array(test_images))

[[[[151.061      138.22101    131.32      ]
   [151.061      138.22101    131.32      ]
   [151.061      138.22101    131.32      ]
   ...
   [ 80.061       82.221       84.32      ]
   [ 80.061       82.221       84.32      ]
   [ 80.061       82.221       84.32      ]]

  [[151.061      138.22101    131.32      ]
   [151.061      138.22101    131.32      ]
   [151.061      138.22101    131.32      ]
   ...
   [ 80.061       82.221       84.32      ]
   [ 80.061       82.221       84.32      ]
   [ 80.061       82.221       84.32      ]]

  [[151.061      138.22101    131.32      ]
   [151.061      138.22101    131.32      ]
   [151.061      138.22101    131.32      ]
   ...
   [ 80.061       82.221       84.32      ]
   [ 80.061       82.221       84.32      ]
   [ 80.061       82.221       84.32      ]]

  ...

  [[108.061      115.221      119.32      ]
   [108.061      115.221      119.32      ]
   [108.061      115.221      119.32      ]
   ...
   [ 44.060997    48.221       50.3

In [22]:
# 加载指定模型
model.load_weights('train_model/model_03-0.23.hdf5')
# 直接使用predict方法进行预测
y_pred = model.predict(np.array(test_images))

In [23]:
y_pred

array([[0.00000000e+00, 2.20754724e-29, 1.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 3.80304373e-27],
       [3.81986856e-01, 7.01894867e-04, 6.58475934e-03, 1.00093475e-05,
        5.18780398e-05, 6.10664487e-01],
       [7.54209280e-01, 2.14092829e-03, 2.09121290e-03, 3.91501264e-04,
        2.88283569e-04, 2.40878791e-01],
       ...,
       [3.15510479e-05, 1.42450415e-04, 9.58114207e-01, 1.17592322e-07,
        3.44438831e-07, 4.17114198e-02],
       [0.00000000e+00, 0.00000000e+00, 1.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 1.40174150e-37],
       [0.00000000e+00, 0.00000000e+00, 1.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00]], dtype=float32)