In [3]:
from sklearn.metrics import classification_report
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, f1_score

def create_confusion_matrix(true_labels, predicted_labels, class_names):
    cm = confusion_matrix(true_labels, predicted_labels)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix')
    plt.show()

def calculate_metrics(model):
  predictions = model.predict(test_data)

  true_labels = test_data.map(lambda x, y: y).unbatch()

  true_labels = list(tfds.as_numpy(true_labels))

  class_report = classification_report(true_labels, predictions)
  print(class_report)
  return predictions, true_labels

def find_most_wrong_predictions(true_labels, predictions):
  most_wrong_preds = []

  for i in range(len(predictions)):
    true_label = true_labels[i]
    predicted_label = np.argmax(predictions[i])
    prediction_prob = predictions[i][predicted_label]

    if true_label != predicted_label:
      most_wrong_preds.append((i, true_label, predicted_label, prediction_prob))

  most_wrong_preds.sort(key=lambda x: x[3], reverse=True)

  for idx, true_label, predicted_label, prediction_prob in most_wrong_preds[:10]:
    print(f"Index: {idx}, True Label: {true_label}, Predicted Label: {predicted_label}, Probability: {prediction_prob}")


def plot_f1_scores(true_labels, predicted_labels, class_names):
    f1_scores = f1_score(true_labels, predicted_labels, average=None)
    plt.figure(figsize=(10, 6))
    plt.bar(class_names, f1_scores)
    plt.xlabel('Class')
    plt.ylabel('F1-Score')
    plt.title('F1-Scores by Class')
    plt.xticks(rotation=90)
    plt.show()

def visualize_predictions(test_images, true_labels, predicted_labels, class_names, prediction_probabilities):
    num_samples = len(test_images)
    sample_indices = np.random.choice(num_samples, size=5, replace=False)  # Choose 5 random samples

    plt.figure(figsize=(15, 6))
    for i, idx in enumerate(sample_indices):
        plt.subplot(1, 5, i + 1)
        plt.imshow(test_images[idx])
        plt.title(f"Prediction: {class_names[predicted_labels[idx]]}\n"
                  f"Probability: {prediction_probabilities[idx]:.2f}\n"
                  f"Ground Truth: {class_names[true_labels[idx]]}")
    plt.show()