This notebook is not used operationally or for any validation, its only purpose is to have a clear understanding of the core functions of the AA workflow. The outputs and dimensions of each main step can thus be identified here.

**Import required libraries and functions**

In [1]:
%cd ..

c:\Users\amine.barkaoui\OneDrive - World Food Programme\Documents\GitHub\anticipatory-action


In [2]:
import os
import datetime
import pandas as pd

from config.params import Params

from AA.helper_fns import (
    read_forecasts,
    read_observations,
    aggregate_by_district,
    merge_un_biased_probs,
    merge_probabilities_triggers_dashboard,
)

from hip.analysis.analyses.drought import (
    get_accumulation_periods,
    run_accumulation_index,
    run_gamma_standardization,
    run_bias_correction,
    compute_probabilities,
)

from hip.analysis.aoi.analysis_area import AnalysisArea

**Define parameters**

The `config/{country}_config.yaml` file gathers all the parameters used in the operational script and that can be customized. For example, the *monitoring_year*, the list of districts or the intensity levels can be defined in that file.

In [3]:
params = Params(iso='ZWE', issue=5, index='SPI')

**Read shapefile**

In [4]:
# Define aoi to read datasets using hip-analysis
area = AnalysisArea.from_admin_boundaries(
    iso3=params.iso.upper(),
    admin_level=2,
    resolution=0.25,
    datetime_range=f"1981-01-01/{params.monitoring_year + 1}-06-30",
)

# Read the shapefile
gdf = area.get_dataset([area.BASE_AREA_DATASET])
gdf

Unnamed: 0_level_0,geometry,Code,Name,adm1_Code,adm0_Code
Name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Bulawayo,"POLYGON ((28.6712 -20.0163, 28.6072 -19.9615, ...",1010745,Bulawayo,900969,271
Chitungwiza,"POLYGON ((31 -17.9994, 31.0626 -18.0506, 31.11...",1010758,Chitungwiza,900970,271
Epworth,"POLYGON ((31.2031 -17.8863, 31.1303 -17.8654, ...",1010760,Epworth,900970,271
Harare,"POLYGON ((31.2031 -17.8863, 31.2195 -17.8512, ...",1010771,Harare,900970,271
Harare Rural,"POLYGON ((31.1361 -17.9289, 31.125 -17.888, 31...",1010772,Harare Rural,900970,271
...,...,...,...,...,...
Redcliff,"POLYGON ((29.8419 -19.025, 29.8195 -19.0142, 2...",1010810,Redcliff,900978,271
Shurugwi,"POLYGON ((30.4711 -19.8304, 30.4674 -19.8254, ...",1010817,Shurugwi,900978,271
Shurugwi Town,"POLYGON ((30.0317 -19.6013, 29.9984 -19.607, 3...",1010818,Shurugwi Town,900978,271
Zvishavane,"POLYGON ((30.4609 -20.5564, 30.4438 -20.528, 3...",1010826,Zvishavane,900978,271


**Read forecasts**

In [5]:
# When update is set to False, the downscaled dataset is read from a local folder or a s3 bucket. Otherwise, it is directly read from HDC.
forecasts = read_forecasts(
    area,
    params.issue,
    f"{params.data_path}/data/{params.iso}/zarr/2022/{str(params.issue).zfill(2)}/forecasts.zarr",
    update=False,  # True,
)
forecasts

Unnamed: 0,Array,Chunk
Bytes,1.65 GiB,33.19 MiB
Shape,"(9416, 51, 28, 33)","(9416, 1, 28, 33)"
Dask graph,51 chunks in 1 graph layer,51 chunks in 1 graph layer
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.65 GiB 33.19 MiB Shape (9416, 51, 28, 33) (9416, 1, 28, 33) Dask graph 51 chunks in 1 graph layer Data type float32 numpy.ndarray",9416  1  33  28  51,

Unnamed: 0,Array,Chunk
Bytes,1.65 GiB,33.19 MiB
Shape,"(9416, 51, 28, 33)","(9416, 1, 28, 33)"
Dask graph,51 chunks in 1 graph layer,51 chunks in 1 graph layer
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,257.47 kiB,257.47 kiB
Shape,"(9416,)","(9416,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,,
"Array Chunk Bytes 257.47 kiB 257.47 kiB Shape (9416,) (9416,) Dask graph 1 chunks in 1 graph layer Data type",9416  1,

Unnamed: 0,Array,Chunk
Bytes,257.47 kiB,257.47 kiB
Shape,"(9416,)","(9416,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,,


**Read observations**

In [6]:
# Observations data reading (already stored as the dataset used is the same as the one used in the pre-season/analytical script)
observations = read_observations(
    area,
    f"{params.data_path}/data/{params.iso}/zarr/{params.calibration_year}/obs/observations.zarr",
)
observations

Unnamed: 0,Array,Chunk
Bytes,106.84 MiB,7.22 kiB
Shape,"(15156, 28, 33)","(1, 28, 33)"
Dask graph,15156 chunks in 1 graph layer,15156 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 106.84 MiB 7.22 kiB Shape (15156, 28, 33) (1, 28, 33) Dask graph 15156 chunks in 1 graph layer Data type float64 numpy.ndarray",33  28  15156,

Unnamed: 0,Array,Chunk
Bytes,106.84 MiB,7.22 kiB
Shape,"(15156, 28, 33)","(1, 28, 33)"
Dask graph,15156 chunks in 1 graph layer,15156 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray


**Read pre-computed triggers**

Now that we got all the data we need, let's read the triggers file so we can merge the probabilities with it once we have them.

In [7]:
# Read triggers file
if os.path.exists(f"{params.data_path}/data/{params.iso}/probs/aa_probabilities_triggers_pilots.csv"):
    triggers_df = pd.read_csv(
        f"{params.data_path}/data/{params.iso}/probs/aa_probabilities_triggers_pilots.csv",
    )
else:
    triggers_df = pd.read_csv(
        f"{params.data_path}/data/{params.iso}/triggers/triggers.spi.dryspell.{params.calibration_year}.pilots.csv",
    )
triggers_df

Unnamed: 0,district,index,category,window,issue_ready,issue_set,trigger_ready,trigger_set,vulnerability,prob_ready,prob_set
0,Beitbridge,DRYSPELL JF,Normal,Window 2,7.0,8.0,0.14,0.30,NRT,,
1,Beitbridge,SPI DJF,Normal,Window 2,7.0,8.0,0.38,0.00,NRT,,
2,Beitbridge,SPI FM,Normal,Window 2,9.0,10.0,0.31,0.35,NRT,,
3,Beitbridge,SPI JFM,Normal,Window 2,10.0,11.0,0.00,0.37,NRT,,
4,Beitbridge,SPI ND,Normal,Window 1,6.0,7.0,0.35,0.11,NRT,0.24,
...,...,...,...,...,...,...,...,...,...,...,...
73,Rushinga,SPI DJ,Normal,Window 2,9.0,10.0,0.23,0.37,NRT,,
74,Rushinga,SPI FM,Normal,Window 2,11.0,12.0,0.34,0.27,NRT,,
75,Rushinga,SPI JF,Normal,Window 2,8.0,9.0,0.08,0.34,NRT,,
76,Rushinga,SPI ND,Normal,Window 1,6.0,7.0,0.30,0.26,NRT,0.29,


**Get accumulation periods covered by the forecasts of the defined issue month**

In [8]:
# Get accumulation periods (DJ, JF, FM, DJF, JFM...)
accumulation_periods = get_accumulation_periods(
    forecasts,
    params.start_season,
    params.end_season,
    params.min_index_period,
    params.max_index_period,
)
accumulation_periods

{'MJ': (5, 6), 'ON': (10, 11)}

Here we focus on the pipeline for one indicator (one period) so we select a single element from the above dictionary (November-December using October forecasts).

In [9]:
# Get single use case
period_name, period_months = list(accumulation_periods.items())[1] # [4]
period_name, period_months

('ON', (10, 11))

**Run accumulation (sum for SPI)**

In [10]:
# Remove 1980 season to harmonize observations between different indexes 
if int(params.issue) >= params.end_season:
    observations = observations.where(
        observations.time.dt.date >= datetime.date(1981, 10, 1), drop=True
    )

In [11]:
# Accumulation
accumulation_fc = run_accumulation_index(
    forecasts.chunk(dict(time=-1)), params.aggregate, period_months, forecasts=True
)
accumulation_obs = run_accumulation_index(
    observations.chunk(dict(time=-1)), params.aggregate, period_months
)

In [12]:
accumulation_fc

Unnamed: 0,Array,Chunk
Bytes,7.91 MiB,3.61 kiB
Shape,"(44, 51, 28, 33)","(1, 1, 28, 33)"
Dask graph,2244 chunks in 183 graph layers,2244 chunks in 183 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 7.91 MiB 3.61 kiB Shape (44, 51, 28, 33) (1, 1, 28, 33) Dask graph 2244 chunks in 183 graph layers Data type float32 numpy.ndarray",44  1  33  28  51,

Unnamed: 0,Array,Chunk
Bytes,7.91 MiB,3.61 kiB
Shape,"(44, 51, 28, 33)","(1, 1, 28, 33)"
Dask graph,2244 chunks in 183 graph layers,2244 chunks in 183 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [13]:
accumulation_obs

Unnamed: 0,Array,Chunk
Bytes,295.97 kiB,7.22 kiB
Shape,"(41, 28, 33)","(1, 28, 33)"
Dask graph,41 chunks in 170 graph layers,41 chunks in 170 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 295.97 kiB 7.22 kiB Shape (41, 28, 33) (1, 28, 33) Dask graph 41 chunks in 170 graph layers Data type float64 numpy.ndarray",33  28  41,

Unnamed: 0,Array,Chunk
Bytes,295.97 kiB,7.22 kiB
Shape,"(41, 28, 33)","(1, 28, 33)"
Dask graph,41 chunks in 170 graph layers,41 chunks in 170 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


**Run standardization (SPI)**

In [14]:
# Remove inconsistent observations
accumulation_obs = accumulation_obs.sel(
    time=slice(datetime.date(1979, 1, 1), datetime.date(params.monitoring_year - 1, 12, 31))
)

In [15]:
# Anomaly
anomaly_fc = run_gamma_standardization(
    accumulation_fc.load(),
    params.hist_anomaly_start,
    params.hist_anomaly_stop,
    members=True,
)
anomaly_obs = run_gamma_standardization(
    accumulation_obs.load(),
    params.hist_anomaly_start,
    params.hist_anomaly_stop,
)

  rfh["time"] = [


In [16]:
anomaly_fc

In [17]:
anomaly_obs

**Run bias correction**

In [18]:
import numpy as np
import pandas as pd
import xarray as xr
from hdc.algo.dekad import Dekad
from numba import jit
from odc.geo.xr import xr_reproject
from xclim import sdba

from hip.analysis.analyses.drought import ENSO_COLD, ENSO_WARM, ENSO_NEUTRAL

In [19]:
def _get_enso_years(year, issue, monitoring_start):
    for years in [ENSO_NEUTRAL, ENSO_WARM, ENSO_COLD]:
        # TODO this is if meant to replicate R outputs but will be removed in another branch
        if issue >= monitoring_start and year in years:
            return np.array(years)
        elif issue < monitoring_start and year in (np.array(years) + 1):
            return np.array(years) + 1


@jit(nopython=True)
def _get_nn(pos, size):
    """
    Returns the nearest neighbors of a position on a 1D lattice.

    Args:
        pos: int, The position for which nearest neighbors are to be found.
        size: int, The size of the 1D lattice.

    Returns:
        list: a list containing the position itself and its nearest neighbors.

    Notes:
        - If the position is at the edge of the lattice, it returns its single neighbor.
        - For all other positions, it returns the position itself and its adjacent neighbors.
    """
    nn = [pos]

    edge_min = 0
    edge_max = size - 1

    if pos == edge_min:
        nn.append(pos + 1)
    elif pos == edge_max:
        nn.insert(0, pos - 1)
    else:
        nn.insert(0, pos - 1)
        nn.append(pos + 1)

    return nn


def _create_time_coord(n: int):
    """
    Create a time coordinate for reformatted obs for bias correction.
    """
    start = datetime.datetime.strptime("01-01-1900", "%d-%m-%Y")
    date_generated = [start + datetime.timedelta(days=x) for x in range(0, n)]
    return date_generated


@jit(nopython=True)
def _stack_neighbors_in_time(observations_np, num_neighbors):
    """
    Stack the values of each pixel's neighbors along with the pixel itself in a single dimension.

    Args:
        observations_np: np.ndarray, Array of observations.
        num_neighbors: int, number of nearest neighbors to consider.

    Returns:
        np.ndarray: Reformatted array of observations with neighbors stacked in time dimension.
    """
    lat_size, lon_size, time_size = observations_np.shape

    # Initialize the array to store the reformatted observations
    stacked_observation_array = np.full(
        (lat_size, lon_size, time_size * (num_neighbors + 1)), np.nan
    )

    # Iterate over each latitude and longitude index
    for lat in range(lat_size):
        for lon in range(lon_size):
            # Get the indices of the nearest neighbors
            lat_indices = np.array(_get_nn(lat, lat_size))
            lon_indices = np.array(_get_nn(lon, lon_size))

            # Extract the sub-array corresponding to the nearest neighbors
            pixel_observations = observations_np[lat_indices][:, lon_indices]

            # Assign the stacked values to the appropriate indices in stacked_observation_array
            stacked_values = []
            for t in range(time_size):
                for i, _ in enumerate(lat_indices):
                    for j, _ in enumerate(lon_indices):
                        stacked_values.append(pixel_observations[i, j, t])
            stacked_observation_array[lat, lon, : len(stacked_values)] = stacked_values

    return stacked_observation_array


def _reformat_neighbors(da_observation, num_neighbors):
    """
    Reformat a DataArray of observations by stacking values of each pixel's neighbors in time dimension.

    Args:
        da_observation: xr.DataArray, DataArray containing observations.
        num_neighbors: int, number of nearest neighbors to consider.

    Returns:
        xr.DataArray: Reformatted DataArray with neighbors stacked in time dimension for bias correction.
    """
    stacked_observation_array = _stack_neighbors_in_time(
        da_observation.values, num_neighbors
    )

    reformatted_da_observation = xr.DataArray(
        stacked_observation_array,
        dims=["latitude", "longitude", "time"],
        coords=dict(
            latitude=("latitude", da_observation.latitude.values),
            longitude=("longitude", da_observation.longitude.values),
            time=_create_time_coord(stacked_observation_array.shape[2]),
        ),
        attrs=da_observation.attrs,
    )

    return reformatted_da_observation

In [20]:
def run_bias_correction(
    forecasts: xr.DataArray,
    observations: xr.DataArray,
    start_monitoring: int,
    year: int = 2022,
    issue: int = None,
    nearest_neighbours: int = 0,
    enso: bool = True,
) -> xr.DataArray:
    """
    Calculate bias-corrected rainfall forecasts for year of interest using Empirical Quantile Mapping method

    Args:
        forecasts: xarray.DataArray, rainfall forecasts dataset
        observations: xarray.DataArray, rainfall observations dataset
        end_season: int, end of season month
        year: int, year to bias correct
        issue: int, issue month of forecasts
        nearest_neighbours: int, number of nearest neighbouring pixels to consider in quantile mapping (only for observations)
        enso: bool, default True, if True then only years of the same enso type as the year to bias-correct are selected in the reference time series

    Returns:
        xarray.DataArray: bias-corrected rainfall forecasts for specified year
    """

    # Get TS to bias correct
    year_to_bc = forecasts.where(forecasts.time.dt.year == year, drop=True)

    # Keep historical data
    hist_obs = observations.where(observations.time.dt.year != year, drop=True)
    hist_fc = forecasts.where(
        forecasts.time.dt.year.isin(hist_obs.time.dt.year.values), drop=True
    )

    # Mask ENSO years
    if enso:
        hist_fc = hist_fc.sel(
            time=hist_fc.time.dt.year.isin(_get_enso_years(year, issue, start_monitoring))
        )
        hist_obs = hist_obs.sel(
            time=hist_obs.time.dt.year.isin(_get_enso_years(year, issue, start_monitoring))
        )

    # If nearest_neighbours >0, reformat obs to stack neighbours valuesin time dimension
    if nearest_neighbours > 0:
        hist_obs = _reformat_neighbors(hist_obs, nearest_neighbours)

    # Reformat time coord
    hist_fc = hist_fc.assign_coords({"time": _create_time_coord(hist_fc.time.size)})

    # Create units attr for QM function
    hist_fc.attrs["units"] = ""
    hist_obs.attrs["units"] = ""
    year_to_bc.attrs["units"] = ""

    # Fit QM model
    QM = sdba.EmpiricalQuantileMapping.train(
        hist_obs, hist_fc, nquantiles=np.arange(0, 1.1, 0.1), group="time"
    )

    # Bias correct forecasts of season of interest
    bc = QM.adjust(year_to_bc, interp="linear", extrapolation="constant")

    return bc

In [25]:
# Bias correction
index_bc = run_bias_correction(
    anomaly_fc,
    anomaly_obs,
    start_monitoring=params.start_monitoring,
    year=params.monitoring_year,
    issue=int(params.issue),
    nearest_neighbours=8,
    enso=True,
)
display(index_bc)

**Run probabilities**

In [26]:
# Change dryspell sign as we compare values to a negative threshold to get probabilities
if params.index == "dryspell":
    anomaly_fc *= -1
    index_bc *= -1
    anomaly_obs *= -1

In [27]:
# Probabilities without Bias Correction
probabilities = compute_probabilities(
    anomaly_fc.where(anomaly_fc.time.dt.year == params.monitoring_year, drop=True),
    levels=params.intensity_thresholds,
).round(2)
display(probabilities)

In [28]:
# Probabilities after Bias Correction
probabilities_bc = compute_probabilities(
    index_bc, levels=params.intensity_thresholds
).round(2)
display(probabilities_bc)

**Admin-2 level aggregation**

In [40]:
# Aggregate by district
probs_district = aggregate_by_district(probabilities, gdf, params)
probs_bc_district = aggregate_by_district(probabilities_bc, gdf, params)

# Build single xarray with merged unbiased/biased probabilities
probs_by_district = merge_un_biased_probs(
    probs_district, probs_bc_district, params, period_name
)
display(probs_by_district)

**Dataframe formatting**

In [41]:
# Merge probabilities with triggers
probs_df, merged_df = merge_probabilities_triggers_dashboard(
    probs_by_district, triggers_df, params, period_name
)

In [42]:
probs_df

Unnamed: 0,district,category,issue,index,prob,aggregation
0,Beitbridge,Moderate,10,SPI ND,0.30,SPI 2
1,Beitbridge,Normal,10,SPI ND,0.41,SPI 2
2,Bikita,Moderate,10,SPI ND,0.31,SPI 2
3,Bikita,Normal,10,SPI ND,0.43,SPI 2
4,Bindura,Moderate,10,SPI ND,0.33,SPI 2
...,...,...,...,...,...,...
125,Zaka,Normal,10,SPI ND,0.42,SPI 2
126,Zvimba,Moderate,10,SPI ND,0.31,SPI 2
127,Zvimba,Normal,10,SPI ND,0.42,SPI 2
128,Zvishavane,Moderate,10,SPI ND,0.32,SPI 2


In [45]:
merged_df

Unnamed: 0,district,index,category,window,issue_ready,issue_set,trigger_ready,trigger_set,vulnerability,prob_ready,prob_set
0,Beitbridge,DRYSPELL JF,Normal,Window 2,7.0,8.0,0.14,0.30,NRT,,
1,Beitbridge,SPI DJF,Normal,Window 2,7.0,8.0,0.38,0.00,NRT,,
2,Beitbridge,SPI FM,Normal,Window 2,9.0,10.0,0.31,0.35,NRT,,
3,Beitbridge,SPI JFM,Normal,Window 2,10.0,11.0,0.00,0.37,NRT,,
4,Beitbridge,SPI ND,Normal,Window 1,6.0,7.0,0.35,0.11,NRT,0.24,
...,...,...,...,...,...,...,...,...,...,...,...
73,Rushinga,SPI DJ,Normal,Window 2,9.0,10.0,0.23,0.37,NRT,,
74,Rushinga,SPI FM,Normal,Window 2,11.0,12.0,0.34,0.27,NRT,,
75,Rushinga,SPI JF,Normal,Window 2,8.0,9.0,0.08,0.34,NRT,,
76,Rushinga,SPI ND,Normal,Window 1,6.0,7.0,0.30,0.26,NRT,0.29,
