# 2.0.0: Trait map correlation with sPlot sparse grids

On a global scale, sPlot is one of the best benchmarks we have when evaluating the accuracy of trait extrapolations. A simple way of evaluating the quality of our models is to calculate the correlation between the extrapolated trait values and the corresponding gridded sPlot trait values using Pearson's correlation coefficient. However, since our values are stored in geographic coordinates, we should make sure to weight each value according to its grid cell's actual area on Earth.

## Imports and config

In [1]:
from pathlib import Path
from typing import Iterable

import pandas as pd
import rioxarray as riox
import statsmodels.api as sm
import xarray as xr
from pyproj import Proj
from shapely.geometry import shape

from src.conf.conf import get_config
from src.conf.environment import log

cfg = get_config()

## Define latitude weights

In [11]:
def get_lat_area(lat: int | float, resolution: int | float) -> float:
    """Calculate the area of a grid cell at a given latitude."""
    # Define the grid cell coordinates
    coordinates = [
        (0, lat + (resolution / 2)),
        (resolution, lat + (resolution / 2)),
        (resolution, lat - (resolution / 2)),
        (0, lat - (resolution / 2)),
        (0, lat + (resolution / 2)),  # Close the polygon by repeating the first point
    ]

    # Define the projection string directly using the coordinates
    projection_string = (
        f"+proj=aea +lat_1={coordinates[0][1]} +lat_2={coordinates[2][1]} "
        f"+lat_0={lat} +lon_0={resolution / 2}"
    )
    pa = Proj(projection_string)

    # Project the coordinates and create the polygon
    x, y = pa(*zip(*coordinates))  # pylint: disable=unpacking-non-sequence
    area = shape({"type": "Polygon", "coordinates": [list(zip(x, y))]}).area / 1000000

    return area


def lat_weights(lat_unique: Iterable[int | float], resolution: int | float) -> dict:
    """Calculate weights for each latitude band based on area of grid cells."""
    weights = {}

    for j in lat_unique:
        weights[j] = get_lat_area(j, resolution)

    # Normalize the weights by the maximum area
    max_area = max(weights.values())
    weights = {k: v / max_area for k, v in weights.items()}

    return weights

## Calculate weighted $r$

In [3]:
def weighted_pearson_r(df: pd.DataFrame, weights: dict) -> float:
    """Calculate the weighted Pearson correlation coefficient between two DataFrames."""

    df["weights"] = df.index.get_level_values("y").map(weights)

    model = sm.stats.DescrStatsW(df.iloc[:, :2], df["weights"])
    return model.corrcoef[0, 1]

## Load the data

In [4]:
splot_fns = sorted(
    list(
        Path(
            cfg.interim_dir,
            cfg.splot.interim.dir,
            cfg.splot.interim.traits,
            cfg.PFT,
            cfg.model_res,
        ).glob("*.tif")
    ),
    key=lambda x: int(x.stem.split("X")[-1]),
)
extrap_fns = sorted(
    list(
        Path(
            cfg.processed.dir,
            cfg.PFT,
            cfg.model_res,
            cfg.datasets.Y.use,
            cfg.processed.predict_dir,
        ).glob("*.tif")
    ),
    key=lambda x: int(x.stem.split("_")[0].split("X")[-1]),
)

In [12]:
for splot_fn, extrap_fn in zip(splot_fns[:1], extrap_fns[:1]):
    log.info("Loading and filtering data...")
    splot = (
        riox.open_rasterio(splot_fn)
        .sel(band=1)
        .to_dataframe(name=f"splot_{splot_fn.stem}")
        .drop(columns=["band", "spatial_ref"])
        .dropna()
    )
    extrap = (
        riox.open_rasterio(extrap_fn)
        .sel(band=1)
        .to_dataframe(name=f"extrap_{extrap_fn.stem}")
        .drop(columns=["band", "spatial_ref"])
        .dropna()
    )
    log.info("Joining dataframes...")
    df = splot.join(extrap, how="inner")

    lat_unique = df.index.get_level_values("y").unique().values()

    log.info("Calculating weights...")
    weights = lat_weights(lat_unique, cfg.target_resolution)

    log.info("Calculating weighted Pearson correlation coefficient...")
    r = weighted_pearson_r(df, weights)
    log.info(f"Weighted Pearson correlation coefficient: {r}")

2024-06-17 14:43:21 UTC - src.conf.environment - INFO - Loading and filtering data...


2024-06-17 14:43:48 UTC - src.conf.environment - INFO - Joining dataframes...
2024-06-17 14:44:20 UTC - src.conf.environment - INFO - Calculating weights...
2024-06-17 14:44:23 UTC - src.conf.environment - INFO - Calculating weighted Pearson correlation coefficient...
2024-06-17 14:44:23 UTC - src.conf.environment - INFO - Weighted Pearson correlation coefficient: 0.5984358563948847
