<h1>Utilise le cnn v2 et affiche les résultats</h1>

In [1]:
"""
    Use the CNN v2 model in order to make predictions on files in data_png_cnn2/all_png_test
    from: https://keras.io/examples/vision/image_classification_from_scratch/
"""

import sys
import datetime
import numpy as np
import pandas as pd
from os import listdir
from os.path import isfile, join
from keras.preprocessing import image
from keras.models import load_model
from keras.utils import np_utils 


# add parent folder to path 
from pathlib import Path
path = Path(os.getcwd())
str_parent = str(path.parent.absolute())
if not str_parent in sys.path:
    sys.path.append(str_parent)
import utils 



In [2]:
now = datetime.datetime.now()

image_size = (333, 216)
png_test_folder = '../../_data_png_cnn2/all_png_test/'
model_folder = '../../_classifier_cnn2/global_cnn.h5'

# load the model from disk
model = load_model(model_folder)

df_result = pd.DataFrame(columns=['file', 'score']) # df result  with files and scores
indiceFile = 0
nbPredictionOK = 0

# from train_ds.class_names: ['ToyCar', 'ToyConveyor', 'fan', 'pump', 'slider', 'valve']
list_machine = ['ToyCar', 'ToyConveyor', 'fan', 'pump', 'slider', 'valve'] # order ok 
dict_machine = {'ToyCar': {'err': 0, 'ok': 0, 'accuracy': 0}, 
                'ToyConveyor': {'err': 0, 'ok': 0, 'accuracy': 0}, 
                'fan': {'err': 0, 'ok': 0, 'accuracy': 0}, 
                'pump': {'err': 0, 'ok': 0, 'accuracy': 0}, 
                'slider': {'err': 0, 'ok': 0, 'accuracy': 0}, 
                'valve': {'err': 0, 'ok': 0, 'accuracy': 0}} # stats accuracy by machine



In [3]:
def predict_one_image(nameFilePngTotest):
    global nbPredictionOK
    global list_machine
    global dict_machine
    
    arrName = nameFilePngTotest.split("_") # normal_id_06_00000451_pump.png
    classPrefixReal = arrName[0] # 'normal' or 'anomaly'
    classNameReal = arrName[4][:-4] # 'pump' / [:-4] remove the '.png'
    test_image = image.load_img(png_test_folder + nameFilePngTotest, target_size = image_size) 
    img_array = image.img_to_array(test_image)
    img_array = np.expand_dims(test_image, axis = 0)
    arr_predictions = model.predict(img_array)
    
    arr_predictions = arr_predictions[0] # prediction one by one so use the [0]
    # print('nameFilePngTotest: ', nameFilePngTotest, np.round(arr_predictions, 2)) # [[0.   0.03 0.   0.95 0.03 0.  ]] 
    
    predictionOK = False
    indiceClassNamePredict = list_machine.index(classNameReal)
    scorePredict = arr_predictions[indiceClassNamePredict]
    if scorePredict > cutoff and classPrefixReal == 'normal' or scorePredict <= cutoff and classPrefixReal == 'anomaly' :
        predictionOK = True
        nbPredictionOK+= 1
    
    # count errors 
    if predictionOK == False:
        dict_machine[classNameReal]['err'] = dict_machine[classNameReal]['err'] + 1
    else:
        dict_machine[classNameReal]['ok'] = dict_machine[classNameReal]['ok'] + 1
        
    # calculate accuracy
    if (dict_machine[classNameReal]['ok'] + dict_machine[classNameReal]['err']) > 0:
        dict_machine[classNameReal]['accuracy'] =  round(dict_machine[classNameReal]['ok'] / (dict_machine[classNameReal]['ok'] + dict_machine[classNameReal]['err']), 4)

    # print('result: ', nameFilePngTotest, np.round(arr_predictions, 2), list_machine[indiceClassNamePredict], predictionOK) #ok

In [4]:
# rules: 
#   sound consider as normal if it predicts the correct machine with a score > cutoff
#   sound consider as anomaly if it doesn't predict the correct machine or if the score < cutoff
#   then count the correct predictions

cutoff = 0.7 

# browse test files
wavfiles = [f for f in listdir(png_test_folder) if isfile(join(png_test_folder, f))] # 7730 pngs
nbWavs = len(wavfiles)
print('nbWavs: ', nbWavs)
for nameFilePngTotest in wavfiles:
    if nameFilePngTotest[-4:] != '.png': # ignore non .png files 
        continue
    predict_one_image(nameFilePngTotest)    
    indiceFile += 1
    if indiceFile % 100 == 0:
        print('indiceFile: ', indiceFile, ' / ', nbWavs)

accuracy = nbPredictionOK / indiceFile
print('indiceFile: ', indiceFile, ' / ', nbWavs, ' nbPredictionOK: ', nbPredictionOK, ' accuracy: ', accuracy)
print('dict_machine: ', dict_machine) 

nbWavs:  6812
indiceFile:  100  /  6812
indiceFile:  200  /  6812
indiceFile:  300  /  6812
indiceFile:  400  /  6812
indiceFile:  500  /  6812
indiceFile:  600  /  6812
indiceFile:  700  /  6812
indiceFile:  800  /  6812
indiceFile:  900  /  6812
indiceFile:  1000  /  6812
indiceFile:  1100  /  6812
indiceFile:  1200  /  6812
indiceFile:  1300  /  6812
indiceFile:  1400  /  6812
indiceFile:  1500  /  6812
indiceFile:  1600  /  6812
indiceFile:  1700  /  6812
indiceFile:  1800  /  6812
indiceFile:  1900  /  6812
indiceFile:  2000  /  6812
indiceFile:  2100  /  6812
indiceFile:  2200  /  6812
indiceFile:  2300  /  6812
indiceFile:  2400  /  6812
indiceFile:  2500  /  6812
indiceFile:  2600  /  6812
indiceFile:  2700  /  6812
indiceFile:  2800  /  6812
indiceFile:  2900  /  6812
indiceFile:  3000  /  6812
indiceFile:  3100  /  6812
indiceFile:  3200  /  6812
indiceFile:  3300  /  6812
indiceFile:  3400  /  6812
indiceFile:  3500  /  6812
indiceFile:  3600  /  6812
indiceFile:  3700  /  6