# 2.0.0: Coefficient of Variation from cross-validation

We can measure model performance using several metrics. To assess model reliability in terms of the reference data, we can calculate average $RMSE$ and $R^2$ values across all folds. To provide further indication of model robustness, however, we can also calculate the pixel-wise coefficient of variation (CoV), or the ratio of the standard deviation  to the mean  across all folds, defined as:

$$CoV = \frac{\sigma}{\mu}$$

This, paired with Meyer and Pebesma's Area of Applicability (2022), can give us a better idea of where our extrapolated trait maps are more or less reliable.

## Imports and config

In [1]:
from pathlib import Path

import joblib
import pandas as pd
from autogluon.tabular import TabularPredictor
from tqdm import trange

from src.conf.conf import get_config
from src.conf.environment import log
from src.utils.autogluon_utils import get_best_model_ag

cfg = get_config()

## Load CV models

### Get best base predictor

First we load the AutoGluon `TabularPredictor` in order to retrieve the individual models from each cross-validation fold. But there's a catch: to get the best performance at inference time, our final models are actually ensemble models which don't contain sub-models for each CV fold. Instead, we'll need to identify the best-performing base model from the ensemble and generate trait predictions from each of its CV fold sub-models. This should then provide a fairly conservative CoV which the ensemble model should actually outperform slightly.

In [2]:
predictor = TabularPredictor.load(
    "models/Shrub_Tree_Grass/001/splot_gbif/autogluon/X11_mean/good_20240607_220933"
)

In [3]:
predictor.leaderboard()

Unnamed: 0,model,score_val,eval_metric,pred_time_val,fit_time,pred_time_val_marginal,fit_time_marginal,stack_level,can_infer,fit_order
0,WeightedEnsemble_L2,-6.585418,root_mean_squared_error,579.164635,6507.011832,0.031631,0.452263,2,True,5
1,WeightedEnsemble_L3,-6.585418,root_mean_squared_error,579.164829,6507.021219,0.031826,0.461649,3,True,6
2,LightGBMXT_BAG_L1,-6.587197,root_mean_squared_error,480.394939,2224.356663,480.394939,2224.356663,1,True,1
3,LightGBM_BAG_L1,-6.598851,root_mean_squared_error,93.548172,725.426548,93.548172,725.426548,1,True,2
4,ExtraTreesMSE_BAG_L1,-6.626949,root_mean_squared_error,5.189893,3556.776358,5.189893,3556.776358,1,True,4
5,CatBoost_BAG_L1,-6.687462,root_mean_squared_error,0.478468,128.84223,0.478468,128.84223,1,True,3
6,WeightedEnsemble_L3_FULL,,root_mean_squared_error,,1110.229789,,0.461649,3,True,12
7,WeightedEnsemble_L2_FULL,,root_mean_squared_error,,1110.220403,,0.452263,2,True,11
8,LightGBM_BAG_L1_FULL,,root_mean_squared_error,,138.430075,,138.430075,1,True,8
9,LightGBMXT_BAG_L1_FULL,,root_mean_squared_error,,574.082538,,574.082538,1,True,7


Get the best predictor that is not an ensemble model (i.e. `stack_level == 1`).

In [4]:
best_model = (
    predictor.leaderboard(refit_full=False)
    .pipe(lambda df: df[df["stack_level"] == 1])
    .pipe(lambda df: df.loc[df["score_val"].idxmax()])
    .model
)

cv_models_dir = Path(predictor.path, "models", str(best_model))

## Load inference data

In [5]:
predict_fn: Path = (
    Path(cfg.train.dir)
    / cfg.eo_data.predict.dir
    / cfg.model_res
    / cfg.eo_data.predict.filename
)

data = pd.read_parquet(predict_fn)
xy = data[["x", "y"]]
data = data.drop(columns=["x", "y"])

In [7]:
BATCHES = 1

# Calculate batch size
batch_size = len(data) // BATCHES + (len(data) % BATCHES > 0)

# Initialize an empty list to store batch predictions
covs = []

# Predict in batches
log.info("Predicting in batches...")
for i in trange(0, len(data), batch_size):
    batch = data.iloc[i : i + batch_size]
    batch_predictions = []

    for submodel in cv_models_dir.iterdir():
        if not submodel.stem.startswith("S1"):
            continue
        log.info("Predicting with %s", submodel.stem)
        sub_predictor = joblib.load(str(submodel / "model.pkl"))
        batch_predictions.append(sub_predictor.predict(batch))

    log.info("Calculating coefficient of variation for batch %s...", i)
    # Calculate coefficient of variation across all submodel predictions
    batch_cov = pd.concat(batch_predictions).std(axis=1) / pd.concat(
        batch_predictions
    ).mean(axis=1)
    covs.append(batch_cov)

# Concatenate all batch predictions
full_cov = pd.concat(covs)

2024-06-21 11:03:00 CEST - src.conf.environment - INFO - Predicting in batches...
  0%|          | 0/1 [00:00<?, ?it/s]2024-06-21 11:03:00 CEST - src.conf.environment - INFO - Predicting with S1F4


2024-06-21 11:14:58 CEST - src.conf.environment - INFO - Predicting with S1F8
2024-06-21 11:27:12 CEST - src.conf.environment - INFO - Predicting with S1F5
2024-06-21 11:38:07 CEST - src.conf.environment - INFO - Predicting with S1F10
2024-06-21 11:50:36 CEST - src.conf.environment - INFO - Predicting with S1F2
2024-06-21 12:02:55 CEST - src.conf.environment - INFO - Predicting with S1F1
2024-06-21 12:13:22 CEST - src.conf.environment - INFO - Predicting with S1F3
2024-06-21 12:26:26 CEST - src.conf.environment - INFO - Predicting with S1F6
2024-06-21 12:36:03 CEST - src.conf.environment - INFO - Predicting with S1F7
2024-06-21 12:49:13 CEST - src.conf.environment - INFO - Predicting with S1F9
2024-06-21 13:02:13 CEST - src.conf.environment - INFO - Calculating coefficient of variation for batch 0...
  0%|          | 0/1 [1:59:12<?, ?it/s]


TypeError: cannot concatenate object of type '<class 'numpy.ndarray'>'; only Series and DataFrame objs are valid

In [23]:
batch_prediction_dfs = []
for i, batch_prediction in enumerate(batch_predictions):
    batch_prediction_dfs.append(pd.DataFrame(batch_prediction, columns=[f"prediction_{i}"]))

full_predictions = pd.concat(batch_prediction_dfs, axis=1)
cov = full_predictions.std(axis=1) / full_predictions.mean(axis=1)

In [24]:
cov

0            0.016514
1            0.014311
2            0.023106
3            0.019038
4            0.022063
               ...   
134187201    0.022066
134187202    0.033093
134187203    0.019086
134187204    0.009998
134187205    0.012896
Length: 134187206, dtype: float32

In [None]:
# TODO: #9 Plot CoV