# Workshop Context and Objectives

This notebook is a hands-on workshop exercise demonstrating how to build a workflow from remote sensing and climate data to estimating missing spatio-temporal information using an analogue approach.

## Context

<details>
<summary><b>Details</b></summary>

Suppose we want to run a fully distributed hydrological model over the Volta River Basin. Such models require continuous daily inputs over long time periods, but in many regions observational datasets are incomplete, temporally sparse, or unavailable for recent years. To overcome this limitation, we will use a data-driven analogue method to generate the missing information needed to run the model.

<div style="display:flex; gap:10px;">

<img src="https://raw.githubusercontent.com/LoicGerber/synthetic_data_generation_workshop/52acd5850ad2311637ecf5704099c15e947f975e/isohyets.png" width="35%">

<img src="https://raw.githubusercontent.com/LoicGerber/synthetic_data_generation_workshop/52acd5850ad2311637ecf5704099c15e947f975e/dem.png" width="34.8%">

</div>

In this workshop, we will use GLEAM evapotranspiration data as our target variable. The goal is to learn the relationships between climate conditions and evapotranspiration from this reference dataset, and then generate synthetic images that statistically and spatially resemble GLEAM. To achieve this, we will use ERA5-Land temperature and precipitation as predictor variables, allowing us to reconstruct realistic evapotranspiration fields even when observations are missing.

</details>

## Part 1 - Accessing, downloading, and visualizing data

<details>
<summary><b>Details</b></summary>

In the first part, we focus on data acquisition and preprocessing:
- Access remote sensing and reanalysis data directly from Google Earth Engine (GEE).
- Download Actual Evapotranspiration (AET) from the WaPOR Level 1 product (FAO).
- Download daily climate predictors (precipitation and temperature) from ERA5-Land.
- Spatially subset all datasets to the Volta River Basin using HydroSHEDS geometries.
- Convert GEE ImageCollections into local NumPy arrays for simple python handling.
- Save and reload processed datasets efficiently using compressed .npz files for reproducibility.
- Visualize AET, precipitation, and temperature maps for qualitative inspection.

</details>

## Part 2 – kNN-based analogue modeling: method understanding and validation

<details>
<summary><b>Details</b></summary>

This second part focuses on understanding the kNN-based analogue method through a simple, step-by-step demonstration. Following this illustrative example, the method is applied to reconstruct three full years of daily data within a validation framework, allowing for qualitative and quantitative assessment of its performance.
- Load pre-processed NetCDF datasets of ET (target), precipitation, and temperature (predictors) clipped to the Volta Basin.
- Ensure temporal alignment of predictors and target.
- Optionally subset the spatial domain for faster computation.
- Construct covariate features, including lagged predictors (*climate window*), to capture temporal dynamics.
- Split the dataset into training and testing periods based on years.
- Apply k-nearest neighbors (kNN) regression, where each unobserved ET image is estimated by finding the most similar historical climate “analogues.”
- Visualize predicted versus observed ET maps for qualitative assessment.
- Perform quantitative evaluation using Root Mean Squared Error (RMSE) over space and time.

</details>

## Part 3 – kNN-based analogue modeling: production run for 2021–2025

<details>
<summary><b>Details</b></summary>

In this final part, we move from method validation to a full production run, applying the kNN-based analogue approach to estimate daily ET for the period 2021–2025, using all available historical observations up to 2020 as training data.

Key steps include:
- Prepare training datasets from ET (target) and climate predictors (precipitation and temperature) for all available historical data.
- Apply the kNN analogue method to generate daily ET estimates for 2021–2025.
- Compute _analogue-based uncertainty_ for each generated day, defined as the weighted standard deviation across the k nearest analogues.
- Visualize daily ET reconstructions and their uncertainty.
- Save the production datasets in both NetCDF and compressed `.npz` formats.
- Copy final outputs to Google Drive for long-term storage and reproducibility.

</details>

## Learning Outcomes

<details>
<summary><b>Details</b></summary>

By the end of this workshop, participants will be able to:
1. Access and process geospatial datasets from GEE and local NetCDF files.
2. Handle spatio-temporal data using `NumPy` and `xarray` efficiently.
3. Apply masking and subsetting to focus on a specific river basin.
4. Understand and implement a kNN-based analogue approach for estimating missing or unobserved images.
5. Validate analogue-based reconstructions using visual diagnostics and RMSE.
6. Apply the same method in a production setting to temporally disaggregate remote sensing products.
7. Build reproducible workflows that integrate data acquisition, preprocessing, modeling, validation, and production runs.

</details>

___

# Part 2 – kNN-based analogue modeling: method understanding and validation using pre-downloaded data

In this second part, we apply an analogue method to estimate missing evapotranspiration (ET) images using climate predictors. The core idea is simple:

>For a given day with unknown ET, we search the historical archive for days with similar climate conditions (analogues), and we reconstruct ET by combining their observed ET maps using kNN.

Here, the predictors describe the climatic state of a day (precipitation and temperature), while ET represents the hydrological response. By identifying past days with similar predictor patterns, we assume that their ET spatial patterns are informative analogues for the target day.

Before applying this method in a production context, we first demonstrate it step by step for a single day, purely for didactic purposes. We then apply the method to reconstruct three complete years of data in a validation setup, allowing for both qualitative and quantitative evaluation.


## Import libraries and define helper functions

> **Note:** this block must be run, but we will not look into it in details.

In this first step, we import all Python libraries required throughout the notebook. Key libraries include:
- `NumPy` and `xarray` for numerical data handling
- `matplotlib` for visualization
- `scikit-learn` for kNN regression

In [None]:
# Import required libraries
import numpy as np
from google.colab import files, drive
import matplotlib.pyplot as plt
import xarray as xr
import gdown
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import pairwise_distances
import gc

### Preparing the data

We define a function `prepare_data` to:
- Construct covariates (precipitation and temperature)
- Optionally include lagged values (`time_window`)
- Split the dataset into training and testing sets based on `test_periods`

In [None]:
def prepare_data(et_ds, pre_ds, tmax_ds, time_window, test_periods):
    """
    Prepare X_train, y_train, X_test, y_test, dates_train, dates_test
    from ET (target) and covariates (pre, tmax).

    Parameters
    ----------
    et_ds : xarray.Dataset
        Dataset containing the target variable (ET).
    pre_ds : xarray.Dataset
        Dataset containing the precipitation predictor.
    tmax_ds : xarray.Dataset
        Dataset containing the maximum temperature predictor.
    time_window : int
        Number of lag days for covariates (e.g., if 2, includes t, t-1, t-2).
    test_periods : list of (start_date, end_date)
        Dates defining test periods.

    Returns
    -------
    X_train, y_train, X_test, y_test, dates_train, dates_test
    """

    # ==========================================================
    # 1) EXTRACT DATAARRAYS FROM DATASETS
    # ==========================================================
    # Assume each dataset contains a single variable
    target_var = list(et_ds.data_vars)[0]
    pre_var    = list(pre_ds.data_vars)[0]
    tmax_var   = list(tmax_ds.data_vars)[0]

    # Extract DataArrays
    et_da   = et_ds[target_var]
    pre_da  = pre_ds[pre_var]
    tmax_da = tmax_ds[tmax_var]

    # Ensure target has a time dimension
    if "time" not in et_da.dims:
        raise ValueError("ET DataArray has no 'time' dimension.")
    time_dim = "time"

    # ==========================================================
    # 2) ALIGN PREDICTORS WITH TARGET TIME AXIS
    # ==========================================================
    # Use ET time as reference
    time = et_da[time_dim]

    # Subset predictors to match ET dates
    pre_da  = pre_da.sel({time_dim: time})
    tmax_da = tmax_da.sel({time_dim: time})

    # ==========================================================
    # 3) BUILD FEATURE STACK (INCLUDING LAGGED VARIABLES)
    # ==========================================================
    feat_list = []

    if time_window == 0:
        # Only current-day predictors
        feat_list.extend([pre_da, tmax_da])

        # Effective ET is unchanged
        et_eff = et_da

    else:
        # Include current day and lagged days
        for da in (pre_da, tmax_da):

            # Current day
            feat_list.append(da)

            # Add lagged predictors
            for lag in range(1, time_window + 1):
                feat_list.append(
                    da.shift({time_dim: lag})
                )

        # Remove first days where lagged data is incomplete
        et_eff = et_da.isel({time_dim: slice(time_window, None)})
        feat_list = [
            da.isel({time_dim: slice(time_window, None)})
            for da in feat_list
        ]

        # Update time coordinate accordingly
        time = et_eff[time_dim]

    # Concatenate predictors along a new "feature" dimension
    features = xr.concat(feat_list, dim="feature")

    # ==========================================================
    # 4) CREATE TRAIN / TEST SPLIT BASED ON DATE PERIODS
    # ==========================================================
    # Ensure time coordinate is datetime
    if not np.issubdtype(time.dtype, np.datetime64):
        raise TypeError("time coordinate is not datetime-like.")

    # Initialize boolean mask for test periods
    test_mask = np.zeros(time.size, dtype=bool)

    # Mark all dates belonging to test periods
    for start, end in test_periods:
        test_mask |= (
            (time >= np.datetime64(start)) &
            (time <= np.datetime64(end))
        )

    # Training mask is inverse
    train_mask = ~test_mask

    # Convert masks to indices
    train_time_idx = np.where(train_mask)[0]
    test_time_idx  = np.where(test_mask)[0]

    # Safety checks
    if train_time_idx.size == 0:
        raise ValueError("No training samples left after applying test periods.")
    if test_time_idx.size == 0:
        raise ValueError("No test samples found for the given test periods.")

    # Subset datasets
    features_train = features.isel({time_dim: train_time_idx})
    features_test  = features.isel({time_dim: test_time_idx})

    et_train = et_eff.isel({time_dim: train_time_idx})
    et_test  = et_eff.isel({time_dim: test_time_idx})

    # ==========================================================
    # 5) CONVERT XARRAY → NUMPY ARRAYS
    # ==========================================================
    # Identify spatial dimensions (everything except time)
    spatial_dims = [d for d in et_eff.dims if d != time_dim]

    # Reorder dimensions for ML-friendly format:
    # (time, features, lat, lon)
    X_train = features_train.transpose(
        time_dim, "feature", *spatial_dims
    ).values

    X_test = features_test.transpose(
        time_dim, "feature", *spatial_dims
    ).values

    # Target arrays: (time, lat, lon)
    y_train = et_train.transpose(
        time_dim, *spatial_dims
    ).values

    y_test = et_test.transpose(
        time_dim, *spatial_dims
    ).values

    # ==========================================================
    # 6) EXTRACT DATE ARRAYS
    # ==========================================================
    dates_train = et_train[time_dim].values
    dates_test  = et_test[time_dim].values

    # ==========================================================
    # RETURN RESULTS
    # ==========================================================
    return X_train, y_train, X_test, y_test, dates_train, dates_test

### kNN estimation

- For each test day, kNN identifies climate analogues in the training set.
- ET is reconstructed from the analogue ET maps.
- Predictions are generated independently for each day.
- For each query day, the `k_neighbors` closest analogues are weighted by similarity (inverse distance).
- The predicted ET is the weighted mean of these analogues.
- The "uncertainty" is the weighted standard deviation of the analogue ET values around the mean.

> **Note:** This uncertainty reflects the spread of the analogue ensemble not a formal statistical confidence interval.  
> It indicates regions or days where the selected analogues are less consistent (larger spread = higher analogue-based uncertainty), not a formal statistical confidence interval.

In [None]:
def knn_predict_with_uncertainty(X_train, y_train, X_query, k_neighbors):
    """
    KNN analogue prediction with uncertainty from analogue spread.

    Parameters
    ----------
    X_train : (T_train, n_features, latX, lonX)
        Predictor variables for training period
    y_train : (T_train, latY, lonY)
        Target variable for training period
    X_query : (T_query, n_features, latX, lonX)
        Predictor variables for query period
    k_neighbors : int
        Number of analogues (nearest neighbors)

    Returns
    -------
    y_mean : (T_query, latY, lonY)
        Predicted mean field
    y_std  : (T_query, latY, lonY)
        Analogue-based uncertainty (weighted std)
    """

    # -------------------------------
    # Extract dimensions
    # -------------------------------
    T_train, n_features, n_latX, n_lonX = X_train.shape
    T_query = X_query.shape[0]
    _, n_latY, n_lonY = y_train.shape

    # Total number of spatial pixels
    NX = n_latX * n_lonX
    NY = n_latY * n_lonY

    # ==========================================================
    # X MASKING — remove pixels that contain NaNs in ANY time or feature
    # ==========================================================

    # Combine train + query to ensure consistent valid mask
    X_all = np.concatenate([X_train, X_query], axis=0)

    # Valid pixel mask → True where all values finite across time + features
    valid_X = np.all(np.isfinite(X_all), axis=(0, 1))

    # Flatten spatial mask for vector indexing
    mask_X = valid_X.ravel()

    # Reshape predictors → (time, features, pixels)
    # then keep only valid pixels
    Xtr = X_train.reshape(T_train, n_features, NX)[:, :, mask_X]
    Xq  = X_query.reshape(T_query, n_features, NX)[:, :, mask_X]

    # Flatten feature + pixel dimensions into single vector
    # final shapes:
    #   Xtr_vec = (T_train, n_valid_features)
    #   Xq_vec  = (T_query, n_valid_features)
    Xtr_vec = Xtr.reshape(T_train, -1)
    Xq_vec  = Xq.reshape(T_query,  -1)

    # ==========================================================
    # Y MASKING — remove pixels invalid in training target
    # ==========================================================

    # Valid target pixels across ALL training time
    valid_Y = np.all(np.isfinite(y_train), axis=0)

    # Flatten spatial mask
    mask_Y = valid_Y.ravel()

    # Reshape y_train → (time, pixels) and keep valid ones
    ytr = y_train.reshape(T_train, NY)[:, mask_Y]

    # ==========================================================
    # NEAREST NEIGHBOUR SEARCH
    # ==========================================================

    # Build KNN search structure
    nn = NearestNeighbors(
        n_neighbors=k_neighbors,
        metric="euclidean",
        n_jobs=-1              # use all CPU cores
    )

    # Fit on training predictors
    nn.fit(Xtr_vec)

    # Find k nearest analogues for each query timestep
    distances, indices = nn.kneighbors(Xq_vec)

    # ==========================================================
    # DISTANCE → WEIGHTS
    # ==========================================================

    # Convert distances to inverse-distance weights
    # (small distance → large weight)
    weights = 1.0 / (distances + 1e-12)

    # Normalize weights so each row sums to 1
    weights /= weights.sum(axis=1, keepdims=True)

    # ==========================================================
    # ANALOGUE PREDICTIONS
    # ==========================================================

    # Dimensions
    Tq, k = indices.shape

    # Container for analogue target values
    # shape = (query_time, k_neighbors, valid_pixels)
    y_pred = np.empty((Tq, k, ytr.shape[1]))

    # Gather analogue fields
    for i in range(Tq):
        # indices[i] → indices of k nearest training dates
        y_pred[i] = ytr[indices[i]]

    # ==========================================================
    # WEIGHTED MEAN PREDICTION
    # ==========================================================

    # Weighted average across k analogues
    y_mean_flat = np.sum(
        weights[:, :, None] * y_pred,
        axis=1
    )

    # ==========================================================
    # UNCERTAINTY ESTIMATION
    # ==========================================================

    # Weighted variance across analogues
    y_var_flat = np.sum(
        weights[:, :, None] * (y_pred - y_mean_flat[:, None, :])**2,
        axis=1
    )

    # Standard deviation = uncertainty estimate
    y_std_flat = np.sqrt(y_var_flat)

    # ==========================================================
    # RESTORE FULL SPATIAL GRID (reinsert masked pixels)
    # ==========================================================

    # Initialize full grids with NaNs
    y_mean = np.full((Tq, NY), np.nan)
    y_std  = np.full((Tq, NY), np.nan)

    # Fill valid pixels only
    y_mean[:, mask_Y] = y_mean_flat
    y_std[:,  mask_Y] = y_std_flat

    # Reshape back to spatial maps
    y_mean = y_mean.reshape(Tq, n_latY, n_lonY)
    y_std  = y_std.reshape(Tq, n_latY, n_lonY)

    return y_mean, y_std

### Visualizing Observed vs Predicted Maps

In [None]:
def plot_maps(y_obs, y_mean, y_std, dates, lon, lat, indices):
    """
    Plot observed ET, predicted ET, error, and uncertainty maps.

    Parameters
    ----------
    y_obs : array (T, lat, lon)
        Observed / reference ET (use None for production runs).
    y_mean : array (T, lat, lon)
        Mean reconstructed ET.
    y_std : array (T, lat, lon)
        Uncertainty (analogue spread).
    dates : array-like (T,)
        Datetime array.
    lon, lat : 1D or 2D arrays
        Spatial coordinates.
    indices : list of int
        Time indices to visualise.
    """

    # ==========================================================
    # PREPARE SPATIAL GRID
    # ==========================================================
    # If coordinates are 1D vectors, convert to 2D meshgrid
    # so they match the shape required by pcolormesh.
    if lon.ndim == 1 and lat.ndim == 1:
        Lon, Lat = np.meshgrid(lon, lat)
    else:
        # Already 2D coordinate grids
        Lon, Lat = lon, lat

    # ==========================================================
    # LOOP THROUGH REQUESTED TIME INDICES
    # ==========================================================
    for idx in indices:

        # Extract prediction and uncertainty for this timestep
        y_pred = y_mean[idx]
        y_unc  = y_std[idx]

        # Convert date to readable format
        date = np.array(dates[idx]).astype("datetime64[D]")

        # ------------------------------------------------------
        # IF REFERENCE DATA EXISTS → compute error + shared scale
        # ------------------------------------------------------
        if y_obs is not None:

            # Reference map
            y_ref = y_obs[idx]

            # Prediction error map
            y_err = y_pred - y_ref

            # Use same color scale for observed + predicted
            vmin  = np.nanmin([y_ref, y_pred])
            vmax  = np.nanmax([y_ref, y_pred])

            # Error scale based on robust percentile
            vmax_err = np.nanpercentile(np.abs(y_err), 95)

        # ------------------------------------------------------
        # IF NO REFERENCE DATA (production mode)
        # ------------------------------------------------------
        else:
            y_ref = None
            y_err = None

            # Scale only based on prediction
            vmin  = np.nanmin(y_pred)
            vmax  = np.nanmax(y_pred)

        # Robust scale for uncertainty
        vmax_std = np.nanpercentile(y_unc, 95)

        # ------------------------------------------------------
        # DETERMINE NUMBER OF PANELS
        # ------------------------------------------------------
        # If reference exists → show 4 panels
        # Otherwise → only prediction + uncertainty
        ncols = 4 if y_obs is not None else 2

        # Create figure
        fig, axes = plt.subplots(
            1, ncols,
            figsize=(4.5 * ncols, 4),
            constrained_layout=True
        )

        # Column pointer for flexible plotting
        col = 0

        # ======================================================
        # PANEL 1 — OBSERVED MAP
        # ======================================================
        if y_obs is not None:

            im0 = axes[col].pcolormesh(
                Lon, Lat, y_ref,
                shading="auto",
                vmin=vmin, vmax=vmax
            )

            axes[col].set_title(f"Observed ET\n{date}")
            axes[col].set_aspect("equal")

            # Colorbar
            c0 = plt.colorbar(im0, ax=axes[col])
            c0.set_label("[mm/day]")

            col += 1

        # ======================================================
        # PANEL 2 — PREDICTED MAP
        # ======================================================
        im1 = axes[col].pcolormesh(
            Lon, Lat, y_pred,
            shading="auto",
            vmin=vmin, vmax=vmax
        )

        # Title depends on mode
        if y_obs is not None:
            axes[col].set_title("Predicted ET")
        else:
            axes[col].set_title(f"Predicted ET\n{date}")

        axes[col].set_aspect("equal")

        # Colorbar
        c1 = plt.colorbar(im1, ax=axes[col])
        c1.set_label("[mm/day]")

        col += 1

        # ======================================================
        # PANEL 3 — ERROR MAP (only if reference exists)
        # ======================================================
        if y_obs is not None:

            im2 = axes[col].pcolormesh(
                Lon, Lat, y_err,
                shading="auto",
                vmin=-vmax_err, vmax=vmax_err,
                cmap="coolwarm"   # diverging colormap for errors
            )

            axes[col].set_title("Error (Pred − Ref)")
            axes[col].set_aspect("equal")

            # Colorbar
            c2 = plt.colorbar(im2, ax=axes[col])
            c2.set_label("[mm/day]")

            col += 1

        # ======================================================
        # FINAL PANEL — UNCERTAINTY MAP
        # ======================================================
        im3 = axes[col].pcolormesh(
            Lon, Lat, y_unc,
            shading="auto",
            vmin=0, vmax=vmax_std,
            cmap="magma"      # sequential colormap for uncertainty
        )

        axes[col].set_title("Uncertainty (analogue spread)")
        axes[col].set_aspect("equal")

        # Colorbar
        c3 = plt.colorbar(im3, ax=axes[col])
        c3.set_label("[mm/day]")

        # ======================================================
        # DISPLAY FIGURE
        # ======================================================
        plt.show()

### Quantitative evaluation using RMSE

- Here, we compute Root Mean Squared Error (RMSE):
- RMSE $= \sqrt{\frac{1}{N}\sum^{N}_{i=1}\left(y_{obs} - y_{pred}\right)^2}$
- Computed spatially per time step.
- Can also be aggregated over the entire test period.

In [None]:
def compute_rmse(y_obs, y_pred, dates, max_gap_days=1):
    """
    Compute RMSE over space for each time step and split by date gaps.

    Returns
    -------
    rmse : (T,)
    segments : list of index arrays for continuous date segments
    """

    # ==========================================================
    # INITIALIZATION
    # ==========================================================
    # Number of time steps
    T = y_obs.shape[0]

    # Array to store RMSE at each timestep
    rmse = np.zeros(T)

    # ==========================================================
    # COMPUTE SPATIAL RMSE PER TIME STEP
    # ==========================================================
    for t in range(T):

        # Difference map between observation and prediction
        diff = y_obs[t] - y_pred[t]

        # RMSE over spatial domain (ignores NaNs)
        rmse[t] = np.sqrt(np.nanmean(diff**2))

    # ==========================================================
    # DETECT TEMPORAL GAPS IN THE DATE SERIES
    # ==========================================================
    # Convert dates to numpy array for vectorized operations
    dates = np.asarray(dates)

    # Compute day differences between consecutive dates
    gaps = np.diff(dates).astype("timedelta64[D]").astype(int)

    # Identify where gaps exceed allowed threshold
    breaks = np.where(gaps > max_gap_days)[0]

    # ==========================================================
    # SPLIT TIME SERIES INTO CONTINUOUS SEGMENTS
    # ==========================================================
    segments = []

    # Start index of current segment
    start = 0

    # Loop over detected breaks
    for b in breaks:

        # Segment runs from current start → break index
        segments.append(np.arange(start, b + 1))

        # Next segment starts after the break
        start = b + 1

    # Add final segment after last break
    segments.append(np.arange(start, T))

    # ==========================================================
    # RETURN RESULTS
    # ==========================================================
    return rmse, segments

## 2.1 - Why we switch to pre-downloaded datasets

In the previous section, we demonstrated how to download data directly from Google Earth Engine using a short time window.

**Important practical note:** Downloading large spatial domains or long time periods from GEE can be **very slow** and sometimes unstable in a workshop setting.

For this reason, we now switch to pre-downloaded datasets stored on a shared Google Drive. This allows us to:
- Work with a much larger temporal dataset
- Ensure that all participants use identical data
- Focus on the methodology rather than data access

## 2.2 - Pre-downloaded datasets used in this workshop

In this second part, we use the following datasets:

**Target**:
- GLEAM Actual Evapotranspiration (ET)

GLEAM, the target, has a spatial resolution of 0.25° and is available daily from 1980 to 2020. It provides a long, continuous reference dataset, which makes quantitative validation of the analogue reconstructions possible.

**Predictors**
- ERA5-Land precipitation (PRE)
- ERA5-Land maximum daily temperature (TMAX)

The ERA5-Land reanalysis dataset provides daily maps as a spatial resolution of 0.1° from 1940 to today.

Using ERA5-Land predictors together with GLEAM ET allows us to:
- Build analogues using a long climatological archive
- Reconstruct ET for selected days
- Directly compare reconstructed ET maps against the reference GLEAM ET, enabling objective performance assessment

In [None]:
url_et   = "https://drive.google.com/file/d/1QKJCwt44LdsQNpNBIfFA4-kFs2MTaJRD/view?usp=drive_link"
url_pre  = "https://drive.google.com/file/d/1ZyRFuYSFle5zqPpyoSBXLiyIp5frSCWk/view?usp=drive_link"
url_tmax = "https://drive.google.com/file/d/1Ru0eKa_DrvjA2x_F1MF7EPixge2VXIJh/view?usp=drive_link"

gdown.download(url_et,   output="et.nc",   fuzzy=True)
gdown.download(url_pre,  output="pre.nc",  fuzzy=True)
gdown.download(url_tmax, output="tmax.nc", fuzzy=True)

et   = xr.open_dataset("et.nc")
pre  = xr.open_dataset("pre.nc")
tmax = xr.open_dataset("tmax.nc")

We can interogate `pre` to see how the dataset is constructed.

In [None]:
pre

We can also visualise the total daily precipitation over the region of interest.

In [None]:
tot_pre = pre["pre"].sum(dim=("longitude", "latitude"))
tot_pre.plot()
plt.show()
del tot_pre
gc.collect();

Or the time-series for only one pixel.

In [None]:
ts = pre["pre"].isel(longitude=50, latitude=50)  # longitude and latitude selection based on pixel number, not actual degrees.
ts.plot()
plt.show()
del ts
gc.collect();

## 2.3 - Temporal alignment of predictors and target

We first restrict all datasets to a common time window (here 2000–2020).
This reduces the amount of data loaded in memory while retaining enough temporal variability for analogue selection.

Even after subsetting, small differences in time coordinates may remain (e.g. missing days or different calendar handling). To ensure perfect temporal consistency, precipitation and temperature are explicitly aligned to the ET time axis

In [None]:
t_start = "2000-01-01"
t_end   = "2020-12-31"

et   =   et.sel(time=slice(t_start, t_end))
pre  =  pre.sel(time=slice(t_start, t_end))
tmax = tmax.sel(time=slice(t_start, t_end))

pre  =  pre.sel(time=et.time)
tmax = tmax.sel(time=et.time)

# use the same names for spatial coordinates
pre  =  pre.rename({'longitude': 'lon', 'latitude': 'lat'})
tmax = tmax.rename({'longitude': 'lon', 'latitude': 'lat'})

We can interrogate `pre` to see how the dataset is constructed. Notice that the time availability is now shorter.

In [None]:
pre

## 2.5 - Preparing the data
Now we call the function to prepare our datasets. We define here the `time_window` length and the `n_test_years` number of years to reconstruct for validation:

In [None]:
X_train, y_train, X_test, y_test, dates_train, dates_test = prepare_data(
    et_ds         = et,
    pre_ds        = pre,
    tmax_ds       = tmax,
    time_window   = 2,    # <------ NUMBER OF DAYS IN CLIMATE WINDOW (min 1 day): Here one day lag
    test_periods  = [
        ("2018-01-01", "2018-12-31"),     # <------ VALIDATION DATES
        ("2020-06-01", "2020-08-31")      # <------ ADDITIONAL VALIDATION DATES
        ]
    )

## 2.6 - Analogue selection for a single day (illustrative example)

Before applying kNN to the full testing period, we first illustrate the analogue-based reconstruction for one specific day.

This example helps to explicitly visualize:
- The predictor fields for the target day
- The dates of the selected analogues
- The ET maps of these analogue days
- The reconstructed ET map
- The reference (observed) ET map for comparison

<div style="display:flex; gap:10px;">

<img src="https://raw.githubusercontent.com/LoicGerber/synthetic_data_generation_workshop/52acd5850ad2311637ecf5704099c15e947f975e/knn.png" width="75%">

</div>

### 2.6.1 Selecting a target day

We select one day from the test period and treat its ET map as _missing_.
Only the predictors (precipitation and temperature) for that day are used to find analogues.

Conceptually:
- Input: PRE(t), TMAX(t)
- Output: reconstructed ET(t)

In [None]:
# --- Select one target day from the test set ---
target_idx = 0            # <--------------------------------- CHANGE TO ANY INDEX IN [0, len(dates_test)-1]

target_date = dates_test[target_idx]
print(f"Target reconstruction date: {np.datetime_as_string(target_date, unit='D')}")

# Extract predictors and reference ET for that day
X_target     = X_test[target_idx:target_idx+1]  # shape (1, n_features, lat, lon)
y_target_ref = y_test[target_idx]               # shape (lat, lon)

### 2.6.2 Visualizing the predictors for the selected day

For the selected day, we first visualize the predictor fields:
- Precipitation
- Maximum temperature

This step allows us to physically interpret what _similarity_ means in the analogue space and ensures that the climate situation of the target day is well understood before reconstruction.

In [None]:
# --- Extract predictor fields ---
# Feature order:
# [pre(t-2), pre(t-1), pre(t), tmax(t-2), tmax(t-1), tmax(t)]

pre_t2, pre_t1, pre_t0    = X_target[0, 0], X_target[0, 1], X_target[0, 2]
tmax_t2, tmax_t1, tmax_t0 = X_target[0, 3], X_target[0, 4], X_target[0, 5]

fig, axes = plt.subplots(3, 2, figsize=(8, 10), constrained_layout=True)

plots = [
    (pre_t0,  "Precipitation (t)",   0, 0, "Precipitation [mm/day]"),
    (pre_t1,  "Precipitation (t-1)", 1, 0, "Precipitation [mm/day]"),
    (pre_t2,  "Precipitation (t-2)", 2, 0, "Precipitation [mm/day]"),
    (tmax_t0, "Tmax (t)",            0, 1, "Tmax [°C]"),
    (tmax_t1, "Tmax (t-1)",          1, 1, "Tmax [°C]"),
    (tmax_t2, "Tmax (t-2)",          2, 1, "Tmax [°C]"),
]

for data, title, row, col, cbar_label in plots:
    ax = axes[row, col]
    im = ax.imshow(data, origin="upper")
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_yticks([])

    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label(cbar_label)

plt.show()

### 2.6.3 Selecting analogue days using kNN

Using the kNN algorithm:
- Each day in the training period is represented as a high-dimensional vector of predictor values (all grid cells and variables).
- Here, distances are computed as Euclidean distances between flattened (collapsed into one dimension) predictor maps, meaning that each day is represented as a high-dimensional vector combining all predictor values over space and time lags.
- The k nearest neighbors correspond to the most similar climate situations, i.e. the analogues.

For transparency, we explicitly list:
- The dates of the selected analogue days
- Their distance-based ranking

This step emphasizes that kNN is used purely for analogue selection, not as a black-box regression model.

In [None]:
# --- Flatten predictors exactly as in knn_predict ---
T_train, n_features, n_lat, n_lon = X_train.shape
NX = n_lat * n_lon

# Mask valid X space (same logic as knn_predict)
X_all = np.concatenate([X_train, X_test], axis=0)
valid_X_space = np.all(np.isfinite(X_all), axis=(0, 1))
mask_X_flat = valid_X_space.ravel()

X_train_flat = X_train.reshape(T_train, n_features, NX)[:, :, mask_X_flat]
X_target_flat = X_target.reshape(1, n_features, NX)[:, :, mask_X_flat]

X_train_vec = X_train_flat.reshape(T_train, -1)
X_target_vec = X_target_flat.reshape(1, -1)

# --- Compute distances to all training days ---
distances = pairwise_distances(X_train_vec, X_target_vec, metric="euclidean").ravel()

k = 5           # <-------------------------------------------------------------------- NUMBER OF ANALOGUES TO DISPLAY
analogue_idx = np.argsort(distances)[:k]

analogue_dates = dates_train[analogue_idx]

print("Selected analogue dates:")
for i, d in zip(analogue_idx, analogue_dates):
    print(f"  {np.datetime_as_string(d, unit='D')}  | distance = {distances[i]:.3f}")

# --- FREE RAM ---
del X_all, X_train_flat, X_target_flat, X_train_vec, X_target_vec, valid_X_space, mask_X_flat
gc.collect();

### 2.6.4 Visualizing ET maps of the selected analogues

Next, we plot the ET maps corresponding to the selected analogue dates.

This allows us to verify that:
- The analogue ET patterns are physically plausible
- Spatial structures are consistent across analogue days
- Variability between analogues is preserved

In [None]:
fig, axes = plt.subplots(1, k, figsize=(4*k, 4), constrained_layout=True)

# Make sure axes is iterable when k = 1
if k == 1:
    axes = [axes]

for ax, idx, date in zip(axes, analogue_idx, analogue_dates):
    im = ax.imshow(y_train[idx], origin="upper")
    ax.set_title(np.datetime_as_string(date, unit='D'))
    ax.set_xticks([])
    ax.set_yticks([])
    cbar0 = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar0.set_label("ET [mm/day]")

plt.suptitle("ET maps of selected analogue days", fontsize=14)
plt.show()

### 2.6.5 Reconstructing ET for the target day

The ET map for the target day is reconstructed as a distance-weighted combination of the analogue ET maps.

In addition to the reconstructed ET, we compute an _analogue-based uncertainty_, defined as the weighted standard deviation of the analogue ET maps around the reconstructed mean. This quantity reflects the level of agreement among the selected analogues.

We then plot:
- The reference (observed) ET map
- The reconstructed ET map
- The reconstruction error map (prediction − reference)
- The analogue-based uncertainty map

This visual comparison demonstrates how the analogue method transfers spatial patterns from historical situations to the target day. In the error map, high values show overestimation in the generated map, while low values show underestimation. The uncertainty map highlights regions where the analogues disagree, indicating lower confidence in the reconstruction.


In [None]:
# --- Distance-based weights (inverse-distance weighting) ---
eps = np.finfo(np.float64).eps
d = distances[analogue_idx].copy()
d[d == 0.0] += eps

weights = 1.0 / d
weights /= weights.sum()

# --- Reconstruct ET ---
y_reconstructed = np.zeros_like(y_target_ref)
for w, idx in zip(weights, analogue_idx):
    y_reconstructed += w * y_train[idx]

# --- Error map ---
error = y_reconstructed - y_target_ref
err_abs_max = np.nanmax(np.abs(error))

# --- Analogue-based uncertainty (weighted spread) ---
y_var = np.zeros_like(y_target_ref)

for w, idx in zip(weights, analogue_idx):
    y_var += w * (y_train[idx] - y_reconstructed) ** 2

y_uncertainty = np.sqrt(y_var)

# --- Plot reference, reconstruction, and error ---
vmin = np.nanmin([y_reconstructed, y_target_ref])
vmax = np.nanmax([y_reconstructed, y_target_ref])

fig, axes = plt.subplots(1, 4, figsize=(20, 4), constrained_layout=True)

# Reference ET
im0 = axes[0].imshow(y_target_ref, origin="upper", vmin=vmin, vmax=vmax)
axes[0].set_title("Reference ET")
axes[0].set_xticks([])
axes[0].set_yticks([])
cbar0 = plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)
cbar0.set_label("ET [mm/day]")

# Reconstructed ET
im1 = axes[1].imshow(y_reconstructed, origin="upper", vmin=vmin, vmax=vmax)
axes[1].set_title("Reconstructed ET")
axes[1].set_xticks([])
axes[1].set_yticks([])
cbar1 = plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
cbar1.set_label("ET [mm/day]")

# Error map
im2 = axes[2].imshow(
    error,
    origin="upper",
    cmap="coolwarm",
    vmin=-err_abs_max,
    vmax=err_abs_max,
)
axes[2].set_title("Reconstruction Error (Pred − Ref)")
axes[2].set_xticks([])
axes[2].set_yticks([])
cbar2 = plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)
cbar2.set_label("Error [mm/day]")

im3 = axes[3].imshow(
    y_uncertainty,
    origin="upper",
    cmap="magma"
)
axes[3].set_title("Uncertainty (analogue spread)")
axes[3].set_xticks([]); axes[2].set_yticks([])
cbar3 = plt.colorbar(im3, ax=axes[3], fraction=0.046, pad=0.04)
cbar3.set_label("mm/day")

plt.show()

### 2.6.6 Quantitative evaluation for the single day

To complement visual inspection, we compute the spatial RMSE between:
- The reconstructed ET map
- The reference ET map

RMSE $= \sqrt{\frac{1}{N}\sum^{N}_{i=1}\left(y_{obs} - y_{pred}\right)^2}$

This provides a single quantitative score summarizing reconstruction accuracy for the example day.

In [None]:
# --- RMSE for the single reconstructed day ---
rmse_single = np.sqrt(
    np.nanmean((y_target_ref - y_reconstructed) ** 2)
)

print(
    f"RMSE for {np.datetime_as_string(target_date, unit='D')}: "
    f"{rmse_single:.3f} mm/day"
)

## 2.7 - kNN prediction for the full test period (2018)

After illustrating the analogue method for a single day, we now apply the exact same approach to the entire test period.

The procedure is identical:
- For each test day, kNN identifies climate analogues in the training set.
- ET is reconstructed from the analogue ET maps.
- Predictions are generated independently for each day.

In addition to the mean ET prediction, we also compute an analogue-based uncertainty for each grid cell and day:
- For each query day, the `k_neighbors` closest analogues are weighted by similarity (inverse distance).
- The predicted ET is the weighted mean of these analogues.
- The "uncertainty" is the weighted standard deviation of the analogue ET values around the mean.

> **Note:** This uncertainty reflects the spread of the analogue ensemble not a formal statistical confidence interval.  
> It indicates regions or days where the selected analogues are less consistent (larger spread = higher analogue-based uncertainty), not a formal statistical confidence interval.

Here, we call the function to run the kNN with *k = 20*.

In [None]:
y_test_mean, y_test_std = knn_predict_with_uncertainty(
    X_train     = X_train,
    y_train     = y_train,
    X_query     = X_test,
    k_neighbors = 20
)

## 2.8 - Visualizing Observed vs Predicted Maps

To inspect model performance, we plot observed and predicted ET side by side for selected dates.

In [None]:
target_var = list(et.data_vars)[0]
et_da      = et[target_var]
lat        = et_da["lat"].values
lon        = et_da["lon"].values

indices = [0, 50, 100] # <---------------------------------- CHOOSE DAYS TO VISUALISE

plot_maps(
    y_obs   = y_test,
    y_mean  = y_test_mean,
    y_std   = y_test_std,
    dates   = dates_test,
    lon     = lon,
    lat     = lat,
    indices = indices
)

## 2.9 - Quantitative evaluation using RMSE

While visual inspection is useful, quantitative metrics provide a more rigorous assessment.

Here, we compute Root Mean Squared Error (RMSE):

RMSE $= \sqrt{\frac{1}{N}\sum^{N}_{i=1}\left(y_{obs} - y_{pred}\right)^2}$

Computed spatially per time step.

Can also be aggregated over the entire test period.

In [None]:
rmse_test, segments = compute_rmse(y_test, y_test_mean, dates_test)

for i, seg in enumerate(segments, 1):
    mean_rmse = np.nanmean(rmse_test[seg])
    year = year = str(dates_test[seg][0])[:4]
    plt.figure(figsize=(10,4))
    plt.plot(dates_test[seg], rmse_test[seg], '-')
    plt.axhline(mean_rmse, color='red', linestyle='--', label=f"Mean RMSE = {mean_rmse:.2f}")
    plt.legend()
    plt.ylim(0)
    plt.ylabel("RMSE [mm/day]")
    plt.xlabel("Date")
    plt.title(f"RMSE - {year}")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

## 2.10 – Sensitivity to the number of neighbours (*k*)

Up to this point, the k-nearest neighbours (kNN) model has been evaluated using a fixed number of neighbours.  
However, the choice of *k* directly controls the trade-off between local similarity and smoothing:

- Small *k*
  - Strongly local analogues  
  - Potentially sharper spatial patterns  
  - Higher sensitivity to noise and outliers  

- Large *k*
  - More spatial and temporal smoothing  
  - Reduced variance and noise  
  - Possible loss of fine-scale extremes  

To better understand this behaviour, we now **vary the number of neighbours (*k*)** and analyse its impact on:

1. **Reconstructed evapotranspiration (ET) maps**
2. **Associated uncertainty**, estimated from the analogue spread
3. **Temporal RMSE** over the validation period

For each value of *k*, we:
- Reconstruct ET for the validation period
- Compute uncertainty maps from the ensemble of selected analogues
- Visually compare observed, predicted, error, and uncertainty fields
- Evaluate prediction skill using the RMSE time series

In [None]:
y_test_mean, y_test_std = knn_predict_with_uncertainty(
    X_train     = X_train,
    y_train     = y_train,
    X_query     = X_test,
    k_neighbors = 5             # <--------------------------------- CHANGE K HERE
)

indices = [0, 50, 100]          # <--------------------------------- CHANGE MAPS TO VISUALISE HERE

plot_maps(
    y_obs   = y_test,
    y_mean  = y_test_mean,
    y_std   = y_test_std,
    dates   = dates_test,
    lon     = lon,
    lat     = lat,
    indices = indices
)

rmse_test, segments = compute_rmse(y_test, y_test_mean, dates_test)

for i, seg in enumerate(segments, 1):
    mean_rmse = np.nanmean(rmse_test[seg])
    year = year = str(dates_test[seg][0])[:4]
    plt.figure(figsize=(10,4))
    plt.plot(dates_test[seg], rmse_test[seg], '-')
    plt.axhline(mean_rmse, color='red', linestyle='--', label=f"Mean RMSE = {mean_rmse:.2f}")
    plt.legend()
    plt.ylim(0)
    plt.ylabel("RMSE [mm/day]")
    plt.xlabel("Date")
    plt.title(f"RMSE - {year}")
    plt.grid(True)
    plt.tight_layout()
    plt.show()