# Nuclei Segmentation Model Results Visualisation
Compare model predictions against keypoint annotations.
Visualize: Original Image, Cancer Mask, Ground Truth Keypoints, Predicted Masks
Show best, average, and worst performing examples based on F1 score.

In [None]:
import os
import sys
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import json
import random
from pathlib import Path
from collections import defaultdict

import cv2
import matplotlib.pyplot as plt
from scipy.ndimage import label as scipy_label

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['SM_FRAMEWORK'] = 'tf.keras'

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, UpSampling2D, Add, Multiply, concatenate, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.applications import EfficientNetB4
from tensorflow.keras.preprocessing.image import img_to_array, load_img
import segmentation_models as sm

tf.get_logger().setLevel('ERROR')

np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)

BASE_DIR = Path('')
DATA_ROOT = BASE_DIR / ""

MODEL_PATH = BASE_DIR / ''
MODEL_TYPE = 'unet'
BACKBONE = 'seresnet50'

TISSUE_MODEL_SPEC_PATH = BASE_DIR / ''
TISSUE_BACKBONE = 'seresnet50'

IMG_HEIGHT, IMG_WIDTH = 512, 512
NUM_CLASSES = 4
TISSUE_NUM_CLASSES = 3

CLASS_NAMES = ['Background', 'Negative', 'Positive', 'Boundaries']
CLASS_MAPPING = {
    (255, 255, 255): 0, 
    (112, 112, 225): 1, 
    (250, 62, 62): 2,   
    (0, 0, 0): 3,       
}
INVERSE_CLASS_MAPPING = {v: k for k, v in CLASS_MAPPING.items()}

KEYPOINT_TO_CLASS = {
    'negative': 1, 
    'positive': 2, 
}

JSON_CATEGORY_ID_TO_NAME = {
    0: 'negative',
    1: 'positive'
}

CANCER_CLASS_ID = 1

print(" Imports and configuration loaded")
print(f"  Base directory: {BASE_DIR}")
print(f"  Data root: {DATA_ROOT}")
print(f"  Model path: {MODEL_PATH}")
print(f"  Model exists: {MODEL_PATH.exists()}")

In [None]:
def decode_mask_to_colors(mask):
    """Convert class mask to RGB color image"""
    color_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    for class_index, color in INVERSE_CLASS_MAPPING.items():
        color_mask[mask == class_index] = color
    return color_mask

def find_connected_components(mask, class_id):
    """Find connected components for a specific class"""
    binary_mask = (mask == class_id).astype(np.uint8)
    labeled_mask, num_features = scipy_label(binary_mask)
    components = []
    for i in range(1, num_features + 1):
        component_mask = (labeled_mask == i)
        components.append({
            'id': i,
            'mask': component_mask,
            'area': component_mask.sum(),
            'centroid': np.array(np.where(component_mask)).mean(axis=1)[::-1]
        })
    return components

def extract_cancer_region_ensemble(image, model_spec):
    """Extract cancer region using tissue specialist model"""
    input_tensor = np.expand_dims(image, axis=0)
    pred_spec = model_spec.predict(input_tensor, verbose=0)[0]
    mask_spec = np.argmax(pred_spec, axis=-1)
    cancer_mask = (mask_spec == CANCER_CLASS_ID).astype(np.uint8)
    cancer_image = image.copy()
    cancer_image[cancer_mask == 0] = 1.0
    return cancer_mask, cancer_image

def filter_keypoints_by_cancer_mask(keypoints, cancer_mask):
    """Filter keypoints: Keep only those inside the cancer mask"""
    filtered_keypoints = []
    excluded_keypoints = []
    
    for kp in keypoints:
        x_int = int(round(kp['local_x']))
        y_int = int(round(kp['local_y']))
        
        if 0 <= y_int < cancer_mask.shape[0] and 0 <= x_int < cancer_mask.shape[1]:
            if cancer_mask[y_int, x_int] == 1:
                filtered_keypoints.append(kp)
            else:
                excluded_keypoints.append(kp)
        else:
            excluded_keypoints.append(kp)
    
    return filtered_keypoints, excluded_keypoints

def remove_boundaries_from_nuclei(nuclei_mask):
    """Remove boundary class (class 3) from nuclei mask"""
    processed_mask = nuclei_mask.copy()
    processed_mask[processed_mask == 3] = 0
    return processed_mask

def remove_small_objects_from_mask(nuclei_mask, min_size=50):
    """Remove small connected components"""
    processed_mask = nuclei_mask.copy()
    
    for class_id in [1, 2]:  
        binary_mask = (processed_mask == class_id).astype(np.uint8)
        labeled_mask, num_objects = scipy_label(binary_mask)
        
        for obj_id in range(1, num_objects + 1):
            obj_pixels = np.sum(labeled_mask == obj_id)
            if obj_pixels < min_size:
                processed_mask[labeled_mask == obj_id] = 0
    
    return processed_mask

def apply_post_processing(pred_mask, min_object_size=50):
    """Apply full post-processing pipeline"""
    mask_no_boundaries = remove_boundaries_from_nuclei(pred_mask)
    processed_mask = remove_small_objects_from_mask(mask_no_boundaries, min_size=min_object_size)
    return processed_mask

def load_slide_data(data_root):
    """Load all data from test set"""
    data_root = Path(data_root)
    slide_dirs = [d for d in data_root.iterdir() if d.is_dir()]
    
    all_data = []
    print(f"Scanning {len(slide_dirs)} slide directories...")
    
    for slide_dir in sorted(slide_dirs):
        annotations_file = slide_dir / "annotations.json"
        patches_dir = slide_dir / "patches_512_30"
        
        if not annotations_file.exists():
            if (patches_dir / "annotations.json").exists():
                annotations_file = patches_dir / "annotations.json"
            else:
                continue
                
        if not patches_dir.exists():
            continue
        
        try:
            with open(annotations_file, 'r') as f:
                coco_data = json.load(f)
        except json.JSONDecodeError:
            continue
            
        img_id_to_kps = defaultdict(list)
        for ann in coco_data.get('annotations', []):
            img_id = ann['image_id']
            cat_id = ann['category_id']
            kp_x, kp_y = ann['keypoints'][0], ann['keypoints'][1]
            label_name = JSON_CATEGORY_ID_TO_NAME.get(cat_id)
            if label_name:
                img_id_to_kps[img_id].append({
                    'label': label_name,
                    'local_x': kp_x,
                    'local_y': kp_y
                })
        
        for img_entry in coco_data.get('images', []):
            img_id = img_entry['id']
            file_name = Path(img_entry['file_name']).name
            patch_path = patches_dir / file_name
            keypoints = img_id_to_kps.get(img_id, [])
            
            if patch_path.exists() and len(keypoints) > 0:
                all_data.append({
                    'slide_dir': slide_dir,
                    'patch_path': patch_path,
                    'patch_info': {
                        'filename': file_name,
                        'keypoints': keypoints
                    }
                })

    print(f"Found {len(all_data)} valid patches.")
    return all_data

print(" Helper functions defined")

In [None]:
# Attention U-Net architecture

def attention_gate(gating_signal, skip_connection, inter_channels):
    theta_g = Conv2D(inter_channels, kernel_size=1, strides=1, padding="same")(gating_signal)
    theta_g = BatchNormalization()(theta_g)
    phi_x = Conv2D(inter_channels, kernel_size=1, strides=1, padding="same")(skip_connection)
    phi_x = BatchNormalization()(phi_x)
    add_xg = Add()([theta_g, phi_x])
    act_xg = Activation("relu")(add_xg)
    psi = Conv2D(1, kernel_size=1, strides=1, padding="same")(act_xg)
    psi = BatchNormalization()(psi)
    psi = Activation("sigmoid")(psi)
    return Multiply()([skip_connection, psi])

def conv_block(x, filters, kernel_size=3):
    x = Conv2D(filters, kernel_size, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = Conv2D(filters, kernel_size, padding="same")(x)
    x = BatchNormalization()(x)
    return Activation("relu")(x)

def decoder_block(x, skip, filters, use_attention=True, dropout_rate=0.1):
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(filters, kernel_size=2, padding="same")(x)
    if use_attention:
        skip = attention_gate(gating_signal=x, skip_connection=skip, inter_channels=filters // 2)
    x = concatenate([x, skip], axis=-1)
    if dropout_rate > 0:
        x = Dropout(dropout_rate)(x)
    return conv_block(x, filters)

def build_attention_unet(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3), num_classes=NUM_CLASSES, backbone="efficientnetb4"):
    keras.backend.set_image_data_format("channels_last")
    input_layer = Input(shape=(512, 512, 3), name="input_layer")
    base_model = EfficientNetB4(include_top=False, weights=None, input_tensor=input_layer)
    try:
        weights_path = tf.keras.utils.get_file("efficientnetb4_notop.h5", "https://storage.googleapis.com/keras-applications/efficientnetb4_notop.h5", cache_subdir="models")
        base_model.load_weights(weights_path, skip_mismatch=True, by_name=True)
    except:
        pass
    skip_names = ["block2a_expand_activation", "block3a_expand_activation", "block4a_expand_activation", "block6a_expand_activation"]
    skip_connections = [base_model.get_layer(name).output for name in skip_names]
    encoder_model = Model(inputs=base_model.input, outputs=skip_connections + [base_model.output])
    inputs = Input(shape=input_shape, name="input_layer")
    all_outputs = encoder_model(inputs, training=False)
    skip_connections = all_outputs[:-1]
    bottleneck = all_outputs[-1]
    skip4, skip3, skip2, skip1 = skip_connections[3], skip_connections[2], skip_connections[1], skip_connections[0]
    dec4 = decoder_block(bottleneck, skip4, 512, use_attention=True)
    dec3 = decoder_block(dec4, skip3, 256, use_attention=True)
    dec2 = decoder_block(dec3, skip2, 128, use_attention=True)
    dec1 = decoder_block(dec2, skip1, 64, use_attention=True)
    final_up = UpSampling2D(size=(2, 2))(dec1)
    final_conv = Conv2D(64, kernel_size=3, padding="same")(final_up)
    final_conv = BatchNormalization()(final_conv)
    final_conv = Activation("relu")(final_conv)
    outputs = Conv2D(num_classes, kernel_size=1, padding="same", activation="softmax")(final_conv)
    return Model(inputs=inputs, outputs=outputs, name=f"attention_unet_{backbone}")

print(" Attention U-Net architecture defined")

In [None]:
def build_and_load_nuclei_model():
    """Build and load nuclei segmentation model"""
    print(f"Initializing Nuclei Model: {MODEL_TYPE} with {BACKBONE}...")
    
    if 'attention' in MODEL_TYPE.lower():
        model = build_attention_unet(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3), num_classes=NUM_CLASSES, backbone=BACKBONE)
    else:
        model = sm.Unet(backbone_name=BACKBONE, input_shape=(IMG_HEIGHT, IMG_WIDTH, 3), classes=NUM_CLASSES, activation='softmax', encoder_weights=None)
    
    print(f"Loading weights from: {MODEL_PATH}")
    try:
        model.load_weights(str(MODEL_PATH))
        print(" Weights loaded successfully")
    except Exception as e:
        print(f"Standard load failed ({e}). Trying with skip_mismatch=True...")
        model.load_weights(str(MODEL_PATH), skip_mismatch=True)
        print(" Weights loaded with skip_mismatch")
    
    return model

def load_tissue_model():
    """Load tissue specialist model"""
    print(f"Loading Tissue Specialist Model (Backbone: {TISSUE_BACKBONE})...")
    model_spec = sm.Unet(backbone_name=TISSUE_BACKBONE, input_shape=(IMG_HEIGHT, IMG_WIDTH, 3), classes=TISSUE_NUM_CLASSES, activation='softmax', encoder_weights=None)
    model_spec.load_weights(str(TISSUE_MODEL_SPEC_PATH))
    print(" Tissue model loaded")
    return model_spec

print("Loading models...")
nuclei_model = build_and_load_nuclei_model()
tissue_model = load_tissue_model()
print(" All models loaded successfully")

In [None]:
from sklearn.metrics import f1_score

def check_component_has_keypoint(component, keypoints, expected_class, proximity_threshold=3):
    """Check if a component has a keypoint inside it or within proximity_threshold pixels"""
    component_mask = component['mask']
    
    for kp in keypoints:
        if kp['label'] == expected_class:
            x_int = int(round(kp['local_x']))
            y_int = int(round(kp['local_y']))
            
            if 0 <= y_int < component_mask.shape[0] and 0 <= x_int < component_mask.shape[1]:
                if component_mask[y_int, x_int]:
                    return True
                
                y_min = max(0, y_int - proximity_threshold)
                y_max = min(component_mask.shape[0], y_int + proximity_threshold + 1)
                x_min = max(0, x_int - proximity_threshold)
                x_max = min(component_mask.shape[1], x_int + proximity_threshold + 1)
                
                window = component_mask[y_min:y_max, x_min:x_max]
                if np.any(window):
                    return True
    
    return False

def calculate_keypoint_coverage_f1(keypoints, pred_mask):
    """
    Calculate F1-like score for keypoint coverage.
    
    Metrics:
    - Precision: How many predicted nuclei contain keypoints (mask validity)
    - Recall: How many keypoints are correctly covered
    - F1: Harmonic mean
    """
    if len(keypoints) == 0:
        return 0.0, 0.0, 0.0, {}
    
    correct_keypoints = 0
    for kp in keypoints:
        label = kp['label']
        expected_class_id = KEYPOINT_TO_CLASS[label]
        x_int = int(round(kp['local_x']))
        y_int = int(round(kp['local_y']))
        
        if 0 <= x_int < pred_mask.shape[1] and 0 <= y_int < pred_mask.shape[0]:
            if pred_mask[y_int, x_int] == expected_class_id:
                correct_keypoints += 1
    
    recall = correct_keypoints / len(keypoints) if len(keypoints) > 0 else 0.0
    
    total_predicted_nuclei = 0
    valid_predicted_nuclei = 0
    
    for class_name, class_id in [('negative', 1), ('positive', 2)]:
        components = find_connected_components(pred_mask, class_id)
        total_predicted_nuclei += len(components)
        
        for component in components:
            if check_component_has_keypoint(component, keypoints, class_name):
                valid_predicted_nuclei += 1
    
    precision = valid_predicted_nuclei / total_predicted_nuclei if total_predicted_nuclei > 0 else 0.0
    
    if precision + recall > 0:
        f1 = 2 * (precision * recall) / (precision + recall)
    else:
        f1 = 0.0
    
    details = {
        'precision': precision,
        'recall': recall,
        'correct_keypoints': correct_keypoints,
        'total_keypoints': len(keypoints),
        'valid_predicted_nuclei': valid_predicted_nuclei,
        'total_predicted_nuclei': total_predicted_nuclei
    }
    
    return f1, precision, recall, details


def evaluate_patch_with_metrics_extended(image_path, keypoints, nuclei_model, tissue_model):
    """Evaluate a single patch and return metrics INCLUDING raw prediction"""
    img_raw = img_to_array(load_img(str(image_path), target_size=(512, 512)))
    img_norm = img_raw / 255.0
    
    cancer_mask, cancer_image = extract_cancer_region_ensemble(img_norm, tissue_model)
    
    filtered_kps, excluded_kps = filter_keypoints_by_cancer_mask(keypoints, cancer_mask)
    
    if 'efficientnet' in BACKBONE:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        inp = (cancer_image - mean) / std
    else:
        inp = cancer_image
    
    pred = nuclei_model.predict(np.expand_dims(inp, axis=0), verbose=0)[0]
    pred_mask_raw = np.argmax(pred, axis=-1)
    
    pred_mask_processed = apply_post_processing(pred_mask_raw, min_object_size=50)
    
    f1, precision, recall, details = calculate_keypoint_coverage_f1(filtered_kps, pred_mask_processed)
    
    return {
        'image': img_raw.astype(np.uint8),
        'cancer_mask': cancer_mask,
        'pred_mask_raw': pred_mask_raw,
        'pred_mask': pred_mask_processed,
        'keypoints': filtered_kps,
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'details': details
    }

print(" Metrics functions defined")

In [None]:
from scipy.stats import pearsonr
from scipy.spatial.distance import cdist

def calculate_spatial_correlation_metrics(keypoints, pred_mask):
    if len(keypoints) == 0:
        return None
    
    gt_negative_count = sum(1 for kp in keypoints if kp['label'] == 'negative')
    gt_positive_count = sum(1 for kp in keypoints if kp['label'] == 'positive')
    
    neg_components = find_connected_components(pred_mask, 1)
    pos_components = find_connected_components(pred_mask, 2)
    pred_negative_count = len(neg_components)
    pred_positive_count = len(pos_components)
    
    spatial_metrics = {'negative': [], 'positive': []}
    
    for class_name, class_id, components in [
        ('negative', 1, neg_components),
        ('positive', 2, pos_components)
    ]:
        class_keypoints = [kp for kp in keypoints if kp['label'] == class_name]
        
        if len(class_keypoints) > 0 and len(components) > 0:
            kp_coords = np.array([[kp['local_x'], kp['local_y']] for kp in class_keypoints])
            
            centroids = np.array([comp['centroid'] for comp in components])
            
            distances = cdist(kp_coords, centroids, metric='euclidean')
            
            min_distances = distances.min(axis=1)
            spatial_metrics[class_name] = min_distances.tolist()
    
    gt_ratio = gt_positive_count / max(gt_negative_count, 1)
    pred_ratio = pred_positive_count / max(pred_negative_count, 1)
    
    return {
        'count_metrics': {
            'gt_negative': gt_negative_count,
            'gt_positive': gt_positive_count,
            'pred_negative': pred_negative_count,
            'pred_positive': pred_positive_count,
            'gt_total': gt_negative_count + gt_positive_count,
            'pred_total': pred_negative_count + pred_positive_count
        },
        'spatial_metrics': {
            'negative_distances': spatial_metrics['negative'],
            'positive_distances': spatial_metrics['positive'],
            'negative_mean_dist': np.mean(spatial_metrics['negative']) if spatial_metrics['negative'] else None,
            'negative_median_dist': np.median(spatial_metrics['negative']) if spatial_metrics['negative'] else None,
            'positive_mean_dist': np.mean(spatial_metrics['positive']) if spatial_metrics['positive'] else None,
            'positive_median_dist': np.median(spatial_metrics['positive']) if spatial_metrics['positive'] else None,
        },
        'ratio_metrics': {
            'gt_pos_neg_ratio': gt_ratio,
            'pred_pos_neg_ratio': pred_ratio,
            'ratio_difference': abs(gt_ratio - pred_ratio)
        }
    }

def aggregate_correlation_metrics(all_patch_metrics):
    """
    Aggregate correlation metrics across all patches.
    Calculate overall Pearson correlation for counts.
    """
    if not all_patch_metrics:
        return None
    
    gt_neg_counts = [m['count_metrics']['gt_negative'] for m in all_patch_metrics]
    gt_pos_counts = [m['count_metrics']['gt_positive'] for m in all_patch_metrics]
    gt_total_counts = [m['count_metrics']['gt_total'] for m in all_patch_metrics]
    
    pred_neg_counts = [m['count_metrics']['pred_negative'] for m in all_patch_metrics]
    pred_pos_counts = [m['count_metrics']['pred_positive'] for m in all_patch_metrics]
    pred_total_counts = [m['count_metrics']['pred_total'] for m in all_patch_metrics]
    
    results = {}
    
    if len(gt_neg_counts) > 1:
        pearson_neg, p_pearson_neg = pearsonr(gt_neg_counts, pred_neg_counts)
        results['negative'] = {
            'pearson': pearson_neg,
            'pearson_p': p_pearson_neg
        }
    
    if len(gt_pos_counts) > 1:
        pearson_pos, p_pearson_pos = pearsonr(gt_pos_counts, pred_pos_counts)
        results['positive'] = {
            'pearson': pearson_pos,
            'pearson_p': p_pearson_pos
        }
    
    if len(gt_total_counts) > 1:
        pearson_total, p_pearson_total = pearsonr(gt_total_counts, pred_total_counts)
        results['total'] = {
            'pearson': pearson_total,
            'pearson_p': p_pearson_total
        }
    
    all_neg_distances = [d for m in all_patch_metrics for d in m['spatial_metrics']['negative_distances']]
    all_pos_distances = [d for m in all_patch_metrics for d in m['spatial_metrics']['positive_distances']]
    
    results['spatial_aggregated'] = {
        'negative_mean_dist': np.mean(all_neg_distances) if all_neg_distances else None,
        'negative_median_dist': np.median(all_neg_distances) if all_neg_distances else None,
        'positive_mean_dist': np.mean(all_pos_distances) if all_pos_distances else None,
        'positive_median_dist': np.median(all_pos_distances) if all_pos_distances else None,
    }
    
    results['raw_data'] = {
        'gt_negative': gt_neg_counts,
        'gt_positive': gt_pos_counts,
        'gt_total': gt_total_counts,
        'pred_negative': pred_neg_counts,
        'pred_positive': pred_pos_counts,
        'pred_total': pred_total_counts,
        'negative_distances': all_neg_distances,
        'positive_distances': all_pos_distances
    }
    
    return results

print(" Correlation metrics functions defined")

In [None]:
def plot_count_correlation(correlation_results, save_path=None):
    """
    Plot scatter plots showing correlation between GT and predicted nuclei counts.
    Creates 3 subplots: Negative, Positive, and Total nuclei counts.
    """
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    raw_data = correlation_results['raw_data']
    
    ax = axes[0]
    ax.scatter(raw_data['gt_negative'], raw_data['pred_negative'], alpha=0.6, s=50, edgecolors='k', linewidth=0.5)
    
    max_val = max(max(raw_data['gt_negative']), max(raw_data['pred_negative']))
    ax.plot([0, max_val], [0, max_val], 'r--', linewidth=2, label='Perfect correlation')
    
    z = np.polyfit(raw_data['gt_negative'], raw_data['pred_negative'], 1)
    p = np.poly1d(z)
    ax.plot(raw_data['gt_negative'], p(raw_data['gt_negative']), 'b-', linewidth=2, alpha=0.7, label='Best fit')
    
    ax.set_xlabel('Ground Truth Count', fontsize=12)
    ax.set_ylabel('Predicted Count', fontsize=12)
    ax.set_title(f"Negative Nuclei Count Correlation\nPearson r={correlation_results['negative']['pearson']:.3f} (p={correlation_results['negative']['pearson_p']:.4f})", fontsize=11)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    ax = axes[1]
    ax.scatter(raw_data['gt_positive'], raw_data['pred_positive'], alpha=0.6, s=50, color='red', edgecolors='k', linewidth=0.5)
    
    max_val = max(max(raw_data['gt_positive']), max(raw_data['pred_positive']))
    ax.plot([0, max_val], [0, max_val], 'darkred', linestyle='--', linewidth=2, label='Perfect correlation')
    
    z = np.polyfit(raw_data['gt_positive'], raw_data['pred_positive'], 1)
    p = np.poly1d(z)
    ax.plot(raw_data['gt_positive'], p(raw_data['gt_positive']), 'orange', linewidth=2, alpha=0.7, label='Best fit')
    
    ax.set_xlabel('Ground Truth Count', fontsize=12)
    ax.set_ylabel('Predicted Count', fontsize=12)
    ax.set_title(f"Positive Nuclei Count Correlation\nPearson r={correlation_results['positive']['pearson']:.3f} (p={correlation_results['positive']['pearson_p']:.4f})", fontsize=11)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    ax = axes[2]
    ax.scatter(raw_data['gt_total'], raw_data['pred_total'], alpha=0.6, s=50, color='green', edgecolors='k', linewidth=0.5)
    
    max_val = max(max(raw_data['gt_total']), max(raw_data['pred_total']))
    ax.plot([0, max_val], [0, max_val], 'darkgreen', linestyle='--', linewidth=2, label='Perfect correlation')
    
    z = np.polyfit(raw_data['gt_total'], raw_data['pred_total'], 1)
    p = np.poly1d(z)
    ax.plot(raw_data['gt_total'], p(raw_data['gt_total']), 'lime', linewidth=2, alpha=0.7, label='Best fit')
    
    ax.set_xlabel('Ground Truth Count', fontsize=12)
    ax.set_ylabel('Predicted Count', fontsize=12)
    ax.set_title(f"Total Nuclei Count Correlation\nPearson r={correlation_results['total']['pearson']:.3f} (p={correlation_results['total']['pearson_p']:.4f})", fontsize=11)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def plot_spatial_distance_distribution(correlation_results, save_path=None):
    """
    Plot histograms showing distribution of distances from keypoints to nearest predicted nucleus.
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    raw_data = correlation_results['raw_data']
    spatial_agg = correlation_results['spatial_aggregated']
    
    ax = axes[0]
    if raw_data['negative_distances']:
        ax.hist(raw_data['negative_distances'], bins=30, color='blue', alpha=0.7, edgecolor='black')
        ax.axvline(spatial_agg['negative_mean_dist'], color='red', linestyle='--', linewidth=2, label=f"Mean: {spatial_agg['negative_mean_dist']:.2f}px")
        ax.axvline(spatial_agg['negative_median_dist'], color='orange', linestyle='--', linewidth=2, label=f"Median: {spatial_agg['negative_median_dist']:.2f}px")
        ax.set_xlabel('Distance (pixels)', fontsize=12)
        ax.set_ylabel('Frequency', fontsize=12)
        ax.set_title(f"Negative Nuclei: Distance from Keypoint to Nearest Prediction\n(n={len(raw_data['negative_distances'])} keypoints)", fontsize=11)
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')
    else:
        ax.text(0.5, 0.5, 'No negative keypoints', ha='center', va='center', transform=ax.transAxes, fontsize=14)
        ax.set_title("Negative Nuclei: No Data", fontsize=11)
    
    ax = axes[1]
    if raw_data['positive_distances']:
        ax.hist(raw_data['positive_distances'], bins=30, color='red', alpha=0.7, edgecolor='black')
        ax.axvline(spatial_agg['positive_mean_dist'], color='darkred', linestyle='--', linewidth=2, label=f"Mean: {spatial_agg['positive_mean_dist']:.2f}px")
        ax.axvline(spatial_agg['positive_median_dist'], color='orange', linestyle='--', linewidth=2, label=f"Median: {spatial_agg['positive_median_dist']:.2f}px")
        ax.set_xlabel('Distance (pixels)', fontsize=12)
        ax.set_ylabel('Frequency', fontsize=12)
        ax.set_title(f"Positive Nuclei: Distance from Keypoint to Nearest Prediction\n(n={len(raw_data['positive_distances'])} keypoints)", fontsize=11)
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')
    else:
        ax.text(0.5, 0.5, 'No positive keypoints', ha='center', va='center', transform=ax.transAxes, fontsize=14)
        ax.set_title("Positive Nuclei: No Data", fontsize=11)
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def print_correlation_summary(correlation_results):
    """Print comprehensive correlation metrics summary."""
    print("\n" + "="*80)
    print("CORRELATION METRICS SUMMARY")
    print("="*80)
    
    print("\nCount Correlation (Pearson):")
    print("-" * 80)
    for nuclei_type in ['negative', 'positive', 'total']:
        if nuclei_type in correlation_results:
            metrics = correlation_results[nuclei_type]
            print(f"\n  {nuclei_type.upper()} Nuclei:")
            print(f"    Pearson correlation:  r = {metrics['pearson']:.4f}  (p-value: {metrics['pearson_p']:.4e})")
            
            r_val = abs(metrics['pearson'])
            if r_val >= 0.9:
                interp = "Very strong"
            elif r_val >= 0.7:
                interp = "Strong"
            elif r_val >= 0.5:
                interp = "Moderate"
            elif r_val >= 0.3:
                interp = "Weak"
            else:
                interp = "Very weak"
            print(f"    Interpretation: {interp} correlation")
    
    print("\n\nSpatial Distance Metrics (Keypoint to Nearest Prediction):")
    print("-" * 80)
    spatial = correlation_results['spatial_aggregated']
    
    if spatial['negative_mean_dist'] is not None:
        print(f"\n  NEGATIVE Nuclei:")
        print(f"    Mean distance:   {spatial['negative_mean_dist']:.2f} pixels")
        print(f"    Median distance: {spatial['negative_median_dist']:.2f} pixels")
    
    if spatial['positive_mean_dist'] is not None:
        print(f"\n  POSITIVE Nuclei:")
        print(f"    Mean distance:   {spatial['positive_mean_dist']:.2f} pixels")
        print(f"    Median distance: {spatial['positive_median_dist']:.2f} pixels")
    
    print("\n" + "="*80)

print(" Correlation visualization functions defined")

In [None]:
def visualize_prediction_with_filtering(result, title_prefix=""):
    """
    Visualize prediction with 4 panels showing post-processing filtering:
    1. Original Image (input)
    2. Ground Truth (Cancer Region + GT Keypoints)
    3. Prediction with Filtering Cues (Green=Accepted, Red=Rejected)
    4. Final Clean Result vs GT (Error Analysis)
    """
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    img = result['image']
    cancer_mask = result['cancer_mask']
    pred_mask_raw = result.get('pred_mask_raw', result['pred_mask'])  
    pred_mask_final = result['pred_mask']  
    keypoints = result['keypoints']
    f1 = result['f1']
    precision = result['precision']
    recall = result['recall']
    
    axes[0].imshow(img)
    axes[0].set_title("Input Image", fontsize=11, fontweight='bold')
    axes[0].axis('off')
    
    overlay = np.zeros_like(img)
    overlay[cancer_mask == 1] = [255, 255, 0]
    panel2 = img.copy()
    mask_indices = cancer_mask == 1
    
    if np.any(mask_indices):
        alpha = 0.3
        for c in range(3): 
            panel2[:, :, c][mask_indices] = (
                img[:, :, c][mask_indices] * (1-alpha) + 
                overlay[:, :, c][mask_indices] * alpha
            ).astype(np.uint8)
    
    axes[1].imshow(panel2)
    for kp in keypoints:
        color = 'red' if kp['label'] == 'positive' else 'blue'
        axes[1].plot(kp['local_x'], kp['local_y'], 'o', color=color, markersize=5, 
                    markeredgecolor='white', markeredgewidth=0.8)
    
    gt_pos = sum(1 for kp in keypoints if kp['label'] == 'positive')
    gt_neg = sum(1 for kp in keypoints if kp['label'] == 'negative')
    axes[1].set_title(f"Ground Truth\nCancer ROI + Keypoints (Pos={gt_pos}, Neg={gt_neg})", 
                     fontsize=11, fontweight='bold')
    axes[1].axis('off')
    
    panel3 = img.copy()
    
    removed_mask = (pred_mask_raw = 0) & (pred_mask_final == 0)
    
    accepted_neg = (pred_mask_final == 1)  
    accepted_pos = (pred_mask_final == 2)  
    
    rejected_objects = removed_mask
    
    overlay3 = np.zeros_like(img)
    overlay3[accepted_neg] = [0, 255, 0]      
    overlay3[accepted_pos] = [0, 200, 0]      
    overlay3[rejected_objects] = [255, 100, 0]  
    
    alpha = 0.5
    panel3 = cv2.addWeighted(img, 1-alpha, overlay3, alpha, 0)
    
    axes[2].imshow(panel3)
    
    rejected_count = np.sum(removed_mask > 0)
    accepted_count_neg = len(find_connected_components(pred_mask_final, 1))
    accepted_count_pos = len(find_connected_components(pred_mask_final, 2))
    
    axes[2].set_title(f"Prediction with Filtering\n Accepted: {accepted_count_neg + accepted_count_pos} |  Rejected: {rejected_count} px", 
                     fontsize=11, fontweight='bold')
    axes[2].axis('off')
    
    panel4 = decode_mask_to_colors(pred_mask_final)
    
    for kp in keypoints:
        label = kp['label']
        expected_class_id = KEYPOINT_TO_CLASS[label]
        x_int = int(round(kp['local_x']))
        y_int = int(round(kp['local_y']))
        
        if 0 <= x_int < pred_mask_final.shape[1] and 0 <= y_int < pred_mask_final.shape[0]:
            is_correct = pred_mask_final[y_int, x_int] == expected_class_id
            
            if is_correct:
                marker = 'o'
                color = 'lime'
                markersize = 6
            else:
                marker = 'x'
                color = 'red'
                markersize = 8
            
            axes[3].imshow(panel4)
            axes[3].plot(kp['local_x'], kp['local_y'], marker, color=color, 
                        markersize=markersize, markeredgecolor='white', 
                        markeredgewidth=1.2, alpha=0.9)
    
    total_pred_nuclei = accepted_count_neg + accepted_count_pos
    correctly_covered = result['details']['correct_keypoints']
    false_negatives = len(keypoints) - correctly_covered
    false_positives = total_pred_nuclei - result['details']['valid_predicted_nuclei']
    
    axes[3].set_title(f"Final Result - Error Analysis\nF1={f1:.3f} | FP={false_positives} | FN={false_negatives}", 
                     fontsize=11, fontweight='bold')
    axes[3].axis('off')
    
    if title_prefix:
        fig.suptitle(title_prefix, fontsize=14, fontweight='bold', y=0.98)
    
    plt.tight_layout()
    plt.show()


print(" Visualization function defined")

In [None]:
import pandas as pd
import heapq

def evaluate_and_select_examples(min_keypoints=20, n_candidates=10, collect_correlation_metrics=False):
    """
    Memory-efficient evaluation: process patches sequentially and track only
    top candidates for visualization (best, average, worst with many keypoints).
    
    Args:
        min_keypoints: Minimum keypoints required to consider a patch
        n_candidates: How many candidates to track per category
        collect_correlation_metrics: If True, collect correlation metrics for all patches
    """
    print("Loading test data...")
    data_list = load_slide_data(DATA_ROOT)
    
    best_candidates = []  
    worst_candidates = [] 
    
    all_f1_scores = []
    patch_metadata = []  
 
    correlation_metrics_list = [] if collect_correlation_metrics else None
    
    print(f"\nPhase 1: Scanning {len(data_list)} patches for F1 distribution...")
    
    for i, data in enumerate(data_list):
        if (i+1) % 50 == 0:
            print(f"  Progress: {i+1}/{len(data_list)} patches processed")
        
        try:
            img_raw = img_to_array(load_img(str(data['patch_path']), target_size=(512, 512)))
            img_norm = img_raw / 255.0
            cancer_mask, cancer_image = extract_cancer_region_ensemble(img_norm, tissue_model)
            filtered_kps, _ = filter_keypoints_by_cancer_mask(data['patch_info']['keypoints'], cancer_mask)
            
            if len(filtered_kps) < min_keypoints:
                continue
            
            # Quick prediction to get F1
            if 'efficientnet' in BACKBONE:
                mean = np.array([0.485, 0.456, 0.406])
                std = np.array([0.229, 0.224, 0.225])
                inp = (cancer_image - mean) / std
            else:
                inp = cancer_image
            
            pred = nuclei_model.predict(np.expand_dims(inp, axis=0), verbose=0)[0]
            pred_mask = np.argmax(pred, axis=-1)
            pred_mask_processed = apply_post_processing(pred_mask, min_object_size=50)
            
            f1, precision, recall, details = calculate_keypoint_coverage_f1(filtered_kps, pred_mask_processed)
            
            if collect_correlation_metrics:
                corr_metrics = calculate_spatial_correlation_metrics(filtered_kps, pred_mask_processed)
                if corr_metrics:
                    correlation_metrics_list.append(corr_metrics)
            
            patch_metadata.append({
                'f1': f1,
                'patch_path': data['patch_path'],
                'keypoints': data['patch_info']['keypoints'],
                'num_keypoints': len(filtered_kps),
                'filename': data['patch_info']['filename'],
                'slide_dir': data['slide_dir'].name
            })
            all_f1_scores.append(f1)
            
            if len(best_candidates) < n_candidates:
                heapq.heappush(best_candidates, (f1, len(best_candidates), data['patch_path'], data['patch_info']))  # min-heap
            elif f1 > best_candidates[0][0]:
                heapq.heapreplace(best_candidates, (f1, len(best_candidates), data['patch_path'], data['patch_info']))
            
            if len(worst_candidates) < n_candidates:
                heapq.heappush(worst_candidates, (-f1, len(worst_candidates), data['patch_path'], data['patch_info']))  # max-heap
            elif f1 < -worst_candidates[0][0]:
                heapq.heapreplace(worst_candidates, (-f1, len(worst_candidates), data['patch_path'], data['patch_info']))
            
        except Exception as e:
            print(f"  Error processing patch: {e}")
            continue
    
    print(f"\n Phase 1 complete: {len(patch_metadata)} patches with >={min_keypoints} keypoints found")
    
    if len(all_f1_scores) == 0:
        print("No valid patches found")
        return None, None, None
    
    median_f1 = np.median(all_f1_scores)
    print(f"\nF1 Score distribution:")
    print(f"  Mean: {np.mean(all_f1_scores):.4f} ± {np.std(all_f1_scores):.4f}")
    print(f"  Median: {median_f1:.4f}")
    print(f"  Range: [{np.min(all_f1_scores):.4f}, {np.max(all_f1_scores):.4f}]")
    
    median_tolerance = 0.05
    average_candidates = []
    for meta in patch_metadata:
        if abs(meta['f1'] - median_f1) <= median_tolerance:
            average_candidates.append(meta)
    
    average_candidates.sort(key=lambda x: x['num_keypoints'], reverse=True)
    average_candidates = average_candidates[:n_candidates]
    
    print(f"\nFound candidates:")
    print(f"  Best: {len(best_candidates)} (F1 range: {min(x[0] for x in best_candidates):.4f} - {max(x[0] for x in best_candidates):.4f})")
    print(f"  Average: {len(average_candidates)} (F1 around {median_f1:.4f} ± {median_tolerance})")
    print(f"  Worst: {len(worst_candidates)} (F1 range: {min(-x[0] for x in worst_candidates):.4f} - {max(-x[0] for x in worst_candidates):.4f})")
    
    if collect_correlation_metrics:
        return best_candidates, average_candidates, worst_candidates, correlation_metrics_list
    else:
        return best_candidates, average_candidates, worst_candidates

def visualize_selected_examples_with_filtering(best_candidates, average_candidates, worst_candidates, n_show=3):
    """
    Visualize pre-selected examples with filtering visualization.
    """
    
    def evaluate_and_visualize(path, keypoints, title_prefix, f1_expected):
        """Helper to evaluate a specific patch and visualize it with filtering"""
        result = evaluate_patch_with_metrics_extended(path, keypoints, nuclei_model, tissue_model)
        print(f"\n{title_prefix}")
        print(f"  F1: {result['f1']:.4f} | Precision: {result['precision']:.4f} | Recall: {result['recall']:.4f}")
        print(f"  Keypoints: {len(result['keypoints'])}")
        visualize_prediction_with_filtering(result, title_prefix)
    
    print("\n" + "="*80)
    print(f"TOP {n_show} BEST PERFORMING PATCHES (High F1, Many Keypoints)")
    print("="*80)
    
    best_sorted = sorted(best_candidates, key=lambda x: x[0], reverse=True)[:n_show]
    for i, (f1, _, path, patch_info) in enumerate(best_sorted, 1):
        title = f"Best #{i} | Expected F1: {f1:.4f}"
        evaluate_and_visualize(path, patch_info['keypoints'], title, f1)
    
    print("\n" + "="*80)
    print(f"{n_show} AVERAGE PERFORMING PATCHES (Median F1, Many Keypoints)")
    print("="*80)
    
    average_to_show = average_candidates[:n_show]
    for i, meta in enumerate(average_to_show, 1):
        title = f"Average #{i} | Expected F1: {meta['f1']:.4f}"
        evaluate_and_visualize(meta['patch_path'], meta['keypoints'], title, meta['f1'])
    
    print("\n" + "="*80)
    print(f"TOP {n_show} WORST PERFORMING PATCHES (Low F1, Many Keypoints)")
    print("="*80)
    
    worst_sorted = sorted(worst_candidates, key=lambda x: x[0], reverse=True)[:n_show]
    for i, (neg_f1, _, path, patch_info) in enumerate(worst_sorted, 1):
        f1 = -neg_f1
        title = f"Worst #{i} | Expected F1: {f1:.4f}"
        evaluate_and_visualize(path, patch_info['keypoints'], title, f1)


print(" Evaluation and showcase functions defined")

In [None]:
print("Starting memory-efficient evaluation...")
print("This will find best/average/worst patches with many keypoints (≥20)")
print("\nNote: Only candidate patches are tracked, not all results stored in memory.\n")

results = evaluate_and_select_examples(
    min_keypoints=100,  # Only consider patches with ≥100 keypoints
    n_candidates=10,    # Track top 10 candidates per category
    collect_correlation_metrics=False
)

best_candidates, average_candidates, worst_candidates = results

if best_candidates is not None:
    visualize_selected_examples_with_filtering(
        best_candidates, 
        average_candidates, 
        worst_candidates, 
        n_show=3
    )
    print("\n Visualization complete")
else:
    print("\n✗ No valid patches found. Try lowering min_keypoints threshold.")

In [None]:
# Run correlation analysis on test set

print("="*80)
print(" CORRELATION ANALYSIS")
print("="*80)
print("\nEvaluating all patches to calculate correlation metrics...")
print("This analyzes the relationship between keypoint annotations and model predictions.\n")

results = evaluate_and_select_examples(
    min_keypoints=20,  # Include patches with ≥20 keypoints
    n_candidates=10,
    collect_correlation_metrics=True
)

if len(results) == 4:
    best_candidates, average_candidates, worst_candidates, correlation_metrics_list = results
    
    print(f"\n Collected correlation metrics from {len(correlation_metrics_list)} patches")
    print("Aggregating results...")
    
    correlation_results = aggregate_correlation_metrics(correlation_metrics_list)
    
    print_correlation_summary(correlation_results)
    
    print("\n\n Generating correlation plots...")
    print("-" * 80)
    
    print("\nCount Correlation Plots (Scatter plots with best fit lines)")
    plot_count_correlation(correlation_results)
    
    
    print("\nSpatial Distance Distribution (Keypoint to Nearest Prediction)")
    plot_spatial_distance_distribution(correlation_results)
    
    print("\n\n Correlation analysis complete")
    print("="*80)
    
else:
    print("Correlation metrics were not collected. Re-run with collect_correlation_metrics=True")
    best_candidates, average_candidates, worst_candidates = results

In [None]:
# F1 Score Distribution Analysis

def plot_f1_distribution(patch_metadata, save_path=None):
    """
    Plot histogram showing F1 score distribution across all evaluated patches.
    
    Args:
        patch_metadata: List of dictionaries containing 'f1' scores for each patch
        save_path: Optional path to save the figure
    """
    f1_scores = [meta['f1'] for meta in patch_metadata]
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    ax1 = axes[0]
    
    n, bins, patches_hist = ax1.hist(f1_scores, bins=30, color='steelblue', alpha=0.7, 
                                      edgecolor='black', linewidth=1.2)
    
    mean_f1 = np.mean(f1_scores)
    median_f1 = np.median(f1_scores)
    std_f1 = np.std(f1_scores)
    min_f1 = np.min(f1_scores)
    max_f1 = np.max(f1_scores)
    
    ax1.axvline(mean_f1, color='red', linestyle='--', linewidth=2.5, 
               label=f'Mean: {mean_f1:.4f}', alpha=0.8)
    ax1.axvline(median_f1, color='orange', linestyle='--', linewidth=2.5, 
               label=f'Median: {median_f1:.4f}', alpha=0.8)
    
    ax1.axvspan(mean_f1 - std_f1, mean_f1 + std_f1, alpha=0.2, color='red', 
               label=f'±1 SD: [{mean_f1-std_f1:.4f}, {mean_f1+std_f1:.4f}]')
    
    ax1.set_xlabel('F1 Score', fontsize=13, fontweight='bold')
    ax1.set_ylabel('Number of Patches', fontsize=13, fontweight='bold')
    ax1.set_title(f'F1 Score Distribution\n(n={len(f1_scores)} patches)', 
                 fontsize=14, fontweight='bold')
    ax1.legend(loc='upper left', fontsize=11, framealpha=0.9)
    ax1.grid(True, alpha=0.3, axis='y')
    
    stats_text = f"""Statistics:
Mean:   {mean_f1:.4f}
Median: {median_f1:.4f}
Std:    {std_f1:.4f}
Min:    {min_f1:.4f}
Max:    {max_f1:.4f}

Quartiles:
Q1:     {np.percentile(f1_scores, 25):.4f}
Q3:     {np.percentile(f1_scores, 75):.4f}
IQR:    {np.percentile(f1_scores, 75) - np.percentile(f1_scores, 25):.4f}"""
    
    props = dict(boxstyle='round', facecolor='lightyellow', alpha=0.9, edgecolor='black', linewidth=1.5)
    ax1.text(0.98, 0.97, stats_text, transform=ax1.transAxes, fontsize=10,
            verticalalignment='top', horizontalalignment='right', bbox=props, fontfamily='monospace')
    
    ax2 = axes[1]
    
    parts = ax2.violinplot([f1_scores], positions=[1], widths=0.7, 
                           showmeans=True, showmedians=True, showextrema=True)
    
    # Customize violin plot colors
    for pc in parts['bodies']:
        pc.set_facecolor('steelblue')
        pc.set_alpha(0.6)
        pc.set_edgecolor('black')
        pc.set_linewidth(1.5)
    
    for partname in ('cbars', 'cmins', 'cmaxes', 'cmedians', 'cmeans'):
        if partname in parts:
            vp = parts[partname]
            vp.set_edgecolor('black')
            vp.set_linewidth(2)
    
    bp = ax2.boxplot([f1_scores], positions=[1], widths=0.3, 
                     patch_artist=True, showfliers=True,
                     boxprops=dict(facecolor='orange', alpha=0.7, edgecolor='black', linewidth=2),
                     medianprops=dict(color='red', linewidth=2.5),
                     whiskerprops=dict(color='black', linewidth=1.5),
                     capprops=dict(color='black', linewidth=1.5),
                     flierprops=dict(marker='o', markerfacecolor='red', markersize=6, 
                                   markeredgecolor='black', alpha=0.5))
    
    ax2.set_ylabel('F1 Score', fontsize=13, fontweight='bold')
    ax2.set_title('F1 Score Distribution (Violin + Box Plot)', fontsize=14, fontweight='bold')
    ax2.set_xticks([1])
    ax2.set_xticklabels(['All Patches'])
    ax2.grid(True, alpha=0.3, axis='y')
    ax2.set_ylim([0, 1])
    
    q1 = np.percentile(f1_scores, 25)
    q3 = np.percentile(f1_scores, 75)
    ax2.axhline(mean_f1, color='red', linestyle='--', linewidth=1.5, alpha=0.5, label='Mean')
    ax2.axhline(median_f1, color='orange', linestyle='--', linewidth=1.5, alpha=0.5, label='Median')
    ax2.legend(loc='lower right', fontsize=10)
    
    ax2.text(1.4, median_f1, f'Median\n{median_f1:.3f}', va='center', fontsize=9, fontweight='bold')
    ax2.text(1.4, q1, f'Q1\n{q1:.3f}', va='center', fontsize=9, color='darkblue')
    ax2.text(1.4, q3, f'Q3\n{q3:.3f}', va='center', fontsize=9, color='darkblue')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f" F1 distribution plot saved to {save_path}")
    
    plt.show()
    
    # Print summary to console
    print("\n" + "="*80)
    print(" F1 SCORE DISTRIBUTION SUMMARY")
    print("="*80)
    print(f"\nTotal patches evaluated: {len(f1_scores)}")
    print(f"\nCentral Tendency:")
    print(f"  Mean:   {mean_f1:.6f}")
    print(f"  Median: {median_f1:.6f}")
    print(f"\nSpread:")
    print(f"  Std Dev:  {std_f1:.6f}")
    print(f"  Range:    [{min_f1:.6f}, {max_f1:.6f}]")
    print(f"  IQR:      {np.percentile(f1_scores, 75) - np.percentile(f1_scores, 25):.6f}")
    print(f"\nQuartiles:")
    print(f"  Q1 (25%): {np.percentile(f1_scores, 25):.6f}")
    print(f"  Q2 (50%): {median_f1:.6f}")
    print(f"  Q3 (75%): {np.percentile(f1_scores, 75):.6f}")
    print(f"\nPerformance Categories:")
    excellent = sum(1 for f1 in f1_scores if f1 >= 0.9)
    good = sum(1 for f1 in f1_scores if 0.8 <= f1 < 0.9)
    fair = sum(1 for f1 in f1_scores if 0.7 <= f1 < 0.8)
    poor = sum(1 for f1 in f1_scores if f1 < 0.7)
    print(f"  Excellent (F1 ≥ 0.9): {excellent} ({excellent/len(f1_scores)*100:.1f}%)")
    print(f"  Good (0.8 ≤ F1 < 0.9): {good} ({good/len(f1_scores)*100:.1f}%)")
    print(f"  Fair (0.7 ≤ F1 < 0.8): {fair} ({fair/len(f1_scores)*100:.1f}%)")
    print(f"  Poor (F1 < 0.7):       {poor} ({poor/len(f1_scores)*100:.1f}%)")
    print("="*80 + "\n")

print(" F1 distribution visualization function defined")

In [None]:
# Generate F1 Score Distribution Plot

print("Analyzing F1 score distribution across all evaluated patches...")

results = evaluate_and_select_examples(
    min_keypoints=20,
    n_candidates=10,
    collect_correlation_metrics=False
)

best_candidates, average_candidates, worst_candidates = results

print("\nScanning all patches to collect F1 scores for distribution analysis...")

all_f1_metadata = []
data_list = load_slide_data(DATA_ROOT)

for i, data in enumerate(data_list):
    if (i+1) % 100 == 0:
        print(f"  Progress: {i+1}/{len(data_list)}")
    
    try:
        img_raw = img_to_array(load_img(str(data['patch_path']), target_size=(512, 512)))
        img_norm = img_raw / 255.0
        cancer_mask, cancer_image = extract_cancer_region_ensemble(img_norm, tissue_model)
        filtered_kps, _ = filter_keypoints_by_cancer_mask(data['patch_info']['keypoints'], cancer_mask)
        
        if len(filtered_kps) < 20:  # Skip patches with too few keypoints
            continue
        
        if 'efficientnet' in BACKBONE:
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            inp = (cancer_image - mean) / std
        else:
            inp = cancer_image
        
        pred = nuclei_model.predict(np.expand_dims(inp, axis=0), verbose=0)[0]
        pred_mask = np.argmax(pred, axis=-1)
        pred_mask_processed = apply_post_processing(pred_mask, min_object_size=50)
        
        f1, precision, recall, details = calculate_keypoint_coverage_f1(filtered_kps, pred_mask_processed)
        
        all_f1_metadata.append({
            'f1': f1,
            'precision': precision,
            'recall': recall,
            'filename': data['patch_info']['filename'],
            'num_keypoints': len(filtered_kps)
        })
        
    except Exception as e:
        continue

print(f"\n Collected F1 scores from {len(all_f1_metadata)} patches")

plot_f1_distribution(all_f1_metadata)

print("\n F1 distribution analysis complete")