In [3]:
import sys
import os
sys.path.append(os.path.abspath("../src"))

In [None]:
"""
cloud_detection.py - Memory Efficient Cloud Detection with Classical ML
Uses stratified sampling, optimized feature engineering, and ensemble techniques
"""

import os
import gc
import json
import joblib
import warnings
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm
from tqdm.contrib.concurrent import thread_map
from functools import partial
from concurrent.futures import ThreadPoolExecutor, as_completed
from scipy import ndimage
from scipy.ndimage import uniform_filter, minimum_filter, maximum_filter
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.linear_model import SGDClassifier
from sklearn.svm import LinearSVC
from sklearn.calibration import CalibratedClassifierCV
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import confusion_matrix, roc_auc_score, classification_report
from sklearn.feature_selection import SelectFromModel
from sklearn.preprocessing import StandardScaler
from skimage.util import view_as_windows
from skimage.exposure import equalize_hist
from rasterio.errors import NotGeoreferencedWarning

warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)

# ============================================================================
# Configuration
# ============================================================================
class Config:
    # Paths
    PROCESSED_DATA = Path("../data/processed")
    OUTPUT_PATH = Path("../outputs")
    MODEL_PATH = OUTPUT_PATH / "models"
    FEATURES_PATH = OUTPUT_PATH / "features"
    
    # Create directories
    for path in [MODEL_PATH, FEATURES_PATH, OUTPUT_PATH / "logs", OUTPUT_PATH / "predictions"]:
        os.makedirs(path, exist_ok=True)
    
    # Data processing
    TILE_SIZE = 128  # Process images in smaller tiles to save memory
    MAX_SAMPLES_PER_IMAGE = 10000  # Cap samples from each image
    SAMPLES_PER_CATEGORY = 100000  # Balance samples across categories
    PRECOMPUTE_FEATURES = True  # Store features on disk
    
    # Training
    BATCH_SIZE = 5000  # Pixels per batch
    N_EPOCHS = 15
    VAL_SPLIT = 0.2
    TEST_SPLIT = 0.1
    CROSS_VAL_FOLDS = 3  # Number of cross-validation folds
    
    # Early stopping
    EARLY_STOP_PATIENCE = 4
    EARLY_STOP_DELTA = 0.001
    
    # Feature engineering
    SPECTRAL_INDICES = True
    SPATIAL_WINDOW_SIZES = [3, 5, 7]  # Multi-scale spatial features
    TEXTURE_WINDOW = 7
    FEATURE_SELECTION = True  # Use feature selection
    MAX_FEATURES = 30  # Max number of features to keep
    
    # Model parameters
    SGD_PARAMS = {
        'loss': 'modified_huber',
        'penalty': 'elasticnet',
        'alpha': 0.0005,
        'l1_ratio': 0.15,
        'learning_rate': 'adaptive',
        'eta0': 0.02,
        'max_iter': 1,
        'class_weight': 'balanced',
        'n_jobs': -1
    }
    
    RF_PARAMS = {
        'n_estimators': 50,
        'max_depth': 10,
        'min_samples_split': 10,
        'min_samples_leaf': 4,
        'max_features': 'sqrt',
        'bootstrap': True,
        'class_weight': 'balanced',
        'n_jobs': -1
    }
    
    SVM_PARAMS = {
        'penalty': 'l2',
        'loss': 'squared_hinge',
        'dual': False,
        'C': 0.1,
        'class_weight': 'balanced',
        'max_iter': 1000
    }

    # Post-processing
    POST_PROCESS = True
    MORPHOLOGY_SIZE = 3
    PROBABILITY_THRESHOLD = 0.4  # Lower than 0.5 to catch more clouds

config = Config()

# ============================================================================
# Utilities
# ============================================================================
def load_and_normalize_tiff(path):
    """Load TIFF image and normalize values to [0, 1]"""
    try:
        import rasterio
        with rasterio.open(path) as src:
            image = src.read()
            
        # Handle different band counts
        if image.shape[0] != 4:
            raise ValueError(f"Expected 4 bands, got {image.shape[0]}")
            
        # Scale each band independently to [0, 1]
        for i in range(image.shape[0]):
            band = image[i]
            min_val, max_val = np.percentile(band, (1, 99))  # Robust scaling
            band = np.clip((band - min_val) / (max_val - min_val + 1e-8), 0, 1)
            image[i] = band
            
        return image
    except Exception as e:
        print(f"Error loading {path}: {e}")
        return None

def load_mask(path):
    """Load binary mask"""
    try:
        import rasterio
        with rasterio.open(path) as src:
            mask = src.read(1).astype(np.uint8)
        return mask
    except Exception as e:
        print(f"Error loading mask {path}: {e}")
        return None

def generate_tiles(image, tile_size=128, overlap=0):
    """Split image into tiles for memory-efficient processing"""
    if len(image.shape) == 3:  # Multi-band image
        h, w = image.shape[1], image.shape[2]
        tiles = []
        coords = []
        
        for y in range(0, h - overlap, tile_size - overlap):
            if y + tile_size > h:
                y = h - tile_size
                
            for x in range(0, w - overlap, tile_size - overlap):
                if x + tile_size > w:
                    x = w - tile_size
                    
                tile = image[:, y:y+tile_size, x:x+tile_size]
                tiles.append(tile)
                coords.append((y, x))
                
        return tiles, coords
    else:  # Single band (mask)
        h, w = image.shape
        tiles = []
        coords = []
        
        for y in range(0, h - overlap, tile_size - overlap):
            if y + tile_size > h:
                y = h - tile_size
                
            for x in range(0, w - overlap, tile_size - overlap):
                if x + tile_size > w:
                    x = w - tile_size
                    
                tile = image[y:y+tile_size, x:x+tile_size]
                tiles.append(tile)
                coords.append((y, x))
                
        return tiles, coords

# ============================================================================
# Feature Engineering
# ============================================================================
def calculate_spectral_indices(image):
    """Calculate multiple spectral indices"""
    # Ensure float32 dtype and handle NaN values
    bands = []
    for i in range(image.shape[0]):
        band = image[i].astype('float32')
        band = np.nan_to_num(band, nan=0.0)
        bands.append(band)
    
    # Unpack bands
    if len(bands) == 4:
        blue, green, red, nir = bands
    else:
        raise ValueError(f"Expected 4 bands, got {len(bands)}")
    
    indices = {}
    eps = 1e-6  # Prevent division by zero
    
    # Standard indices
    with np.errstate(divide='ignore', invalid='ignore'):
        # NDVI (Normalized Difference Vegetation Index)
        indices['ndvi'] = np.nan_to_num((nir - red) / (nir + red + eps), nan=0.0)
        
        # NDWI (Normalized Difference Water Index)
        indices['ndwi'] = np.nan_to_num((green - nir) / (green + nir + eps), nan=0.0)
        
        # Enhanced Vegetation Index (EVI)
        indices['evi'] = np.nan_to_num(2.5 * (nir - red) / (nir + 6 * red - 7.5 * blue + 1 + eps), nan=0.0)
        
        # Visible Atmospherically Resistant Index (VARI)
        indices['vari'] = np.nan_to_num((green - red) / (green + red - blue + eps), nan=0.0)
        
        # Simple Ratio (SR)
        indices['sr'] = np.nan_to_num(nir / (red + eps), nan=0.0)
        
        # Shadow index
        indices['si'] = np.nan_to_num((1 - blue) * (1 - green) * (1 - red), nan=0.0)
        
        # Normalized Difference Built-up Index (NDBI)
        if nir is not None:
            indices['ndbi'] = np.nan_to_num((nir - green) / (nir + green + eps), nan=0.0)
    
    # Clip extreme values
    for key in indices:
        indices[key] = np.clip(indices[key], -1, 1)
    
    return indices

def calculate_spatial_features(band, window_size):
    """Calculate spatial features with uniform filtering"""
    # Mean and variance
    mean = uniform_filter(band, size=window_size)
    mean_sq = uniform_filter(band**2, size=window_size)
    variance = np.maximum(mean_sq - mean**2, 0.0)  # Ensure non-negative
    std = np.sqrt(variance)
    
    # Min, max, range
    min_val = minimum_filter(band, size=window_size)
    max_val = maximum_filter(band, size=window_size)
    range_val = max_val - min_val
    
    # Edge detection (approximate gradient magnitude)
    sobel_h = ndimage.sobel(band, axis=0)
    sobel_v = ndimage.sobel(band, axis=1)
    edge = np.sqrt(sobel_h**2 + sobel_v**2)
    
    return np.stack([mean, std, min_val, max_val, range_val, edge], axis=-1)

def calculate_texture_features(band, window_size):
    """Calculate GLCM-like texture features"""
    # Simple texture descriptors without computing full GLCM
    features = []
    
    # Local entropy (approximation)
    entropy = uniform_filter(band * np.log(band + 1e-10), size=window_size)
    features.append(entropy)
    
    # Contrast-like (local variance)
    mean = uniform_filter(band, size=window_size)
    variance = uniform_filter((band - mean)**2, size=window_size)
    features.append(variance)
    
    # Homogeneity-like (inverse of local range)
    min_val = minimum_filter(band, size=window_size)
    max_val = maximum_filter(band, size=window_size)
    range_val = max_val - min_val + 1e-6
    homogeneity = 1.0 / range_val
    features.append(homogeneity)
    
    return np.stack(features, axis=-1)

def extract_features(image, indices=None):
    """Extract comprehensive feature set from image"""
    if image.shape[0] != 4:
        raise ValueError(f"Expected 4 bands, got {image.shape[0]}")
    
    # Transpose to (H, W, C) for easier processing
    image = np.transpose(image, (1, 2, 0))
    h, w, c = image.shape
    n_pixels = h * w
    
    # Calculate spectral indices if not provided
    if indices is None and config.SPECTRAL_INDICES:
        image_transposed = np.transpose(image, (2, 0, 1))  # Back to (C, H, W)
        indices = calculate_spectral_indices(image_transposed)
    
    # Determine feature count
    n_spectral_indices = len(indices) if indices else 0
    n_spatial_per_band = 6  # mean, std, min, max, range, edge
    n_texture_per_band = 3  # entropy, variance, homogeneity
    n_scales = len(config.SPATIAL_WINDOW_SIZES)
    
    # Calculate total feature count
    total_features = (
        c +  # Raw bands
        n_spectral_indices +  # Spectral indices
        c * n_spatial_per_band * n_scales +  # Multi-scale spatial features
        c * n_texture_per_band  # Texture features
    )
    
    # Initialize feature array
    features = np.empty((n_pixels, total_features), dtype='float32')
    
    # 1. Raw bands
    features[:, 0:c] = image.reshape(n_pixels, c)
    
    # 2. Spectral indices
    if indices:
        col_idx = c
        for idx_name, idx_values in indices.items():
            features[:, col_idx] = idx_values.ravel()
            col_idx += 1
    else:
        col_idx = c
    
    # 3. Multi-scale spatial features
    for band_idx in range(c):
        band = image[:, :, band_idx]
        
        for scale_idx, window_size in enumerate(config.SPATIAL_WINDOW_SIZES):
            spatial = calculate_spatial_features(band, window_size)
            start_col = col_idx + band_idx * n_spatial_per_band * n_scales + scale_idx * n_spatial_per_band
            end_col = start_col + n_spatial_per_band
            features[:, start_col:end_col] = spatial.reshape(n_pixels, n_spatial_per_band)
    
    # Update column index
    col_idx = c + n_spectral_indices + c * n_spatial_per_band * n_scales
    
    # 4. Texture features
    for band_idx in range(c):
        band = image[:, :, band_idx]
        texture = calculate_texture_features(band, config.TEXTURE_WINDOW)
        start_col = col_idx + band_idx * n_texture_per_band
        end_col = start_col + n_texture_per_band
        features[:, start_col:end_col] = texture.reshape(n_pixels, n_texture_per_band)
    
    # Clean up any remaining NaNs or infs
    features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
    
    return features

# ============================================================================
# Data Pipeline
# ============================================================================
class CloudDataset:
    def __init__(self, image_pairs, mode='train', feature_cache_dir=None):
        self.image_pairs = image_pairs
        self.mode = mode
        self.feature_cache_dir = feature_cache_dir
        if feature_cache_dir:
            os.makedirs(feature_cache_dir, exist_ok=True)
        self.class_weights = None
    
    def _get_cache_path(self, img_path):
        """Get path for cached features"""
        if self.feature_cache_dir:
            img_name = img_path.stem
            return Path(self.feature_cache_dir) / f"{img_name}_features.npz"
        return None
    
    def _stratified_sampling(self, mask, n_samples, random_state=None):
        """Balanced sampling preserving class ratio with optional fixed random state"""
        # Input validation
        if not isinstance(mask, np.ndarray):
            raise ValueError("Mask must be a numpy array")
        if n_samples <= 0:
            return np.array([], dtype=np.int32)
        
        rng = np.random.RandomState(random_state) if random_state is not None else np.random
        
        flat_mask = mask.ravel()
        cloud_idx = np.where(flat_mask == 1)[0]
        clear_idx = np.where(flat_mask == 0)[0]
        
        # Calculate sample counts with protection against edge cases
        n_cloud = min(len(cloud_idx), max(int(n_samples * 0.5), 100))
        n_clear = min(len(clear_idx), max(n_samples - n_cloud, 0))
        
        # Adjust sample counts if one class is under-represented
        if n_clear < (n_samples - n_cloud):
            n_cloud = min(len(cloud_idx), max(n_samples - n_clear, 0))
        
        # Perform sampling with type safety
        sampled_cloud = np.array([], dtype=np.int32)
        sampled_clear = np.array([], dtype=np.int32)
        
        if n_cloud > 0 and len(cloud_idx) > 0:
            sampled_cloud = rng.choice(cloud_idx, size=n_cloud, replace=False)
            sampled_cloud = sampled_cloud.astype(np.int32)
        
        if n_clear > 0 and len(clear_idx) > 0:
            sampled_clear = rng.choice(clear_idx, size=n_clear, replace=False)
            sampled_clear = sampled_clear.astype(np.int32)
        
        # Combine results with type consistency
        combined = np.concatenate([sampled_cloud, sampled_clear])
        return combined.astype(np.int32) if len(combined) > 0 else np.array([], dtype=np.int32)
    
    def process_image_pair(self, img_path, mask_path, category, random_state=None):
        """Process a single image-mask pair, with caching"""
        cache_path = self._get_cache_path(img_path)

        # Try to load from cache first
        if cache_path and cache_path.exists() and config.PRECOMPUTE_FEATURES:
            try:
                cached = np.load(cache_path)
                features = cached['features']
                labels = cached['labels']
                return features, labels
            except Exception as e:
                print(f"Cache read error for {img_path.name}: {e}")

        # Load data
        image = load_and_normalize_tiff(img_path)
        if image is None:
            return None, None

        mask = load_mask(mask_path)
        if mask is None:
            return None, None

        # Split into tiles for memory efficiency
        img_tiles, img_coords = generate_tiles(image, config.TILE_SIZE)
        mask_tiles, _ = generate_tiles(mask, config.TILE_SIZE)

        all_features = []
        all_labels = []

        # Process each tile
        for img_tile, mask_tile in zip(img_tiles, mask_tiles):
            # Calculate indices once per tile
            indices = calculate_spectral_indices(img_tile)

            # Extract features
            features = extract_features(img_tile, indices)

            # Get labels
            labels = mask_tile.ravel()

            # Sampling for balanced classes
            if self.mode == 'train':
                max_samples = min(config.MAX_SAMPLES_PER_IMAGE // len(img_tiles), len(labels))
                if max_samples > 0:  # Only sample if we have samples to take
                    idx = self._stratified_sampling(mask_tile, max_samples, random_state)
                    if len(idx) > 0:  # Check we got valid indices
                        idx = np.asarray(idx, dtype=np.int32)  # Ensure integer type
                        features = features[idx]
                        labels = labels[idx]

            if len(features) > 0 and len(labels) > 0:  # Only add if we have data
                all_features.append(features)
                all_labels.append(labels)

        # Combine results from all tiles
        if all_features:
            features = np.vstack(all_features)
            labels = np.hstack(all_labels)

            # Save to cache
            if cache_path and config.PRECOMPUTE_FEATURES:
                try:
                    np.savez_compressed(cache_path, features=features, labels=labels)
                except Exception as e:
                    print(f"Cache write error for {img_path.name}: {e}")

            return features, labels
        return None, None
    
    def batch_generator(self, random_state=None):
        """Generate batches of features and labels"""
        # Process all image pairs in parallel
        all_features = []
        all_labels = []

        # Create wrapper function for thread_map
        def process_wrapper(args):
            img_path, mask_path, category = args
            return self.process_image_pair(img_path, mask_path, category, random_state=random_state)

        # Use thread_map for parallel processing with progress bar
        max_workers = min(os.cpu_count(), 16)
        results = thread_map(process_wrapper,
                            self.image_pairs,
                            max_workers=max_workers,
                            desc=f"Processing {self.mode} data")

        # Collect results
        for features, labels in results:
            if features is not None and labels is not None:
                all_features.append(features)
                all_labels.append(labels)

        if not all_features:
            print(f"No valid features found in {self.mode} dataset!")
            return

        # Combine all data
        X = np.vstack(all_features)
        y = np.hstack(all_labels)

        # For validation/test, we might want to use all data
        if self.mode != 'train':
            # Still cap maximum size for very large datasets
            if len(y) > config.MAX_SAMPLES_PER_IMAGE * 10:
                idx = np.random.choice(len(y), config.MAX_SAMPLES_PER_IMAGE * 10, replace=False)
                X = X[idx]
                y = y[idx]

            # Return all validation/test data at once
            yield X, y
            return

        # For training, shuffle and batch
        indices = np.arange(len(y))
        rng = np.random.RandomState(random_state) if random_state is not None else np.random
        rng.shuffle(indices)

        # Generate batches
        start_idx = 0
        while start_idx < len(indices):
            batch_indices = indices[start_idx:start_idx + config.BATCH_SIZE]
            X_batch = X[batch_indices]
            y_batch = y[batch_indices]
            yield X_batch, y_batch
            start_idx += config.BATCH_SIZE
    
    def calculate_class_weights(self):
        """Compute global class weights"""
        class_counts = {0: 0, 1: 0}
        
        for _, mask_path, _ in self.image_pairs:
            mask = load_mask(mask_path)
            if mask is not None:
                unique, counts = np.unique(mask, return_counts=True)
                for cls, cnt in zip(unique, counts):
                    if cls in class_counts:
                        class_counts[cls] += cnt
        
        total = sum(class_counts.values())
        if total > 0 and class_counts[0] > 0 and class_counts[1] > 0:
            self.class_weights = {
                0: total / (2 * class_counts[0]),
                1: total / (2 * class_counts[1])
            }
        else:
            print("Warning: Unable to calculate class weights, using default")
            self.class_weights = {0: 1.0, 1: 1.0}
        
        return self.class_weights

# ============================================================================
# Model Pipeline
# ============================================================================
class CloudDetectionModel:
    def __init__(self):
        self.model = None
        self.feature_selector = None
        self.scaler = None
        self.selected_features = None
        self.feature_importances = None
    
    def _create_ensemble(self, class_weights=None):
        """Create a voting ensemble of classifiers"""
        # SGD Classifier
        sgd = SGDClassifier(**{
            **config.SGD_PARAMS,
            'class_weight': class_weights
        })
        
        # Random Forest
        rf = RandomForestClassifier(**{
            **config.RF_PARAMS,
            'class_weight': class_weights
        })
        
        # Linear SVM (calibrated for probability outputs)
        svm = CalibratedClassifierCV(
            LinearSVC(**{
                **config.SVM_PARAMS, 
                'class_weight': class_weights
            }),
            cv=3
        )
        
        # Create voting ensemble
        ensemble = VotingClassifier(
            estimators=[
                ('sgd', sgd),
                ('rf', rf),
                ('svm', svm)
            ],
            voting='soft',  # Use probability estimates
            n_jobs=-1
        )
        
        return ensemble
    
    def select_features(self, X, y):
        """Perform feature selection"""
        print("Performing feature selection...")
        
        # Initialize and fit a RandomForest for feature importance
        rf_selector = RandomForestClassifier(
            n_estimators=50,
            max_depth=10,
            n_jobs=-1,
            random_state=42
        )
        rf_selector.fit(X, y)
        
        # Get feature importances
        importances = rf_selector.feature_importances_
        self.feature_importances = importances
        
        # Select top features
        selector = SelectFromModel(
            rf_selector,
            threshold=-np.inf,  # Keep all features initially
            prefit=True,
            max_features=config.MAX_FEATURES
        )
        
        X_selected = selector.transform(X)
        self.feature_selector = selector
        
        # Record selected feature indices
        self.selected_features = selector.get_support(indices=True)
        
        print(f"Selected {len(self.selected_features)} features out of {X.shape[1]}")
        return X_selected
    
    def fit(self, train_gen_func, val_gen_func=None, class_weights=None):
        """Train the model with early stopping"""
        print("Starting model training...")
        
        # Get initial batch of data for feature selection and scaling
        for X_train, y_train in train_gen_func():
            break
        
        # Initial preprocessing
        print(f"Initial data shape: {X_train.shape}")
        
        # Feature scaling
        self.scaler = StandardScaler()
        X_train = self.scaler.fit_transform(X_train)
        
        # Feature selection
        if config.FEATURE_SELECTION:
            X_train = self.select_features(X_train, y_train)
        
        # Create model
        print("Creating model...")
        self.model = self._create_ensemble(class_weights)
        
        # Training loop
        best_f1 = 0
        no_improvement_count = 0
        val_metrics_history = []
        
        for epoch in range(config.N_EPOCHS):
            print(f"\nEpoch {epoch + 1}/{config.N_EPOCHS}")
            
            # Training phase
            batch_count = 0
            for X_batch, y_batch in tqdm(train_gen_func(), desc="Training"):
                # Apply preprocessing
                X_batch = self.scaler.transform(X_batch)
                if config.FEATURE_SELECTION:
                    X_batch = self.feature_selector.transform(X_batch)
                
                # Partial fit for incremental learning
                if hasattr(self.model, 'partial_fit'):
                    self.model.partial_fit(X_batch, y_batch, classes=[0, 1])
                else:  # For non-incremental models, accumulate data and fit once
                    if batch_count == 0:
                        X_accumulated = X_batch
                        y_accumulated = y_batch
                    else:
                        X_accumulated = np.vstack([X_accumulated, X_batch])
                        y_accumulated = np.hstack([y_accumulated, y_batch])
                
                batch_count += 1
            
            # For non-incremental models, fit on accumulated data
            if not hasattr(self.model, 'partial_fit') and batch_count > 0:
                print(f"Fitting model on {len(y_accumulated)} samples...")
                self.model.fit(X_accumulated, y_accumulated)
            
            print(f"Trained on {batch_count} batches")
            
            # Validation phase
            if val_gen_func:
                val_preds, val_true, val_proba = [], [], []
                
                for X_val, y_val in val_gen_func():
                    # Apply preprocessing
                    X_val = self.scaler.transform(X_val)
                    if config.FEATURE_SELECTION:
                        X_val = self.feature_selector.transform(X_val)
                    
                    # Make predictions
                    val_preds.extend(self.model.predict(X_val))
                    val_true.extend(y_val)
                    
                    # Get probabilities if available
                    if hasattr(self.model, 'predict_proba'):
                        val_proba.extend(self.model.predict_proba(X_val)[:, 1])
                
                if val_true:
                    # Calculate metrics
                    metrics = calculate_metrics(val_true, val_preds, val_proba if val_proba else None)
                    val_metrics_history.append(metrics)
                    
                    current_f1 = metrics['f1']
                    print(f"Val F1: {current_f1:.4f} | Best: {best_f1:.4f}")
                    
                    # Early stopping check
                    if current_f1 > best_f1 + config.EARLY_STOP_DELTA:
                        best_f1 = current_f1
                        no_improvement_count = 0
                        # Save best model
                        self.save(config.MODEL_PATH / "best_model.joblib")
                        print("↑ New best model saved ↑")
                    else:
                        no_improvement_count += 1
                        print(f"No improvement ({no_improvement_count}/{config.EARLY_STOP_PATIENCE})")
                        
                        if no_improvement_count >= config.EARLY_STOP_PATIENCE:
                            print(f"Early stopping triggered at epoch {epoch+1}!")
                            break
        
        # Save final model
        self.save(config.MODEL_PATH / "final_model.joblib")
        return val_metrics_history
    
    def predict(self, X):
        """Make binary predictions"""
        # Preprocess
        X = self.scaler.transform(X)
        if config.FEATURE_SELECTION and self.feature_selector is not None:
            X = self.feature_selector.transform(X)
        
        # Predict
        return self.model.predict(X)
    
    def predict_proba(self, X):
        """Predict probabilities"""
        # Preprocess
        X = self.scaler.transform(X)
        if config.FEATURE_SELECTION and self.feature_selector is not None:
            X = self.feature_selector.transform(X)
        
        # Predict probabilities
        if hasattr(self.model, 'predict_proba'):
            return self.model.predict_proba(X)[:, 1]
        else:
            # Fall back to binary predictions
            return self.predict(X).astype(float)
    
    def save(self, path):
        """Save model and preprocessing components"""
        model_data = {
            'model': self.model,
            'scaler': self.scaler,
            'feature_selector': self.feature_selector,
            'selected_features': self.selected_features,
            'feature_importances': self.feature_importances
        }
        joblib.dump(model_data, path)
        print(f"Model saved to {path}")
    
    @classmethod
    def load(cls, path):
        """Load model from file"""
        model_data = joblib.load(path)
        
        model = cls()
        model.model = model_data['model']
        model.scaler = model_data['scaler']
        model.feature_selector = model_data['feature_selector']
        model.selected_features = model_data['selected_features']
        model.feature_importances = model_data['feature_importances']
        return model

# ============================================================================
# Evaluation and Metrics
# ============================================================================
def calculate_metrics(y_true, y_pred, y_proba=None):
    """Calculate comprehensive evaluation metrics"""
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'precision': precision_score(y_true, y_pred, zero_division=0),
        'recall': recall_score(y_true, y_pred, zero_division=0),
        'f1': f1_score(y_true, y_pred, zero_division=0),
        'confusion_matrix': confusion_matrix(y_true, y_pred).tolist(),
        'classification_report': classification_report(y_true, y_pred, output_dict=True)
    }
    
    if y_proba is not None and len(np.unique(y_true)) > 1:
        try:
            metrics['roc_auc'] = roc_auc_score(y_true, y_proba)
        except Exception:
            metrics['roc_auc'] = 0.5  # Default value if calculation fails
    
    return metrics

def save_metrics(metrics, name, output_dir):
    """Save metrics to JSON file"""
    log_dir = output_dir / "logs"
    os.makedirs(log_dir, exist_ok=True)
    
    metrics_path = log_dir / f"{name}_metrics.json"
    with open(metrics_path, 'w') as f:
        json.dump(metrics, f, indent=4)
    
    print(f"Saved {name} metrics to {metrics_path}")

def plot_feature_importance(model, output_path=None):
    """Plot feature importance if available"""
    if model.feature_importances is not None and model.selected_features is not None:
        plt.figure(figsize=(12, 8))
        
        # Only plot for selected features
        indices = np.argsort(model.feature_importances)[-20:]  # Top 20 features
        plt.barh(range(len(indices)), model.feature_importances[indices])
        plt.yticks(range(len(indices)), [f"Feature {i}" for i in indices])
        plt.xlabel('Feature Importance')
        plt.title('Top 20 Important Features')
        
        if output_path:
            plt.savefig(output_path, dpi=300, bbox_inches='tight')
            print(f"Feature importance plot saved to {output_path}")
        else:
            plt.show()
        
        plt.close()

# ============================================================================
# Cross-Validation
# ============================================================================
def cross_validate(image_pairs, n_folds=5, random_state=42):
    """Perform cross-validation"""
    print(f"Starting {n_folds}-fold cross-validation...")
    
    # Prepare folds
    kf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=random_state)
    
    # Extract categories for stratification
    categories = [category for _, _, category in image_pairs]
    category_ids = {cat: i for i, cat in enumerate(set(categories))}
    stratify_values = [category_ids[cat] for cat in categories]
    
    cv_metrics = []
    
    for fold, (train_idx, val_idx) in enumerate(kf.split(image_pairs, stratify_values)):
        print(f"\nFold {fold+1}/{n_folds}")
        
        # Split data
        train_pairs = [image_pairs[i] for i in train_idx]
        val_pairs = [image_pairs[i] for i in val_idx]
        
        # Create datasets
        train_ds = CloudDataset(
            train_pairs, 
            mode='train',
            feature_cache_dir=config.FEATURES_PATH / f"fold_{fold+1}_train"
        )
        val_ds = CloudDataset(
            val_pairs, 
            mode='val',
            feature_cache_dir=config.FEATURES_PATH / f"fold_{fold+1}_val"
        )
        
        # Calculate class weights
        class_weights = train_ds.calculate_class_weights()
        
        # Initialize and train model
        model = CloudDetectionModel()
        val_history = model.fit(
            train_ds.batch_generator,
            val_ds.batch_generator,
            class_weights
        )
        
        # Get best validation metrics
        best_metrics = max(val_history, key=lambda x: x['f1']) if val_history else None
        if best_metrics:
            print(f"Fold {fold+1} best F1: {best_metrics['f1']:.4f}")
            cv_metrics.append(best_metrics)
            
            # Save fold-specific model
            model.save(config.MODEL_PATH / f"model_fold_{fold+1}.joblib")
    
    # Calculate average metrics
    if cv_metrics:
        avg_metrics = {
            'accuracy': np.mean([m['accuracy'] for m in cv_metrics]),
            'precision': np.mean([m['precision'] for m in cv_metrics]),
            'recall': np.mean([m['recall'] for m in cv_metrics]),
            'f1': np.mean([m['f1'] for m in cv_metrics])
        }
        
        if 'roc_auc' in cv_metrics[0]:
            avg_metrics['roc_auc'] = np.mean([m['roc_auc'] for m in cv_metrics])
        
        print("\nCross-Validation Results:")
        print(f"Avg Accuracy: {avg_metrics['accuracy']:.4f}")
        print(f"Avg Precision: {avg_metrics['precision']:.4f}")
        print(f"Avg Recall: {avg_metrics['recall']:.4f}")
        print(f"Avg F1 Score: {avg_metrics['f1']:.4f}")
        if 'roc_auc' in avg_metrics:
            print(f"Avg ROC AUC: {avg_metrics['roc_auc']:.4f}")
        
        # Save CV metrics
        all_cv_data = {
            'folds': cv_metrics,
            'average': avg_metrics
        }
        save_metrics(all_cv_data, 'cross_validation', config.OUTPUT_PATH)
    
    return cv_metrics

# ============================================================================
# Prediction and Visualization
# ============================================================================
def plot_image_and_mask(image, mask, output_path=None):
    """Plot satellite image with cloud mask overlay"""
    # Convert image from CHW to HWC format if needed
    if image.ndim == 3 and image.shape[0] in [3, 4]:
        # Assuming image is (C, H, W) with C=4 (R,G,B,NIR)
        # Create RGB image for visualization
        rgb_image = np.transpose(image[:3], (1, 2, 0))
    else:
        rgb_image = image
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Plot original image (RGB bands)
    axes[0].imshow(np.clip(rgb_image[:, :, :3], 0, 1))
    axes[0].set_title("RGB Image")
    axes[0].axis('off')
    
    # Plot mask
    axes[1].imshow(mask, cmap='gray')
    axes[1].set_title("Cloud Mask")
    axes[1].axis('off')
    
    # Plot overlay
    overlay = np.copy(rgb_image)
    overlay_mask = np.zeros_like(rgb_image)
    overlay_mask[:, :, 0] = mask * 0.7  # Red channel for clouds
    
    # Blend mask with image
    alpha = 0.5
    blended = (1-alpha) * rgb_image + alpha * overlay_mask
    axes[2].imshow(np.clip(blended, 0, 1))
    axes[2].set_title("Cloud Overlay")
    axes[2].axis('off')
    
    plt.tight_layout()
    
    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Visualization saved to {output_path}")
    else:
        plt.show()
    
    plt.close()

def predict_full_image(model, image_path, output_path=None):
    """Predict cloud mask for a full image with tiling"""
    # Load image
    image = load_and_normalize_tiff(image_path)
    if image is None:
        print(f"Failed to load image: {image_path}")
        return None
    
    # Process image in tiles to save memory
    img_tiles, coords = generate_tiles(image, config.TILE_SIZE)
    h, w = image.shape[1], image.shape[2]
    
    # Initialize full mask
    full_mask = np.zeros((h, w), dtype=np.float32)
    
    # Process each tile
    for tile, (y, x) in zip(img_tiles, coords):
        # Extract features
        indices = calculate_spectral_indices(tile)
        features = extract_features(tile, indices)
        
        # Get probabilities
        proba = model.predict_proba(features)
        
        # Reshape to tile shape
        tile_mask = proba.reshape(config.TILE_SIZE, config.TILE_SIZE)
        
        # Place in full mask
        full_mask[y:y+config.TILE_SIZE, x:x+config.TILE_SIZE] = tile_mask
    
    # Apply threshold and post-processing
    binary_mask = (full_mask > config.PROBABILITY_THRESHOLD).astype(np.uint8)
    
    if config.POST_PROCESS:
        # Remove small isolated pixels (noise)
        binary_mask = ndimage.binary_opening(
            binary_mask, 
            structure=np.ones((config.MORPHOLOGY_SIZE, config.MORPHOLOGY_SIZE))
        )
        
        # Fill small holes
        binary_mask = ndimage.binary_closing(
            binary_mask, 
            structure=np.ones((config.MORPHOLOGY_SIZE, config.MORPHOLOGY_SIZE))
        )
    
    # Visualize and save if needed
    if output_path:
        plot_image_and_mask(image, binary_mask, output_path)
    
    return binary_mask

# ============================================================================
# Main Execution
# ============================================================================
def main():
    print("Cloud Detection Pipeline")
    print("=======================")
    
    # 1. Find and list image pairs
    print("Finding image pairs...")
    image_pairs = []
    for category in ['cloud_free', 'partially_clouded', 'fully_clouded']:
        img_dir = config.PROCESSED_DATA / "data" / category
        mask_dir = config.PROCESSED_DATA / "masks" / category
        
        if not img_dir.exists() or not mask_dir.exists():
            print(f"Warning: Directory not found: {img_dir} or {mask_dir}")
            continue
        
        matched = [
            (img_file, mask_dir / img_file.name, category)
            for img_file in img_dir.glob('*.tif')
            if (mask_dir / img_file.name).exists()
        ]
        
        # Limit samples per category if needed
        if config.SAMPLES_PER_CATEGORY > 0:
            np.random.shuffle(matched)
            matched = matched[:config.SAMPLES_PER_CATEGORY]
        
        image_pairs.extend(matched)
    
    print(f"Found {len(image_pairs)} image pairs")
    if not image_pairs:
        print("No image pairs found. Check your data paths.")
        return
    
    # 2. Split data
    print("\nSplitting datasets...")
    train_pairs, val_test_pairs = train_test_split(
        image_pairs,
        test_size=config.VAL_SPLIT + config.TEST_SPLIT,
        random_state=42
    )
    
    val_pairs, test_pairs = train_test_split(
        val_test_pairs,
        test_size=config.TEST_SPLIT / (config.VAL_SPLIT + config.TEST_SPLIT),
        random_state=42
    )
    
    print(f"Train: {len(train_pairs)}, Validation: {len(val_pairs)}, Test: {len(test_pairs)}")
    
    # 3. Create datasets with caching
    train_ds = CloudDataset(
        train_pairs,
        mode='train',
        feature_cache_dir=config.FEATURES_PATH / "train"
    )
    
    val_ds = CloudDataset(
        val_pairs,
        mode='val',
        feature_cache_dir=config.FEATURES_PATH / "val"
    )
    
    test_ds = CloudDataset(
        test_pairs,
        mode='test',
        feature_cache_dir=config.FEATURES_PATH / "test"
    )
    
    # 4. Calculate class weights
    print("\nCalculating class weights...")
    class_weights = train_ds.calculate_class_weights()
    print(f"Class weights: {class_weights}")
    
    # 5. Run cross-validation or train single model
    if config.CROSS_VAL_FOLDS > 1:
        print("\nRunning cross-validation...")
        cv_metrics = cross_validate(train_pairs, config.CROSS_VAL_FOLDS)
        
        # Load best model from cross-validation
        best_model_path = config.MODEL_PATH / "model_fold_1.joblib"  # Default to first fold
        if cv_metrics:
            # Find best fold
            best_f1 = 0
            best_fold = 1
            for fold, metrics in enumerate(cv_metrics, 1):
                if metrics['f1'] > best_f1:
                    best_f1 = metrics['f1']
                    best_fold = fold
            best_model_path = config.MODEL_PATH / f"model_fold_{best_fold}.joblib"
        
        print(f"Loading best model from {best_model_path}")
        model = CloudDetectionModel.load(best_model_path)
    else:
        # Train a single model
        print("\nTraining model...")
        model = CloudDetectionModel()
        model.fit(
            train_ds.batch_generator,
            val_ds.batch_generator,
            class_weights
        )
    
    # 6. Final evaluation on test set
    print("\nEvaluating on test set...")
    test_preds, test_true, test_proba = [], [], []
    
    for X_test, y_test in test_ds.batch_generator():
        test_preds.extend(model.predict(X_test))
        test_true.extend(y_test)
        test_proba.extend(model.predict_proba(X_test))
    
    if test_true:
        test_metrics = calculate_metrics(test_true, test_preds, test_proba)
        print("\nTest Set Results:")
        print(f"Accuracy: {test_metrics['accuracy']:.4f}")
        print(f"Precision: {test_metrics['precision']:.4f}")
        print(f"Recall: {test_metrics['recall']:.4f}")
        print(f"F1 Score: {test_metrics['f1']:.4f}")
        if 'roc_auc' in test_metrics:
            print(f"ROC AUC: {test_metrics['roc_auc']:.4f}")
        
        # Save test metrics
        save_metrics(test_metrics, 'test', config.OUTPUT_PATH)
    
    # 7. Plot feature importance
    plot_feature_importance(model, config.OUTPUT_PATH / "logs" / "feature_importance.png")
    
    # 8. Example predictions
    print("\nGenerating example predictions...")
    test_images = []
    for _, _, category in test_pairs[:3]:  # Take one from each category if possible
        img_dir = config.PROCESSED_DATA / "data" / category
        for img_file in img_dir.glob('*.tif'):
            test_images.append((img_file, category))
            break
    
    for i, (img_path, category) in enumerate(test_images):
        print(f"Predicting for {img_path.name} ({category})...")
        output_path = config.OUTPUT_PATH / "predictions" / f"pred_{img_path.stem}.png"
        _ = predict_full_image(model, img_path, output_path)

if __name__ == "__main__":
    main()

Cloud Detection Pipeline
Finding image pairs...
Found 10574 image pairs

Splitting datasets...
Train: 7401, Validation: 2115, Test: 1058

Calculating class weights...
Class weights: {0: 1.2817360952698484, 1: 0.8198010191837211}

Running cross-validation...
Starting 3-fold cross-validation...

Fold 1/3
Starting model training...


Processing train data:   0%|          | 0/4934 [00:00<?, ?it/s]