# 2.0.4: Assessing the effect of sample weights on model performance

Previously we generated sample weights based on the source of our training data (sPlot vs GBIF) to address GBIF's disportionate size in the training pool. Below we perform some rudimentary testing to examine the performance impact of including sample weights during model training.

## Imports and config

In [1]:
from autogluon.tabular import TabularPredictor
import pandas as pd
import dask.dataframe as dd

from src.conf.conf import get_config
from src.conf.environment import log
from src.utils.autogluon_utils import get_best_model_ag
from src.utils.dataset_utils import get_models_dir, get_predict_dir, get_train_fn, get_weights_fn, get_trait_maps_dir
from src.utils.df_utils import grid_df_to_raster
from src.utils.raster_utils import open_raster
from src.utils.spatial_utils import weighted_pearson_r, lat_weights
from src.utils.spatial_utils import lat_weights


cfg = get_config()

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Load model

We will load a model that was trained on only a small fraction of the data due to time constraints.

In [2]:
models_dir = get_models_dir(cfg) / "debug"
trait_models = models_dir.glob("*")
model_fn = next(trait_models)
model_fn = model_fn if model_fn.is_dir() else next(trait_models)

model = get_best_model_ag(model_fn)
print("Using model: ", model)
predictor = TabularPredictor.load(str(model))

Using model:  models/Shrub_Tree_Grass/001/splot_gbif/autogluon/debug/X11_mean/high_20240802_080920


In [3]:
predictor.leaderboard(silent=True)

Unnamed: 0,model,score_val,eval_metric,pred_time_val,fit_time,pred_time_val_marginal,fit_time_marginal,stack_level,can_infer,fit_order
0,LightGBMXT_BAG_L1,-6.72649,root_mean_squared_error,39.58563,1387.272156,39.58563,1387.272156,1,True,1
1,WeightedEnsemble_L2,-6.72649,root_mean_squared_error,39.630754,1387.82867,0.045125,0.556515,2,True,3
2,WeightedEnsemble_L3,-6.72649,root_mean_squared_error,39.631998,1387.96595,0.046368,0.693794,3,True,6
3,LightGBM_BAG_L2,-6.742515,root_mean_squared_error,51.854211,1852.947598,5.637207,208.342802,2,True,5
4,LightGBMXT_BAG_L2,-6.74532,root_mean_squared_error,60.750352,2115.974113,14.533347,471.369317,2,True,4
5,LightGBM_BAG_L1,-6.831939,root_mean_squared_error,6.631375,257.33264,6.631375,257.33264,1,True,2


## Load inference data

We will now load the inference data. Since our first concern is whether or not sample weights improve correlation with sPlot sparse trait data, we will only load inference data where we know we also had sPlot observations.

In [4]:
weights = dd.from_pandas(pd.read_parquet(get_weights_fn(cfg)).reset_index()).pipe(
    lambda _ddf: _ddf[_ddf.weights == 1]  # this selects only sPlot weights
)
feats = dd.read_parquet(get_train_fn(cfg))

inference = (
    dd.merge(feats, weights, on=["x", "y"], how="inner")
    .drop(columns=["weights"])
    .compute()
    .reset_index(drop=True)
)

In [5]:
inference.head()

Unnamed: 0,x,y,ETH_GlobalCanopyHeightSD_2020_v1,ETH_GlobalCanopyHeight_2020_v1,sur_refl_b01_2001-2024_m10_mean,sur_refl_b01_2001-2024_m11_mean,sur_refl_b01_2001-2024_m12_mean,sur_refl_b01_2001-2024_m1_mean,sur_refl_b01_2001-2024_m2_mean,sur_refl_b01_2001-2024_m3_mean,...,X224_mean,X237_mean,X281_mean,X282_mean,X289_mean,X1080_mean,X3112_mean,X3113_mean,X3114_mean,X3120_mean
0,-175.345,-21.085,8.0,11.0,417.0,458.0,427.0,365.0,420.0,305.0,...,3.107818,13.240777,65.44162,555.043266,1538.228803,2005.007386,5243.341098,6803.984017,13709.344606,3.35982
1,-175.335,-21.105,9.0,15.0,545.0,564.0,667.0,575.0,530.0,465.0,...,2.840519,4.320474,41.47536,420.203514,939.41114,2959.057511,4245.471626,3023.614783,2894.802195,3.333049
2,-175.315,-21.165,10.0,14.0,431.0,567.0,468.0,482.0,434.0,286.0,...,3.258127,10.030351,54.316275,534.522312,1276.222195,2943.860841,4755.826745,4371.802828,7896.880415,3.755575
3,-175.305,-21.145,7.0,12.0,505.0,515.0,571.0,465.0,459.0,454.0,...,3.339594,11.461425,63.559805,492.36107,1299.387527,2077.620478,5787.809759,6588.912343,11914.602533,3.559118
4,-175.295,-21.155,7.0,7.0,511.0,570.0,567.0,480.0,457.0,436.0,...,3.150806,7.828414,51.288306,417.311839,932.916827,2347.639714,5849.237594,5938.183091,10485.350612,3.523358


## Make prediction

In [6]:
xy = inference[["x", "y"]]
prediction = predictor.predict(inference, as_pandas=True)
prediction = pd.concat([xy, prediction], axis=1).set_index(["y", "x"])

## Rasterize prediction

In [7]:
out_dir = get_predict_dir(cfg) / "debug"
out_dir.mkdir(exist_ok=True, parents=True)

out_fn = out_dir / f"{model.parent.stem}.tif"

raster = grid_df_to_raster(
    prediction, res=cfg.target_resolution, out=out_fn
)

## Compare weighted vs unweighted predictions

Load **original sparse sPlot traits**.

In [8]:
dropcols = ["band", "spatial_ref"]

In [9]:
splot_fn = get_trait_maps_dir(cfg, "splot") / f"{model_fn.stem.split('_')[0]}.tif"
splot = (
    open_raster(splot_fn)
    .sel(band=cfg.datasets.Y.trait_stat)
    .to_dataframe(name="splot")
    .drop(columns=dropcols)
    .dropna()
)

Load prediction made with **unweighted samples**.

In [10]:
unweighted = (
    open_raster(get_predict_dir(cfg) / f"{model_fn.stem}.tif")
    .sel(band=1)
    .to_dataframe(name="unweighted")
    .drop(columns=dropcols)
    .dropna()
)

Load prediction made with **weighted samples**.

In [11]:
weighted = (
    open_raster(out_fn)
    .sel(band=1)
    .to_dataframe(name="weighted")
    .drop(columns=dropcols)
    .dropna()
)

### Calculate sPlot correlations using Pearson's *r*.

Join splot and unweighted.

In [12]:
splot_unweighted = splot.join(unweighted, how="inner")
splot_weighted = splot.join(weighted, how="inner")

Calculate latitude weights.

In [13]:
lat_unique = splot_unweighted.index.get_level_values("y").unique()
weights = lat_weights(lat_unique, cfg.target_resolution)

And finally calculate and compare the weighted Pearson correlation coefficients.

In [14]:
r_unweighted = weighted_pearson_r(splot_unweighted, weights)
r_weighted = weighted_pearson_r(splot_weighted, weights)

print(f"r (unweighted): {r_unweighted}")
print(f"r (weighted): {r_weighted}")

r (unweighted): 0.6013233545903542
r (weighted): 0.6573193892105225


## Compare other performance metrics

In [15]:
unweighted_perf = (
    pd.read_csv(
        get_best_model_ag(get_models_dir(cfg) / model_fn.name) / cfg.train.eval_results
    )
    .assign(idx=["unw_mean", "unw_std"])
    .set_index("idx")
)

weighted_perf = (
    pd.read_csv(model / cfg.train.eval_results, index_col=0)
    .reset_index(drop=True)
    .assign(idx=["w_mean", "w_std"])
    .set_index("idx")
)

pd.concat([unweighted_perf, weighted_perf]).sort_index(
    key=lambda x: x.str.split("_").str[1]
)

Unnamed: 0_level_0,root_mean_squared_error,mean_squared_error,mean_absolute_error,r2,pearsonr,median_absolute_error,norm_root_mean_squared_error
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
unw_mean,-6.559402,-43.043923,-4.826946,0.246041,0.496308,-3.716623,-0.111703
w_mean,-5.905897,-34.911147,-4.306712,0.262391,0.466459,-3.276779,-0.100575
unw_std,0.142945,1.868891,0.127109,0.025733,0.026194,0.133616,0.002434
w_std,0.187178,2.233136,0.143684,0.029128,0.018593,0.139431,0.003188
