## GEOTifs from predictions

Taking these to QGIS, resolution issues need to be correct and defined.

In [1]:
import os
import logging
import zarr
import rasterio
from rasterio.transform import from_bounds
import numpy as np
import logging
import importlib.util
import sys
import os
import importlib

In [12]:
results = "../02_Results"
spatial_coord = "spatial_coord.py"
zarr_file_path = os.path.join(results, "prediction_results_smote.zarr")

export, probability

In [18]:
def export_zarr_predictions_to_geotiff(zarr_file, zone_id, output_path, zone_name=None):
    """
    Export prediction results from a Zarr file to GeoTIFF format with proper georeferencing.
    
    Parameters:
    -----------
    zarr_file : str
        Path to the zarr file containing the prediction results
    zone_id : int
        The ID of the zone (index in zone_boundaries dictionary)
    output_path : str
        Directory to save the GeoTIFF files
    zone_name : str, optional
        The zone name to use in output files (e.g., "zone_2"). If None, will use "zone_{zone_id+1}"
    """

    logger = logging.getLogger(__name__)
    
    # Import zone_boundaries from spatial_coord.py
    try:
        # Try to import using importlib
        spec = importlib.util.spec_from_file_location("spatial_coord", "spatial_coord.py")
        if spec is None:
            # If not found in current directory, try to find it in parent directories
            import glob
            spatial_coord_files = glob.glob("**/spatial_coord.py", recursive=True)
            if spatial_coord_files:
                spec = importlib.util.spec_from_file_location("spatial_coord", spatial_coord_files[0])
            else:
                raise FileNotFoundError("Could not find spatial_coord.py")
                
        spatial_coord = importlib.util.module_from_spec(spec)
        sys.modules["spatial_coord"] = spatial_coord
        spec.loader.exec_module(spatial_coord)
        
        # Get zone_boundaries from the module
        zone_boundaries = spatial_coord.zone_boundaries
        logger.info(f"Successfully imported zone_boundaries from spatial_coord.py")
    except Exception as e:
        logger.error(f"Error importing zone_boundaries: {e}")
        logger.error("Using a default resolution for GeoTIFF export")
        # Use a fallback approach with a default resolution
        zone_boundaries = None
    
    # Ensure output directory exists
    os.makedirs(output_path, exist_ok=True)
    
    # Determine which zone name to use in output files
    if zone_name is None:
        output_zone_name = f"zone_{zone_id+1}"
    else:
        output_zone_name = zone_name
    
    # Get zone boundaries from the dictionary or use defaults
    if zone_boundaries is not None and zone_id in zone_boundaries:
        # Get boundaries for this zone_id
        ul_x, ul_y = zone_boundaries[zone_id]["upper_left"]
        lr_x, lr_y = zone_boundaries[zone_id]["lower_right"]
        logger.info(f"Using zone boundaries from spatial_coord.py for zone_id {zone_id} / {output_zone_name}")
    else:
        logger.warning(f"Zone ID {zone_id} not found in boundaries dictionary, using default coordinates")
        # Default to a square 2500x2500 meter area around a point in Finland if boundaries are not available
        # These are dummy values and should be replaced with actual coordinates
        ul_x, ul_y = 375000.0, 6865000.0
        lr_x, lr_y = ul_x + 2500, ul_y - 2500
    
    try:
        # Open Zarr file
        root = zarr.open(zarr_file, mode="r")
        
        # Check for predictions group
        if "predictions" not in root:
            logger.error("No 'predictions' group found in Zarr file")
            return False
            
        predictions_group = root["predictions"]
        
        # Check for spatial group
        if "spatial" not in predictions_group:
            logger.error("No 'spatial' group found in predictions group")
            return False
            
        spatial_group = predictions_group["spatial"]
        
        # Look for zone-specific group within spatial
        zarr_zone_name = output_zone_name
        
        # Check if zone exists directly in spatial group
        if zarr_zone_name in spatial_group:
            zone_group = spatial_group[zarr_zone_name]
            
            # Get stream and ditch probability data
            if "prob_streams" in zone_group:
                stream_data = zone_group["prob_streams"][:]
                logger.info(f"Found stream data for {zarr_zone_name} with shape {stream_data.shape}")
            else:
                logger.error(f"No 'prob_streams' array found for {zarr_zone_name}")
                return False
                
            if "prob_ditches" in zone_group:
                ditch_data = zone_group["prob_ditches"][:]
                logger.info(f"Found ditch data for {zarr_zone_name} with shape {ditch_data.shape}")
            else:
                logger.error(f"No 'prob_ditches' array found for {zarr_zone_name}")
                return False
                
            # Calculate dimensions and create transform
            grid_height, grid_width = stream_data.shape
            transform = from_bounds(ul_x, lr_y, lr_x, ul_y, grid_width, grid_height)
            
            # Save stream probability map
            stream_filename = f"{output_path}/{output_zone_name}_stream_prob.tif"
            meta = {
                'driver': 'GTiff',
                'height': grid_height,
                'width': grid_width,
                'count': 1,
                'dtype': str(stream_data.dtype),
                'crs': 'EPSG:3067',  # Finnish ETRS-TM35FIN coordinate system
                'transform': transform,
                'nodata': 0
            }
            
            with rasterio.open(stream_filename, 'w', **meta) as dst:
                dst.write(stream_data, 1)
            logger.info(f"Saved stream probability map for {output_zone_name} to {stream_filename}")
            
            # Save ditch probability map
            ditch_filename = f"{output_path}/{output_zone_name}_ditch_prob.tif"
            meta.update(dtype=str(ditch_data.dtype))
            
            with rasterio.open(ditch_filename, 'w', **meta) as dst:
                dst.write(ditch_data, 1)
            logger.info(f"Saved ditch probability map for {output_zone_name} to {ditch_filename}")
            
            # Check if combined prediction already exists in the Zarr file
            if "pred_combined" in zone_group:
                classification = zone_group["pred_combined"][:]
                logger.info(f"Using existing combined classification from Zarr for {output_zone_name}")
            else:
                # Create a combined classification map
                logger.info(f"Generating combined classification for {output_zone_name}")
                all_probs = np.zeros((3, grid_height, grid_width))
                all_probs[0] = 1.0 - (stream_data + ditch_data)  # Background probability
                all_probs[1] = stream_data  # Stream probability
                all_probs[2] = ditch_data  # Ditch probability
                
                # Get the class with highest probability for each pixel
                classification = np.argmax(all_probs, axis=0).astype(np.uint8)
            
            # Define output filename for classification map
            class_filename = f"{output_path}/{output_zone_name}_classification.tif"
            
            # Update metadata for uint8 datatype
            meta.update(dtype='uint8')
            
            # Write classification GeoTIFF
            with rasterio.open(class_filename, 'w', **meta) as dst:
                dst.write(classification, 1)
            
            logger.info(f"Saved classification map for {output_zone_name} to {class_filename}")
            
            return True
        else:
            logger.error(f"Zone {zarr_zone_name} not found in spatial group")
            return False
    
    except Exception as e:
        logger.error(f"Error exporting predictions from Zarr to GeoTIFF: {e}")
        import traceback
        logger.error(traceback.format_exc())
    
    return False

def export_all_zones_from_zarr(zarr_file, output_path, target_zones=None):
    """
    Export all target zones from a Zarr file to GeoTIFF format.
    
    Parameters:
    -----------
    zarr_file : str
        Path to the zarr file containing the prediction results
    output_path : str
        Directory to save the GeoTIFF files
    target_zones : list, optional
        List of zone names to process (e.g., ["zone_2", "zone_5", "zone_9"]).
        If None, will export all zones available in the spatial_coord.py file.
    """

    logger = logging.getLogger(__name__)
    
    # Ensure output directory exists
    os.makedirs(output_path, exist_ok=True)
    
    # Define specific zone name to ID mapping based on the provided information
    # This is the crucial part to fix the issue
    zone_name_to_id = {
        "zone_1": 0,
        "zone_2": 1,   # zone_2 is index 1
        "zone_3": 2,
        "zone_4": 3,
        "zone_5": 4,   # zone_5 is index 4
        "zone_6": 5,
        "zone_7": 6,
        "zone_8": 7,
        "zone_9": 8,   # zone_9 is index 8
        "zone_10": 9,
        "zone_11": 10,
        "zone_12": 11,
        "zone_13": 12,
        "zone_14": 13,
        "zone_15": 14,
        "zone_16": 15,
        "zone_17": 16,
        "zone_18": 17,
        "zone_19": 18,
        "zone_20": 19,
        "zone_21": 20
    }
    
    # Try to import zone_boundaries to get all available zones if not specified
    try:
        spec = importlib.util.spec_from_file_location("spatial_coord", "spatial_coord.py")
        if spec is None:
            # If not found in current directory, try to find it in parent directories
            import glob
            spatial_coord_files = glob.glob("**/spatial_coord.py", recursive=True)
            if spatial_coord_files:
                spec = importlib.util.spec_from_file_location("spatial_coord", spatial_coord_files[0])
            else:
                raise FileNotFoundError("Could not find spatial_coord.py")
                
        spatial_coord = importlib.util.module_from_spec(spec)
        sys.modules["spatial_coord"] = spatial_coord
        spec.loader.exec_module(spatial_coord)
        
        # Get zone_boundaries from the module
        zone_boundaries = spatial_coord.zone_boundaries
        all_zones = [f"zone_{i+1}" for i in range(len(zone_boundaries))]
        logger.info(f"Found {len(all_zones)} zones in spatial_coord.py")
    except Exception as e:
        logger.error(f"Error importing zone_boundaries: {e}")
        zone_boundaries = None
        all_zones = []
    
    # If target_zones is provided, process only those zones
    zones_to_process = target_zones if target_zones is not None else all_zones
    
    if not zones_to_process:
        logger.error("No zones specified to process and could not find zones in spatial_coord.py")
        return
    
    for zone_name in zones_to_process:
        try:
            # Get the correct zone_id from our mapping
            if zone_name in zone_name_to_id:
                zone_id = zone_name_to_id[zone_name]
                logger.info(f"Processing {zone_name}: using zone_id={zone_id} from explicit mapping")
            else:
                # Fallback to numerical extraction if not in mapping
                zone_id = int(zone_name.split('_')[1]) - 1
                logger.warning(f"Zone {zone_name} not found in explicit mapping, using calculated zone_id={zone_id}")
            
            logger.info(f"Exporting {zone_name} (zone_id={zone_id}) from Zarr to GeoTIFF...")
            success = export_zarr_predictions_to_geotiff(zarr_file, zone_id, output_path, zone_name)
            
            if success:
                logger.info(f"Successfully exported {zone_name} to GeoTIFF")
            else:
                logger.warning(f"Failed to export {zone_name} to GeoTIFF")
                
        except (IndexError, ValueError) as e:
            logger.error(f"Invalid zone name format or error processing {zone_name}: {e}")

In [19]:
# Configure logging
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

if __name__ == "__main__":
    # Path to your Zarr file - replace with your actual path
    zarr_file = os.path.join(results, "prediction_results_smote.zarr")
    
    # Path for output GeoTIFFs - replace with your actual path
    output_path = "../02_Results/geotiffs_tests"
    
    # Only use the zones that actually exist in the Zarr file based on tree output
    target_zones = ["zone_2", "zone_5", "zone_9"]
    
    logger.info("Starting export of specific zones from Zarr file")
    
    # Export only the specified zones
    export_all_zones_from_zarr(zarr_file, output_path, target_zones=target_zones)
    
    logger.info("Export completed")

2025-03-31 11:51:13,947 - __main__ - INFO - Starting export of specific zones from Zarr file
2025-03-31 11:51:13,949 - __main__ - INFO - Found 21 zones in spatial_coord.py
2025-03-31 11:51:13,949 - __main__ - INFO - Processing zone_2: using zone_id=1 from explicit mapping
2025-03-31 11:51:13,950 - __main__ - INFO - Exporting zone_2 (zone_id=1) from Zarr to GeoTIFF...
2025-03-31 11:51:13,967 - __main__ - INFO - Successfully imported zone_boundaries from spatial_coord.py
2025-03-31 11:51:13,967 - __main__ - INFO - Using zone boundaries from spatial_coord.py for zone_id 1 / zone_2
2025-03-31 11:51:14,198 - __main__ - INFO - Found stream data for zone_2 with shape (5000, 5000)
2025-03-31 11:51:14,339 - __main__ - INFO - Found ditch data for zone_2 with shape (5000, 5000)
2025-03-31 11:51:14,533 - __main__ - INFO - Saved stream probability map for zone_2 to ../02_Results/geotiffs_tests/zone_2_stream_prob.tif
2025-03-31 11:51:14,718 - __main__ - INFO - Saved ditch probability map for zone_2 