## GEOTifs from predictions

Taking these to QGIS, resolution issues need to be correct and defined.
- I started to use dates with test and results files. 

In [1]:
import os
import logging
import zarr
import rasterio
from rasterio.transform import from_bounds
import numpy as np
import logging
import importlib.util
import sys
import os
import importlib
os.environ["GTIFF_SRS_SOURCE"] = "EPSG"

export, probability

In [2]:
# Configure logging
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def export_zarr_predictions_to_geotiff(zarr_file, zone_id, output_path, zone_name=None):
    """
    Export prediction results from a Zarr file to GeoTIFF format with proper georeferencing.
    
    Parameters:
    -----------
    zarr_file : str
        Path to the zarr file containing the prediction results
    zone_id : int
        The ID of the zone (index in zone_boundaries dictionary)
    output_path : str
        Directory to save the GeoTIFF files
    zone_name : str, optional
        The zone name to use in output files (e.g., "zone_2"). If None, will use "zone_{zone_id+1}"
    """
    logger = logging.getLogger(__name__)
    
    # Determine zone name
    if zone_name is None:
        output_zone_name = f"zone_{zone_id+1}"
    else:
        output_zone_name = zone_name
    
    # Import zone_boundaries from spatial_coord.py
    try:
        # Try to import using importlib
        spec = importlib.util.spec_from_file_location("spatial_coord", "spatial_coord.py")
        if spec is None:
            # If not found in current directory, try to find it in parent directories
            import glob
            spatial_coord_files = glob.glob("**/spatial_coord.py", recursive=True)
            if spatial_coord_files:
                spec = importlib.util.spec_from_file_location("spatial_coord", spatial_coord_files[0])
            else:
                raise FileNotFoundError("Could not find spatial_coord.py")
                
        spatial_coord = importlib.util.module_from_spec(spec)
        sys.modules["spatial_coord"] = spatial_coord
        spec.loader.exec_module(spatial_coord)
        
        # Get zone_boundaries from the module
        zone_boundaries = spatial_coord.zone_boundaries
        logger.info(f"Successfully imported zone_boundaries from spatial_coord.py")
    except Exception as e:
        logger.error(f"Error importing zone_boundaries: {e}")
        logger.error("Using a default resolution for GeoTIFF export")
        # Use a fallback approach with a default resolution
        zone_boundaries = None
    
    # Ensure output directory exists
    os.makedirs(output_path, exist_ok=True)
    
    # Get zone boundaries from the dictionary or use defaults
    if zone_boundaries is not None and zone_id in zone_boundaries:
        # Get boundaries for this zone_id
        ul_x, ul_y = zone_boundaries[zone_id]["upper_left"]
        lr_x, lr_y = zone_boundaries[zone_id]["lower_right"]
        logger.info(f"Using zone boundaries from spatial_coord.py for zone_id {zone_id} / {output_zone_name}")
    else:
        logger.warning(f"Zone ID {zone_id} not found in boundaries dictionary, using default coordinates")
        # Default to a square 2500x2500 meter area around a point in Finland if boundaries are not available
        # These are dummy values and should be replaced with actual coordinates
        ul_x, ul_y = 375000.0, 6865000.0
        lr_x, lr_y = ul_x + 2500, ul_y - 2500
    
    try:
        root = zarr.open(zarr_file, mode="r")
        
        # Based on your Zarr tree structure, the data is directly in zones/zone_X
        if "zones" in root:
            zones_group = root["zones"]
            
            # Check if the zone exists
            if output_zone_name in zones_group:
                zone_group = zones_group[output_zone_name]
                
                # Get stream and ditch probability data
                if "prob_streams" in zone_group:
                    stream_data = zone_group["prob_streams"][:]
                    logger.info(f"Found stream data for {output_zone_name} with shape {stream_data.shape}")
                else:
                    logger.error(f"No 'prob_streams' array found for {output_zone_name}")
                    return False
                    
                if "prob_ditches" in zone_group:
                    ditch_data = zone_group["prob_ditches"][:]
                    logger.info(f"Found ditch data for {output_zone_name} with shape {ditch_data.shape}")
                else:
                    logger.error(f"No 'prob_ditches' array found for {output_zone_name}")
                    return False
                
                # Calculate pixel resolution
                height, width = stream_data.shape
                x_res = (lr_x - ul_x) / width
                y_res = (ul_y - lr_y) / height
                
                # Create transform for GeoTIFF
                transform = from_bounds(ul_x, lr_y, lr_x, ul_y, width, height)
                
                # Save stream probabilities as GeoTIFF
                stream_output = os.path.join(output_path, f"{output_zone_name}_stream_prob.tif")
                with rasterio.open(
                    stream_output,
                    'w',
                    driver='GTiff',
                    height=height,
                    width=width,
                    count=1,
                    dtype=stream_data.dtype,
                    crs='+proj=utm +zone=35 +datum=WGS84 +units=m +no_defs',
                    transform=transform,
                ) as dst:
                    dst.write(stream_data, 1)
                logger.info(f"Saved stream probabilities to {stream_output}")
                
                # Save ditch probabilities as GeoTIFF
                ditch_output = os.path.join(output_path, f"{output_zone_name}_ditch_prob.tif")
                with rasterio.open(
                    ditch_output,
                    'w',
                    driver='GTiff',
                    height=height,
                    width=width,
                    count=1,
                    dtype=ditch_data.dtype,
                    crs='+proj=utm +zone=35 +datum=WGS84 +units=m +no_defs',
                    transform=transform,
                ) as dst:
                    dst.write(ditch_data, 1)
                logger.info(f"Saved ditch probabilities to {ditch_output}")
                
                return True
            else:
                logger.error(f"Zone {output_zone_name} not found in zones group")
                return False
        else:
            logger.error("No 'zones' group found in Zarr file")
            return False
            
    except Exception as e:
        logger.error(f"Error exporting predictions from Zarr to GeoTIFF: {e}")
        import traceback
        logger.error(traceback.format_exc())
    
    return False

def export_all_zones_from_zarr(zarr_file, output_path, target_zones=None):
    """
    Export all target zones from a Zarr file to GeoTIFF format.
    
    Parameters:
    -----------
    zarr_file : str
        Path to the zarr file containing the prediction results
    output_path : str
        Directory to save the GeoTIFF files
    target_zones : list, optional
        List of zone names to process (e.g., ["zone_2", "zone_5", "zone_9"]).
        If None, will export all zones available in the spatial_coord.py file.
    """
    logger = logging.getLogger(__name__)
    
    # Ensure output directory exists
    os.makedirs(output_path, exist_ok=True)


    # Define specific zone name to ID mapping based on the provided information
    zone_name_to_id = {
        "zone_1": 0,
        "zone_2": 1,   # zone_2 is index 1
        "zone_3": 2,
        "zone_4": 3,
        "zone_5": 4,   # zone_5 is index 4
        "zone_6": 5,
        "zone_7": 6,
        "zone_8": 7,
        "zone_9": 8,   # zone_9 is index 8
        "zone_10": 9,
        "zone_11": 10,
        "zone_12": 11,
        "zone_13": 12,
        "zone_14": 13,
        "zone_15": 14,
        "zone_16": 15,
        "zone_17": 16,
        "zone_18": 17,
        "zone_19": 18,
        "zone_20": 19,
        "zone_21": 20
    }
    
    # Try to import zone_boundaries to get all available zones if not specified
    try:
        spec = importlib.util.spec_from_file_location("spatial_coord", "spatial_coord.py")
        if spec is None:
            # If not found in current directory, try to find it in parent directories
            import glob
            spatial_coord_files = glob.glob("**/spatial_coord.py", recursive=True)
            if spatial_coord_files:
                spec = importlib.util.spec_from_file_location("spatial_coord", spatial_coord_files[0])
            else:
                raise FileNotFoundError("Could not find spatial_coord.py")
                
        spatial_coord = importlib.util.module_from_spec(spec)
        sys.modules["spatial_coord"] = spatial_coord
        spec.loader.exec_module(spatial_coord)
        
        # Get zone_boundaries from the module
        zone_boundaries = spatial_coord.zone_boundaries
        all_zones = [f"zone_{i+1}" for i in range(len(zone_boundaries))]
        logger.info(f"Found {len(all_zones)} zones in spatial_coord.py")
    except Exception as e:
        logger.error(f"Error importing zone_boundaries: {e}")
        zone_boundaries = None
        all_zones = []
    
    # If target_zones is provided, process only those zones
    zones_to_process = target_zones if target_zones is not None else all_zones
    
    if not zones_to_process:
        logger.error("No zones specified to process and could not find zones in spatial_coord.py")
        return
    
    for zone_name in zones_to_process:
        try:
            # Get the correct zone_id from our mapping
            if zone_name in zone_name_to_id:
                zone_id = zone_name_to_id[zone_name]
                logger.info(f"Processing {zone_name}: using zone_id={zone_id} from explicit mapping")
            else:
                # Fallback to numerical extraction if not in mapping
                zone_id = int(zone_name.split('_')[1]) - 1
                logger.warning(f"Zone {zone_name} not found in explicit mapping, using calculated zone_id={zone_id}")
            
            logger.info(f"Exporting {zone_name} (zone_id={zone_id}) from Zarr to GeoTIFF...")
            success = export_zarr_predictions_to_geotiff(zarr_file, zone_id, output_path, zone_name)
            
            if success:
                logger.info(f"Successfully exported {zone_name} to GeoTIFF")
            else:
                logger.warning(f"Failed to export {zone_name} to GeoTIFF")
                
        except (IndexError, ValueError) as e:
            logger.error(f"Invalid zone name format or error processing {zone_name}: {e}")

In [3]:
if __name__ == "__main__":
    # Use the path variables you provided
    results = "../02_Results/1204/"
    zarr_file = os.path.join(results, "predictions.zarr")  # Assuming this is your Zarr file name
    
    # Path for output GeoTIFFs
    output_path = "../02_Results/geotiffs_1204"
    
    # Ensure output directory exists
    os.makedirs(output_path, exist_ok=True)
    
    # Check if zarr_file exists
    if not os.path.exists(zarr_file):
        logger.error(f"Zarr file path does not exist: {zarr_file}")
        logger.info(f"Checking for Zarr storage without .zarr extension...")
        # Try without .zarr extension
        zarr_file = results
        if not os.path.exists(zarr_file):
            logger.error(f"Zarr directory path also does not exist: {zarr_file}")
            sys.exit(1)
    
    # Only use the zones that actually exist in the Zarr file based on tree output
    target_zones = ["zone_1", "zone_2", "zone_3", "zone_4", "zone_5", 
                    "zone_6", "zone_7", "zone_8", "zone_9", "zone_10"]
    
    logger.info("Starting export of specific zones from Zarr file")
    export_all_zones_from_zarr(zarr_file, output_path, target_zones=target_zones)
    logger.info("Export completed")

2025-04-12 23:09:30,316 - __main__ - INFO - Starting export of specific zones from Zarr file
2025-04-12 23:09:30,319 - __main__ - INFO - Found 21 zones in spatial_coord.py
2025-04-12 23:09:30,319 - __main__ - INFO - Processing zone_1: using zone_id=0 from explicit mapping
2025-04-12 23:09:30,319 - __main__ - INFO - Exporting zone_1 (zone_id=0) from Zarr to GeoTIFF...
2025-04-12 23:09:30,319 - __main__ - INFO - Successfully imported zone_boundaries from spatial_coord.py
2025-04-12 23:09:30,325 - __main__ - INFO - Using zone boundaries from spatial_coord.py for zone_id 0 / zone_1
2025-04-12 23:09:30,569 - __main__ - INFO - Found stream data for zone_1 with shape (5000, 5000)
2025-04-12 23:09:30,746 - __main__ - INFO - Found ditch data for zone_1 with shape (5000, 5000)
2025-04-12 23:09:31,046 - __main__ - INFO - Saved stream probabilities to ../02_Results/geotiffs_1204\zone_1_stream_prob.tif
2025-04-12 23:09:31,314 - __main__ - INFO - Saved ditch probabilities to ../02_Results/geotiffs_1

In [51]:
import zarr

def print_zarr_tree(group, indent=0):
    prefix = " " * indent
    print(f"{prefix}{group.path or '/'}")

    for key, item in group.items():
        if isinstance(item, zarr.hierarchy.Group):
            print_zarr_tree(item, indent=indent + 4)
        elif isinstance(item, zarr.core.Array):
            shape = item.shape
            dtype = item.dtype
            print(f"{' ' * (indent + 4)}{key} {shape} {dtype}")

# Open the store
store = zarr.open('../02_Results/0804/predictions.zarr', mode='r')

# Print the structure
print_zarr_tree(store)


/
    metadata
        metadata/boundaries
            metadata/boundaries/zone_1
            metadata/boundaries/zone_10
            metadata/boundaries/zone_11
            metadata/boundaries/zone_12
            metadata/boundaries/zone_13
            metadata/boundaries/zone_14
            metadata/boundaries/zone_15
            metadata/boundaries/zone_16
            metadata/boundaries/zone_17
            metadata/boundaries/zone_18
            metadata/boundaries/zone_19
            metadata/boundaries/zone_2
            metadata/boundaries/zone_20
            metadata/boundaries/zone_21
            metadata/boundaries/zone_3
            metadata/boundaries/zone_4
            metadata/boundaries/zone_5
            metadata/boundaries/zone_6
            metadata/boundaries/zone_7
            metadata/boundaries/zone_8
            metadata/boundaries/zone_9
    zones
        zones/zone_1
            classification (5000, 5000) uint8
            pred_combined (5000, 5000) float32
  

In [None]:
# Example if zarr in cloud
import zarr
import fsspec

fs = fsspec.filesystem('s3')
mapper = fs.get_mapper('bucket-name/path/to/zarr-store')
store = zarr.open(mapper, mode='r')
store.tree()
