I decided to share this too despite the model needs more practising with RF. This saves the results as images. 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import zarr
import psutil
import joblib
import multiprocessing
import pandas as pd
import glob
from datetime import datetime
import re
import logging
from sklearn.metrics import precision_score, recall_score, f1_score
from rasterio.transform import from_bounds
import rasterio
from rasterio.crs import CRS
from rasterio.transform import from_bounds
import os


# Finnish coordinate system (ETRS89 / TM35FIN)
# Define using WKT string
crs = CRS.from_wkt("""
PROJCS["ETRS89 / TM35FIN",
    GEOGCS["ETRS89",
        DATUM["European_Terrestrial_Reference_System_1989",
            SPHEROID["GRS 1980",6378137,298.257222101]],
        PRIMEM["Greenwich",0],
        UNIT["degree",0.0174532925199433]],
    PROJECTION["Transverse_Mercator"],
    PARAMETER["latitude_of_origin",0],
    PARAMETER["central_meridian",27],
    PARAMETER["scale_factor",0.9996],
    PARAMETER["false_easting",500000],
    PARAMETER["false_northing",0],
    UNIT["metre",1]]
""")

# Or using PROJ string
crs = CRS.from_string("+proj=utm +zone=35 +ellps=GRS80 +towgs84=0,0,0,0,0,0,0 +units=m +no_defs +type=crs")


# Setup logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, 
                   format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# Configuration variables
#zarr_file = "zones_data_2.zarr"  # train and practicing zarr file
zarr_file = "zones_data_11-21.zarr"
total_cpus = min(4, multiprocessing.cpu_count())
input_path = "../02_Results/"  # model is in here, results also are saved in here
output_path = "../02_Results/1204_2/"  # New results from using the model
models_dir = os.path.join(input_path, "models")
cache_dir = os.path.join(input_path, "joblib_cache")
#zones_to_process = ["zone_1", "zone_2", "zone_3", "zone_4", "zone_5", 
#                    "zone_6", "zone_7", "zone_8", "zone_9", "zone_10"]
zones_to_process = ["zone_11", "zone_12", "zone_13", "zone_14", "zone_15", 
                    "zone_16", "zone_17", "zone_18", "zone_19", "zone_20", "zone_21"]

selected_features = ['row_idx', 'col_idx', 'impoundment_amplified', 'zone_id', 'skyview_gabor',
                    'impoundment_raw', 'conic_mean', 'hpmf_raw', 'skyview_raw', 'hpmf_f',
                    'slope_channels']

# Global dictionary to store zone name to ID mapping
zone_name_to_id = {f"zone_{i}": i for i in range(1, 22)}
# Inverse mapping from ID to name
zone_id_to_name = {i: f"zone_{i}" for i in range(1, 22)}
# Global dictionary to store zone boundaries
zone_boundaries = {f"zone_{i}": {"upper_left": (0, 0), "lower_right": (5000, 5000)} for i in range(1, 22)}

# Define CRS (Coordinate Reference System) for output files
crs = CRS.from_epsg(3067)  # Finnish coordinate system, update as needed

def save_tif(output_path, data, transform, crs, dtype='float32'):
    """Save array data to a GeoTIFF file."""
    height, width = data.shape
    
    # Ensure output directory exists
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    with rasterio.open(
        output_path,
        'w',
        driver='GTiff',
        height=height,
        width=width,
        count=1,
        dtype=dtype,
        crs=crs,
        transform=transform,
        compress='lzw'
    ) as dst:
        dst.write(data, 1)
    
    logger.info(f"Saved GeoTIFF: {output_path}")

def print_memory_usage():
    """Print current memory usage."""
    process = psutil.Process()
    logger.info(f"Memory usage: {process.memory_info().rss / 1024 / 1024:.2f} MB")

def setup_joblib_cache(cache_dir):
    """Set up cache for joblib."""
    os.makedirs(cache_dir, exist_ok=True)
    memory = joblib.Memory(cache_dir, verbose=0)
    return memory

def fix_zarr_file_features(zarr_file):
    """Fix the zarr file by adding missing features to all zones."""
    logger.info(f"Fixing zarr file by adding missing features: {zarr_file}")
    
    # Open the zarr file
    root = zarr.open(zarr_file, mode='r+')  # Use 'r+' for read/write
    
    # Get all zone names
    zone_names = [name for name in root.keys() if name.startswith('zone_')]
    
    for zone_name in zone_names:
        zone = root[zone_name]
        
        # Add zone_id if missing
        if 'zone_id' not in zone:
            try:
                # Get zone_id from zone_name
                if zone_name in zone_name_to_id:
                    zone_id = zone_name_to_id[zone_name]
                else:
                    # Extract zone number from name
                    zone_id = int(zone_name.split('_')[1])  # Keep 1-indexed for data consistency
                
                # Create full array with zone_id
                logger.info(f"Adding zone_id feature to {zone_name}: {zone_id}")
                zone.create_dataset('zone_id', data=np.full(25000000, zone_id, dtype=np.int32))
                
            except (IndexError, ValueError) as e:
                logger.warning(f"Could not add zone_id to {zone_name}: {e}")
        
        # Add spatial indices if missing
        if 'row_idx' not in zone or 'col_idx' not in zone:
            logger.info(f"Adding spatial indices to {zone_name}")
            
            # Standard 5000x5000 grid
            spatial_shape = (5000, 5000)
            total_pixels = spatial_shape[0] * spatial_shape[1]
            
            # Generate row indices
            if 'row_idx' not in zone:
                row_indices = np.repeat(np.arange(spatial_shape[0]), spatial_shape[1])
                zone.create_dataset('row_idx', data=row_indices, dtype=np.int32)
            
            # Generate column indices
            if 'col_idx' not in zone:
                col_indices = np.tile(np.arange(spatial_shape[1]), spatial_shape[0])
                zone.create_dataset('col_idx', data=col_indices, dtype=np.int32)
            
            logger.info(f"Added spatial indices to {zone_name}")
    
    logger.info(f"Finished checking/fixing features in {zarr_file}")

def fix_load_zone_data(zarr_file, zone_name, selected_features):
    """Load zone data from zarr file and fix if needed."""
    logger.info(f"Loading and fixing data for {zone_name}")
    
    try:
        # Open the zarr file
        root = zarr.open(zarr_file, mode='r')
        
        # Check if the zone exists in the zarr file
        if zone_name not in root:
            logger.warning(f"Zone {zone_name} not found in zarr file")
            return pd.DataFrame()
        
        # Get the zone data
        zone_group = root[zone_name]
        
        # Convert to pandas DataFrame
        data_dict = {}
        for feature in selected_features:
            if feature in zone_group:
                data_dict[feature] = zone_group[feature][:]
            else:
                logger.warning(f"Feature {feature} not found in zone {zone_name}")
        
        if not data_dict:
            logger.warning(f"No feature data found for zone {zone_name}")
            return pd.DataFrame()
        
        # Create DataFrame
        df = pd.DataFrame(data_dict)
        
        # Add any missing features with default values if needed
        for feature in selected_features:
            if feature not in df.columns:
                if feature == 'zone_id':
                    df[feature] = zone_name_to_id.get(zone_name, -1)
                # Add other defaults as needed
        
        return df
        
    except Exception as e:
        logger.error(f"Error loading data for zone {zone_name}: {str(e)}")
        import traceback
        logger.error(traceback.format_exc())
        return pd.DataFrame()

def load_models(models_dir):
    """Load all trained models from directory."""
    model_files = sorted(glob.glob(os.path.join(models_dir, "rf_model_*.joblib")))
    if not model_files:
        logger.error(f"No model files found in {models_dir}")
        return None
    
    logger.info(f"Loading {len(model_files)} models...")
    models = []
    for model_file in model_files:
        try:
            model = joblib.load(model_file)
            models.append(model)
            logger.info(f"Loaded model: {os.path.basename(model_file)}")
        except Exception as e:
            logger.error(f"Error loading model {model_file}: {str(e)}")
    
    if not models:
        logger.error("No models could be loaded successfully")
        return None
        
    return models

def make_parallel_predictions(models, X):
    """Make predictions in parallel using multiple models."""
    if not models:
        logger.error("No models provided for prediction")
        return np.zeros((len(X), 3))
        
    logger.info(f"Making parallel predictions with {len(models)} models")
    
    # Get the number of classes from the first model
    n_classes = len(models[0].classes_) if hasattr(models[0], 'classes_') else 3
    
    # Initialize array to store probabilities
    all_probabilities = np.zeros((len(X), n_classes))
    
    # Make predictions with each model and average them
    for i, model in enumerate(models):
        try:
            logger.info(f"Running predictions with model {i+1}/{len(models)}")
            y_prob = model.predict_proba(X)
            all_probabilities += y_prob
        except Exception as e:
            logger.error(f"Error making predictions with model {i+1}: {str(e)}")
    
    # Average the predictions across all models
    if models:
        all_probabilities /= len(models)
    
    return all_probabilities

def reshape_predictions_for_visualization(y_prob, spatial_indices, spatial_shape):
    """Reshape predictions for visualization."""
    logger.info("Reshaping predictions for visualization")
    
    # Extract class probabilities
    if y_prob.shape[1] >= 3:
        prob_background = y_prob[:, 0]
        prob_streams = y_prob[:, 1]
        prob_ditches = y_prob[:, 2]
    else:
        # Handle case where we have fewer classes
        prob_background = 1.0 - np.sum(y_prob, axis=1)
        prob_streams = y_prob[:, 0] if y_prob.shape[1] > 0 else np.zeros(y_prob.shape[0])
        prob_ditches = y_prob[:, 1] if y_prob.shape[1] > 1 else np.zeros(y_prob.shape[0])
    
    # Initialize spatial arrays
    prob_streams_spatial = np.zeros(spatial_shape)
    prob_ditches_spatial = np.zeros(spatial_shape)
    prob_combined_spatial = np.zeros(spatial_shape)
    
    # Apply threshold for binary predictions (0.5 by default)
    threshold = 0.5
    pred_streams_spatial = np.zeros(spatial_shape, dtype=np.int8)
    pred_ditches_spatial = np.zeros(spatial_shape, dtype=np.int8)
    pred_combined_spatial = np.zeros(spatial_shape, dtype=np.int8)
    
    # Map predictions to spatial grid using indices
    for i in range(len(spatial_indices)):
        row, col = spatial_indices[i]
        if 0 <= row < spatial_shape[0] and 0 <= col < spatial_shape[1]:
            prob_streams_spatial[row, col] = prob_streams[i]
            prob_ditches_spatial[row, col] = prob_ditches[i]
            prob_combined_spatial[row, col] = max(prob_streams[i], prob_ditches[i])
            
            # Binary predictions
            pred_streams_spatial[row, col] = 1 if prob_streams[i] > threshold else 0
            pred_ditches_spatial[row, col] = 1 if prob_ditches[i] > threshold else 0
            pred_combined_spatial[row, col] = 1 if prob_streams[i] > threshold else (2 if prob_ditches[i] > threshold else 0)
    
    # Create predictions structure
    predictions = {
        'raw': {
            'probabilities': y_prob
        },
        'spatial': {
            'prob_streams': prob_streams_spatial,
            'prob_ditches': prob_ditches_spatial,
            'prob_combined': prob_combined_spatial,
            'pred_streams': pred_streams_spatial,
            'pred_ditches': pred_ditches_spatial,
            'pred_combined': pred_combined_spatial
        }
    }
    
    return predictions

def create_ground_truth_visualization(zone_data, spatial_shape):
    """Create ground truth visualization from zone data."""
    logger.info("Creating ground truth visualization")
    
    # Initialize ground truth arrays
    ground_truth_streams = np.zeros(spatial_shape, dtype=np.int8)
    ground_truth_ditches = np.zeros(spatial_shape, dtype=np.int8)
    
    # Check if we have ground truth labels
    if 'label_3m' in zone_data.columns:
        # Get spatial indices
        row_indices = zone_data['row_idx'].values
        col_indices = zone_data['col_idx'].values
        labels = zone_data['label_3m'].values
        
        # Map labels to spatial grid
        for i in range(len(labels)):
            row, col = row_indices[i], col_indices[i]
            if 0 <= row < spatial_shape[0] and 0 <= col < spatial_shape[1]:
                if labels[i] == 1:  # Stream
                    ground_truth_streams[row, col] = 1
                elif labels[i] == 2:  # Ditch
                    ground_truth_ditches[row, col] = 1
    else:
        logger.warning("No ground truth labels found in zone data")
    
    ground_truth = {
        'streams': ground_truth_streams,
        'ditches': ground_truth_ditches
    }
    
    return ground_truth

def plot_predictions(predictions, title_prefix="Predicted Water Features", output_path="."):
    """Plot predictions."""
    logger.info(f"Plotting predictions with title prefix: {title_prefix}")
    
    if not isinstance(predictions, dict) or 'spatial' not in predictions:
        logger.error("Invalid predictions structure")
        return
    
    spatial_data = predictions['spatial']
    
    # Create figure with 3 subplots
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Plot probability maps
    if 'prob_streams' in spatial_data:
        stream_im = axes[0].imshow(spatial_data['prob_streams'], cmap=plt.cm.Blues, vmin=0, vmax=1)
        axes[0].set_title('Stream Probability')
        plt.colorbar(stream_im, ax=axes[0], fraction=0.046, pad=0.04)
    
    if 'prob_ditches' in spatial_data:
        ditch_im = axes[1].imshow(spatial_data['prob_ditches'], cmap=plt.cm.Greens, vmin=0, vmax=1)
        axes[1].set_title('Ditch Probability')
        plt.colorbar(ditch_im, ax=axes[1], fraction=0.046, pad=0.04)
    
    if 'pred_combined' in spatial_data:
        # Combined class visualization
        cmap_combined = plt.cm.colors.ListedColormap(['white', 'blue', 'green'])
        bounds_combined = [-0.5, 0.5, 1.5, 2.5]
        norm_combined = plt.cm.colors.BoundaryNorm(bounds_combined, cmap_combined.N)
        
        combined_im = axes[2].imshow(spatial_data['pred_combined'], cmap=cmap_combined, norm=norm_combined)
        axes[2].set_title('Predicted Classes')
        cbar_combined = plt.colorbar(combined_im, ax=axes[2], ticks=[0, 1, 2], fraction=0.046, pad=0.04)
        cbar_combined.set_ticklabels(['None', 'Stream', 'Ditch'])
    
    # Remove ticks from all subplots
    for ax in axes:
        ax.set_xticks([])
        ax.set_yticks([])
    
    plt.suptitle(f"{title_prefix}", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust for the suptitle
    
    # Ensure output directory exists
    os.makedirs(output_path, exist_ok=True)
    
    # Create a safe filename for output
    safe_filename = re.sub(r'[^\w\-_\.]', '_', title_prefix.lower().replace(' ', '_'))
    output_file = os.path.join(output_path, f"{safe_filename}.png")
    
    # Save the figure
    try:
        logger.info(f"Saving figure to: {output_file}")
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        logger.info(f"Figure saved successfully")
    except Exception as e:
        logger.error(f"Error saving figure: {e}")
    
    plt.close(fig)

def compare_predictions_with_ground_truth(predictions, ground_truth, title_prefix="Comparison with Ground Truth", output_path="."):
    """Compare predictions with ground truth."""
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Define simple colormaps
    stream_cmap = plt.cm.Blues
    ditch_cmap = plt.cm.Greens
    combined_cmap = plt.cm.colors.ListedColormap(['white', 'blue', 'green'])
    
    # Check if the ground_truth has the expected structure
    if isinstance(ground_truth, dict) and 'streams' in ground_truth and 'ditches' in ground_truth:
        # Ground truth visualizations
        axes[0, 0].imshow(ground_truth['streams'], cmap=stream_cmap)
        axes[0, 0].set_title('Ground Truth - Streams')
        
        axes[0, 1].imshow(ground_truth['ditches'], cmap=ditch_cmap)
        axes[0, 1].set_title('Ground Truth - Ditches')
        
        # Create combined ground truth visualization
        combined_gt = np.zeros_like(ground_truth['streams'], dtype=np.int8)
        combined_gt[ground_truth['streams'] == 1] = 1
        combined_gt[ground_truth['ditches'] == 1] = 2
        
        cmap_gt = plt.cm.colors.ListedColormap(['white', 'blue', 'green'])
        bounds_gt = [-0.5, 0.5, 1.5, 2.5]
        norm_gt = plt.cm.colors.BoundaryNorm(bounds_gt, cmap_gt.N)
        
        im_gt = axes[0, 2].imshow(combined_gt, cmap=cmap_gt, norm=norm_gt)
        axes[0, 2].set_title('Ground Truth - Combined')
        cbar_gt = plt.colorbar(im_gt, ax=axes[0, 2], ticks=[0, 1, 2], fraction=0.046, pad=0.04)
        cbar_gt.set_ticklabels(['None', 'Stream', 'Ditch'])
    else:
        logger.warning("Ground truth data is not in the expected format. Expected 'streams' and 'ditches' keys.")
        for i in range(3):
            axes[0, i].text(0.5, 0.5, 'Ground Truth Not Available', 
                          horizontalalignment='center', verticalalignment='center',
                          transform=axes[0, i].transAxes)
            axes[0, i].set_title(f"Ground Truth - {'Streams' if i==0 else 'Ditches' if i==1 else 'Combined'}")
    
    # Check if predictions has the expected structure
    if isinstance(predictions, dict) and 'spatial' in predictions:
        spatial_data = predictions['spatial']
        
        # Predicted streams visualization
        if 'pred_streams' in spatial_data:
            axes[1, 0].imshow(spatial_data['pred_streams'], cmap=stream_cmap)
            axes[1, 0].set_title('Predicted - Streams')
        else:
            logger.warning("Missing 'pred_streams' in predictions")
            axes[1, 0].text(0.5, 0.5, 'Predictions Not Available', 
                          horizontalalignment='center', verticalalignment='center',
                          transform=axes[1, 0].transAxes)
            axes[1, 0].set_title('Predicted - Streams')
        
        # Predicted ditches visualization
        if 'pred_ditches' in spatial_data:
            axes[1, 1].imshow(spatial_data['pred_ditches'], cmap=ditch_cmap)
            axes[1, 1].set_title('Predicted - Ditches')
        else:
            logger.warning("Missing 'pred_ditches' in predictions")
            axes[1, 1].text(0.5, 0.5, 'Predictions Not Available', 
                          horizontalalignment='center', verticalalignment='center',
                          transform=axes[1, 1].transAxes)
            axes[1, 1].set_title('Predicted - Ditches')
        
        # Combined prediction visualization
        if 'pred_combined' in spatial_data:
            pred_combined = spatial_data['pred_combined']
        elif 'pred_streams' in spatial_data and 'pred_ditches' in spatial_data:
            # Create combined prediction if not already present
            pred_combined = np.zeros_like(spatial_data['pred_streams'], dtype=np.int8)
            pred_combined[spatial_data['pred_streams'] == 1] = 1
            pred_combined[spatial_data['pred_ditches'] == 1] = 2
        else:
            pred_combined = None
            
        if pred_combined is not None:
            cmap_pred = plt.cm.colors.ListedColormap(['white', 'blue', 'green'])
            bounds_pred = [-0.5, 0.5, 1.5, 2.5]
            norm_pred = plt.cm.colors.BoundaryNorm(bounds_pred, cmap_pred.N)
            
            im_pred = axes[1, 2].imshow(pred_combined, cmap=cmap_pred, norm=norm_pred)
            axes[1, 2].set_title('Predicted - Combined')
            cbar_pred = plt.colorbar(im_pred, ax=axes[1, 2], ticks=[0, 1, 2], fraction=0.046, pad=0.04)
            cbar_pred.set_ticklabels(['None', 'Stream', 'Ditch'])
        else:
            logger.warning("Could not create combined prediction visualization")
            axes[1, 2].text(0.5, 0.5, 'Predictions Not Available', 
                          horizontalalignment='center', verticalalignment='center',
                          transform=axes[1, 2].transAxes)
            axes[1, 2].set_title('Predicted - Combined')
    else:
        logger.warning("Predictions data is not in the expected format. Expected 'spatial' key with prediction data.")
        for i in range(3):
            axes[1, i].text(0.5, 0.5, 'Predictions Not Available', 
                          horizontalalignment='center', verticalalignment='center',
                          transform=axes[1, i].transAxes)
            axes[1, i].set_title(f"Predicted - {'Streams' if i==0 else 'Ditches' if i==1 else 'Combined'}")
    
    # Remove ticks from all subplots
    for ax in axes.flatten():
        ax.set_xticks([])
        ax.set_yticks([])
    
    plt.suptitle(f"{title_prefix}", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust for the suptitle
    
    # Ensure output directory exists
    os.makedirs(output_path, exist_ok=True)
    
    # Create a safe filename for output
    safe_filename = re.sub(r'[^\w\-_\.]', '_', title_prefix.lower().replace(' ', '_'))
    output_file = os.path.join(output_path, f"{safe_filename}.png")
    
    # Save the figure
    try:
        logger.info(f"Saving figure to: {output_file}")
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        logger.info(f"Figure saved successfully")
    except Exception as e:
        logger.error(f"Error saving figure: {e}")
    
    plt.close(fig)

def predict_for_new_zone(zarr_file, zone_name, models_dir, selected_features, output_path):
    """Predict water features for a new zone."""
    logger.info(f"Predict for new zone {zone_name} using models from {models_dir}")
   
    # Load and fix zone data
    zone_data = fix_load_zone_data(zarr_file, zone_name, selected_features)
    if zone_data is None or len(zone_data) == 0:
        logger.warning(f"No data available for zone {zone_name}")
        return None
    
    # Check and add missing zone_id feature
    if 'zone_id' not in zone_data.columns:
        logger.info(f"Adding missing zone_id feature for {zone_name}")
        zone_data['zone_id'] = zone_name_to_id.get(zone_name, -1)  # Use -1 as fallback if name not found

    # Check and add missing spatial index features
    if 'row_idx' not in zone_data.columns:
        logger.info(f"Adding missing row_idx feature for {zone_name}")
        # Create row indices based on the data shape (assuming data is ordered)
        num_rows = 5000  # Based on spatial_shape=(5000, 5000)
        zone_data['row_idx'] = np.repeat(np.arange(num_rows), num_rows)
    
    if 'col_idx' not in zone_data.columns:
        logger.info(f"Adding missing col_idx feature for {zone_name}")
        # Create column indices based on the data shape (assuming data is ordered)
        num_cols = 5000  # Based on spatial_shape=(5000, 5000)
        num_rows = 5000
        zone_data['col_idx'] = np.tile(np.arange(num_cols), num_rows)
    
    # Extract feature columns for model input
    feature_cols = [col for col in selected_features if col in zone_data.columns]
    logger.info(f"Using features: {feature_cols}")
    X = zone_data[feature_cols]
    
    # Make sure spatial indices exist
    if 'row_idx' in zone_data.columns and 'col_idx' in zone_data.columns:
        spatial_indices = zone_data[['row_idx', 'col_idx']].values
    else:
        logger.warning("Missing spatial indices after fixing data - generating default")
        spatial_shape = (5000, 5000)
        row_indices = np.repeat(np.arange(spatial_shape[0]), spatial_shape[1])
        col_indices = np.tile(np.arange(spatial_shape[1]), spatial_shape[0])
        spatial_indices = np.column_stack((row_indices, col_indices))
    
    # Load trained models
    models = load_models(models_dir)
    if not models:
        logger.error(f"No models could be loaded from {models_dir}")
        return None
    
    # Make sure we have the right features for the model
    feature_names = models[0].feature_names_in_.tolist()
    logger.info(f"Model expects: {feature_names}")
    logger.info(f"Data has: {X.columns.tolist()}")
    
    # Make sure X has all the features the model expects
    missing_features = set(feature_names) - set(X.columns)
    if missing_features:
        logger.warning(f"Missing features in data: {missing_features}")
        # Add missing features with zeros
        for feature in missing_features:
            X[feature] = 0
    
    # Select only the features the model expects
    X = X[feature_names]
    
    # Make predictions
    final_probabilities = make_parallel_predictions(models, X)
    
    if final_probabilities.shape[1] >= 3:
        y_prob_final = final_probabilities[:, 1:3]  # Extract streams and ditches probabilities
    else:
        logger.warning(f"Unexpected probability shape: {final_probabilities.shape}. Using all columns.")
        y_prob_final = final_probabilities
    
    # Reshape predictions to spatial format
    spatial_shape = (5000, 5000) 
    predictions = reshape_predictions_for_visualization(final_probabilities, spatial_indices, spatial_shape)
    
    # Visualize predictions
    if 'label_3m' in zone_data.columns:
        ground_truth = create_ground_truth_visualization(zone_data, spatial_shape)
        
        plot_predictions(predictions, 
                        title_prefix=f"Predicted Water Features - {zone_name}",
                        output_path=output_path)

        compare_predictions_with_ground_truth(predictions, ground_truth, 
                                            title_prefix=f"Comparison with Ground Truth - {zone_name}",
                                            output_path=output_path)
    else:
        plot_predictions(predictions, 
                        title_prefix=f"Predicted Water Features - {zone_name}",
                        output_path=output_path)
    
    return predictions

def transform_predictions(predictions):
    """Transform predictions to the format needed for export."""
    transformed = {}
    
    for zone_name, zone_data in predictions.items():
        if 'spatial' in zone_data:
            spatial_data = zone_data['spatial']

            transformed[zone_name] = {
                'prob_streams': spatial_data.get('prob_streams'),
                'prob_ditches': spatial_data.get('prob_ditches'),
                'pred_combined': spatial_data.get('pred_combined')
            }
    
    return transformed

def calculate_spatial_metrics(predictions, ground_truth):
    """Calculate spatial metrics for predictions vs ground truth."""
    pred_streams = predictions['spatial']['pred_streams']
    pred_ditches = predictions['spatial']['pred_ditches']
    
    gt_streams = ground_truth['streams']
    gt_ditches = ground_truth['ditches']
    
    stream_metrics = {
        'precision': precision_score(gt_streams.flatten(), pred_streams.flatten(), zero_division=0),
        'recall': recall_score(gt_streams.flatten(), pred_streams.flatten(), zero_division=0),
        'f1': f1_score(gt_streams.flatten(), pred_streams.flatten(), zero_division=0), 
    }

    ditch_metrics = {
        'precision': precision_score(gt_ditches.flatten(), pred_ditches.flatten(), zero_division=0),
        'recall': recall_score(gt_ditches.flatten(), pred_ditches.flatten(), zero_division=0),
        'f1': f1_score(gt_ditches.flatten(), pred_ditches.flatten(), zero_division=0)
    }
    
    pred_combined = (pred_streams + pred_ditches > 0).astype(np.int8)
    gt_combined = (gt_streams + gt_ditches > 0).astype(np.int8)
    
    combined_metrics = {
        'precision': precision_score(gt_combined.flatten(), pred_combined.flatten(), zero_division=0),
        'recall': recall_score(gt_combined.flatten(), pred_combined.flatten(), zero_division=0),
        'f1': f1_score(gt_combined.flatten(), pred_combined.flatten(), zero_division=0)
    }
    
    return {
        'streams': stream_metrics,
        'ditches': ditch_metrics,
        'combined': combined_metrics
    }

def setup_zarr_output(output_path, zone_names=None):
    """Set up zarr output structure."""
    compression_params = dict(
        compressor=zarr.Blosc(cname='zstd', clevel=3, shuffle=zarr.Blosc.SHUFFLE)
    )

    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    root = zarr.open(output_path, mode="w")

    root.attrs['description'] = 'Water feature predictions from Random Forest models'
    root.attrs['creation_date'] = datetime.now().isoformat()
    root.attrs['model_type'] = 'Random Forest with SMOTE'

    predictions_group = root.create_group("predictions")
    predictions_group.attrs['class_labels'] = ['background', 'stream', 'ditch']

    raw_group = predictions_group.create_group("raw")
    spatial_group = predictions_group.create_group("spatial")

    if zone_names is None:
        zone_names = []

    root.attrs['zone_names'] = zone_names

    for zone in zone_names:
        spatial_group.create_group(zone)
    
    return root, predictions_group, raw_group, spatial_group, compression_params

def process_and_visualize_zone(zarr_file, zone_name, selected_features, output_path, 
                               trained_models, X_test, y_test, test_indices, 
                               zarr_output, compression_params):
    """Process and visualize a zone."""
    # Placeholder implementation
    logger.info(f"Processing and visualizing zone {zone_name}")
    return None

def process_selected_zones(zarr_file, selected_features, output_path, trained_models, X_test, y_test, test_indices, target_zones=None):
    """Process selected zones."""
    if target_zones is None:
        target_zones = ["zone_1", "zone_2", "zone_3", "zone_4", "zone_5", 
                        "zone_6", "zone_7", "zone_8", "zone_9", "zone_10"]
    
    output_zarr_file = os.path.join(output_path, "prediction_results_smote.zarr")
    
    with zarr.open_group(output_zarr_file, mode="w") as root:
        predictions_group = root.create_group("predictions")
        predictions_group.attrs['class_labels'] = ['background', 'stream', 'ditch']

        raw_group = predictions_group.create_group("raw")
        spatial_group = predictions_group.create_group("spatial")
        root.attrs['description'] = 'Water feature predictions from Random Forest models'
        root.attrs['creation_date'] = datetime.now().isoformat()
        root.attrs['model_type'] = 'Random Forest with SMOTE'
        root.attrs['zone_names'] = target_zones

        compression_params = {
            'compressor': zarr.Blosc(cname='zstd', clevel=3, shuffle=zarr.Blosc.SHUFFLE)
        }
        
        raw_group.create_dataset("y_test", data=y_test.to_numpy().astype(np.int8), 
                                chunks=True, **compression_params)

        if hasattr(X_test, 'to_numpy'):
            raw_group.create_dataset("X_test", data=X_test.to_numpy(), 
                                    chunks=True, **compression_params)

        if test_indices is not None:
            raw_group.create_dataset("test_indices", data=test_indices, 
                                    chunks=True, **compression_params)

        all_zone_details = {}
        all_zone_bounds = {}  # Ensure this is defined

        for zone_name in target_zones:
            logger.info(f"Processing {zone_name}")

            if zone_name in spatial_group:
                del spatial_group[zone_name]
            
            zone_group = spatial_group.create_group(zone_name)
            
            try:
                zone_number = int(zone_name.split('_')[1])
            except (IndexError, ValueError):
                logger.error(f"Invalid zone name: {zone_name}")
                zone_number = None
            
            zone_metadata = {
                'name': zone_name,
                'number': zone_number,
                'processed_date': datetime.now().isoformat()
            }
            
            zone_group.attrs.update(zone_metadata)
            
            all_zone_details[zone_name] = zone_metadata

            predictions = process_and_visualize_zone(
                zarr_file=zarr_file,
                zone_name=zone_name,
                selected_features=selected_features,
                output_path=output_path,
                trained_models=trained_models,
                X_test=X_test,
                y_test=y_test,
                test_indices=test_indices,
                zarr_output=(root, predictions_group, raw_group, spatial_group),
                compression_params=compression_params
            )

            if predictions and 'spatial' in predictions:
                for pred_type in ['pred_streams', 'pred_ditches', 'pred_combined', 
                                'prob_streams', 'prob_ditches', 'prob_combined']:
                    if pred_type in predictions['spatial']:
                        zone_group.create_dataset(
                            pred_type, 
                            data=predictions['spatial'][pred_type], 
                            dtype=np.float32 if 'prob' in pred_type else np.int8,
                            chunks=True, 
                            **compression_params
                        )

                zone_group.attrs['processed_date'] = datetime.now().isoformat()
                zone_group.attrs['zone_name'] = zone_name
                
                if 'zone_boundaries' in globals():
                    zone_boundary = zone_boundaries.get(zone_name, {})
                    if zone_boundary:
                        zone_group.attrs['upper_left'] = zone_boundary.get('upper_left')
                        zone_group.attrs['lower_right'] = zone_boundary.get('lower_right')
                        all_zone_bounds[zone_name] = zone_boundary
            
            logger.info(f"Saved data for {zone_name}")
        
        predictions_group.attrs['processed_zones'] = target_zones
        
        if all_zone_bounds:
            predictions_group.attrs['zone_boundaries'] = all_zone_bounds

        root.attrs['zone_details'] = all_zone_details
        logger.info(f"Saved prediction results to Zarr file: {output_zarr_file}")
    
    return output_zarr_file

def export_predictions(predictions, zone_boundaries, zone_name_to_id, output_path):
    """Export predictions to GeoTIFF files.
    
    Args:
        predictions: Dictionary with zone names as keys and prediction data as values,
                    or dictionary with a 'spatial' key containing prediction data.
        zone_boundaries: Dictionary mapping zone names to boundary coordinates.
        zone_name_to_id: Dictionary mapping zone names to numeric IDs.
        output_path: Directory where GeoTIFF files will be saved.
    """

    if not predictions:
        logger.error("No predictions data to export!")
        return
    
    logger.info(f"Found {len(predictions)} prediction entries to process")
    
    # Check if we have a 'spatial' key in the predictions
    if 'spatial' in predictions:
        logger.info("Found 'spatial' key in predictions - using spatial prediction data")
        
        # Check if spatial contains the expected probability maps
        spatial_data = predictions['spatial']
        required_keys = ['prob_streams', 'prob_ditches']
        if not all(key in spatial_data for key in required_keys):
            logger.error(f"Spatial data missing required fields. Available keys: {list(spatial_data.keys())}")
            return
        
        # If the zone ID is included in the predictions, use it; otherwise use zone_1 as default
        zone_id = spatial_data.get('zone_id', 0)  # Default to zone_1 (index 0)
        zone_name = zone_id_to_name.get(zone_id)
        
        if zone_name is None:
            logger.warning(f"Invalid zone_id {zone_id} in spatial data, defaulting to zone_1")
            zone_name = "zone_1"
            
        logger.info(f"Processing spatial predictions for {zone_name}")
        
        # Get zone boundaries
        if zone_name not in zone_boundaries:
            logger.error(f"No boundaries found for {zone_name}")
            return
            
        ul_x, ul_y = zone_boundaries[zone_name]["upper_left"]
        lr_x, lr_y = zone_boundaries[zone_name]["lower_right"]
        
        # Get array dimensions
        try:
            grid_height, grid_width = spatial_data["prob_streams"].shape
            logger.info(f"Array shape: {grid_height}x{grid_width}")
        except Exception as e:
            logger.error(f"Error getting array shape: {str(e)}")
            return
            
        transform = from_bounds(ul_x, lr_y, lr_x, ul_y, grid_width, grid_height)
        
        # Save probability maps as float32
        save_tif(os.path.join(output_path, f"{zone_name}_stream_prob.tif"),
                 spatial_data["prob_streams"].astype(np.float32), 
                 transform, crs, dtype='float32')
        
        save_tif(os.path.join(output_path, f"{zone_name}_ditch_prob.tif"),
                 spatial_data["prob_ditches"].astype(np.float32), 
                 transform, crs, dtype='float32')
        
        # Handle classification
        if "pred_combined" in spatial_data:
            classification = spatial_data["pred_combined"].astype(np.uint8)
        else:
            # Compute classification from probabilities
            probs = np.zeros((3, grid_height, grid_width), dtype=np.float32)
            probs[0] = 1.0 - (spatial_data["prob_streams"] + spatial_data["prob_ditches"])
            probs[1] = spatial_data["prob_streams"]
            probs[2] = spatial_data["prob_ditches"]
            classification = np.argmax(probs, axis=0).astype(np.uint8)
            
        save_tif(os.path.join(output_path, f"{zone_name}_classification.tif"),
                 classification, transform, crs, dtype='uint8')
                 
        logger.info(f"✅ Exported spatial predictions for {zone_name}")
        return
    
    # If no 'spatial' key, process each zone separately
    logger.info("Processing zone-by-zone predictions")
    success_count = 0
    
    for zone_name, zone_data in predictions.items():
        # Skip non-zone entries
        if zone_name == 'spatial':
            continue
            
        # Check if this is a valid zone
        if zone_name not in zone_boundaries:
            logger.warning(f"⚠️ No boundaries for zone: {zone_name}")
            continue
            
        # Check if this zone has prediction data
        if not isinstance(zone_data, dict):
            logger.warning(f"⚠️ Zone data for {zone_name} is not a dictionary")
            continue
            
        if "prob_streams" not in zone_data or "prob_ditches" not in zone_data:
            logger.warning(f"⚠️ Missing probability data for {zone_name}")
            if isinstance(zone_data, dict):
                logger.info(f"Available keys: {list(zone_data.keys())}")
            continue

        # Get zone boundaries
        ul_x, ul_y = zone_boundaries[zone_name]["upper_left"]
        lr_x, lr_y = zone_boundaries[zone_name]["lower_right"]
        
        # Verify shape of arrays
        try:
            grid_height, grid_width = zone_data["prob_streams"].shape
            logger.info(f"Processing {zone_name}: array shape = {grid_height}x{grid_width}")
        except Exception as e:
            logger.error(f"Error getting array shape for {zone_name}: {str(e)}")
            continue
            
        transform = from_bounds(ul_x, lr_y, lr_x, ul_y, grid_width, grid_height)

        # Save probability rasters as float32
        save_tif(os.path.join(output_path, f"{zone_name}_stream_prob.tif"),
                 zone_data["prob_streams"].astype(np.float32), 
                 transform, crs, dtype='float32')
        
        save_tif(os.path.join(output_path, f"{zone_name}_ditch_prob.tif"),
                 zone_data["prob_ditches"].astype(np.float32), 
                 transform, crs, dtype='float32')

        # Handle classification
        if "pred_combined" in zone_data:
            classification = zone_data["pred_combined"].astype(np.uint8)
        else:
            # Compute classification from probabilities
            probs = np.zeros((3, grid_height, grid_width), dtype=np.float32)
            probs[0] = 1.0 - (zone_data["prob_streams"] + zone_data["prob_ditches"])
            probs[1] = zone_data["prob_streams"]
            probs[2] = zone_data["prob_ditches"]
            classification = np.argmax(probs, axis=0).astype(np.uint8)
            
        save_tif(os.path.join(output_path, f"{zone_name}_classification.tif"),
                 classification, transform, crs, dtype='uint8')
        
        success_count += 1
        logger.info(f"✅ Exported {zone_name}")
    
    logger.info(f"Export complete. Successfully processed {success_count} zones.")

# Main execution
logger.info("Starting export process...")
try:
    export_predictions(predictions, zone_boundaries, zone_name_to_id, output_path)
    logger.info("Export process completed successfully")
except Exception as e:
    logger.error(f"Error during export: {str(e)}")
    import traceback
    logger.error(traceback.format_exc())

    logger.info(f"Exporting predictions to GeoTIFF in {output_path}")
    pass

def save_predictions_to_zarr(predictions, output_path, zone_boundaries, zone_name_to_id):
    """Save prediction results to a Zarr file format."""
    logger = logging.getLogger(__name__)
    logger.info(f"Saving predictions to Zarr store: {output_path}")
    
    # Define compression parameters
    compression_params = {
        'compressor': zarr.Blosc(cname='zstd', clevel=3)
    }
    
    # Create root zarr group
    root = zarr.open(output_path, mode='w')
    
    # Create a group for zones
    zones_group = root.create_group('zones')
    
    # Create metadata group and add attributes
    metadata = root.create_group('metadata')
    metadata.attrs['crs'] = "EPSG:3067"
    metadata.attrs['created'] = np.datetime64('now').astype(str)
    
    # Add zone boundaries as attributes
    boundaries_group = metadata.create_group('boundaries')
    for zone_name, bounds in zone_boundaries.items():
        zone_bounds = boundaries_group.create_group(zone_name)
        zone_bounds.attrs['upper_left'] = bounds['upper_left']
        zone_bounds.attrs['lower_right'] = bounds['lower_right']
    
    # Process each zone separately
    logger.info(f"Saving zone-by-zone predictions to Zarr")
    success_count = 0
    
    for zone_name, zone_data in predictions.items():
        # Skip non-zone entries or invalid zones
        if zone_name not in zone_boundaries:
            continue
            
        # Check if this zone has prediction data
        if not isinstance(zone_data, dict):
            logger.warning(f"⚠️ Zone data for {zone_name} is not a dictionary")
            continue
            
        if "prob_streams" not in zone_data or "prob_ditches" not in zone_data:
            logger.warning(f"⚠️ Missing probability data for {zone_name}")
            continue
        
        try:
            # Get zone ID
            zone_id = zone_name_to_id.get(zone_name)
            
            # Create zone group
            zone_group = zones_group.create_group(zone_name)
            zone_group.attrs['zone_id'] = zone_id
            
            # Save probability arrays
            zone_group.create_dataset('prob_streams', 
                                    data=zone_data["prob_streams"].astype(np.float32),
                                    chunks=(1250, 1250),
                                    **compression_params)
            
            zone_group.create_dataset('prob_ditches', 
                                    data=zone_data["prob_ditches"].astype(np.float32),
                                    chunks=(1250, 1250),
                                    **compression_params)
            
            # Add additional datasets if they exist
            if "pred_streams" in zone_data:
                zone_group.create_dataset("pred_streams", 
                                     data=zone_data["pred_streams"], 
                                     dtype=np.float32, 
                                     chunks=(1250, 1250),
                                     **compression_params)
                                     
            if "pred_ditches" in zone_data:
                zone_group.create_dataset("pred_ditches", 
                                     data=zone_data["pred_ditches"], 
                                     dtype=np.float32, 
                                     chunks=(1250, 1250),
                                     **compression_params)
                                     
            if "pred_combined" in zone_data:
                zone_group.create_dataset("pred_combined", 
                                     data=zone_data["pred_combined"], 
                                     dtype=np.float32, 
                                     chunks=(1250, 1250),
                                     **compression_params)
                                     
            if "prob_combined" in zone_data:
                zone_group.create_dataset("prob_combined", 
                                     data=zone_data["prob_combined"], 
                                     dtype=np.float32, 
                                     chunks=(1250, 1250),
                                     **compression_params)
            
            # Save or compute classification
            grid_height, grid_width = zone_data["prob_streams"].shape
            
            if "pred_combined" in zone_data:
                classification = zone_data["pred_combined"].astype(np.uint8)
            else:
                # Compute classification from probabilities
                probs = np.zeros((3, grid_height, grid_width), dtype=np.float32)
                probs[0] = 1.0 - (zone_data["prob_streams"] + zone_data["prob_ditches"])
                probs[1] = zone_data["prob_streams"]
                probs[2] = zone_data["prob_ditches"]
                classification = np.argmax(probs, axis=0).astype(np.uint8)
            
            zone_group.create_dataset('classification', 
                                    data=classification,
                                    chunks=(1250, 1250),
                                    **compression_params)
            
            success_count += 1
            logger.info(f"✅ Saved {zone_name} to Zarr")
            
        except Exception as e:
            logger.error(f"Error saving {zone_name} to Zarr: {str(e)}")
            import traceback
            logger.error(traceback.format_exc())
    

    logger.info(f"Saving predictions to Zarr in {output_path}")
    pass

def main():
    """Main function to run the script."""
    # Set up directories and memory cache
    os.makedirs(output_path, exist_ok=True)
    os.makedirs(models_dir, exist_ok=True)  # Ensure models directory exists
    memory = setup_joblib_cache(cache_dir)
    
    # Fix the zarr file features if needed
    fix_zarr_file_features(zarr_file)
    
    # Process each zone
    all_predictions = {}
    zones_processed = 0
    
    for zone_name in zones_to_process:
        logger.info(f"Processing zone: {zone_name}")
        
        predictions = predict_for_new_zone(
            zarr_file=zarr_file,
            zone_name=zone_name,
            models_dir=models_dir,
            selected_features=selected_features,
            output_path=output_path
        )
        print_memory_usage()
        
        if predictions:
            all_predictions[zone_name] = predictions
            logger.info(f"Successfully processed zone: {zone_name}")
            zones_processed += 1
        else:
            logger.warning(f"Failed to process zone: {zone_name}")
    
    # After processing all zones, export everything to both formats - only if we have predictions
    if zones_processed > 0 and all_predictions:
        logger.info(f"Exporting {len(all_predictions)} zone predictions...")
        
        # Transform the predictions
        transformed_predictions = transform_predictions(all_predictions)
        
        # Export to GeoTIFF
        export_predictions(all_predictions, zone_boundaries, zone_name_to_id, output_path)
        
        # Export to Zarr using the existing function, but with transformed data
        zarr_output_path = os.path.join(output_path, "predictions.zarr")
        try:
            # Note the parameter order matches the existing function
            save_predictions_to_zarr(transformed_predictions, zarr_output_path, zone_boundaries, zone_name_to_id)
            logger.info(f"Successfully exported {zones_processed} zones.")
        except Exception as e:
            logger.error(f"Error saving to Zarr: {str(e)}")
            import traceback
            logger.error(traceback.format_exc())
            logger.warning("Export process encountered errors.")
    else:
        logger.warning("No zones were successfully processed. Nothing to export.")
    
    logger.info("Processing complete.")

if __name__ == "__main__":
    main()

2025-04-12 22:02:11,240 - __main__ - INFO - Starting export process...
2025-04-12 22:02:11,240 - __main__ - ERROR - Error during export: name 'predictions' is not defined
2025-04-12 22:02:11,240 - __main__ - ERROR - Traceback (most recent call last):
  File "C:\Users\OMISTAJA\AppData\Local\Temp\ipykernel_11712\3165423215.py", line 977, in <module>
    export_predictions(predictions, zone_boundaries, zone_name_to_id, output_path)
NameError: name 'predictions' is not defined

2025-04-12 22:02:11,240 - __main__ - INFO - Exporting predictions to GeoTIFF in ../02_Results/1204/
2025-04-12 22:02:11,240 - __main__ - INFO - Fixing zarr file by adding missing features: zones_data_2.zarr
2025-04-12 22:02:11,341 - __main__ - INFO - Finished checking/fixing features in zones_data_2.zarr
2025-04-12 22:02:11,342 - __main__ - INFO - Processing zone: zone_1
2025-04-12 22:02:11,342 - __main__ - INFO - Predict for new zone zone_1 using models from ../02_Results/models
2025-04-12 22:02:11,342 - __main__ -