# Demo of the baseline model

In [None]:
import sys
sys.path.append("../")

import src.paths as PATHS
import src.constants as CONST

import src.data.data_handler as DH
import src.data.config as DATA_CONFIG

import src.model.baseline_model as BM

import geopandas as gpd
import folium
from folium.plugins import MarkerCluster

from pprint import pprint

Needed for the datahandler:
1. config (default?)
2. prediction regions (scope, geodataframe) from Luke's data
3. local data for enrichment (a dict of geodataframes) we need "centerline shape" gdf
4. erosion data (geodataframe) where the river bank was at different times
5. erosion border with the same CRS as the prediction regions

## 1. Gather all the inputs (normally user provided, or translated from user specifications)

### a. Configuration

Configuration drives the parameters of the feature creation. This would be a combination of user input (including defaults) and some internal parameters).

In the beaseline model case not all parameters are used, e.g. no need to define the remote data sources, as we only need the erosion data. Below we list the parameters relevant to the baseline model, even though we use their default values.

In [None]:
# explicitly set the relevant parameters to the defaults, equivalent to running
# baseline_configuration = DATA_CONFIG.DataConfiguration()

baseline_configuration = DATA_CONFIG.DataConfiguration(
    no_of_points_for_distance_calculation=CONST.DEFAULT_NO_OF_POINTS_FOR_DISTANCE_CALCULATION,
    prediction_region_id_column_name=CONST.PREDICTION_REGION_ID,
    timestamp_column_name=CONST.TIMESTAMP,
    use_only_certain_river_bank_points=CONST.DEFAULT_USE_ONLY_CERTAIN_RIVER_BANK_POINTS,
)

In [None]:
pprint(baseline_configuration)

### b. Prediction regions

Also called "scope" by Luke. A user defined dataset.

We have the data locally.

In [None]:
luke_geospatial_data = PATHS.DATA_DIR / "all_results_20250121_v2.gpkg"

scope_layer_name = "vlakken_scope"
prediction_regions = gpd.read_file(luke_geospatial_data, layer=scope_layer_name)

In [None]:
print(prediction_regions.crs)
prediction_regions.info()

In [None]:
prediction_regions.head()

### c. The local geospatial enrichment data

While we don't actually do any enrichment, we need the river centerline to be able to properly determine which ponts lie beyond the erosion border. Again, a user provided dataset.

We also have these locally.

In [None]:
etienne_geospatial_data = PATHS.DATA_DIR / "Levering_erosie_data.gpkg"

centerline_layer_name = "Centreline_River"
centerline = gpd.read_file(etienne_geospatial_data, layer=centerline_layer_name)

# we need to align all the CRS
centerline.to_crs(prediction_regions.crs, inplace=True)

# the local geospatial enrichment data is a dictionary of geodataframes
# TODO: define this constant better, not via an operation
local_geospatial_data = {CONST.AggregationOperations.CENTERLINE_SHAPE.value: centerline}

In [None]:
centerline.head()

### d. Erosion data

I.e. the points where the river bank was at different times. Used in model training.

We also have this locally.

In [None]:
riverbank_layer_name = "punten_oever"
river_bank_locations = gpd.read_file(luke_geospatial_data, layer=riverbank_layer_name)

river_bank_locations.to_crs(prediction_regions.crs, inplace=True)

In [None]:
river_bank_locations.sample(5)

### e. Erosion border

This is a line that has to have the right CRS. Can be user provided or internal.

We have it locally.

In [None]:
fake_erosion_border = PATHS.DATA_DIR / "handdrawn_fake_erosion_border.geojson"

erosion_border_gdf = gpd.read_file(fake_erosion_border)
erosion_border_gdf.to_crs(prediction_regions.crs, inplace=True)

erosion_border = erosion_border_gdf.iloc[0]["geometry"]

In [None]:
erosion_border

In [None]:
mapa = folium.Map(location=[CONST.CENTRE_NL_LAT, CONST.CENTRE_NL_LON], zoom_start=CONST.DEFAULT_NL_ZOOM, control_scale=True)

# scope
fg_scope = folium.FeatureGroup(name="prediction regions (scope)", show=False).add_to(mapa)
folium.GeoJson(prediction_regions["geometry"].to_crs(epsg=CONST.EPSG_WGS84)).add_to(fg_scope)

# river centerline
fg_centerline = folium.FeatureGroup(name="river centerline", show=False).add_to(mapa)
folium.GeoJson(centerline["geometry"].to_crs(epsg=CONST.EPSG_WGS84)).add_to(fg_centerline)

# river bank locations
fg_bank = folium.FeatureGroup(name="river bank", show=False).add_to(mapa)
# folium.GeoJson(river_bank_locations["geometry"].to_crs(epsg=CONST.EPSG_WGS84)).add_to(fg_bank)

# Add points to the map
for idx, row in river_bank_locations.to_crs(epsg=CONST.EPSG_WGS84).iterrows():
    folium.CircleMarker(
        location=[row.geometry.y, row.geometry.x],
        radius = 2,
        color="blue" if row[CONST.TIMESTAMP] == 3 else "orange",
        opacity=0.5,
    ).add_to(fg_bank)


# erosion border
fg_border = folium.FeatureGroup(name="erosion border", show=False).add_to(mapa)
folium.GeoJson(erosion_border_gdf["geometry"].to_crs(epsg=CONST.EPSG_WGS84)).add_to(fg_border)

folium.LayerControl().add_to(mapa)

mapa

## 2. Create a prediction model

The baseline model averages the changes in the river bank position in time for each region.

In [None]:
data_handler = DH.DataHandler(
    config=baseline_configuration,
    prediction_regions=prediction_regions,
    local_data_for_enrichment=local_geospatial_data,
    erosion_data=river_bank_locations,
    erosion_border=erosion_border,
)

data_handler.process_erosion_features()

In [None]:
print(data_handler.processed_erosion_data.shape)
data_handler.processed_erosion_data.head()

In [None]:
baseline_model = BM.BaselineErosionModel(
    config=baseline_configuration,
    training_data=data_handler.processed_erosion_data,
    verbose=True,
)

baseline_model.train()

In [None]:
baseline_model.model

## 3. Run and visualize the prediction

In [None]:
data_for_prediction = data_handler.processed_erosion_data.copy()

latest_timestamp = data_for_prediction.index.get_level_values(CONST.TIMESTAMP).unique().max()

data_for_prediction = data_for_prediction[data_for_prediction.index.get_level_values(CONST.TIMESTAMP) == latest_timestamp]

In [None]:
prediction = baseline_model.predict(data_for_prediction, prediction_length=10)

prediction