In [None]:
import numpy as np

import matplotlib.pyplot as plt

def show_most_confused_examples(conf_matrix, y_true, y_pred, images, class_names, num_examples=5):
    # Find the highest off-diagonal value in the confusion matrix
    conf_matrix_no_diag = conf_matrix.copy()
    np.fill_diagonal(conf_matrix_no_diag, 0)
    max_idx = np.unravel_index(np.argmax(conf_matrix_no_diag), conf_matrix_no_diag.shape)
    class_a, class_b = max_idx

    # Find indices where true label is class_a but predicted as class_b
    misclassified_idx = np.where((y_true == class_a) & (y_pred == class_b))[0]
    selected_idx = misclassified_idx[:num_examples]

    print(f"Most confused classes: {class_names[class_a]} (true) vs {class_names[class_b]} (predicted)")
    print(f"Showing {len(selected_idx)} examples:")

    plt.figure(figsize=(15, 3))
    for i, idx in enumerate(selected_idx):
        plt.subplot(1, num_examples, i+1)
        plt.imshow(images[idx], cmap='gray' if images[idx].ndim == 2 else None)
        plt.title(f"True: {class_names[class_a]}\nPred: {class_names[class_b]}")
        plt.axis('off')
    plt.show()

    return class_a, class_b, selected_idx