# Demo: Compare CNN output to benchmark FIM using autoeval process

Prerequisite: Repositories `autoeval-jobs`, `f1-trainer`, and input data.

To run locally:
1. `mkdir /home/pi_7/subcase_4_CNN`
2. Clone `f1-trainer` and `autoeval-jobs` to `/home/pi_7/subcase_4_CNN`
3. `mkdir autoeval-jobs/agreement_maker/data` and `mkdir f1-trainer/examples/data`
4. Place desired comparison in `f1/examples/data` as `model_domain.gpkg`, `cnn_prediction.tif`, and `benchmark.tif`
5. In `f1-trainer1`, `uv venv .venv` -> `source .venv/bin/activate` -> `uv pip install -r pyproject.toml --extra jupyter`
6. Run notebook

In [None]:
import os
from pathlib import Path

import geopandas as gpd
import numpy as np
import pandas as pd
import rasterio
import rasterio.mask
import xarray as xr
from rasterio.enums import Resampling

In [None]:
# run from pi_7/subcase_4_CNN
os.chdir("../../")
print(f"Working dir changed to {os.getcwd()}")
os.makedirs("./f1-trainer/examples/data", exist_ok=True)
os.makedirs("./autoeval-jobs/agreement_maker/data", exist_ok=True)
f1_data_path = Path("./f1-trainer/examples/data")
eval_data_path = Path("./autoeval-jobs/agreement_maker/data")

In [None]:
def generate_eval_fim(
    raster: str | Path,
    fim_extent: str | Path,
    output: str | Path,
    threshold: int,
    input_nodata: int = 1e20,
    output_nodata: int = 255,
):
    """Generate a FIM autoeval-compatible raster from CNN output

    Classify wet (1) where >= threshold

    Parameters
    ----------
    raster : str
        input path
    fim_extent : str
        fim model domain path (vector)
    output : str
        output path
    threshold : int
        classify wet (1) where >= threshold else dry (0)
    input_nodata : int, optional
        no data value in input, by default 1e20
    output_nodata : int, optional
        no data value for output, by default 255
    """
    with rasterio.open(raster, "r") as src:
        profile = src.profile

    print("Reading model extent")
    gdf = gpd.read_file(fim_extent, layer="model_domain")
    gdf = gdf.to_crs(profile["crs"])

    print("Clipping raster to model extent")
    with rasterio.open(raster, "r") as src:
        data, transform = rasterio.mask.mask(src, gdf.geometry.tolist(), all_touched=True, crop=True)

    print("Re-classifying raster to binary")
    data[data == input_nodata] = output_nodata
    data[(data < threshold) & (data != output_nodata)] = 0
    data[(data >= threshold) & (data != output_nodata)] = 1

    print(f"Writing output to {output}")
    profile.update(
        transform=transform,
        height=data.shape[1],
        width=data.shape[2],
        nodata=255,
        dtype=rasterio.uint8,
        tiled="YES",
        compress="deflate",
        blockxsize=512,
        blockysize=512,
    )
    with rasterio.open(output, "w", **profile) as dst:
        dst.write(data)


def resample_benchmark_fim(
    benchmark: str | Path, output_path: str | Path, target_resolution: int | float = 250
) -> None:
    """Helper to resample benchmark FIMs

    Parameters
    ----------
    benchmark : str
        path to FIM benchmark
    target_resolution : int | float
        output pixel resolution
    output_path : str
        output path
    """
    raster = xr.open_dataset(benchmark, engine="rasterio", chunked_array_type="cubed")

    raster = raster.rio.reproject(
        raster.rio.crs, resolution=target_resolution, resampling=Resampling.bilinear
    )

    raster["band_data"].rio.to_raster(
        output_path, tiled=True, compression="deflate", blockxsize=512, blockysize=512
    )

Generate a binary CNN FIM in benchmark FIM model domain

In [None]:
fim_extent = f1_data_path / "model_domain.gpkg"
cnn = f1_data_path / "cnn_prediction.tif"
cnn_output = eval_data_path / "test_cnn.tif"

generate_eval_fim(
    raster=cnn,
    threshold=45,
    fim_extent=fim_extent,
    output=cnn_output,
)

Resample Benchmark FIM to match CNN output

In [None]:
# may take 5+ minutes
fim_benchmark = f1_data_path / "benchmark.tif"
benchmark_output = eval_data_path / "test_benchmark.tif"

resample_benchmark_fim(
    benchmark=fim_benchmark,
    output_path=benchmark_output,
    target_resolution=250,
)

Run auto-eval make agreement process

In [None]:
%%bash
cd ~
cd pi_7/subcase_4_CNN/autoeval-jobs
docker compose build
docker compose up -d
docker compose exec make-agreement-dev bash
python make_agreement.py --fim_type extent --candidate_path /app/data/test_cnn.tif --benchmark_path /app/data/test_benchmark.tif --output_path /app/data/eval_output.tif --metrics_path /app/data/cnn_output_metrics.csv
chmod  777 ./data/eval_output.tif
exit
docker compose down


Display metrics

In [None]:
df = pd.read_csv(eval_data_path / "cnn_output_metrics.csv")
df.head()

Display maps

In [None]:
with rasterio.open(eval_data_path / "eval_output.tif", "r") as src:
    eval = src.read(1)
    eval = np.where(eval == 255, np.nan, eval)

with rasterio.open(cnn_output, "r") as src:
    cnn = src.read(1)
    cnn = np.where(cnn == 255, np.nan, cnn)

with rasterio.open(benchmark_output, "r") as src:
    benchmark = src.read(1)
    benchmark = np.where(benchmark == 255, np.nan, benchmark)

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

f, ax = plt.subplots(1, 3, figsize=(14, 5))
ax[0].imshow(benchmark, cmap="Blues")
ax[1].imshow(cnn, cmap="Blues")
ax[2].imshow(eval, cmap="viridis")
ax[0].set_axis_off()
ax[1].set_axis_off()
ax[2].set_axis_off()

# legend for benchmark and CNN
cmap = plt.get_cmap("Blues")
rgba = [cmap(0), cmap(0.99)]
legend_elements = [
    Patch(facecolor=rgba[1], edgecolor="black", label="Wet"),
    Patch(facecolor=rgba[0], edgecolor="black", label="Dry"),
]
ax[0].legend(title="Benchmark", handles=legend_elements, loc="lower left")
ax[1].legend(title="CNN", handles=legend_elements, loc="lower left")

# legend for eval
cmap = plt.get_cmap("viridis")
rgba = [cmap(0), cmap(0.33), cmap(0.66), cmap(0.99)]
legend_elements = [
    Patch(facecolor=rgba[0], edgecolor="black", label="True Negative"),
    Patch(facecolor=rgba[1], edgecolor="black", label="False Negative"),
    Patch(facecolor=rgba[2], edgecolor="black", label="False Positive"),
    Patch(facecolor=rgba[3], edgecolor="black", label="True Positive"),
]
ax[2].legend(title="Eval", handles=legend_elements, loc="lower left")