# Correlation between SPLOT AND PLANTTRAITPRODUCTS


In [None]:
import os
import sys
from pathlib import Path

# Scientific + Data
import numpy as np
import pandas as pd
import xarray as xr
from scipy.stats import linregress

# Raster + Geo
import rasterio
import rioxarray as riox

# Visualization
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import seaborn as sns

# Local utilities
from box import ConfigBox
from utils.benchmark_conf import get_benchmark_config
from utils.benchmark_utils import (
    read_trait_map,
    global_grid_df,
    lat_weights,
    create_sample_raster,
)

# Load config once
cfg = get_benchmark_config()


In [None]:
from pathlib import Path
from typing import Union, Optional
import os
import xarray as xr
import rioxarray as riox


def open_raster(
    filename: Union[str, os.PathLike],
    mask_and_scale: bool = True,
    **kwargs
) -> Union[xr.DataArray, xr.Dataset]:
    """Open a raster dataset using rioxarray."""
    ds = riox.open_rasterio(filename, mask_and_scale=mask_and_scale, **kwargs)

    if isinstance(ds, list):
        raise ValueError("Multiple files found.")

    return ds


def read_product_trait_map(
    trait_id: str,
    path: Path,
    band: Optional[int] = None
) -> Union[xr.DataArray, xr.Dataset]:
    """Read a product trait map from a .tif file or directory."""

    path = Path(path)

    # If input is already a single GeoTIFF file
    fn = path if path.suffix == ".tif" else path / f"{trait_id}.tif"

    if band is not None:
        return open_raster(fn).sel(band=band)

    return open_raster(fn)


def read_product_tiff_to_df(
    trait: str,
    path: Union[str, Path],
    band: int = 1
) -> "pd.DataFrame":
    """Load a TIFF as a tidy pandas DataFrame with x,y,value columns."""
    ds = read_product_trait_map(trait_id=trait, path=path, band=band)

    return (
        ds.to_dataframe(name=trait)
          .drop(columns=["band", "spatial_ref"], errors="ignore")
          .dropna()
    )


In [None]:

def weighted_pearsonr(x, y, weights):
    mean_x = np.average(x, weights=weights)
    mean_y = np.average(y, weights=weights)
    cov_xy = np.sum(weights * (x - mean_x) * (y - mean_y))
    var_x = np.sum(weights * (x - mean_x) ** 2)
    var_y = np.sum(weights * (y - mean_y) ** 2)
    return cov_xy / np.sqrt(var_x * var_y)

def compute_weighted_metrics(df, latwts, pred_col="product_values", true_col="sPLOT", nmin=1, log_scale=False):
    # Ensure MultiIndex is labeled
    if df.index.names != ["y", "x"]:
        df.index.set_names(["y", "x"], inplace=True)

    df = df.copy()
    df["weight"] = df.index.get_level_values("y").map(latwts)
    
    print("df shape", df.shape)

    # Drop rows with missing weights or values
    df = df.dropna(subset=["weight", pred_col, true_col])
    
    print("df shape after dropna weight", df.shape)

    # Extract values
    y_true = df[true_col].values
    y_pred = df[pred_col].values
    weights = df["weight"].values
    
    if log_scale:
        # Apply log transformation
        y_true = np.log1p(y_true)
        y_pred = np.log1p(y_pred)

    # Edge case: empty after filtering
    if len(y_true) == 0:
        return {
            "n": 0,
            "bias": np.nan,
            "mae": np.nan,
            "nmae": np.nan,
            "rmse": np.nan,
            "nrmse": np.nan,
            "r2": np.nan,
            "slope": np.nan,
            "pearson_r": np.nan
        }

    # Weighted means
    y_true_mean = np.average(y_true, weights=weights)
    # y_pred_mean = np.average(y_pred, weights=weights)

    # Errors
    errors = y_pred - y_true
    mae = np.average(np.abs(errors), weights=weights)
    rmse = np.sqrt(np.average(errors ** 2, weights=weights))
    #bias = np.average(errors, weights=weights)
    #get y_true < 0.99 quantile
    y_true_quantile = np.quantile(y_true, 0.99)
    y_true_min_quantile = np.quantile(y_true, 0.01)
    
    # Filter y_true based on the 0.99 quantile and 0.01 quantile
    y_true_filtered = y_true[(y_true < y_true_quantile) & (y_true > y_true_min_quantile)]
    max = np.max(y_true_filtered)
    min = np.min(y_true_filtered)
    nmae = mae / (max - min)  # Avoid division by zero
    nrmse = rmse / (max - min)  # Avoid division by zero

    ss_res = np.sum(weights * (errors) ** 2)
    ss_tot = np.sum(weights * (y_true - y_true_mean) ** 2)
    r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else np.nan

    slope, _, *_ = linregress(y_true, y_pred)
    weighted_pearson_r = weighted_pearsonr(y_true, y_pred, weights)
    pearson_r = np.corrcoef(y_true, y_pred)[0, 1] if len(y_true) > 1 else np.nan
    

    # Number of latitude and longitude pairs
    n = df.index.nunique()
    
    return {
        "n": n,
        "mae": mae,
        "nmae": nmae,
        "rmse": rmse,
        "nrmse": nrmse,
        "r2": r2,
        "slope": slope,
        "weighted_pearson_r": weighted_pearson_r,
        "pearson_r": pearson_r,
    }


In [None]:
def plothexbin(df, true_col, pred_col, latwts, ax, label=None, scaleaxis=True, allmetrics=True, n_min=1, log_scale=False, title='Model Performance'):
    true = df[true_col].values
    pred = df[pred_col].values
    
    if log_scale:
        true = np.log1p(true)
        pred = np.log1p(pred)
        label = f'log({label})' if label else 'log(values)'

    if true.shape == (0,) or pred.shape == (0,):
        print(f'x: {true.shape}, y: {pred.shape}')
        return ax

    metrics = compute_weighted_metrics(df, latwts, pred_col=pred_col, true_col=true_col, nmin=n_min, log_scale=log_scale)
    r2 = metrics["r2"]
    nmae = metrics["nmae"]
    nrmse = metrics["nrmse"]
    slope = metrics["slope"]
    weighted_r = metrics["weighted_pearson_r"]
    r = metrics["pearson_r"]
    n = metrics["n"]
    

    ax.set_aspect('equal', adjustable='box')
    ax.set_xlabel(f'Observed {label}', fontsize=20, fontweight='bold')
    ax.set_ylabel(f'Predicted {label}', fontsize=20, fontweight='bold')

    if scaleaxis:
        maxval = max(true.max(), pred.max())
        minval = min(true.min(), pred.min())
        ax.set_xlim(minval, maxval)
        ax.set_ylim(minval, maxval)
        ax.plot([minval, maxval], [minval, maxval], color='red', lw=2)
        extent = [minval, maxval, minval, maxval]
    else:
        extent = None

    hb = ax.hexbin(true, pred, gridsize=60, cmap='plasma', mincnt=1, extent=extent, norm=Normalize(vmin=0, vmax=1))
    counts = hb.get_array()
    norm_counts = counts / counts.max()
    hb.set_array(norm_counts)

    if allmetrics:
        ax.text(0.05, 0.95, f'R²: {r2:.2f}\nnMAE: {nmae:.2f}\nnRMSE: {nrmse:.2f}\nslope: {slope:.2f}\nweighted_r: {weighted_r:.2f}\nr: {r:.2f}\nn: {n} points',
                transform=ax.transAxes, fontsize=14, verticalalignment='top',
                bbox=dict(boxstyle='round,pad=0.3', edgecolor='black', facecolor='white'))
    else:
        ax.text(0.05, 0.95, f'R²: {r2:.2f}',
                transform=ax.transAxes, fontsize=14, verticalalignment='top',
                bbox=dict(boxstyle='round,pad=0.3', edgecolor='black', facecolor='white'))
    
    ax.set_title(title, fontsize=20, fontweight='bold')
    metrics = {
        "R²": r2,
        "nMAE": nmae,
        "nRMSE": nrmse,
        "slope": slope,
        "weighted_r": weighted_r,
        "r": r,
        "n": n
    }

    return ax, metrics


In [None]:
# Configuration for seed directories and trait mapping
SEED_DIRS = {
    "seed0": Path("/path/to.parquet_file for the seed generated after benchamrking against sPlotOpen/data/nmin20"),
    "seed100": Path("/path/to.parquet_file for the seed generated after benchamrking against sPlotOpen/data/nmin20"),
    "seed200": Path("/path/to.parquet_file for the seed generated after benchamrking against sPlotOpen/data/nmin20")
}

TRAITS = {
    "X18": "Height",
    "X3113": "LeafArea",
    "X14": "LeafN",
    "X3117": "SLA"
}
grid_stat = "mean"
TARGET_RESOLUTION = 1
PRODUCTS = ['moreno', 'wolf', 'schiller', 'butler', 'boonman', 'madani', 'bodegom']

TRAIT_DATA_DIR = Path("/path/to/folder with .tif files from other products")
OUT_BASE_DIR = Path("/path/to.output_dir")
all_stats_path = OUT_BASE_DIR / "splot_corr.csv"
all_stats_path.parent.mkdir(parents=True, exist_ok=True)

if all_stats_path.exists():
    all_stats = pd.read_csv(all_stats_path)
else:
    all_stats = pd.DataFrame(columns=["seed", "product","trait", "r2", "nmae", "r", "#grids"])

refrence_raster = create_sample_raster(
    resolution=TARGET_RESOLUTION,
)
for seed, seed_dir in SEED_DIRS.items():
    print(f"\nProcessing for {seed}...")

    # Load common grids for this seed
    common_grids = {}
    for trait_id, trait_name in TRAITS.items():
        parquet_path = seed_dir / f"{trait_name}_vs_splot_1kmdeg_nmin20.parquet"
        if parquet_path.exists():
            common_grids[trait_id] = pd.read_parquet(parquet_path)
        else:
            print(f"Warning: Missing parquet for {trait_id} at {parquet_path}")
            

    for product in PRODUCTS:
        print(f"  Processing product: {product}")


        trait_files = sorted(TRAIT_DATA_DIR.glob(f"X*_{product}.tif"))
        valid_traits = list(TRAITS.keys())

        for file_path in trait_files:
            trait_id = file_path.stem.split("_")[0]
            if trait_id not in valid_traits:
                continue

            print(f"Processing trait: {trait_id}")

            # Load product trait data
            if file_path.suffix == ".parquet":
                df = pd.read_parquet(file_path)
            else:
                # This function should be defined somewhere in your codebase
                df = read_product_tiff_to_df(trait_id, file_path)

            df = df.rename(columns={trait_id: f"{trait_id}_product"})
            print(f"Loaded {trait_id} df before reset: {df}")
            df = df.reset_index()
            df = df.rename(columns={'y': 'lat', 'x': 'lon'})
            print(f"Reset index for {trait_id} df after reste: {df}")

            product_df = global_grid_df(
                    df[["lon", "lat", f"{trait_id}_product"]],
                    f"{trait_id}_product",
                    lon="lon",
                    lat="lat",
                    res=cfg.target_resolution,
                    stats=[grid_stat],
                    n_min=1,
                ).rename(columns={"mean": f"{trait_id}_product"})
            
            


            # Load sPlot trait data
            splot_df = (
                read_trait_map(trait_id, "splot", band=1)
                .to_dataframe(name=trait_id)
                .drop(columns=["band", "spatial_ref"])
                .dropna()
            )
            
            print(f"Loaded {trait_id} sPLOT df before reset: {splot_df}")
            splot_df = splot_df.reset_index().rename(columns={'y': 'lat', 'x': 'lon'})
            splot_df = global_grid_df(
                splot_df[["lon", "lat", trait_id]],
                trait_id,
                lon="lon",
                lat="lat",
                res=cfg.target_resolution,
                stats=[grid_stat],
                n_min=1,
            ).rename(columns={"mean": f"{trait_id}"})
            

            # Filter based on common grid cells
            common_df = common_grids.get(trait_id)
            print(f"Common grid for {trait_id}: {common_df}")
            common_df = common_df.rename(columns={'y': 'lat', 'x': 'lon'})
            common_df = global_grid_df(
                common_df[["lon", "lat", f"{TRAITS[trait_id]}_true"]],
                f"{TRAITS[trait_id]}_true",
                lon="lon",
                lat="lat",
                res=cfg.target_resolution,
                stats=[grid_stat],
                n_min=1,
            ).rename(columns={"mean": f"{TRAITS[trait_id]}_true"})
            
            
            if common_df is None:
                print(f"      Skipping {trait_id}: no common grid data found.")
                continue
            print(f"Common grid after processing for {trait_id}: {common_df}")
            common_index_df = common_df.index.to_frame(index=False)


            common_index = pd.MultiIndex.from_frame(common_index_df)
            
            print(f"True df before filtering for {trait_id}: {df.shape}")
            print(f"sPLOT df before filtering for {trait_id}: {splot_df.shape}")

            true_df = df.loc[df.index.isin(common_index)]
            splot_df = splot_df.loc[splot_df.index.isin(common_index)]
            
            print(f"True df after filtering for {trait_id}: {true_df.shape}")
            print(f"sPLOT df after filtering for {trait_id}: {splot_df.shape}")

            splot_vs_true = product_df.join(splot_df, how="inner")
            if splot_vs_true.empty:
                print("      Skipping due to empty join.")
                continue
            print(f"sPLOT vs True df for {trait_id}: {splot_vs_true}")



            splot_vs_true = splot_vs_true.rename(
                columns={f'{trait_id}': "sPLOT", f'{trait_id}_product': "product_values"}
            )
            print(f"sPLOT vs True df after renaming for {trait_id}: {splot_vs_true.shape}")

            out_dir = OUT_BASE_DIR / seed / f"{product}"
            combined = splot_vs_true[["sPLOT", "product_values"]]
            combined_path = out_dir / "data" / f"{trait_id}_vs_splot_{TARGET_RESOLUTION}deg_nmin.parquet"
            combined_path.parent.mkdir(parents=True, exist_ok=True)
            combined.reset_index().to_parquet(combined_path)

            lat_wts = lat_weights(
                splot_vs_true.index.get_level_values("y").unique().values,
                TARGET_RESOLUTION,
            )

            # Plot hexbin
            with sns.plotting_context("notebook"), sns.axes_style("ticks"):
                fig, ax = plt.subplots(figsize=(10, 10))
                ax, metrics = plothexbin(
                    splot_vs_true,
                    pred_col="product_values",
                    true_col="sPLOT",
                    latwts=lat_wts,
                    label=trait_id,
                    scaleaxis=True,
                    allmetrics=True,
                    log_scale=True,
                    ax=ax,
                    n_min=1,
                    title=f'{product} ({seed})'
                )
                sns.despine()
                fig_path = out_dir / "plots" / f"{trait_id}_vs_splot_{TARGET_RESOLUTION}deg.png"
                fig_path.parent.mkdir(parents=True, exist_ok=True)
                plt.savefig(fig_path, dpi=300)
                plt.close()
                

                trait_stats = pd.DataFrame({
                    "seed": [seed],
                    "product": [product],
                    "trait": [trait_id],
                    "r2": [metrics["R²"]],
                    "nmae": [metrics["nMAE"]],
                    "r": [metrics["r"]],
                    "#grids": [metrics["n"]]
                })
                all_stats = pd.concat([all_stats, trait_stats], ignore_index=True).drop_duplicates()


        print(f"  Saving stats to {all_stats_path}")
        all_stats.to_csv(all_stats_path, index=False)




Processing for seed0...
  Processing product: moreno
Processing trait: X14
Loaded X14 df before reset:              X14_product
y     x                 
 89.5 -64.5    16.121082
      -63.5    16.121082
      -27.5    14.253015
 88.5 -64.5    16.121082
      -63.5    16.121082
...                  ...
-54.5 -69.5    14.473083
      -68.5    14.748342
      -67.5    15.855331
-55.5 -68.5    15.338668
-56.5 -68.5    15.338668

[11027 rows x 1 columns]
Reset index for X14 df after reste:         lat   lon  X14_product
0      89.5 -64.5    16.121082
1      89.5 -63.5    16.121082
2      89.5 -27.5    14.253015
3      88.5 -64.5    16.121082
4      88.5 -63.5    16.121082
...     ...   ...          ...
11022 -54.5 -69.5    14.473083
11023 -54.5 -68.5    14.748342
11024 -54.5 -67.5    15.855331
11025 -55.5 -68.5    15.338668
11026 -56.5 -68.5    15.338668

[11027 rows x 3 columns]
Loaded X14 sPLOT df before reset:                    X14
y     x               
 80.5  19.5  19.984119
 79.5 -7

In [None]:
common_X18 = pd.read_parquet('/path/to/PlantTraitNet output folder/data/nmin10/Height_vs_splot_1kmdeg_nmin10.parquet')
common_X3113 = pd.read_parquet('/path/to/PlantTraitNet output folder/data/nmin10/LeafArea_vs_splot_1kmdeg_nmin10.parquet')
common_X3117 = pd.read_parquet('/path/to/PlantTraitNet output folder/data/nmin10/SLA_vs_splot_1kmdeg_nmin10.parquet')
common_X14 = pd.read_parquet('/path/to/PlantTraitNet output folder/data/nmin10/LeafN_vs_splot_1kmdeg_nmin10.parquet')


In [None]:


common_grids = {
    "X18": common_X18,
    "X3113": common_X3113,
    "X14": common_X14,
    "X3117": common_X3117
}


cfg.target_resolution = 1
products = ['moreno', 'wolf', 'schiller', 'butler', 'boonman', 'madani', 'bodegom']
#products = ['schiller']

for product in products:
    trait_data_dir = Path("/path/to/all tif files")
    out_base_dir_path = Path("/path/to/output dir") / f"{product}"
    all_stats_path = out_base_dir_path / f"{product}" / "splot_corr.csv"
    all_stats_path.parent.mkdir(parents=True, exist_ok=True)

    if all_stats_path.exists():
        all_stats = pd.read_csv(all_stats_path)
    else:
        all_stats = pd.DataFrame(columns=["trait", "source", "resolution", "n_min"])

    trait_files = sorted(trait_data_dir.glob(f"X*_{product}.tif"))
    valid_traits = ['X14', 'X3117', 'X3113', 'X18']
    #valid_traits = ['X18']
    n_min = 1

    for file_path in trait_files:
        trait_id = file_path.stem.split("_")[0]
        print(f"\nProcessing trait: {trait_id}")
        if trait_id not in valid_traits:
            continue

        if file_path.suffix == ".parquet":
            df = pd.read_parquet(file_path)
        else:
            df = read_product_tiff_to_df(f'{trait_id}', file_path)

        df = df.rename(columns={trait_id: f"{trait_id}_true"})
        print(f'df:{df}')

        true_col = f"{trait_id}_true"

        splot_df = (
                read_trait_map(trait_id, "splot", band=1)
                .to_dataframe(name=trait_id)
                .drop(columns=["band", "spatial_ref"])
                .dropna()
        )
        print(f'splot df:{splot_df}')
        # Filter based on common grid cells
        print(f'common grids:{common_grids}')
        common_df = common_grids.get(trait_id)
        print(f"common_df for {trait_id}: {common_df}")
        if common_df is None:
            print(f"common dataframe not found, continuing without finding common data points for {trait_id}")
            true_df = df
        else:
            common_index = pd.MultiIndex.from_frame(common_df[["y", "x"]])
            true_df = df.loc[df.index.isin(common_index)]
            splot_df = splot_df.loc[splot_df.index.isin(common_index)]
        
        print(f'before join, true_df shape: {true_df}, splot_df shape: {splot_df}')

        splot_vs_true = true_df.join(splot_df, how="inner")
        # if splot_vs_true.empty:
        #     print("    Skipping due to empty join.")
        #     continue
        print(f"splot_vs_true : {splot_vs_true}")  
        splot_vs_true = splot_vs_true.rename(
                columns={f'{trait_id}': "sPLOT"}
            )
        splot_vs_true = splot_vs_true.rename(
                columns={f'{trait_id}_true': "product_values"}
            )
        print(f"splot_vs_true shape: {splot_vs_true.shape}")
        print(f"splot_vs_true : {splot_vs_true}")

        out_dir = out_base_dir_path / f"{product}"
        combined = splot_vs_true[["sPLOT", "product_values"]]
        combined_path = (
                out_dir / "data" / f"{trait_id}_vs_splot_{cfg.target_resolution}deg_nmin{n_min}.parquet"
        )
        combined_path.parent.mkdir(parents=True, exist_ok=True)
        combined.reset_index().to_parquet(combined_path)

        lat_wts = lat_weights(
                splot_vs_true.index.get_level_values("y").unique().values,
                cfg.target_resolution,
            )
        
        trait_stats_true = pd.DataFrame({
                "trait": [true_col],
                "source": ["true"],
                "resolution": [cfg.target_resolution],
        
        })

        all_stats = pd.concat([all_stats, trait_stats_true], ignore_index=True).drop_duplicates()

        with sns.plotting_context("notebook"), sns.axes_style("ticks"):
                fig, ax = plt.subplots(figsize=(10, 10))
                ax, metric = plothexbin(
                    splot_vs_true,
                    pred_col="product_values",
                    true_col="sPLOT",
                    latwts=lat_wts,
                    label=trait_id,
                    scaleaxis=True,
                    allmetrics=True,
                    ax=ax,
                    n_min=1,
                    log_scale=True,
                    log10_scale=False,
                    title=f'{product}'
                )

                sns.despine()
                fig_path = out_dir / "plots" / f"{trait_id}_vs_splot_{cfg.target_resolution}deg_nmin{n_min}.png"
                fig_path.parent.mkdir(parents=True, exist_ok=True)
                plt.savefig(fig_path, dpi=300)
                plt.close()

    print("\nSaving stats to", all_stats_path)
    all_stats.to_csv(all_stats_path, index=False)



Processing trait: X14
df:              X14_true
y     x               
 89.5 -64.5  16.121082
      -63.5  16.121082
      -27.5  14.253015
 88.5 -64.5  16.121082
      -63.5  16.121082
...                ...
-54.5 -69.5  14.473083
      -68.5  14.748342
      -67.5  15.855331
-55.5 -68.5  15.338668
-56.5 -68.5  15.338668

[11027 rows x 1 columns]
splot df:                   X14
y     x               
 80.5  19.5  19.984119
 79.5 -79.5  20.533623
       12.5  18.944222
       13.5  24.271261
 78.5 -76.5  21.019420
...                ...
-46.5 -73.5  11.822598
      -72.5  14.917747
-47.5 -73.5  12.612290
-54.5 -70.5  12.432774
      -69.5  12.759544

[2171 rows x 1 columns]
common grids:{'X18':         y      x  Height_splot  Height_true  Height_pred
0   -45.5  167.5      0.529262     3.795521     3.977481
1   -44.5  168.5      0.600692     3.440429     2.911857
2   -44.5  169.5      0.542537     2.254321     3.574434
3   -43.5  171.5      0.507069     2.265599     2.620199
4   -43.5 