In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import sys
import os

In [None]:
# Configuration
sys.path.append(os.path.abspath('..'))
from src.data_loader import get_data_generators

MODEL_PATH = '../models/cloud_model_best.h5'
DATA_DIR = '../data/raw'
IMG_SIZE = (224, 224)

In [None]:
# 1. Load Data and Model
print("--- Loading Validation Data and Model ---")
# shuffle=False is crucial here to match predictions with ground truth labels
_, val_gen = get_data_generators(DATA_DIR, img_size=IMG_SIZE, batch_size=32)
val_gen.shuffle = False
val_gen.reset()

print(f"Loading model from: {MODEL_PATH}")
model = tf.keras.models.load_model(MODEL_PATH)

In [None]:
# 2. Generate Predictions 
print("--- Generating Predictions ---")
Y_pred_probs = model.predict(val_gen, verbose=0)
y_pred = np.argmax(Y_pred_probs, axis=1) 
y_true = val_gen.classes                 
class_names = list(val_gen.class_indices.keys())

In [None]:
# 3. Identify Error Indices
errors = np.where(y_pred != y_true)[0]
print(f"Number of misclassified images: {len(errors)} out of {len(y_true)}.")

Visualization of top mistakes

In [None]:
def plot_mistakes(num_to_show=5):
    """
    Visualizes the 'worst' mistakes: images where the model was 
    highly confident but wrong.
    """
    
    # Retrieve confidence scores for incorrect decisions
    error_confidences = []
    for i in errors:
        confidence = Y_pred_probs[i][y_pred[i]]
        error_confidences.append((confidence, i))
    
    # Sort in descending order (highest confidence mistakes first)
    error_confidences.sort(key=lambda x: x[0], reverse=True)
    
    # Select top N errors
    top_errors = error_confidences[:num_to_show]

    if not top_errors:
        print("No errors found! The model is perfect on this dataset.")
        return

    plt.figure(figsize=(15, 6 * num_to_show))
    
    for idx, (conf, img_idx) in enumerate(top_errors):
        # Calculate batch index to retrieve the specific image
        batch_idx = img_idx // 32
        in_batch_idx = img_idx % 32
        
        # Manually extract image from the generator
        # Note: This might take a moment as we iterate through the generator
        val_gen.reset()
        for _ in range(batch_idx + 1):
            batch_imgs, batch_labels = next(val_gen)
            
        img = batch_imgs[in_batch_idx]
        true_label = class_names[y_true[img_idx]]
        pred_label = class_names[y_pred[img_idx]]
        
        # Plotting
        plt.subplot(num_to_show, 1, idx + 1)
        plt.imshow(img)
        plt.axis('off')
        
        # Title with error details
        title_text = (f"MISTAKE #{idx+1}\n"
                      f"Ground Truth: {true_label} | Prediction: {pred_label}\n"
                      f"Model Confidence: {conf:.2%}")
        
        plt.title(title_text, fontsize=14, color='red', fontweight='bold')
        
    plt.tight_layout()
    plt.show()

# Run the visualization
plot_mistakes(num_to_show=5)