In [None]:
# ============================================================
# COMPLETE Q1-READY HISTOLOGY PIPELINE - PRODUCTION VERSION
# LUNIT ATOM + INTERPRETABLE FEATURES + ROBUST OPTIMIZATION
# ALL SLIDES PROCESSED (NO DATA LOSS)
# ============================================================

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import numpy as np
import pandas as pd
import openslide
import torch
import torchvision.transforms as transforms
from PIL import Image
from skimage.filters import threshold_otsu, laplace, gaussian
from skimage.morphology import remove_small_objects, binary_dilation, disk
from skimage.color import rgb2hsv, rgb2gray
from skimage.measure import regionprops, label
from skimage.feature import graycomatrix, graycoprops
from scipy import stats
from sklearn.metrics import roc_curve, auc
import json
from datetime import datetime
import warnings
import matplotlib.pyplot as plt
from pathlib import Path
import timm
import traceback

warnings.filterwarnings("ignore")

# ===============================
# REPRODUCIBILITY
# ===============================
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)
    torch.backends.cudnn.deterministic = True

# ===============================
# CONFIGURATION
# ===============================
SVS_DIR = r"C:\Users\Shahinur\Downloads\PKG_Dataset\PKG - Brain-Mets-Lung-MRI-Path-Segs_histopathology images\data"
OUTPUT_DIR = "histology_q1_production_final"
Path(OUTPUT_DIR).mkdir(exist_ok=True)
Path(f"{OUTPUT_DIR}/figures").mkdir(exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("="*80)
print("Q1-READY: LUNIT ATOM + INTERPRETABLE FEATURES (PRODUCTION)")
print("="*80)
print(f"Device: {DEVICE}")
print(f"Seed: {RANDOM_SEED}")
print(f"Output: {OUTPUT_DIR}\n")

def log_msg(m):
    """Thread-safe logging"""
    print(m)
    try:
        with open(f"{OUTPUT_DIR}/progress.log", 'a') as f:
            f.write(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] {m}\n")
    except:
        pass

# ===============================
# OPTIMIZER - FIXED & ROBUST
# ===============================
class Optimizer:
    """Production-grade optimizer with robust error handling"""
    
    def __init__(self, slides, n=300):
        self.slides = slides
        self.n = n
        self.results = {}
    
    def _bg(self, t):
        """Check if tile is background"""
        return np.mean(t) > 220
    
    def _blur(self, t):
        """Compute blur score with gradient boost for low contrast"""
        g = rgb2gray(t)
        v = laplace(g).var()
        return v + (np.sqrt(np.gradient(g)[0]**2 + np.gradient(g)[1]**2).mean()*10 if v<10 else 0)
    
    def _mask(self, t):
        """Tissue segmentation mask"""
        g = np.mean(t, 2)
        th = threshold_otsu(g) if g.std()>1 else 200
        m = g < th
        m = remove_small_objects(m, 500)
        m = binary_dilation(m, disk(3))
        return m
    
    def elbow(self, sz, mx=250):
        """Elbow method for optimal tile count"""
        log_msg("METHOD 1: Elbow (Tile Count)")
        cnts, vars = [], []
        
        for p in self.slides[:3]:
            try:
                sl = openslide.OpenSlide(p)
                lv = sl.get_best_level_for_downsample(1)
                ds = sl.level_downsamples[lv]
                w, h = sl.level_dimensions[lv]
                ts = []
                
                for y in range(0, h-sz, sz):
                    for x in range(0, w-sz, sz):
                        if len(ts)>=mx: break
                        t = np.array(sl.read_region((int(x*ds), int(y*ds)), lv, (sz,sz)).convert("RGB"))
                        if not self._bg(t) and self._mask(t).sum()/t.size>=0.1:
                            ts.append(rgb2gray(t).flatten())
                    if len(ts)>=mx: break
                
                sl.close()
                
                if len(ts)<50: continue
                ta = np.array(ts)
                
                for n in range(25, mx+1, 25):
                    if n>len(ta): continue
                    vars.append(np.var(np.mean(ta[:n], 0)))
                    cnts.append(n)
                    
            except Exception as e:
                log_msg(f"  ‚ö†Ô∏è Slide error: {e}")
                continue
        
        if len(cnts)<3:
            log_msg("  ‚ö†Ô∏è Insufficient data, using default: 100")
            return 100
        
        cnts, vars = np.array(cnts), np.array(vars)
        d2 = np.gradient(np.gradient(vars))
        opt = max(50, min(int(cnts[np.argmin(np.abs(d2))]), 200))
        
        self.results['elbow'] = {'optimal': opt, 'samples': len(cnts)}
        log_msg(f"‚úÖ Optimal tiles: {opt}")
        return opt
    
    def youden(self, sz):
        """Youden's J statistic for blur threshold"""
        log_msg("METHOD 2: Youden's J (Blur)")
        blurs, tisss = [], []
        
        for p in self.slides[:4]:
            try:
                sl = openslide.OpenSlide(p)
                lv = sl.get_best_level_for_downsample(1)
                ds = sl.level_downsamples[lv]
                w, h = sl.level_dimensions[lv]
                
                for y in range(0, h-sz, sz):
                    for x in range(0, w-sz, sz):
                        if len(blurs)>=500: break
                        t = np.array(sl.read_region((int(x*ds), int(y*ds)), lv, (sz,sz)).convert("RGB"))
                        if not self._bg(t):
                            blurs.append(self._blur(t))
                            tisss.append(self._mask(t).sum()/t.size)
                    if len(blurs)>=500: break
                
                sl.close()
                
            except Exception as e:
                log_msg(f"  ‚ö†Ô∏è Slide error: {e}")
                continue
        
        if len(blurs) < 100:
            log_msg("  ‚ö†Ô∏è Insufficient data, using default: 0.1")
            return 0.1
        
        ba, ta = np.array(blurs), np.array(tisss)
        emp, tis = ta<0.05, ta>=0.3
        
        if emp.sum() < 10 or tis.sum() < 10:
            # Fallback to percentile
            opt = float(np.percentile(ba, 5))
            log_msg(f"‚úÖ Blur threshold (percentile): {opt:.4f}")
            return opt
        
        ths = np.percentile(ba, np.arange(1,20,1))
        js = []
        
        for th in ths:
            tp = (ba[emp]<th).sum()
            fn = (ba[tis]>=th).sum()
            fp = (ba[tis]<th).sum()
            tn = (ba[emp]>=th).sum()
            
            sensitivity = tp/(tp+fn+1e-8)
            specificity = tn/(tn+fp+1e-8)
            js.append(sensitivity + specificity - 1)
        
        opt = float(ths[np.argmax(js)])
        self.results['youden'] = {'optimal': opt, 'j': float(max(js))}
        log_msg(f"‚úÖ Blur threshold: {opt:.4f}")
        return opt
    
    def tissue_threshold_robust(self, sz):
        """
        ROBUST tissue threshold using multiple methods
        Q1-ready: No arbitrary defaults, data-driven fallbacks
        """
        log_msg("METHOD 3: Tissue Threshold (Multi-Method)")
        
        tisss = []
        
        # Collect tissue percentages
        for p in self.slides[:5]:  # Use more slides
            try:
                sl = openslide.OpenSlide(p)
                lv = sl.get_best_level_for_downsample(1)
                ds = sl.level_downsamples[lv]
                w, h = sl.level_dimensions[lv]
                
                for y in range(0, h-sz, sz):
                    for x in range(0, w-sz, sz):
                        if len(tisss)>=600:  # More samples
                            break
                        t = np.array(sl.read_region((int(x*ds), int(y*ds)), lv, (sz,sz)).convert("RGB"))
                        if not self._bg(t):
                            tp = self._mask(t).sum()/t.size
                            tisss.append(tp)
                    if len(tisss)>=600:
                        break
                
                sl.close()
                
            except Exception as e:
                log_msg(f"  ‚ö†Ô∏è Slide error: {e}")
                continue
        
        if len(tisss) < 100:
            log_msg("  ‚ùå CRITICAL: Insufficient data for tissue threshold")
            return None
        
        ta = np.array(tisss)
        
        log_msg(f"  üìä Tissue % distribution:")
        log_msg(f"     Samples: {len(ta)}")
        log_msg(f"     Mean: {ta.mean():.3f}, Std: {ta.std():.3f}")
        log_msg(f"     Min: {ta.min():.3f}, Max: {ta.max():.3f}")
        log_msg(f"     P10: {np.percentile(ta, 10):.3f}, P50: {np.percentile(ta, 50):.3f}")
        
        # METHOD A: Percentile-based (conservative)
        # Use 25th percentile - excludes mostly-background tiles
        method_a = float(np.percentile(ta, 25))
        log_msg(f"  Method A (P25): {method_a:.3f}")
        
        # METHOD B: Otsu on tissue distribution
        try:
            # Bin the tissue percentages
            hist, bin_edges = np.histogram(ta, bins=50)
            # Find threshold that separates background-heavy from tissue-rich
            cumsum = np.cumsum(hist)
            total = cumsum[-1]
            
            max_var = 0
            best_th = 0.3
            
            for i in range(1, len(hist)-1):
                w0 = cumsum[i] / total
                w1 = 1 - w0
                
                if w0 == 0 or w1 == 0:
                    continue
                
                m0 = np.average(bin_edges[:i+1], weights=hist[:i+1]) if hist[:i+1].sum() > 0 else 0
                m1 = np.average(bin_edges[i+1:], weights=hist[i+1:]) if hist[i+1:].sum() > 0 else 0
                
                var = w0 * w1 * (m0 - m1)**2
                
                if var > max_var:
                    max_var = var
                    best_th = bin_edges[i]
            
            method_b = float(best_th)
            log_msg(f"  Method B (Otsu): {method_b:.3f}")
            
        except Exception as e:
            log_msg(f"  Method B failed: {e}")
            method_b = method_a
        
        # METHOD C: Gap statistic
        # Find largest gap in sorted tissue percentages
        try:
            sorted_ta = np.sort(ta)
            gaps = np.diff(sorted_ta)
            
            # Find gap in range [0.2, 0.6]
            valid_gaps = []
            for i, gap in enumerate(gaps):
                if 0.2 <= sorted_ta[i] <= 0.6:
                    valid_gaps.append((gap, sorted_ta[i]))
            
            if valid_gaps:
                max_gap = max(valid_gaps, key=lambda x: x[0])
                method_c = float(max_gap[1])
                log_msg(f"  Method C (Gap): {method_c:.3f}")
            else:
                method_c = method_a
                log_msg(f"  Method C (Gap): No gap found, using P25")
                
        except Exception as e:
            log_msg(f"  Method C failed: {e}")
            method_c = method_a
        
        # METHOD D: Mixture model (simple 2-component)
        try:
            # Assume bimodal: background-heavy vs tissue-rich
            # Find local minimum between modes
            hist, bins = np.histogram(ta, bins=30)
            smoothed = np.convolve(hist, np.ones(3)/3, mode='same')
            
            # Find local minima
            minima = []
            for i in range(1, len(smoothed)-1):
                if smoothed[i] < smoothed[i-1] and smoothed[i] < smoothed[i+1]:
                    if 0.2 <= bins[i] <= 0.6:
                        minima.append((smoothed[i], bins[i]))
            
            if minima:
                # Use deepest minimum
                method_d = float(min(minima, key=lambda x: x[0])[1])
                log_msg(f"  Method D (Mixture): {method_d:.3f}")
            else:
                method_d = method_a
                log_msg(f"  Method D (Mixture): No minimum, using P25")
                
        except Exception as e:
            log_msg(f"  Method D failed: {e}")
            method_d = method_a
        
        # CONSENSUS: Use median of methods (robust to outliers)
        methods = [method_a, method_b, method_c, method_d]
        consensus = float(np.median(methods))
        
        # Clamp to reasonable range
        consensus = max(0.25, min(consensus, 0.65))
        
        log_msg(f"\n  üìä Multi-Method Results:")
        log_msg(f"     A (P25): {method_a:.3f}")
        log_msg(f"     B (Otsu): {method_b:.3f}")
        log_msg(f"     C (Gap): {method_c:.3f}")
        log_msg(f"     D (Mixture): {method_d:.3f}")
        log_msg(f"  üéØ Consensus (median): {consensus:.3f}")
        
        self.results['tissue_threshold'] = {
            'optimal': consensus,
            'method_a_p25': method_a,
            'method_b_otsu': method_b,
            'method_c_gap': method_c,
            'method_d_mixture': method_d,
            'samples': len(ta),
            'distribution': {
                'mean': float(ta.mean()),
                'std': float(ta.std()),
                'p10': float(np.percentile(ta, 10)),
                'p25': float(np.percentile(ta, 25)),
                'p50': float(np.percentile(ta, 50)),
                'p75': float(np.percentile(ta, 75))
            }
        }
        
        log_msg(f"‚úÖ Tissue threshold: {consensus:.2f} (robust multi-method)")
        return consensus
    
    def roc(self, sz):
        """ROC-based tissue threshold optimization using Youden's J statistic"""
        log_msg("METHOD 3: ROC (Tissue Threshold)")
        
        tisss = []
        
        # First pass: Collect tissue percentages to determine adaptive thresholds
        for p in self.slides[:5]:
            try:
                sl = openslide.OpenSlide(p)
                lv = sl.get_best_level_for_downsample(1)
                ds = sl.level_downsamples[lv]
                w, h = sl.level_dimensions[lv]
                
                for y in range(0, h-sz, sz):
                    for x in range(0, w-sz, sz):
                        if len(tisss)>=800:
                            break
                        t = np.array(sl.read_region((int(x*ds), int(y*ds)), lv, (sz,sz)).convert("RGB"))
                        if not self._bg(t):
                            tp = self._mask(t).sum()/t.size
                            tisss.append(tp)
                    if len(tisss)>=800:
                        break
                
                sl.close()
                
            except Exception as e:
                log_msg(f"  ‚ö†Ô∏è Slide error: {e}")
                continue
        
        if len(tisss) < 100:
            log_msg(f"  ‚ö†Ô∏è Insufficient data for ROC, using percentile-based fallback")
            consensus = float(np.percentile(np.array(tisss), 25))
            self.results['roc'] = {'method': 'fallback_percentile', 'optimal': consensus}
            log_msg(f"‚úÖ Tissue threshold: {consensus:.3f} (percentile fallback)")
            return consensus
        
        ta_all = np.array(tisss)
        
        # ADAPTIVE LABELING: Use percentiles to ensure both classes exist
        # Use 75th percentile as tissue-rich threshold (upper quartile)
        # Use 25th percentile as background-heavy threshold (lower quartile)
        p25 = float(np.percentile(ta_all, 25))
        p75 = float(np.percentile(ta_all, 75))
        
        log_msg(f"  üìä Tissue % distribution (initial):")
        log_msg(f"     Samples: {len(ta_all)}")
        log_msg(f"     P25: {p25:.4f}, P50: {np.percentile(ta_all, 50):.4f}, P75: {p75:.4f}")
        log_msg(f"     Min: {ta_all.min():.4f}, Max: {ta_all.max():.4f}")
        
        # Create binary labels based on adaptive percentiles
        labels = []
        for tp in ta_all:
            if tp >= p75:
                labels.append(1)  # Tissue-rich
            elif tp < p25:
                labels.append(0)  # Background-heavy
            else:
                labels.append(-1)  # Ambiguous (middle range)
        
        # Filter out ambiguous samples
        la = np.array(labels)
        mask = la != -1
        ta = ta_all[mask]
        la = la[mask]
        
        # Verify we have both classes
        n_tissue = (la == 1).sum()
        n_background = (la == 0).sum()
        
        if len(ta) < 50 or n_tissue < 10 or n_background < 10:
            log_msg(f"  ‚ö†Ô∏è Insufficient balanced data ({len(ta)} total, {n_tissue} tissue, {n_background} background)")
            log_msg(f"  Using simple median-based threshold instead")
            consensus = float(np.median(ta_all))
            self.results['roc'] = {'method': 'fallback_median', 'optimal': consensus}
            log_msg(f"‚úÖ Tissue threshold: {consensus:.3f} (median fallback)")
            return consensus
        
        # Compute ROC curve
        fpr, tpr, thresholds = roc_curve(la, ta)
        
        # Compute Youden's J statistic for each threshold
        j_scores = tpr - fpr
        optimal_idx = np.argmax(j_scores)
        optimal_threshold = float(thresholds[optimal_idx])
        optimal_j = float(j_scores[optimal_idx])
        roc_auc = float(auc(fpr, tpr))
        
        # Clamp to reasonable range
        optimal_threshold = max(0.10, min(optimal_threshold, 0.75))
        
        log_msg(f"  üìä ROC Analysis (adaptive labeling):")
        log_msg(f"     Labeled samples: {len(ta)} (tissue: {n_tissue}, background: {n_background})")
        log_msg(f"     P25 threshold: {p25:.4f}, P75 threshold: {p75:.4f}")
        log_msg(f"     AUC: {roc_auc:.4f}")
        log_msg(f"     Optimal threshold: {optimal_threshold:.4f}")
        log_msg(f"     Youden's J: {optimal_j:.4f}")
        log_msg(f"     Sensitivity: {tpr[optimal_idx]:.4f}")
        log_msg(f"     Specificity: {1-fpr[optimal_idx]:.4f}")
        
        self.results['roc'] = {
            'optimal': optimal_threshold,
            'j_statistic': optimal_j,
            'auc': roc_auc,
            'sensitivity': float(tpr[optimal_idx]),
            'specificity': float(1-fpr[optimal_idx]),
            'p25': p25,
            'p75': p75,
            'samples': len(ta),
            'samples_total': len(ta_all)
        }
        
        log_msg(f"‚úÖ Tissue threshold: {optimal_threshold:.3f} (ROC-optimized with adaptive labeling)")
        return optimal_threshold
    
    def bootstrap(self, sz, n=50):
        """Bootstrap confidence interval for blur threshold"""
        log_msg("METHOD 4: Bootstrap")
        blurs = []
        
        for p in self.slides[:2]:
            try:
                sl = openslide.OpenSlide(p)
                lv = sl.get_best_level_for_downsample(1)
                ds = sl.level_downsamples[lv]
                w, h = sl.level_dimensions[lv]
                
                for y in range(0, h-sz, sz):
                    for x in range(0, w-sz, sz):
                        if len(blurs)>=200: break
                        t = np.array(sl.read_region((int(x*ds), int(y*ds)), lv, (sz,sz)).convert("RGB"))
                        if not self._bg(t):
                            blurs.append(self._blur(t))
                    if len(blurs)>=200: break
                
                sl.close()
                
            except Exception as e:
                log_msg(f"  ‚ö†Ô∏è Slide error: {e}")
                continue
        
        if len(blurs) < 50:
            log_msg("  ‚ö†Ô∏è Insufficient data for bootstrap")
            return 0.1, 0.0
        
        ba = np.array(blurs)
        bs = [np.percentile(np.random.choice(ba, len(ba), True), 5) for _ in range(n)]
        mu, std = np.mean(bs), np.std(bs)
        
        self.results['bootstrap'] = {'mean': float(mu), 'std': float(std)}
        log_msg(f"‚úÖ Bootstrap: {mu:.4f}¬±{std:.4f}")
        return mu, std
    
    def entropy(self, sz):
        """Compute stain normalization targets"""
        log_msg("METHOD 5: Entropy (Stain)")
        tiles = []
        
        for p in self.slides[:3]:
            try:
                sl = openslide.OpenSlide(p)
                lv = sl.get_best_level_for_downsample(1)
                ds = sl.level_downsamples[lv]
                w, h = sl.level_dimensions[lv]
                
                for y in range(0, h-sz, sz):
                    for x in range(0, w-sz, sz):
                        if len(tiles)>=200: break
                        t = np.array(sl.read_region((int(x*ds), int(y*ds)), lv, (sz,sz)).convert("RGB"))
                        if not self._bg(t) and self._mask(t).sum()/t.size>=0.3:
                            tiles.append(t.astype(np.float32)/255)
                    if len(tiles)>=200: break
                
                sl.close()
                
            except Exception as e:
                log_msg(f"  ‚ö†Ô∏è Slide error: {e}")
                continue
        
        if len(tiles) < 20:
            log_msg("  ‚ö†Ô∏è Insufficient data, using defaults")
            m, s = np.array([0.75, 0.55, 0.45]), np.array([0.15, 0.15, 0.15])
        else:
            ms = [t.mean((0,1)) for t in tiles]
            ss = [t.std((0,1)) for t in tiles]
            m, s = np.mean(ms,0), np.mean(ss,0)
        
        self.results['entropy'] = {'means': m.tolist(), 'stds': s.tolist()}
        log_msg(f"‚úÖ Stain: means={m.round(3)}")
        return m, s
    
    def save(self):
        """Save optimization results"""
        try:
            with open(f"{OUTPUT_DIR}/optimization.json", 'w') as f:
                json.dump({
                    'timestamp': datetime.now().isoformat(),
                    'seed': RANDOM_SEED,
                    **self.results
                }, f, indent=2)
            log_msg(f"‚úÖ Optimization results saved\n")
        except Exception as e:
            log_msg(f"‚ö†Ô∏è Could not save optimization: {e}")

# ===============================
# INTERPRETABLE FEATURES
# ===============================
class InterpExtractor:
    """Extract interpretable histology features"""
    
    def nuclear(self, t):
        """Nuclear morphology features"""
        g = rgb2gray(t)
        try:
            b = g < threshold_otsu(g)*0.8
        except:
            b = g < 100
        
        l = label(b)
        r = regionprops(l)
        
        if not r:
            return {f'nuc_{k}':0 for k in ['cnt','area_m','area_s','dens','circ','sol']}
        
        a = np.array([x.area for x in r])
        c = np.array([4*np.pi*x.area/(x.perimeter**2+1e-8) for x in r])
        s = np.array([x.solidity for x in r])
        
        return {
            'nuc_cnt': len(r),
            'nuc_area_m': a.mean(),
            'nuc_area_s': a.std(),
            'nuc_dens': len(r)/b.size,
            'nuc_circ': c.mean(),
            'nuc_sol': s.mean()
        }
    
    def arch(self, t):
        """Architectural features (organization uniformity)"""
        g = rgb2gray(t)
        sm = gaussian(g, 5)
        
        # Compute local variance
        vs = [np.var(g[i:i+20,j:j+20]) 
              for i in range(0,g.shape[0]-20,20) 
              for j in range(0,g.shape[1]-20,20)]
        
        return {
            'arch_org': np.mean(vs) if vs else 0,
            'arch_uni': np.std(vs) if vs else 0
        }
    
    def texture(self, t):
        """Texture features via GLCM"""
        g = (rgb2gray(t)*255).astype(np.uint8)
        
        try:
            glcm = graycomatrix(g, [1], [0], 256, symmetric=True, normed=True)
            f = {}
            for p in ['contrast','homogeneity','energy']:
                f[f'tex_{p}'] = float(graycoprops(glcm, p)[0,0])
        except:
            f = {f'tex_{p}':0 for p in ['contrast','homogeneity','energy']}
        
        return f
    
    def extract(self, t):
        """Extract all interpretable features from tile"""
        try:
            return {**self.nuclear(t), **self.arch(t), **self.texture(t)}
        except Exception as e:
            # Return zeros on error
            return {f'nuc_{k}':0 for k in ['cnt','area_m','area_s','dens','circ','sol']} | \
                   {'arch_org':0, 'arch_uni':0} | \
                   {f'tex_{p}':0 for p in ['contrast','homogeneity','energy']}

# ===============================
# ATOM EXTRACTOR (CTRANSPATH)
# ===============================
class ATOMExtractor:
    """Feature extractor using Ctranspath model"""
    
    def __init__(self):
        log_msg("Loading Ctranspath...")
        try:
            # Load Ctranspath model from weights folder
            model_path = "weights/ctranspath.pth"
            
            # Create ResNet50 backbone for Ctranspath
            self.model = timm.create_model(
                'resnet50',
                pretrained=False,
                num_classes=0,
                global_pool='avg'
            ).to(DEVICE)
            
            # Load pretrained Ctranspath weights
            checkpoint = torch.load(model_path, map_location=DEVICE)
            
            # Handle different checkpoint formats
            if isinstance(checkpoint, dict) and 'model' in checkpoint:
                state_dict = checkpoint['model']
            elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            else:
                state_dict = checkpoint
            
            # Remove 'module.' prefix if present (from DataParallel)
            state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
            
            self.model.load_state_dict(state_dict, strict=False)
            self.model.eval()
            
            log_msg("‚úÖ Ctranspath loaded (2048D)\n")
        except Exception as e:
            log_msg(f"‚ùå Ctranspath loading failed: {e}")
            raise
        
        self.tf = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
        ])
    
    def extract(self, tiles, sz=224):
        """Extract ATOM features from tiles"""
        if not tiles:
            return None
        
        fs = []
        log_msg(f"  Extracting ATOM features from {len(tiles)} tiles...")
        
        for i, t in enumerate(tiles):
            try:
                # Resize if needed
                if t.shape[0]!=sz or t.shape[1]!=sz:
                    t = np.array(Image.fromarray(t).resize((sz,sz)))
                
                x = self.tf(Image.fromarray(t)).unsqueeze(0).to(DEVICE)
                
                with torch.no_grad():
                    fs.append(self.model(x).squeeze().cpu().numpy())
                
                if (i+1)%50==0:
                    print(f"    {i+1}/{len(tiles)}", end='\r')
                    
            except Exception as e:
                continue
        
        if not fs:
            log_msg(f"  ‚ùå No features extracted")
            return None
        
        fs = np.array(fs)
        log_msg(f"  ‚úì Extracted {len(fs)} tile features")
        
        # FIXED: Robust outlier removal
        if len(fs) > 10:  # Only remove outliers if we have enough tiles
            try:
                # Compute z-scores
                mean_feat = fs.mean(0)
                std_feat = fs.std(0)
                
                # Avoid division by zero for constant features
                std_feat = np.where(std_feat < 1e-6, 1.0, std_feat)
                
                z = np.abs((fs - mean_feat) / std_feat)
                
                # More lenient threshold (5 instead of 3)
                # AND require multiple features to be outliers (not just 1)
                outlier_mask = (z > 5).sum(axis=1) > (z.shape[1] * 0.1)  # >10% features are outliers
                
                num_outliers = outlier_mask.sum()
                
                if num_outliers > 0 and num_outliers < len(fs) * 0.5:  # Don't remove >50%
                    fs = fs[~outlier_mask]
                    log_msg(f"  üîç Removed {num_outliers} outlier tiles")
                elif num_outliers >= len(fs) * 0.5:
                    log_msg(f"  ‚ö†Ô∏è Too many outliers ({num_outliers}), keeping all tiles")
                
            except Exception as e:
                log_msg(f"  ‚ö†Ô∏è Outlier detection failed: {e}, keeping all tiles")
        else:
            log_msg(f"  ‚ö†Ô∏è Too few tiles for outlier removal, keeping all")
        
        if len(fs) == 0:
            log_msg(f"  ‚ùå All tiles removed as outliers")
            return None
        
        log_msg(f"  ‚úÖ Final: {len(fs)} tiles")
        
        # Aggregate features (with safety checks)
        try:
            return {
                'atom_m': fs.mean(0),
                'atom_s': fs.std(0),
                'atom_mx': fs.max(0) if len(fs) > 0 else np.zeros(fs.shape[1]),
                'atom_mn': fs.min(0) if len(fs) > 0 else np.zeros(fs.shape[1]),
                'atom_md': np.median(fs, 0) if len(fs) > 0 else np.zeros(fs.shape[1])
            }
        except Exception as e:
            log_msg(f"  ‚ùå Aggregation error: {e}")
            return None

# ===============================
# MAIN PIPELINE
# ===============================
def main():
    files = [f for f in os.listdir(SVS_DIR) if f.lower().endswith('.svs')]
    
    if len(files) < 10:
        log_msg("‚ùå Need ‚â•10 slides for calibration")
        return
    
    log_msg(f"Found {len(files)} SVS files")
    
    # CALIBRATION: Use 25% of slides (for robust optimization)
    # For 111 slides: 111 // 4 = 27 slides
    np.random.shuffle(files)
    n_calib = max(20, min(30, len(files) // 4))  # 20-30 slides depending on total
    cal_paths = [os.path.join(SVS_DIR, f) for f in files[:n_calib]]
    
    # PROCESSING: Use ALL slides (NO DATA LOSS)
    proc_files = files
    
    log_msg(f"\n{'='*80}")
    log_msg("STEP 1: OPTIMIZATION (CALIBRATION)")
    log_msg(f"{'='*80}")
    log_msg(f"Calibration slides: {len(cal_paths)} ({len(cal_paths)/len(files)*100:.1f}%)")
    log_msg(f"Processing slides: {len(proc_files)} (ALL - no data loss)\n")
    
    # Run optimization
    opt = Optimizer(cal_paths, 300)
    sz = 224
    n_tiles = opt.elbow(sz)
    blur_th = opt.youden(sz)
    tiss_th = opt.roc(sz)
    boot_m, boot_s = opt.bootstrap(sz)
    stain_m, stain_s = opt.entropy(sz)
    opt.save()
    
    # Save parameters
    params = {
        'tile_sz': sz,
        'n_tiles': n_tiles,
        'blur_th': blur_th,
        'tiss_th': tiss_th,
        'stain_m': stain_m.tolist(),
        'stain_s': stain_s.tolist(),
        'seed': RANDOM_SEED,
        'calibration_slides': len(cal_paths),
        'calibration_percentage': len(cal_paths) / len(proc_files) * 100,
        'processing_slides': len(proc_files)
    }
    
    with open(f"{OUTPUT_DIR}/params.json", 'w') as f:
        json.dump(params, f, indent=2)
    
    log_msg(f"\n{'='*80}")
    log_msg("STEP 2: FEATURE EXTRACTION")
    log_msg(f"{'='*80}\n")
    
    # Initialize extractors
    interp = InterpExtractor()
    
    try:
        atom = ATOMExtractor()
    except:
        log_msg("‚ö†Ô∏è ATOM loading failed, continuing with interpretable features only")
        atom = None
    
    # Storage
    interp_res, atom_res, qc = [], [], []
    
    # Process all slides
    for i, fn in enumerate(proc_files, 1):
        log_msg(f"\n[{i}/{len(proc_files)}] {fn}")
        
        try:
            sl = openslide.OpenSlide(os.path.join(SVS_DIR, fn))
            lv = sl.get_best_level_for_downsample(1)
            ds = sl.level_downsamples[lv]
            w, h = sl.level_dimensions[lv]
            
            tiles = []
            
            # Extract tiles
            for y in range(0, h-sz, sz):
                for x in range(0, w-sz, sz):
                    if len(tiles)>=n_tiles:
                        break
                    
                    t = np.array(sl.read_region(
                        (int(x*ds), int(y*ds)),
                        lv,
                        (sz,sz)
                    ).convert("RGB"))
                    
                    # QC checks
                    if np.mean(t)>220:  # Background
                        continue
                    
                    g = rgb2gray(t)
                    m = g < threshold_otsu(g) if g.std()>1 else g<200
                    
                    if m.sum()/m.size < tiss_th:  # Tissue percentage
                        continue
                    
                    if opt._blur(t) < blur_th:  # Blur
                        continue
                    
                    tiles.append(t)
                
                if len(tiles)>=n_tiles:
                    break
            
            sl.close()
            
            # Check minimum tiles
            if len(tiles) < n_tiles//2:
                log_msg(f"  ‚ùå Insufficient tiles: {len(tiles)}")
                qc.append({'slide': fn, 'status': 'fail', 'reason': 'insufficient_tiles', 'tiles': len(tiles)})
                continue
            
            # Extract interpretable features
            ifs = [interp.extract(t) for t in tiles]
            idf = pd.DataFrame(ifs)
            
            iagg = {'slide': fn}
            for c in idf.columns:
                iagg[f'{c}_m'] = idf[c].mean()
                iagg[f'{c}_s'] = idf[c].std()
            
            interp_res.append(iagg)
            
            # Extract ATOM features
            if atom:
                try:
                    af = atom.extract(tiles, sz)
                    if af:
                        aagg = {'slide': fn}
                        for k, v in af.items():
                            for j, x in enumerate(v):
                                aagg[f'{k}_{j}'] = float(x)
                        atom_res.append(aagg)
                    else:
                        log_msg(f"  ‚ö†Ô∏è ATOM extraction returned None")
                except Exception as e:
                    log_msg(f"  ‚ö†Ô∏è ATOM extraction failed: {e}")
                    traceback.print_exc()
            
            log_msg(f"  ‚úÖ Success: {len(tiles)} tiles")
            qc.append({'slide': fn, 'status': 'ok', 'tiles': len(tiles)})
            
            # Periodic save
            if i%10==0:
                pd.DataFrame(interp_res).to_csv(f"{OUTPUT_DIR}/interpretable.csv", index=False)
                if atom_res:
                    pd.DataFrame(atom_res).to_csv(f"{OUTPUT_DIR}/atom.csv", index=False)
                pd.DataFrame(qc).to_csv(f"{OUTPUT_DIR}/qc.csv", index=False)
                log_msg(f"  üíæ Checkpoint saved ({i} slides processed)")
        
        except Exception as e:
            log_msg(f"  ‚ùå Error: {e}")
            traceback.print_exc()
            qc.append({'slide': fn, 'status': 'fail', 'reason': str(e), 'tiles': 0})
    
    # Final save
    log_msg(f"\n{'='*80}")
    log_msg("FINAL SAVE")
    log_msg(f"{'='*80}")
    
    if interp_res:
        pd.DataFrame(interp_res).to_csv(f"{OUTPUT_DIR}/interpretable.csv", index=False)
        log_msg(f"‚úÖ Interpretable features: {len(interp_res)} slides")
    
    if atom_res:
        pd.DataFrame(atom_res).to_csv(f"{OUTPUT_DIR}/atom.csv", index=False)
        log_msg(f"‚úÖ ATOM features: {len(atom_res)} slides")
    
    pd.DataFrame(qc).to_csv(f"{OUTPUT_DIR}/qc.csv", index=False)
    log_msg(f"‚úÖ QC report saved")
    
    # Summary
    qc_df = pd.DataFrame(qc)
    success = (qc_df['status']=='ok').sum()
    failed = (qc_df['status']=='fail').sum()
    
    log_msg(f"\n{'='*80}")
    log_msg("PIPELINE COMPLETED")
    log_msg(f"{'='*80}")
    log_msg(f"‚úÖ Successful: {success}/{len(qc_df)} ({success/len(qc_df)*100:.1f}%)")
    log_msg(f"‚ùå Failed: {failed}/{len(qc_df)}")
    log_msg(f"\nOutput files:")
    log_msg(f"  - {OUTPUT_DIR}/interpretable.csv")
    log_msg(f"  - {OUTPUT_DIR}/atom.csv")
    log_msg(f"  - {OUTPUT_DIR}/qc.csv")
    log_msg(f"  - {OUTPUT_DIR}/params.json")
    log_msg(f"  - {OUTPUT_DIR}/optimization.json")

if __name__ == "__main__":
    main()

Q1-READY: LUNIT ATOM + INTERPRETABLE FEATURES (PRODUCTION)
Device: cpu
Seed: 42
Output: histology_q1_production_final

Found 111 SVS files

STEP 1: OPTIMIZATION (CALIBRATION)
Calibration slides: 27 (24.3%)
Processing slides: 111 (ALL - no data loss)

METHOD 1: Elbow (Tile Count)
