# 0.1.4: Build GBIF trait maps

The final step before training models using Earth observation (EO) data is to link the TRY trait data with the GBIF species observations and then grid them. In this way, we can have matching trait rasters to be paired with our EO data.

## Imports and config

In [42]:
from pathlib import Path

import dask.dataframe as dd
from dask.distributed import Client, LocalCluster
import dask.config as dask_config
import dask_geopandas as dgpd
import geopandas as gpd
import numpy as np
import pandas as pd
import rioxarray as riox
from src.conf.conf import get_config
from src.conf.environment import log
from src.utils.raster_utils import create_sample_raster, xr_to_raster

%load_ext autoreload
%autoreload 2

# Display all columns when printing a pandas DataFrame
pd.set_option("display.max_columns", None)

cfg = get_config()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Due to the size of the GBIF data, we're going to need to use Dask in order to keep memory usage low as well as to parallelize the merging and spatial gridding operations. The settings below (`n_workers` and `memory_limit`, in particular) are specific to the machine being used during this exercise.

In [43]:
cluster = LocalCluster(n_workers=50, memory_limit="24GB", dashboard_address=":39143")
client = Client(cluster)

## Load GBIF and filter by PFT

Let's load the GBIF data, select all three PFTs, and set species name as the index to make merging with the TRY data faster.

**Note:** Processing the entire GBIF dataset, as done below, may be infeasible for some machines. If this is the case, simply select a single PFT for the `filter_pft` call, and also consider using `DataFrame.sample(frac=<fraction of the data>)` to use only a subsample of the data.

In [44]:
def filter_pft(df: pd.DataFrame, pft_set: str, pft_col: str = "pft") -> pd.DataFrame:
    pfts = pft_set.split("_")
    if not any(pft in ["Shrub", "Tree", "Grass"] for pft in pfts):
        raise ValueError(f"Invalid PFT designation: {pft_set}")

    return df[df[pft_col].isin(pfts)]

npartitions = 90

gbif = (
    dd.read_parquet(Path(cfg.gbif.interim.dir, cfg.gbif.interim.subsampled))
    .repartition(npartitions=npartitions)
    .pipe(filter_pft, "Shrub_Tree_Grass")
    .set_index("speciesname")
)

## Load TRY filtered mean trait data

In [45]:
mn_traits = (
    dd.read_parquet(Path(cfg.trydb.interim.dir, cfg.trydb.interim.filtered))
    .repartition(npartitions=npartitions)
    .set_index("speciesname")
)

## Link mean trait values with GBIF data

Because we set species name as the index on both DataFrames, we can simply perform an inner join, called by the GBIF data, to merge the traits and cit-sci species occurrences.

In [46]:
merged = gbif.join(mn_traits, how="inner").reset_index().drop(columns=["pft"])

In [None]:
# compute the number of unique indices
print(
    f"Pct matched species: {merged.index.nunique().compute() / gbif.index.nunique():.2%}"
)

## Rasterize merged trait values

### Grid the matched trait data

In [47]:
def global_grid_df(
    df: dd.DataFrame,
    col: str,
    lon: str = "decimallongitude",
    lat: str = "decimallatitude",
    res: int | float = 0.5,
) -> dd.DataFrame:
    """
    Group and aggregate a DataFrame by latitude and longitude coordinates to create a
    gridded DataFrame.

    Parameters:
        df (dd.DataFrame): The input DataFrame.
        col (str): The column to aggregate.
        lon (str, optional): The column name for longitude coordinates. Defaults to
            "decimallongitude".
        lat (str, optional): The column name for latitude coordinates. Defaults to
            "decimallatitude".
        res (int | float, optional): The resolution of the grid. Defaults to 0.5.

    Returns:
        dd.DataFrame: The gridded DataFrame with aggregated statistics.

    """
    stat_funcs = [
        "mean",
        "std",
        "median",
        lambda x: x.quantile(0.05, interpolation="nearest"),
        lambda x: x.quantile(0.95, interpolation="nearest"),
    ]
    
    stat_names = ["mean", "std", "median", "q05", "q95"]

    # Calculate the bin for each row directly. This may be very slightly less accurate
    # than creating x and y bins and using `pd.cut`, but it has the benefit of being
    # significantly more performant.
    df["y"] = (df[lat] + 90) // res * res - 90 + res / 2
    df["x"] = (df[lon] + 180) // res * res - 180 + res / 2

    gridded_df = (
        df.drop(columns=[lat, lon])
        .groupby(["y", "x"], observed=False)[[col]]
        .agg(stat_funcs)
    )

    gridded_df.columns = stat_names

    return gridded_df

In [74]:
def grid_df_to_raster(df: pd.DataFrame, res: int | float, name: str) -> None:
    """
    Converts a grid DataFrame to a raster file.

    Args:
        df (pd.DataFrame): The grid DataFrame to convert.
        res (int | float): The resolution of the raster.
        name (str): The name of the raster file.

    Returns:
        None
    """
    rast = create_sample_raster(resolution=res)
    ds = df.to_xarray()
    ds = ds.rio.write_crs(rast.rio.crs)
    ds = ds.rio.reproject_match(rast)

    for var in ds.data_vars:
        nodata = ds[var].attrs["_FillValue"]
        ds[var] = ds[var].where(ds[var] != nodata, np.nan)
        ds[var] = ds[var].rio.write_nodata(-32767.0, encoded=True)
    
    ds.attrs["long_name"] = list(ds.data_vars)
    ds.attrs["trait"] = name

    xr_to_raster(ds, f"{name}.tif")

Now let's grid the data for the first trait, "X4" or "stem specific density". 

`global_grid_df` grids the data to the centroids of each point observations corresponding grid cell (at the desired resolution), and then calculates the mean, standard deviation, median, and 5th and 95th quantiles of each grid cell.

In [48]:
cols = [col for col in merged.columns if col.startswith("X")]
grid_data = global_grid_df(merged, cols[0], res=0.01).compute()



And finally let's fill a raster with the gridded data and save it to file.

Note that in `grid_df_to_raster` we first generate a reference raster at the desired resolution and then match the `xarray.Dataset` we created from our gridded DataFrame to the reference raster. This is important, because, due to minor differences in floating point accuracy, the rasterized DataFrame's coordinates may be subtly different than those of our EO predictor data. If we first match all training data to a reference raster, however, we can ensure that all coordinates will match perfectly.

In [76]:
grid_df_to_raster(grid_data, 0.01, "X4")

And lastly, let's shut down our Dask cluster.

In [41]:
client.close()
cluster.close()