In [7]:
import os, sys
import tensorflow as tf
import numpy as np
from sklearn.metrics import roc_curve, precision_recall_curve, auc
import matplotlib.pyplot as plt
from tqdm import tqdm
from tensorflow.keras.models import load_model

In [2]:
sys.path.append(os.path.abspath(os.path.join('..','data_processing')))
sys.path.append(os.path.abspath(os.path.join('..','models')))

In [None]:
from contrastive_preprocessing import test_contrastive_dataset

### Evaluates the contrastive models on the test dataset by computing similarity scores generating ROC and PR curves and calculating performance metrics at a given threshold


In [4]:
def evaluate_on_test(
        model,
        test_dataset,
        threshold=0.75,
        num_test_steps=2000
):
    print("Computing similarities from test dataset")

    all_distances = []
    all_labels = []

    for (anchor_img, comparison_img), labels in tqdm(test_dataset.take(num_test_steps)):
        distances = model.predict([anchor_img, comparison_img], verbose=0)
        all_distances.extend(distances.flatten())
        all_labels.extend(labels.numpy())

    all_distances = np.array(all_distances)
    all_labels = np.array(all_labels)

    # Convert distances to similarity scores (0-1 range)
    similarity_scores = 1 / (1 + all_distances)

    # Compute ROC and PR curves using similarity scores
    fpr, tpr, _ = roc_curve(all_labels, similarity_scores)
    precision, recall, _ = precision_recall_curve(all_labels, similarity_scores)
    roc_auc = auc(fpr, tpr)
    pr_auc = auc(recall, precision)

    # Calculate predictions using similarity threshold
    predictions = (similarity_scores >= threshold).astype(int)

    # Calculate metrics
    tp = np.sum((predictions == 1) & (all_labels == 1))
    fp = np.sum((predictions == 1) & (all_labels == 0))
    tn = np.sum((predictions == 0) & (all_labels == 0))
    fn = np.sum((predictions == 0) & (all_labels == 1))
    total = len(all_labels)

    test_metrics = {
        'accuracy': (tp + tn) / total,
        'precision': tp / (tp + fp) if (tp + fp) > 0 else 0,
        'recall': tp / (tp + fn) if (tp + fn) > 0 else 0,
        'f1': 2 * (tp / (tp + fp)) * (tp / (tp + fn)) / ((tp / (tp + fp)) + (tp / (tp + fn))) if (tp + fp) > 0 and (tp + fn) > 0 else 0,
        'roc_auc': roc_auc,
        'pr_auc': pr_auc
    }

    # Plot evaluation curves and distributions
    plt.figure(figsize=(15, 5))

    # ROC curve
    plt.subplot(131)
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Test ROC Curve')
    plt.legend(loc="lower right")

    # Precision-Recall curve
    plt.subplot(132)
    plt.plot(recall, precision, color='blue', lw=2, label=f'PR curve (AUC = {pr_auc:.2f})')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Test Precision-Recall Curve')
    plt.legend(loc="lower right")

    # Similarity distributions
    plt.subplot(133)
    plt.hist(similarity_scores[all_labels == 1], bins=50, alpha=0.5, label='Same Identity', density=True)
    plt.hist(similarity_scores[all_labels == 0], bins=50, alpha=0.5, label='Different Identity', density=True)
    plt.axvline(x=threshold, color='r', linestyle='--', label='Threshold')
    plt.xlabel('Similarity Score')
    plt.ylabel('Density')
    plt.title('Test Similarity Distributions')
    plt.legend()

    plt.tight_layout()
    plt.show()

    print("\nTest Results:")
    for metric, value in test_metrics.items():
        print(f"{metric}: {value:.3f}")

    return test_metrics

### Example usage

In [None]:
v3 = load_model('../results/siamese/contrastive_v3/contrastive_v3.h5',compile=False)
v3.compile()

In [None]:
test_metrics = evaluate_on_test(v3,test_contrastive_dataset)