In [None]:
from IPython.display import display

from pathlib import Path

from training_raster_clipper.core.models import TrainingConfiguration, TrainingFunctions

from training_raster_clipper.core.logging import log_info

from training_raster_clipper.core.visualization import (
    plot_rgb_data_array,
    plot_array,
    plot_geodataframe,
)

import matplotlib.pyplot as plt

from pathlib import Path

import geopandas as gpd
import numpy as np
import rioxarray
import xarray as xr
from affine import Affine
from geopandas.geodataframe import GeoDataFrame
from rasterio.features import rasterize
from sklearn.ensemble import RandomForestClassifier

from training_raster_clipper.custom_types import (
    BandNameType,
    ClassificationResult,
    ClassifiedSamples,
    FeatureClassNameToId,
    PolygonMask,
    ResolutionType,
)

from geopandas import read_file
from rioxarray import open_rasterio
from rasterio.features import rasterize
import numpy as np
import pandas as pd

xr.set_options(keep_attrs=True)

In [None]:
raster_input_path = Path(
    "D:/Profils/iatraoui/Desktop/Project/S2A_MSIL2A_20230823T104631_N0509_R051_T31TCJ_20230823T170355.SAFE"
)
polygons_input_path = Path(".").resolve().parent / Path(
    "D:\\Profils\\iatraoui\\Desktop\\Project\\QGIS\\Polygons2.geojson"
)

config = TrainingConfiguration(
    verbose=True,
    show_plots=True,
    resolution=60,
    band_names=("B04", "B03", "B02", "B8A"),
    raster_input_path=raster_input_path,
    polygons_input_path=polygons_input_path,
    csv_output_path=(
        Path(".").resolve().parent / Path("generated/classified_points.csv")
    ),
    raster_output_path=(
        Path(".").resolve().parent / Path("generated/sklearn_raster.tiff")
    ),
    implementation_name="eschalk",
)
config

In [None]:
verbose = config.verbose
show_plots = config.show_plots

resolution = config.resolution
band_names = config.band_names

raster_input_path = config.raster_input_path
polygons_input_path = config.polygons_input_path
csv_output_path = config.csv_output_path
raster_output_path = config.raster_output_path

### (1) Load a GeoJSON file with `geopandas`

In [None]:
def load_feature_polygons(input_path: Path) -> GeoDataFrame:

    gdf = read_file(input_path)
    gdf = gdf.to_crs(32631)
    return gdf

In [None]:
polygons = load_feature_polygons(polygons_input_path)
if verbose:
    log_info(polygons, "polygons")
if show_plots:
    plot_geodataframe(polygons, f"{load_feature_polygons.__name__}")

### (2) Load a Sentinel-2 raster with `rioxarray`

In [None]:
def load_sentinel_data(
    sentinel_product_location: Path,
    resolution: ResolutionType,
    band_names: tuple[BandNameType, ...],
) -> xr.DataArray:
    """Loads sentinel product

    Example input path: `S2A_MSIL2A_20221116T105321_N0400_R051_T31TCJ_20221116T170958.SAFE`

    Args:
        sentinel_product_location (Path): Location of the .SAFE folder containing a Sentinel-2 product.

    Returns:
        xr.DataArray: A DataArray containing the 3 RGB bands from the visible spectrum
    """

    paths_list = {
        band_name: list(
            sentinel_product_location.glob(
                f"GRANULE/*/IMG_DATA/R{resolution}m/*_{band_name}_*"
            )
        )[0]
        for band_name in band_names
    }
    data_arrays_list = [
        open_rasterio(paths_list[band_name]).assign_coords({"band": [band_name]})
        for band_name in band_names
    ]
    rasters_data_array = xr.concat(data_arrays_list, dim="band")
    rasters_data_array = rasters_data_array.where(
        rasters_data_array != 0, np.float32(np.nan)
    )
    radio_add_offset = -1000
    quantification_value = 10000
    result = (rasters_data_array + radio_add_offset) / quantification_value
    return result

In [None]:
rasters = load_sentinel_data(raster_input_path, resolution, band_names)
if verbose:
    log_info(rasters, "rasters")
if show_plots:
    plot_rgb_data_array(rasters, f"{load_sentinel_data.__name__}")

### (3) Rasterize the polygons

In [None]:
def rasterize_geojson(
    data_array: xr.DataArray,
    training_classes: GeoDataFrame,
) -> tuple[PolygonMask, FeatureClassNameToId]:
    """Burns a set of vectorial polygons to a raster.

    See https://gis.stackexchange.com/questions/316626/rasterio-features-rasterize

    Args:
        data_array (xr.DataArray): The Sentinel raster, from which data is taken, such as the transform or the shape.
        training_classes (GeoDataFrame): The input set of classified multipolygons to burn

    Returns:
        xr.DataArray: A mask raster generated from the polygons, representing the same geographical region as the source dataarray param
                      0 where no polygon were found, and integers representing classes in order of occurence in the GeoDataFrame
    """

    mapping = {"WATER": 1, "FOREST": 2, "FARM": 3}

    out_shape = data_array.isel(band=0, drop=True).shape

    shapes = pd.DataFrame(training_classes)
    shapes = shapes[["geometry", "class"]]
    shapes["class"] = shapes["class"].map(mapping)
    shapes = [tuple(row) for row in shapes.to_numpy()]
    print(shapes[0])

    transform = data_array.spatial_ref.GeoTransform
    transform = (float(x) for x in transform.split())
    transform = Affine.from_gdal(*transform)

    display(f"{shapes = }")
    display(out_shape)
    display(transform)

    burnt_polygons = rasterize(
        shapes, out_shape=out_shape, transform=transform, dtype=np.uint8
    )

    return (burnt_polygons, mapping)

In [None]:
burnt_polygons, mapping = rasterize_geojson(rasters, polygons)
if verbose:
    log_info(burnt_polygons, "burnt_polygons")
    log_info(mapping, "mapping")
if show_plots:
    plot_array(burnt_polygons, f"{rasterize_geojson.__name__}")

### (4) Intersect the Sentinel-2 raster with polygons

In [None]:
def produce_clips(
    data_array: xr.DataArray, burnt_polygons: PolygonMask, mapping: FeatureClassNameToId
) -> ClassifiedSamples:
    """Extract RGB values covered by classified polygons

    Args:
        data_array (xr.DataArray): RGB raster
        burnt_polygons (PolygonMask): Rasterized classified multipolygons

    Returns:
        _type_: A list of the RGB values contained in the data_array and their corresponding classes
    """

    data_array = data_array.stack(z=("y", "x"))
    burnt_polygons = burnt_polygons.reshape(-1)
    burnt_polygons = burnt_polygons.astype("int64")

    classified_rgb_rows = xr.Dataset(
        {
            "reflectance": data_array,
            "feature_id": xr.DataArray(burnt_polygons, dims="z"),
        }
    )

    classified_rgb_rows = classified_rgb_rows.sel(
        z=classified_rgb_rows["feature_id"] != 0
    )

    return classified_rgb_rows

In [None]:
classified_rgb_rows = produce_clips(rasters, burnt_polygons, mapping)
if verbose:
    log_info(classified_rgb_rows, "classified_rgb_rows")

### (5) Persist the intersection to a CSV

In [None]:
def persist_to_csv(
    classified_rgb_rows: ClassifiedSamples,
    csv_output_path: Path,
) -> None:

    dict_data = {
        band_name: classified_rgb_rows.sel(band=band_name)["reflectance"].values
        for band_name in classified_rgb_rows.coords["band"].values
    }
    dict_data["feature_id"] = classified_rgb_rows.isel(band=0)["feature_id"].values
    classified_rgb_rows_dataframe = pd.DataFrame.from_dict(dict_data, orient="columns")

    classified_rgb_rows_dataframe.to_csv(
        Path(
            "D:\\Profils\\iatraoui\\Desktop\\Project\\training-raster-clipper\\generated\\classified_points_iatraoui.csv"
        )
    )

In [None]:
persist_to_csv(classified_rgb_rows, csv_output_path)
log_info(f"Written CSV output {csv_output_path}")

### (6) Train a machine learning model

In [None]:
def classify_sentinel_data(
    rasters: xr.DataArray, classified_rgb_rows: ClassifiedSamples
) -> ClassificationResult:

    RF_model = RandomForestClassifier(
        random_state=0,
        n_jobs=-1,
        n_estimators=10,
        bootstrap=False,
        class_weight="balanced",
    )

    training_data = classified_rgb_rows["reflectance"].values
    training_labels = classified_rgb_rows["feature_id"].values

    RF_model.fit(training_data.T, training_labels)

    rasters = rasters.stack(z=("y", "x"))

    result = RF_model.predict(rasters.values.T)

    result_array = rasters.unstack().isel(band=0).copy()
    result_array.values = result.reshape(result_array.shape)

    result_array = result_array.astype(np.uint8)

    return result_array

In [None]:
classification_result = classify_sentinel_data(rasters, classified_rgb_rows)
if verbose:
    log_info(classification_result, "classification_result")
if show_plots:
    plot_array(classification_result, f"{classify_sentinel_data.__name__}")

### (7) Export the classification raster result

In [None]:
def persist_classification_to_raster(
    raster_output_path: Path, classification_result: ClassificationResult
) -> None:

    classification_result.rio.to_raster(
        Path(
            "D:\\Profils\\iatraoui\\Desktop\\Project\\training-raster-clipper\\generated\\sklearn_raster_iatraoui.tiff"
        )
    )

In [None]:
persist_classification_to_raster(raster_output_path, classification_result)
log_info(f"Written Classified Raster to {csv_output_path}")

# --

log_info("Congratulations, you reached the end of the tutorial!")