# Model Check: Is it good?
Let's quickly see how well the model performs on test data.

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay, classification_report
import kagglehub
import pathlib

# 1. Load Data & Model
path = kagglehub.dataset_download("birdy654/cifake-real-and-ai-generated-synthetic-images")
test_ds = tf.keras.utils.image_dataset_from_directory(pathlib.Path(path) / 'test', image_size=(32, 32), batch_size=32, shuffle=False)
model = tf.keras.models.load_model('ai_detector_model.h5')
class_names = test_ds.class_names
print("Loaded model & dataset.")

In [None]:
# 2. Get Predictions (One-Liner style)
y_true = np.concatenate([y for x, y in test_ds], axis=0)
y_pred_probs = model.predict(test_ds, verbose=0)
y_pred = np.argmax(y_pred_probs, axis=1)

print(classification_report(y_true, y_pred, target_names=class_names))

In [None]:
# 3. Visual Confusion Matrix
ConfusionMatrixDisplay.from_predictions(y_true, y_pred, display_labels=class_names, cmap='Blues')
plt.title("Confusion Matrix")
plt.show()

In [None]:
# 4. Metrics Visualization (Resolution: Precision, Recall, F1)
report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
metrics = ['precision', 'recall', 'f1-score']
x = np.arange(len(class_names))
width = 0.25

plt.figure(figsize=(10, 6))
for i, metric in enumerate(metrics):
    values = [report[cls][metric] for cls in class_names]
    plt.bar(x + i*width, values, width, label=metric)

plt.xlabel('Classes')
plt.ylabel('Score')
plt.title('Detailed Performance Metrics')
plt.xticks(x + width, class_names)
plt.legend()
plt.ylim(0, 1.1)
plt.show()

In [None]:
# 5. Quick Visual Check (First batch)
plt.figure(figsize=(10, 10))
images, labels = next(iter(test_ds))
preds = np.argmax(model.predict(images, verbose=0), axis=1)

for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(f"True: {class_names[labels[i]]}\nPred: {class_names[preds[i]]}", color=('green' if labels[i] == preds[i] else 'red'))
    plt.axis("off")