In [None]:
import logging
import sys
import warnings
from itertools import product

import matplotlib as mpl
import pandas as pd
from loguru import logger as loguru_logger

from empirical_fire_modelling import variable
from empirical_fire_modelling.analysis.model_combinations import (
    cached_multiple_combinations,
)
from empirical_fire_modelling.configuration import Experiment, n_splits
from empirical_fire_modelling.data import get_experiment_split_data
from empirical_fire_modelling.logging_config import enable_logging
from empirical_fire_modelling.utils import tqdm

mpl.rc_file("matplotlibrc")

loguru_logger.enable("alepython")
loguru_logger.remove()
loguru_logger.add(sys.stderr, level="WARNING")

logger = logging.getLogger(__name__)
enable_logging(level="WARNING")

warnings.filterwarnings("ignore", ".*Collapsing a non-contiguous coordinate.*")
warnings.filterwarnings("ignore", ".*DEFAULT_SPHERICAL_EARTH_RADIUS.*")
warnings.filterwarnings("ignore", ".*guessing contiguous bounds.*")

warnings.filterwarnings(
    "ignore", 'Setting feature_perturbation = "tree_path_dependent".*'
)

In [None]:
# Get training and test data for all variables.
get_experiment_split_data.check_in_store(Experiment.ALL)
X_train, X_test, y_train, y_test = get_experiment_split_data(Experiment.ALL)

shifts = (0, 1, 3, 6, 9)
assert all(shift in variable.lags for shift in shifts)

veg_lags = tuple(
    tuple(
        [
            var_factory[shift]
            for var_factory in variable.feature_categories[variable.Category.VEGETATION]
        ]
    )
    for shift in shifts
)

assert all(feature in X_train for unpacked in veg_lags for feature in unpacked)
assert all(feature in X_test for unpacked in veg_lags for feature in unpacked)

common_vars = (
    variable.DRY_DAY_PERIOD[0],
    variable.MAX_TEMP[0],
    variable.PFT_CROP[0],
    variable.DRY_DAY_PERIOD[1],
    variable.DRY_DAY_PERIOD[3],
    variable.DRY_DAY_PERIOD[9],
    variable.POPD[0],
    variable.DRY_DAY_PERIOD[6],
    variable.LIGHTNING[0],
    variable.DIURNAL_TEMP_RANGE[0],
)

combinations = [
    (
        *common_vars,
        *veg_lag_product,
    )
    for veg_lag_product in product(*veg_lags)
]

assert all(len(combination) == 15 for combination in combinations)

In [None]:
# Load cached data for all combinations / splits.

# Get training and test data for all variables.
get_experiment_split_data.check_in_store(Experiment.ALL)
X_all, _, y, _ = get_experiment_split_data(Experiment.ALL)
combined_scores = cached_multiple_combinations(X_all, y, combinations, range(n_splits))

In [None]:
processed_scores = {}
for variables, combination_data in tqdm(combined_scores.items()):
    # Combine the different CV splits.
    test_r2s = []
    train_r2s = []

    for cv_data in combination_data.values():
        for key, val in cv_data.items():
            if "test_score" in key:
                test_r2s.append(val["r2"])
            elif "train_score" in key:
                train_r2s.append(val["r2"])

    assert len(test_r2s) == len(train_r2s) == n_splits
    processed_scores[", ".join([str(v) for v in variables if v not in common_vars])] = {
        "test_mean": np.mean(test_r2s),
        "test_std": np.std(test_r2s),
        "train_mean": np.mean(train_r2s),
        "train_std": np.std(train_r2s),
    }

In [None]:
score_df = pd.DataFrame(processed_scores).T.sort_values("test_mean", ascending=False)
score_df

In [None]:
agg_data = {}
for key in ["FAPAR", "LAI", "SIF", "VOD"]:
    agg_data[key] = score_df[[key in i for i in score_df.index]]["test_mean"]
_ = pd.DataFrame(agg_data).boxplot()