In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing import image
import logging
logging.getLogger('absl').setLevel(logging.ERROR)

In [2]:

def predict_image(model_path, image_path):
    # Extract model name from the filename (e.g., 'VGG16_custom_head_model.h5' → 'VGG16')
    model_name = os.path.basename(model_path).split('_')[0]
    
    # Set image size based on model
    img_h, img_w = (299, 299) if model_name == 'Xception' else (224, 224)
    
    # Map model name to preprocessing function and base model class
    model_config = {
        'VGG19': (tf.keras.applications.vgg19.preprocess_input, tf.keras.applications.VGG19),
        'ResNet50': (tf.keras.applications.resnet50.preprocess_input, tf.keras.applications.ResNet50),
        'VGG16': (tf.keras.applications.vgg16.preprocess_input, tf.keras.applications.VGG16),
        'MobileNetV2': (tf.keras.applications.mobilenet_v2.preprocess_input, tf.keras.applications.MobileNetV2),
        'Xception': (tf.keras.applications.xception.preprocess_input, tf.keras.applications.Xception),
        'EfficientNetB0': (tf.keras.applications.efficientnet.preprocess_input, tf.keras.applications.EfficientNetB0),
        'DenseNet121': (tf.keras.applications.densenet.preprocess_input, tf.keras.applications.DenseNet121)
    }
    
    preprocess_input, base_model_class = model_config[model_name]
    
    # Create feature extractor (base model + pooling)
    base_model = base_model_class(weights='imagenet', include_top=False, input_shape=(img_h, img_w, 3))
    feature_extractor = tf.keras.Model(
        inputs=base_model.input,
        outputs=tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
    )
    
    # Load custom head model
    custom_head_model = tf.keras.models.load_model(model_path)
    
    # Load and preprocess image
    img = image.load_img(image_path, target_size=(img_h, img_w))
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_preprocessed = preprocess_input(img_array)
    # Extract features and predict
    features = feature_extractor.predict(img_preprocessed,verbose=0)
    prediction = custom_head_model.predict(features, verbose=0)
    
    imag_name =  os.path.basename(image_path)
    # Convert prediction to class label (assuming class 0: 'oblique', class 1: 'overriding')
    return (model_name, imag_name,'overriding' if prediction[0][0] > 0.5 else 'oblique')


In [3]:
# model_path = '../05- Saved Models/VGG16_custom_head_model.h5'
# image_path = '../Test Data/v3.png'
# predicted_class = predict_image(model_path, image_path)
# print(f'Predicted class: {predicted_class}')
Base_Folder = 'D:/Learning/University of sadat/Grade 4/Semester 2/06- Graduation Project/Coding/'
Models_folder = f"{Base_Folder}05- Saved Models/"
Images_Folder = f"{Base_Folder}Test Data/"


models=[]

for file in os.listdir(Models_folder) :
    if file.lower().endswith("h5"):
        models.append(f"{Models_folder}{file}")

Results = []
for img  in os.listdir(Images_Folder) :
    for model in models:
        Results.append(predict_image(model,f"{Images_Folder}{img}"))

print(Results)

for model_name,img_name,answer  in Results:
    print(f"the prediction result of the \033[91m{model_name}\033[0m model for this image {img_name} is \033[92m {answer}\033[0m class ")









[('DenseNet121', 'o1.png', 'oblique'), ('EfficientNetB0', 'o1.png', 'oblique'), ('MobileNetV2', 'o1.png', 'oblique'), ('ResNet50', 'o1.png', 'oblique'), ('VGG16', 'o1.png', 'oblique'), ('VGG19', 'o1.png', 'oblique'), ('Xception', 'o1.png', 'oblique'), ('DenseNet121', 'o2.png', 'oblique'), ('EfficientNetB0', 'o2.png', 'oblique'), ('MobileNetV2', 'o2.png', 'oblique'), ('ResNet50', 'o2.png', 'oblique'), ('VGG16', 'o2.png', 'oblique'), ('VGG19', 'o2.png', 'oblique'), ('Xception', 'o2.png', 'oblique'), ('DenseNet121', 'v1.png', 'overriding'), ('EfficientNetB0', 'v1.png', 'overriding'), ('MobileNetV2', 'v1.png', 'oblique'), ('ResNet50', 'v1.png', 'overriding'), ('VGG16', 'v1.png', 'overriding'), ('VGG19', 'v1.png', 'oblique'), ('Xception', 'v1.png', 'overriding'), ('DenseNet121', 'v2.png', 'overriding'), ('EfficientNetB0', 'v2.png', 'overriding'), ('MobileNetV2', 'v2.png', 'overriding'), ('ResNet50', 'v2.png', 'overriding'), ('VGG16', 'v2.png', 'overriding'), ('VGG19', 'v2.png', 'overriding'