In [None]:
import logging
import sys
import warnings

import matplotlib as mpl
from loguru import logger as loguru_logger

from empirical_fire_modelling.configuration import Experiment
from empirical_fire_modelling.data import get_data, get_experiment_split_data
from empirical_fire_modelling.logging_config import enable_logging
from empirical_fire_modelling.model import get_model
from empirical_fire_modelling.plotting import cube_plotting
from empirical_fire_modelling.utils import get_mm_data

mpl.rc_file("../matplotlibrc")

loguru_logger.enable("alepython")
loguru_logger.remove()
loguru_logger.add(sys.stderr, level="WARNING")

logger = logging.getLogger(__name__)
enable_logging(level="WARNING")

warnings.filterwarnings("ignore", ".*Collapsing a non-contiguous coordinate.*")
warnings.filterwarnings("ignore", ".*DEFAULT_SPHERICAL_EARTH_RADIUS.*")
warnings.filterwarnings("ignore", ".*guessing contiguous bounds.*")

warnings.filterwarnings(
    "ignore", 'Setting feature_perturbation = "tree_path_dependent".*'
)


def get_experiment_prediction(experiment, **kwargs):
    """Get out-of-sample (validation) predictions."""
    # Operate on cached data only.
    get_experiment_split_data.check_in_store(experiment)
    X_train, X_test, y_train, y_test = get_experiment_split_data(experiment)

    # Operate on cached fitted models only.
    get_model(X_train, y_train, cache_check=True)
    model = get_model(X_train, y_train)

    return model.predict(X_test)

### Get reference data

In [None]:
(
    endog_data,
    exog_data,
    master_mask,
    masked_datasets,
    land_mask,
) = get_data(Experiment.CURR)

### Get predictions

In [None]:
curr_pred = get_mm_data(get_experiment_prediction(Experiment.CURR), master_mask, "val")
all_pred = get_mm_data(get_experiment_prediction(Experiment.ALL), master_mask, "val")

### Compare predictions

In [None]:
X_train, X_test, y_train, y_test = get_experiment_split_data(Experiment.CURR)
obs_ba = get_mm_data(y_test.values, master_mask, "val")

In [None]:
obs_ba.shape, curr_pred.shape, all_pred.shape

In [None]:
for arr in (obs_ba, curr_pred, all_pred):
    cube_plotting(np.mean(arr, axis=0), log=True)

### Plot errors

In [None]:
for arr in (curr_pred, all_pred):
    cube_plotting(np.mean(arr - obs_ba, axis=0), log=True, title="Pred - Obs")

### Compare errors

In [None]:
curr_err = curr_pred - obs_ba
all_err = all_pred - obs_ba

In [None]:
_ = cube_plotting(
    np.mean(np.abs(curr_err) - np.abs(all_err), axis=0) / np.mean(obs_ba, axis=0),
    title="<|Err(CURR)| - |Err(ALL)|> / <Ob>",
    fig=plt.figure(dpi=300),
    boundaries=[-5e-1, 0, 5e-1],
)

In [None]:
plt.hist(
    np.mean(np.abs(curr_err) - np.abs(all_err), axis=0).ravel(),
    bins=np.linspace(-0.2, 0.2, 80),
)
plt.yscale("log")
_ = plt.title("<|Err(CURR)| - |Err(ALL)|>")

In [None]:
plt.hist(
    -np.mean(np.abs(curr_err) - np.abs(all_err), axis=0).ravel(),
    bins=np.linspace(-0.2, 0.2, 80),
    color="C1",
)
plt.yscale("log")
_ = plt.title("<|Err(CURR)| - |Err(ALL)|>")