In [37]:

import os
import re
import argparse
from pathlib import Path
from glob import glob
import numpy as np
import pandas as pd
from scipy.stats import linregress
from matplotlib import pyplot as plt
from matplotlib.colors import Normalize
import seaborn as sns
import sys
sys.path.append('/home/as2114/code/PlantTraitNet/src')
from utils.benchmark_conf import get_benchmark_config
from utils.benchmark_utils import read_trait_map, global_grid_df, lat_weights

cfg = get_benchmark_config()
# -------------------- METRICS --------------------
def weighted_pearsonr(x, y, w):
    mean_x, mean_y = np.average(x, weights=w), np.average(y, weights=w)
    cov_xy = np.sum(w * (x - mean_x) * (y - mean_y))
    return cov_xy / np.sqrt(np.sum(w * (x - mean_x) ** 2) * np.sum(w * (y - mean_y) ** 2))


def compute_weighted_metrics(df, latwts, pred_col, true_col, log_scale=False):
    """Compute weighted metrics between prediction and ground truth"""
    df = df.copy()
    if df.index.names != ["y", "x"]:
        df.index.set_names(["y", "x"], inplace=True)

    df["weight"] = df.index.get_level_values("y").map(latwts)
    df = df.dropna(subset=["weight", pred_col, true_col])

    y_true, y_pred, w = df[true_col].values, df[pred_col].values, df["weight"].values
    if log_scale:
        y_true, y_pred = np.log1p(y_true), np.log1p(y_pred)

    if len(y_true) == 0:
        return dict.fromkeys(["n", "mae", "nmae", "rmse", "nrmse", "r2", "slope", "pearson_r"], np.nan)

    err = y_pred - y_true
    mae = np.average(np.abs(err), weights=w)
    rmse = np.sqrt(np.average(err ** 2, weights=w))
    rng = np.quantile(y_true, 0.99) - np.quantile(y_true, 0.01)
    nmae, nrmse = mae / rng, rmse / rng

    ss_res, ss_tot = np.sum(w * err ** 2), np.sum(w * (y_true - np.average(y_true, weights=w)) ** 2)
    r2 = 1 - ss_res / ss_tot if ss_tot > 0 else np.nan
    slope, _, *_ = linregress(y_true, y_pred)
    r = weighted_pearsonr(y_true, y_pred, w)
    return dict(n=len(df), mae=mae, nmae=nmae, rmse=rmse, nrmse=nrmse, r2=r2, slope=slope, pearson_r=r)


def plothexbin(df, true_col, pred_col, latwts, ax, label=None,
               scaleaxis=True, allmetrics=True, n_min=1, log_scale=False, title=''):
    """Plot weighted hexbin with regression metrics overlay"""
    if df.empty:
        print("⚠️ Empty DataFrame passed to plothexbin")
        return ax

    metrics = compute_weighted_metrics(df, latwts, pred_col, true_col, log_scale=log_scale)
    y_true, y_pred = df[true_col].values, df[pred_col].values
    if log_scale:
        y_true, y_pred = np.log1p(y_true), np.log1p(y_pred)

    ax.set_aspect('equal', adjustable='box')

    if scaleaxis:
        maxval, minval = max(y_true.max(), y_pred.max()), min(y_true.min(), y_pred.min())
        ax.set_xlim(minval, maxval)
        ax.set_ylim(minval, maxval)
        ax.plot([minval, maxval], [minval, maxval], color='red', lw=2)
        extent = [minval, maxval, minval, maxval]
    else:
        extent = None

    hb = ax.hexbin(y_true, y_pred, gridsize=60, cmap='plasma', mincnt=1, extent=extent,
                   norm=Normalize(vmin=0, vmax=1))
    counts = hb.get_array()
    norm_counts = counts / counts.max()
    hb.set_array(norm_counts)

    if label:
        ax.set_xlabel(f"Observed {label}")
        ax.set_ylabel(f"Predicted {label}")

    if allmetrics:
        ax.text(
            0.05, 0.95,
            f"R²: {metrics['r2']:.2f}\nnMAE: {metrics['nmae']:.2f}\nR: {metrics['pearson_r']:.2f}",
            transform=ax.transAxes,
            fontsize=12, verticalalignment='top',
            bbox=dict(boxstyle='round,pad=0.3', edgecolor='black', facecolor='white')
        )

    ax.set_title(title)
    return ax


# -------------------- PREPARE DATA --------------------
def prepare_data(pred, normval, valmeta, traits):
    # normval = pd.read_csv(normval_path)
    # pred = pd.read_csv(val_path)
    pred = pd.concat([
        pred,
        normval[[c for c in normval.columns if c.endswith("_uncertainty") and c.split("_uncertainty")[0] in traits]],
        valmeta[['Longitude', 'Latitude']]
    ], axis=1)
    for t in traits:
        if t.lower() == "leafarea":
            pred[f"{t}_true"] *= 100
            pred[f"{t}_pred"] *= 100
        #if t.lower() == "leaf_n then make it leafn
        if t.lower() == "leaf_n":
            pred.rename(columns={f"{t}_true": "LeafN_true", f"{t}_pred": "LeafN_pred"}, inplace=True)
    return pred


# -------------------- SPLOT BENCHMARKING --------------------
def splot_benchmarking(pred, cfg, traits, out_dir, n_min=20):
    results = []
    out_dir = Path(out_dir)

    for t in traits:
        splot_df = read_trait_map(t, "splot", band=1).to_dataframe(name=t).drop(columns=["band", "spatial_ref"]).dropna()
        
        

        grid_true = global_grid_df(pred, f"{t}_true", lon="Longitude", lat="Latitude", res=1, stats=["mean"], n_min=n_min)
        grid_pred = global_grid_df(pred, f"{t}_pred", lon="Longitude", lat="Latitude", res=1, stats=["mean"], n_min=n_min)
        
        print(f"sPlot df shape for {t}: {splot_df.shape}")
        print(f"Grid true shape for {t}: {grid_true.shape}")
        print(f"Grid pred shape for {t}: {grid_pred.shape}")

        # Join on grid index (y, x)
        merged = grid_true.join(grid_pred, lsuffix="_true", rsuffix="_pred", how="inner")
        merged = merged.join(splot_df, how="inner")
        
        print(f"Merged df shape for {t}: {merged.shape}")
        

        merged.rename(columns={t: f"{t}_splot", "mean_true": f"{t}_true", "mean_pred": f"{t}_pred"}, inplace=True)
        merged = merged[[f"{t}_splot", f"{t}_true", f"{t}_pred"]].dropna()

        if merged.empty:
            continue

        # Save combined data
        combined_path = out_dir / "combined_data" / f"{t}_splot_vs_model_1deg_nmin{n_min}.parquet"
        combined_path.parent.mkdir(parents=True, exist_ok=True)
        merged.reset_index().to_parquet(combined_path)
        print(f"Saved combined data for {t} at {combined_path}")
        

        lat_wts = lat_weights(merged.index.get_level_values("y").unique().values, 1)

        # Plotting
        with sns.plotting_context("notebook"), sns.axes_style("ticks"):
            fig, ax = plt.subplots(1, 2, figsize=(20, 10))

            plothexbin(merged, f"{t}_splot", f"{t}_true", lat_wts, ax[0],
                       label=t, title="sPlot vs TRY6 mean", log_scale=True)

            plothexbin(merged, f"{t}_splot", f"{t}_pred", lat_wts, ax[1],
                       label=t, title="sPlot vs Model", log_scale=True)

            fig_path = out_dir / "plots" / f"nmin{n_min}" / f"{t}_vs_splot_1deg_nmin{n_min}.png"
            fig_path.parent.mkdir(parents=True, exist_ok=True)
            plt.savefig(fig_path, dpi=300)
            plt.close(fig)

        # Compute metrics
        metrics = compute_weighted_metrics(merged, lat_wts, f"{t}_pred", f"{t}_splot", log_scale=True)
        metrics.update(trait=t)
        results.append(metrics)

    if results:
        pd.DataFrame(results).to_csv(out_dir / "splot_metrics.csv", index=False)
    return results


In [None]:
import pandas as pd

#running inference.py creates two files, <>.csv which contains trait prediction in original scale and normalized_<>.csv which contains trait predictions with normalized values

# === File Paths ===
mega_val_path = '<>/results_benchmark_data.csv'
meganorm_path = '<>/results_normalized_benchmark_data.csv'
megameta_path = '<>/benchmark_data.csv' #meta data provided in the dataset reporsitory


val_path = '<>/results_val.csv'
normval_path = '<>/results_normalized_val.csv'
valmeta_path = '<>/val.csv'

# === Load CSVs ===
megaval = pd.read_csv(mega_val_path).drop(columns=[
    'Height_uncertainty', 'LeafArea_uncertainty', 'SLA_uncertainty', 'Leaf_N_uncertainty'
])
meganormval = pd.read_csv(meganorm_path)

val = pd.read_csv(val_path).drop(columns=[
    'Height_uncertainty', 'LeafArea_uncertainty', 'SLA_uncertainty', 'Leaf_N_uncertainty'
])
normval = pd.read_csv(normval_path)
valmeta = pd.read_csv(valmeta_path)[['Longitude', 'Latitude']]

#rename latitude and longitude columns to Latitude and Longitude for consistency
valmeta = valmeta.rename(columns={'longitude': 'Longitude', 'latitude': 'Latitude'})

print(megaval.shape, meganormval.shape, megavalmeta.shape)
print(val.shape, normval.shape, valmeta.shape)



# === Concatenate Predictions, Normals, and Metadata ===
full_pred = pd.concat([megaval, val], axis=0).reset_index(drop=True)
full_norm = pd.concat([meganormval, normval], axis=0).reset_index(drop=True)
full_meta = pd.concat([megavalmeta, valmeta], axis=0).reset_index(drop=True)

print("After concatenation:")

print(full_pred.shape, full_norm.shape, full_meta.shape)

(298501, 8) (298501, 12) (298501, 2)
(84710, 8) (84710, 12) (84710, 2)
After concatenation:
(383211, 8) (383211, 12) (383211, 2)


In [None]:
from utils.benchmark_conf import get_benchmark_config
cfg = get_benchmark_config()
traits=['Height', 'LeafArea', 'SLA', 'Leaf_N']
pred = prepare_data(pred=full_pred, normval=ful_norm, valmeta=full_meta, traits=['Height', 'LeafArea', 'SLA', 'LeafN'])
            
splot_benchmarking(pred, cfg, traits=traits, out_dir='./output', n_min=20)