In [None]:
def evaluate_model(model, X_test, y_test, model_name, plot_curves=True):
    print(f"\nEvaluating {model_name}...")
    print("-" * 40)

    # Predict
    y_pred = model.predict(X_test, batch_size=32, verbose=1)
    y_pred_binary = (y_pred > 0.5).astype(np.float32)

    # Flatten for sklearn metrics
    y_true_flat = y_test.reshape(-1).astype(np.int32)
    y_pred_flat = y_pred_binary.reshape(-1).astype(np.int32)
    y_pred_prob_flat = y_pred.reshape(-1)

    # Calculate metrics
    precision = precision_score(y_true_flat, y_pred_flat, zero_division=0)
    recall = recall_score(y_true_flat, y_pred_flat, zero_division=0)
    f1 = f1_score(y_true_flat, y_pred_flat, zero_division=0)
    iou = jaccard_score(y_true_flat, y_pred_flat, zero_division=0)

    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1-Score: {f1:.4f}")
    print(f"IoU: {iou:.4f}")

    # Calculate dice score
    dice = dice_coefficient(
        tf.convert_to_tensor(y_test),
        tf.convert_to_tensor(y_pred_binary)
    )
    print(f"Dice Coefficient: {dice:.4f}")

    # ROC Curve
    fpr, tpr, _ = roc_curve(y_true_flat, y_pred_prob_flat)
    roc_auc = auc(fpr, tpr)

    # Precision-Recall Curve
    precision_curve, recall_curve, _ = precision_recall_curve(y_true_flat, y_pred_prob_flat)
    pr_auc = auc(recall_curve, precision_curve)

    print(f"ROC AUC: {roc_auc:.4f}")
    print(f"PR AUC: {pr_auc:.4f}")

    # Plot curves if requested
    if plot_curves:
        plot_single_model_curves(y_true_flat, y_pred_prob_flat, model_name, fpr, tpr, roc_auc, recall_curve, precision_curve, pr_auc)

    return {
        'model_name': model_name,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'iou': iou,
        'dice': float(dice),
        'roc_auc': roc_auc,
        'pr_auc': pr_auc,
        'y_pred': y_pred,
        'y_pred_binary': y_pred_binary,
        'fpr': fpr,
        'tpr': tpr,
        'precision_curve': precision_curve,
        'recall_curve': recall_curve
    }

def plot_single_model_curves(y_true, y_pred_prob, model_name, fpr, tpr, roc_auc, recall_curve, precision_curve, pr_auc):
    """Plot ROC and PR curves for a single model"""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # ROC Curve
    axes[0].plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
    axes[0].plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--')
    axes[0].set_xlim([0.0, 1.0])
    axes[0].set_ylim([0.0, 1.05])
    axes[0].set_xlabel('False Positive Rate')
    axes[0].set_ylabel('True Positive Rate')
    axes[0].set_title(f'{model_name} - ROC Curve')
    axes[0].legend(loc="lower right")
    axes[0].grid(True, alpha=0.3)

    # Precision-Recall Curve
    axes[1].plot(recall_curve, precision_curve, color='red', lw=2, label=f'PR curve (AUC = {pr_auc:.3f})')
    axes[1].set_xlim([0.0, 1.0])
    axes[1].set_ylim([0.0, 1.05])
    axes[1].set_xlabel('Recall')
    axes[1].set_ylabel('Precision')
    axes[1].set_title(f'{model_name} - Precision-Recall Curve')
    axes[1].legend(loc="lower left")
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

def plot_roc_curves(all_metrics, y_test):
    """Plot ROC curves for multiple models"""
    plt.figure(figsize=(10, 6))

    for metrics in all_metrics:
        y_true_flat = y_test.reshape(-1).astype(np.int32)
        y_pred_prob_flat = metrics['y_pred'].reshape(-1)

        fpr, tpr, _ = roc_curve(y_true_flat, y_pred_prob_flat)
        roc_auc = auc(fpr, tpr)

        plt.plot(fpr, tpr, label=f"{metrics['model_name']} (AUC = {roc_auc:.3f})", linewidth=2)

    plt.plot([0, 1], [0, 1], 'k--', alpha=0.5, linewidth=1)
    plt.xlabel('False Positive Rate', fontsize=12)
    plt.ylabel('True Positive Rate', fontsize=12)
    plt.title('ROC Curves - Model Comparison', fontsize=14)
    plt.legend(loc='lower right', fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.show()

def plot_pr_curves(all_metrics, y_test):
    """Plot Precision-Recall curves for multiple models"""
    plt.figure(figsize=(10, 6))

    for metrics in all_metrics:
        y_true_flat = y_test.reshape(-1).astype(np.int32)
        y_pred_prob_flat = metrics['y_pred'].reshape(-1)

        precision, recall, _ = precision_recall_curve(y_true_flat, y_pred_prob_flat)
        pr_auc = auc(recall, precision)

        plt.plot(recall, precision, label=f"{metrics['model_name']} (AUC = {pr_auc:.3f})", linewidth=2)

    plt.xlabel('Recall', fontsize=12)
    plt.ylabel('Precision', fontsize=12)
    plt.title('Precision-Recall Curves - Model Comparison', fontsize=14)
    plt.legend(loc='lower left', fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.show()