In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
# %load_ext cudf.pandas
# import pandas as pd
# print(pd)


In [None]:
# Optimizations
# GDAL optimizations
import multiprocessing as mp
import os

cpu_count: int = mp.cpu_count()
num_cores: int = cpu_count - 2
os.environ["GDAL_NUM_THREADS"] = f"{num_cores}"
os.environ["GDAL_CACHEMAX"] = "1024"


## Libraries

In [None]:
# Imports
from pathlib import Path
import numpy as np
import riskmapjnr as rmj


## Dask Processing Instance


In [None]:
import dask
from dask.distributed import LocalCluster, Client

cluster = LocalCluster(n_workers=8, threads_per_worker=2, memory_limit="2GB")
geospatial_client = Client(cluster)
geospatial_client


## Set user parameters

In [None]:
project_name = "test"


In [None]:
years = [2015, 2020, 2024]
tree_cover_threshold = 10
forest_source = "gfc"  ##gfc, tmf


In [None]:
coarse_grid_cell_size_pixels: list[int] = [300]
models_to_compare: list[str] = ["rmj_bm", "rmj_mw", "far_rf", "far_icar"]
periods: list[str] = ["calibration", "validation"]


In [None]:
# coarse_grid_cell_size_pixels = [50, 100]
# models_to_compare = ["rmj_bm", "rmj_mw", "rf", "icar", "glm", "user"]


## Connect folders

In [None]:
root_folder: Path = Path.cwd().parent
downloads_folder: Path = root_folder / "data"
downloads_folder.mkdir(parents=True, exist_ok=True)


In [None]:
project_folder = downloads_folder / project_name
project_folder.mkdir(parents=True, exist_ok=True)
processed_data_folder = project_folder / "data"
processed_data_folder.mkdir(parents=True, exist_ok=True)
evaluation_folder = project_folder / "evaluation"
evaluation_folder.mkdir(parents=True, exist_ok=True)


## Select predictions files

In [None]:
def list_folders(directory):
    """
    Lists all folders (directories) within a specified directory.

    Parameters:
        directory (str): The path to the directory from which to list folders.

    Returns:
        list: A list of folder names within the specified directory.
              If an error occurs, returns an empty list and prints an error message.
    """
    try:
        # Create a Path object for the directory
        path = Path(directory)

        # Filter out only directories (folders) using is_dir()
        folders = [entry for entry in path.iterdir() if entry.is_dir()]

        return folders
    except FileNotFoundError:
        print(f"The directory {directory} does not exist.")
        return []
    except Exception as e:
        print(f"An error occurred: {e}")
        return []


In [None]:
def filter_folders(input_folders, filter_words, exclude_words=None):
    """
    Filters a list of folders based on include and exclude words.
    Parameters:
        input_folders (list): List of folder names to be filtered.
        filter_words (list): Words that must be present in the folder names for inclusion.
        exclude_words (list, optional): Words that must not be present in the folder names for exclusion. Defaults to None.
    Returns:
        list: Filtered list of folders.
    """
    # Ensure all words are lowercase for case-insensitive comparison
    filter_words = [word.lower() for word in filter_words]
    exclude_words = [word.lower() for word in (exclude_words or [])]

    filtered_folders = [
        folder
        for folder in input_folders
        if any(word in folder.name.lower() for word in filter_words)
        and not any(
            exclude_word in folder.name.lower() for exclude_word in exclude_words
        )
    ]

    return filtered_folders


In [None]:
def list_files_by_extension(folder_path, file_extensions, recursive=False):
    """
    List all files with specified extensions in the given folder.
    Parameters:
    folder_path (str or Path): The path to the folder where you want to search for files.
    file_extensions (list of str): A list of file extensions to search for (e.g., ['.shp', '.tif']).
    recursive (bool): Whether to recursively search through subdirectories or not.
    Returns:
    list: A list of file paths with the specified extensions.
    """
    matching_files = []
    try:
        # Convert folder_path to Path object if it's a string
        folder_path = Path(folder_path)

        # Check if the provided path is a directory
        if folder_path.is_dir():
            for entry in folder_path.iterdir():
                if entry.is_file() and any(
                    entry.suffix.lower() == ext.lower() for ext in file_extensions
                ):
                    matching_files.append(str(entry))
                elif recursive and entry.is_dir():
                    # Recursively search subdirectories
                    matching_files.extend(
                        list_files_by_extension(entry, file_extensions, recursive)
                    )
        else:
            print(f"The provided path '{folder_path}' is not a directory.")
    except Exception as e:
        print(f"An error occurred: {e}")
    return matching_files


In [None]:
def filter_files(input_files, filter_words, exclude_words=None, include_all=True):
    """
    Filters a list of files based on include and exclude words.
    Parameters:
        input_files (list): List of file paths to be filtered.
        filter_words (list): Words that must be present in the filenames for inclusion.
        exclude_words (list, optional): Words that must not be present in the filenames for exclusion. Defaults to None.
        include_all (bool, optional): If True, all filter words must be present in the filename. If False, at least one of the filter words must be present. Defaults to False.
    Returns:
        list: Filtered list of files.
    """
    # Ensure all words are lowercase for case-insensitive comparison
    filter_words = [word.lower() for word in filter_words]
    exclude_words = [word.lower() for word in (exclude_words or [])]

    if include_all:
        filtered_files = [
            file
            for file in input_files
            if all(word in Path(file).name.lower() for word in filter_words)
            and not any(
                exclude_word in Path(file).name.lower()
                for exclude_word in exclude_words
            )
        ]
    else:
        filtered_files = [
            file
            for file in input_files
            if any(word in Path(file).name.lower() for word in filter_words)
            and not any(
                exclude_word in Path(file).name.lower()
                for exclude_word in exclude_words
            )
        ]

    return filtered_files


In [None]:
def filter_out_ipynb_checkpoints(input_files):
    """
    Filters out files whose paths contain '.ipynb_checkpoints'.
    Parameters:
        input_files (list): List of file paths to be filtered.
    Returns:
        list: Filtered list of files.
    """
    filtered_files = [
        file for file in input_files if ".ipynb_checkpoints" not in Path(file).parts
    ]
    return filtered_files


In [None]:
directory_path = project_folder
folders = list_folders(directory_path)
available_models = filter_folders(folders, models_to_compare, ["data", "data_raw"])
print("Models_available:", available_models)


In [None]:
folders = list_folders(project_folder)
available_models = filter_folders(folders, models_to_compare, ["data", "data_raw"])

available_prediction_files = []
for model_folder in available_models:
    tif_files = list_files_by_extension(model_folder, [".tif"], True)
    model_files = filter_files(tif_files, periods, None, False)
    available_prediction_files.append(model_files)

available_defrate_files = []
for model_folder in available_models:
    csv_files = list_files_by_extension(model_folder, [".csv"], True)
    defrate_files = filter_files(csv_files, periods, None, False)
    defrate_files1 = filter_files(defrate_files, ["defrate"])
    defrate_files2 = filter_out_ipynb_checkpoints(defrate_files1)
    available_defrate_files.append(defrate_files2)


In [None]:
available_prediction_files = sum(available_prediction_files, [])
available_defrate_files = sum(available_defrate_files, [])


In [None]:
# Create dictionaries mapping common names to paths
dict1 = {}
dict2 = {}


# Function to extract the subsystem/model part from path
def extract_subsystem(path):
    """Extract the subsystem identifier from path structure"""
    path_obj = Path(path)

    # Get the path components and look for key directories like rmj_bm, rmj_mw, etc.
    # The pattern is: .../test/{subsystem}/{validation|calibration}/...
    path_parts = path_obj.parts

    # Find the subsystem by looking for test directory and its immediate subdirectory
    try:
        test_index = path_parts.index("test")
        if test_index + 1 < len(path_parts):
            return path_parts[
                test_index + 1
            ]  # This should be the subsystem like rmj_bm
    except ValueError:
        pass

    return None


# Function to extract the period (validation/calibration/historical) from path
def extract_period(path):
    """Extract the period identifier from path structure"""
    path_obj = Path(path)

    # Get the path components
    path_parts = path_obj.parts

    # Find the period by looking for validation or calibration directories
    try:
        for i, part in enumerate(path_parts):
            if part in ["validation", "calibration", "historical"]:
                return part
    except ValueError:
        pass

    return None


for path in available_prediction_files:
    path_obj = Path(path)
    filename = path_obj.name
    subsystem = extract_subsystem(path)
    period = extract_period(path)

    name_no_ext = Path(filename).stem
    if name_no_ext.startswith("prob_"):
        identifier = name_no_ext[len("prob_") :]
    else:
        identifier: str = name_no_ext
    key = (identifier, subsystem, period)
    if key not in dict1:
        dict1[key] = []
    dict1[key].append(path)

for path in available_defrate_files:
    path_obj = Path(path)
    filename = path_obj.name
    subsystem = extract_subsystem(path)
    period = extract_period(path)

    name_no_ext = Path(filename).stem
    identifier = name_no_ext[len("defrate_cat_") :]

    key = (identifier, subsystem, period)
    if key not in dict2:
        dict2[key] = []
    dict2[key].append(path)


# Create the final matching dictionary - now we match by both identifier and subsystem
models_dict = {}
for key1 in dict1:
    if key1 in dict2:
        # Create all possible combinations between tiff and csv files with same identifier and subsystem
        identifier, subsystem, period = key1
        for tiff_path in dict1[key1]:
            for csv_path in dict2[key1]:
                models_dict[(identifier, subsystem, period)] = (
                    tiff_path,
                    csv_path,
                )


In [None]:
print("Final matching dictionary with all attributes:")
for key, value in models_dict.items():
    identifier, subsystem, period = key
    tiff_path, csv_path = value
    print(f"Identifier '{identifier}', Subsystem '{subsystem}', Period '{period}':")
    print(f"  TIFF: {tiff_path}")
    print(f"  CSV:  {csv_path}")
    print()


## Select forest cover change file

In [None]:
# List all raster files in the processed data folder
input_raster_files = list_files_by_extension(processed_data_folder, [".tiff", ".tif"])


In [None]:
forest_change_file = filter_files(
    input_raster_files,
    ["forest", "loss", forest_source] + [str(num) for num in years],
    ["distance", "edge"],
)[0]


## Periods dictionaries

In [None]:
calibration_dict = {
    "period": "calibration",
    "initial_year": years[0],
    "final_year": years[1],
    "defor_value": 1,
    "time_interval": years[1] - years[0],
}
validation_dict = {
    "period": "validation",
    "initial_year": years[1],
    "final_year": years[2],
    "defor_value": 1,
    "time_interval": years[2] - years[1],
}
historical_dict = {
    "period": "historical",
    "initial_year": years[0],
    "final_year": years[2],
    "defor_value": [1, 2],
    "time_interval": years[2] - years[0],
}


In [None]:
# Crear el diccionario principal
period_dict = {
    calibration_dict["period"]: calibration_dict,
    validation_dict["period"]: validation_dict,
    historical_dict["period"]: historical_dict,
}


## Compare models

In [None]:
import os
import numpy as np
import pandas as pd
import xarray as xr
import rioxarray as rxr
import dask
from dask import delayed
import matplotlib.pyplot as plt


def validation_udef_arp_distributed(
    fcc_file,
    time_interval,
    riskmap_file,
    tab_file_defor,
    period="calibration",
    csize_coarse_grid=300,
    indices_file_pred="indices.csv",
    tab_file_pred="pred_obs.csv",
    fig_file_pred="pred_obs.png",
    figsize=(6.4, 6.4),
    dpi=100,
    verbose=True,
):
    """Validation of the deforestation risk map with distributed processing.

    This function computes the observed and predicted deforestation (in ha)
    for either the calibration or validation period using distributed computing
    through dask and rioxarray. The function creates both a .csv file with
    the validation data and a plot comparing predictions vs observations.

    :param fcc_file: Input raster file of forest cover change at three dates (123).
        1: first period deforestation, 2: second period deforestation, 3: remaining forest
        at the end of the second period. No data value must be 0 (zero).

    :param period: Either "calibration" (from t1 to t2), "validation" (from t2 to t3),
        or "historical" (from t1 to t3).

    :param time_interval: Duration (in years) of the period.

    :param riskmap_file: Input raster file with categories of spatial deforestation
        risk at the beginning of the period.

    :param tab_file_defor: Path to the .csv input file with estimates of deforestation
        density (in ha/pixel/yr) for each category of deforestation risk.

    :param csize_coarse_grid: Spatial cell size in number of pixels. Must correspond to a
        distance < 10 km. Default to 300 corresponding to 9 km for a 30 m resolution raster.

    :param tab_file_pred: Path to the .csv output file with validation data.

    :param fig_file_pred: Path to the .png output file for the predictions vs observations plot.

    :param figsize: Figure size.

    :param dpi: Resolution for output image.

    :param verbose: Logical. Whether to print messages or not. Default to True.

    :return: A dictionary. With wRMSE, MedAE, and R2: weighted root mean squared error
        (in ha), median absolute error (in ha), and R-square respectively for the
        deforestation predictions, ncell: the number of grid cells with forest cover > 0
        at the beginning of the validation period, csize_coarse_grid: the coarse grid
        cell size in number of pixels, csize_coarse_grid_ha: the coarse grid cell size in ha.
    """

    # ==============================================================
    # Input data - Using rioxarray for improved geospatial handling
    # ==============================================================

    # Open raster files using rioxarray (supports lazy loading and chunking)
    fcc_ds = rxr.open_rasterio(
        fcc_file, chunks={"x": csize_coarse_grid, "y": csize_coarse_grid}
    )
    defor_cat_ds = rxr.open_rasterio(
        riskmap_file, chunks={"x": csize_coarse_grid, "y": csize_coarse_grid}
    )

    # Get defor_dens per cat
    defor_dens_per_cat = pd.read_csv(tab_file_defor)
    cat = defor_dens_per_cat["cat"].values

    # Pixel area (in unit square, eg. meter square)
    pix_area = fcc_ds.attrs.get("pixel_size", 1) * abs(
        fcc_ds.attrs.get("pixel_height", 1)
    )

    # If we don't have pixel size info from the raster, calculate it
    if pix_area == 1:
        # Get geotransform info if available
        try:
            gt = fcc_ds.rio.transform()
            pix_area = abs(gt[0] * gt[4])  # Pixel area in meters squared
        except:
            pix_area = 900  # Default to 30m x 30m pixels

    # Get the data arrays (squeeze to remove any time dimensions)
    fcc_data = fcc_ds.squeeze()
    defor_cat_data = defor_cat_ds.squeeze()

    # ==============================================================
    # Distributed Processing using dask delayed functions
    # ==============================================================

    @delayed
    def process_chunk(
        fcc_chunk,
        defor_cat_chunk,
        defor_dens_per_cat,
        cat,
        period,
        time_interval,
        pix_area,
    ):
        """Process a single chunk of data."""
        # Calculate observed values
        if period == "calibration":
            nfor_obs = int(fcc_chunk.where(fcc_chunk > 0).count().values)
            ndefor_obs = int(fcc_chunk.where(fcc_chunk == 1).count().values)
        elif period == "validation":
            nfor_obs = int(fcc_chunk.where(fcc_chunk > 1).count().values)
            ndefor_obs = int(fcc_chunk.where(fcc_chunk == 2).count().values)
        else:  # historical
            nfor_obs = int(fcc_chunk.where(fcc_chunk > 0).count().values)
            ndefor_obs = int(
                fcc_chunk.where((fcc_chunk == 1) | (fcc_chunk == 2)).count().values
            )

        # Calculate predicted deforestation
        defor_cat_flat = defor_cat_chunk.values.flatten()
        defor_cat_series = pd.Series(defor_cat_flat)

        # Count occurrences of each category in this chunk
        cat_counts = defor_cat_series.value_counts()

        # Get deforestation density for each category
        defor_dens = defor_dens_per_cat["defor_dens"].values
        defor_dens_period = defor_dens * time_interval

        # Calculate predicted deforestation area in ha
        ndefor_pred_ha = 0.0
        for cat_val, count in cat_counts.items():
            if cat_val in cat:
                idx = np.where(cat == cat_val)[0][0]
                ndefor_pred_ha += count * defor_dens_period[idx] * pix_area / 10000

        # Return results for this chunk
        return {
            "nfor_obs": nfor_obs,
            "ndefor_obs": ndefor_obs,
            "nfor_obs_ha": nfor_obs * pix_area / 10000,
            "ndefor_obs_ha": ndefor_obs * pix_area / 10000,
            "ndefor_pred_ha": ndefor_pred_ha,
        }

    # Create delayed tasks for all chunks (dask will handle the chunking)
    tasks = []

    # Get all chunks from the dask arrays - this leverages the chunking already set up
    # Dask will automatically distribute processing of these chunks across available cores
    for i in range(len(fcc_data.chunks[0])):  # Iterate over y chunks
        for j in range(len(fcc_data.chunks[1])):  # Iterate over x chunks
            # Extract chunk data
            y_start = sum(fcc_data.chunks[0][:i])
            y_end = y_start + fcc_data.chunks[0][i]
            x_start = sum(fcc_data.chunks[1][:j])
            x_end = x_start + fcc_data.chunks[1][j]

            fcc_chunk = fcc_data.isel(y=slice(y_start, y_end), x=slice(x_start, x_end))
            defor_cat_chunk = defor_cat_data.isel(
                y=slice(y_start, y_end), x=slice(x_start, x_end)
            )

            task = process_chunk(
                fcc_chunk,
                defor_cat_chunk,
                defor_dens_per_cat,
                cat,
                period,
                time_interval,
                pix_area,
            )
            tasks.append(task)

    # Compute all tasks in parallel
    if verbose:
        print("Processing chunks in parallel...")

    results = dask.compute(*tasks)

    # ==============================================================
    # Combine results and compute validation metrics
    # ==============================================================

    # Create DataFrame from results
    df_data = []
    for result in results:
        if result["nfor_obs"] > 0:  # Only include cells with forest cover
            df_data.append(result)

    df = pd.DataFrame(df_data)

    # If no cells with forest cover, return error
    if df.empty:
        raise ValueError(
            "No grid cells with forest cover found. Please decrease the spatial cell size 'csize_coarse_grid'."
        )

    ncell = len(df)

    # Cell size in ha
    csize_coarse_grid_ha = round(
        csize_coarse_grid * csize_coarse_grid * pix_area / 10000, 2
    )

    # Export the table of results
    df.to_csv(tab_file_pred, sep=",", header=True, index=False, index_label=False)

    # Prediction error
    error_pred = df["ndefor_pred_ha"] - df["ndefor_obs_ha"]

    # Compute RMSE
    squared_error = (error_pred) ** 2
    RMSE = round(np.sqrt(np.mean(squared_error)), 2)

    # Compute wRMSE
    w = df["nfor_obs_ha"] / df["nfor_obs_ha"].sum()
    wRMSE = round(np.sqrt(sum(squared_error * w)), 2)

    # Compute MedAE
    MedAE = round(np.median(np.absolute(error_pred)), 2)

    # Calculate R square
    # Get the correlation coefficient
    r = np.corrcoef(df["ndefor_pred_ha"], df["ndefor_obs_ha"])[0, 1]
    # Square the correlation coefficient
    r_square = round(r**2, 2)

    # Identify model from file
    model_basename = os.path.basename(riskmap_file)
    model_name = model_basename[5:-7]

    # Plot title
    title = (
        "{0} model, {1} period\n"
        "Predicted vs. observed deforestation "
        "in {2} ha grid cells."
    )
    title = title.format(model_name, period, csize_coarse_grid_ha)

    # Points for identity line
    p = [
        df[["ndefor_obs_ha", "ndefor_pred_ha"]].min(axis=None),
        df[["ndefor_obs_ha", "ndefor_pred_ha"]].max(axis=None),
    ]

    # Plot predictions vs. observations
    fig = plt.figure(figsize=figsize, dpi=dpi)
    ax = plt.subplot(111)
    ax.set_box_aspect(1)
    plt.scatter(
        df["ndefor_obs_ha"], df["ndefor_pred_ha"], color=None, marker="o", edgecolor="k"
    )
    plt.plot(p, p, "r--")
    plt.title(title)
    plt.xlabel("Observed deforestation (ha)")
    plt.ylabel("Predicted deforestation (ha)")
    # Text indices and ncell
    t = "MedAE = {0:.2f} ha\nR2 = {1:.2f}\nn = {2:d}"
    t = t.format(MedAE, r_square, ncell)
    x_text = 0
    y_text = df[["ndefor_obs_ha", "ndefor_pred_ha"]].max(axis=None)
    plt.text(x_text, y_text, t, ha="left", va="top")
    fig.savefig(fig_file_pred)
    plt.close(fig)

    # Results
    indices = {
        "RMSE": RMSE,
        "wRMSE": wRMSE,
        "MedAE": MedAE,
        "R2": r_square,
        "ncell": ncell,
        "csize_coarse_grid": csize_coarse_grid,
        "csize_coarse_grid_ha": csize_coarse_grid_ha,
    }
    indices_df = pd.DataFrame([indices])
    indices_df.to_csv(
        indices_file_pred, sep=",", header=True, index=False, index_label=False
    )

    return indices


In [None]:
"""xarray + dask optimized version of `validation_udef_arp` with edge padding.

Enhancements:
- Uses rioxarray exclusively (assumed installed).
- Handles partial tiles at edges by padding arrays before coarsening.
- Aggregates data efficiently using Dask-backed xarray operations.

This version avoids explicit for-loops, uses block-wise aggregation, and
pads edges to ensure all data is included in coarse-grid computations.
"""

from typing import Dict
import os

import numpy as np
import pandas as pd
import xarray as xr
import dask.array as da
import matplotlib.pyplot as plt
import rioxarray  # assume installed


def validation_udef_arp_xr(
    fcc_file: str,
    time_interval: float,
    riskmap_file: str,
    tab_file_defor: str,
    period: str = "calibration",
    csize_coarse_grid: int = 300,
    indices_file_pred: str = "indices.csv",
    tab_file_pred: str = "pred_obs.csv",
    fig_file_pred: str = "pred_obs.png",
    figsize=(6.4, 6.4),
    dpi=100,
    verbose: bool = True,
) -> Dict:
    """xarray/dask optimized validation function with padding edge handling."""

    # -------------------------
    # Read density table
    # -------------------------
    defor_tab = pd.read_csv(tab_file_defor)
    cats = defor_tab["cat"].values
    defor_dens = defor_tab["defor_dens"].values
    defor_dens_period = defor_dens * time_interval

    # -------------------------
    # Open rasters with rioxarray + dask chunks
    # -------------------------
    chunks = {"x": csize_coarse_grid, "y": csize_coarse_grid}

    fcc = rioxarray.open_rasterio(fcc_file, chunks=chunks).squeeze()
    risk = rioxarray.open_rasterio(riskmap_file, chunks=chunks).squeeze()

    if "band" in fcc.dims:
        fcc = fcc.isel(band=0)
    if "band" in risk.dims:
        risk = risk.isel(band=0)

    # Get pixel area
    transform = fcc.rio.transform()
    pix_w = abs(transform.a)
    pix_h = abs(transform.e)
    pix_area = pix_w * pix_h

    csize_coarse_grid_ha = round(
        csize_coarse_grid * csize_coarse_grid * pix_area / 10000.0, 2
    )

    # -------------------------
    # Create observed forest and deforestation masks
    # -------------------------
    if period == "calibration":
        forest_mask = fcc > 0
        defor_mask = fcc == 1
    elif period == "validation":
        forest_mask = fcc > 1
        defor_mask = fcc == 2
    else:  # historical
        forest_mask = fcc > 0
        defor_mask = xr.apply_ufunc(
            lambda a: np.isin(a, [1, 2]),
            fcc,
            vectorize=True,
            dask="parallelized",
            output_dtypes=[bool],
        )

    forest_int = forest_mask.astype(int)
    defor_int = defor_mask.astype(int)

    # -------------------------
    # Pad to make dimensions divisible by csize_coarse_grid
    # -------------------------
    ny = forest_int.sizes["y"]
    nx = forest_int.sizes["x"]
    pad_y = (csize_coarse_grid - ny % csize_coarse_grid) % csize_coarse_grid
    pad_x = (csize_coarse_grid - nx % csize_coarse_grid) % csize_coarse_grid
    # Ensure risk can accept NaNs
    if np.issubdtype(risk.dtype, np.integer):
        risk = risk.astype(float)
    if pad_y > 0 or pad_x > 0:
        pad_width = {"y": (0, pad_y), "x": (0, pad_x)}
        forest_int = forest_int.pad(pad_width, mode="constant", constant_values=0)
        defor_int = defor_int.pad(pad_width, mode="constant", constant_values=0)
        risk = risk.pad(pad_width, mode="constant", constant_values=np.nan)

    # -------------------------
    # Aggregate using coarsen (safe with divisible dims)
    # -------------------------
    factor_y = csize_coarse_grid
    factor_x = csize_coarse_grid

    forest_coarse = forest_int.coarsen(y=factor_y, x=factor_x).sum()
    defor_coarse = defor_int.coarsen(y=factor_y, x=factor_x).sum()

    # Category counts per coarse cell
    cat_counts_list = []
    for c in cats:
        mask = (risk == c).astype(int)
        mask_coarse = mask.coarsen(y=factor_y, x=factor_x).sum()
        cat_counts_list.append(mask_coarse)

    cat_counts = xr.concat(cat_counts_list, dim="cat").assign_coords(cat=("cat", cats))

    dens_da = xr.DataArray(defor_dens_period, coords={"cat": cats}, dims=("cat",))
    pred_ha_per_cat = cat_counts * dens_da
    pred_ha_coarse = pred_ha_per_cat.sum(dim="cat")

    forest_ha_coarse = forest_coarse * (pix_area / 10000.0)
    defor_ha_coarse = defor_coarse * (pix_area / 10000.0)

    # -------------------------
    # Compute results (trigger computation)
    # -------------------------
    if verbose:
        print("Computing aggregated arrays with padding...")

    forest_ha = forest_ha_coarse.compute()
    defor_ha_obs = defor_ha_coarse.compute()
    pred_ha = pred_ha_coarse.compute()

    df = (
        xr.Dataset(
            {
                "nfor_obs_ha": forest_ha,
                "ndefor_obs_ha": defor_ha_obs,
                "ndefor_pred_ha": pred_ha,
            }
        )
        .to_dataframe()
        .reset_index()
        .dropna()
    )

    df = df[df["nfor_obs_ha"] > 0]
    ncell = df.shape[0]

    df["nfor_obs"] = (df["nfor_obs_ha"] * 10000.0 / pix_area).round().astype(int)
    df["ndefor_obs"] = (df["ndefor_obs_ha"] * 10000.0 / pix_area).round().astype(int)

    df_out = df[
        [
            "y",
            "x",
            "nfor_obs",
            "ndefor_obs",
            "nfor_obs_ha",
            "ndefor_obs_ha",
            "ndefor_pred_ha",
        ]
    ]
    df_out.to_csv(tab_file_pred, index=False)

    # -------------------------
    # Compute metrics
    # -------------------------
    error_pred = df_out["ndefor_pred_ha"] - df_out["ndefor_obs_ha"]
    squared_error = error_pred**2
    RMSE = round(np.sqrt(np.mean(squared_error)), 2)

    w = df_out["nfor_obs_ha"] / df_out["nfor_obs_ha"].sum()
    wRMSE = round(np.sqrt(np.sum(squared_error * w)), 2)
    MedAE = round(np.median(np.abs(error_pred)), 2)

    if ncell > 1:
        r = np.corrcoef(df_out["ndefor_pred_ha"], df_out["ndefor_obs_ha"])[0, 1]
        r_square = round(float(r**2), 2)
    else:
        r_square = np.nan

    # -------------------------
    # Plot
    # -------------------------
    title = f"Predicted vs Observed Deforestation ({period}) in {csize_coarse_grid_ha} ha cells"
    pmin = min(df_out[["ndefor_obs_ha", "ndefor_pred_ha"]].min().min(), 0)
    pmax = df_out[["ndefor_obs_ha", "ndefor_pred_ha"]].max().max()

    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
    ax.scatter(
        df_out["ndefor_obs_ha"], df_out["ndefor_pred_ha"], marker="o", edgecolor="k"
    )
    ax.plot([pmin, pmax], [pmin, pmax], "r--")
    ax.set_xlabel("Observed deforestation (ha)")
    ax.set_ylabel("Predicted deforestation (ha)")
    ax.set_title(title)

    t = f"MedAE = {MedAE:.2f} ha\nR2 = {r_square}\nn = {ncell:d}"
    ax.text(0, pmax, t, ha="left", va="top")

    fig.savefig(fig_file_pred)
    plt.close(fig)

    # -------------------------
    # Save indices
    # -------------------------
    indices = {
        "RMSE": RMSE,
        "wRMSE": wRMSE,
        "MedAE": MedAE,
        "R2": r_square,
        "ncell": int(ncell),
        "csize_coarse_grid": csize_coarse_grid,
        "csize_coarse_grid_ha": csize_coarse_grid_ha,
    }
    pd.DataFrame([indices]).to_csv(indices_file_pred, index=False)

    return indices


In [None]:
import forestatrisk as far


def compare_models(
    fcc_file,
    csizes_val,
    val_periods,
    val_models,
    available_prediction_files,
    available_defrate_files,
    period_dict,
):
    for csize_val in csizes_val:
        for period in val_periods:
            period_output_folder = evaluation_folder / period
            period_output_folder.mkdir(parents=True, exist_ok=True)
            for model in val_models:
                riskmap_file = filter_files(
                    available_prediction_files, [model, period], None, True
                )[0]
                defrate_file = filter_files(
                    available_defrate_files, [model, period], None, True
                )[0]
                far.validation_udef_arp(
                    # validation_udef_arp_xr(
                    fcc_file=fcc_file,
                    period=period,
                    time_interval=period_dict[period]["time_interval"],
                    riskmap_file=riskmap_file,
                    tab_file_defor=defrate_file,
                    csize_coarse_grid=csize_val,
                    indices_file_pred=period_output_folder
                    / f"indices_{model}_{period}_{csize_val}.csv",
                    tab_file_pred=period_output_folder
                    / f"pred_obs_{model}_{period}_{csize_val}.csv",
                    fig_file_pred=period_output_folder
                    / f"pred_obs_{model}_{period}_{csize_val}.png",
                    verbose=False,
                )


In [None]:
import forestatrisk as far
from typing import Dict


def compare_models(
    fcc_file: Path,
    csizes_val: list[int],
    models_dict: Dict,
    period_dict: Dict,
):
    for csize_val in csizes_val:
        for key, value in models_dict.items():
            identifier, subsystem, period = key
            riskmap_file, defrate_file = value
            period_output_folder = evaluation_folder / period
            period_output_folder.mkdir(parents=True, exist_ok=True)
            far.validation_udef_arp(
                # validation_udef_arp_xr(
                fcc_file=fcc_file,
                period=period,
                time_interval=period_dict[period]["time_interval"],
                riskmap_file=riskmap_file,
                tab_file_defor=defrate_file,
                csize_coarse_grid=csize_val,
                indices_file_pred=period_output_folder
                / f"indices_{identifier}_{csize_val}.csv",
                tab_file_pred=period_output_folder
                / f"pred_obs_{identifier}_{csize_val}.csv",
                fig_file_pred=period_output_folder
                / f"pred_obs_{identifier}_{csize_val}.png",
                verbose=False,
            )


In [None]:
compare_models(
    forest_change_file,
    coarse_grid_cell_size_pixels,
    models_dict,
    period_dict,
)


## Join all the indices 

In [None]:
evaluation_csv_files = list_files_by_extension(evaluation_folder, [".csv"], True)
indices_csv_files = filter_files(evaluation_csv_files, ["indices"], None, False)
indices_csv_files_clean = filter_out_ipynb_checkpoints(indices_csv_files)


In [None]:
import pandas as pd


from pathlib import Path


def extract_info_from_filename(filepath):
    """
    Extracts period and model from a given filename.

    Args:
        filepath (str): The full path to the file.

    Returns:
        tuple: A tuple containing (period, model).
    """
    # Convert the filepath to a Path object
    path = Path(filepath)

    # Get the filename without the extension
    filename = path.stem

    # Split the filename by underscores
    parts = filename.split("_")

    # The period is always the last part before the number (which is the last part)
    # Find where the numeric part starts
    for i in range(len(parts) - 1, 0, -1):
        if parts[i].isdigit():
            # The model is between 'indices' and the period
            period = parts[i - 1]
            # The model name can contain underscores
            model_parts = parts[
                1 : i - 1
            ]  # Skip 'indices' (index 0) and period (index i-1)
            model = "_".join(model_parts)
            return period, model

    # If no numeric part is found, fallback to original logic
    period = parts[-1]
    model = parts[-2]
    return period, model


def combine_model_results(indices_files_list):
    """Combine model results for comparison."""
    indices_list = []
    for file in indices_files_list:
        if Path(file).is_file():
            period, model = extract_info_from_filename(file)
            df = pd.read_csv(file)
            df["model"] = model
            df["period"] = period
            indices_list.append(df)
        # Concat indices
        indices = pd.concat(indices_list, axis=0)
        indices.sort_values(by=["csize_coarse_grid", "period", "model"])
        indices = indices[
            [
                "csize_coarse_grid",
                "csize_coarse_grid_ha",
                "ncell",
                "period",
                "model",
                "MedAE",
                "R2",
                "RMSE",
                "wRMSE",
            ]
        ]
    indices.to_csv(
        os.path.join(evaluation_folder, "indices_all.csv"),
        sep=",",
        header=True,
        index=False,
        index_label=False,
    )


In [None]:
combine_model_results(indices_csv_files_clean)
