In [1]:
import logging
from pathlib import Path

import geopandas as gpd
import numpy as np
import rioxarray
import xarray as xr
from affine import Affine
from geopandas.geodataframe import GeoDataFrame
from rasterio.features import rasterize
from rasterio.transform import Affine
from sklearn.ensemble import RandomForestClassifier

from training_raster_clipper.core.logging import log_info
from training_raster_clipper.core.models import TrainingConfiguration
from training_raster_clipper.core.visualization import (
    plot_array,
    plot_geodataframe,
    plot_rgb_data_array,
)
from training_raster_clipper.custom_types import (
    BandNameType,
    ClassificationResult,
    ClassifiedSamples,
    FeatureClassNameToId,
    PolygonMask,
    ResolutionType,
)

In [None]:
# See https://stackoverflow.com/questions/18786912/get-output-from-the-logging-module-in-ipython-notebook
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.info("test") 

In [None]:
raster_input_path = Path(
    "C:/Users/abelaid/training-raster-clipper/resources/solution/example_sentinel_files/SENTINEL.SAFE"
)
polygons_input_path = Path(
    "C:/Users/abelaid/training-raster-clipper/resources/solution/polygons.geojson"
)

assert raster_input_path.exists() and raster_input_path.is_dir()
assert polygons_input_path.exists() and polygons_input_path.is_file()

config = TrainingConfiguration(
    verbose=True,
    show_plots=True,
    resolution=60,
    band_names=("B04", "B03", "B02", "B8A"),
    raster_input_path=raster_input_path,
    polygons_input_path=polygons_input_path,
    csv_output_path = (Path(".").resolve().parent / Path("generated/classified_points.csv")),
    raster_output_path=(Path(".").resolve().parent / Path("generated/sklearn_raster.tiff")),
    implementation_name="abdellah",
)
config


In [None]:
verbose = config.verbose
show_plots = config.show_plots

resolution = config.resolution
band_names = config.band_names

raster_input_path = config.raster_input_path
polygons_input_path = config.polygons_input_path
csv_output_path = config.csv_output_path
raster_output_path = config.raster_output_path


### (1) Load a GeoJSON file with `geopandas`


In [None]:
def load_feature_polygons(input_path: Path) -> GeoDataFrame:
    
    return gpd.read_file(input_path).to_crs(epsg=32631)


In [None]:

polygons = load_feature_polygons(polygons_input_path)
if verbose:
    log_info(polygons, "polygons")
if show_plots:
    plot_geodataframe(polygons, f"{load_feature_polygons.__name__}")


### (2) Load a Sentinel-2 raster with `rioxarray`


In [None]:
def load_sentinel_data(
    sentinel_product_location: Path,
    resolution: ResolutionType,
    band_names: tuple[BandNameType, ...],
) -> xr.DataArray:
    """Loads sentinel product

    Example input path: `S2A_MSIL2A_20221116T105321_N0400_R051_T31TCJ_20221116T170958.SAFE`

    Args:
        sentinel_product_location (Path): Location of the .SAFE folder containing a Sentinel-2 product.

    Returns:
        xr.DataArray: A DataArray containing the 3 RGB bands from the visible spectrum
    """

    dict = { e: list(sentinel_product_location.glob(f'GRANULE/*/IMG_DATA/R{resolution}m/*_{e}_*'))[0] for e in band_names }

    rioDict = { band: rioxarray.open_rasterio(dict[band]) for band in band_names }

    rasters = list(
        raster.assign_coords(coords={"band": [band_name]})
        for band_name, raster in rioDict.items()
        if isinstance(raster, xr.DataArray)
    )

    bxy = xr.concat(rasters, "band")

    print(bxy.coords)

    bxy.fillna(np.float32(0.0))

    RADIO_ADD_OFFSET = -1000
    QUANTIFICATION_VALUE = 10000

    return (bxy + RADIO_ADD_OFFSET) / QUANTIFICATION_VALUE


In [None]:
rasters = load_sentinel_data(raster_input_path, resolution, band_names)
if verbose:
    log_info(rasters, "rasters")
if show_plots:
    plot_rgb_data_array(rasters, f"{load_sentinel_data.__name__}")


### (3) Rasterize the polygons


In [None]:
def rasterize_geojson(
    data_array: xr.DataArray,
    training_classes: GeoDataFrame,
):
# ) -> tuple[PolygonMask, FeatureClassNameToId]:
    """Burns a set of vectorial polygons to a raster.

    See https://gis.stackexchange.com/questions/316626/rasterio-features-rasterize

    Args:
        data_array (xr.DataArray): The Sentinel raster, from which data is taken, such as the transform or the shape.
        training_classes (GeoDataFrame): The input set of classified multipolygons to burn

    Returns:
        xr.DataArray: A mask raster generated from the polygons, representing the same geographical region as the source dataarray param
                      0 where no polygon were found, and integers representing classes in order of occurence in the GeoDataFrame
    """

    raster_transform = list(float(k) for k in data_array.spatial_ref.GeoTransform.split())
    raster_transform = Affine.from_gdal(*raster_transform)

    shape = data_array.isel(band=0, drop=True).shape
    
    geometry_col = training_classes["geometry"]
    class_col = training_classes["class"]
    index = training_classes.index

    mapping = dict(zip(class_col, index + 1))
    shapes = list(zip(geometry_col, index + 1))

    rasterized: PolygonMask = rasterize(
        shapes, 
        out_shape = shape,
        transform= raster_transform,
        dtype = np.uint8
    )

    return rasterized, mapping

In [None]:
rasterize_geojson(rasters, polygons)

In [None]:
burnt_polygons, mapping = rasterize_geojson(rasters, polygons)
if verbose:
    log_info(burnt_polygons, "burnt_polygons")
    log_info(mapping, "mapping")
if show_plots:
    plot_array(burnt_polygons, f"{rasterize_geojson.__name__}")


### (4) Intersect the Sentinel-2 raster with polygons


In [None]:
def produce_clips(
    data_array: xr.DataArray, burnt_polygons: PolygonMask, mapping: FeatureClassNameToId
):
# ) -> ClassifiedSamples:
    """Extract RGB values covered by classified polygons

    Args:
        data_array (xr.DataArray): RGB raster
        burnt_polygons (PolygonMask): Rasterized classified multipolygons

    Returns:
        _type_: A list of the RGB values contained in the data_array and their corresponding classes
    """
    reflectance = data_array.stack(z=("y", "x"))
    reflectance.drop_vars(("y", "x"))

    feature_id_class = xr.DataArray(burnt_polygons.reshape(-1),dims="z")

    classified = xr.Dataset(
        {
            "reflectance": reflectance,
            "feature_id": feature_id_class
        }
    )

    return classified.sel(z = feature_id_class != 0)

    


In [None]:
produce_clips(rasters, burnt_polygons, mapping)

In [None]:
classified_rgb_rows = produce_clips(rasters, burnt_polygons, mapping)
if verbose:
    log_info(classified_rgb_rows, "classified_rgb_rows")


### (5) Persist the intersection to a CSV


In [None]:
def persist_to_csv(
    classified_rgb_rows: ClassifiedSamples,
    csv_output_path: Path,
) -> None:
    
    data = classified_rgb_rows["reflectance"].T.to_pandas()
    data["feature_id"] = classified_rgb_rows["feature_id"].to_series()
    data.to_csv(csv_output_path, index=False, sep=";")


In [None]:
persist_to_csv(classified_rgb_rows, csv_output_path)
log_info(f"Written CSV output {csv_output_path}")


### (6) Train a machine learning model


In [None]:
def classify_sentinel_data(
    rasters: xr.DataArray, classified_rgb_rows: ClassifiedSamples
) -> ClassificationResult:
    
    model = RandomForestClassifier()

    train = classified_rgb_rows["reflectance"].T
    labels = classified_rgb_rows["feature_id"]

    model.fit(train, labels)

    classes = model.predict(rasters.stack(z=("y", "x")).T)

    rasters_nb = rasters.isel(band=0, drop=True)

    return xr.DataArray(classes.reshape(rasters_nb.shape))


In [None]:
classification_result = classify_sentinel_data(
    rasters, classified_rgb_rows
)
if verbose:
    log_info(classification_result, "classification_result")
if show_plots:
    plot_array(
        classification_result, f"{classify_sentinel_data.__name__}"
    )


### (7) Export the classification raster result


In [None]:
def persist_classification_to_raster(
    raster_output_path: Path, classification_result: ClassificationResult
) -> None:
    classification_result.rio.to_raster(raster_output_path)


In [None]:
persist_classification_to_raster(
    raster_output_path, classification_result
)
log_info(f"Written Classified Raster to {csv_output_path}")

# --

log_info("Congratulations, you reached the end of the tutorial!")
