In [None]:
# Import required libraries
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc, precision_score, f1_score
import seaborn as sns
import matplotlib.pyplot as plt

# Load the trained model
model_version = 1  # Change if using a different saved version
model_path = f"../models/{model_version}"
model = load_model(model_path)

# Load test dataset
# Ensure 'test_ds' is available in the same format as used during training
y_true = []
y_pred = []

for images, labels in test_ds:  
    predictions = model.predict(images)
    y_true.extend(labels.numpy())  # Actual labels
    y_pred.extend(np.argmax(predictions, axis=1))  # Predicted labels

# Convert lists to NumPy arrays
y_true = np.array(y_true)
y_pred = np.array(y_pred)

# Get class names 
class_names = ['Tomato_Bacterial_spot',
 'Tomato_Early_blight',
 'Tomato_Late_blight',
 'Tomato_Leaf_Mold',
 'Tomato_Septoria_leaf_spot',
 'Tomato_Spider_mites_Two_spotted_spider_mite',
 'Tomato__Target_Spot',
 'Tomato__Tomato_YellowLeaf__Curl_Virus',
 'Tomato__Tomato_mosaic_virus',
 'Tomato_healthy']  

# ------------------------
# 1️⃣ Precision and F1 Score
# ------------------------
precision = precision_score(y_true, y_pred, average='weighted')
f1 = f1_score(y_true, y_pred, average='weighted')

print(f"🔹 Precision: {precision:.4f}")
print(f"🔹 F1 Score: {f1:.4f}")

# ------------------------
# 2️⃣ Confusion Matrix
# ------------------------
cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(10,7))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted Label")
plt.ylabel("Actual Label")
plt.title("Confusion Matrix")
plt.show()

# ------------------------
# 3️⃣ ROC Curve & AUC
# ------------------------
# Convert labels to one-hot encoding for multi-class ROC curve
y_true_one_hot = tf.keras.utils.to_categorical(y_true, num_classes=len(class_names))
y_pred_prob = model.predict(test_ds)  # Get probability outputs

plt.figure(figsize=(10, 7))

for i in range(len(class_names)):  
    fpr, tpr, _ = roc_curve(y_true_one_hot[:, i], y_pred_prob[:, i])
    auc_score = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f'{class_names[i]} (AUC = {auc_score:.2f})')

plt.plot([0, 1], [0, 1], 'k--')  # Diagonal line (random model)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend()
plt.show()

# ------------------------
# 4️⃣ Classification Report
# ------------------------
print("🔹 Classification Report:\n")
print(classification_report(y_true, y_pred, target_names=class_names))
