In [5]:
import cv2
import numpy as np
import time
import os
from tensorflow.keras.models import load_model
from tabulate import tabulate  # Import form printing library
from scipy import stats

# Tags and their corresponding paths
labels = ['glioma', 'meningioma', 'notumor', 'pituitary']
label_dict = {
    'glioma': './Dataset/cropped/Testing/glioma',
    'meningioma': './Dataset/cropped/Testing/meningioma',
    'notumor': './Dataset/cropped/Testing/notumor',
    'pituitary': './Dataset/cropped/Testing/pituitary'
}

# Model file list
model_files = [
    'bestresnet4.h5',
    'bestvgg19-2.h5',
    'bestXception2.h5',
    'bestinception2.h5',
    'bestMobileNetV2.h5',
    'bestNASNetLarge.h5'
]

# Image size required by the model
model_input_sizes = {
    'bestresnet4.h5': 200,
    'bestvgg19-2.h5': 200,
    'bestXception2.h5': 200,
    'bestinception2.h5': 200,
    'bestMobileNetV2.h5': 200,
    'bestNASNetLarge.h5': 331
}

# Image preprocessing function
def preprocess_image(image_path, image_size):
    image = cv2.imread(image_path, 0)  # 加载灰度图像
    image = cv2.bilateralFilter(image, 2, 50, 50)  # 去噪
    image = cv2.applyColorMap(image, cv2.COLORMAP_BONE)  # 伪彩色处理
    image = cv2.resize(image, (image_size, image_size))  # 调整大小
    return image

# Get all image paths
def get_image_paths(label_dict):
    image_paths = []
    for label, path in label_dict.items():
        for root, _, files in os.walk(path):
            for file in files:
                if file.endswith('.jpg') or file.endswith('.jpeg') or file.endswith('.png'):
                    image_paths.append(os.path.join(root, file))
    return image_paths

# Count the prediction time of each model
def calculate_prediction_times(model_files, model_input_sizes, image_paths):
    prediction_times = {model_file: [] for model_file in model_files}

    for model_file in model_files:
        model_path = model_file
        model = load_model(model_path)
        image_size = model_input_sizes[model_file]

        for image_path in image_paths:
            try:
                image = preprocess_image(image_path, image_size)
            except Exception as e:
                print(f"Error processing image {image_path}: {e}")
                continue

            image = np.expand_dims(image, axis=0)
            image = np.expand_dims(image, axis=-1)  #Add batch and channel dimensions

            # Warm up model
            model.predict(image)

            # Calculate prediction time
            start_time = time.time()
            model.predict(image)
            end_time = time.time()
            prediction_time = (end_time - start_time) * 1000  # Convert to milliseconds
            prediction_times[model_file].append(prediction_time)

        # Release model memory
        del model

    return prediction_times

# Calculate statistics
def calculate_statistics(data):
    statistics = {}
    for model_file, times in data.items():
        min_time = np.min(times)
        max_time = np.max(times)
        mean_time = np.mean(times)
        q1_time = np.percentile(times, 25)
        median_time = np.median(times)
        q3_time = np.percentile(times, 75)
        statistics[model_file] = {
            'min_time': min_time,
            'max_time': max_time,
            'mean_time': mean_time,
            'q1_time': q1_time,
            'median_time': median_time,
            'q3_time': q3_time
        }
    return statistics

# Get all image paths
image_paths = get_image_paths(label_dict)

# Perform calculations
prediction_times = calculate_prediction_times(model_files, model_input_sizes, image_paths)
statistics = calculate_statistics(prediction_times)

# Prepare table data
table_data = []
for model_file, stats_data in statistics.items():
    table_data.append([model_file,
                       stats_data['min_time'],
                       stats_data['max_time'],
                       stats_data['mean_time'],
                       stats_data['q1_time'],
                       stats_data['median_time'],
                       stats_data['q3_time']])

#Print result table
headers = ['Model', 'Min Time (ms)', 'Max Time (ms)', 'Mean Time (ms)', 'Q1 Time (ms)', 'Median Time (ms)', 'Q3 Time (ms)']
print(tabulate(table_data, headers=headers, tablefmt='grid'))


+--------------------+-----------------+-----------------+------------------+----------------+--------------------+----------------+
| Model              |   Min Time (ms) |   Max Time (ms) |   Mean Time (ms) |   Q1 Time (ms) |   Median Time (ms) |   Q3 Time (ms) |
| bestresnet4.h5     |         76.4611 |         249.069 |          90.0858 |        84.7486 |            91.0397 |        95.0003 |
+--------------------+-----------------+-----------------+------------------+----------------+--------------------+----------------+
| bestvgg19-2.h5     |         97.8785 |         301.134 |         128.627  |       121.032  |           128.944  |       136.457  |
+--------------------+-----------------+-----------------+------------------+----------------+--------------------+----------------+
| bestXception2.h5   |         73.9999 |         254.095 |          84.9209 |        82.7417 |            84.6989 |        86.5954 |
+--------------------+-----------------+-----------------+-----------

In [2]:
# !pip install tabulate


Collecting tabulate
  Downloading tabulate-0.9.0-py3-none-any.whl (35 kB)
Installing collected packages: tabulate
Successfully installed tabulate-0.9.0


