In [None]:
from pathlib import Path
from typing import Any

import cmcrameri as cmc  # noqa
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from rasterio import Affine
from rasterio.features import rasterize
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

In [None]:
# Setup Paths
assets: Path = Path("assets").resolve(strict=True)
filepath: Path = next(assets.glob("*2024*.nc"))

# Load Data
ds = xr.open_dataset(filepath)
slivers: gpd.GeoDataFrame = gpd.read_file(assets / "forest-features.geojson")

# Prepare the crs and affine transformation
spatial_ref: dict[str, Any] = ds["spatial_ref"].attrs
crs = spatial_ref["crs_wkt"]
geotransfrom: list[float] = [float(x) for x in spatial_ref["GeoTransform"].split()]
affine = Affine.from_gdal(*geotransfrom)

In [None]:
# Combine forest and non-forest
slivers["class_id"] = slivers["forest"].notna() + (slivers["non-forest"].notna() * 2)
slivers

In [None]:
# https://rasterio.readthedocs.io/en/stable/api/rasterio.features.html#rasterio.features.rasterize
rst = rasterize(
    shapes=slivers.to_crs(crs)[["geometry", "class_id"]]
    .to_records(index=False)
    .tolist(),
    out_shape=ds["Red"].shape[1:],
    transform=affine,
)

label_cube = xr.DataArray(rst, coords={"y": ds.y, "x": ds.x}, name="class_id")
feature_frame = (
    ds.drop_vars("spatial_ref").to_dataframe().dropna().unstack(level="quantile")
)

In [None]:
feature_frame.columns = [
    f"{c[0]}_P{np.round(100 * c[1], 0).astype(int):03}" for c in feature_frame.columns
]

In [None]:
label_frame = label_cube.to_dataframe()

In [None]:
train_frame = label_frame[label_frame["class_id"] != 0].join(feature_frame)

In [None]:
X_features = train_frame.drop(columns=["class_id"])
y_target = train_frame["class_id"]

In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    X_features,
    y_target,
    test_size=0.2,
    random_state=42,
)

In [None]:
randforest = RandomForestClassifier()
randforest_test = randforest.fit(X_train, y_train)
randforest_predict = randforest.predict(X_test)

In [None]:
img = (
    ds.drop_vars("spatial_ref")
    .to_dataarray()
    .transpose("x", "y", "quantile", "variable")
)

# Reshape the image data
num_of_pixels = img.sizes["x"] * img.sizes["y"]
num_of_bands = img.sizes["quantile"] * img.sizes["variable"]
X_image_data = img.values.reshape(num_of_pixels, num_of_bands)

In [None]:
randforest_predict_img = randforest.predict(X_image_data)
randforest_predict_img = randforest_predict_img.reshape(
    img.sizes["x"],
    img.sizes["y"],
).transpose()
plt.imshow(randforest_predict_img)

In [None]:
predicted_forest = xr.DataArray(
    randforest_predict_img,
    dims=("y", "x"),
    coords={
        "x": ds["x"],
        "y": ds["y"],
    },
)

In [None]:
fig, ax = plt.subplots()
ds["Red"].sel(quantile=0.5).plot.imshow(ax=ax)
predicted_forest.where(predicted_forest == 1).plot.imshow(ax=ax, cmap="Greens")
plt.show()

In [None]:
ds["Red"].sel(quantile=0.5).plot.imshow()