In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os

def get_logits(model_path, test_dir):
    """Get logits from a single model"""
    model = tf.keras.models.load_model(model_path)
    
    image_files = []
    for root, _, files in os.walk(test_dir):
        for file in files:
            if file.lower().endswith('.jpg'):
                image_files.append(os.path.join(root, file))
    
    def process_image(path):
        img = tf.io.read_file(path)
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, (224, 224))
        img = tf.cast(img, tf.float32) / 255.0
        return img
    
    dataset = tf.data.Dataset.from_tensor_slices(image_files)
    dataset = dataset.map(process_image)
    dataset = dataset.batch(32)
    
    predictions = model.predict(dataset, verbose=0)
    return predictions.flatten().tolist()

def plot_logit_histogram(logits, save_path, model_name):
    """Create and save histogram for a single model"""
    plt.figure(figsize=(10, 6))
    
    n, bins, patches = plt.hist(logits, bins=50, edgecolor='black')
    
    mean_logit = np.mean(logits)
    median_logit = np.median(logits)
    std_dev = np.std(logits)
    
    plt.axvline(mean_logit, color='red', linestyle='dashed', linewidth=2, 
                label=f'Mean: {mean_logit:.3f}')
    plt.axvline(median_logit, color='green', linestyle='dashed', linewidth=2, 
                label=f'Median: {median_logit:.3f}')
    
    plt.title(f'Logit Distribution - {model_name}', fontsize=12)
    plt.xlabel('Logit Value', fontsize=10)
    plt.ylabel('Frequency', fontsize=10)
    plt.grid(True, alpha=0.3)
    
    # Add statistics text box
    stats_text = f'Statistics:\nMean: {mean_logit:.3f}\nMedian: {median_logit:.3f}\nStd Dev: {std_dev:.3f}'
    plt.text(0.95, 0.95, stats_text, transform=plt.gca().transAxes, 
             verticalalignment='top', horizontalalignment='right',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.legend()
    plt.tight_layout()
    
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    return {
        'mean': mean_logit,
        'median': median_logit,
        'std_dev': std_dev
    }

In [2]:
def process_all_models(models_dir, test_dir, output_dir):
    """
    Process all models in a directory and save their histograms
    
    Args:
        models_dir (str): Directory containing the model files
        test_dir (str): Directory containing test images
        output_dir (str): Directory where histograms will be saved
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Dictionary to store statistics for all models
    all_stats = {}
    
    # Process each model file
    for root, _, files in os.walk(models_dir):
        for file in files:
            if file.endswith('.h5'):
                model_path = os.path.join(root, file)
                model_name = os.path.splitext(file)[0]
                
                print(f"\nProcessing model: {model_name}")
                
                try:
                    # Get logits
                    logits = get_logits(model_path, test_dir)
                    
                    # Create histogram and save it
                    histogram_path = os.path.join(output_dir, f"{model_name}_histogram.png")
                    stats = plot_logit_histogram(logits, histogram_path, model_name)
                    
                    # Store statistics
                    all_stats[model_name] = stats
                    
                    print(f"Processed {model_name}:")
                    print(f"Mean: {stats['mean']:.3f}")
                    print(f"Median: {stats['median']:.3f}")
                    print(f"Std Dev: {stats['std_dev']:.3f}")
                    print(f"Histogram saved to: {histogram_path}")
                    
                except Exception as e:
                    print(f"Error processing {model_name}: {str(e)}")
    
    # Save summary statistics to CSV
    import pandas as pd
    stats_df = pd.DataFrame.from_dict(all_stats, orient='index')
    stats_path = os.path.join(output_dir, 'model_statistics.csv')
    stats_df.to_csv(stats_path)
    print(f"\nSummary statistics saved to: {stats_path}")

In [3]:
models_dir = r"D:\Kananat\_result\model_to_test\5px"  # Folder containing all your .h5 models
test_dir = r"D:\Kananat\TF_TMJOA_jpg_x_5px_test_total"        # Folder containing test images
output_dir = r"D:\Kananat\_result\result_5px\logits"   # Where to save histograms and statistics

process_all_models(models_dir, test_dir, output_dir)


Processing model: DenseNet201_bo40_lr0001
Processed DenseNet201_bo40_lr0001:
Mean: 0.987
Median: 1.000
Std Dev: 0.083
Histogram saved to: D:\Kananat\_result\result_5px\logits\DenseNet201_bo40_lr0001_histogram.png

Processing model: EfficientNetB7_bo40_lr0001
Processed EfficientNetB7_bo40_lr0001:
Mean: 1.000
Median: 1.000
Std Dev: 0.000
Histogram saved to: D:\Kananat\_result\result_5px\logits\EfficientNetB7_bo40_lr0001_histogram.png

Processing model: EfficientNetV2L_bo40_lr0001
Processed EfficientNetV2L_bo40_lr0001:
Mean: 1.000
Median: 1.000
Std Dev: 0.000
Histogram saved to: D:\Kananat\_result\result_5px\logits\EfficientNetV2L_bo40_lr0001_histogram.png

Processing model: InceptionV3_bo20_lr0001
Processed InceptionV3_bo20_lr0001:
Mean: 0.193
Median: 0.168
Std Dev: 0.043
Histogram saved to: D:\Kananat\_result\result_5px\logits\InceptionV3_bo20_lr0001_histogram.png

Processing model: MobileNetV3Large_bo20_lr0001
Processed MobileNetV3Large_bo20_lr0001:
Mean: 0.023
Median: 0.017
Std Dev: 