## Pre-processing the chosen geotiffs before passing to the stratification notebook

Checklist:
- all incoming data using same projection (and if not, reproject)
- align the spatial resolutions
- data cleaning - check for missing values
- normalise/ scale pixel values
- create data cube

some resources to check later:
https://discourse.pangeo.io/t/advice-for-scalable-raster-vector-extraction/4129

In [32]:
%pip install pydantic -q

Note: you may need to restart the kernel to use updated packages.


In [33]:
import os

import numpy as np
import xarray as xr
import geopandas as gpd

import rasterio
from rasterio.warp import calculate_default_transform, reproject
from rasterio.enums import Resampling
import rioxarray
import rio_cogeo

from pydantic import BaseModel, field_validator
from typing import List

import logging

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s',
                    handlers=[
                        logging.FileHandler("debug.log"),
                        logging.StreamHandler()
                    ])

In [34]:
class ResamplingMethods:
    NEAREST = "nearest"
    BILINEAR = "bilinear"
    CUBIC = "cubic"

In [35]:
def load_rasters(raster_dir):
    data_arrays = []
    for raster_file in os.listdir(raster_dir):
        if raster_file.endswith('.tiff') or raster_file.endswith('.tif'):
            raster_path = os.path.join(raster_dir, raster_file)
            # Open the raster file as an xarray DataArray with rioxarray
            data_array = rioxarray.open_rasterio(raster_path)
            data_arrays.append(data_array)
    return data_arrays

In [36]:
# NOTE: This function assumes all incoming rasters only have one band. If multi-band rasters are included, it will need to be modified.

def projection_check(data_arrays, target_crs):
    reprojected_data_arrays = []
    for da in data_arrays:
        if da.rio.crs != target_crs:
            # Reproject the data array to the target CRS
            reprojected_da = da.rio.reproject(target_crs)
            reprojected_data_arrays.append(reprojected_da)
        else:
            reprojected_data_arrays.append(da)
    return reprojected_data_arrays
        

In [37]:
def convert_dtypes(data_arrays):
    converted_data_arrays = []
    for da in data_arrays:
        if da.dtype == "uint16":
            # Assume the original NoData value is known, set it as such or detect it
            original_nodata = da.rio.nodata
            input_raster = da.astype("float32")
            # Replace original NoData with NaN in float32
            if original_nodata is not None:
                input_raster = input_raster.where(
                    input_raster != original_nodata, np.nan
                )
            
        input_raster.rio.write_nodata(np.nan, inplace=True)
        converted_data_arrays.append(input_raster)
    return converted_data_arrays

In [38]:
def check_consistency(data_arrays, dst_crs):
    formatted_crs = f"EPSG:{dst_crs}"
    # Check for inconsistent data types
    dtypes = [da.dtype for da in data_arrays]
    unique_dtypes = set(dtypes)
    if len(unique_dtypes) > 1:
        logging.error("Inconsistent data types found: {}".format(unique_dtypes))
    else:
        logging.info("All data arrays have consistent data types.")

    # Convert from uint16 to float32 if needed:
    if "uint16" in unique_dtypes:
        logging.info("Converting uint16 data arrays to float32.")
        data_arrays = convert_dtypes(data_arrays)
    else:
        logging.info("No uint16 data arrays found.")

    # Check for inconsistent CRS
    crs_set = {da.rio.crs for da in data_arrays if da.rio.crs is not None}
    if len(crs_set) > 1:
        logging.error("Inconsistent CRS found: {}".format(crs_set))
        logging.info("Attempting to convert all data arrays to the same data type.")
        return projection_check(data_arrays, formatted_crs)
    else:
        logging.info("All data arrays have a consistent CRS.")

    if crs_set != {formatted_crs}:
        logging.info("Converting all data arrays to the target CRS.")
        return projection_check(data_arrays, formatted_crs)

    # Check for missing values in any of the DataArrays
    missing_values_found = False
    for da in data_arrays:
        if da.isnull().any():
            logging.error("Missing values found in one of the DataArrays.")
            missing_values_found = True
            break
    if not missing_values_found:
        logging.info("No missing values found in any of the data arrays.")

In [39]:
def find_finest_resolution(data_arrays) -> float:
    finest_x_res = float('inf')  # Start with a large number as the finest resolution
    finest_y_res = float('inf')

    for da in data_arrays:
            current_x_res = abs(da.rio.resolution()[0])
            current_y_res = abs(da.rio.resolution()[1])
            # Update the finest resolution if the current one is smaller
            if current_x_res < finest_x_res and current_y_res < finest_y_res:
                finest_width = current_x_res
                finest_height = current_y_res
    return finest_x_res, finest_y_res

In [40]:
def resample_rasters_to_finest_pixel(data_arrays):
    finest_x_res, finest_y_res = find_finest_resolution(data_arrays) #find the smallest pixels in the input arrays to then match the rest to

    print(f"finest res: {finest_x_res}, {finest_y_res}")
    resampled_data_arrays = []

    for da in data_arrays:

        current_x_res = abs(da.rio.resolution()[0])
        current_y_res = abs(da.rio.resolution()[1])


        scale_factor_x = current_x_res / finest_x_res
        scale_factor_y = current_y_res / finest_y_res 

        print(f"scale_factor_x: {scale_factor_x}")
        print(f"scale_factor_y: {scale_factor_y}")

        new_width = int(da.rio.width / scale_factor_x)
        new_height = int(da.rio.height / scale_factor_y)

        print(f"Resampled width pixels: {new_width}")
        print(f"Resampled height pixels: {new_height}")
        
        resampled_da = da.rio.reproject(
            da.rio.crs, 
            shape=(new_height, new_width),
            resampling = Resampling.bilinear)
        resampled_data_arrays.append(resampled_da)

    return resampled_data_arrays

In [41]:
target_crs = 3857

input_raster_dir = "/workspace/notebooks/sandbox/data/stratification/input-rasters"
output_raster_dir = "/workspace/notebooks/sandbox/data/stratification/processed-rasters"

input_rasters = load_rasters(input_raster_dir)


In [42]:
# check that the loaded rasters have the same CRS and dtype and there are no missing data:

checked_rasters = check_consistency(input_rasters, target_crs)


2024-06-24 03:06:54,502 - INFO - All data arrays have consistent data types.
2024-06-24 03:06:54,505 - INFO - No uint16 data arrays found.
2024-06-24 03:06:54,507 - INFO - All data arrays have a consistent CRS.
2024-06-24 03:06:54,509 - INFO - Converting all data arrays to the target CRS.


Selecting Resampling Method

- Nearest Neighbor: Fast and suitable for categorical data.
- Bilinear: Good for continuous data where interpolation between values can be meaningful.
- Cubic: Better for continuous data where a smoother output is desired.

In [43]:
resampled_data_arrays = resample_rasters_to_finest_pixel(checked_rasters)

finest res: inf, inf
scale_factor_x: 0.0
scale_factor_y: 0.0


ZeroDivisionError: float division by zero

In [None]:
for da in resampled_data_arrays:
    print(da.rio.resolution())

In [None]:
for da in checked_rasters:
    print(da.rio.shape)
for da in resampled_data_arrays:
    print(da.rio.shape)