## Use the trained model to create probabilities and to predict

These zones were not used in training. Python 3.11.9 was in use in here. 

- Code makes the probabilities
- The predictions are rubbish (classifications are "1", the classification of streams and ditches are so close to each other. That's the reason for float32 or flot64 needed.)
- Taking the results to GeoTifs

It's faster to create the codes with AI.

## Taking out probabilities

The post processing did not work, I didn't have time to finish it. I took the results into QGIS for doing the post processing. 

In [3]:
import numpy as np
import pandas as pd
import zarr
import joblib
import os
import gc
import psutil
import time
import matplotlib.pyplot as plt
import multiprocessing
from sklearn.metrics import (
    accuracy_score, recall_score, precision_score, f1_score, 
    confusion_matrix, cohen_kappa_score
)
import logging
from functools import partial

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Memory management constants
MAX_MEMORY_GB = 250  # Maximum available RAM in GB
MEMORY_THRESHOLD = 0.85  # Threshold for safe memory usage (85% of max)
MAX_SAFE_MEMORY_BYTES = MAX_MEMORY_GB * 0.85 * 1024 * 1024 * 1024  # Safe memory limit in bytes
BATCH_SIZE = 1000000  # Initial batch size for processing

def get_memory_usage():
    """Get current memory usage in bytes."""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss

def log_memory_usage(message="Current memory usage"):
    """Log the current memory usage with a custom message."""
    mem_bytes = get_memory_usage()
    mem_gb = mem_bytes / (1024 ** 3)
    logger.info(f"{message}: {mem_gb:.2f} GB ({mem_bytes} bytes)")
    return mem_bytes

def estimate_required_memory(df_size, n_models, n_classes=3):
    """Estimate memory required for processing based on data size."""
    # Approximate memory needed for:
    # 1. DataFrame: ~100 bytes per row × number of rows
    # 2. Feature array: 8 bytes per value × rows × columns
    # 3. Prediction arrays: 8 bytes × rows × classes × models
    # 4. Averaged probs: 8 bytes × rows × classes
    
    df_mem = df_size * 100  # Approximate DataFrame memory
    features_mem = df_size * 10 * 8  # Assuming ~10 features, 8 bytes per float
    pred_mem = df_size * n_classes * n_models * 8  # Prediction arrays
    avg_mem = df_size * n_classes * 8  # Averaged probabilities
    
    # Add 20% buffer
    total_estimate = (df_mem + features_mem + pred_mem + avg_mem) * 1.2
    return total_estimate

def generate_spatial_indices(spatial_shape):
    """Generate row and column indices for a given spatial shape."""
    # Process in batches to reduce memory usage
    total_pixels = spatial_shape[0] * spatial_shape[1]
    row_indices = np.zeros(total_pixels, dtype=np.int32)
    col_indices = np.zeros(total_pixels, dtype=np.int32)
    
    batch_size = min(BATCH_SIZE, total_pixels)
    num_batches = (total_pixels + batch_size - 1) // batch_size
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, total_pixels)
        
        batch_indices = np.arange(start_idx, end_idx)
        row_indices[start_idx:end_idx] = batch_indices // spatial_shape[1]
        col_indices[start_idx:end_idx] = batch_indices % spatial_shape[1]
        
        # Force garbage collection after each batch
        if i % 10 == 0:
            gc.collect()
            
    return row_indices, col_indices

def load_zone_data(zarr_file, zone, selected_features, max_pixels=None):
    """
    Load data for a specific zone with selected features.
    Optionally limit the number of pixels to manage memory usage.
    """
    try:
        log_memory_usage(f"Memory before loading {zone}")
        root = zarr.open(zarr_file, mode='r')
        zone_data = root[zone]
        
        if 'label_3m' not in zone_data.keys():
            logger.warning(f"Zone {zone} does not have label_3m, skipping")
            return None
        
        # Get spatial dimensions
        if 'spatial_shape' in zone_data.attrs:
            spatial_shape = zone_data.attrs['spatial_shape']
        elif 'row_idx' in zone_data and 'col_idx' in zone_data:
            row_idx = zone_data['row_idx'][:]
            col_idx = zone_data['col_idx'][:]  # Load col_idx first
            max_row = np.max(row_idx) + 1
            max_col = np.max(col_idx) + 1
            spatial_shape = (max_row, max_col)
        else:
            spatial_shape = (5000, 5000)  # Default
        
        logger.info(f"Spatial shape for {zone}: {spatial_shape}")
        total_pixels = spatial_shape[0] * spatial_shape[1]
        
        # If max_pixels specified, sample the data
        if max_pixels and total_pixels > max_pixels:
            logger.info(f"Limiting {zone} to {max_pixels} pixels (from {total_pixels})")
            if 'row_idx' in zone_data and 'col_idx' in zone_data:
                # Get existing indices
                row_idx = zone_data['row_idx'][:]
                col_idx = zone_data['col_idx'][:]
                
                # Random sampling
                sample_indices = np.random.choice(len(row_idx), max_pixels, replace=False)
                row_idx = row_idx[sample_indices]
                col_idx = col_idx[sample_indices]
            else:
                # Generate and sample indices
                all_indices = np.arange(total_pixels)
                sample_indices = np.random.choice(total_pixels, max_pixels, replace=False)
                row_idx = np.array([i // spatial_shape[1] for i in sample_indices])
                col_idx = np.array([i % spatial_shape[1] for i in sample_indices])
            
            # Get corresponding labels
            labels = zone_data['label_3m'][:][sample_indices] if 'label_3m' in zone_data else np.zeros(max_pixels)
            effective_size = max_pixels
        else:
            # Use all data
            if 'row_idx' in zone_data and 'col_idx' in zone_data:
                row_idx = zone_data['row_idx'][:]
                col_idx = zone_data['col_idx'][:]
            else:
                row_idx, col_idx = generate_spatial_indices(spatial_shape)
            
            labels = zone_data['label_3m'][:] if 'label_3m' in zone_data else np.zeros(len(row_idx))
            effective_size = len(labels)
        
        # Create dataframe incrementally to save memory
        zone_dict = {}
        try:
            zone_id = int(zone.split('_')[1]) - 1
            zone_dict['zone_id'] = np.full(effective_size, zone_id, dtype=np.int32)
        except (IndexError, ValueError):
            logger.warning(f"Could not extract zone ID from {zone}")
            zone_dict['zone_id'] = np.zeros(effective_size, dtype=np.int32)
        
        zone_dict['row_idx'] = row_idx
        zone_dict['col_idx'] = col_idx
        zone_dict['label_3m'] = labels
        
        # Add metadata
        zone_dict['spatial_shape_x'] = np.full(effective_size, spatial_shape[0], dtype=np.int32)
        zone_dict['spatial_shape_y'] = np.full(effective_size, spatial_shape[1], dtype=np.int32)
        
        # Load each feature one by one, cleaning up after each to minimize peak memory
        for feature in selected_features:
            if feature in ['row_idx', 'col_idx', 'label_3m', 'zone_id', 'spatial_shape_x', 'spatial_shape_y']:
                continue
                
            if feature in zone_data.keys():
                if max_pixels and total_pixels > max_pixels:
                    # Sample the feature data using the same indices
                    zone_dict[feature] = zone_data[feature][:][sample_indices]
                else:
                    zone_dict[feature] = zone_data[feature][:]
            else:
                logger.warning(f"Feature {feature} not found in {zone}")
                zone_dict[feature] = np.zeros_like(labels)
            
            # Force garbage collection after loading each feature
            gc.collect()
            log_memory_usage(f"After loading feature {feature}")
        
        df = pd.DataFrame(zone_dict)
        log_memory_usage(f"After creating DataFrame for {zone}")
        logger.info(f"Loaded data from {zone}, shape: {df.shape}")
        
        return df
        
    except Exception as e:
        logger.error(f"Error loading data from {zone}: {e}")
        return None
    finally:
        # Force garbage collection
        gc.collect()

def batch_predict(model, X_batch):
    """Run prediction on a batch of data using a single model."""
    return model.predict_proba(X_batch)

def process_predictions_batch(batch_idx, X_batch, rf_models):
    """Process predictions for a batch of data using all models."""
    logger.info(f"Processing sub-batch {batch_idx+1} with {X_batch.shape[0]} samples")
    
    # Process each model sequentially to save memory
    all_probabilities = []
    for i, model in enumerate(rf_models):
        start_time = time.time()
        prob = batch_predict(model, X_batch)
        elapsed = time.time() - start_time
        
        # Check probability shape and fix if needed
        if prob.ndim == 2:
            if prob.shape[1] == 3:
                fixed_prob = prob
            elif prob.shape[1] == 2:
                fixed_prob = np.hstack([prob, np.zeros((prob.shape[0], 1))])
            elif prob.shape[1] == 1:
                fixed_prob = np.hstack([np.zeros((prob.shape[0], 2)), prob])
            else:
                raise ValueError(f"Unexpected probability shape from model {i+1}: {prob.shape}")
        elif prob.ndim == 1:
            fixed_prob = np.hstack([np.zeros((prob.shape[0], 2)), prob.reshape(-1, 1)])
        else:
            raise ValueError(f"Model {i+1} returned an unexpected shape: {prob.shape}")
        
        all_probabilities.append(fixed_prob)
        logger.info(f"Model {i+1} prediction took {elapsed:.2f}s, shape: {fixed_prob.shape}")
        
        # Clean up to save memory
        del prob
        gc.collect()
    
    # Calculate average probabilities
    averaged_probabilities = np.mean(all_probabilities, axis=0)
    
    # Clean up individual model predictions to save memory
    del all_probabilities
    gc.collect()
    
    return averaged_probabilities

def predict_for_batch(batch_idx, zones, rf_models, zarr_file, selected_features, results_dir):
    """Generate predictions for a batch of zones with memory management."""
    logger.info(f"Generating probabilities for batch {batch_idx+1}")
    log_memory_usage("Starting batch prediction")

    try:
        # Estimate zone sizes and determine if we need to sample
        root = zarr.open(zarr_file, mode='r')
        total_estimated_pixels = 0
        for zone in zones:
            if zone in root:
                zone_data = root[zone]
                if 'spatial_shape' in zone_data.attrs:
                    shape = zone_data.attrs['spatial_shape']
                    pixels = shape[0] * shape[1]
                elif 'row_idx' in zone_data:
                    pixels = len(zone_data['row_idx'])
                else:
                    pixels = 5000 * 5000  # Default assumption
                total_estimated_pixels += pixels
                logger.info(f"Zone {zone} estimated pixels: {pixels}")
        
        # Estimate required memory
        estimated_memory = estimate_required_memory(total_estimated_pixels, len(rf_models))
        logger.info(f"Estimated memory requirement: {estimated_memory / (1024**3):.2f} GB")
        
        # Determine if we need to limit pixels per zone
        current_mem = get_memory_usage()
        available_mem = MAX_SAFE_MEMORY_BYTES - current_mem
        
        max_pixels_per_zone = None
        if estimated_memory > available_mem:
            # Calculate how many pixels we can safely process
            safe_pixel_count = int(available_mem / (estimate_required_memory(1, len(rf_models)) * len(zones)))
            max_pixels_per_zone = safe_pixel_count
            logger.warning(f"Memory limits exceeded. Limiting to {max_pixels_per_zone} pixels per zone")
        
        # Load data with potential sampling
        zone_data_list = []
        for zone in zones:
            df = load_zone_data(zarr_file, zone, selected_features, max_pixels_per_zone)
            if df is not None:
                zone_data_list.append(df)
            
            # Check memory after each zone
            if get_memory_usage() > MAX_SAFE_MEMORY_BYTES * 0.9:
                logger.warning("Memory usage approaching limit. Forcing garbage collection.")
                gc.collect()
        
        if not zone_data_list:
            logger.warning(f"No valid data loaded for batch {batch_idx+1}")
            return False

        # Combine data frames
        logger.info(f"Combining {len(zone_data_list)} zone dataframes")
        combined_df = pd.concat(zone_data_list, ignore_index=True)
        
        # Clean up individual dataframes to save memory
        del zone_data_list
        gc.collect()
        
        # Extract spatial shape from the first entries
        spatial_shape = (combined_df['spatial_shape_x'].iloc[0], combined_df['spatial_shape_y'].iloc[0])
        logger.info(f"Using spatial shape for prediction: {spatial_shape}")
        
        # Prepare feature columns for model input (exclude metadata columns)
        feature_cols = [f for f in selected_features if f not in ['spatial_shape_x', 'spatial_shape_y']]
        
        # Count samples and determine batch size for prediction
        n_samples = combined_df.shape[0]
        current_memory = get_memory_usage()
        memory_per_sample = estimate_required_memory(1, len(rf_models))
        available_memory = MAX_SAFE_MEMORY_BYTES - current_memory
        
        # Calculate optimal batch size
        optimal_batch_size = min(int(available_memory / memory_per_sample / 2), n_samples, BATCH_SIZE)
        logger.info(f"Processing {n_samples} samples in batches of {optimal_batch_size}")
        
        # Process in batches
        n_batches = (n_samples + optimal_batch_size - 1) // optimal_batch_size
        averaged_results = []
        
        for i in range(n_batches):
            start_idx = i * optimal_batch_size
            end_idx = min((i + 1) * optimal_batch_size, n_samples)
            
            logger.info(f"Processing batch {i+1}/{n_batches}, samples {start_idx}-{end_idx}")
            
            # Extract batch data
            X_batch = combined_df.iloc[start_idx:end_idx][feature_cols].values
            log_memory_usage(f"After extracting batch {i+1}")
            
            # Process batch
            batch_result = process_predictions_batch(i, X_batch, rf_models)
            log_memory_usage(f"After processing batch {i+1}")
            
            # Extract class 1 and 2 probabilities (streams and ditches)
            batch_probs = batch_result[:, 1:3]
            averaged_results.append(batch_probs)
            
            # Clean up
            del X_batch, batch_result
            gc.collect()
        
        # Combine results
        logger.info("Combining batch results")
        y_prob_final = np.vstack(averaged_results)
        log_memory_usage("After combining all batches")
        
        # Clean up combined dataframe
        del combined_df
        gc.collect()
        
        # Handle reshaping to spatial dimensions
        if max_pixels_per_zone is None:
            # If we didn't sample, we expect to have full data for reshaping
            expected_size = spatial_shape[0] * spatial_shape[1] * 2  # height * width * 2 classes
            actual_size = y_prob_final.size
            
            if actual_size != expected_size:
                logger.warning(f"Size mismatch. Expected: {expected_size}, Actual: {actual_size}")
                
                # Try reshaping with available data
                try:
                    pixels_per_zone = y_prob_final.shape[0] // 2  # Divide by 2 classes
                    adjusted_height = int(np.sqrt(pixels_per_zone * spatial_shape[0] / spatial_shape[1]))
                    adjusted_width = pixels_per_zone // adjusted_height
                    
                    logger.info(f"Reshaping to adjusted dimensions: {adjusted_height}x{adjusted_width}x2")
                    final_probabilities = y_prob_final.reshape(adjusted_height, adjusted_width, 2).astype(np.float32)
                except ValueError as e:
                    logger.error(f"Reshape failed: {e}")
                    # Save as flat array
                    final_probabilities = y_prob_final.astype(np.float32)
            else:
                # Reshape to original spatial dimensions
                final_probabilities = y_prob_final.reshape(spatial_shape[0], spatial_shape[1], 2).astype(np.float32)
        else:
            # If we sampled, we can't reliably reshape - save as flat array
            logger.info("Data was sampled - saving flat probability array")
            final_probabilities = y_prob_final.astype(np.float32)
        
        # Save probabilities
        prob_file = os.path.join(results_dir, f"probabilities_zones_batch_{batch_idx+1}.zarr")
        logger.info(f"Saving results to {prob_file}")
        
        # Save in chunks to minimize memory usage
        zarr_group = zarr.open(prob_file, mode="w")
        zarr_group.create_dataset(
            "probabilities", 
            data=final_probabilities,
            dtype=np.float32,
            chunks=True  # Let zarr determine optimal chunk size
        )
        zarr_group.attrs["zones"] = zones
        zarr_group.attrs["spatial_shape"] = spatial_shape
        zarr_group.attrs["prediction_date"] = str(pd.Timestamp.now())
        zarr_group.attrs["sampled_data"] = max_pixels_per_zone is not None
        
        if max_pixels_per_zone is not None:
            zarr_group.attrs["max_pixels_per_zone"] = max_pixels_per_zone
        
        logger.info(f"Saved probabilities for batch {batch_idx+1}")
        log_memory_usage("After saving results")
        
        # Final cleanup
        del y_prob_final, final_probabilities, averaged_results
        gc.collect()
        
        return True
    
    except Exception as e:
        logger.error(f"Error in prediction for batch {batch_idx+1}: {e}", exc_info=True)
        return False
    finally:
        # Force garbage collection
        gc.collect()

def main():
    # Set up memory monitoring
    log_memory_usage("Starting execution")
    
    # Configure paths and settings
    results_dir = "../02_Results"
    zarr_file = "zones_data_2.zarr"
    models_dir = os.path.join(results_dir, "models")
    os.makedirs(results_dir, exist_ok=True)
    
    # Use smaller batches of zones to manage memory
    zone_batches = [
        ["zone_1"], 
        ["zone_3"], 
        ["zone_4"], 
        ["zone_6"],
        ["zone_7"],
        ["zone_8"],
        ["zone_10"]
    ]
    
    # Define features
    selected_features = [
        'col_idx', 'row_idx', 'impoundment_amplified', 'zone_id', 'skyview_gabor', 
        'impoundment_raw', 'conic_mean', 'hpmf_raw', 'skyview_raw', 'hpmf_f', 'slope_channels',
        'spatial_shape_x', 'spatial_shape_y'  # Store dimensions as separate columns
    ]
    
    # Load models
    logger.info("Loading trained models...")
    model_paths = [os.path.join(models_dir, f) for f in os.listdir(models_dir) if f.endswith(".joblib")]
    
    rf_models = []
    for model_path in model_paths:
        try:
            model = joblib.load(model_path)
            rf_models.append(model)
            logger.info(f"Loaded model from {model_path}")
        except Exception as e:
            logger.error(f"Failed to load model {model_path}: {e}")
    
    logger.info(f"Loaded {len(rf_models)} models")
    log_memory_usage("After loading models")
    
    # Process each batch of zones
    logger.info("Generating probabilities for all batches...")
    prediction_results = []
    
    for i, zones in enumerate(zone_batches):
        try:
            logger.info(f"Processing batch {i+1}/{len(zone_batches)}: {zones}")
            prediction_success = predict_for_batch(i, zones, rf_models, zarr_file, selected_features, results_dir)
            prediction_results.append(prediction_success)
            
            # Force garbage collection between batches
            gc.collect()
            log_memory_usage(f"After processing batch {i+1}")
        except Exception as e:
            logger.error(f"Error processing batch {i+1}: {e}", exc_info=True)
            prediction_results.append(False)

    # Final status report
    successful_batches = sum(prediction_results)
    logger.info(f"Successfully processed {successful_batches}/{len(zone_batches)} batches")
    
    if not any(prediction_results):
        logger.error("Failed to generate probabilities for any batch.")
    else:
        logger.info("Processing complete. Results saved.")
    
if __name__ == "__main__":
    main()

2025-04-02 10:26:42,947 - __main__ - INFO - Starting execution: 1.20 GB (1286557696 bytes)
2025-04-02 10:26:42,948 - __main__ - INFO - Loading trained models...


2025-04-02 10:26:43,184 - __main__ - INFO - Loaded model from ../02_Results/models/rf_model_1.joblib
2025-04-02 10:26:43,195 - __main__ - INFO - Loaded model from ../02_Results/models/rf_model_8.joblib
2025-04-02 10:26:43,206 - __main__ - INFO - Loaded model from ../02_Results/models/rf_model_7.joblib
2025-04-02 10:26:43,434 - __main__ - INFO - Loaded model from ../02_Results/models/rf_model_0.joblib
2025-04-02 10:26:43,702 - __main__ - INFO - Loaded model from ../02_Results/models/rf_model_2.joblib
2025-04-02 10:26:43,713 - __main__ - INFO - Loaded model from ../02_Results/models/rf_model_5.joblib
2025-04-02 10:26:43,774 - __main__ - INFO - Loaded model from ../02_Results/models/rf_model_6.joblib
2025-04-02 10:26:43,910 - __main__ - INFO - Loaded model from ../02_Results/models/rf_model_3.joblib
2025-04-02 10:26:43,921 - __main__ - INFO - Loaded model from ../02_Results/models/rf_model_4.joblib
2025-04-02 10:26:43,922 - __main__ - INFO - Loaded 9 models
2025-04-02 10:26:43,922 - __mai

### Probabilities to GeoTifs

In [9]:
import os
import logging
import numpy as np
import zarr
import rasterio
from rasterio.transform import from_bounds

# Configure logging
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Define zone coordinates
zone_boundaries = {
    "zone_1": {"upper_left": (377982, 6854660), "lower_right": (380482, 6852160)},
    "zone_2": {"upper_left": (377982, 6857160), "lower_right": (380482, 6854660)},
    "zone_3": {"upper_left": (380482, 6857160), "lower_right": (382982, 6854660)},
    "zone_4": {"upper_left": (375482, 6859660), "lower_right": (377982, 6857160)},
    "zone_5": {"upper_left": (377982, 6859660), "lower_right": (380482, 6857160)},
    "zone_6": {"upper_left": (380482, 6859660), "lower_right": (382982, 6857160)},
    "zone_7": {"upper_left": (375482, 6862159.999999999), "lower_right": (377982, 6859660)},
    "zone_8": {"upper_left": (377982, 6862159.999999999), "lower_right": (380482, 6859660)},
    "zone_9": {"upper_left": (380482, 6862159.999999999), "lower_right": (382982, 6859660)},
    "zone_10": {"upper_left": (372982, 6864660), "lower_right": (375482, 6862159.999999999)},
    "zone_11": {"upper_left": (375482, 6864660), "lower_right": (377982, 6862159.999999999)},
    "zone_12": {"upper_left": (377982, 6864660), "lower_right": (380482, 6862159.999999999)},
    "zone_13": {"upper_left": (370482, 6867160), "lower_right": (372982, 6864660)},
    "zone_14": {"upper_left": (372982, 6867160), "lower_right": (375482, 6864660)},
    "zone_15": {"upper_left": (375482, 6867160), "lower_right": (377982, 6864660)},
    "zone_16": {"upper_left": (377982, 6867160), "lower_right": (380482, 6864660)},
    "zone_17": {"upper_left": (370482, 6869660), "lower_right": (372982, 6867160)},
    "zone_18": {"upper_left": (372982, 6869660), "lower_right": (375482, 6867160)},
    "zone_19": {"upper_left": (375482, 6869660), "lower_right": (377982, 6867160)},
    "zone_20": {"upper_left": (372982, 6872160), "lower_right": (375482, 6869660)},
    "zone_21": {"upper_left": (375482, 6872160), "lower_right": (377982, 6869660)}
}

def export_zarr_to_geotiff_by_batch(zarr_file, output_path, batch_number):
    """
    Export probabilities from a Zarr file to GeoTIFF format for zones in a specific batch.
    
    Parameters:
    -----------
    zarr_file : str
        Path to the zarr file containing the probabilities results
    output_path : str
        Directory to save the GeoTIFF files
    batch_number : int
        The batch number (1 to 7) to process
    """
    # Define the zones in each batch
    zone_batches = [
        ["zone_1"], 
        ["zone_3"], 
        ["zone_4"], 
        ["zone_6"],
        ["zone_7"],
        ["zone_8"],
        ["zone_10"]
    ]
    
    # Select the zones for the requested batch
    if batch_number < 1 or batch_number > len(zone_batches):
        logger.error(f"Invalid batch number: {batch_number}. Should be between 1 and {len(zone_batches)}")
        return False
    
    target_zones = zone_batches[batch_number - 1]
    logger.info(f"Processing batch {batch_number} with zones: {target_zones}")
    
    # Ensure output directory exists
    os.makedirs(output_path, exist_ok=True)
    
    try:
        # Open Zarr file
        root = zarr.open(zarr_file, mode="r")
        
        # Check for probabilities array
        if "probabilities" not in root:
            logger.error("No 'probabilities' array found in Zarr file")
            return False
            
        probabilities = root["probabilities"][:]
        logger.info(f"Loaded probabilities array with shape {probabilities.shape}")
        
        # Standard grid size for all zones (5000x5000)
        grid_width, grid_height = 5000, 5000
        
        # Each batch file contains data for a single zone
        zone_name = target_zones[0]  # Each batch has one zone
        
        try:
            # Get coordinates for this zone
            if zone_name not in zone_boundaries:
                logger.error(f"No coordinates found for {zone_name}, skipping")
                return False
            
            boundaries = zone_boundaries[zone_name]
            ul_x, ul_y = boundaries["upper_left"]
            lr_x, lr_y = boundaries["lower_right"]
            
            # Check if we have the right shape of data
            if probabilities.shape[0] == grid_height and probabilities.shape[1] == grid_width:
                # Create transform
                transform = from_bounds(ul_x, lr_y, lr_x, ul_y, grid_width, grid_height)
                
                # We have a 3D array, so we need to save each channel separately
                for channel in range(probabilities.shape[2]):
                    # Extract the channel data
                    channel_data = probabilities[:, :, channel]
                    
                    # Save as GeoTIFF
                    class_filename = f"{output_path}/{zone_name}_class{channel+1}.tif"
                    meta = {
                        'driver': 'GTiff',
                        'height': grid_height,
                        'width': grid_width,
                        'count': 1,
                        'dtype': str(channel_data.dtype),
                        'crs': 'EPSG:3067',  # Finnish ETRS-TM35FIN coordinate system
                        'transform': transform,
                        'nodata': 0
                    }
                    
                    with rasterio.open(class_filename, 'w', **meta) as dst:
                        dst.write(channel_data, 1)
                    
                    logger.info(f"Saved classification map for {zone_name} (channel {channel+1}) to {class_filename}")
                    
                # Also save a combined classification map (highest probability class)
                combined_data = np.argmax(probabilities, axis=2)
                combined_filename = f"{output_path}/{zone_name}_classification.tif"
                meta = {
                    'driver': 'GTiff',
                    'height': grid_height,
                    'width': grid_width,
                    'count': 1,
                    'dtype': str(combined_data.dtype),
                    'crs': 'EPSG:3067',  # Finnish ETRS-TM35FIN coordinate system
                    'transform': transform,
                    'nodata': 0
                }
                
                with rasterio.open(combined_filename, 'w', **meta) as dst:
                    dst.write(combined_data, 1)
                
                logger.info(f"Saved combined classification map for {zone_name} to {combined_filename}")
                
                return True
            else:
                logger.error(f"Unexpected data shape: {probabilities.shape}, expected ({grid_height}, {grid_width}, 2)")
                return False
            
        except Exception as e:
            logger.error(f"Error processing zone {zone_name}: {e}")
            import traceback
            logger.error(traceback.format_exc())
            return False
    
    except Exception as e:
        logger.error(f"Error exporting from Zarr to GeoTIFF: {e}")
        import traceback
        logger.error(traceback.format_exc())
        return False

if __name__ == "__main__":
    # Path to your Zarr files
    zarr_file_batch1 = "../02_Results/probabilities_zones_batch_1.zarr"
    zarr_file_batch2 = "../02_Results/probabilities_zones_batch_2.zarr"
    zarr_file_batch3 = "../02_Results/probabilities_zones_batch_3.zarr"
    zarr_file_batch4 = "../02_Results/probabilities_zones_batch_4.zarr"
    zarr_file_batch5 = "../02_Results/probabilities_zones_batch_5.zarr"
    zarr_file_batch6 = "../02_Results/probabilities_zones_batch_6.zarr"
    zarr_file_batch7 = "../02_Results/probabilities_zones_batch_7.zarr"
    
    # Path for output GeoTIFFs
    output_path = "../02_Results/geotiff_results"
    
    logger.info("Starting export from Zarr files")
    
    # Track successful exports
    successful_exports = 0
    total_exports = 7
    
    # Export zones from each batch
    batch_results = []
    
    # Process batch 1
    logger.info("Processing batch 1")
    success_batch1 = export_zarr_to_geotiff_by_batch(zarr_file_batch1, output_path, 1)
    batch_results.append(success_batch1)
    if success_batch1:
        successful_exports += 1
    
    # Process batch 2
    logger.info("Processing batch 2")
    success_batch2 = export_zarr_to_geotiff_by_batch(zarr_file_batch2, output_path, 2)
    batch_results.append(success_batch2)
    if success_batch2:
        successful_exports += 1
    
    # Process batch 3
    logger.info("Processing batch 3")
    success_batch3 = export_zarr_to_geotiff_by_batch(zarr_file_batch3, output_path, 3)
    batch_results.append(success_batch3)
    if success_batch3:
        successful_exports += 1
    
    # Process batch 4
    logger.info("Processing batch 4")
    success_batch4 = export_zarr_to_geotiff_by_batch(zarr_file_batch4, output_path, 4)
    batch_results.append(success_batch4)
    if success_batch4:
        successful_exports += 1
    
    # Process batch 5
    logger.info("Processing batch 5")
    success_batch5 = export_zarr_to_geotiff_by_batch(zarr_file_batch5, output_path, 5)
    batch_results.append(success_batch5)
    if success_batch5:
        successful_exports += 1
    
    # Process batch 6
    logger.info("Processing batch 6")
    success_batch6 = export_zarr_to_geotiff_by_batch(zarr_file_batch6, output_path, 6)
    batch_results.append(success_batch6)
    if success_batch6:
        successful_exports += 1
    
    # Process batch 7
    logger.info("Processing batch 7")
    success_batch7 = export_zarr_to_geotiff_by_batch(zarr_file_batch7, output_path, 7)
    batch_results.append(success_batch7)
    if success_batch7:
        successful_exports += 1
    
    # Report overall results
    if successful_exports == total_exports:
        logger.info("All exports completed successfully")
    elif successful_exports > 0:
        logger.warning(f"Partial success: {successful_exports}/{total_exports} batches exported successfully")
    else:
        logger.error("All exports failed, check the logs for details")

2025-04-02 11:10:37,358 - __main__ - INFO - Starting export from Zarr files
2025-04-02 11:10:37,359 - __main__ - INFO - Processing batch 1
2025-04-02 11:10:37,359 - __main__ - INFO - Processing batch 1 with zones: ['zone_1']
2025-04-02 11:10:39,537 - __main__ - INFO - Loaded probabilities array with shape (5000, 5000, 2)
2025-04-02 11:10:40,288 - __main__ - INFO - Saved classification map for zone_1 (channel 1) to ../02_Results/geotiff_results/zone_1_class1.tif
2025-04-02 11:10:40,411 - __main__ - INFO - Saved classification map for zone_1 (channel 2) to ../02_Results/geotiff_results/zone_1_class2.tif
2025-04-02 11:10:41,002 - __main__ - INFO - Saved combined classification map for zone_1 to ../02_Results/geotiff_results/zone_1_classification.tif
2025-04-02 11:10:41,023 - __main__ - INFO - Processing batch 2
2025-04-02 11:10:41,024 - __main__ - INFO - Processing batch 2 with zones: ['zone_3']
2025-04-02 11:10:41,745 - __main__ - INFO - Loaded probabilities array with shape (5000, 5000,