In [None]:
# Import required libraries
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, precision_recall_curve, classification_report, confusion_matrix
import numpy as np

In [None]:
# Define paths and parameters
dataset_dir = 'trashnet/data/dataset-resized'
img_height, img_width = 224, 224
batch_size = 32

# Data generator for validation (no augmentation, just rescaling)
validation_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.3)

# Load validation data
validation_generator = validation_datagen.flow_from_directory(
    dataset_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical',
    subset='validation',
    shuffle=False  # Important for aligning predictions with true labels
)

In [None]:
# Load the trained model
model = tf.keras.models.load_model('garbage_classifier.h5')

# Evaluate the model
validation_loss, validation_accuracy = model.evaluate(validation_generator)
print(f"Validation Loss: {validation_loss:.4f}")
print(f"Validation Accuracy: {validation_accuracy:.4f}")

# Get predictions
validation_generator.reset()
y_pred = model.predict(validation_generator)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true = validation_generator.classes
class_labels = list(validation_generator.class_indices.keys())
n_classes = len(class_labels)

# One-hot encode y_true for ROC-AUC (multi-class)
y_true_one_hot = tf.keras.utils.to_categorical(y_true, num_classes=n_classes)

In [None]:
# Load the trained model
model = tf.keras.models.load_model('garbage_classifier.h5')

# Evaluate the model
validation_loss, validation_accuracy = model.evaluate(validation_generator)
print(f"Validation Loss: {validation_loss:.4f}")
print(f"Validation Accuracy: {validation_accuracy:.4f}")

# Get predictions
validation_generator.reset()
y_pred = model.predict(validation_generator)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true = validation_generator.classes
class_labels = list(validation_generator.class_indices.keys())
n_classes = len(class_labels)

# One-hot encode y_true for ROC-AUC (multi-class)
y_true_one_hot = tf.keras.utils.to_categorical(y_true, num_classes=n_classes)

In [None]:
# Confusion Matrix
cm = confusion_matrix(y_true, y_pred_classes)
print("Confusion Matrix:\n", cm)

# Classification Report (includes precision, recall, f1-score)
print("Classification Report:\n", classification_report(y_true, y_pred_classes, target_names=class_labels))

In [None]:
# ROC Curve and ROC-AUC for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_true_one_hot[:, i], y_pred[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Compute macro-average ROC-AUC
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
    mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
mean_tpr /= n_classes
macro_auc = auc(all_fpr, mean_tpr)

# Plot ROC Curves
plt.figure(figsize=(10, 8))
colors = ['blue', 'red', 'green', 'yellow', 'purple', 'orange']
for i, color in zip(range(n_classes), colors):
    plt.plot(fpr[i], tpr[i], color=color, lw=2,
             label=f'ROC curve of class {class_labels[i]} (AUC = {roc_auc[i]:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.plot(all_fpr, mean_tpr, color='black', lw=2, linestyle='-', label=f'Macro-average ROC curve (AUC = {macro_auc:.2f})')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.grid()
plt.show()

In [None]:
# Precision-Recall Curve for each class
plt.figure(figsize=(10, 8))
for i, color in zip(range(n_classes), colors):
    precision, recall, _ = precision_recall_curve(y_true_one_hot[:, i], y_pred[:, i])
    plt.plot(recall, precision, color=color, lw=2,
             label=f'Precision-Recall curve of class {class_labels[i]}')

plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc="lower left")
plt.grid()
plt.show()

In [None]:
# Compute per-class accuracy from confusion matrix
cm = confusion_matrix(y_true, y_pred_classes)
per_class_accuracy = cm.diagonal() / cm.sum(axis=1)
print("Per-Class Accuracy:", {class_labels[i]: acc for i, acc in enumerate(per_class_accuracy)})

# Plot bar graph
plt.figure(figsize=(10, 6))
plt.bar(class_labels, per_class_accuracy)
plt.title('Per-Class Validation Accuracy')
plt.xlabel('Class')
plt.ylabel('Accuracy')
plt.ylim(0, 1.0)  # Accuracy between 0 and 1
plt.xticks(rotation=45)
for i, v in enumerate(per_class_accuracy):
    plt.text(i, v, f'{v:.2f}', ha='center', va='bottom')
plt.grid(axis='y')
plt.show()