## Pre-processing the chosen geotiffs before passing to the stratification notebook

Checklist:
- all incoming data using same projection (and if not, reproject)
- align the spatial resolutions
- data cleaning - check for missing values
- normalise/ scale pixel values
- create data cube

some resources to check later:
https://discourse.pangeo.io/t/advice-for-scalable-raster-vector-extraction/4129

In [None]:
import os

import numpy as np
import geopandas as gpd

import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling


In [None]:
def load_rasters(raster_dir):
    rasters = []
    for raster_file in os.listdir(raster_dir):
        if raster_file.endswith('.tiff'):
            raster = rasterio.open(os.path.join(raster_dir, raster_file))
            rasters.append(raster)
    return rasters

In [None]:
# NOTE: Thi function assumes all incoming rasters only have one band. If multi-band rasters are included, it will need to be modified.

def projection_check(rasters, target_crs):
    reprojected_rasters = []
    for raster in rasters:
        if raster.crs != target_crs:
            transform, width, height = calculate_default_transform(
                raster.crs, target_crs, raster.width, raster.height, *raster.bounds)
            
            # target array for new data. Makes sure all data is in the same format
            new_array = np.zeros((height, width), dtype = rasterio.float32)

            # Reproject rasters to target crs
            reproject(
                source = rasterio.band(raster, 1),
                destination = new_array,
                src_transform=raster.transform,
                src_crs = raster.crs,
                dst_transform = transform,
                dst_crs = target_crs,
                resampling = Resampling.nearest
            )

            new_raster_metadata = raster.meta.copy()
            new_raster_metadata.update({
                'crs': target_crs,
                'transform': transform,
                'width': width,
                'height': height
            })

            reprojected_rasters.append((new_array, new_raster_metadata))
        else:
            data = raster.read(1)
            reprojected_rasters.append((data, raster.meta))
    return reprojected_rasters
        