In [1]:
# Cell 1: Imports

import os
import sys
import json
import yaml
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import cv2
from PIL import Image
import matplotlib.pyplot as plt
from scipy.ndimage import label as scipy_label
from scipy.optimize import linear_sum_assignment

import albumentations as A
from albumentations.pytorch import ToTensorV2

# Add CIPS-Net to path
sys.path.insert(0, 'CIPS-Net')
from models.cips_net import CIPSNet

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

print("\n" + "=" * 60)
print("CIPS-Net PanNuke Evaluation with Panoptic Quality")
print("=" * 60)

Using device: cuda
GPU: NVIDIA RTX A5000

CIPS-Net PanNuke Evaluation with Panoptic Quality


In [2]:
# Cell 2: Configuration

# ============================================================================
# EVALUATION CONFIGURATION
# ============================================================================

# Point to your trained experiment directory
# Update this path to match your training output
EXPERIMENT_DIR = "results/cipsnet_pannuke_cv3_balanced_VIT_B_16_distil_bert_uncased_20260103_201916"  # <-- UPDATE THIS

# If experiment dir doesn't exist, list available experiments
if not os.path.exists(EXPERIMENT_DIR):
    print("‚ö†Ô∏è  Experiment directory not found!")
    print("\nAvailable experiments in results/:")
    if os.path.exists("results"):
        for exp in sorted(os.listdir("results")):
            if exp.startswith("cipsnet_pannuke"):
                print(f"  - results/{exp}")
    print("\nüëÜ Update EXPERIMENT_DIR above with one of these paths")
else:
    print(f"‚úì Using experiment: {EXPERIMENT_DIR}")

# Load config from training
config_path = f"{EXPERIMENT_DIR}/config.yaml"
if os.path.exists(config_path):
    with open(config_path, 'r') as f:
        CONFIG = yaml.safe_load(f)
    print(f"‚úì Loaded config from {config_path}")
else:
    # Default config if not found
    CONFIG = {
        'dataset_path': 'PanNuke_Preprocess',
        'num_folds': 3,
        'img_size': 224,
        'class_names': ['Neoplastic', 'Inflammatory', 'Connective_Soft_tissue', 'Dead', 'Epithelial'],
        'num_classes': 5,
        'img_encoder': 'vit_b_16',
        'text_encoder': 'distilbert-base-uncased',
        'embed_dim': 768,
        'mean': [0.485, 0.456, 0.406],
        'std': [0.229, 0.224, 0.225],
        'batch_size': 16,
    }
    print("‚ö†Ô∏è  Using default config (config.yaml not found)")

# Evaluation settings
EVAL_CONFIG = {
    'batch_size': 8,  # Lower batch size for evaluation
    'num_workers': 0,
    'iou_threshold': 0.5,  # PQ matching threshold (standard)
}

# Output directory for evaluation results
EVAL_OUTPUT_DIR = f"{EXPERIMENT_DIR}/evaluation"
os.makedirs(EVAL_OUTPUT_DIR, exist_ok=True)

print(f"\nEvaluation output: {EVAL_OUTPUT_DIR}")
print(f"Classes: {CONFIG['class_names']}")

‚úì Using experiment: results/cipsnet_pannuke_cv3_balanced_VIT_B_16_distil_bert_uncased_20260103_201916
‚úì Loaded config from results/cipsnet_pannuke_cv3_balanced_VIT_B_16_distil_bert_uncased_20260103_201916/config.yaml

Evaluation output: results/cipsnet_pannuke_cv3_balanced_VIT_B_16_distil_bert_uncased_20260103_201916/evaluation
Classes: ['Neoplastic', 'Inflammatory', 'Connective_Soft_tissue', 'Dead', 'Epithelial']


## 1. Dataset and Data Loading

In [12]:
# Cell 3: Dataset Class for Evaluation (Same as Training)

class PanNukeEvalDataset(Dataset):
    """
    Dataset for PanNuke evaluation - same format as training.
    
    - Images: PNG files in images/fold{n}/
    - Masks: NPZ files in masks/fold{n}/ containing binary masks per class
    - Annotations: annotations.csv with image_id, fold, classes_present, instruction
    """
    
    def __init__(self, data_root, folds, transform=None, img_size=224, class_names=None):
        self.data_root = data_root
        self.folds = folds if isinstance(folds, list) else [folds]
        self.transform = transform
        self.img_size = img_size
        self.class_names = class_names or CONFIG['class_names']
        self.num_classes = len(self.class_names)
        
        # Load annotations
        annotations = pd.read_csv(os.path.join(data_root, 'annotations.csv'))
        self.df = annotations[annotations['fold'].isin(self.folds)].reset_index(drop=True)
        
        print(f"Loaded {len(self.df)} samples from folds {self.folds}")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_id = row['image_id']
        fold = row['fold']
        
        # Load image (same as training)
        img_path = os.path.join(self.data_root, 'images', f'fold{fold}', f'{image_id}.png')
        image = np.array(Image.open(img_path).convert('RGB'))
        
        # Load masks (same as training)
        mask_path = os.path.join(self.data_root, 'masks', f'fold{fold}', f'{image_id}.npz')
        mask_data = np.load(mask_path)
        masks = mask_data['masks']  # [H, W, num_classes] binary masks
        
        # Create class index mask from binary masks (same as training)
        # Priority: later classes override earlier ones if overlapping
        class_index_mask = np.zeros((masks.shape[0], masks.shape[1]), dtype=np.int64)
        for c in range(self.num_classes):
            class_index_mask[masks[:, :, c] > 0] = c
        
        # Create instance mask from binary masks for PQ computation
        # Each connected component gets a unique ID
        instance_mask = self._create_instance_mask(masks)
        
        # Store original semantic mask before transforms (for GT in PQ)
        original_semantic = class_index_mask.copy()
        original_instance = instance_mask.copy()
        
        # Apply augmentations (only resize & normalize for eval)
        if self.transform:
            augmented = self.transform(image=image, mask=class_index_mask)
            image = augmented['image']
            class_index_mask = augmented['mask']
        
        # Convert to tensors (same as training)
        if isinstance(image, np.ndarray):
            image = torch.from_numpy(image).permute(2, 0, 1).float()
        elif isinstance(image, torch.Tensor) and image.ndim == 3 and image.shape[-1] == 3:
            image = image.permute(2, 0, 1).float()
        
        if isinstance(class_index_mask, np.ndarray):
            class_index_mask = torch.from_numpy(class_index_mask.astype(np.int64)).long()
        else:
            class_index_mask = class_index_mask.long()
        
        # Get instruction (same as training)
        instruction = row['instruction'] if pd.notna(row['instruction']) else "Segment all tissue types."
        
        return {
            'image': image,
            'mask': class_index_mask,
            'semantic_mask': original_semantic,  # Original resolution for PQ
            'instance_mask': original_instance,  # For PQ computation
            'binary_masks': masks,  # Original [H,W,C] binary masks
            'instruction': instruction,
            'image_id': image_id,
            'fold': fold
        }
    
    def _create_instance_mask(self, binary_masks):
        """Create instance mask from binary masks using connected components."""
        H, W, C = binary_masks.shape
        instance_mask = np.zeros((H, W), dtype=np.int32)
        instance_id = 1
        
        for c in range(C):
            class_mask = binary_masks[:, :, c].astype(np.uint8)
            if class_mask.sum() > 0:
                num_labels, labels = cv2.connectedComponents(class_mask, connectivity=8)
                for label_id in range(1, num_labels):
                    instance_mask[labels == label_id] = instance_id
                    instance_id += 1
        
        return instance_mask


def get_eval_transforms(img_size, mean, std):
    """Get evaluation transforms (no augmentation, same as training val)."""
    return A.Compose([
        A.Resize(img_size, img_size),
        A.Normalize(mean=mean, std=std),
    ])


def eval_collate_fn(batch):
    """Custom collate function for evaluation."""
    images = torch.stack([item['image'] for item in batch])
    masks = torch.stack([item['mask'] for item in batch])
    semantic_masks = [item['semantic_mask'] for item in batch]  # List of numpy arrays
    instance_masks = [item['instance_mask'] for item in batch]  # List of numpy arrays
    binary_masks = [item['binary_masks'] for item in batch]  # List of numpy arrays
    instructions = [item['instruction'] for item in batch]
    image_ids = [item['image_id'] for item in batch]
    
    return {
        'image': images,
        'mask': masks,
        'semantic_mask': semantic_masks,
        'instance_mask': instance_masks,
        'binary_masks': binary_masks,
        'instruction': instructions,
        'image_id': image_ids
    }


print("‚úì Dataset and transforms defined for evaluation (matching training format)")

‚úì Dataset and transforms defined for evaluation (matching training format)


## 2. Official PanNuke Panoptic Quality Functions

These functions are adapted from the official PanNuke evaluation code (`stats_utils.py`).

In [4]:
# Cell 4: Official PanNuke PQ Functions (from stats_utils.py)

def remap_label(pred, by_size=False):
    """
    Rename all instance IDs such that the ID is contiguous (i.e. 1, 2, 3, ...)
    
    Args:
        pred: Input instance map (numpy array)
        by_size: If True, relabel instances by their size (largest first)
    
    Returns:
        Remapped instance map
    """
    pred_id = list(np.unique(pred))
    if 0 in pred_id:
        pred_id.remove(0)
    if len(pred_id) == 0:
        return pred  # No instances
    
    if by_size:
        pred_size = []
        for inst_id in pred_id:
            size = (pred == inst_id).sum()
            pred_size.append(size)
        # Sort by size (descending)
        pair_list = zip(pred_id, pred_size)
        pair_list = sorted(pair_list, key=lambda x: x[1], reverse=True)
        pred_id, _ = zip(*pair_list)
    
    new_pred = np.zeros(pred.shape, np.int32)
    for idx, inst_id in enumerate(pred_id):
        new_pred[pred == inst_id] = idx + 1
    
    return new_pred


def get_fast_pq(true, pred, match_iou=0.5):
    """
    Compute Panoptic Quality (PQ) for instance segmentation.
    
    PQ = DQ √ó SQ
    - DQ (Detection Quality) = TP / (TP + 0.5*FP + 0.5*FN)
    - SQ (Segmentation Quality) = Average IoU of matched (TP) pairs
    
    Uses Hungarian algorithm for optimal 1-to-1 matching between
    ground truth and predicted instances.
    
    Args:
        true: Ground truth instance map (H, W) - each unique value is an instance
        pred: Predicted instance map (H, W) - each unique value is an instance
        match_iou: IoU threshold for considering a match (default 0.5)
    
    Returns:
        [DQ, SQ, PQ] as numpy array
    """
    assert match_iou >= 0.0, "match_iou must be >= 0.0"
    
    true = np.copy(true)
    pred = np.copy(pred)
    
    # Remap labels to be contiguous
    true = remap_label(true)
    pred = remap_label(pred)
    
    # Get unique instance IDs (excluding background 0)
    true_id_list = list(np.unique(true))
    pred_id_list = list(np.unique(pred))
    
    if 0 in true_id_list:
        true_id_list.remove(0)
    if 0 in pred_id_list:
        pred_id_list.remove(0)
    
    # Edge case: no instances
    if len(true_id_list) == 0 and len(pred_id_list) == 0:
        return np.array([1.0, 1.0, 1.0])  # Perfect score (nothing to predict, nothing predicted)
    if len(true_id_list) == 0:
        return np.array([0.0, 0.0, 0.0])  # All FP
    if len(pred_id_list) == 0:
        return np.array([0.0, 0.0, 0.0])  # All FN
    
    # Compute pairwise IoU matrix
    num_true = len(true_id_list)
    num_pred = len(pred_id_list)
    
    # Create IoU matrix
    pairwise_iou = np.zeros((num_true, num_pred), dtype=np.float64)
    
    for t_idx, t_id in enumerate(true_id_list):
        true_mask = (true == t_id)
        for p_idx, p_id in enumerate(pred_id_list):
            pred_mask = (pred == p_id)
            
            intersection = np.logical_and(true_mask, pred_mask).sum()
            union = np.logical_or(true_mask, pred_mask).sum()
            
            if union > 0:
                pairwise_iou[t_idx, p_idx] = intersection / union
    
    # Hungarian matching to find optimal 1-to-1 assignment
    # We want to maximize IoU, but linear_sum_assignment minimizes cost
    # So we use negative IoU as cost
    
    if num_true <= num_pred:
        # More predictions than GT - match GT to predictions
        cost_matrix = -pairwise_iou
        row_ind, col_ind = linear_sum_assignment(cost_matrix)
        matched_iou = pairwise_iou[row_ind, col_ind]
    else:
        # More GT than predictions - match predictions to GT
        cost_matrix = -pairwise_iou.T
        col_ind, row_ind = linear_sum_assignment(cost_matrix)
        matched_iou = pairwise_iou[row_ind, col_ind]
    
    # Filter matches by IoU threshold
    valid_matches = matched_iou >= match_iou
    
    # Count TP, FP, FN
    tp = valid_matches.sum()
    fp = num_pred - tp  # Predictions not matched to GT
    fn = num_true - tp  # GT not matched to predictions
    
    # Compute metrics
    # DQ = TP / (TP + 0.5*FP + 0.5*FN)
    dq = tp / (tp + 0.5 * fp + 0.5 * fn + 1e-8)
    
    # SQ = average IoU of TPs
    if tp > 0:
        sq = matched_iou[valid_matches].sum() / tp
    else:
        sq = 0.0
    
    # PQ = DQ √ó SQ
    pq = dq * sq
    
    return np.array([dq, sq, pq])


print("‚úì Official PanNuke PQ functions defined:")
print("  - remap_label(): Relabel instances to contiguous IDs")
print("  - get_fast_pq(): Compute DQ, SQ, PQ with Hungarian matching")

‚úì Official PanNuke PQ functions defined:
  - remap_label(): Relabel instances to contiguous IDs
  - get_fast_pq(): Compute DQ, SQ, PQ with Hungarian matching


## 3. Instance Extraction from Semantic Predictions

In [5]:
# Cell 5: Instance Extraction from Semantic Segmentation

def semantic_to_instance(semantic_mask, num_classes):
    """
    Convert semantic segmentation to instance segmentation using connected components.
    
    For each class, find connected components and assign unique instance IDs.
    
    Args:
        semantic_mask: (H, W) array with class labels (0 to num_classes-1)
        num_classes: Number of classes
    
    Returns:
        instance_mask: (H, W) array with unique instance IDs
        class_instance_map: Dict mapping instance_id -> class_id
    """
    H, W = semantic_mask.shape
    instance_mask = np.zeros((H, W), dtype=np.int32)
    class_instance_map = {}  # instance_id -> class_id
    
    instance_id = 1
    
    for class_id in range(num_classes):
        # Get binary mask for this class
        class_mask = (semantic_mask == class_id).astype(np.uint8)
        
        if class_mask.sum() == 0:
            continue
        
        # Find connected components
        num_labels, labels = cv2.connectedComponents(class_mask, connectivity=8)
        
        # Assign unique instance IDs (skip background label 0)
        for label_id in range(1, num_labels):
            instance_mask[labels == label_id] = instance_id
            class_instance_map[instance_id] = class_id
            instance_id += 1
    
    return instance_mask, class_instance_map


def get_class_instance_mask(instance_mask, class_instance_map, target_class):
    """
    Extract instance mask for a specific class.
    
    Args:
        instance_mask: Full instance mask (H, W)
        class_instance_map: Dict mapping instance_id -> class_id
        target_class: Class ID to extract
    
    Returns:
        Class-specific instance mask (H, W) with relabeled instances
    """
    class_mask = np.zeros_like(instance_mask)
    new_id = 1
    
    for inst_id, class_id in class_instance_map.items():
        if class_id == target_class:
            class_mask[instance_mask == inst_id] = new_id
            new_id += 1
    
    return class_mask


def compute_pq_metrics(pred_semantic, gt_instance, num_classes, class_names, iou_threshold=0.5):
    """
    Compute bPQ and mPQ from semantic predictions and instance ground truth.
    
    Args:
        pred_semantic: Predicted semantic mask (H, W) with class labels
        gt_instance: Ground truth instance mask (H, W) with unique instance IDs
        num_classes: Number of classes
        class_names: List of class names
        iou_threshold: IoU threshold for PQ matching
    
    Returns:
        Dictionary with bPQ, mPQ, and per-class PQ values
    """
    # Convert prediction to instances
    pred_instance, pred_class_map = semantic_to_instance(pred_semantic, num_classes)
    
    # Get GT class mapping (we need to know which class each GT instance belongs to)
    # Assuming GT instance mask has class info encoded or we derive from overlap
    gt_class_map = {}
    for inst_id in np.unique(gt_instance):
        if inst_id == 0:
            continue
        # Find most common class in the region (from semantic gt)
        # For PanNuke, we can derive from the semantic mask
        # Here we assume gt_instance encodes class info already
        gt_class_map[inst_id] = 0  # Will be overridden below
    
    # Per-class PQ computation
    pq_per_class = []
    dq_per_class = []
    sq_per_class = []
    
    for class_id in range(num_classes):
        # Get class-specific instance masks
        pred_class_inst = get_class_instance_mask(pred_instance, pred_class_map, class_id)
        
        # For GT, we need to extract instances of this class
        # This requires knowing the class of each GT instance
        # Simplified: assume we have a way to filter GT by class
        gt_class_inst = np.zeros_like(gt_instance)  # Placeholder
        
        # Compute PQ for this class
        if pred_class_inst.max() == 0 and gt_class_inst.max() == 0:
            # No instances of this class - skip or count as perfect
            pq_per_class.append(1.0)
            dq_per_class.append(1.0)
            sq_per_class.append(1.0)
        else:
            dq, sq, pq = get_fast_pq(gt_class_inst, pred_class_inst, match_iou=iou_threshold)
            pq_per_class.append(pq)
            dq_per_class.append(dq)
            sq_per_class.append(sq)
    
    # Binary PQ (treat all nuclei as one class)
    pred_binary = (pred_instance > 0).astype(np.int32)
    pred_binary = remap_label(pred_binary)
    
    gt_binary = (gt_instance > 0).astype(np.int32)
    gt_binary = remap_label(gt_binary)
    
    b_dq, b_sq, b_pq = get_fast_pq(gt_binary, pred_binary, match_iou=iou_threshold)
    
    # Compile results
    metrics = {
        'bPQ': b_pq,
        'bDQ': b_dq,
        'bSQ': b_sq,
        'mPQ': np.mean(pq_per_class),
        'mDQ': np.mean(dq_per_class),
        'mSQ': np.mean(sq_per_class),
    }
    
    for i, name in enumerate(class_names):
        metrics[f'pq_{name}'] = pq_per_class[i]
        metrics[f'dq_{name}'] = dq_per_class[i]
        metrics[f'sq_{name}'] = sq_per_class[i]
    
    return metrics


print("‚úì Instance extraction functions defined:")
print("  - semantic_to_instance(): Convert semantic ‚Üí instance using connected components")
print("  - get_class_instance_mask(): Extract instances for a specific class")
print("  - compute_pq_metrics(): Compute bPQ, mPQ from predictions")

‚úì Instance extraction functions defined:
  - semantic_to_instance(): Convert semantic ‚Üí instance using connected components
  - get_class_instance_mask(): Extract instances for a specific class
  - compute_pq_metrics(): Compute bPQ, mPQ from predictions


## 4. Fold Evaluation Function

In [13]:
# Cell 6: Evaluation Function for a Single Fold

def evaluate_fold(fold_idx, test_fold, config, experiment_dir, eval_config):
    """
    Evaluate a single fold: load model, run predictions, compute PQ metrics.
    
    Args:
        fold_idx: Fold index (1, 2, or 3)
        test_fold: Which fold to use as test data
        config: Training configuration
        experiment_dir: Directory containing trained models
        eval_config: Evaluation configuration
    
    Returns:
        Dictionary with PQ metrics for this fold
    """
    print("\n" + "=" * 70)
    print(f"EVALUATING FOLD {fold_idx}: Testing on fold {test_fold}")
    print("=" * 70)
    
    # Load model checkpoint
    model_path = f"{experiment_dir}/fold{fold_idx}/best_model.pth"
    if not os.path.exists(model_path):
        print(f"‚ö†Ô∏è  Model not found: {model_path}")
        return None
    
    # weights_only=False for PyTorch 2.6+ compatibility (contains numpy arrays)
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    print(f"‚úì Loaded model from: {model_path}")
    print(f"  Best epoch: {checkpoint.get('epoch', 'N/A')}")
    print(f"  Best Dice: {checkpoint.get('best_dice', 'N/A'):.4f}")
    
    # Initialize model
    model = CIPSNet(
        img_encoder_name=config['img_encoder'],
        text_encoder_name=config['text_encoder'],
        embed_dim=config['embed_dim'],
        num_classes=config['num_classes'],
        img_size=config['img_size'],
        pretrained=False  # We're loading weights
    ).to(device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # Create test dataset (matching training format)
    test_transform = get_eval_transforms(
        config['img_size'], 
        config['mean'], 
        config['std']
    )
    
    test_dataset = PanNukeEvalDataset(
        data_root=config['dataset_path'],
        folds=[test_fold],
        transform=test_transform,
        img_size=config['img_size'],
        class_names=config['class_names']
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=eval_config['batch_size'],
        shuffle=False,
        num_workers=eval_config['num_workers'],
        collate_fn=eval_collate_fn,
        pin_memory=True
    )
    
    print(f"Test samples: {len(test_dataset)}")
    
    # Collect predictions and compute PQ
    all_pq_binary = []
    all_pq_per_class = {name: [] for name in config['class_names']}
    
    with torch.no_grad():
        pbar = tqdm(test_loader, desc=f'Evaluating Fold {fold_idx}')
        
        for batch in pbar:
            images = batch['image'].to(device)
            gt_instance_list = batch['instance_mask']  # List of numpy arrays
            gt_binary_masks_list = batch['binary_masks']  # List of [H,W,C] arrays
            instructions = batch['instruction']
            
            # Forward pass
            outputs = model(images, instructions)
            logits = outputs['masks'][:, :config['num_classes'], :, :]
            
            # Resize predictions to original size if needed
            # Note: GT is at original resolution, predictions at model resolution
            pred_semantic = torch.argmax(logits, dim=1).cpu().numpy()
            
            # Process each sample in batch
            batch_size = images.shape[0]
            for i in range(batch_size):
                pred_sem = pred_semantic[i]  # [H_model, W_model]
                gt_inst = gt_instance_list[i]  # [H_orig, W_orig]
                gt_binary = gt_binary_masks_list[i]  # [H_orig, W_orig, C]
                
                # Resize prediction to original GT size for fair comparison
                H_orig, W_orig = gt_inst.shape
                pred_sem_resized = cv2.resize(
                    pred_sem.astype(np.uint8), 
                    (W_orig, H_orig), 
                    interpolation=cv2.INTER_NEAREST
                )
                
                # Convert prediction to instance mask
                pred_inst, pred_class_map = semantic_to_instance(pred_sem_resized, config['num_classes'])
                
                # Create GT class map from binary masks
                gt_class_map = {}
                gt_inst_copy = gt_inst.copy()
                instance_id = 1
                for c in range(config['num_classes']):
                    class_mask = gt_binary[:, :, c].astype(np.uint8)
                    if class_mask.sum() > 0:
                        num_labels, labels = cv2.connectedComponents(class_mask, connectivity=8)
                        for label_id in range(1, num_labels):
                            gt_class_map[instance_id] = c
                            instance_id += 1
                
                # Binary PQ (all nuclei as one class)
                dq, sq, pq = get_fast_pq(gt_inst, pred_inst, match_iou=eval_config['iou_threshold'])
                all_pq_binary.append(pq)
                
                # Per-class PQ
                for class_id, class_name in enumerate(config['class_names']):
                    pred_class_inst = get_class_instance_mask(pred_inst, pred_class_map, class_id)
                    gt_class_inst = get_class_instance_mask(gt_inst, gt_class_map, class_id)
                    
                    # Only compute if at least one has instances
                    if pred_class_inst.max() > 0 or gt_class_inst.max() > 0:
                        dq_c, sq_c, pq_c = get_fast_pq(gt_class_inst, pred_class_inst, 
                                                        match_iou=eval_config['iou_threshold'])
                        all_pq_per_class[class_name].append(pq_c)
    
    # Aggregate metrics
    metrics = {
        'fold': fold_idx,
        'test_fold': test_fold,
        'bPQ': np.mean(all_pq_binary) if all_pq_binary else 0.0,
        'bPQ_std': np.std(all_pq_binary) if all_pq_binary else 0.0,
    }
    
    # Per-class PQ
    pq_per_class_mean = []
    for class_name in config['class_names']:
        values = all_pq_per_class[class_name]
        if values:
            metrics[f'pq_{class_name}'] = np.mean(values)
            metrics[f'pq_{class_name}_std'] = np.std(values)
            pq_per_class_mean.append(np.mean(values))
        else:
            metrics[f'pq_{class_name}'] = 0.0
            metrics[f'pq_{class_name}_std'] = 0.0
            pq_per_class_mean.append(0.0)
    
    # mPQ is average across classes
    metrics['mPQ'] = np.mean(pq_per_class_mean)
    
    # Print results
    print(f"\nüìä Fold {fold_idx} Results:")
    print(f"  bPQ (Binary PQ): {metrics['bPQ']:.4f} ¬± {metrics['bPQ_std']:.4f}")
    print(f"  mPQ (Multi-class PQ): {metrics['mPQ']:.4f}")
    print("  Per-class PQ:")
    for class_name in config['class_names']:
        print(f"    {class_name}: {metrics[f'pq_{class_name}']:.4f}")
    
    # Cleanup
    del model
    torch.cuda.empty_cache()
    
    return metrics


print("‚úì Fold evaluation function defined")

‚úì Fold evaluation function defined


## 5. Run 3-Fold Cross-Validation Evaluation

In [14]:
# Cell 7: Run 3-Fold Cross-Validation Evaluation

print("=" * 70)
print("STARTING 3-FOLD PANOPTIC QUALITY EVALUATION")
print("=" * 70)

# Cross-validation splits (same as training)
cv_splits = [
    {'fold_idx': 1, 'test_fold': 1},  # Model trained on [2,3], test on 1
    {'fold_idx': 2, 'test_fold': 2},  # Model trained on [1,3], test on 2
    {'fold_idx': 3, 'test_fold': 3},  # Model trained on [1,2], test on 3
]

# Check if experiment directory exists
if not os.path.exists(EXPERIMENT_DIR):
    print(f"\n‚ùå ERROR: Experiment directory not found: {EXPERIMENT_DIR}")
    print("Please update EXPERIMENT_DIR in Cell 2 to point to your training results.")
else:
    # Store results for all folds
    all_fold_results = []
    
    for split in cv_splits:
        fold_metrics = evaluate_fold(
            fold_idx=split['fold_idx'],
            test_fold=split['test_fold'],
            config=CONFIG,
            experiment_dir=EXPERIMENT_DIR,
            eval_config=EVAL_CONFIG
        )
        
        if fold_metrics is not None:
            all_fold_results.append(fold_metrics)
    
    print("\n" + "=" * 70)
    print("3-FOLD EVALUATION COMPLETE!")
    print("=" * 70)

STARTING 3-FOLD PANOPTIC QUALITY EVALUATION

EVALUATING FOLD 1: Testing on fold 1
‚úì Loaded model from: results/cipsnet_pannuke_cv3_balanced_VIT_B_16_distil_bert_uncased_20260103_201916/fold1/best_model.pth
  Best epoch: 47
  Best Dice: 0.6201
Loaded 2656 samples from folds [1]
Test samples: 2656


Evaluating Fold 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 332/332 [04:02<00:00,  1.37it/s]



üìä Fold 1 Results:
  bPQ (Binary PQ): 0.1503 ¬± 0.1613
  mPQ (Multi-class PQ): 0.1392
  Per-class PQ:
    Neoplastic: 0.0000
    Inflammatory: 0.2729
    Connective_Soft_tissue: 0.1483
    Dead: 0.1418
    Epithelial: 0.1330

EVALUATING FOLD 2: Testing on fold 2
‚úì Loaded model from: results/cipsnet_pannuke_cv3_balanced_VIT_B_16_distil_bert_uncased_20260103_201916/fold2/best_model.pth
  Best epoch: 49
  Best Dice: 0.6070
Loaded 2523 samples from folds [2]
Test samples: 2523


Evaluating Fold 2:  93%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé| 294/316 [02:45<00:12,  1.77it/s]


KeyboardInterrupt: 

## 6. Aggregate Results and Save

In [None]:
# Cell 8: Aggregate Results and Save

if len(all_fold_results) > 0:
    print("\n" + "=" * 70)
    print("PANOPTIC QUALITY SUMMARY (3-Fold Cross-Validation)")
    print("=" * 70)
    
    # Create results DataFrame
    results_df = pd.DataFrame(all_fold_results)
    
    print("\nPer-Fold Results:")
    print(results_df.to_string(index=False))
    
    # Aggregate metrics
    print("\n" + "-" * 60)
    print("AGGREGATED RESULTS (Mean ¬± Std across 3 folds):")
    print("-" * 60)
    
    # Binary PQ
    bpq_mean = results_df['bPQ'].mean()
    bpq_std = results_df['bPQ'].std()
    print(f"\n  bPQ (Binary Panoptic Quality): {bpq_mean:.4f} ¬± {bpq_std:.4f}")
    
    # Multi-class PQ
    mpq_mean = results_df['mPQ'].mean()
    mpq_std = results_df['mPQ'].std()
    print(f"  mPQ (Multi-class Panoptic Quality): {mpq_mean:.4f} ¬± {mpq_std:.4f}")
    
    # Per-class PQ
    print("\n  Per-Class PQ (Mean ¬± Std):")
    for class_name in CONFIG['class_names']:
        col = f'pq_{class_name}'
        if col in results_df.columns:
            mean_val = results_df[col].mean()
            std_val = results_df[col].std()
            print(f"    {class_name}: {mean_val:.4f} ¬± {std_val:.4f}")
    
    # Save results
    results_df.to_csv(f"{EVAL_OUTPUT_DIR}/pq_per_fold.csv", index=False)
    
    # Save aggregated summary
    summary = {
        'bPQ_mean': float(bpq_mean),
        'bPQ_std': float(bpq_std),
        'mPQ_mean': float(mpq_mean),
        'mPQ_std': float(mpq_std),
    }
    for class_name in CONFIG['class_names']:
        col = f'pq_{class_name}'
        if col in results_df.columns:
            summary[f'{col}_mean'] = float(results_df[col].mean())
            summary[f'{col}_std'] = float(results_df[col].std())
    
    with open(f"{EVAL_OUTPUT_DIR}/pq_summary.json", 'w') as f:
        json.dump(summary, f, indent=2)
    
    print(f"\n‚úì Results saved to: {EVAL_OUTPUT_DIR}")
    print(f"  - pq_per_fold.csv: Per-fold PQ metrics")
    print(f"  - pq_summary.json: Aggregated summary")
else:
    print("‚ö†Ô∏è  No evaluation results to aggregate.")

## 7. Visualization

In [None]:
# Cell 9: Visualization of PQ Results

if len(all_fold_results) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot 1: bPQ and mPQ per fold
    ax = axes[0]
    folds = [r['fold'] for r in all_fold_results]
    bpq_values = [r['bPQ'] for r in all_fold_results]
    mpq_values = [r['mPQ'] for r in all_fold_results]
    
    x = np.arange(len(folds))
    width = 0.35
    
    bars1 = ax.bar(x - width/2, bpq_values, width, label='bPQ', color='#1f77b4', alpha=0.8)
    bars2 = ax.bar(x + width/2, mpq_values, width, label='mPQ', color='#ff7f0e', alpha=0.8)
    
    # Add value labels
    for bar, val in zip(bars1, bpq_values):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
                f'{val:.3f}', ha='center', va='bottom', fontsize=9)
    for bar, val in zip(bars2, mpq_values):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
                f'{val:.3f}', ha='center', va='bottom', fontsize=9)
    
    ax.set_xlabel('Fold')
    ax.set_ylabel('Panoptic Quality')
    ax.set_title('bPQ and mPQ per Fold')
    ax.set_xticks(x)
    ax.set_xticklabels([f'Fold {f}' for f in folds])
    ax.legend()
    ax.set_ylim(0, 1.0)
    ax.grid(True, alpha=0.3, axis='y')
    
    # Plot 2: Per-class PQ (averaged across folds)
    ax = axes[1]
    class_names = CONFIG['class_names']
    pq_means = []
    pq_stds = []
    
    for class_name in class_names:
        values = [r[f'pq_{class_name}'] for r in all_fold_results]
        pq_means.append(np.mean(values))
        pq_stds.append(np.std(values))
    
    x = np.arange(len(class_names))
    bars = ax.bar(x, pq_means, yerr=pq_stds, capsize=5, color='#2ca02c', alpha=0.8)
    
    for bar, mean in zip(bars, pq_means):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                f'{mean:.3f}', ha='center', va='bottom', fontsize=9)
    
    ax.set_xlabel('Class')
    ax.set_ylabel('Panoptic Quality')
    ax.set_title('Per-Class PQ (Mean ¬± Std across folds)')
    ax.set_xticks(x)
    ax.set_xticklabels(class_names, rotation=45, ha='right')
    ax.set_ylim(0, 1.0)
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig(f"{EVAL_OUTPUT_DIR}/pq_visualization.png", dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\n‚úì Visualization saved to: {EVAL_OUTPUT_DIR}/pq_visualization.png")
else:
    print("‚ö†Ô∏è  No results to visualize.")

In [None]:
# Cell 10: Final Summary Table (Paper-Ready Format)

if len(all_fold_results) > 0:
    print("\n" + "=" * 70)
    print("FINAL PANOPTIC QUALITY RESULTS (Paper-Ready Format)")
    print("=" * 70)
    
    # Create summary table
    summary_rows = []
    
    # bPQ
    bpq_mean = results_df['bPQ'].mean()
    bpq_std = results_df['bPQ'].std()
    summary_rows.append({
        'Metric': 'bPQ (Binary)',
        'Mean': f'{bpq_mean:.4f}',
        'Std': f'{bpq_std:.4f}',
        'Result': f'{bpq_mean:.4f} ¬± {bpq_std:.4f}'
    })
    
    # mPQ
    mpq_mean = results_df['mPQ'].mean()
    mpq_std = results_df['mPQ'].std()
    summary_rows.append({
        'Metric': 'mPQ (Multi-class)',
        'Mean': f'{mpq_mean:.4f}',
        'Std': f'{mpq_std:.4f}',
        'Result': f'{mpq_mean:.4f} ¬± {mpq_std:.4f}'
    })
    
    # Per-class PQ
    for class_name in CONFIG['class_names']:
        col = f'pq_{class_name}'
        if col in results_df.columns:
            mean_val = results_df[col].mean()
            std_val = results_df[col].std()
            summary_rows.append({
                'Metric': f'PQ ({class_name})',
                'Mean': f'{mean_val:.4f}',
                'Std': f'{std_val:.4f}',
                'Result': f'{mean_val:.4f} ¬± {std_val:.4f}'
            })
    
    summary_table = pd.DataFrame(summary_rows)
    print("\n" + summary_table.to_string(index=False))
    
    # Save final summary
    summary_table.to_csv(f"{EVAL_OUTPUT_DIR}/final_pq_results.csv", index=False)
    
    print("\n" + "=" * 70)
    print("COMPARISON WITH PANNUKE PAPER BASELINES")
    print("=" * 70)
    print("\nReference values from PanNuke paper (different methods):")
    print("  - Micro-Net: bPQ=0.3866, mPQ=0.2291")
    print("  - DIST:      bPQ=0.4108, mPQ=0.2464")
    print("  - Mask-RCNN: bPQ=0.4011, mPQ=0.2484")
    print("  - HoVer-Net: bPQ=0.4724, mPQ=0.2958")
    print(f"\n  Your CIPS-Net: bPQ={bpq_mean:.4f}, mPQ={mpq_mean:.4f}")
    
    print("\n" + "=" * 70)
    print(f"All evaluation results saved to: {EVAL_OUTPUT_DIR}")
    print("=" * 70)
else:
    print("‚ö†Ô∏è  No results available for summary.")