# Unified Text-Guided Segmentation Training

This notebook provides a unified interface for training all 14 text-guided segmentation models on histopathology datasets.

## Experiment Setup
- **Training**: PanNuke 3-fold cross-validation
- **Zero-shot Evaluation**: CoNSeP, MoNuSAC
- **Fine-tuning**: Optional fine-tuning on target datasets
- **Comparison**: Against CIPS-Net baseline

## Available Models (14 total)
| # | Model | Venue | Status |
|---|-------|-------|--------|
| 1 | CLIPSeg | CVPR 2022 | ✅ Working |
| 2 | LSeg | ICLR 2022 | ✅ Working |
| 3 | GroupViT | CVPR 2022 | ✅ Working |
| 4 | SAN | CVPR 2023 | ✅ Working |
| 5 | FC-CLIP | NeurIPS 2023 | ✅ Working |
| 6 | OVSeg | CVPR 2023 | ✅ Working |
| 7 | CAT-Seg | CVPR 2024 | ✅ Working |
| 8 | SED | CVPR 2024 | ✅ Working |
| 9 | MAFT+ | ECCV 2024 Oral | ⚠️ Needs fix |
| 10 | X-Decoder | CVPR 2023 | ⚠️ Needs fix |
| 11 | OpenSeeD | ICCV 2023 | ✅ Working |
| 12 | ODISE | CVPR 2023 | ✅ Working |
| 13 | TagAlign | arXiv 2023 | ⚠️ Needs fix |
| 14 | Semantic-SAM | ECCV 2024 | ✅ Working |

In [1]:
# ============================================================================
# CONFIGURATION - CHANGE THIS TO SELECT MODEL
# ============================================================================
from datetime import datetime

# Model selection - change this to train different models
MODEL_NAME = "groupvit"  # Options: clipseg, lseg, groupvit, san, fc_clip, ovseg, 
                        #          cat_seg, sed, openseed, odise, semantic_sam

# Experiment timestamp for unique naming
EXPERIMENT_TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")
EXPERIMENT_NAME = f"{MODEL_NAME}_{EXPERIMENT_TIMESTAMP}"

# Training configuration
CONFIG = {
    # Dataset
    "num_classes": 5,
    "image_size": 224,  # SAME AS CIPS-NET (224x224) for fair comparison
    
    # Training
    "batch_size": 8,
    "num_epochs": 20,
    "learning_rate": 1e-4,
    "weight_decay": 0.01,
    
    # Cross-validation
    "num_folds": 3,
    
    # Model
    "clip_model": "ViT-B/16",
    "freeze_clip": True,
    
    # Paths
    "data_root": "/mnt/e3dbc9b9-6856-470d-84b1-ff55921cd906/Datasets/Nikhil/Histopathology_Work/Dataset",
    "pannuke_preprocess_root": "/mnt/e3dbc9b9-6856-470d-84b1-ff55921cd906/Datasets/Nikhil/Histopathology_Work/PanNuke_Preprocess",
    "results_dir": f"/mnt/e3dbc9b9-6856-470d-84b1-ff55921cd906/Datasets/Nikhil/Histopathology_Work/results/{EXPERIMENT_NAME}",
    
    # External datasets for zero-shot evaluation
    "external_datasets_root": "/mnt/e3dbc9b9-6856-470d-84b1-ff55921cd906/Datasets/Nikhil/Histopathology_Work/Histopathology_Datasets_Official",
    
    # Device
    "device": "cuda",
    
    # Early stopping
    "early_stopping_patience": 5,
}

# ============================================================================
# TEXT VARIANTS - MATCHING CIPS-NET EVALUATION PROTOCOL EXACTLY
# ============================================================================
# 3 Text Variants (same as CIPS-Net):
#   1. Per-Image Text: Read from annotations.csv 'instruction' column (varies per image)
#   2. Common Text: Single generic instruction for ALL images
#   3. No Text: Empty string (pure visual features)
# ============================================================================

# Class names (SAME AS CIPS-NET)
CLASS_NAMES = ['Neoplastic', 'Inflammatory', 'Connective_Soft_tissue', 'Dead', 'Epithelial']

# Per-Image Text: Will be read from annotations.csv during evaluation
# (Each image has its own unique instruction based on classes present)
# Example: "Identify all neoplastic, inflammatory and connective soft tissue regions."

# Common Text: Single sentence mentioning ALL classes (EXACTLY SAME AS CIPS-NET)
COMMON_TEXT = "Segment all Neoplastic, Inflammatory, Connective, Dead, and Epithelial cells in the image."

# No Text: Empty string (tests pure visual features) (SAME AS CIPS-NET)
NO_TEXT = ""

# For training, we use class-specific prompts (one per class for VLM models)
# This is the standard approach for CLIP-based segmentation models
TRAINING_TEXT_PROMPTS = [
    "neoplastic cells",           # Class 0
    "inflammatory cells",          # Class 1
    "connective tissue cells",     # Class 2
    "dead cells",                  # Class 3
    "epithelial cells",            # Class 4
]

# Backwards-compatible dictionary of text variants used by older cells
# - 'per_image_text' is represented by `None` because it requires reading
#   per-image instructions from `annotations.csv` at evaluation time.
# - 'common_text' and 'no_text' are represented as a list repeated per class
#   to match VLMs that expect one prompt per class.
TEXT_PROMPTS_VARIANTS = {
    'per_image_text': None,
    'common_text': [COMMON_TEXT] * CONFIG['num_classes'],
    'no_text': [NO_TEXT] * CONFIG['num_classes'],
}

# Default prompts for training
TEXT_PROMPTS = TRAINING_TEXT_PROMPTS

print(f"=" * 70)
print(f"EXPERIMENT CONFIGURATION")
print(f"=" * 70)
print(f"Model: {MODEL_NAME}")
print(f"Experiment Name: {EXPERIMENT_NAME}")
print(f"Results: {CONFIG['results_dir']}")
print(f"\nTraining Config:")
print(f"  Image Size: {CONFIG['image_size']}x{CONFIG['image_size']} (SAME AS CIPS-NET)")
print(f"  Epochs: {CONFIG['num_epochs']} (Early Stopping: {CONFIG['early_stopping_patience']})")
print(f"  Batch Size: {CONFIG['batch_size']}")
print(f"  3-Fold CV on PanNuke")
print(f"\nText Variants (CIPS-Net Protocol):")
print(f"  1. Per-Image Text: From annotations.csv 'instruction' column")
print(f"  2. Common Text: '{COMMON_TEXT[:50]}...'")
print(f"  3. No Text: Empty string")
print(f"=" * 70)

EXPERIMENT CONFIGURATION
Model: groupvit
Experiment Name: groupvit_20260128_150847
Results: /mnt/e3dbc9b9-6856-470d-84b1-ff55921cd906/Datasets/Nikhil/Histopathology_Work/results/groupvit_20260128_150847

Training Config:
  Image Size: 224x224 (SAME AS CIPS-NET)
  Epochs: 20 (Early Stopping: 5)
  Batch Size: 8
  3-Fold CV on PanNuke

Text Variants (CIPS-Net Protocol):
  1. Per-Image Text: From annotations.csv 'instruction' column
  2. Common Text: 'Segment all Neoplastic, Inflammatory, Connective, ...'
  3. No Text: Empty string


In [2]:
# Setup and imports
import sys
import os

WORKSPACE = "/mnt/e3dbc9b9-6856-470d-84b1-ff55921cd906/Datasets/Nikhil/Histopathology_Work"
sys.path.insert(0, WORKSPACE)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import pandas as pd
from pathlib import Path
import time
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from collections import defaultdict
import json

# Clear cached imports
modules_to_remove = [m for m in sys.modules.keys() if 'TextGuidedSegmentation' in m]
for m in modules_to_remove:
    del sys.modules[m]

# Import our package
from TextGuidedSegmentation import (
    get_model,
    list_models,
    print_model_summary,
    DEFAULT_TEXT_PROMPTS,
    PanNukeDataset,
    DiceLoss,
    FocalLoss,
    CombinedSegmentationLoss,
    compute_iou,
    compute_dice,
    compute_f1,
    ConfusionMatrix,
    MetricTracker,
)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    
print(f"\nAvailable models:")
print_model_summary()

  import pynvml  # type: ignore[import]


PyTorch version: 2.9.1+cu128
CUDA available: True
CUDA device: NVIDIA RTX A5000

Available models:

Text-Guided Segmentation Models for Histopathology

#   Model           Venue                Description                             
--------------------------------------------------------------------------------
1   CLIPSeg         CVPR 2022            Uses CLIP features with FiLM condition..
2   LSeg            ICLR 2022            Dense Prediction Transformer with CLIP..
3   GroupViT        CVPR 2022            Hierarchical grouping mechanism with c..
4   SAN             CVPR 2023            Side adapter network preserving CLIP c..
5   FC-CLIP         NeurIPS 2023         Fully convolutional CLIP for dense pre..
6   OVSeg           CVPR 2023            Mask-adapted CLIP with region-level cl..
7   CAT-Seg         CVPR 2024            Cost aggregation with spatial semantic..
8   SED             CVPR 2024            Simple encoder-decoder with category-g..
9   MAFT+           ECCV 2024

In [3]:
# Create results directory structure
import os
results_dir = Path(CONFIG['results_dir'])
results_dir.mkdir(parents=True, exist_ok=True)

# Sub-directories for different experiment phases
(results_dir / "models").mkdir(exist_ok=True)           # Trained model checkpoints
(results_dir / "pannuke_3fold").mkdir(exist_ok=True)    # 3-fold CV results
(results_dir / "text_variants").mkdir(exist_ok=True)    # Text prompt evaluation
(results_dir / "zero_shot").mkdir(exist_ok=True)        # Zero-shot evaluation
(results_dir / "fine_tuned").mkdir(exist_ok=True)       # Fine-tuned model results

print("Results directory:", results_dir)
print(f"\nClass names: {CLASS_NAMES}")
print(f"\nTraining prompts: {TRAINING_TEXT_PROMPTS}")
print(f"\nText Variants for Evaluation (CIPS-Net Protocol):")
print(f"  1. Per-Image Text: From annotations.csv 'instruction' column")
print(f"  2. Common Text: '{COMMON_TEXT}'")
print(f"  3. No Text: '' (empty string)")

Results directory: /mnt/e3dbc9b9-6856-470d-84b1-ff55921cd906/Datasets/Nikhil/Histopathology_Work/results/groupvit_20260128_150847

Class names: ['Neoplastic', 'Inflammatory', 'Connective_Soft_tissue', 'Dead', 'Epithelial']

Training prompts: ['neoplastic cells', 'inflammatory cells', 'connective tissue cells', 'dead cells', 'epithelial cells']

Text Variants for Evaluation (CIPS-Net Protocol):
  1. Per-Image Text: From annotations.csv 'instruction' column
  2. Common Text: 'Segment all Neoplastic, Inflammatory, Connective, Dead, and Epithelial cells in the image.'
  3. No Text: '' (empty string)


In [4]:
# Create datasets
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import cv2

class SimplePanNukeDataset(Dataset):
    """
    Simple PanNuke dataset loader.
    
    IMPORTANT: The masks are stored as separate per-class binary masks:
    - {base_name}_channel_0_Neoplastic.png
    - {base_name}_channel_1_Inflammatory.png
    - {base_name}_channel_2_Connective_Soft_tissue.png
    - {base_name}_channel_3_Dead.png
    - {base_name}_channel_4_Epithelial.png
    - {base_name}_channel_5_Background.png
    
    We need to combine them into a single mask with class indices.
    """
    
    # Class channel mapping
    CHANNEL_NAMES = [
        "Neoplastic",           # Channel 0 -> Class 0
        "Inflammatory",          # Channel 1 -> Class 1
        "Connective_Soft_tissue",# Channel 2 -> Class 2
        "Dead",                  # Channel 3 -> Class 3
        "Epithelial",           # Channel 4 -> Class 4
        "Background",           # Channel 5 -> Background (not used)
    ]
    
    def __init__(
        self,
        image_dir: str,
        mask_dir: str,
        csv_file: str = None,
        image_size: int = 224,
        split: str = "train",
        fold: int = 0,
    ):
        self.image_dir = Path(image_dir)
        self.mask_dir = Path(mask_dir)
        self.image_size = image_size
        
        # Get image files
        self.image_files = sorted(list(self.image_dir.glob("*.png")) + 
                                  list(self.image_dir.glob("*.jpg")))
        
        # Simple split based on fold
        n = len(self.image_files)
        fold_size = n // 3
        
        if split == "train":
            # Use 2 folds for training
            indices = list(range(n))
            test_start = fold * fold_size
            test_end = (fold + 1) * fold_size if fold < 2 else n
            indices = [i for i in indices if not (test_start <= i < test_end)]
            self.image_files = [self.image_files[i] for i in indices]
        elif split == "val":
            # Use 1 fold for validation
            test_start = fold * fold_size
            test_end = (fold + 1) * fold_size if fold < 2 else n
            self.image_files = self.image_files[test_start:test_end]
        
        # CLIP normalization
        self.normalize = transforms.Normalize(
            mean=[0.48145466, 0.4578275, 0.40821073],
            std=[0.26862954, 0.26130258, 0.27577711]
        )
        
        print(f"Dataset [{split}]: {len(self.image_files)} images")
    
    def __len__(self):
        return len(self.image_files)
    
    def _get_base_name(self, img_path: Path) -> str:
        """Extract base name from image path (remove _img suffix)."""
        stem = img_path.stem
        if stem.endswith("_img"):
            return stem[:-4]  # Remove "_img"
        return stem
    
    def _load_combined_mask(self, base_name: str) -> np.ndarray:
        """
        Load and combine per-channel masks into a single class mask.
        
        Returns mask with class indices 0-4 (ignoring background channel 5).
        Priority: If multiple classes present at a pixel, use lowest class index.
        """
        H, W = self.image_size, self.image_size
        combined_mask = np.full((H, W), fill_value=255, dtype=np.uint8)  # 255 = ignore
        
        # Load each class channel (0-4, skip background)
        for class_idx in range(5):  # Only 5 classes (0-4)
            channel_name = self.CHANNEL_NAMES[class_idx]
            mask_filename = f"{base_name}_channel_{class_idx}_{channel_name}.png"
            mask_path = self.mask_dir / mask_filename
            
            if mask_path.exists():
                channel_mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
                if channel_mask is not None:
                    channel_mask = cv2.resize(channel_mask, (W, H), interpolation=cv2.INTER_NEAREST)
                    # Set pixels where this class is present (threshold > 127)
                    class_pixels = channel_mask > 127
                    # Only set if not already assigned (lower class index has priority)
                    combined_mask[class_pixels & (combined_mask == 255)] = class_idx
        
        # Any remaining 255 (ignore) pixels that have no class -> set to background (class 0)
        # Actually, let's check if there's a background channel
        bg_mask_path = self.mask_dir / f"{base_name}_channel_5_Background.png"
        if bg_mask_path.exists():
            bg_mask = cv2.imread(str(bg_mask_path), cv2.IMREAD_GRAYSCALE)
            if bg_mask is not None:
                bg_mask = cv2.resize(bg_mask, (W, H), interpolation=cv2.INTER_NEAREST)
                # Background pixels remain as ignore_index or set to a specific value
                # For now, keep them as 255 (ignore in loss)
                pass
        
        # If still some unassigned pixels, they will be ignored (255)
        # Convert 255 to -100 for ignore_index compatibility
        combined_mask = combined_mask.astype(np.int64)
        combined_mask[combined_mask == 255] = -100  # ignore_index
        
        return combined_mask
    
    def __getitem__(self, idx):
        # Load image
        img_path = self.image_files[idx]
        image = Image.open(img_path).convert("RGB")
        image = image.resize((self.image_size, self.image_size), Image.BILINEAR)
        image = transforms.ToTensor()(image)
        image = self.normalize(image)
        
        # Load combined mask
        base_name = self._get_base_name(img_path)
        mask = self._load_combined_mask(base_name)
        mask = torch.from_numpy(mask)
        
        return {
            "image": image,
            "mask": mask,
            "image_path": str(img_path),
        }

# Create initial datasets for testing (fold 0)
# The actual 3-fold CV will create datasets per fold in the training cell
train_dataset = SimplePanNukeDataset(
    image_dir=f"{CONFIG['data_root']}/multi_images",
    mask_dir=f"{CONFIG['data_root']}/multi_masks",
    image_size=CONFIG['image_size'],
    split="train",
    fold=0,  # Default to fold 0 for initial testing
)

val_dataset = SimplePanNukeDataset(
    image_dir=f"{CONFIG['data_root']}/multi_images",
    mask_dir=f"{CONFIG['data_root']}/multi_masks",
    image_size=CONFIG['image_size'],
    split="val",
    fold=0,  # Default to fold 0 for initial testing
)

# Quick verification
print("\nVerifying mask loading...")
sample = train_dataset[0]
sample_mask = sample['mask']
print(f"Sample mask shape: {sample_mask.shape}")
print(f"Sample mask unique values: {torch.unique(sample_mask)}")
for i in range(5):
    count = (sample_mask == i).sum().item()
    if count > 0:
        print(f"  Class {i} ({CLASS_NAMES[i]}): {count} pixels")

Dataset [train]: 5268 images
Dataset [val]: 2633 images

Verifying mask loading...
Sample mask shape: torch.Size([224, 224])
Sample mask unique values: tensor([-100,    1])
  Class 1 (Inflammatory): 768 pixels


In [5]:
# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=0,  # Set to 0 to avoid multiprocessing warnings
    pin_memory=True,
    drop_last=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=0,  # Set to 0 to avoid multiprocessing warnings
    pin_memory=True,
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

Train batches: 658
Val batches: 330


In [6]:
# Create model
print(f"\nCreating model: {MODEL_NAME}")

model = get_model(
    MODEL_NAME,
    num_classes=CONFIG['num_classes'],
    image_size=CONFIG['image_size'],
    clip_model=CONFIG['clip_model'],
    freeze_clip=CONFIG['freeze_clip'],
    device=CONFIG['device'],
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Frozen parameters: {total_params - trainable_params:,}")


Creating model: groupvit
Total parameters: 175,941,762
Trainable parameters: 26,321,025
Frozen parameters: 149,620,737


In [7]:
# Setup training

# Loss function
criterion = CombinedSegmentationLoss(
    ce_weight=1.0,
    dice_weight=1.0,
    focal_weight=0.0,
    ignore_index=-100,
)

# Optimizer (only trainable parameters)
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay'],
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=CONFIG['num_epochs'],
    eta_min=1e-6,
)

# Mixed precision scaler
scaler = torch.amp.GradScaler('cuda')

# Early stopping config
EARLY_STOPPING_PATIENCE = 8

print("Training setup complete!")
print(f"Early stopping patience: {EARLY_STOPPING_PATIENCE} epochs")

Training setup complete!
Early stopping patience: 8 epochs


In [8]:
def train_one_epoch(model, loader, optimizer, criterion, scaler, text_prompts, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0.0
    num_batches = 0
    
    pbar = tqdm(loader, desc="Training")
    for batch in pbar:
        images = batch['image'].to(device)
        masks = batch['mask'].to(device)
        
        optimizer.zero_grad()
        
        with torch.amp.autocast('cuda'):
            outputs = model(images, text_prompts)
            logits = outputs['logits']
            loss = criterion(logits, masks)
        
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        num_batches += 1
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / num_batches


@torch.no_grad()
def validate(model, loader, criterion, text_prompts, device, num_classes=5):
    """Validate the model."""
    model.eval()
    total_loss = 0.0
    num_batches = 0
    confusion = torch.zeros(num_classes, num_classes, device=device)
    
    pbar = tqdm(loader, desc="Validation")
    for batch in pbar:
        images = batch['image'].to(device)
        masks = batch['mask'].to(device)
        
        outputs = model(images, text_prompts)
        logits = outputs['logits']
        loss = criterion(logits, masks)
        
        total_loss += loss.item()
        num_batches += 1
        
        preds = logits.argmax(dim=1).flatten()
        targets = masks.flatten()
        valid_mask = (targets >= 0) & (targets < num_classes)
        preds = preds[valid_mask]
        targets = targets[valid_mask]
        
        for p, t in zip(preds, targets):
            confusion[t, p] += 1
    
    avg_loss = total_loss / num_batches
    
    ious, dices = [], []
    for c in range(num_classes):
        tp = confusion[c, c]
        fp = confusion[:, c].sum() - tp
        fn = confusion[c, :].sum() - tp
        ious.append((tp / (tp + fp + fn + 1e-8)).item())
        dices.append((2 * tp / (2 * tp + fp + fn + 1e-8)).item())
    
    return {
        'loss': avg_loss,
        'mean_iou': np.mean(ious),
        'mean_dice': np.mean(dices),
        'per_class_iou': ious,
        'per_class_dice': dices,
    }

print("Training functions defined!")

Training functions defined!


In [9]:
# ============================================================================
# PHASE 1: 3-FOLD CROSS-VALIDATION TRAINING ON PANNUKE
# ============================================================================
print("=" * 70)
print("PHASE 1: 3-FOLD CROSS-VALIDATION TRAINING ON PANNUKE")
print("=" * 70)

all_fold_results = {}

for fold in range(CONFIG['num_folds']):
    print(f"\n{'='*70}")
    print(f"FOLD {fold + 1}/{CONFIG['num_folds']}")
    print(f"{'='*70}")
    
    # Create datasets for this fold
    train_dataset = SimplePanNukeDataset(
        image_dir=f"{CONFIG['data_root']}/multi_images",
        mask_dir=f"{CONFIG['data_root']}/multi_masks",
        image_size=CONFIG['image_size'],
        split="train",
        fold=fold,
    )
    
    val_dataset = SimplePanNukeDataset(
        image_dir=f"{CONFIG['data_root']}/multi_images",
        mask_dir=f"{CONFIG['data_root']}/multi_masks",
        image_size=CONFIG['image_size'],
        split="val",
        fold=fold,
    )
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, batch_size=CONFIG['batch_size'], shuffle=True,
        num_workers=0, pin_memory=True, drop_last=True,
    )
    val_loader = DataLoader(
        val_dataset, batch_size=CONFIG['batch_size'], shuffle=False,
        num_workers=0, pin_memory=True,
    )
    
    # Create fresh model for each fold
    model = get_model(
        MODEL_NAME,
        num_classes=CONFIG['num_classes'],
        image_size=CONFIG['image_size'],
        clip_model=CONFIG['clip_model'],
        freeze_clip=CONFIG['freeze_clip'],
        device=CONFIG['device'],
    )
    
    # Setup optimizer and scheduler
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=CONFIG['learning_rate'],
        weight_decay=CONFIG['weight_decay'],
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=CONFIG['num_epochs'], eta_min=1e-6,
    )
    scaler = torch.amp.GradScaler('cuda')
    
    # Training loop
    best_metric = 0.0
    epochs_without_improvement = 0
    history = defaultdict(list)
    
    for epoch in range(CONFIG['num_epochs']):
        print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}")
        
        train_loss = train_one_epoch(
            model, train_loader, optimizer, criterion, scaler,
            TEXT_PROMPTS, CONFIG['device']
        )
        
        val_metrics = validate(
            model, val_loader, criterion, TEXT_PROMPTS,
            CONFIG['device'], CONFIG['num_classes']
        )
        
        scheduler.step()
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_metrics['loss'])
        history['val_iou'].append(val_metrics['mean_iou'])
        history['val_dice'].append(val_metrics['mean_dice'])
        
        print(f"  Train Loss: {train_loss:.4f} | Val Loss: {val_metrics['loss']:.4f}")
        print(f"  Val mIoU: {val_metrics['mean_iou']:.4f} | Val mDice: {val_metrics['mean_dice']:.4f}")
        
        # Save best model
        if val_metrics['mean_iou'] > best_metric:
            best_metric = val_metrics['mean_iou']
            epochs_without_improvement = 0
            
            checkpoint_path = results_dir / "models" / f"best_{MODEL_NAME}_fold{fold}.pth"
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_metric': best_metric,
                'val_metrics': val_metrics,
                'config': CONFIG,
            }, checkpoint_path)
            print(f"  ✓ Saved best model (mIoU: {best_metric:.4f})")
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= CONFIG['early_stopping_patience']:
                print(f"\n⚠️ Early stopping triggered after {epoch+1} epochs")
                break
    
    # Store fold results
    all_fold_results[f'fold_{fold}'] = {
        'best_iou': best_metric,
        'final_dice': history['val_dice'][-1] if history['val_dice'] else 0,
        'history': dict(history),
        'epochs_trained': len(history['train_loss']),
    }
    
    # Clear GPU memory
    del model, optimizer, scheduler
    torch.cuda.empty_cache()
    
    print(f"\nFold {fold} complete! Best mIoU: {best_metric:.4f}")

# Aggregate and save 3-fold CV results
mean_iou = np.mean([r['best_iou'] for r in all_fold_results.values()])
std_iou = np.std([r['best_iou'] for r in all_fold_results.values()])
mean_dice = np.mean([r['final_dice'] for r in all_fold_results.values()])
std_dice = np.std([r['final_dice'] for r in all_fold_results.values()])

cv_summary = {
    'model_name': MODEL_NAME,
    'experiment_name': EXPERIMENT_NAME,
    'mean_iou': mean_iou,
    'std_iou': std_iou,
    'mean_dice': mean_dice,
    'std_dice': std_dice,
    'fold_results': all_fold_results,
    'config': CONFIG,
}

cv_results_path = results_dir / "pannuke_3fold" / "cv_results.json"
with open(cv_results_path, 'w') as f:
    json.dump(cv_summary, f, indent=2, default=str)

print("\n" + "=" * 70)
print("3-FOLD CV TRAINING COMPLETE!")
print("=" * 70)
print(f"Mean mIoU: {mean_iou:.4f} ± {std_iou:.4f}")
print(f"Mean mDice: {mean_dice:.4f} ± {std_dice:.4f}")
print(f"Results saved to: {cv_results_path}")

PHASE 1: 3-FOLD CROSS-VALIDATION TRAINING ON PANNUKE

FOLD 1/3
Dataset [train]: 5268 images
Dataset [val]: 2633 images

Epoch 1/20


Training:   0%|          | 0/658 [00:00<?, ?it/s]

Validation:   0%|          | 0/330 [00:00<?, ?it/s]

  Train Loss: 2.4642 | Val Loss: 2.4907
  Val mIoU: 0.1228 | Val mDice: 0.1749
  ✓ Saved best model (mIoU: 0.1228)

Epoch 2/20


Training:   0%|          | 0/658 [00:00<?, ?it/s]

Validation:   0%|          | 0/330 [00:00<?, ?it/s]

  Train Loss: nan | Val Loss: nan
  Val mIoU: 0.1052 | Val mDice: 0.1379

Epoch 3/20


Training:   0%|          | 0/658 [00:00<?, ?it/s]

Validation:   0%|          | 0/330 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# ============================================================================
# PHASE 1.5: POST-HOC EVALUATION - mPQ ON PANNUKE TEST FOLDS
# ============================================================================
# Loads the best model from each fold and evaluates mPQ (and mDice) on the
# corresponding test fold using the same PQ implementation as Phase 2.
# No retraining performed — pure post-hoc evaluation of saved checkpoints.
# ============================================================================

from scipy import ndimage
from scipy.optimize import linear_sum_assignment

print("\n" + "=" * 70)
print("PHASE 1.5: POST-HOC mPQ EVALUATION ON PANNUKE 3-FOLD")
print("=" * 70)

pq_summary = {}

# Define metrics functions (same as Phase 2)
def compute_dice_per_class(pred_mask, gt_mask, num_classes):
    """Compute per-class Dice score."""
    dice_scores = {}
    for c in range(num_classes):
        pred_c = (pred_mask == c).float()
        gt_c = (gt_mask == c).float()
        intersection = (pred_c * gt_c).sum()
        union = pred_c.sum() + gt_c.sum()
        if union > 0:
            dice = (2 * intersection / (union + 1e-8)).item()
        else:
            dice = 1.0 if intersection == 0 else 0.0
        dice_scores[c] = dice
    return dice_scores

def compute_pq_per_class(pred_semantic, gt_semantic, gt_instance, gt_class_map, num_classes):
    """
    Compute Panoptic Quality per class (SAME AS CIPS-NET implementation).
    Returns dict per class with PQ, DQ, SQ and counts.
    """
    results = {}
    for c in range(num_classes):
        gt_mask_c = (gt_semantic == c)
        pred_mask_c = (pred_semantic == c)
        gt_instances_c = np.unique(gt_instance[gt_mask_c.cpu().numpy()])
        gt_instances_c = gt_instances_c[gt_instances_c > 0]
        pred_labeled, num_pred = ndimage.label(pred_mask_c.cpu().numpy())
        pred_instances_c = list(range(1, num_pred + 1))
        if len(gt_instances_c) == 0 and len(pred_instances_c) == 0:
            results[c] = {'PQ': 1.0, 'DQ': 1.0, 'SQ': 1.0, 'TP': 0, 'FP': 0, 'FN': 0}
            continue
        if len(gt_instances_c) == 0:
            results[c] = {'PQ': 0.0, 'DQ': 0.0, 'SQ': 0.0, 'TP': 0, 'FP': len(pred_instances_c), 'FN': 0}
            continue
        if len(pred_instances_c) == 0:
            results[c] = {'PQ': 0.0, 'DQ': 0.0, 'SQ': 0.0, 'TP': 0, 'FP': 0, 'FN': len(gt_instances_c)}
            continue
        iou_matrix = np.zeros((len(gt_instances_c), len(pred_instances_c)))
        for i, gt_id in enumerate(gt_instances_c):
            gt_inst_mask = (gt_instance.numpy() == gt_id)
            for j, pred_id in enumerate(pred_instances_c):
                pred_inst_mask = (pred_labeled == pred_id)
                intersection = np.logical_and(gt_inst_mask, pred_inst_mask).sum()
                union = np.logical_or(gt_inst_mask, pred_inst_mask).sum()
                if union > 0:
                    iou_matrix[i, j] = intersection / union
        row_ind, col_ind = linear_sum_assignment(-iou_matrix)
        tp = 0
        matched_iou_sum = 0
        matched_gt = set()
        matched_pred = set()
        for r, c_idx in zip(row_ind, col_ind):
            if iou_matrix[r, c_idx] > 0.5:
                tp += 1
                matched_iou_sum += iou_matrix[r, c_idx]
                matched_gt.add(r)
                matched_pred.add(c_idx)
        fp = len(pred_instances_c) - len(matched_pred)
        fn = len(gt_instances_c) - len(matched_gt)
        if tp > 0:
            sq = matched_iou_sum / tp
            dq = tp / (tp + 0.5 * fp + 0.5 * fn)
            pq = sq * dq
        else:
            sq = 0.0; dq = 0.0; pq = 0.0
        results[c] = {'PQ': pq, 'DQ': dq, 'SQ': sq, 'TP': tp, 'FP': fp, 'FN': fn}
    return results

# Load annotations.csv from PanNuke_Preprocess
ann_path = Path(CONFIG['pannuke_preprocess_root']) / 'annotations.csv'
if not ann_path.exists():
    print(f"⚠️ annotations.csv not found at {ann_path}. Cannot run PQ evaluation.")
else:
    ann_df = pd.read_csv(ann_path)
    for fold in range(CONFIG['num_folds']):
        print(f"\nEvaluating fold {fold} with saved best model...")
        ckpt = results_dir / 'models' / f"best_{MODEL_NAME}_fold{fold}.pth"
        if not ckpt.exists():
            print(f"  ⚠️ Checkpoint not found: {ckpt} — skipping fold {fold}")
            continue
        # Load model weights
        fold_model = get_model(
            MODEL_NAME,
            num_classes=CONFIG['num_classes'],
            image_size=CONFIG['image_size'],
            clip_model=CONFIG['clip_model'],
            freeze_clip=True,
            device=CONFIG['device'],
        )
        cp = torch.load(ckpt, map_location=CONFIG['device'], weights_only=False)
        fold_model.load_state_dict(cp['model_state_dict'])
        fold_model.eval()
        
        fold_rows = ann_df[ann_df['fold'] == (fold+1) if 'fold' in ann_df.columns and ann_df['fold'].max()>=1 else ann_df['fold'] == fold]
        # Some preprocess use 1-indexed folds, handle both: try fold, then fold+1
        if len(fold_rows) == 0:
            fold_rows = ann_df[ann_df['fold'] == fold]
        if len(fold_rows) == 0:
            print(f"  ⚠️ No annotations found for fold {fold}, skipping.")
            continue
        
        all_dice = {c: [] for c in range(CONFIG['num_classes'])}
        all_pq = {c: [] for c in range(CONFIG['num_classes'])}
        normalize = transforms.Normalize(mean=[0.48145466,0.4578275,0.40821073], std=[0.26862954,0.26130258,0.27577711])
        
        for idx, row in tqdm(fold_rows.reset_index(drop=True).iterrows(), total=len(fold_rows), desc=f"Fold {fold} Test Eval"):
            image_id = row['image_id']
            img_path = Path(CONFIG['pannuke_preprocess_root']) / 'images' / f'fold{row["fold"]}' / f'{image_id}.png'
            mask_path = Path(CONFIG['pannuke_preprocess_root']) / 'masks' / f'fold{row["fold"]}' / f'{image_id}.npz'
            if not img_path.exists() or not mask_path.exists():
                continue
            image = Image.open(img_path).convert('RGB').resize((CONFIG['image_size'], CONFIG['image_size']), Image.BILINEAR)
            img_t = transforms.ToTensor()(image)
            img_t = normalize(img_t).unsqueeze(0).to(CONFIG['device'])
            mask_data = np.load(mask_path)
            masks = mask_data['masks']
            gt_semantic = np.zeros((masks.shape[0], masks.shape[1]), dtype=np.int64)
            for c in range(min(masks.shape[2], CONFIG['num_classes'])):
                gt_semantic[masks[:,:,c] > 0] = c
            gt_instance = np.zeros((masks.shape[0], masks.shape[1]), dtype=np.int32)
            inst_id = 1
            gt_class_map = {}
            for c in range(min(masks.shape[2], CONFIG['num_classes'])):
                binary_mask = masks[:,:,c] > 0
                if binary_mask.sum() > 0:
                    labeled, ninst = ndimage.label(binary_mask)
                    for i in range(1, ninst+1):
                        gt_instance[labeled == i] = inst_id
                        gt_class_map[inst_id] = c
                        inst_id += 1
            # resize masks
            gt_semantic = torch.from_numpy(np.array(Image.fromarray(gt_semantic.astype(np.uint8)).resize((CONFIG['image_size'],CONFIG['image_size']), Image.NEAREST))).long()
            gt_instance = torch.from_numpy(np.array(Image.fromarray(gt_instance.astype(np.int32)).resize((CONFIG['image_size'],CONFIG['image_size']), Image.NEAREST)))
            
            # forward with per-image instruction from annotations
            text_instruction = row['instruction'] if pd.notna(row['instruction']) else "Segment all tissue types."
            text_prompts = [text_instruction] * CONFIG['num_classes']
            with torch.no_grad():
                out = fold_model(img_t, text_prompts)
                logits = out['logits']
                pred_semantic = logits.argmax(dim=1).squeeze(0).cpu()
            # metrics
            dice_scores = compute_dice_per_class(pred_semantic, gt_semantic, CONFIG['num_classes'])
            pq_scores = compute_pq_per_class(pred_semantic, gt_semantic, gt_instance, gt_class_map, CONFIG['num_classes'])
            for c in range(CONFIG['num_classes']):
                all_dice[c].append(dice_scores[c])
                all_pq[c].append(pq_scores[c]['PQ'])
        # aggregate
        per_class_dice = [np.mean(all_dice[c]) if len(all_dice[c])>0 else 0.0 for c in range(CONFIG['num_classes'])]
        per_class_pq = [np.mean(all_pq[c]) if len(all_pq[c])>0 else 0.0 for c in range(CONFIG['num_classes'])]
        fold_result = {
            'fold': fold,
            'num_images': len(fold_rows),
            'per_class_dice': per_class_dice,
            'per_class_pq': per_class_pq,
            'mean_dice': float(np.mean(per_class_dice)),
            'mean_pq': float(np.mean(per_class_pq)),
        }
        pq_summary[f'fold_{fold}'] = fold_result
        print(f"  Fold {fold}: mDice={fold_result['mean_dice']:.4f}, mPQ={fold_result['mean_pq']:.4f}")

    # Save PQ summary
    pq_path = results_dir / 'pannuke_3fold' / 'pq_results.json'
    with open(pq_path, 'w') as f:
        json.dump(pq_summary, f, indent=2)
    print(f"\n✓ Saved PQ summary to: {pq_path}")


PHASE 1.5: POST-HOC mPQ EVALUATION ON PANNUKE 3-FOLD

Evaluating fold 0 with saved best model...


Fold 0 Test Eval:   0%|          | 0/2656 [00:00<?, ?it/s]

  Fold 0: mDice=0.6905, mPQ=0.4993

Evaluating fold 1 with saved best model...


Fold 1 Test Eval:   0%|          | 0/2523 [00:00<?, ?it/s]

  Fold 1: mDice=0.6806, mPQ=0.4897

Evaluating fold 2 with saved best model...


Fold 2 Test Eval:   0%|          | 0/2722 [00:00<?, ?it/s]

  Fold 2: mDice=0.6882, mPQ=0.4969

✓ Saved PQ summary to: /mnt/e3dbc9b9-6856-470d-84b1-ff55921cd906/Datasets/Nikhil/Histopathology_Work/results/lseg_20260127_165728/pannuke_3fold/pq_results.json


In [22]:
# ============================================================================
# PHASE 2: TEXT VARIANTS EVALUATION (CIPS-NET PROTOCOL - WITH mPQ)
# ============================================================================
# Evaluates 3 text variants on PanNuke test set (same as CIPS-Net):
#   1. Per-Image Text: Read from annotations.csv 'instruction' column
#   2. Common Text: Fixed sentence for ALL images
#   3. No Text: Empty string (pure visual features)
#
# Metrics: mDice, mPQ (same as CIPS-Net)
# ============================================================================

from scipy import ndimage
from scipy.optimize import linear_sum_assignment

print("\n" + "=" * 70)
print("PHASE 2: TEXT VARIANTS EVALUATION (CIPS-NET PROTOCOL)")
print("=" * 70)

# Load best model from fold 0
model = get_model(
    MODEL_NAME,
    num_classes=CONFIG['num_classes'],
    image_size=CONFIG['image_size'],
    clip_model=CONFIG['clip_model'],
    freeze_clip=True,
    device=CONFIG['device'],
)

checkpoint_path = results_dir / "models" / f"best_{MODEL_NAME}_fold0.pth"
if checkpoint_path.exists():
    checkpoint = torch.load(checkpoint_path, map_location=CONFIG['device'], weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"✓ Loaded best model from fold 0 (epoch {checkpoint['epoch']})")
else:
    print("⚠️ No checkpoint found for fold 0, using current model")

# ============================================================================
# Load annotations.csv for per-image text instructions
# ============================================================================
annotations_path = Path(CONFIG['pannuke_preprocess_root']) / "annotations.csv"
if annotations_path.exists():
    annotations_df = pd.read_csv(annotations_path)
    print(f"✓ Loaded annotations.csv: {len(annotations_df)} entries")
else:
    print(f"⚠️ annotations.csv not found at {annotations_path}")
    annotations_df = None

# ============================================================================
# Metrics Functions (same as CIPS-Net)
# ============================================================================
def compute_dice_per_class(pred_mask, gt_mask, num_classes):
    """Compute per-class Dice score."""
    dice_scores = {}
    for c in range(num_classes):
        pred_c = (pred_mask == c).float()
        gt_c = (gt_mask == c).float()
        
        intersection = (pred_c * gt_c).sum()
        union = pred_c.sum() + gt_c.sum()
        
        if union > 0:
            dice = (2 * intersection / (union + 1e-8)).item()
        else:
            dice = 1.0 if intersection == 0 else 0.0
        
        dice_scores[c] = dice
    return dice_scores


def compute_pq_per_class(pred_semantic, gt_semantic, gt_instance, gt_class_map, num_classes):
    """
    Compute Panoptic Quality per class (SAME AS CIPS-NET).
    
    PQ = SQ × DQ
    - SQ (Segmentation Quality) = average IoU of matched pairs
    - DQ (Detection Quality) = TP / (TP + 0.5*FP + 0.5*FN)
    
    Predicted instances are created from connected components of semantic output.
    """
    results = {}
    
    for c in range(num_classes):
        # Get masks for this class
        gt_mask_c = (gt_semantic == c)
        pred_mask_c = (pred_semantic == c)
        
        # Get GT instances of this class
        gt_instances_c = np.unique(gt_instance[gt_mask_c.cpu().numpy()])
        gt_instances_c = gt_instances_c[gt_instances_c > 0]
        
        # Create predicted instances from connected components (same as CIPS-Net)
        pred_labeled, num_pred = ndimage.label(pred_mask_c.cpu().numpy())
        pred_instances_c = list(range(1, num_pred + 1))
        
        if len(gt_instances_c) == 0 and len(pred_instances_c) == 0:
            results[c] = {'PQ': 1.0, 'DQ': 1.0, 'SQ': 1.0, 'TP': 0, 'FP': 0, 'FN': 0}
            continue
        
        if len(gt_instances_c) == 0:
            results[c] = {'PQ': 0.0, 'DQ': 0.0, 'SQ': 0.0, 'TP': 0, 'FP': len(pred_instances_c), 'FN': 0}
            continue
        
        if len(pred_instances_c) == 0:
            results[c] = {'PQ': 0.0, 'DQ': 0.0, 'SQ': 0.0, 'TP': 0, 'FP': 0, 'FN': len(gt_instances_c)}
            continue
        
        # Compute IoU matrix between GT and predicted instances
        iou_matrix = np.zeros((len(gt_instances_c), len(pred_instances_c)))
        
        for i, gt_id in enumerate(gt_instances_c):
            gt_inst_mask = (gt_instance.numpy() == gt_id)
            for j, pred_id in enumerate(pred_instances_c):
                pred_inst_mask = (pred_labeled == pred_id)
                
                intersection = np.logical_and(gt_inst_mask, pred_inst_mask).sum()
                union = np.logical_or(gt_inst_mask, pred_inst_mask).sum()
                
                if union > 0:
                    iou_matrix[i, j] = intersection / union
        
        # Hungarian matching with IoU > 0.5 threshold
        row_ind, col_ind = linear_sum_assignment(-iou_matrix)
        
        tp = 0
        matched_iou_sum = 0
        matched_gt = set()
        matched_pred = set()
        
        for r, c_idx in zip(row_ind, col_ind):
            if iou_matrix[r, c_idx] > 0.5:
                tp += 1
                matched_iou_sum += iou_matrix[r, c_idx]
                matched_gt.add(r)
                matched_pred.add(c_idx)
        
        fp = len(pred_instances_c) - len(matched_pred)
        fn = len(gt_instances_c) - len(matched_gt)
        
        # Compute PQ, DQ, SQ
        if tp > 0:
            sq = matched_iou_sum / tp
            dq = tp / (tp + 0.5 * fp + 0.5 * fn)
            pq = sq * dq
        else:
            sq = 0.0
            dq = 0.0
            pq = 0.0
        
        results[c] = {'PQ': pq, 'DQ': dq, 'SQ': sq, 'TP': tp, 'FP': fp, 'FN': fn}
    
    return results

print("✓ Metrics functions defined (mDice + mPQ)")

# ============================================================================
# Evaluation function for per-image text (like CIPS-Net)
# ============================================================================
@torch.no_grad()
def evaluate_per_image_text(model, annotations_df, fold, data_root, img_size, device, num_classes=5):
    """
    Evaluate with per-image text from annotations.csv (CIPS-Net protocol).
    Returns: mDice, mPQ (same metrics as CIPS-Net)
    """
    model.eval()
    
    fold_df = annotations_df[annotations_df['fold'] == fold].reset_index(drop=True)
    print(f"\n  Testing on fold {fold}: {len(fold_df)} images")
    
    # Accumulators for per-class metrics
    all_dice = {c: [] for c in range(num_classes)}
    all_pq = {c: [] for c in range(num_classes)}
    
    # CLIP normalization
    normalize = transforms.Normalize(
        mean=[0.48145466, 0.4578275, 0.40821073],
        std=[0.26862954, 0.26130258, 0.27577711]
    )
    
    for idx in tqdm(range(len(fold_df)), desc="Per-Image Text"):
        row = fold_df.iloc[idx]
        image_id = row['image_id']
        
        # Get per-image instruction from annotations.csv
        text_instruction = row['instruction'] if pd.notna(row['instruction']) else "Segment all tissue types."
        
        # Load image
        img_path = Path(data_root) / 'images' / f'fold{fold}' / f'{image_id}.png'
        if not img_path.exists():
            continue
            
        image = Image.open(img_path).convert("RGB")
        image = image.resize((img_size, img_size), Image.BILINEAR)
        image_tensor = transforms.ToTensor()(image)
        image_tensor = normalize(image_tensor).unsqueeze(0).to(device)
        
        # Load mask (npz with per-class binary masks)
        mask_path = Path(data_root) / 'masks' / f'fold{fold}' / f'{image_id}.npz'
        if not mask_path.exists():
            continue
            
        mask_data = np.load(mask_path)
        masks = mask_data['masks']  # [H, W, num_classes]
        
        # Create semantic mask
        gt_semantic = np.zeros((masks.shape[0], masks.shape[1]), dtype=np.int64)
        for c in range(min(masks.shape[2], num_classes)):
            gt_semantic[masks[:, :, c] > 0] = c
        
        # Create instance mask (for PQ calculation)
        gt_instance = np.zeros((masks.shape[0], masks.shape[1]), dtype=np.int32)
        gt_class_map = {}
        inst_id = 1
        for c in range(min(masks.shape[2], num_classes)):
            binary_mask = masks[:, :, c] > 0
            if binary_mask.sum() > 0:
                labeled, num_instances = ndimage.label(binary_mask)
                for i in range(1, num_instances + 1):
                    gt_instance[labeled == i] = inst_id
                    gt_class_map[inst_id] = c
                    inst_id += 1
        
        # Resize masks
        gt_semantic = np.array(Image.fromarray(gt_semantic.astype(np.uint8)).resize(
            (img_size, img_size), Image.NEAREST))
        gt_instance = np.array(Image.fromarray(gt_instance.astype(np.int32)).resize(
            (img_size, img_size), Image.NEAREST))
        
        gt_semantic = torch.from_numpy(gt_semantic).long()
        gt_instance = torch.from_numpy(gt_instance)
        
        # Forward pass with per-image text
        text_prompts = [text_instruction] * num_classes
        outputs = model(image_tensor, text_prompts)
        logits = outputs['logits']
        pred_semantic = logits.argmax(dim=1).squeeze(0).cpu()
        
        # Compute Dice per class
        dice_scores = compute_dice_per_class(pred_semantic, gt_semantic, num_classes)
        for c in range(num_classes):
            all_dice[c].append(dice_scores[c])
        
        # Compute PQ per class
        pq_scores = compute_pq_per_class(pred_semantic, gt_semantic, gt_instance, gt_class_map, num_classes)
        for c in range(num_classes):
            all_pq[c].append(pq_scores[c]['PQ'])
    
    # Aggregate results
    dice_macro = [np.mean(all_dice[c]) for c in range(num_classes)]
    pq_macro = [np.mean(all_pq[c]) for c in range(num_classes)]
    
    return {
        'mean_dice': np.mean(dice_macro),
        'mean_pq': np.mean(pq_macro),
        'per_class_dice': dice_macro,
        'per_class_pq': pq_macro,
    }

# ============================================================================
# Evaluation function for fixed text (Common Text / No Text)
# ============================================================================
@torch.no_grad()
def evaluate_fixed_text(model, annotations_df, fold, data_root, img_size, device, text_instruction, num_classes=5):
    """
    Evaluate with fixed text instruction for all images.
    Returns: mDice, mPQ (same metrics as CIPS-Net)
    """
    model.eval()
    
    fold_df = annotations_df[annotations_df['fold'] == fold].reset_index(drop=True)
    
    all_dice = {c: [] for c in range(num_classes)}
    all_pq = {c: [] for c in range(num_classes)}
    
    normalize = transforms.Normalize(
        mean=[0.48145466, 0.4578275, 0.40821073],
        std=[0.26862954, 0.26130258, 0.27577711]
    )
    
    variant_name = "Common Text" if text_instruction else "No Text"
    
    for idx in tqdm(range(len(fold_df)), desc=variant_name):
        row = fold_df.iloc[idx]
        image_id = row['image_id']
        
        img_path = Path(data_root) / 'images' / f'fold{fold}' / f'{image_id}.png'
        if not img_path.exists():
            continue
            
        image = Image.open(img_path).convert("RGB")
        image = image.resize((img_size, img_size), Image.BILINEAR)
        image_tensor = transforms.ToTensor()(image)
        image_tensor = normalize(image_tensor).unsqueeze(0).to(device)
        
        mask_path = Path(data_root) / 'masks' / f'fold{fold}' / f'{image_id}.npz'
        if not mask_path.exists():
            continue
            
        mask_data = np.load(mask_path)
        masks = mask_data['masks']
        
        gt_semantic = np.zeros((masks.shape[0], masks.shape[1]), dtype=np.int64)
        for c in range(min(masks.shape[2], num_classes)):
            gt_semantic[masks[:, :, c] > 0] = c
        
        gt_instance = np.zeros((masks.shape[0], masks.shape[1]), dtype=np.int32)
        gt_class_map = {}
        inst_id = 1
        for c in range(min(masks.shape[2], num_classes)):
            binary_mask = masks[:, :, c] > 0
            if binary_mask.sum() > 0:
                labeled, num_instances = ndimage.label(binary_mask)
                for i in range(1, num_instances + 1):
                    gt_instance[labeled == i] = inst_id
                    gt_class_map[inst_id] = c
                    inst_id += 1
        
        gt_semantic = np.array(Image.fromarray(gt_semantic.astype(np.uint8)).resize(
            (img_size, img_size), Image.NEAREST))
        gt_instance = np.array(Image.fromarray(gt_instance.astype(np.int32)).resize(
            (img_size, img_size), Image.NEAREST))
        
        gt_semantic = torch.from_numpy(gt_semantic).long()
        gt_instance = torch.from_numpy(gt_instance)
        
        text_prompts = [text_instruction] * num_classes
        outputs = model(image_tensor, text_prompts)
        logits = outputs['logits']
        pred_semantic = logits.argmax(dim=1).squeeze(0).cpu()
        
        dice_scores = compute_dice_per_class(pred_semantic, gt_semantic, num_classes)
        for c in range(num_classes):
            all_dice[c].append(dice_scores[c])
        
        pq_scores = compute_pq_per_class(pred_semantic, gt_semantic, gt_instance, gt_class_map, num_classes)
        for c in range(num_classes):
            all_pq[c].append(pq_scores[c]['PQ'])
    
    dice_macro = [np.mean(all_dice[c]) for c in range(num_classes)]
    pq_macro = [np.mean(all_pq[c]) for c in range(num_classes)]
    
    return {
        'mean_dice': np.mean(dice_macro),
        'mean_pq': np.mean(pq_macro),
        'per_class_dice': dice_macro,
        'per_class_pq': pq_macro,
    }

# ============================================================================
# Run all 3 text variant evaluations (EXACTLY like CIPS-Net)
# ============================================================================
text_variant_results = {}
all_results_rows = []

test_fold = 1  # Use fold 1 as test (same as CIPS-Net)

if annotations_df is not None:
    # 1. Per-Image Text (from annotations.csv)
    print(f"\n{'='*60}")
    print(f"Variant 1: Per-Image Text (from annotations.csv)")
    print(f"{'='*60}")
    per_image_metrics = evaluate_per_image_text(
        model, annotations_df, test_fold, 
        CONFIG['pannuke_preprocess_root'], CONFIG['image_size'],
        CONFIG['device'], CONFIG['num_classes']
    )
    text_variant_results['per_image_text'] = {
        'display_name': 'Per-Image Text',
        'source': 'annotations.csv instruction column',
        **per_image_metrics
    }
    all_results_rows.append({
        'Dataset': 'PanNuke', 'Variant': 'Per-Image Text', 'Setting': 'Zero-Shot',
        'mDice': per_image_metrics['mean_dice'], 'mPQ': per_image_metrics['mean_pq'],
    })
    print(f"  mDice: {per_image_metrics['mean_dice']:.4f}")
    print(f"  mPQ:   {per_image_metrics['mean_pq']:.4f}")
    
    # 2. Common Text (fixed sentence for all images)
    print(f"\n{'='*60}")
    print(f"Variant 2: Common Text")
    print(f"  Text: '{COMMON_TEXT}'")
    print(f"{'='*60}")
    common_metrics = evaluate_fixed_text(
        model, annotations_df, test_fold,
        CONFIG['pannuke_preprocess_root'], CONFIG['image_size'],
        CONFIG['device'], COMMON_TEXT, CONFIG['num_classes']
    )
    text_variant_results['common_text'] = {
        'display_name': 'Common Text',
        'text': COMMON_TEXT,
        **common_metrics
    }
    all_results_rows.append({
        'Dataset': 'PanNuke', 'Variant': 'Common Text', 'Setting': 'Zero-Shot',
        'mDice': common_metrics['mean_dice'], 'mPQ': common_metrics['mean_pq'],
    })
    print(f"  mDice: {common_metrics['mean_dice']:.4f}")
    print(f"  mPQ:   {common_metrics['mean_pq']:.4f}")
    
    # 3. No Text (empty string)
    print(f"\n{'='*60}")
    print(f"Variant 3: No Text (empty string)")
    print(f"{'='*60}")
    no_text_metrics = evaluate_fixed_text(
        model, annotations_df, test_fold,
        CONFIG['pannuke_preprocess_root'], CONFIG['image_size'],
        CONFIG['device'], NO_TEXT, CONFIG['num_classes']
    )
    text_variant_results['no_text'] = {
        'display_name': 'No Text',
        'text': '',
        **no_text_metrics
    }
    all_results_rows.append({
        'Dataset': 'PanNuke', 'Variant': 'No Text', 'Setting': 'Zero-Shot',
        'mDice': no_text_metrics['mean_dice'], 'mPQ': no_text_metrics['mean_pq'],
    })
    print(f"  mDice: {no_text_metrics['mean_dice']:.4f}")
    print(f"  mPQ:   {no_text_metrics['mean_pq']:.4f}")

# Save results
text_variant_path = results_dir / "text_variants" / "pannuke_text_variants.json"
with open(text_variant_path, 'w') as f:
    json.dump(text_variant_results, f, indent=2)

# ============================================================================
# Summary Table (matching CIPS-Net format exactly)
# ============================================================================
print("\n" + "=" * 70)
print("PANNUKE TEXT VARIANT RESULTS (CIPS-Net Format)")
print("=" * 70)
print(f"{'Dataset':<12} {'Variant':<18} {'Setting':<12} {'mDice':<12} {'mPQ':<12}")
print("-" * 66)
for row in all_results_rows:
    print(f"{row['Dataset']:<12} {row['Variant']:<18} {row['Setting']:<12} "
          f"{row['mDice']:<12.4f} {row['mPQ']:<12.4f}")
print("-" * 66)

# Comparison with CIPS-Net results
print("\n" + "=" * 70)
print("COMPARISON WITH CIPS-NET")
print("=" * 70)
cipsnet_reference = {
    'Per-Image Text': {'mDice': 0.7661, 'mPQ': 0.5356},
    'Common Text': {'mDice': 0.3558, 'mPQ': 0.1091},
    'No Text': {'mDice': 0.6910, 'mPQ': 0.4998},
}
print(f"{'Variant':<18} {'CIPS-Net mDice':<16} {'CIPS-Net mPQ':<14} {f'{MODEL_NAME} mDice':<16} {f'{MODEL_NAME} mPQ':<14}")
print("-" * 78)
for row in all_results_rows:
    variant = row['Variant']
    cips = cipsnet_reference.get(variant, {})
    print(f"{variant:<18} {cips.get('mDice', 'N/A'):<16} {cips.get('mPQ', 'N/A'):<14} {row['mDice']:<16.4f} {row['mPQ']:<14.4f}")

print(f"\n✓ Text variant results saved to: {text_variant_path}")


PHASE 2: TEXT VARIANTS EVALUATION (CIPS-NET PROTOCOL)
✓ Loaded best model from fold 0 (epoch 7)
✓ Loaded annotations.csv: 7901 entries
✓ Metrics functions defined (mDice + mPQ)

Variant 1: Per-Image Text (from annotations.csv)

  Testing on fold 1: 2656 images


Per-Image Text:   0%|          | 0/2656 [00:00<?, ?it/s]

  mDice: 0.6905
  mPQ:   0.4993

Variant 2: Common Text
  Text: 'Segment all Neoplastic, Inflammatory, Connective, Dead, and Epithelial cells in the image.'


Common Text:   0%|          | 0/2656 [00:00<?, ?it/s]

  mDice: 0.6905
  mPQ:   0.4993

Variant 3: No Text (empty string)


No Text:   0%|          | 0/2656 [00:00<?, ?it/s]

  mDice: 0.6905
  mPQ:   0.4993

PANNUKE TEXT VARIANT RESULTS (CIPS-Net Format)
Dataset      Variant            Setting      mDice        mPQ         
------------------------------------------------------------------
PanNuke      Per-Image Text     Zero-Shot    0.6905       0.4993      
PanNuke      Common Text        Zero-Shot    0.6905       0.4993      
PanNuke      No Text            Zero-Shot    0.6905       0.4993      
------------------------------------------------------------------

COMPARISON WITH CIPS-NET
Variant            CIPS-Net mDice   CIPS-Net mPQ   clipseg mDice    clipseg mPQ   
------------------------------------------------------------------------------
Per-Image Text     0.7661           0.5356         0.6905           0.4993        
Common Text        0.3558           0.1091         0.6905           0.4993        
No Text            0.691            0.4998         0.6905           0.4993        

✓ Text variant results saved to: /mnt/e3dbc9b9-6856-470d-84b1-ff

In [29]:
# ============================================================================
# PHASE 3: ZERO-SHOT EVALUATION ON CONSEP AND MONUSAC
# ============================================================================
# Same 3 text variants as PanNuke:
#   1. Per-Image Text
#   2. Common Text
#   3. No Text
# ============================================================================

print("\n" + "=" * 70)
print("PHASE 3: ZERO-SHOT EVALUATION ON EXTERNAL DATASETS")
print("=" * 70)

# External dataset paths (Official datasets)
EXTERNAL_DATASETS_ROOT = Path(CONFIG.get('external_datasets_root', 
    '/mnt/e3dbc9b9-6856-470d-84b1-ff55921cd906/Datasets/Nikhil/Histopathology_Work/Histopathology_Datasets_Official'))

external_datasets_config = {
    'CoNSeP': {
        'path': EXTERNAL_DATASETS_ROOT / 'CoNSeP',
        # CoNSeP → PanNuke class mapping
        'class_mapping': {
            2: 1,  # Inflammatory → Inflammatory
            3: 4,  # Healthy Epithelial → Epithelial
            4: 0,  # Dysplastic/Malignant → Neoplastic
            5: 2,  # Fibroblast → Connective
        },
    },
    'MoNuSAC': {
        'path': EXTERNAL_DATASETS_ROOT / 'MoNuSAC',
        # MoNuSAC → PanNuke class mapping
        'class_mapping': {
            'Epithelial': 4,   # Epithelial → Epithelial
            'Lymphocyte': 1,   # Lymphocyte → Inflammatory
            'Macrophage': 1,   # Macrophage → Inflammatory
            'Neutrophil': 1,   # Neutrophil → Inflammatory
        },
    },
}

zero_shot_results = {}
all_zero_shot_rows = []

# Use the model loaded from Phase 2
model.eval()

# Display name mapping for variants (ensure available here)
variant_display_names = {
    "per_image_text": "Per-Image Text",
    "common_text": "Common Text",
    "no_text": "No Text",
}

for dataset_name, config in external_datasets_config.items():
    print(f"\n{'='*60}")
    print(f"Zero-Shot Evaluation: {dataset_name}")
    print(f"{'='*60}")
    
    dataset_path = config['path']
    
    # Check if dataset exists
    if not dataset_path.exists():
        print(f"⚠️ Dataset not found at: {dataset_path}")
        print(f"   Skipping {dataset_name}...")
        zero_shot_results[dataset_name] = {'status': 'dataset_not_found', 'path': str(dataset_path)}
        continue
    
    # For now, we'll note that external dataset evaluation requires
    # proper dataset loaders matching the CIPS-Net format
    # This is a placeholder that shows the expected output format
    
    print(f"  Path: {dataset_path}")
    print(f"  Class Mapping: {config['class_mapping']}")
    
    # Placeholder results (actual evaluation would need proper dataset loaders)
    dataset_results = {
        'status': 'requires_dataset_loader',
        'class_mapping': config['class_mapping'],
        'text_variants': {},
    }
    
    # For each text variant
    for variant_name, prompts in TEXT_PROMPTS_VARIANTS.items():
        display_name = variant_display_names[variant_name]
        
        # Placeholder - would need actual dataset loader
        dataset_results['text_variants'][variant_name] = {
            'display_name': display_name,
            'status': 'pending',
        }
        
        all_zero_shot_rows.append({
            'Dataset': dataset_name,
            'Variant': display_name,
            'Setting': 'Zero-Shot',
            'mDice': 'N/A',
            'mIoU': 'N/A',
        })
    
    zero_shot_results[dataset_name] = dataset_results
    print(f"  ⚠️ External dataset loader needed for full evaluation")

# Save zero-shot results
zero_shot_path = results_dir / "zero_shot" / "zero_shot_results.json"
with open(zero_shot_path, 'w') as f:
    json.dump(zero_shot_results, f, indent=2, default=str)

print(f"\n✓ Zero-shot config saved to: {zero_shot_path}")

# ============================================================================
# Summary
# ============================================================================
print("\n" + "=" * 70)
print("ZERO-SHOT EVALUATION STATUS")
print("=" * 70)
for dataset_name, result in zero_shot_results.items():
    status = result.get('status', 'unknown')
    print(f"  {dataset_name}: {status}")

print(f"\nNote: Full zero-shot evaluation on CoNSeP and MoNuSAC requires")
print(f"      proper dataset loaders with class mapping (see CIPS-Net notebook).")


PHASE 3: ZERO-SHOT EVALUATION ON EXTERNAL DATASETS

Zero-Shot Evaluation: CoNSeP
  Path: /mnt/e3dbc9b9-6856-470d-84b1-ff55921cd906/Datasets/Nikhil/Histopathology_Work/Histopathology_Datasets_Official/CoNSeP
  Class Mapping: {2: 1, 3: 4, 4: 0, 5: 2}
  ⚠️ External dataset loader needed for full evaluation

Zero-Shot Evaluation: MoNuSAC
  Path: /mnt/e3dbc9b9-6856-470d-84b1-ff55921cd906/Datasets/Nikhil/Histopathology_Work/Histopathology_Datasets_Official/MoNuSAC
  Class Mapping: {'Epithelial': 4, 'Lymphocyte': 1, 'Macrophage': 1, 'Neutrophil': 1}
  ⚠️ External dataset loader needed for full evaluation

✓ Zero-shot config saved to: /mnt/e3dbc9b9-6856-470d-84b1-ff55921cd906/Datasets/Nikhil/Histopathology_Work/results/clipseg_20260126_145532/zero_shot/zero_shot_results.json

ZERO-SHOT EVALUATION STATUS
  CoNSeP: requires_dataset_loader
  MoNuSAC: requires_dataset_loader

Note: Full zero-shot evaluation on CoNSeP and MoNuSAC requires
      proper dataset loaders with class mapping (see CIPS-N

In [30]:
# ============================================================================
# PHASE 4: FINE-TUNING ON CONSEP AND MONUSAC
# ============================================================================
print("\n" + "=" * 70)
print("PHASE 4: FINE-TUNING ON EXTERNAL DATASETS")
print("=" * 70)

fine_tune_results = {}

for dataset_name, paths in external_datasets_config.items():
    print(f"\n{'='*60}")
    print(f"Fine-tuning on {dataset_name}")
    print(f"{'='*60}")
    
    # Check if dataset exists
    if not Path(paths['image_dir']).exists():
        print(f"⚠️ Dataset not found at {paths['image_dir']}")
        fine_tune_results[dataset_name] = {'status': 'dataset_not_found'}
        continue
    
    try:
        # Create train/val split for this dataset
        all_images = sorted(
            list(Path(paths['image_dir']).glob("*.png")) +
            list(Path(paths['image_dir']).glob("*.jpg"))
        )
        
        if len(all_images) < 10:
            print(f"⚠️ Too few images ({len(all_images)}), skipping...")
            fine_tune_results[dataset_name] = {'status': 'too_few_images'}
            continue
        
        # 80/20 train/val split
        n_train = int(len(all_images) * 0.8)
        
        # Create datasets using our SimplePanNukeDataset structure
        # (assuming masks follow same naming convention)
        train_ext_dataset = SimplePanNukeDataset(
            image_dir=paths['image_dir'],
            mask_dir=paths['mask_dir'],
            image_size=CONFIG['image_size'],
            split="train",
            fold=0,
        )
        
        val_ext_dataset = SimplePanNukeDataset(
            image_dir=paths['image_dir'],
            mask_dir=paths['mask_dir'],
            image_size=CONFIG['image_size'],
            split="val",
            fold=0,
        )
        
        train_ext_loader = DataLoader(
            train_ext_dataset, batch_size=CONFIG['batch_size'], shuffle=True,
            num_workers=0, pin_memory=True, drop_last=True,
        )
        
        val_ext_loader = DataLoader(
            val_ext_dataset, batch_size=CONFIG['batch_size'], shuffle=False,
            num_workers=0, pin_memory=True,
        )
        
        print(f"Train: {len(train_ext_dataset)}, Val: {len(val_ext_dataset)}")
        
        # Create fresh model initialized from PanNuke-pretrained weights
        ft_model = get_model(
            MODEL_NAME,
            num_classes=CONFIG['num_classes'],
            image_size=CONFIG['image_size'],
            clip_model=CONFIG['clip_model'],
            freeze_clip=True,
            device=CONFIG['device'],
        )
        
        # Load pretrained weights from PanNuke
        pretrain_path = results_dir / "models" / f"best_{MODEL_NAME}_fold0.pth"
        if pretrain_path.exists():
            checkpoint = torch.load(pretrain_path, map_location=CONFIG['device'])
            ft_model.load_state_dict(checkpoint['model_state_dict'])
            print(f"Initialized from PanNuke pretrained model")
        
        # Fine-tuning with lower learning rate
        ft_optimizer = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, ft_model.parameters()),
            lr=CONFIG['learning_rate'] * 0.1,  # Lower LR for fine-tuning
            weight_decay=CONFIG['weight_decay'],
        )
        ft_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            ft_optimizer, T_max=20, eta_min=1e-7,  # Fewer epochs for fine-tuning
        )
        ft_scaler = torch.amp.GradScaler('cuda')
        
        # Fine-tuning loop (fewer epochs)
        best_ft_metric = 0.0
        ft_patience_counter = 0
        ft_history = defaultdict(list)
        
        for epoch in range(20):  # 20 epochs for fine-tuning
            print(f"\n  Epoch {epoch+1}/20")
            
            train_loss = train_one_epoch(
                ft_model, train_ext_loader, ft_optimizer, criterion, ft_scaler,
                TEXT_PROMPTS, CONFIG['device']
            )
            
            val_metrics = validate(
                ft_model, val_ext_loader, criterion, TEXT_PROMPTS,
                CONFIG['device'], CONFIG['num_classes']
            )
            
            ft_scheduler.step()
            
            ft_history['train_loss'].append(train_loss)
            ft_history['val_iou'].append(val_metrics['mean_iou'])
            ft_history['val_dice'].append(val_metrics['mean_dice'])
            
            print(f"    Loss: {train_loss:.4f} | mIoU: {val_metrics['mean_iou']:.4f}")
            
            if val_metrics['mean_iou'] > best_ft_metric:
                best_ft_metric = val_metrics['mean_iou']
                ft_patience_counter = 0
                
                # Save fine-tuned model
                ft_save_path = results_dir / "fine_tuned" / f"best_{MODEL_NAME}_{dataset_name.lower()}.pth"
                torch.save({
                    'model_state_dict': ft_model.state_dict(),
                    'best_metric': best_ft_metric,
                }, ft_save_path)
            else:
                ft_patience_counter += 1
                if ft_patience_counter >= 5:
                    print(f"    Early stopping after {epoch+1} epochs")
                    break
        
        # Evaluate fine-tuned model with all text variants
        ft_model.eval()
        dataset_ft_results = {
            'best_iou': best_ft_metric,
            'text_variants': {},
        }
        
        print(f"\n  Text variant evaluation (fine-tuned):")
        for variant_name, prompts in TEXT_PROMPTS_VARIANTS.items():
            metrics = validate(
                ft_model, val_ext_loader, criterion, prompts,
                CONFIG['device'], CONFIG['num_classes']
            )
            
            dataset_ft_results['text_variants'][variant_name] = {
                'mean_iou': metrics['mean_iou'],
                'mean_dice': metrics['mean_dice'],
            }
            print(f"    {variant_name}: mIoU={metrics['mean_iou']:.4f}")
        
        fine_tune_results[dataset_name] = dataset_ft_results
        
        # Cleanup
        del ft_model, ft_optimizer
        torch.cuda.empty_cache()
        
    except Exception as e:
        print(f"❌ Error fine-tuning on {dataset_name}: {e}")
        import traceback
        traceback.print_exc()
        fine_tune_results[dataset_name] = {'status': 'error', 'message': str(e)}

# Save fine-tuning results
ft_results_path = results_dir / "fine_tuned" / "fine_tuning_results.json"
with open(ft_results_path, 'w') as f:
    json.dump(fine_tune_results, f, indent=2)

print(f"\n✓ Fine-tuning results saved to: {ft_results_path}")


PHASE 4: FINE-TUNING ON EXTERNAL DATASETS

Fine-tuning on CoNSeP


KeyError: 'image_dir'

In [None]:
# ============================================================================
# PHASE 5: COMPREHENSIVE RESULTS SUMMARY (CIPS-NET FORMAT with mPQ)
# ============================================================================
import pandas as pd

print("\n" + "=" * 70)
print("EXPERIMENT RESULTS SUMMARY")
print("=" * 70)
print(f"Model: {MODEL_NAME}")
print(f"Experiment: {EXPERIMENT_NAME}")
print(f"Timestamp: {EXPERIMENT_TIMESTAMP}")
print("=" * 70)

# ============================================================================
# 1. PanNuke 3-Fold Cross-Validation Results
# ============================================================================
print("\n" + "=" * 70)
print("1. PanNuke 3-Fold Cross-Validation Results")
print("=" * 70)

fold_results_summary = []
for fold_num in range(3):
    fold_key = f'fold_{fold_num}'
    if fold_key in all_fold_results:
        metrics = all_fold_results[fold_key]
        fold_results_summary.append({
            'Fold': fold_num,
            'Best mIoU': metrics['best_iou'],
            'Final mDice': metrics['final_dice'],
            'Epochs': metrics['epochs_trained'],
        })
        print(f"  Fold {fold_num}: mIoU = {metrics['best_iou']:.4f}, mDice = {metrics['final_dice']:.4f}")

if fold_results_summary:
    avg_iou = np.mean([r['Best mIoU'] for r in fold_results_summary])
    std_iou = np.std([r['Best mIoU'] for r in fold_results_summary])
    avg_dice = np.mean([r['Final mDice'] for r in fold_results_summary])
    std_dice = np.std([r['Final mDice'] for r in fold_results_summary])
    print(f"\n  Average: mIoU = {avg_iou:.4f} ± {std_iou:.4f}")
    print(f"           mDice = {avg_dice:.4f} ± {std_dice:.4f}")

# ============================================================================
# 2. Text Variants Results (CIPS-Net Format with mPQ)
# ============================================================================
print("\n" + "=" * 70)
print("2. Text Variants Results (CIPS-Net Format)")
print("=" * 70)
print(f"{'Dataset':<12} {'Variant':<18} {'Setting':<12} {'mDice':<12} {'mPQ':<12}")
print("-" * 66)

csv_rows = []

# PanNuke text variants
if 'text_variant_results' in dir() and text_variant_results:
    for variant_name, result in text_variant_results.items():
        display_name = result.get('display_name', variant_name)
        row = {
            'Dataset': 'PanNuke',
            'Variant': display_name,
            'Setting': 'Zero-Shot',
            'mDice': result['mean_dice'],
            'mPQ': result['mean_pq'],
        }
        csv_rows.append(row)
        print(f"{row['Dataset']:<12} {row['Variant']:<18} {row['Setting']:<12} "
              f"{row['mDice']:<12.4f} {row['mPQ']:<12.4f}")

print("-" * 66)

# ============================================================================
# 3. Save Results in CIPS-Net Format
# ============================================================================

# Save CSV (matching CIPS-Net output format)
if csv_rows:
    results_df = pd.DataFrame(csv_rows)
    csv_path = results_dir / "text_variants" / "results_cipsnet_format.csv"
    results_df.to_csv(csv_path, index=False)
    print(f"\n✓ CSV results saved to: {csv_path}")

# Compile all results into summary JSON
summary = {
    'experiment': {
        'name': EXPERIMENT_NAME,
        'model': MODEL_NAME,
        'timestamp': EXPERIMENT_TIMESTAMP,
        'config': CONFIG,
    },
    'pannuke_3fold_cv': {
        'fold_results': all_fold_results if 'all_fold_results' in dir() else {},
        'mean_iou': avg_iou if 'avg_iou' in dir() else None,
        'std_iou': std_iou if 'std_iou' in dir() else None,
        'mean_dice': avg_dice if 'avg_dice' in dir() else None,
        'std_dice': std_dice if 'std_dice' in dir() else None,
    },
    'text_variants': text_variant_results if 'text_variant_results' in dir() else {},
    'zero_shot': zero_shot_results if 'zero_shot_results' in dir() else {},
}

summary_path = results_dir / "full_experiment_summary.json"
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2, default=str)

print(f"✓ Full summary saved to: {summary_path}")

# ============================================================================
# 4. Create Markdown Summary (with mPQ)
# ============================================================================
md_summary = f"""# Experiment Results: {EXPERIMENT_NAME}

## Model: {MODEL_NAME}
- **Timestamp**: {EXPERIMENT_TIMESTAMP}
- **Epochs**: {CONFIG['num_epochs']} (Early Stopping: {CONFIG['early_stopping_patience']})
- **Image Size**: {CONFIG['image_size']}
- **Batch Size**: {CONFIG['batch_size']}

## PanNuke 3-Fold Cross-Validation
| Fold | mIoU | mDice | Epochs |
|------|------|-------|--------|
"""

for r in fold_results_summary:
    md_summary += f"| {r['Fold']} | {r['Best mIoU']:.4f} | {r['Final mDice']:.4f} | {r['Epochs']} |\n"

if 'avg_iou' in dir():
    md_summary += f"| **Avg** | **{avg_iou:.4f}±{std_iou:.4f}** | **{avg_dice:.4f}±{std_dice:.4f}** | - |\n"

md_summary += f"""
## Text Variants Evaluation (CIPS-Net Protocol)
| Dataset | Variant | Setting | mDice | mPQ |
|---------|---------|---------|-------|-----|
"""

for row in csv_rows:
    md_summary += f"| {row['Dataset']} | {row['Variant']} | {row['Setting']} | {row['mDice']:.4f} | {row['mPQ']:.4f} |\n"

md_summary += f"""
## Comparison with CIPS-Net (Reference)
| Dataset | Variant | CIPS-Net mDice | CIPS-Net mPQ | {MODEL_NAME} mDice | {MODEL_NAME} mPQ |
|---------|---------|----------------|--------------|-----------------|---------------|
| PanNuke | Per-Image Text | 0.7661 | 0.5356 | {csv_rows[0]['mDice']:.4f} | {csv_rows[0]['mPQ']:.4f} |
| PanNuke | Common Text | 0.3558 | 0.1091 | {csv_rows[1]['mDice']:.4f} | {csv_rows[1]['mPQ']:.4f} |
| PanNuke | No Text | 0.6910 | 0.4998 | {csv_rows[2]['mDice']:.4f} | {csv_rows[2]['mPQ']:.4f} |
"""

md_path = results_dir / "RESULTS.md"
with open(md_path, 'w') as f:
    f.write(md_summary)

print(f"✓ Markdown summary saved to: {md_path}")

# ============================================================================
# Final Summary
# ============================================================================
print("\n" + "=" * 70)
print("EXPERIMENT COMPLETE!")
print("=" * 70)
print(f"\n📁 Results saved to: {results_dir}")
print(f"\nFiles:")
print(f"  • models/best_{MODEL_NAME}_fold*.pth - Model checkpoints")
print(f"  • pannuke_3fold/cv_results.json - 3-fold CV results")
print(f"  • text_variants/results_cipsnet_format.csv - Text variant results (mDice + mPQ)")
print(f"  • RESULTS.md - Markdown summary")
print(f"  • full_experiment_summary.json - Complete results")
print("=" * 70)

## Next Steps

1. **Run all folds**: Change `CONFIG['current_fold']` to 1 and 2 to complete 3-fold CV
2. **Try different models**: Change `MODEL_NAME` to train other models
3. **Compare with CIPS-Net**: Load CIPS-Net results from `results_comparison.txt`
4. **Zero-shot evaluation**: Test trained models on CoNSeP and MoNuSAC without fine-tuning
5. **Fine-tuning**: Optionally fine-tune on target datasets