In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc, precision_recall_curve, ConfusionMatrixDisplay, recall_score
from data_handler import load_datasets
from models.config import MODEL_PATHS
from train import img_history

def plot_confusion_matrix(y_true, y_pred, class_names):
    """
    Plots the confusion matrix using Seaborn's heatmap.
    """
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    disp.plot(cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.show()

def plot_roc_curve(y_true, y_pred_prob, class_names):
    """
    Plots ROC curve for binary classification.
    """
    fpr, tpr, _ = roc_curve(y_true, y_pred_prob[:, 1])
    roc_auc = auc(fpr, tpr)
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc='lower right')
    plt.show()

def plot_precision_recall_curve(y_true, y_pred_prob):
    """
    Plots Precision-Recall curve.
    """
    precision, recall, _ = precision_recall_curve(y_true, y_pred_prob[:, 1])
    plt.figure(figsize=(8, 6))
    plt.plot(recall, precision, color='blue', lw=2)
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.show()

def plot_loss_curve(history):
    """
    Plots the training and validation loss curve.
    """
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Loss Curve')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

def plot_accuracy_curve(history):
    """
    Plots the training and validation accuracy curve.
    """
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title('Accuracy Curve')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()

def calculate_specificity(y_true, y_pred):
    """
    Calculates specificity for binary classification.
    """
    cm = confusion_matrix(y_true, y_pred)
    tn = cm[0, 0]  # True Negatives
    fp = cm[0, 1]  # False Positives
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    return specificity

def evaluate_model(model_path, class_names):
    """
    Evaluates the model, prints metrics, and plots relevant visualizations.
    """
    # Load datasets
    _, val_ds, test_ds = load_datasets()

    # Load the trained model
    model = tf.keras.models.load_model(model_path)

    # Evaluate the model on test data
    loss, accuracy = model.evaluate(test_ds)
    print(f"Model: {model_path}")
    print(f"Test Loss: {loss}")
    print(f"Test Accuracy: {accuracy}")

    # Predictions
    predictions = model.predict(test_ds)
    y_true = tf.concat([y for x, y in test_ds], axis=0)
    y_pred = tf.argmax(predictions, axis=1)

    # Print Classification Report
    print(classification_report(y_true, y_pred, target_names=class_names))

    # Calculate and print specificity if binary classification
    if len(class_names) == 2:
        specificity = calculate_specificity(y_true, y_pred)
        print(f"Specificity: {specificity}")

    # Plot Confusion Matrix
    plot_confusion_matrix(y_true, y_pred, class_names)

    # Plot ROC curve if binary classification (2 classes)
    if len(class_names) == 2:
        plot_roc_curve(y_true, predictions, class_names)

    # Plot Precision-Recall curve
    plot_precision_recall_curve(y_true, predictions)
    plot_loss_curve(img_history)
    plot_accuracy_curve(img_history)
    
    
    
if __name__ == "__main__":
    CLASS_NAMES = ['real', 'fake']  # Replace with your actual class names
    print("Evaluating Mesonet model...")
    evaluate_model(MODEL_PATHS['mesonet'], CLASS_NAMES)
    #print("Evaluating ResNet50 model...")
    #evaluate_model(MODEL_PATHS['resnet50'], CLASS_NAMES)
    #print("Evaluating Xception model...")
    #evaluate_model(MODEL_PATHS['xception'], CLASS_NAMES)
