In [None]:
import sys
import os
sys.path.append(os.path.abspath("../src"))

In [None]:
import os
import numpy as np
from pathlib import Path
from tqdm import tqdm
from scipy import ndimage
from scipy.ndimage import uniform_filter, minimum_filter, maximum_filter
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import (accuracy_score, precision_score, recall_score,confusion_matrix, 
                             f1_score, roc_auc_score, classification_report)
import joblib
import gc
from utils import load_and_normalize_tiff, load_mask
import json
from visualization import plot_image_and_mask

import warnings
from rasterio.errors import NotGeoreferencedWarning
warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)

In [None]:
# ============================================================================
# Configuration
# ============================================================================
class Config:
    # Paths
    PROCESSED_DATA = Path("../data/processed")
    OUTPUT_PATH = Path("../outputs")
    MODEL_PATH = OUTPUT_PATH / "models"
    os.makedirs(MODEL_PATH, exist_ok=True)
    
    # Training
    BATCH_SIZE = 3000  # Pixels per batch
    N_EPOCHS = 20
    VAL_SPLIT = 0.2
    TEST_SPLIT = 0.1
    
    # Early stopping
    EARLY_STOP_PATIENCE = 3  # Number of epochs to wait without improvement
    EARLY_STOP_DELTA = 0.001  # Minimum F1 improvement to count as progress
    
    # Model
    SGD_PARAMS = {
        'loss': 'modified_huber',  # More robust to outliers than log_loss
        'penalty': 'l2',
        'alpha': 0.0001,
        'learning_rate': 'adaptive',
        'eta0': 0.01,
        'max_iter': 1,
        'class_weight': None
    }
    
    # Feature engineering
    SPATIAL_WINDOW = 3  # 3x3 window for spatial features
    TEXTURE_WINDOW = 5  # 5x5 window for texture features

config = Config()

In [None]:
# ============================================================================
# Feature Extraction (Optimized)
# ============================================================================
def calculate_spectral_indices(image):
    """Vectorized spectral index calculation with NaN handling"""
    if image.ndim == 3 and image.shape[0] == 4:
        red = image[0].astype('float32')
        green = image[1].astype('float32')
        nir = image[3].astype('float32')
    else:
        raise ValueError(f"Unexpected image shape: {image.shape}")
    
    eps = 1e-6
    with np.errstate(divide='ignore', invalid='ignore'):
        ndvi = (nir - red) / (nir + red + eps)
        ndwi = (green - nir) / (green + nir + eps)
    
    # Replace NaNs/Infs with 0
    ndvi = np.nan_to_num(ndvi, nan=0.0, posinf=0.0, neginf=0.0)
    ndwi = np.nan_to_num(ndwi, nan=0.0, posinf=0.0, neginf=0.0)
    
    return {'ndvi': ndvi, 'ndwi': ndwi}

def vectorized_spatial_features(band):
    """Compute spatial features with NaN handling"""
    
    mean = uniform_filter(band, size=config.SPATIAL_WINDOW)
    mean_sq = uniform_filter(band**2, size=config.SPATIAL_WINDOW)
    variance = np.maximum(mean_sq - mean**2, 0.0)  # Ensure non-negative
    std = np.sqrt(variance)
    
    minima = minimum_filter(band, size=config.SPATIAL_WINDOW)
    maxima = maximum_filter(band, size=config.SPATIAL_WINDOW)
    
    return np.stack([mean, std, minima, maxima], axis=-1)

def extract_features(image, indices):
    """Optimized feature extraction for 4x512x512 images"""
    if image.shape != (4, 512, 512):
        raise ValueError(f"Unexpected image shape: {image.shape}")
    
    # Transpose to (H, W, C)
    image = np.transpose(image, (1, 2, 0))
    
    # Pre-allocate feature array
    n_pixels = 512 * 512
    features = np.empty((n_pixels, 22), dtype='float32')  # 4 + 2 + (4*4) = 22
    
    # Raw bands (4 features)
    features[:, 0:4] = image.reshape(-1, 4)
    
    # Spectral indices (2 features)
    features[:, 4] = indices['ndvi'].ravel()
    features[:, 5] = indices['ndwi'].ravel()
    
    # Spatial features (16 features)
    for band in range(4):
        spatial = vectorized_spatial_features(image[:, :, band])
        features[:, 6 + band*4 : 6 + (band+1)*4] = spatial.reshape(-1, 4)
    
    return features

In [None]:
# ============================================================================
# Memory-Efficient Data Pipeline
# ============================================================================
class CloudDataset:
    def __init__(self, image_pairs):
        self.image_pairs = image_pairs
        self.class_weights = None
        
    def _stratified_sampling(self, mask, n_samples):
        """Balanced sampling preserving class ratio"""
        flat_mask = mask.ravel()
        cloud_idx = np.where(flat_mask == 1)[0]
        clear_idx = np.where(flat_mask == 0)[0]
        
        n_cloud = int(n_samples * 0.5)
        n_clear = n_samples - n_cloud
        
        sampled_cloud = np.random.choice(cloud_idx, min(n_cloud, len(cloud_idx)), replace=False)
        sampled_clear = np.random.choice(clear_idx, min(n_clear, len(clear_idx)), replace=False)
        
        return np.concatenate([sampled_cloud, sampled_clear])
    
    def batch_generator(self):
        """Yields (features, labels) batches with balanced classes"""
        for img_path, mask_path, _ in self.image_pairs:
            # print(f"Processing: {img_path.name}")
            # Load data
            image = load_and_normalize_tiff(img_path)
            if np.isnan(image).any():
                print(f"Warning: NaNs detected in {img_path}, replacing with 0")
                image = np.nan_to_num(image, nan=0.0)
            
            mask = load_mask(mask_path)
            indices = calculate_spectral_indices(image)
            
            # Extract features
            features = extract_features(image, indices)
            features = np.nan_to_num(features, nan=0.0)
            labels = mask.ravel()
            
            # Batch processing
            n_batches = len(labels) // config.BATCH_SIZE
            for i in range(n_batches):
                # Stratified sampling for class balance
                batch_idx = self._stratified_sampling(mask, config.BATCH_SIZE)
                
                X_batch = features[batch_idx]
                y_batch = labels[batch_idx]
                
                # Ensure finite values
                assert not np.isinf(features).any(), "Infinite values detected in features"
                yield X_batch, y_batch
                
            # Cleanup
            del image, mask, features
            gc.collect()
    
    def calculate_class_weights(self):
        """Compute global class weights"""
        class_counts = {0: 0, 1: 0}
        for _, mask_path, _ in self.image_pairs:
            mask = load_mask(mask_path)
            unique, counts = np.unique(mask, return_counts=True)
            for cls, cnt in zip(unique, counts):
                class_counts[cls] += cnt
        
        total = sum(class_counts.values())
        self.class_weights = {
            0: total / (2 * class_counts[0]),
            1: total / (2 * class_counts[1])
        }

In [None]:
# ============================================================================
# Metrics Saving
# ============================================================================
def calculate_metrics(y_true, y_pred, y_proba=None):
    """Calculate comprehensive metrics"""
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'precision': precision_score(y_true, y_pred),
        'recall': recall_score(y_true, y_pred),
        'f1': f1_score(y_true, y_pred),
        'confusion_matrix': confusion_matrix(y_true, y_pred).tolist(),
        'classification_report': classification_report(y_true, y_pred, output_dict=True)
    }
    
    if y_proba is not None:
        metrics['roc_auc'] = roc_auc_score(y_true, y_proba)
    return metrics

def save_metrics(metrics, name, output_dir):
    """Save metrics to JSON file"""
    metrics_path = output_dir/ "logs" / f"{name}_metrics.json"
    with open(metrics_path, 'w') as f:
        json.dump(metrics, f, indent=4)
    print(f"Saved {name} metrics to {metrics_path}")

In [None]:
# ============================================================================
# Inference & Prediction
# ============================================================================
def predict_full_image(model, image_path, output_path=None):
    """Predict cloud mask for a full image"""
    # Load and preprocess
    image = load_and_normalize_tiff(image_path)
    if image.shape != (4, 512, 512):
        raise ValueError(f"Invalid image shape: {image.shape}")
    
    # Extract features
    indices = calculate_spectral_indices(image)
    features = extract_features(image, indices)
    
    # Predict
    proba = model.predict_proba(features)[:, 1]
    mask = proba.reshape(512, 512)
    
    # Post-processing
    mask = (mask > 0.5).astype(np.uint8)  # Thresholding
    mask = ndimage.binary_closing(mask)  # Remove small holes
    
    # Save output
    if output_path:
        plot_image_and_mask(image, mask, output_path)
        print(f"Saved prediction to {output_path}")
    
    return mask

In [None]:
# ============================================================================
# Training & Validation
# ============================================================================
def train_model(train_gen_func, val_gen_func, class_weights):
    model = SGDClassifier(**{**config.SGD_PARAMS, 'class_weight': class_weights})
    best_f1 = 0
    all_metrics = {'val': [], 'test': {}}
    no_improvement_count = 0  # Track epochs without improvement

    for epoch in range(config.N_EPOCHS):
        print(f"\nEpoch {epoch + 1}/{config.N_EPOCHS}")
        
        # Training phase
        batch_count = 0
        for X_batch, y_batch in tqdm(train_gen_func(), desc="Training", unit="batch"):
            model.partial_fit(X_batch, y_batch, classes=[0, 1])
            batch_count += 1

        print(f"Trained on {batch_count} batches")

        # Validation phase
        val_preds, val_true, val_proba = [], [], []
        for X_val, y_val in val_gen_func():
            val_preds.extend(model.predict(X_val))
            val_true.extend(y_val)
            val_proba.extend(model.predict_proba(X_val)[:, 1])

        if val_true:
            val_metrics = calculate_metrics(val_true, val_preds, val_proba)
            all_metrics['val'].append(val_metrics)
            
            current_f1 = val_metrics['f1']
            print(f"Val F1: {current_f1:.4f} | Best: {best_f1:.4f}")
            
            # Early stopping check
            if current_f1 > best_f1 + config.EARLY_STOP_DELTA:
                best_f1 = current_f1
                no_improvement_count = 0
                joblib.dump(model, config.MODEL_PATH / "SGD_best_model.joblib")
                print("↑ New best model saved ↑")
            else:
                no_improvement_count += 1
                print(f"No improvement ({no_improvement_count}/{config.EARLY_STOP_PATIENCE})")
                
                if no_improvement_count >= config.EARLY_STOP_PATIENCE:
                    print(f"Early stopping triggered at epoch {epoch+1}!")
                    break  # Exit epoch loop

    # Save final model and metrics (even if stopped early)
    joblib.dump(model, config.MODEL_PATH / "SGD_final_model.joblib")
    save_metrics(all_metrics, 'training', config.OUTPUT_PATH)
    
    return model, all_metrics

In [None]:
# ============================================================================
# Main Execution
# ============================================================================
if __name__ == "__main__":
    # 1. Prepare datasets
    # image_pairs = []
    # for category in ['cloud_free', 'partially_clouded', 'fully_clouded']:
    #     img_dir = config.PROCESSED_DATA / "data" / category
    #     mask_dir = config.PROCESSED_DATA / "masks" / category
        
    #     for img_file in img_dir.glob('*.tif'):
    #         mask_file = mask_dir / img_file.name
    #         if mask_file.exists():
    #             image_pairs.append((img_file, mask_file, category))
    
    from itertools import islice

    image_pairs = []
    for category in ['cloud_free', 'partially_clouded', 'fully_clouded']:
        img_dir = config.PROCESSED_DATA / "data" / category
        mask_dir = config.PROCESSED_DATA / "masks" / category

        matched = (
            (img_file, mask_dir / img_file.name, category)
            for img_file in img_dir.glob('*.tif')
            if (mask_dir / img_file.name).exists()
        )

        image_pairs.extend(islice(matched, 300))

    
    # 2. Split datasets
    train_pairs, val_test_pairs = train_test_split(image_pairs, test_size=config.VAL_SPLIT+config.TEST_SPLIT)
    val_pairs, test_pairs = train_test_split(val_test_pairs, test_size=config.TEST_SPLIT/(config.VAL_SPLIT+config.TEST_SPLIT))
    
    # 3. Create datasets
    train_ds = CloudDataset(train_pairs)
    val_ds = CloudDataset(val_pairs)
    test_ds = CloudDataset(test_pairs)
    
    # 4. Calculate class weights
    train_ds.calculate_class_weights()
    
    # 5. Train model
    print("Starting training...")
    model, metrics = train_model(train_ds.batch_generator, 
                                val_ds.batch_generator, 
                                train_ds.class_weights)
    
    # 6. Final test evaluation
    print("\nFinal Test Evaluation:")
    y_true, y_pred, y_proba = [], [], []
    for X_test, y_test in test_ds.batch_generator():
        y_true.extend(y_test)
        y_pred.extend(model.predict(X_test))
        y_proba.extend(model.predict_proba(X_test)[:, 1])
    
    test_metrics = calculate_metrics(y_true, y_pred, y_proba)
    save_metrics(test_metrics, 'test', config.OUTPUT_PATH)
    print(classification_report(y_true, y_pred))
    print(f"ROC AUC: {test_metrics['roc_auc']:.4f}")

    # 7. Example inference
    example_image = "test/data/101885.tif"
    output_mask = config.OUTPUT_PATH / "predictions" / "predicted_mask_101885.png"
    _ = predict_full_image(model, example_image, output_mask)

In [None]:
example_image = "test/data/289430.tif"
output_mask = config.OUTPUT_PATH / "predicted_mask_289430.png"
_ = predict_full_image(model, example_image, output_mask)

In [None]:
example_image = "test/data/327408.tif"
output_mask = config.OUTPUT_PATH / "predicted_mask_327408.png"
_ = predict_full_image(model, example_image, output_mask)

In [None]:
example_image = "test/data/130726.tif"
output_mask = config.OUTPUT_PATH / "predicted_mask_130726.png"
_ = predict_full_image(model, example_image, output_mask)

In [None]:
example_image = "test/data/140804.tif"
output_mask = config.OUTPUT_PATH / "predicted_mask_140804.png"
_ = predict_full_image(model, example_image, output_mask)

In [None]:
example_image = "test/data/262934.tif"
output_mask = config.OUTPUT_PATH / "predicted_mask_262934.png"
_ = predict_full_image(model, example_image, output_mask)