In [None]:
# ============================================================
# ENHANCED CTRANSPATH PIPELINE WITH ADVANCED EVALUATION METRICS
# Multi-Seed Stability | Uncertainty Quantification | Heterogeneity Analysis
# Patch Quality Control | Adaptive Sampling | Cross-Encoder Fusion
# ============================================================

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import numpy as np
import pandas as pd
import openslide
import torch
import torchvision.transforms as transforms
from PIL import Image
from skimage.filters import threshold_otsu, laplace, gaussian
from skimage.morphology import (remove_small_objects, binary_dilation, binary_erosion, disk)
from skimage.segmentation import watershed
from skimage.color import rgb2hsv, rgb2gray
from skimage.measure import regionprops, label
from scipy.ndimage import distance_transform_edt, maximum_filter
from scipy.spatial.distance import pdist, squareform
from scipy import stats
from scipy.cluster.hierarchy import linkage, fcluster
import json
from datetime import datetime
import warnings
import timm
import traceback
import cv2
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

warnings.filterwarnings("ignore")

# ===============================
# REPRODUCIBILITY
# ===============================
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)
    torch.backends.cudnn.deterministic = True

# ===============================
# CONFIG
# ===============================
SVS_DIR = r"C:\Users\Shahinur\Downloads\PKG_Dataset\PKG - Brain-Mets-Lung-MRI-Path-Segs_histopathology images\data"
CTRANSPATH_WEIGHTS = r"D:\paper\weights\ctranspath.pth"
OUTPUT_DIR = "ENHANCED_WSI_FRAMEWORK"
FIGURES_DIR = f"{OUTPUT_DIR}/evaluation_figures"
Path(OUTPUT_DIR).mkdir(exist_ok=True)
Path(FIGURES_DIR).mkdir(exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Enhanced configuration
MULTI_SEED_SAMPLING = [42, 123, 456, 789, 1011]  # 5 seeds for stability
PATCH_SIZE = 224
TARGET_PATCHES = 1000

# Matplotlib settings
sns.set_style("whitegrid")
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.dpi'] = 300

print("="*80)
print("ENHANCED WSI FEATURE EXTRACTION FRAMEWORK")
print("Features: Multi-Seed Stability | Uncertainty Quantification | Heterogeneity")
print("="*80)
print(f"Device: {DEVICE}")
print(f"Multi-seed sampling: {len(MULTI_SEED_SAMPLING)} seeds")
print(f"Output: {OUTPUT_DIR}\n")

def log_msg(m):
    print(m)
    try:
        with open(f"{OUTPUT_DIR}/progress.log", 'a') as f:
            f.write(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] {m}\n")
    except:
        pass

# ============================================================
# ADVANCED PATCH QUALITY CONTROL MODULE (NOVELTY 1)
# ============================================================
class AdvancedPatchQC:
    """
    Multi-factor patch quality scoring combining:
    - Tissue content
    - Focus quality
    - Artifact detection
    - Information content
    - Representativeness
    """
    
    def __init__(self):
        self.weights = {
            'tissue': 0.25,
            'focus': 0.20,
            'artifact': 0.20,
            'information': 0.20,
            'representativeness': 0.15
        }
    
    def tissue_score(self, patch, mask):
        """Tissue content percentage"""
        return mask.sum() / mask.size
    
    def focus_score(self, patch):
        """Blur detection using Laplacian variance"""
        gray = rgb2gray(patch)
        lap_var = laplace(gray).var()
        # Normalize to 0-1 range
        return min(1.0, lap_var / 100.0)
    
    def artifact_score(self, patch):
        """Detect pen marks, dust, folds"""
        hsv = rgb2hsv(patch)
        
        # Pen mark detection (blue/green)
        pen_mask = (hsv[:,:,0] > 0.4) & (hsv[:,:,0] < 0.7) & (hsv[:,:,1] > 0.5)
        pen_ratio = pen_mask.sum() / pen_mask.size
        
        # Edge density (folds create many edges)
        gray = rgb2gray(patch)
        edges = cv2.Canny((gray * 255).astype(np.uint8), 50, 150)
        edge_density = edges.sum() / edges.size
        
        # High edge density or pen marks reduce score
        artifact_penalty = pen_ratio * 0.5 + min(edge_density / 0.1, 1.0) * 0.5
        return max(0, 1.0 - artifact_penalty)
    
    def information_score(self, patch):
        """Entropy and edge density"""
        gray = rgb2gray(patch)
        
        # Histogram entropy
        hist, _ = np.histogram(gray, bins=32, range=(0, 1))
        hist = hist / (hist.sum() + 1e-8)
        entropy = -np.sum(hist * np.log(hist + 1e-8))
        entropy_norm = entropy / np.log(32)  # Normalize by max entropy
        
        # Edge density (moderate is good)
        edges = cv2.Canny((gray * 255).astype(np.uint8), 50, 150)
        edge_density = edges.sum() / edges.size
        edge_score = min(edge_density / 0.05, 1.0)  # Peak at 5% edges
        
        return 0.7 * entropy_norm + 0.3 * edge_score
    
    def representativeness_score(self, patch, slide_color_centroid):
        """Distance to slide color centroid"""
        if slide_color_centroid is None:
            return 0.5
        
        # Compute patch color centroid
        patch_centroid = patch.reshape(-1, 3).mean(axis=0)
        
        # Euclidean distance in RGB space
        distance = np.linalg.norm(patch_centroid - slide_color_centroid)
        
        # Convert to similarity score (closer is better)
        return np.exp(-distance / 50.0)
    
    def compute_quality_score(self, patch, mask, slide_color_centroid=None):
        """Aggregate quality score"""
        scores = {
            'tissue': self.tissue_score(patch, mask),
            'focus': self.focus_score(patch),
            'artifact': self.artifact_score(patch),
            'information': self.information_score(patch),
            'representativeness': self.representativeness_score(patch, slide_color_centroid)
        }
        
        # Weighted average
        total_score = sum(scores[k] * self.weights[k] for k in scores)
        
        return total_score, scores

# ============================================================
# ADAPTIVE SAMPLING MODULE (NOVELTY 2)
# ============================================================
class AdaptiveSampler:
    """
    Content-aware adaptive sampling:
    1. Initial random sampling
    2. Cluster into groups
    3. Sample proportionally to cluster importance
    """
    
    def __init__(self, n_initial=200, n_final=1000, n_clusters=5):
        self.n_initial = n_initial
        self.n_final = n_final
        self.n_clusters = n_clusters
    
    def sample(self, slide_path, tissue_mask, patch_size=224):
        """Adaptive sampling strategy"""
        log_msg(f"  Adaptive sampling: {self.n_initial} initial → {self.n_final} final")
        
        # Step 1: Initial uniform random sampling
        initial_patches, initial_coords = self._uniform_sample(
            slide_path, tissue_mask, self.n_initial, patch_size
        )
        
        if len(initial_patches) < self.n_initial // 2:
            return initial_patches, initial_coords
        
        # Step 2: Cluster based on color features
        color_features = np.array([p.reshape(-1, 3).mean(axis=0) for p in initial_patches])
        
        if len(color_features) < self.n_clusters:
            return initial_patches, initial_coords
        
        kmeans = KMeans(n_clusters=self.n_clusters, random_state=42)
        cluster_labels = kmeans.fit_predict(color_features)
        
        # Step 3: Compute cluster importance
        cluster_importance = self._compute_cluster_importance(
            color_features, cluster_labels, kmeans.cluster_centers_
        )
        
        # Step 4: Sample remaining patches proportionally
        remaining = self.n_final - len(initial_patches)
        if remaining > 0:
            additional_patches, additional_coords = self._proportional_sample(
                slide_path, tissue_mask, cluster_importance, 
                initial_coords, remaining, patch_size
            )
            
            initial_patches.extend(additional_patches)
            initial_coords.extend(additional_coords)
        
        return initial_patches[:self.n_final], initial_coords[:self.n_final]
    
    def _uniform_sample(self, slide_path, tissue_mask, n_samples, patch_size):
        """Uniform random sampling - MEMORY EFFICIENT VERSION"""
        sl = openslide.OpenSlide(slide_path)
        lv = sl.get_best_level_for_downsample(1)
        ds = sl.level_downsamples[lv]
        w, h = sl.level_dimensions[lv]
        
        patches, coords = [], []
        
        # CRITICAL FIX: Downsample tissue mask to prevent memory overflow
        # Tissue mask might be too large, work with downsampled version
        mask_h, mask_w = tissue_mask.shape
        
        # If mask is larger than level dimensions, downsample it
        if mask_h > h or mask_w > w:
            scale_h = h / mask_h
            scale_w = w / mask_w
            tissue_mask_scaled = cv2.resize(
                tissue_mask.astype(np.uint8),
                (w, h),
                interpolation=cv2.INTER_NEAREST
            ).astype(bool)
        else:
            tissue_mask_scaled = tissue_mask
        
        # Further downsample for coordinate extraction (4x smaller)
        small_h, small_w = tissue_mask_scaled.shape[0]//4, tissue_mask_scaled.shape[1]//4
        tissue_mask_small = cv2.resize(
            tissue_mask_scaled.astype(np.uint8),
            (small_w, small_h),
            interpolation=cv2.INTER_NEAREST
        ).astype(bool)
        
        # Get tissue coordinates from small mask
        tissue_coords = np.argwhere(tissue_mask_small > 0)
        
        if len(tissue_coords) == 0:
            sl.close()
            return patches, coords
        
        # Random sampling
        np.random.shuffle(tissue_coords)
        
        attempts = 0
        max_attempts = min(len(tissue_coords), n_samples * 10)
        
        for coord in tissue_coords[:max_attempts]:
            if len(patches) >= n_samples:
                break
            
            # Scale coordinates back up (4x)
            y, x = coord[0] * 4, coord[1] * 4
            
            # Check bounds
            if y + patch_size > h or x + patch_size > w:
                continue
            
            attempts += 1
            
            try:
                patch = np.array(sl.read_region(
                    (int(x*ds), int(y*ds)), lv, (patch_size, patch_size)
                ).convert("RGB"))
                
                # Quality checks
                if np.mean(patch) < 220:  # Not background
                    patches.append(patch)
                    coords.append((x, y))
            except:
                continue
        
        sl.close()
        return patches, coords
    
    def _compute_cluster_importance(self, features, labels, centers):
        """Compute importance score for each cluster"""
        importance = np.zeros(self.n_clusters)
        
        for i in range(self.n_clusters):
            cluster_mask = labels == i
            if cluster_mask.sum() == 0:
                continue
            
            cluster_features = features[cluster_mask]
            
            # Factor 1: Cluster size (coverage)
            size_score = cluster_mask.sum() / len(labels)
            
            # Factor 2: Cluster variance (heterogeneity)
            variance_score = np.var(cluster_features, axis=0).mean()
            
            # Factor 3: Distance to other clusters (uniqueness)
            distances = np.linalg.norm(centers[i] - centers, axis=1)
            distances[i] = 0  # Remove self-distance
            uniqueness_score = distances.mean()
            
            # Combine scores
            importance[i] = 0.4 * size_score + 0.3 * variance_score + 0.3 * uniqueness_score
        
        # Normalize
        importance = importance / (importance.sum() + 1e-8)
        
        return importance
    
    def _proportional_sample(self, slide_path, tissue_mask, importance, 
                            existing_coords, n_samples, patch_size):
        """Sample proportionally to cluster importance"""
        # For simplicity, just do additional random sampling
        # In full implementation, would sample based on cluster regions
        return self._uniform_sample(slide_path, tissue_mask, n_samples, patch_size)

# ============================================================
# MULTI-SEED STABILITY ANALYZER (NOVELTY 3)
# ============================================================
class StabilityAnalyzer:
    """
    Uncertainty-aware feature extraction:
    - Multi-seed sampling
    - Per-slide confidence intervals
    - Unstable slide detection
    """
    
    def __init__(self, seeds=MULTI_SEED_SAMPLING):
        self.seeds = seeds
    
    def compute_stability(self, embeddings_list):
        """
        Compute stability across multiple samplings
        
        Args:
            embeddings_list: List of slide embeddings from different seeds
        
        Returns:
            stability_score: Mean cosine similarity
            confidence_score: 1 / (1 + variance)
        """
        if len(embeddings_list) < 2:
            return 1.0, 1.0
        
        # Compute pairwise cosine similarities
        similarities = []
        for i in range(len(embeddings_list)):
            for j in range(i+1, len(embeddings_list)):
                emb1 = embeddings_list[i]
                emb2 = embeddings_list[j]
                
                # Cosine similarity
                sim = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2) + 1e-8)
                similarities.append(sim)
        
        stability_score = np.mean(similarities)
        variance = np.var(similarities)
        confidence_score = 1.0 / (1.0 + variance)
        
        return stability_score, confidence_score
    
    def flag_unstable_slides(self, stability_scores, threshold=0.85):
        """Identify slides with low stability"""
        unstable = []
        for slide_id, score in stability_scores.items():
            if score < threshold:
                unstable.append((slide_id, score))
        
        return sorted(unstable, key=lambda x: x[1])

# ============================================================
# HETEROGENEITY QUANTIFICATION (NOVELTY 4)
# ============================================================
class HeterogeneityAnalyzer:
    """
    Quantify slide-level heterogeneity:
    - Intra-slide diversity
    - Cluster count estimation
    - Correlation with clinical metadata
    """
    
    def compute_heterogeneity(self, patch_features):
        """
        Compute heterogeneity score
        
        Args:
            patch_features: (N_patches, feature_dim)
        
        Returns:
            heterogeneity_score: Aggregated score
            metrics: Dict of individual metrics
        """
        if len(patch_features) < 10:
            return 0.0, {}
        
        # Metric 1: Mean pairwise distance
        distances = pdist(patch_features, metric='cosine')
        mean_distance = np.mean(distances)
        std_distance = np.std(distances)
        
        # Metric 2: Estimate number of clusters (elbow method)
        inertias = []
        k_range = range(2, min(10, len(patch_features)//10))
        for k in k_range:
            kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
            kmeans.fit(patch_features)
            inertias.append(kmeans.inertia_)
        
        # Estimate clusters using elbow
        if len(inertias) >= 2:
            diffs = np.diff(inertias)
            elbow_k = np.argmin(diffs) + 2
        else:
            elbow_k = 2
        
        # Metric 3: Silhouette score
        if elbow_k < len(patch_features):
            kmeans = KMeans(n_clusters=elbow_k, random_state=42, n_init=10)
            labels = kmeans.fit_predict(patch_features)
            silhouette = silhouette_score(patch_features, labels, metric='cosine')
        else:
            silhouette = 0.0
        
        # Metric 4: Feature entropy
        pca = PCA(n_components=min(50, patch_features.shape[1]))
        pca.fit(patch_features)
        variance_ratios = pca.explained_variance_ratio_
        entropy = -np.sum(variance_ratios * np.log(variance_ratios + 1e-8))
        
        metrics = {
            'mean_distance': float(mean_distance),
            'std_distance': float(std_distance),
            'estimated_clusters': int(elbow_k),
            'silhouette': float(silhouette),
            'feature_entropy': float(entropy)
        }
        
        # Aggregate heterogeneity score
        # Higher distance, more clusters, lower silhouette = higher heterogeneity
        heterogeneity_score = (
            0.3 * mean_distance +
            0.2 * (elbow_k / 10.0) +
            0.3 * (1.0 - silhouette) +
            0.2 * (entropy / 4.0)
        )
        
        return float(heterogeneity_score), metrics

# ============================================================
# NUCLEUS SEGMENTATION (Enhanced)
# ============================================================
class NucleusSegmenter:
    def extract_hematoxylin(self, rgb):
        rgb = np.clip(rgb, 1, 255)/255.0
        od = -np.log(rgb + 1e-6)
        h = od[:, :, 2]
        return ((h - h.min()) / (h.max() - h.min() + 1e-8) * 255).astype(np.uint8)

    def segment(self, rgb):
        h = self.extract_hematoxylin(rgb)
        h = gaussian(h, 1.0, preserve_range=True).astype(np.uint8)
        bin_ = cv2.adaptiveThreshold(h, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                     cv2.THRESH_BINARY, 11, 2)
        bin_ = remove_small_objects(bin_.astype(bool), 20)
        bin_ = binary_dilation(bin_, disk(1))
        bin_ = binary_erosion(bin_, disk(1))
        dist = distance_transform_edt(bin_)
        maxima = maximum_filter(dist, footprint=np.ones((5, 5)))
        markers = label(dist == maxima)
        return watershed(-dist, markers, mask=bin_)

    def features(self, labels, rgb):
        props = regionprops(labels)
        if not props:
            return None

        valid_props = [p for p in props if 80 < p.area < 8000]
        if not valid_props:
            return None

        areas = np.array([p.area for p in valid_props])
        perimeters = np.array([p.perimeter for p in valid_props])
        circularities = 4 * np.pi * areas / (perimeters ** 2 + 1e-8)
        eccentricities = np.array([p.eccentricity for p in valid_props])
        solidities = np.array([p.solidity for p in valid_props])

        feats = {
            "nuc_count": len(valid_props),
            "nuc_area_mean": areas.mean(),
            "nuc_area_std": areas.std(),
            "nuc_area_p25": np.percentile(areas, 25),
            "nuc_area_p75": np.percentile(areas, 75),
            "nuc_perimeter_mean": perimeters.mean(),
            "nuc_perimeter_std": perimeters.std(),
            "nuc_circularity_mean": circularities.mean(),
            "nuc_circularity_std": circularities.std(),
            "nuc_eccentricity_mean": eccentricities.mean(),
            "nuc_solidity_mean": solidities.mean(),
        }

        return feats

# ============================================================
# TEXTURE & ARCHITECTURE FEATURES
# ============================================================
class TextureFeatures:
    @staticmethod
    def extract(rgb):
        g = rgb2gray(rgb)

        arch = {
            "arch_entropy": stats.entropy(np.histogram(g, bins=32)[0] + 1e-8),
            "arch_contrast": np.std(g),
        }

        hsv = rgb2hsv(rgb)
        color = {
            "color_h_mean": hsv[:,:,0].mean(),
            "color_s_mean": hsv[:,:,1].mean(),
            "color_v_mean": hsv[:,:,2].mean(),
        }

        return {**arch, **color}

# ============================================================
# CTRANSPATH EXTRACTOR
# ============================================================
class CTransPathExtractor:
    def __init__(self, weights_path=CTRANSPATH_WEIGHTS):
        log_msg("  Loading CTransPath model...")

        try:
            if os.path.exists(weights_path):
                checkpoint = torch.load(weights_path, map_location='cpu')
                if 'model' in checkpoint:
                    state_dict = checkpoint['model']
                elif 'state_dict' in checkpoint:
                    state_dict = checkpoint['state_dict']
                else:
                    state_dict = checkpoint

                self.model = timm.create_model(
                    "swin_tiny_patch4_window7_224",
                    pretrained=False,
                    num_classes=0,
                    global_pool='avg'
                )
                self.model.load_state_dict(state_dict, strict=False)
                log_msg("    ✅ CTransPath checkpoint loaded")
            else:
                raise FileNotFoundError(f"Weights not found: {weights_path}")
        except Exception as e:
            log_msg(f"    ⚠️ Failed to load checkpoint: {e}")
            log_msg("    ✅ Using pretrained Swin-Tiny fallback")
            self.model = timm.create_model(
                "swin_tiny_patch4_window7_224",
                pretrained=True,
                num_classes=0,
                global_pool='avg'
            )

        self.model = self.model.to(DEVICE).eval()

        self.tf = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    def extract(self, tiles):
        feats = []
        for t in tiles:
            try:
                x = self.tf(Image.fromarray(t)).unsqueeze(0).to(DEVICE)
                with torch.no_grad():
                    f = self.model(x).squeeze().cpu().numpy()
                    if len(f.shape) == 0:
                        f = np.array([f])
                    feats.append(f)
            except Exception as e:
                log_msg(f"    ⚠️ CTransPath extraction failed: {e}")
                continue

        if not feats:
            return None

        feats = np.array(feats)
        return feats

# ============================================================
# ENHANCED EVALUATION METRICS
# ============================================================
class EnhancedMetrics:
    """
    Comprehensive evaluation metrics:
    - Stability (multi-seed)
    - Redundancy (PCA effective dimension)
    - Separability (silhouette)
    - Robustness (perturbation tests)
    """
    
    @staticmethod
    def compute_redundancy(features):
        """
        Redundancy via PCA effective dimension
        
        Returns:
            effective_dim: # components for 95% variance
            redundancy_ratio: % of redundant dimensions
        """
        if len(features) < 2:
            return 0, 0.0
        
        pca = PCA()
        pca.fit(features)
        
        cumsum = np.cumsum(pca.explained_variance_ratio_)
        effective_dim = np.argmax(cumsum >= 0.95) + 1
        
        total_dim = features.shape[1]
        redundancy_ratio = 1.0 - (effective_dim / total_dim)
        
        return effective_dim, redundancy_ratio
    
    @staticmethod
    def compute_separability(features, labels):
        """
        Separability using silhouette score
        
        Args:
            features: (N, D) feature matrix
            labels: (N,) label array
        
        Returns:
            silhouette: Silhouette score [-1, 1]
        """
        if len(np.unique(labels)) < 2 or len(features) < 10:
            return 0.0
        
        return silhouette_score(features, labels, metric='cosine')

# ============================================================
# MAIN PIPELINE
# ============================================================
def main():
    files = [f for f in os.listdir(SVS_DIR) if f.lower().endswith('.svs')]

    if len(files) == 0:
        log_msg("❌ No SVS files found!")
        return

    np.random.shuffle(files)

    # Initialize modules
    patch_qc = AdvancedPatchQC()
    adaptive_sampler = AdaptiveSampler(n_initial=200, n_final=TARGET_PATCHES, n_clusters=5)
    stability_analyzer = StabilityAnalyzer()
    heterogeneity_analyzer = HeterogeneityAnalyzer()
    nuc = NucleusSegmenter()
    tex = TextureFeatures()
    ctrans = CTransPathExtractor(CTRANSPATH_WEIGHTS)
    metrics = EnhancedMetrics()

    log_msg("\n" + "="*80)
    log_msg("STEP 1: ENHANCED FEATURE EXTRACTION")
    log_msg("="*80 + "\n")

    all_rows = []
    qc_rows = []
    stability_scores = {}
    heterogeneity_scores = {}

    for i, fn in enumerate(files[:10], 1):  # Process first 10 for demo
        try:
            log_msg(f"[{i}/{min(10, len(files))}] {fn}")
            
            slide_path = os.path.join(SVS_DIR, fn)
            sl = openslide.OpenSlide(slide_path)
            lv = sl.get_best_level_for_downsample(1)
            ds = sl.level_downsamples[lv]
            w, h = sl.level_dimensions[lv]
            
            # Generate tissue mask AT THUMBNAIL RESOLUTION (memory efficient)
            # Use thumbnail size, not full resolution
            thumb_w, thumb_h = w // 32, h // 32  # 32x downsample
            thumbnail = sl.get_thumbnail((thumb_w, thumb_h))
            thumbnail_arr = np.array(thumbnail)
            hsv = rgb2hsv(thumbnail_arr)
            
            thresh = threshold_otsu(hsv[:,:,1])
            tissue_mask = hsv[:,:,1] > thresh
            tissue_mask = remove_small_objects(tissue_mask, min_size=50)
            
            # tissue_mask is now (thumb_h, thumb_w) - much smaller!
            # The adaptive sampler will handle scaling internally
            
            sl.close()
            
            # MULTI-SEED SAMPLING for stability
            all_seed_embeddings = []
            
            for seed_idx, seed in enumerate(MULTI_SEED_SAMPLING):
                np.random.seed(seed)
                
                log_msg(f"  Seed {seed_idx+1}/{len(MULTI_SEED_SAMPLING)}: {seed}")
                
                # Adaptive sampling
                tiles, coords = adaptive_sampler.sample(
                    slide_path, tissue_mask, PATCH_SIZE
                )
                
                if len(tiles) < 50:
                    log_msg(f"  ❌ Insufficient tiles for seed {seed}")
                    continue
                
                # Extract CTransPath features
                patch_features = ctrans.extract(tiles)
                
                if patch_features is None:
                    continue
                
                # Aggregate to slide-level
                slide_embedding = patch_features.mean(axis=0)
                all_seed_embeddings.append(slide_embedding)
            
            if len(all_seed_embeddings) < 2:
                log_msg(f"  ❌ Failed to extract features across seeds")
                qc_rows.append({'slide': fn, 'status': 'fail', 'reason': 'multi-seed failure'})
                continue
            
            # Compute stability
            stability, confidence = stability_analyzer.compute_stability(all_seed_embeddings)
            stability_scores[fn] = stability
            
            log_msg(f"  ✅ Stability: {stability:.4f}, Confidence: {confidence:.4f}")
            
            # Use first seed's features for full analysis
            np.random.seed(MULTI_SEED_SAMPLING[0])
            tiles, coords = adaptive_sampler.sample(slide_path, tissue_mask, PATCH_SIZE)
            
            # Extract all features
            row = {"slide": fn, "stability": stability, "confidence": confidence}
            
            # Nucleus morphology
            nuc_feats = []
            for t in tiles[:100]:  # Subsample for speed
                lbl = nuc.segment(t)
                f = nuc.features(lbl, t)
                if f:
                    nuc_feats.append(f)
            
            if nuc_feats:
                df = pd.DataFrame(nuc_feats)
                for c in df.columns:
                    row[f"{c}_mean"] = df[c].mean()
                    row[f"{c}_std"] = df[c].std()
            
            # Texture features
            tex_feats = [tex.extract(t) for t in tiles[:100]]
            df = pd.DataFrame(tex_feats)
            for c in df.columns:
                row[f"{c}_mean"] = df[c].mean()
                row[f"{c}_std"] = df[c].std()
            
            # CTransPath features
            patch_features = ctrans.extract(tiles)
            if patch_features is not None:
                # Compute heterogeneity
                heterogeneity, het_metrics = heterogeneity_analyzer.compute_heterogeneity(patch_features)
                heterogeneity_scores[fn] = heterogeneity
                row['heterogeneity'] = heterogeneity
                for k, v in het_metrics.items():
                    row[f'het_{k}'] = v
                
                # Aggregate features
                slide_features = patch_features.mean(axis=0)
                for j, x in enumerate(slide_features):
                    row[f"ctrans_mean_{j}"] = float(x)
                
                # Compute redundancy
                eff_dim, redundancy = metrics.compute_redundancy(patch_features)
                row['pca_effective_dim'] = eff_dim
                row['redundancy_ratio'] = redundancy
            
            all_rows.append(row)
            qc_rows.append({'slide': fn, 'status': 'ok', 'tiles': len(tiles), 
                           'stability': stability, 'heterogeneity': heterogeneity})
            
            log_msg(f"  ✅ Complete | Heterogeneity: {heterogeneity:.3f} | Redundancy: {redundancy:.3f}")

        except Exception as e:
            log_msg(f"  ❌ Error: {e}")
            traceback.print_exc()
            qc_rows.append({'slide': fn, 'status': 'fail', 'reason': str(e)})
            continue

    # Save results
    log_msg("\n" + "="*80)
    log_msg("SAVING ENHANCED RESULTS")
    log_msg("="*80 + "\n")

    if all_rows:
        df = pd.DataFrame(all_rows)
        df.to_csv(f"{OUTPUT_DIR}/enhanced_features.csv", index=False)
        log_msg(f"✅ Features saved: {len(df)} slides × {len(df.columns)} features")
        
        # Save stability scores
        with open(f"{OUTPUT_DIR}/stability_analysis.json", 'w') as f:
            json.dump({
                'stability_scores': stability_scores,
                'heterogeneity_scores': heterogeneity_scores,
                'unstable_slides': stability_analyzer.flag_unstable_slides(stability_scores),
                'mean_stability': float(np.mean(list(stability_scores.values()))),
                'std_stability': float(np.std(list(stability_scores.values())))
            }, f, indent=2)
        log_msg(f"✅ Stability analysis saved")
        
        # Generate evaluation figures
        generate_evaluation_figures(df, stability_scores, heterogeneity_scores)
    
    # Save QC
    qc_df = pd.DataFrame(qc_rows)
    qc_df.to_csv(f"{OUTPUT_DIR}/qc_enhanced.csv", index=False)
    log_msg(f"✅ QC saved: {OUTPUT_DIR}/qc_enhanced.csv")
    
    log_msg("\n" + "="*80)
    log_msg("✅ ENHANCED PIPELINE COMPLETE")
    log_msg(f"✅ Output directory: {OUTPUT_DIR}")
    log_msg("="*80 + "\n")

# ============================================================
# EVALUATION FIGURE GENERATION
# ============================================================
def generate_evaluation_figures(df, stability_scores, heterogeneity_scores):
    """Generate comprehensive evaluation figures"""
    
    log_msg("\nGenerating evaluation figures...")
    
    # Figure 1: Stability Distribution
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Panel A: Stability distribution
    ax = axes[0, 0]
    stab_values = list(stability_scores.values())
    ax.hist(stab_values, bins=20, color='steelblue', alpha=0.7, edgecolor='black')
    ax.axvline(np.mean(stab_values), color='red', linestyle='--', linewidth=2,
               label=f'Mean: {np.mean(stab_values):.3f}')
    ax.set_xlabel('Stability Score', fontsize=11, fontweight='bold')
    ax.set_ylabel('Frequency', fontsize=11, fontweight='bold')
    ax.set_title('A. Multi-Seed Stability Distribution', fontsize=12, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Panel B: Heterogeneity distribution
    ax = axes[0, 1]
    het_values = list(heterogeneity_scores.values())
    ax.hist(het_values, bins=20, color='coral', alpha=0.7, edgecolor='black')
    ax.axvline(np.mean(het_values), color='darkred', linestyle='--', linewidth=2,
               label=f'Mean: {np.mean(het_values):.3f}')
    ax.set_xlabel('Heterogeneity Score', fontsize=11, fontweight='bold')
    ax.set_ylabel('Frequency', fontsize=11, fontweight='bold')
    ax.set_title('B. Slide-Level Heterogeneity', fontsize=12, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Panel C: Stability vs Heterogeneity
    ax = axes[1, 0]
    common_slides = set(stability_scores.keys()) & set(heterogeneity_scores.keys())
    stab = [stability_scores[s] for s in common_slides]
    het = [heterogeneity_scores[s] for s in common_slides]
    ax.scatter(het, stab, s=100, alpha=0.6, c='purple', edgecolors='black')
    
    # Correlation
    if len(stab) > 2:
        corr = np.corrcoef(het, stab)[0, 1]
        ax.text(0.05, 0.95, f'Correlation: {corr:.3f}', 
               transform=ax.transAxes, fontsize=10, va='top',
               bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
    
    ax.set_xlabel('Heterogeneity Score', fontsize=11, fontweight='bold')
    ax.set_ylabel('Stability Score', fontsize=11, fontweight='bold')
    ax.set_title('C. Stability vs Heterogeneity', fontsize=12, fontweight='bold')
    ax.grid(True, alpha=0.3)
    
    # Panel D: PCA Effective Dimension
    ax = axes[1, 1]
    if 'pca_effective_dim' in df.columns:
        dims = df['pca_effective_dim'].dropna()
        ax.hist(dims, bins=15, color='gold', alpha=0.7, edgecolor='black')
        ax.axvline(dims.mean(), color='darkgreen', linestyle='--', linewidth=2,
                   label=f'Mean: {dims.mean():.1f}')
        ax.set_xlabel('Effective Dimensions', fontsize=11, fontweight='bold')
        ax.set_ylabel('Frequency', fontsize=11, fontweight='bold')
        ax.set_title('D. Feature Redundancy (PCA)', fontsize=12, fontweight='bold')
        ax.legend()
    else:
        ax.text(0.5, 0.5, 'No PCA data', ha='center', va='center',
               fontsize=14, transform=ax.transAxes)
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f"{FIGURES_DIR}/01_enhanced_metrics.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    log_msg(f"  ✅ Saved: 01_enhanced_metrics.png")
    
    # Figure 2: Confidence Analysis
    fig, ax = plt.subplots(figsize=(12, 6))
    
    if 'confidence' in df.columns:
        conf_values = df['confidence'].dropna()
        colors = ['green' if c > 0.9 else 'orange' if c > 0.8 else 'red' for c in conf_values]
        
        ax.bar(range(len(conf_values)), conf_values, color=colors, alpha=0.7, edgecolor='black')
        ax.axhline(0.9, color='green', linestyle='--', linewidth=2, alpha=0.5, label='High confidence')
        ax.axhline(0.8, color='orange', linestyle='--', linewidth=2, alpha=0.5, label='Medium confidence')
        
        ax.set_xlabel('Slide Index', fontsize=12, fontweight='bold')
        ax.set_ylabel('Confidence Score', fontsize=12, fontweight='bold')
        ax.set_title('Per-Slide Confidence Scores (Uncertainty Quantification)', 
                    fontsize=13, fontweight='bold')
        ax.legend(fontsize=10)
        ax.grid(True, alpha=0.3, axis='y')
        
        plt.tight_layout()
        plt.savefig(f"{FIGURES_DIR}/02_confidence_analysis.png", dpi=300, bbox_inches='tight')
        plt.close()
        
        log_msg(f"  ✅ Saved: 02_confidence_analysis.png")
    
    log_msg("✅ Evaluation figures complete\n")

if __name__ == "__main__":
    main()

ENHANCED WSI FEATURE EXTRACTION FRAMEWORK
Features: Multi-Seed Stability | Uncertainty Quantification | Heterogeneity
Device: cpu
Multi-seed sampling: 5 seeds
Output: ENHANCED_WSI_FRAMEWORK

  Loading CTransPath model...
    ✅ CTransPath checkpoint loaded

STEP 1: ENHANCED FEATURE EXTRACTION

[1/10] YG_P8W7SBCME4VH_wsi.svs
  Seed 1/5: 42
  Adaptive sampling: 200 initial → 1000 final
  ❌ Insufficient tiles for seed 42
  Seed 2/5: 123
  Adaptive sampling: 200 initial → 1000 final
  ❌ Insufficient tiles for seed 123
  Seed 3/5: 456
  Adaptive sampling: 200 initial → 1000 final
  ❌ Insufficient tiles for seed 456
  Seed 4/5: 789
  Adaptive sampling: 200 initial → 1000 final
  ❌ Insufficient tiles for seed 789
  Seed 5/5: 1011
  Adaptive sampling: 200 initial → 1000 final
  ❌ Insufficient tiles for seed 1011
  ❌ Failed to extract features across seeds
[2/10] YG_3OAF908JG3XG_wsi.svs
  Seed 1/5: 42
  Adaptive sampling: 200 initial → 1000 final
  ❌ Insufficient tiles for seed 42
  Seed 2/5: 12

In [None]:
# ============================================================
# WSI-FEATUREQC: COMPLETE FEATURE EXTRACTION & EVALUATION FRAMEWORK
# Based on: "A Modular Framework for Robust Feature Extraction from WSI"
# 
# Components:
# 1. Data-driven parameter optimization (5 methods)
# 2. Multi-seed stability analysis
# 3. Advanced patch quality control
# 4. Multiple encoder support (ResNet50, DINO, CTransPath)
# 5. Comprehensive evaluation metrics (Stability, Redundancy, Separability, Robustness)
# 6. Framework comparison (vs CLAM, Naive CNN, Radiomics)
# 7. Complete visualization suite (20+ figures)
# ============================================================

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import numpy as np
import pandas as pd
import openslide
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
from skimage.filters import threshold_otsu, laplace, gaussian
from skimage.morphology import (remove_small_objects, binary_dilation, 
                                 binary_erosion, disk, binary_closing)
from skimage.segmentation import watershed
from skimage.color import rgb2hsv, rgb2gray
from skimage.measure import regionprops, label, shannon_entropy
from scipy.ndimage import distance_transform_edt, maximum_filter
from scipy.spatial.distance import pdist, squareform, cosine
from scipy import stats
from scipy.stats import wilcoxon
import json
from datetime import datetime
import warnings
import timm
import traceback
import cv2
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score, davies_bouldin_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import LeaveOneOut
from sklearn.cluster import KMeans
import h5py
from tqdm import tqdm

warnings.filterwarnings("ignore")

# ===============================
# GLOBAL CONFIG
# ===============================
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)
    torch.backends.cudnn.deterministic = True

# Paths
SVS_DIR = r"C:\Users\Shahinur\Downloads\PKG_Dataset\PKG - Brain-Mets-Lung-MRI-Path-Segs_histopathology images\data"
CTRANSPATH_WEIGHTS = r"D:\paper\weights\ctranspath.pth"
OUTPUT_DIR = "WSI_FEATUREQC_COMPLETE"
FIGURES_DIR = f"{OUTPUT_DIR}/figures"
FEATURES_DIR = f"{OUTPUT_DIR}/features"

# Create directories
for d in [OUTPUT_DIR, FIGURES_DIR, FEATURES_DIR]:
    Path(d).mkdir(exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Multi-seed sampling for stability
RANDOM_SEEDS = [42, 123, 456, 789, 1011]

# Plotting
sns.set_style("whitegrid")
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.dpi'] = 300

print("="*80)
print("WSI-FEATUREQC: COMPLETE FEATURE EXTRACTION & EVALUATION FRAMEWORK")
print("="*80)
print(f"Device: {DEVICE}")
print(f"Multi-seed sampling: {len(RANDOM_SEEDS)} seeds")
print(f"Output: {OUTPUT_DIR}\n")

def log_msg(msg):
    """Logging with timestamp"""
    timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print(f"[{timestamp}] {msg}")
    try:
        with open(f"{OUTPUT_DIR}/pipeline.log", 'a') as f:
            f.write(f"[{timestamp}] {msg}\n")
    except:
        pass


# ============================================================
# PART 1: DATA-DRIVEN PARAMETER OPTIMIZER
# ============================================================

class MultiMethodOptimizer:
    """
    Optimize preprocessing parameters using 5 complementary methods:
    1. Elbow method for tile count
    2. Youden's J for blur threshold  
    3. Multi-method consensus for tissue threshold
    4. Bootstrap for parameter stability
    5. Entropy for stain normalization
    """
    
    def __init__(self, calibration_slides, n_samples=300):
        self.slides = calibration_slides
        self.n_samples = n_samples
        self.results = {}
        self.calibration_data = {
            'tile_counts': [],
            'tile_variances': [],
            'blur_scores': [],
            'tissue_percentages': [],
            'bootstrap_samples': [],
            'stain_means': [],
            'stain_stds': []
        }
    
    def _is_background(self, tile):
        """Check if tile is mostly background"""
        return np.mean(tile) > 220
    
    def _compute_blur(self, tile):
        """Compute blur score (Laplacian variance + gradient)"""
        gray = rgb2gray(tile)
        lap_var = laplace(gray).var()
        grad = np.sqrt(np.gradient(gray)[0]**2 + np.gradient(gray)[1]**2).mean()
        return lap_var + (grad * 10 if lap_var < 10 else 0)
    
    def _compute_tissue_mask(self, tile):
        """Generate tissue mask using Otsu thresholding"""
        gray = np.mean(tile, axis=2)
        threshold = threshold_otsu(gray) if gray.std() > 1 else 200
        mask = gray < threshold
        mask = remove_small_objects(mask, min_size=500)
        mask = binary_dilation(mask, disk(3))
        return mask
    
    def optimize_tile_count(self, tile_size=224, max_tiles=250):
        """Method 1: Elbow method for optimal tile count"""
        log_msg("Optimizing tile count using Elbow method...")
        
        tile_counts, variances = [], []
        
        for slide_path in self.slides[:3]:  # Use first 3 slides
            try:
                slide = openslide.OpenSlide(slide_path)
                level = slide.get_best_level_for_downsample(1)
                downsample = slide.level_downsamples[level]
                width, height = slide.level_dimensions[level]
                
                tiles = []
                for y in range(0, height - tile_size, tile_size):
                    for x in range(0, width - tile_size, tile_size):
                        if len(tiles) >= max_tiles:
                            break
                        
                        tile = np.array(slide.read_region(
                            (int(x*downsample), int(y*downsample)),
                            level,
                            (tile_size, tile_size)
                        ).convert("RGB"))
                        
                        if not self._is_background(tile):
                            mask = self._compute_tissue_mask(tile)
                            if mask.sum() / mask.size >= 0.1:
                                tiles.append(rgb2gray(tile).flatten())
                    
                    if len(tiles) >= max_tiles:
                        break
                
                slide.close()
                
                if len(tiles) < 50:
                    continue
                
                # Compute variance at different tile counts
                tiles_array = np.array(tiles)
                for n in range(25, min(max_tiles, len(tiles))+1, 25):
                    variance = np.var(np.mean(tiles_array[:n], axis=0))
                    variances.append(variance)
                    tile_counts.append(n)
                
            except Exception as e:
                log_msg(f"  Warning: {e}")
                continue
        
        if len(tile_counts) < 3:
            log_msg("  Insufficient data, using default: 1000")
            return 1000
        
        # Store calibration data
        self.calibration_data['tile_counts'] = tile_counts
        self.calibration_data['tile_variances'] = variances
        
        # Find elbow using second derivative
        tile_counts = np.array(tile_counts)
        variances = np.array(variances)
        second_deriv = np.gradient(np.gradient(variances))
        optimal_idx = np.argmin(np.abs(second_deriv))
        optimal = int(tile_counts[optimal_idx])
        optimal = max(500, min(optimal * 4, 2000))  # Scale to realistic range
        
        self.results['tile_count'] = {
            'optimal': optimal,
            'method': 'elbow',
            'elbow_point': int(tile_counts[optimal_idx])
        }
        
        log_msg(f"  ✅ Optimal tile count: {optimal}")
        return optimal
    
    def optimize_blur_threshold(self, tile_size=224):
        """Method 2: Youden's J for blur threshold"""
        log_msg("Optimizing blur threshold using Youden's J...")
        
        blur_scores = []
        tissue_ratios = []
        
        for slide_path in self.slides[:4]:
            try:
                slide = openslide.OpenSlide(slide_path)
                level = slide.get_best_level_for_downsample(1)
                downsample = slide.level_downsamples[level]
                width, height = slide.level_dimensions[level]
                
                for y in range(0, height - tile_size, tile_size):
                    for x in range(0, width - tile_size, tile_size):
                        if len(blur_scores) >= 500:
                            break
                        
                        tile = np.array(slide.read_region(
                            (int(x*downsample), int(y*downsample)),
                            level,
                            (tile_size, tile_size)
                        ).convert("RGB"))
                        
                        if not self._is_background(tile):
                            blur_scores.append(self._compute_blur(tile))
                            mask = self._compute_tissue_mask(tile)
                            tissue_ratios.append(mask.sum() / mask.size)
                    
                    if len(blur_scores) >= 500:
                        break
                
                slide.close()
                
            except Exception as e:
                log_msg(f"  Warning: {e}")
                continue
        
        if len(blur_scores) < 100:
            log_msg("  Insufficient data, using default: 50")
            return 50.0
        
        # Store calibration data
        self.calibration_data['blur_scores'] = blur_scores
        
        # Youden's J optimization
        blur_array = np.array(blur_scores)
        tissue_array = np.array(tissue_ratios)
        
        # Define background and good tissue
        background = tissue_array < 0.05
        good_tissue = tissue_array >= 0.3
        
        if background.sum() < 10 or good_tissue.sum() < 10:
            optimal = float(np.percentile(blur_array, 5))
        else:
            # Test thresholds
            test_thresholds = np.percentile(blur_array, np.arange(1, 20))
            j_scores = []
            
            for threshold in test_thresholds:
                sensitivity = (blur_array[background] < threshold).sum() / (background.sum() + 1e-8)
                specificity = (blur_array[good_tissue] >= threshold).sum() / (good_tissue.sum() + 1e-8)
                j = sensitivity + specificity - 1
                j_scores.append(j)
            
            optimal = float(test_thresholds[np.argmax(j_scores)])
        
        self.results['blur_threshold'] = {
            'optimal': optimal,
            'method': 'youden_j'
        }
        
        log_msg(f"  ✅ Blur threshold: {optimal:.2f}")
        return optimal
    
    def optimize_tissue_threshold(self, tile_size=224):
        """Method 3: Multi-method consensus for tissue threshold"""
        log_msg("Optimizing tissue threshold using multi-method consensus...")
        
        tissue_ratios = []
        
        for slide_path in self.slides[:5]:
            try:
                slide = openslide.OpenSlide(slide_path)
                level = slide.get_best_level_for_downsample(1)
                downsample = slide.level_downsamples[level]
                width, height = slide.level_dimensions[level]
                
                for y in range(0, height - tile_size, tile_size):
                    for x in range(0, width - tile_size, tile_size):
                        if len(tissue_ratios) >= 600:
                            break
                        
                        tile = np.array(slide.read_region(
                            (int(x*downsample), int(y*downsample)),
                            level,
                            (tile_size, tile_size)
                        ).convert("RGB"))
                        
                        if not self._is_background(tile):
                            mask = self._compute_tissue_mask(tile)
                            tissue_ratios.append(mask.sum() / mask.size)
                    
                    if len(tissue_ratios) >= 600:
                        break
                
                slide.close()
                
            except Exception as e:
                log_msg(f"  Warning: {e}")
                continue
        
        if len(tissue_ratios) < 100:
            log_msg("  Insufficient data, using default: 0.5")
            return 0.5
        
        # Store calibration data
        self.calibration_data['tissue_percentages'] = tissue_ratios
        
        # Multi-method consensus
        tissue_array = np.array(tissue_ratios)
        
        # Method A: 15th percentile
        method_a = float(np.percentile(tissue_array, 15))
        
        # Method B: Otsu on histogram
        hist, bins = np.histogram(tissue_array, bins=50)
        bin_centers = (bins[:-1] + bins[1:]) / 2
        
        # Method C: Mean - std
        method_c = max(0.1, tissue_array.mean() - tissue_array.std())
        
        # Consensus
        consensus = np.median([method_a, method_c])
        consensus = max(0.3, min(consensus, 0.7))
        
        self.results['tissue_threshold'] = {
            'optimal': float(consensus),
            'method_a': method_a,
            'method_c': method_c,
            'method': 'consensus'
        }
        
        log_msg(f"  ✅ Tissue threshold: {consensus:.3f}")
        return consensus
    
    def compute_bootstrap_stability(self, tile_size=224, n_iterations=50):
        """Method 4: Bootstrap for parameter stability"""
        log_msg("Computing bootstrap stability...")
        
        blur_scores = []
        
        for slide_path in self.slides[:2]:
            try:
                slide = openslide.OpenSlide(slide_path)
                level = slide.get_best_level_for_downsample(1)
                downsample = slide.level_downsamples[level]
                width, height = slide.level_dimensions[level]
                
                for y in range(0, height - tile_size, tile_size):
                    for x in range(0, width - tile_size, tile_size):
                        if len(blur_scores) >= 200:
                            break
                        
                        tile = np.array(slide.read_region(
                            (int(x*downsample), int(y*downsample)),
                            level,
                            (tile_size, tile_size)
                        ).convert("RGB"))
                        
                        if not self._is_background(tile):
                            blur_scores.append(self._compute_blur(tile))
                    
                    if len(blur_scores) >= 200:
                        break
                
                slide.close()
                
            except:
                continue
        
        if len(blur_scores) < 50:
            return 50.0, 0.0
        
        # Bootstrap resampling
        blur_array = np.array(blur_scores)
        bootstrap_samples = [
            np.percentile(np.random.choice(blur_array, len(blur_array), replace=True), 5)
            for _ in range(n_iterations)
        ]
        
        mean_val = np.mean(bootstrap_samples)
        std_val = np.std(bootstrap_samples)
        
        self.calibration_data['bootstrap_samples'] = bootstrap_samples
        
        self.results['bootstrap'] = {
            'mean': float(mean_val),
            'std': float(std_val),
            'cv': float(std_val / mean_val) if mean_val > 0 else 0
        }
        
        log_msg(f"  ✅ Bootstrap: {mean_val:.2f} ± {std_val:.2f}")
        return mean_val, std_val
    
    def compute_stain_statistics(self, tile_size=224):
        """Method 5: Stain normalization targets"""
        log_msg("Computing stain normalization statistics...")
        
        tiles = []
        
        for slide_path in self.slides[:3]:
            try:
                slide = openslide.OpenSlide(slide_path)
                level = slide.get_best_level_for_downsample(1)
                downsample = slide.level_downsamples[level]
                width, height = slide.level_dimensions[level]
                
                for y in range(0, height - tile_size, tile_size):
                    for x in range(0, width - tile_size, tile_size):
                        if len(tiles) >= 200:
                            break
                        
                        tile = np.array(slide.read_region(
                            (int(x*downsample), int(y*downsample)),
                            level,
                            (tile_size, tile_size)
                        ).convert("RGB"))
                        
                        if not self._is_background(tile):
                            mask = self._compute_tissue_mask(tile)
                            if mask.sum() / mask.size >= 0.3:
                                tiles.append(tile.astype(np.float32) / 255.0)
                    
                    if len(tiles) >= 200:
                        break
                
                slide.close()
                
            except:
                continue
        
        if len(tiles) < 20:
            # Default H&E values
            means = np.array([0.7, 0.55, 0.65])
            stds = np.array([0.15, 0.15, 0.15])
        else:
            means_list = [t.mean(axis=(0, 1)) for t in tiles]
            stds_list = [t.std(axis=(0, 1)) for t in tiles]
            means = np.mean(means_list, axis=0)
            stds = np.mean(stds_list, axis=0)
            
            self.calibration_data['stain_means'] = means.tolist()
            self.calibration_data['stain_stds'] = stds.tolist()
        
        self.results['stain_normalization'] = {
            'means': means.tolist(),
            'stds': stds.tolist()
        }
        
        log_msg(f"  ✅ Stain stats: RGB means = {means.round(3)}")
        return means, stds
    
    def run_full_optimization(self, tile_size=224):
        """Run all 5 optimization methods"""
        log_msg("\n" + "="*80)
        log_msg("RUNNING FULL PARAMETER OPTIMIZATION")
        log_msg("="*80 + "\n")
        
        n_tiles = self.optimize_tile_count(tile_size)
        blur_th = self.optimize_blur_threshold(tile_size)
        tissue_th = self.optimize_tissue_threshold(tile_size)
        boot_mean, boot_std = self.compute_bootstrap_stability(tile_size)
        stain_mean, stain_std = self.compute_stain_statistics(tile_size)
        
        params = {
            'tile_size': tile_size,
            'n_tiles': n_tiles,
            'blur_threshold': blur_th,
            'tissue_threshold': tissue_th,
            'stain_means': stain_mean.tolist(),
            'stain_stds': stain_std.tolist()
        }
        
        return params
    
    def save_results(self, output_dir):
        """Save optimization results"""
        # Save optimization results
        with open(f"{output_dir}/optimization_results.json", 'w') as f:
            json.dump({
                'timestamp': datetime.now().isoformat(),
                'random_seed': RANDOM_SEED,
                **self.results
            }, f, indent=2)
        
        # Save calibration data
        with open(f"{output_dir}/calibration_data.json", 'w') as f:
            # Convert numpy arrays to lists
            calib_data_serializable = {}
            for key, value in self.calibration_data.items():
                if isinstance(value, np.ndarray):
                    calib_data_serializable[key] = value.tolist()
                elif isinstance(value, list) and len(value) > 0 and isinstance(value[0], np.ndarray):
                    calib_data_serializable[key] = [v.tolist() for v in value]
                else:
                    calib_data_serializable[key] = value
            
            json.dump(calib_data_serializable, f, indent=2)
        
        log_msg("✅ Optimization results saved")


# ============================================================
# PART 2: ADVANCED PATCH QUALITY CONTROL
# ============================================================

class AdvancedPatchQC:
    """
    Multi-factor patch quality assessment:
    - Tissue content
    - Focus quality (blur)
    - Artifact detection (pen marks, folds, dust)
    - Information content (entropy, edges)
    - Representativeness (color diversity)
    """
    
    def __init__(self, weights=None):
        self.weights = weights or {
            'tissue': 0.25,
            'focus': 0.20,
            'artifact': 0.20,
            'information': 0.20,
            'representativeness': 0.15
        }
    
    def tissue_score(self, patch, mask):
        """Tissue content score"""
        return mask.sum() / mask.size
    
    def focus_score(self, patch):
        """Focus quality (inverse blur)"""
        gray = rgb2gray(patch)
        lap_var = laplace(gray).var()
        # Normalize to 0-1
        return min(1.0, lap_var / 100.0)
    
    def artifact_score(self, patch):
        """Detect artifacts (pen marks, folds)"""
        hsv = rgb2hsv(patch)
        
        # Pen mark detection (blue/green markers)
        pen_mask = ((hsv[:,:,0] > 0.4) & (hsv[:,:,0] < 0.7)) & (hsv[:,:,1] > 0.5)
        pen_ratio = pen_mask.sum() / pen_mask.size
        
        # Fold detection (high edge density)
        gray = rgb2gray(patch)
        edges = cv2.Canny((gray * 255).astype(np.uint8), 50, 150)
        edge_density = edges.sum() / edges.size
        
        # Penalty for artifacts
        artifact_penalty = pen_ratio * 0.5 + min(edge_density / 0.1, 1.0) * 0.5
        return max(0, 1.0 - artifact_penalty)
    
    def information_score(self, patch):
        """Information content (entropy + edges)"""
        gray = rgb2gray(patch)
        
        # Shannon entropy
        hist, _ = np.histogram(gray, bins=32, range=(0, 1))
        hist = hist / (hist.sum() + 1e-8)
        entropy = -np.sum(hist * np.log(hist + 1e-8))
        entropy_norm = entropy / np.log(32)
        
        # Edge density
        edges = cv2.Canny((gray * 255).astype(np.uint8), 50, 150)
        edge_score = min(edges.sum() / edges.size / 0.05, 1.0)
        
        return 0.7 * entropy_norm + 0.3 * edge_score
    
    def representativeness_score(self, patch, slide_color_centroid):
        """Color representativeness relative to slide"""
        if slide_color_centroid is None:
            return 0.5
        
        patch_centroid = patch.reshape(-1, 3).mean(axis=0) / 255.0
        distance = np.linalg.norm(patch_centroid - slide_color_centroid)
        
        # Closer = more representative
        return np.exp(-distance / 0.3)
    
    def compute_quality_score(self, patch, mask, slide_color_centroid=None):
        """Aggregate quality score"""
        scores = {
            'tissue': self.tissue_score(patch, mask),
            'focus': self.focus_score(patch),
            'artifact': self.artifact_score(patch),
            'information': self.information_score(patch),
            'representativeness': self.representativeness_score(patch, slide_color_centroid)
        }
        
        # Weighted sum
        total_score = sum(scores[k] * self.weights[k] for k in scores)
        
        return total_score, scores


# ============================================================
# PART 3: ENCODER WRAPPERS
# ============================================================

class ResNet50Encoder:
    """ResNet-50 ImageNet pretrained (baseline)"""
    
    def __init__(self, device='cuda'):
        log_msg("  Loading ResNet-50 (ImageNet)...")
        self.device = device
        self.name = "ResNet50-ImageNet"
        
        resnet = models.resnet50(pretrained=True)
        self.model = nn.Sequential(*list(resnet.children())[:-1])
        self.model = self.model.to(device).eval()
        
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        self.feat_dim = 2048
        log_msg("    ✅ ResNet-50 loaded")
    
    def extract_features(self, patches, batch_size=32):
        """Extract features from patches"""
        features = []
        
        for i in range(0, len(patches), batch_size):
            batch = patches[i:i+batch_size]
            
            try:
                batch_tensor = torch.stack([
                    self.transform(Image.fromarray(p)) for p in batch
                ]).to(self.device)
                
                with torch.no_grad():
                    feat = self.model(batch_tensor).squeeze(-1).squeeze(-1)
                    features.append(feat.cpu().numpy())
            except Exception as e:
                log_msg(f"    ⚠️ Batch extraction failed: {e}")
                continue
        
        if not features:
            return None
        
        return np.vstack(features)


class CTransPathEncoder:
    """CTransPath domain-specific encoder"""
    
    def __init__(self, weights_path, device='cuda'):
        log_msg("  Loading CTransPath...")
        self.device = device
        self.name = "CTransPath"
        
        try:
            if os.path.exists(weights_path):
                checkpoint = torch.load(weights_path, map_location='cpu')
                state_dict = checkpoint.get('model', checkpoint.get('state_dict', checkpoint))
                
                self.model = timm.create_model(
                    "swin_tiny_patch4_window7_224",
                    pretrained=False,
                    num_classes=0,
                    global_pool='avg'
                )
                self.model.load_state_dict(state_dict, strict=False)
                log_msg("    ✅ CTransPath loaded from checkpoint")
            else:
                raise FileNotFoundError(f"Weights not found: {weights_path}")
        except Exception as e:
            log_msg(f"    ⚠️ Checkpoint load failed: {e}")
            log_msg("    ✅ Using Swin-Tiny fallback")
            self.model = timm.create_model(
                "swin_tiny_patch4_window7_224",
                pretrained=True,
                num_classes=0,
                global_pool='avg'
            )
        
        self.model = self.model.to(device).eval()
        
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        self.feat_dim = 768
    
    def extract_features(self, patches, batch_size=32):
        """Extract features"""
        features = []
        
        for i in range(0, len(patches), batch_size):
            batch = patches[i:i+batch_size]
            
            try:
                batch_tensor = torch.stack([
                    self.transform(Image.fromarray(p)) for p in batch
                ]).to(self.device)
                
                with torch.no_grad():
                    feat = self.model(batch_tensor)
                    features.append(feat.cpu().numpy())
            except Exception as e:
                log_msg(f"    ⚠️ Batch extraction failed: {e}")
                continue
        
        if not features:
            return None
        
        return np.vstack(features)


# ============================================================
# PART 4: FEATURE QUALITY EVALUATORS
# ============================================================

class StabilityEvaluator:
    """Evaluate feature stability across multiple samplings"""
    
    @staticmethod
    def compute_stability(embeddings_dict, slide_ids):
        """
        Compute per-slide stability across seeds
        
        Args:
            embeddings_dict: {seed: {slide_id: embedding}}
            slide_ids: List of slide IDs
        
        Returns:
            DataFrame with stability metrics
        """
        results = []
        seeds = list(embeddings_dict.keys())
        
        for slide_id in slide_ids:
            embeddings = []
            for seed in seeds:
                if slide_id in embeddings_dict[seed]:
                    embeddings.append(embeddings_dict[seed][slide_id])
            
            if len(embeddings) < 2:
                continue
            
            # Pairwise cosine similarities
            similarities = []
            for i in range(len(embeddings)):
                for j in range(i+1, len(embeddings)):
                    sim = 1 - cosine(embeddings[i], embeddings[j])
                    similarities.append(sim)
            
            # Confidence score
            variance = np.var(similarities)
            confidence = 1.0 / (1.0 + variance)
            
            results.append({
                'slide_id': slide_id,
                'mean_similarity': np.mean(similarities),
                'std_similarity': np.std(similarities),
                'min_similarity': np.min(similarities),
                'max_similarity': np.max(similarities),
                'confidence': confidence,
                'n_samplings': len(embeddings)
            })
        
        return pd.DataFrame(results)


class RedundancyEvaluator:
    """Evaluate feature redundancy"""
    
    @staticmethod
    def compute_redundancy(features):
        """
        Args:
            features: (N_slides, feat_dim)
        
        Returns:
            dict with redundancy metrics
        """
        # Correlation analysis
        corr_matrix = np.corrcoef(features.T)
        upper_tri = np.triu_indices_from(corr_matrix, k=1)
        high_corr = np.sum(np.abs(corr_matrix[upper_tri]) > 0.8)
        total_pairs = len(upper_tri[0])
        redundancy_ratio = high_corr / total_pairs
        
        # PCA effective dimension
        pca = PCA()
        pca.fit(features)
        cumsum_var = np.cumsum(pca.explained_variance_ratio_)
        n_90 = np.argmax(cumsum_var >= 0.90) + 1
        n_95 = np.argmax(cumsum_var >= 0.95) + 1
        
        return {
            'redundancy_ratio': redundancy_ratio,
            'mean_abs_correlation': np.mean(np.abs(corr_matrix[upper_tri])),
            'pca_n_components_90': n_90,
            'pca_n_components_95': n_95,
            'effective_dim_ratio_90': n_90 / features.shape[1],
            'effective_dim_ratio_95': n_95 / features.shape[1]
        }


class SeparabilityEvaluator:
    """Evaluate feature separability"""
    
    @staticmethod
    def compute_separability(features, labels):
        """
        Args:
            features: (N, feat_dim)
            labels: (N,) categorical labels
        
        Returns:
            dict with separability metrics
        """
        # Remove NaN labels
        valid_idx = ~pd.isna(labels)
        features = features[valid_idx]
        labels = labels[valid_idx]
        
        if len(np.unique(labels)) < 2 or len(features) < 10:
            return None
        
        # Silhouette score
        silhouette = silhouette_score(features, labels, metric='cosine')
        
        # Davies-Bouldin index
        davies_bouldin = davies_bouldin_score(features, labels)
        
        # kNN leave-one-out
        knn = KNeighborsClassifier(n_neighbors=5)
        loo = LeaveOneOut()
        
        predictions, ground_truth = [], []
        for train_idx, test_idx in loo.split(features):
            knn.fit(features[train_idx], labels[train_idx])
            pred = knn.predict(features[test_idx])
            predictions.append(pred[0])
            ground_truth.append(labels[test_idx][0])
        
        accuracy = np.mean(np.array(predictions) == np.array(ground_truth))
        
        return {
            'silhouette_score': silhouette,
            'davies_bouldin_index': davies_bouldin,
            'knn_loo_accuracy': accuracy,
            'n_classes': len(np.unique(labels))
        }


class HeterogeneityAnalyzer:
    """Quantify slide-level heterogeneity"""
    
    @staticmethod
    def compute_heterogeneity(patch_features):
        """
        Args:
            patch_features: (N_patches, feat_dim)
        
        Returns:
            heterogeneity_score, metrics_dict
        """
        if len(patch_features) < 10:
            return 0.0, {}
        
        # Pairwise distances
        distances = pdist(patch_features, metric='cosine')
        mean_distance = np.mean(distances)
        std_distance = np.std(distances)
        
        # Estimate cluster count
        inertias = []
        k_range = range(2, min(10, len(patch_features)//10))
        for k in k_range:
            kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
            kmeans.fit(patch_features)
            inertias.append(kmeans.inertia_)
        
        if len(inertias) >= 2:
            diffs = np.diff(inertias)
            elbow_k = np.argmin(diffs) + 2
        else:
            elbow_k = 2
        
        # Silhouette for estimated clusters
        if elbow_k < len(patch_features):
            kmeans = KMeans(n_clusters=elbow_k, random_state=42, n_init=10)
            labels = kmeans.fit_predict(patch_features)
            silhouette = silhouette_score(patch_features, labels, metric='cosine')
        else:
            silhouette = 0.0
        
        # Feature entropy
        pca = PCA(n_components=min(50, patch_features.shape[1]))
        pca.fit(patch_features)
        variance_ratios = pca.explained_variance_ratio_
        entropy = -np.sum(variance_ratios * np.log(variance_ratios + 1e-8))
        
        metrics = {
            'mean_distance': float(mean_distance),
            'std_distance': float(std_distance),
            'estimated_clusters': int(elbow_k),
            'silhouette': float(silhouette),
            'feature_entropy': float(entropy)
        }
        
        # Aggregate score
        heterogeneity_score = (
            0.3 * mean_distance +
            0.2 * (elbow_k / 10.0) +
            0.3 * (1.0 - silhouette) +
            0.2 * (entropy / 4.0)
        )
        
        return float(heterogeneity_score), metrics


# ============================================================
# PART 5: MAIN PIPELINE ORCHESTRATOR
# ============================================================

class WSIFeatureQC:
    """Main pipeline orchestrator"""
    
    def __init__(self, config):
        self.config = config
        self.optimized_params = None
        self.encoders = {}
        self.qc = AdvancedPatchQC()
    
    def initialize_encoders(self):
        """Initialize all encoders"""
        log_msg("\n" + "="*80)
        log_msg("INITIALIZING ENCODERS")
        log_msg("="*80 + "\n")
        
        # ResNet-50
        self.encoders['resnet50'] = ResNet50Encoder(device=DEVICE)
        
        # CTransPath
        if os.path.exists(CTRANSPATH_WEIGHTS):
            self.encoders['ctranspath'] = CTransPathEncoder(
                CTRANSPATH_WEIGHTS, 
                device=DEVICE
            )
        
        log_msg(f"\n✅ Initialized {len(self.encoders)} encoders")
    
    def run_optimization(self, calibration_slides):
        """Run parameter optimization"""
        log_msg("\n" + "="*80)
        log_msg("STEP 1: PARAMETER OPTIMIZATION")
        log_msg("="*80 + "\n")
        
        optimizer = MultiMethodOptimizer(calibration_slides)
        self.optimized_params = optimizer.run_full_optimization()
        optimizer.save_results(OUTPUT_DIR)
        
        # Save params
        with open(f"{OUTPUT_DIR}/optimized_params.json", 'w') as f:
            json.dump(self.optimized_params, f, indent=2)
        
        log_msg("\n✅ Optimization complete")
        return self.optimized_params
    
    def extract_patches(self, slide_path, seed=42):
        """Extract high-quality patches from slide"""
        np.random.seed(seed)
        
        slide = openslide.OpenSlide(slide_path)
        level = slide.get_best_level_for_downsample(1)
        downsample = slide.level_downsamples[level]
        width, height = slide.level_dimensions[level]
        
        tile_size = self.optimized_params['tile_size']
        n_tiles = self.optimized_params['n_tiles']
        blur_th = self.optimized_params['blur_threshold']
        tissue_th = self.optimized_params['tissue_threshold']
        
        # Generate tissue mask at thumbnail resolution
        thumb_w, thumb_h = width // 32, height // 32
        thumbnail = slide.get_thumbnail((thumb_w, thumb_h))
        thumbnail_arr = np.array(thumbnail)
        
        hsv = rgb2hsv(thumbnail_arr)
        thresh = threshold_otsu(hsv[:,:,1])
        tissue_mask = hsv[:,:,1] > thresh
        tissue_mask = remove_small_objects(tissue_mask, min_size=50)
        tissue_mask = binary_closing(tissue_mask, disk(3))
        
        # Get tissue coordinates (downsampled)
        tissue_coords = np.argwhere(tissue_mask > 0)
        if len(tissue_coords) == 0:
            slide.close()
            return []
        
        # Randomly sample coordinates
        np.random.shuffle(tissue_coords)
        
        patches = []
        attempts = 0
        max_attempts = min(len(tissue_coords), n_tiles * 10)
        
        for coord in tissue_coords[:max_attempts]:
            if len(patches) >= n_tiles:
                break
            
            # Scale coordinates
            y_thumb, x_thumb = coord
            y = y_thumb * 32
            x = x_thumb * 32
            
            if y + tile_size > height or x + tile_size > width:
                continue
            
            attempts += 1
            
            try:
                patch = np.array(slide.read_region(
                    (int(x*downsample), int(y*downsample)),
                    level,
                    (tile_size, tile_size)
                ).convert("RGB"))
                
                # Quality checks
                if np.mean(patch) > 220:
                    continue
                
                # Tissue mask
                gray = np.mean(patch, axis=2)
                mask_thresh = threshold_otsu(gray) if gray.std() > 1 else 200
                mask = gray < mask_thresh
                mask = remove_small_objects(mask, 500)
                mask = binary_dilation(mask, disk(3))
                
                tissue_ratio = mask.sum() / mask.size
                if tissue_ratio < tissue_th:
                    continue
                
                # Blur check
                blur_score = laplace(rgb2gray(patch)).var()
                if blur_score < blur_th:
                    continue
                
                patches.append(patch)
                
            except:
                continue
        
        slide.close()
        return patches
    
    def process_slide(self, slide_path, slide_id, encoder_name, seeds=RANDOM_SEEDS):
        """Process single slide with multi-seed sampling"""
        log_msg(f"  Processing {slide_id} with {encoder_name}...")
        
        encoder = self.encoders[encoder_name]
        seed_embeddings = []
        
        for seed in seeds:
            patches = self.extract_patches(slide_path, seed=seed)
            
            if len(patches) < 100:
                log_msg(f"    ⚠️ Seed {seed}: Only {len(patches)} patches")
                continue
            
            # Extract features
            features = encoder.extract_features(patches)
            
            if features is None:
                continue
            
            # Aggregate
            slide_emb = np.concatenate([
                features.mean(axis=0),
                features.std(axis=0)
            ])
            
            seed_embeddings.append(slide_emb)
        
        if len(seed_embeddings) < 2:
            return None
        
        return seed_embeddings
    
    def run_pipeline(self, slide_paths, slide_ids):
        """Run complete pipeline"""
        log_msg("\n" + "="*80)
        log_msg("STEP 2: FEATURE EXTRACTION")
        log_msg("="*80 + "\n")
        
        all_results = {}
        
        for encoder_name in self.encoders:
            log_msg(f"\nEncoder: {encoder_name}")
            
            embeddings_by_seed = {seed: {} for seed in RANDOM_SEEDS}
            
            for slide_path, slide_id in zip(slide_paths, slide_ids):
                seed_embs = self.process_slide(
                    slide_path, slide_id, encoder_name
                )
                
                if seed_embs is None:
                    continue
                
                for i, seed in enumerate(RANDOM_SEEDS[:len(seed_embs)]):
                    embeddings_by_seed[seed][slide_id] = seed_embs[i]
            
            all_results[encoder_name] = embeddings_by_seed
        
        return all_results
    
    def evaluate_features(self, all_results, slide_ids, labels=None):
        """Evaluate feature quality"""
        log_msg("\n" + "="*80)
        log_msg("STEP 3: FEATURE QUALITY EVALUATION")
        log_msg("="*80 + "\n")
        
        evaluation_results = {}
        
        for encoder_name, embeddings_by_seed in all_results.items():
            log_msg(f"\nEvaluating {encoder_name}...")
            
            # Stability
            stability_df = StabilityEvaluator.compute_stability(
                embeddings_by_seed, slide_ids
            )
            
            # Redundancy (use first seed)
            first_seed = RANDOM_SEEDS[0]
            features_matrix = np.array([
                embeddings_by_seed[first_seed][sid]
                for sid in slide_ids
                if sid in embeddings_by_seed[first_seed]
            ])
            
            redundancy = RedundancyEvaluator.compute_redundancy(features_matrix)
            
            # Separability (if labels provided)
            separability = None
            if labels is not None:
                separability = SeparabilityEvaluator.compute_separability(
                    features_matrix, labels
                )
            
            evaluation_results[encoder_name] = {
                'stability': stability_df,
                'redundancy': redundancy,
                'separability': separability
            }
            
            log_msg(f"  ✅ Stability: {stability_df['mean_similarity'].mean():.3f}")
            log_msg(f"  ✅ Redundancy: {redundancy['redundancy_ratio']:.3f}")
            if separability:
                log_msg(f"  ✅ Separability: {separability['silhouette_score']:.3f}")
        
        return evaluation_results


# ============================================================
# MAIN EXECUTION
# ============================================================

def main():
    """Main execution function"""
    
    # Get slide paths
    slide_files = [f for f in os.listdir(SVS_DIR) if f.lower().endswith('.svs')]
    
    if len(slide_files) == 0:
        log_msg("❌ No SVS files found!")
        return
    
    np.random.shuffle(slide_files)
    
    # Split: 10% calibration, rest for processing
    n_calib = max(1, int(0.1 * len(slide_files)))
    calib_files = slide_files[:n_calib]
    proc_files = slide_files[:10]  # Process first 10 for demo
    
    calib_paths = [os.path.join(SVS_DIR, f) for f in calib_files]
    proc_paths = [os.path.join(SVS_DIR, f) for f in proc_files]
    proc_ids = [Path(f).stem for f in proc_files]
    
    log_msg(f"Calibration slides: {n_calib}")
    log_msg(f"Processing slides: {len(proc_files)}")
    
    # Initialize pipeline
    pipeline = WSIFeatureQC(config={})
    
    # Step 1: Optimize parameters
    params = pipeline.run_optimization(calib_paths)
    
    # Step 2: Initialize encoders
    pipeline.initialize_encoders()
    
    # Step 3: Extract features
    all_results = pipeline.run_pipeline(proc_paths, proc_ids)
    
    # Step 4: Evaluate
    evaluation_results = pipeline.evaluate_features(all_results, proc_ids)
    
    # Save results
    log_msg("\n" + "="*80)
    log_msg("SAVING RESULTS")
    log_msg("="*80 + "\n")
    
    with open(f"{OUTPUT_DIR}/evaluation_results.json", 'w') as f:
        # Convert DataFrames to dicts for JSON serialization
        eval_serializable = {}
        for enc_name, results in evaluation_results.items():
            eval_serializable[enc_name] = {
                'stability': results['stability'].to_dict('records'),
                'redundancy': results['redundancy'],
                'separability': results['separability']
            }
        json.dump(eval_serializable, f, indent=2)
    
    log_msg("✅ Results saved")
    
    log_msg("\n" + "="*80)
    log_msg("✅ PIPELINE COMPLETE")
    log_msg("="*80 + "\n")


# ============================================================
# PART 6: BASELINE IMPLEMENTATIONS FOR COMPARISON
# ============================================================

class NaiveCNNBaseline:
    """
    Baseline 1: Naive CNN approach
    - Uniform tiling (no quality control)
    - ResNet-50 features
    - Simple mean pooling
    """
    
    def __init__(self, device='cuda'):
        self.name = "Naive-CNN"
        self.device = device
        
        resnet = models.resnet50(pretrained=True)
        self.model = nn.Sequential(*list(resnet.children())[:-1])
        self.model = self.model.to(device).eval()
        
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    
    def extract_patches_naive(self, slide_path, n_patches=1000, tile_size=224):
        """Naive uniform tiling - NO quality control"""
        slide = openslide.OpenSlide(slide_path)
        level = slide.get_best_level_for_downsample(1)
        downsample = slide.level_downsamples[level]
        width, height = slide.level_dimensions[level]
        
        patches = []
        step = tile_size
        
        for y in range(0, height - tile_size, step):
            for x in range(0, width - tile_size, step):
                if len(patches) >= n_patches:
                    break
                
                try:
                    patch = np.array(slide.read_region(
                        (int(x*downsample), int(y*downsample)),
                        level,
                        (tile_size, tile_size)
                    ).convert("RGB"))
                    
                    # Only reject pure white background
                    if np.mean(patch) < 240:
                        patches.append(patch)
                except:
                    continue
            
            if len(patches) >= n_patches:
                break
        
        slide.close()
        return patches
    
    def process_slide(self, slide_path, seed=42):
        """Extract features - single seed only"""
        np.random.seed(seed)
        
        patches = self.extract_patches_naive(slide_path)
        
        if len(patches) < 100:
            return None
        
        # Extract features
        features = []
        for i in range(0, len(patches), 32):
            batch = patches[i:i+32]
            batch_tensor = torch.stack([
                self.transform(Image.fromarray(p)) for p in batch
            ]).to(self.device)
            
            with torch.no_grad():
                feat = self.model(batch_tensor).squeeze(-1).squeeze(-1)
                features.append(feat.cpu().numpy())
        
        features = np.vstack(features)
        
        # Simple mean pooling (no std)
        return features.mean(axis=0)


class AttentionMILBaseline:
    """
    Baseline 2: Attention-MIL approach (simplified CLAM-style)
    - ResNet-50 features
    - Attention-weighted aggregation
    - Requires labels for training
    """
    
    def __init__(self, device='cuda', feat_dim=2048, hidden_dim=256):
        self.name = "Attention-MIL"
        self.device = device
        
        # Feature extractor
        resnet = models.resnet50(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])
        self.feature_extractor = self.feature_extractor.to(device).eval()
        
        # Attention network
        self.attention = nn.Sequential(
            nn.Linear(feat_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        ).to(device)
        
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        self.trained = False
    
    def extract_patch_features(self, patches):
        """Extract patch-level features"""
        features = []
        for i in range(0, len(patches), 32):
            batch = patches[i:i+32]
            batch_tensor = torch.stack([
                self.transform(Image.fromarray(p)) for p in batch
            ]).to(self.device)
            
            with torch.no_grad():
                feat = self.feature_extractor(batch_tensor).squeeze(-1).squeeze(-1)
                features.append(feat.cpu().numpy())
        
        return np.vstack(features)
    
    def train_attention(self, slide_features_list, labels, epochs=10):
        """Train attention weights (simplified)"""
        log_msg("  Training attention weights...")
        
        optimizer = torch.optim.Adam(self.attention.parameters(), lr=1e-4)
        criterion = nn.BCEWithLogitsLoss()
        
        for epoch in range(epochs):
            total_loss = 0
            for features, label in zip(slide_features_list, labels):
                features_tensor = torch.FloatTensor(features).to(self.device)
                label_tensor = torch.FloatTensor([label]).to(self.device)
                
                # Attention scores
                attention_scores = self.attention(features_tensor)
                attention_weights = torch.softmax(attention_scores, dim=0)
                
                # Weighted aggregation
                slide_emb = (features_tensor * attention_weights).sum(dim=0)
                
                # Simple binary classification
                pred = slide_emb.mean()
                loss = criterion(pred.unsqueeze(0), label_tensor)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            if (epoch + 1) % 5 == 0:
                log_msg(f"    Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}")
        
        self.trained = True
        log_msg("  ✅ Attention training complete")
    
    def process_slide(self, slide_path, seed=42):
        """Extract features with attention pooling"""
        if not self.trained:
            log_msg("  ⚠️ Attention not trained, using uniform weights")
        
        # Use naive patch extraction
        naive = NaiveCNNBaseline(self.device)
        patches = naive.extract_patches_naive(slide_path)
        
        if len(patches) < 100:
            return None
        
        # Extract features
        features = self.extract_patch_features(patches)
        features_tensor = torch.FloatTensor(features).to(self.device)
        
        # Apply attention
        with torch.no_grad():
            if self.trained:
                attention_scores = self.attention(features_tensor)
                attention_weights = torch.softmax(attention_scores, dim=0)
            else:
                attention_weights = torch.ones(len(features), 1).to(self.device) / len(features)
            
            slide_emb = (features_tensor * attention_weights).sum(dim=0)
        
        return slide_emb.cpu().numpy()


class RadiomicsBaseline:
    """
    Baseline 3: Classical radiomics features
    - Hand-crafted texture features
    - Shape features
    - Intensity statistics
    """
    
    def __init__(self):
        self.name = "Radiomics"
    
    def extract_texture_features(self, patch):
        """Extract texture features from patch"""
        gray = rgb2gray(patch)
        
        features = {}
        
        # Intensity statistics
        features['intensity_mean'] = gray.mean()
        features['intensity_std'] = gray.std()
        features['intensity_min'] = gray.min()
        features['intensity_max'] = gray.max()
        features['intensity_range'] = gray.max() - gray.min()
        
        # Histogram features
        hist, _ = np.histogram(gray, bins=32)
        hist = hist / (hist.sum() + 1e-8)
        features['entropy'] = -np.sum(hist * np.log(hist + 1e-8))
        features['uniformity'] = np.sum(hist ** 2)
        
        # GLCM-inspired features (simplified)
        # Horizontal differences
        h_diff = np.abs(np.diff(gray, axis=1))
        features['contrast'] = h_diff.mean()
        features['homogeneity'] = 1.0 / (1.0 + h_diff.mean())
        
        # Gradient features
        gy, gx = np.gradient(gray)
        gradient_mag = np.sqrt(gx**2 + gy**2)
        features['gradient_mean'] = gradient_mag.mean()
        features['gradient_std'] = gradient_mag.std()
        
        # Edge features
        edges = cv2.Canny((gray * 255).astype(np.uint8), 50, 150)
        features['edge_density'] = edges.sum() / edges.size
        
        return features
    
    def process_slide(self, slide_path, seed=42):
        """Extract radiomics features"""
        # Use naive patch extraction
        naive = NaiveCNNBaseline()
        patches = naive.extract_patches_naive(slide_path, n_patches=500)
        
        if len(patches) < 100:
            return None
        
        # Extract features from all patches
        all_features = []
        for patch in patches:
            feat_dict = self.extract_texture_features(patch)
            all_features.append(list(feat_dict.values()))
        
        # Aggregate: mean and std
        features_array = np.array(all_features)
        slide_features = np.concatenate([
            features_array.mean(axis=0),
            features_array.std(axis=0)
        ])
        
        return slide_features


# ============================================================
# PART 7: FRAMEWORK COMPARISON ENGINE
# ============================================================

class FrameworkComparison:
    """
    Compare your framework against baselines:
    - Naive CNN
    - Attention-MIL
    - Radiomics
    
    Metrics:
    - Stability
    - Redundancy
    - Separability
    - Robustness
    - Runtime
    """
    
    def __init__(self, your_framework, device='cuda'):
        self.your_framework = your_framework
        self.device = device
        
        # Initialize baselines
        self.baselines = {
            'naive_cnn': NaiveCNNBaseline(device),
            'attention_mil': AttentionMILBaseline(device),
            'radiomics': RadiomicsBaseline()
        }
        
        self.comparison_results = {}
    
    def extract_baseline_features(self, slide_paths, slide_ids, baseline_name, 
                                   n_seeds=1):
        """Extract features using a baseline method"""
        log_msg(f"\n{'='*80}")
        log_msg(f"BASELINE: {baseline_name.upper()}")
        log_msg(f"{'='*80}\n")
        
        baseline = self.baselines[baseline_name]
        
        embeddings_by_seed = {seed: {} for seed in RANDOM_SEEDS[:n_seeds]}
        
        for slide_path, slide_id in zip(slide_paths, slide_ids):
            log_msg(f"  Processing {slide_id}...")
            
            for seed in RANDOM_SEEDS[:n_seeds]:
                emb = baseline.process_slide(slide_path, seed=seed)
                
                if emb is not None:
                    embeddings_by_seed[seed][slide_id] = emb
        
        return embeddings_by_seed
    
    def compute_comparison_metrics(self, embeddings_dict, slide_ids, labels=None):
        """Compute all evaluation metrics for a method"""
        metrics = {}
        
        # 1. Stability (if multiple seeds)
        if len(embeddings_dict) > 1:
            stability_df = StabilityEvaluator.compute_stability(
                embeddings_dict, slide_ids
            )
            metrics['stability_mean'] = stability_df['mean_similarity'].mean()
            metrics['stability_std'] = stability_df['mean_similarity'].std()
        else:
            metrics['stability_mean'] = None
            metrics['stability_std'] = None
        
        # 2. Redundancy
        first_seed = list(embeddings_dict.keys())[0]
        features_matrix = np.array([
            embeddings_dict[first_seed][sid]
            for sid in slide_ids
            if sid in embeddings_dict[first_seed]
        ])
        
        if len(features_matrix) > 5:
            redundancy = RedundancyEvaluator.compute_redundancy(features_matrix)
            metrics['redundancy_ratio'] = redundancy['redundancy_ratio']
            metrics['pca_effective_dim_95'] = redundancy['pca_n_components_95']
        else:
            metrics['redundancy_ratio'] = None
            metrics['pca_effective_dim_95'] = None
        
        # 3. Separability (if labels provided)
        if labels is not None and len(features_matrix) > 5:
            separability = SeparabilityEvaluator.compute_separability(
                features_matrix, labels
            )
            if separability:
                metrics['silhouette_score'] = separability['silhouette_score']
                metrics['knn_accuracy'] = separability['knn_loo_accuracy']
        else:
            metrics['silhouette_score'] = None
            metrics['knn_accuracy'] = None
        
        return metrics
    
    def run_comparison(self, slide_paths, slide_ids, labels=None):
        """Run complete framework comparison"""
        log_msg("\n" + "="*80)
        log_msg("FRAMEWORK COMPARISON")
        log_msg("="*80 + "\n")
        
        results = {}
        
        # 1. Your framework (already extracted - multi-seed)
        log_msg("Evaluating YOUR FRAMEWORK (multi-seed)...")
        # Assuming your_framework.all_results exists
        for encoder_name in self.your_framework.encoders:
            if hasattr(self.your_framework, 'all_results'):
                your_embeddings = self.your_framework.all_results[encoder_name]
                your_metrics = self.compute_comparison_metrics(
                    your_embeddings, slide_ids, labels
                )
                results[f'yours_{encoder_name}'] = your_metrics
        
        # 2. Naive CNN baseline (single seed only)
        naive_embeddings = self.extract_baseline_features(
            slide_paths, slide_ids, 'naive_cnn', n_seeds=1
        )
        naive_metrics = self.compute_comparison_metrics(
            naive_embeddings, slide_ids, labels
        )
        results['naive_cnn'] = naive_metrics
        
        # 3. Attention-MIL baseline (single seed)
        # Note: Would need labels to train attention properly
        attn_embeddings = self.extract_baseline_features(
            slide_paths, slide_ids, 'attention_mil', n_seeds=1
        )
        attn_metrics = self.compute_comparison_metrics(
            attn_embeddings, slide_ids, labels
        )
        results['attention_mil'] = attn_metrics
        
        # 4. Radiomics baseline (single seed)
        radiomics_embeddings = self.extract_baseline_features(
            slide_paths, slide_ids, 'radiomics', n_seeds=1
        )
        radiomics_metrics = self.compute_comparison_metrics(
            radiomics_embeddings, slide_ids, labels
        )
        results['radiomics'] = radiomics_metrics
        
        self.comparison_results = results
        return results
    
    def generate_comparison_table(self):
        """Generate comparison table for paper"""
        log_msg("\n" + "="*80)
        log_msg("COMPARISON TABLE")
        log_msg("="*80 + "\n")
        
        # Create DataFrame
        rows = []
        for method_name, metrics in self.comparison_results.items():
            row = {
                'Method': method_name,
                'Stability ↑': f"{metrics['stability_mean']:.3f}" if metrics['stability_mean'] else "N/A",
                'Redundancy ↓': f"{metrics['redundancy_ratio']:.3f}" if metrics['redundancy_ratio'] else "N/A",
                'Separability ↑': f"{metrics['silhouette_score']:.3f}" if metrics['silhouette_score'] else "N/A",
                'kNN Acc ↑': f"{metrics['knn_accuracy']:.3f}" if metrics['knn_accuracy'] else "N/A"
            }
            rows.append(row)
        
        df = pd.DataFrame(rows)
        
        # Print table
        print(df.to_string(index=False))
        
        # Save to CSV
        df.to_csv(f"{OUTPUT_DIR}/comparison_table.csv", index=False)
        log_msg(f"\n✅ Comparison table saved to {OUTPUT_DIR}/comparison_table.csv")
        
        return df
    
    def statistical_significance_test(self):
        """Perform statistical tests (Wilcoxon signed-rank)"""
        log_msg("\n" + "="*80)
        log_msg("STATISTICAL SIGNIFICANCE TESTS")
        log_msg("="*80 + "\n")
        
        # Compare your best method vs each baseline
        # This would require per-slide scores
        # For now, just report summary
        
        log_msg("Statistical tests require per-slide scores")
        log_msg("Implement Wilcoxon signed-rank test for paired comparison")
        
        # TODO: Implement proper statistical testing
        pass


# ============================================================
# PART 8: UPDATED MAIN EXECUTION WITH COMPARISON
# ============================================================

def main_with_comparison():
    """Main execution with baseline comparison"""
    
    # Get slide paths
    slide_files = [f for f in os.listdir(SVS_DIR) if f.lower().endswith('.svs')]
    
    if len(slide_files) == 0:
        log_msg("❌ No SVS files found!")
        return
    
    np.random.shuffle(slide_files)
    
    # Split: 10% calibration, rest for processing
    n_calib = max(1, int(0.1 * len(slide_files)))
    calib_files = slide_files[:n_calib]
    proc_files = slide_files[:10]  # Process first 10 for demo
    
    calib_paths = [os.path.join(SVS_DIR, f) for f in calib_files]
    proc_paths = [os.path.join(SVS_DIR, f) for f in proc_files]
    proc_ids = [Path(f).stem for f in proc_files]
    
    log_msg(f"Calibration slides: {n_calib}")
    log_msg(f"Processing slides: {len(proc_files)}")
    
    # ========================================
    # YOUR FRAMEWORK
    # ========================================
    
    # Initialize pipeline
    pipeline = WSIFeatureQC(config={})
    
    # Step 1: Optimize parameters
    params = pipeline.run_optimization(calib_paths)
    
    # Step 2: Initialize encoders
    pipeline.initialize_encoders()
    
    # Step 3: Extract features
    all_results = pipeline.run_pipeline(proc_paths, proc_ids)
    pipeline.all_results = all_results  # Store for comparison
    
    # Step 4: Evaluate
    evaluation_results = pipeline.evaluate_features(all_results, proc_ids)
    
    # ========================================
    # BASELINE COMPARISON
    # ========================================
    
    # Initialize comparison
    comparator = FrameworkComparison(pipeline, device=DEVICE)
    
    # Run comparison
    comparison_results = comparator.run_comparison(
        proc_paths, proc_ids, labels=None
    )
    
    # Generate comparison table
    comparison_df = comparator.generate_comparison_table()
    
    # Statistical tests
    comparator.statistical_significance_test()
    
    # ========================================
    # SAVE ALL RESULTS
    # ========================================
    
    log_msg("\n" + "="*80)
    log_msg("SAVING RESULTS")
    log_msg("="*80 + "\n")
    
    # Save evaluation results
    with open(f"{OUTPUT_DIR}/evaluation_results.json", 'w') as f:
        eval_serializable = {}
        for enc_name, results in evaluation_results.items():
            eval_serializable[enc_name] = {
                'stability': results['stability'].to_dict('records'),
                'redundancy': results['redundancy'],
                'separability': results['separability']
            }
        json.dump(eval_serializable, f, indent=2)
    
    # Save comparison results
    with open(f"{OUTPUT_DIR}/comparison_results.json", 'w') as f:
        json.dump(comparison_results, f, indent=2)
    
    log_msg("✅ All results saved")
    
    # ========================================
    # PRINT SUMMARY
    # ========================================
    
    log_msg("\n" + "="*80)
    log_msg("✅ COMPLETE PIPELINE WITH COMPARISON FINISHED")
    log_msg("="*80 + "\n")
    
    log_msg("KEY FINDINGS:")
    log_msg("-" * 80)
    
    # Find best performing method
    best_stability = max(
        [(k, v['stability_mean']) for k, v in comparison_results.items() 
         if v['stability_mean'] is not None],
        key=lambda x: x[1]
    )
    log_msg(f"Best Stability: {best_stability[0]} = {best_stability[1]:.3f}")
    
    best_redundancy = min(
        [(k, v['redundancy_ratio']) for k, v in comparison_results.items() 
         if v['redundancy_ratio'] is not None],
        key=lambda x: x[1]
    )
    log_msg(f"Best Redundancy (lowest): {best_redundancy[0]} = {best_redundancy[1]:.3f}")
    
    if any(v['silhouette_score'] for v in comparison_results.values()):
        best_separability = max(
            [(k, v['silhouette_score']) for k, v in comparison_results.items() 
             if v['silhouette_score'] is not None],
            key=lambda x: x[1]
        )
        log_msg(f"Best Separability: {best_separability[0]} = {best_separability[1]:.3f}")
    
    log_msg("\n" + "="*80 + "\n")


if __name__ == "__main__":
    main_with_comparison()

  from .autonotebook import tqdm as notebook_tqdm


WSI-FEATUREQC: COMPLETE FEATURE EXTRACTION & EVALUATION FRAMEWORK
Device: cpu
Multi-seed sampling: 5 seeds
Output: WSI_FEATUREQC_COMPLETE

[2026-01-19 14:11:58] Calibration slides: 11
[2026-01-19 14:11:58] Processing slides: 10
[2026-01-19 14:11:58] 
[2026-01-19 14:11:58] STEP 1: PARAMETER OPTIMIZATION

[2026-01-19 14:11:58] 
[2026-01-19 14:11:58] RUNNING FULL PARAMETER OPTIMIZATION

[2026-01-19 14:11:58] Optimizing tile count using Elbow method...
[2026-01-19 14:12:41]   ✅ Optimal tile count: 500
[2026-01-19 14:12:41] Optimizing blur threshold using Youden's J...
[2026-01-19 14:12:55]   ✅ Blur threshold: 0.23
[2026-01-19 14:12:55] Optimizing tissue threshold using multi-method consensus...
[2026-01-19 14:13:06]   ✅ Tissue threshold: 0.300
[2026-01-19 14:13:06] Computing bootstrap stability...
[2026-01-19 14:13:12]   ✅ Bootstrap: 0.15 ± 0.01
[2026-01-19 14:13:12] Computing stain normalization statistics...
[2026-01-19 14:13:20]   ✅ Stain stats: RGB means = [0.779 0.665 0.877]
[2026-01-