## Initialisation

In [None]:
import logging
import os
import re
import sys
import warnings
from collections import namedtuple
from functools import reduce
from itertools import combinations
from operator import mul

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import seaborn as sns
from joblib import Memory, Parallel, delayed
from loguru import logger as loguru_logger
from matplotlib.patches import Rectangle
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import cross_val_score, train_test_split
from tqdm import tqdm

import wildfires.analysis
from alepython import ale_plot
from alepython.ale import _second_order_ale_quant
from wildfires.analysis import *
from wildfires.dask_cx1 import get_parallel_backend
from wildfires.data import *
from wildfires.logging_config import enable_logging
from wildfires.qstat import get_ncpus
from wildfires.utils import *

loguru_logger.enable("alepython")
loguru_logger.remove()
loguru_logger.add(sys.stderr, level="WARNING")

logger = logging.getLogger(__name__)

enable_logging("jupyter")

warnings.filterwarnings("ignore", ".*Collapsing a non-contiguous coordinate.*")
warnings.filterwarnings("ignore", ".*DEFAULT_SPHERICAL_EARTH_RADIUS*")
warnings.filterwarnings("ignore", ".*guessing contiguous bounds*")

normal_coast_linewidth = 0.5
mpl.rc("figure", figsize=(14, 6))
mpl.rc("font", size=9.0)

figure_saver = FigureSaver(
    directories=os.path.join("~", "tmp", "analysis_time_lags_explain_pdp_ale"),
    debug=True,
)
memory = get_memory("analysis_time_lags_explain_pdp_ale", verbose=100)
CACHE_DIR = os.path.join(DATA_DIR, ".pickle", "time_lags_explain_pdp_ale")

## Overwrite wildfires get_data with our own personalised version

In [None]:
from get_time_lag_data import get_data

In [None]:
value = "symlog"
linthres = 1e-2
subs = [2, 3, 4, 5, 6, 7, 8, 9]
log_xscale_kwargs = dict(value=value, linthreshx=linthres, subsx=subs)
log_yscale_kwargs = dict(value=value, linthreshy=linthres, subsy=subs)
log_vars = (
    "dry day period",
    "popd",
    "agb tree",
    "cape x precip",
    "lai",
    "shruball",
    "pftherb",
    "pftcrop",
    "treeall",
)

In [None]:
def ale_2d(predictor, train_set, features, bins=40, coverage=1):
    if coverage < 1:
        # This should be ok if `train_set` is randomised, as it usually is.
        train_set = train_set[: int(train_set.shape[0] * coverage)]
    ale, quantiles_list, samples_grid = _second_order_ale_quant(
        predictor, train_set, features, bins=bins, return_samples_grid=True
    )
    fig, ax = plt.subplots()
    centres_list = [get_centres(quantiles) for quantiles in quantiles_list]
    n_x, n_y = 50, 50
    x = np.linspace(centres_list[0][0], centres_list[0][-1], n_x)
    y = np.linspace(centres_list[1][0], centres_list[1][-1], n_y)

    X, Y = np.meshgrid(x, y, indexing="xy")
    ale_interp = scipy.interpolate.interp2d(centres_list[0], centres_list[1], ale.T)
    CF = ax.contourf(X, Y, ale_interp(x, y), cmap="bwr", levels=30, alpha=0.7)

    # Do not autoscale, so that boxes at the edges (contourf only plots the bin
    # centres, not their edges) don't enlarge the plot. Such boxes include markings for
    # invalid cells, or hatched boxes for valid cells.
    plt.autoscale(False)

    # Add hatching for the significant cells. These have at least `min_samples` samples.
    # By default, calculate this as the number of samples in each bin if everything was equally distributed, divided by 10.
    min_samples = (train_set.shape[0] / reduce(mul, map(len, centres_list))) / 10
    for i, j in zip(*np.where(samples_grid >= min_samples)):
        ax.add_patch(
            Rectangle(
                [quantiles_list[0][i], quantiles_list[1][j]],
                quantiles_list[0][i + 1] - quantiles_list[0][i],
                quantiles_list[1][j + 1] - quantiles_list[1][j],
                linewidth=0,
                fill=None,
                hatch=".",
                alpha=0.4,
            )
        )

    if np.any(ale.mask):
        # Add rectangles to indicate cells without samples.
        for i, j in zip(*np.where(ale.mask)):
            ax.add_patch(
                Rectangle(
                    [quantiles_list[0][i], quantiles_list[1][j]],
                    quantiles_list[0][i + 1] - quantiles_list[0][i],
                    quantiles_list[1][j + 1] - quantiles_list[1][j],
                    linewidth=1,
                    edgecolor="k",
                    facecolor="none",
                    alpha=0.4,
                )
            )
    fig.colorbar(CF, format="%.0e")
    ax.set_xlabel(features[0])
    ax.set_ylabel(features[1])
    nbins_str = "x".join([str(len(centres)) for centres in centres_list])
    ax.set_title(
        f"Second-order ALE of features {features[0]} and {features[1]}\n"
        f"Bins: {nbins_str} (Hatching: Sig., Boxes: Invalid)"
    )

    if any(log_var.lower() in features[0].lower() for log_var in log_vars):
        ax.set_xscale(**log_xscale_kwargs)
    if any(log_var.lower() in features[1].lower() for log_var in log_vars):
        ax.set_yscale(**log_yscale_kwargs)
    figure_saver.save_figure(fig, "__".join(columns), sub_directory="2d_ale_low")
    return ale, quantiles_list, samples_grid

## Creating the Data Structures used for Fitting

In [None]:
shift_months = [1, 3, 6, 9, 12, 18, 24]

# selection_variables = (
#     "VOD Ku-band -3 Month",
#     # "SIF",  # Fix regridding!!
#     "VOD Ku-band -1 Month",
#     "Dry Day Period -3 Month",
#     "FAPAR",
#     "pftHerb",
#     "LAI -1 Month",
#     "popd",
#     "Dry Day Period -24 Month",
#     "pftCrop",
#     "FAPAR -1 Month",
#     "FAPAR -24 Month",
#     "Max Temp",
#     "Dry Day Period -6 Month",
#     "VOD Ku-band -6 Month",
# )

# ext_selection_variables = selection_variables + (
#     "Dry Day Period -1 Month",
#     "FAPAR -6 Month",
#     "ShrubAll",
#     "SWI(1)",
#     "TreeAll",
# )
from ipdb import launch_ipdb_on_exception

with launch_ipdb_on_exception():
    (
        e_s_endog_data,
        e_s_exog_data,
        e_s_master_mask,
        e_s_filled_datasets,
        e_s_masked_datasets,
        e_s_land_mask,
    ) = get_data(shift_months=shift_months, selection_variables=None)

### Offset data from 12 or more months before the current month in order to ease analysis (interpretability).
We are interested in the trends in these properties, not their absolute values, therefore we subtract a recent 'seasonal cycle' analogue.
This hopefully avoids capturing the same relationships for a variable and its 12 month counterpart due to their high correlation.

In [None]:
to_delete = []
for column in e_s_exog_data:
    match = re.search(r"-\d{1,2}", column)
    if match:
        span = match.span()
        # Change the string to reflect the shift.
        original_offset = int(column[slice(*span)])
        if original_offset > -12:
            # Only shift months that are 12 or more months before the current month.
            continue
        comp = -(-original_offset % 12)
        new_column = " ".join(
            (
                column[: span[0] - 1],
                f"{original_offset} - {comp}",
                column[span[1] + 1 :],
            )
        )
        if comp == 0:
            comp_column = column[: span[0] - 1]
        else:
            comp_column = " ".join(
                (column[: span[0] - 1], f"{comp}", column[span[1] + 1 :])
            )
        print(column, comp_column)
        e_s_exog_data[new_column] = e_s_exog_data[column] - e_s_exog_data[comp_column]
        to_delete.append(column)
for column in to_delete:
    del e_s_exog_data[column]

## Cached Model Fitting
If anything regarding the data changes above, the cache has to be refreshed using memory.clear()!

In [None]:
ModelResults = namedtuple(
    "ModelResults",
    ("X_train", "X_test", "y_train", "y_test", "r2_test", "r2_train", "model"),
)

model_cache = SimpleCache("rf_model", cache_dir=CACHE_DIR)


@model_cache
def get_time_lags_model():
    """Get a RF model trained on the extended shifted data.
    
    Returns:
        ModelResults: A namedtuple with the fields 'X_train', 'X_test', 'y_train', 'y_test', 
        'r2_test', 'r2_train', and 'model'.
    
    """
    # Split the data.
    X_train, X_test, y_train, y_test = train_test_split(
        e_s_exog_data, e_s_endog_data, random_state=1, shuffle=True, test_size=0.3
    )
    # Define and train the model.
    rf = RandomForestRegressor(
        n_estimators=200, max_depth=None, random_state=1, n_jobs=get_ncpus(),
    )
    rf.fit(X_train, y_train)

    r2_test = rf.score(X_test, y_test)
    r2_train = rf.score(X_train, y_train)

    return ModelResults(X_train, X_test, y_train, y_test, r2_test, r2_train, rf)


model_results = get_time_lags_model()
# Take advantage of all cores available to our job.
model_results.model.n_jobs = get_ncpus()

## R2 Scores

In [None]:
print("R2 train:", model_results.r2_train)
print("R2 test:", model_results.r2_test)

## Aggregated Feature Importances - Scikit-learn Gini Importances (eli5.explain_weights(_df) does the exact same thing)

In [None]:
gini_mean = pd.Series(
    model_results.model.feature_importances_, index=model_results.X_train.columns
).sort_values(ascending=False)

## ELI5 Permutation Importances (PFI)

In [None]:
import cloudpickle
import eli5
from eli5.sklearn import PermutationImportance

from wildfires.dask_cx1 import get_parallel_backend

perm_importance_cache = SimpleCache(
    "perm_importance", cache_dir=CACHE_DIR, pickler=cloudpickle
)

# Does not seem to work with the dask parallel backend - it gets bypassed and every available core on the machine is used up
# if attempted.


@perm_importance_cache
def get_perm_importance():
    with parallel_backend("threading", n_jobs=get_ncpus()):
        return PermutationImportance(model_results.model).fit(
            model_results.X_train, model_results.y_train
        )


perm_importance = get_perm_importance()
perm_df = eli5.explain_weights_df(
    perm_importance, feature_names=list(model_results.X_train.columns)
)

### Brute Force LOCO (leave one column out) by retraining the model with the relevant column(s) removed

In [None]:
import sklearn.base
from sklearn.metrics import mean_squared_error

from wildfires.dask_cx1 import *


def simple_loco(est, X_train, y_train, leave_out=()):
    """Simple LOCO feature importances.
    
    Args:
        est: Estimator object with `fit()` and `predict()` methods.
        train_X (pandas DataFrame): DataFrame containing the training data.
        train_y (pandas Series or array-like): Target data.
        leave_out (iterable of column names): Column names to exclude.
        
    Returns:
        mse: Mean squared error of the training set predictions.
    
    """
    # Get a new instance with the same parameters.
    est = sklearn.base.clone(est)

    # Fit on the reduced dataset.
    X_train = X_train.copy()

    for column in leave_out:
        del X_train[column]
    est.fit(X_train, y_train)
    # Get MSE.
    mse = mean_squared_error(y_true=y_train, y_pred=est.predict(X_train))
    return mse


loco_cache = SimpleCache("loco_mses", cache_dir=CACHE_DIR)


@loco_cache
def get_loco_mses():
    # Baseline prediction will be the empty list (first entry here).
    leave_out_columns = [[]]
    for column in model_results.X_train.columns:
        leave_out_columns.append([column])

    model_clone = sklearn.base.clone(model_results.model)

    with get_parallel_backend(fallback=False):
        mse_values = Parallel(verbose=10)(
            delayed(simple_loco)(
                model_clone, model_results.X_train, model_results.y_train, columns
            )
            for columns in tqdm(leave_out_columns, desc="Prefetch LOCO columns")
        )
    return leave_out_columns, mse_values


leave_out_columns, mse_values = get_loco_mses()

In [None]:
from warnings import warn

mse_values = np.asarray(mse_values)
assert leave_out_columns[0] == []
loco_columns = ["baseline"] + ["_".join(columns) for columns in leave_out_columns[1:]]
baseline_mse = mse_values[0]
loco_importances = pd.Series(
    mse_values[1:] - baseline_mse, index=loco_columns[1:]
).sort_values(ascending=False)

if np.any(loco_importances < 0):
    warn("MSE values without some features were lower than baseline.")

### Comparing the three measures - Gini vs PFI vs LOCO

In [None]:
comp_import_df = pd.DataFrame(
    np.hstack(
        (
            gini_mean.index.values[:, np.newaxis],
            gini_mean.values[:, np.newaxis],
            perm_df["feature"].values[:, np.newaxis],
            perm_df["weight"].values[:, np.newaxis],
            loco_importances.index.values[:, np.newaxis],
            loco_importances.values[:, np.newaxis],
        )
    ),
    columns=[["Gini"] * 2 + ["PFI"] * 2 + ["LOCO"] * 2, ["Feature", "Importance"] * 3],
)

fig, axes = plt.subplots(1, 3, figsize=(23, 15))
for ax, measure in zip(axes, ("Gini", "PFI", "LOCO")):
    features = list(comp_import_df[(measure, "Feature")])[::-1]
    importances = np.asarray(comp_import_df[(measure, "Importance")])[::-1]
    importances /= np.sum(importances)
    ax.set_title(measure)
    ax.barh(
        range(len(features)),
        importances,
        align="center",
        color=sns.color_palette("husl", len(features), desat=0.5)[::-1],
    )
    ax.set_yticks(range(len(features)))
    ax.set_yticklabels(features)
    ax.set_xlabel(f"Relative {measure} Importance")
    ax.margins(y=0.008, tight=True)
plt.subplots_adjust(wspace=0.45)

## Individual Tree Importances - Gini vs PFI

In [None]:
fig, (ax, ax2) = plt.subplots(2, 1, sharex=True, figsize=(28, 14))

# Gini values.
ind_trees_gini = pd.DataFrame(
    [tree.feature_importances_ for tree in model_results.model],
    columns=model_results.X_train.columns,
)
mean_importances = ind_trees_gini.mean().sort_values(ascending=False)
ind_trees_gini = ind_trees_gini.reindex(mean_importances.index, axis=1)
sns.boxplot(data=ind_trees_gini, ax=ax)
ax.set(title="Gini Importances", ylabel="Gini Importance (MSE)")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")

# PFI values.
pfi_ind = pd.DataFrame(perm_importance.results_, columns=model_results.X_train.columns)

# Re-index according to the same ordering as for the Gini importances!
pfi_ind = pfi_ind.reindex(mean_importances.index, axis=1)

sns.boxplot(data=pfi_ind, ax=ax2)
ax2.set(title="PFI Importances", ylabel="PFI Importance")
_ = ax2.set_xticklabels(ax2.get_xticklabels(), rotation=45, ha="right")

for _ax in (ax, ax2):
    _ax.grid(which="major", alpha=0.3)

## Correlation Plot

In [None]:
from functools import partial

import matplotlib.colors as colors


class MidpointNormalize(colors.Normalize):
    def __init__(self, *args, midpoint=None, **kwargs):
        self.midpoint = midpoint
        super().__init__(*args, **kwargs)

    def __call__(self, value, clip=None):
        # Simple mapping between the color range halves and the data halves.
        x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
        return np.ma.masked_array(np.interp(value, x, y))


def corr_plot(exog_data):
    columns = list(map(map_name, exog_data.columns))

    def trim(string, n=10, cont_str="..."):
        if len(string) > n:
            string = string[: n - len(cont_str)]
            string += cont_str
        return string

    n = len(columns)
    fig, ax = plt.subplots(figsize=(20, 15))

    corr_arr = np.ma.MaskedArray(exog_data.corr().values)
    corr_arr.mask = np.zeros_like(corr_arr)
    # Ignore diagnals, since they will all be 1 anyway!
    np.fill_diagonal(corr_arr.mask, True)

    im = ax.matshow(
        corr_arr,
        interpolation="none",
        cmap="RdYlBu_r",
        norm=MidpointNormalize(midpoint=0.0),
    )

    fig.colorbar(im, pad=0.03, shrink=0.95, aspect=20)

    ax.set_xticks(np.arange(n))
    ax.set_xticklabels(map(partial(trim, n=15), columns))
    ax.set_yticks(np.arange(n))
    ax.set_yticklabels(columns)

    # Activate ticks on top of axes.
    ax.tick_params(axis="x", bottom=False, top=True, labelbottom=False, labeltop=True)

    # Rotate and align top ticklabels
    plt.setp(
        [tick.label2 for tick in ax.xaxis.get_major_ticks()],
        rotation=45,
        ha="left",
        va="center",
        rotation_mode="anchor",
    )
    fig.tight_layout()


corr_plot(model_results.X_train)

## PDP Plots 

In [None]:
from pdpbox import info_plots, pdp

model_results.model.n_jobs = get_ncpus()

In [None]:
pdp_dry_day_period = pdp.pdp_isolate(
    model=model_results.model,
    dataset=model_results.X_train[:10000],
    model_features=model_results.X_train.columns,
    feature="Dry Day Period",
)
fig, axes = pdp.pdp_plot(
    pdp_dry_day_period,
    "Dry Day Period",
    plot_lines=True,
    frac_to_plot=1,
    x_quantile=True,
    cluster=False,
    n_cluster_centers=20,
    show_percentile=True,
    plot_pts_dist=True,
)

In [None]:
fig, axes = pdp.pdp_plot(
    pdp_dry_day_period,
    "Dry Day Period",
    plot_lines=True,
    frac_to_plot=1,
    x_quantile=True,
    cluster=True,
    n_cluster_centers=40,
    show_percentile=True,
    plot_pts_dist=True,
)

In [None]:
inter_fapar_dry_day = pdp.pdp_interact(
    model=model_results.model,
    dataset=model_results.X_train,
    features=["FAPAR", "Dry Day Period"],
    num_grid_points=(10, 10),
    percentile_ranges=[(5, 95), (5, 95)],
)

## Worldwide

In [None]:
def save_ale_plot_1d(model, X_train, column):
    with figure_saver(column, sub_directory="ale"):
        ale_plot(
            model,
            X_train,
            column,
            bins=40,
            monte_carlo=False,  # XXX: !!!
            monte_carlo_rep=100,
            monte_carlo_ratio=0.01,
            plot_quantiles=False,
        )
        plt.gcf().axes[0].lines[-1].set_marker(".")
        if any(feature.lower() in column.lower() for feature in log_vars):
            plt.gcf().axes[0].set_xscale(**log_xscale_kwargs)


target_func = save_ale_plot_1d

model_params = (model_results.model, model_results.X_train[:100])

with get_parallel_backend(fallback="none") as (backend, client):
    if client is not None:
        print("Using Dask", client)
        # A Dask scheduler was found, so we need to scatter large pieces of data (if any).

        model_params = [client.scatter(param, broadcast=True) for param in model_params]

        def func(param_iter):
            return client.gather(client.map(target_func, *list(zip(*(param_iter)))))

    else:
        print("Not using any backend")

        def func(param_iter):
            return [target_func(*params) for params in param_iter]

    func(
        (*model_params, column)
        for column in tqdm(model_results.X_train.columns, desc="ALE plotting")
        if column == "lightning"
    )

### 2D ALE interaction plots

In [None]:
coverage = 0.02


def plot_ale_and_get_importance(columns, model, train_set):
    model.n_jobs = get_ncpus()
    ale, quantiles_list, samples_grid = ale_2d(
        model.predict, train_set, columns, bins=20, coverage=coverage,
    )
    min_samples = (
        train_set.shape[0] / reduce(mul, map(lambda x: len(x) - 1, quantiles_list))
    ) / 10
    try:
        return np.ma.max(ale[samples_grid > min_samples]) - np.ma.min(
            ale[samples_grid > min_samples]
        )
    except:
        return None

In [None]:
ptp_values = {}
columns_list = list(combinations(model_results.X_train.columns, 2))
for columns in tqdm(columns_list, desc="Calculating 2D ALE plots"):
    ptp_values[columns] = plot_ale_and_get_importance(
        columns, model_results.model, model_results.X_train
    )

In [None]:
# Ignore and count None values, then plot a histogram of the ptp values.
filtered_columns_list = []
filtered_ptp_values = []
for columns, ptp in ptp_values.items():
    if ptp is not None:
        filtered_columns_list.append(columns)
        filtered_ptp_values.append(ptp)

np.asarray([ptp for ptp in ptp_values if ptp is not None])
_ = plt.hist(filtered_ptp_values, bins=20)

In [None]:
pdp_results = pd.Series(filtered_ptp_values, index=filtered_columns_list)
pdp_results.sort_values(inplace=True, ascending=False)
print(pdp_results.head(20))

## Subset the original DataFrame to analyse specific regions only

In [None]:
def subset_dataframe(data, original_mask, additional_mask, suffix=""):
    """Sub-set results based on an additional mask.
    
    Args:
        data (pandas.core.frame.DataFrame): Data to select.
        orig_mask (array-like): Original mask that was used to transform the data into the column representation in `data`. This mask should be False where data should be selected.
        additional_mask (array-like): After conversion of columns in `data` back to a lat-lon grid, this mask will be used in addition to `orig_mask` to return a subset of the data to the column format. This mask should be False where data should be selected.
        suffix (str): Suffix to add to column labels. An empty space will be added to the beginning of `suffix` if this is not already present.
        
    Returns:
        pandas.core.frame.DataFrame: Selected data.
    
    """
    additional_mask = match_shape(additional_mask, original_mask.shape)
    if suffix:
        if suffix[0] != " ":
            suffix = " " + suffix
    new_data = {}
    for column in tqdm(data.columns, desc="Selecting data"):
        # Create a blank lat-lon grid.
        lat_lon_data = np.empty_like(original_mask, dtype=np.float64)
        # Convert data from the dense column representation to the sparse lat-lon grid.
        lat_lon_data[~original_mask] = data[column]
        # Use the original and the new mask to create new columns.
        new_data[column + suffix] = lat_lon_data[
            ((~original_mask) & (~additional_mask))
        ]
    return pd.DataFrame(new_data)

### SE Asia

In [None]:
# Create new mask.
region_mask = ~box_mask(lats=(-10, 10), lons=(95, 150))
cube_plotting(region_mask)

# XXX: This only allows subsetting the original data (both training and test!) since otherwise the original mask does not apply.

# Apply new mask.
sub_X = subset_dataframe(e_s_exog_data, e_s_master_mask, region_mask, "SE ASIA")
print("Original size:", e_s_exog_data.shape)
print("Selected size:", sub_X.shape)

# Plot ALE plots for only this region.
for column in tqdm(sub_X.columns, desc="Calculating ALE plots"):
    with figure_saver(column, sub_directory="ale_se_asia"):
        ale_plot(
            model_results.model,
            sub_X,
            column,
            bins=40,
            monte_carlo=True,
            monte_carlo_rep=30,
            monte_carlo_ratio=0.1,
            verbose=False,
            log="x"
            if any(
                feature.lower() in column.lower()
                for feature in ("dry day period", "popd",)
            )
            else None,
        )

### Brazilian Amazon

In [None]:
# Create new mask.
region_mask = ~box_mask(lats=(-15, 1), lons=(-72, -46))
cube_plotting(region_mask)

# XXX: This only allows subsetting the original data (both training and test!) since otherwise the original mask does not apply.

# Apply new mask.
sub_X = subset_dataframe(e_s_exog_data, e_s_master_mask, region_mask, "BRAZ AMAZ")
print("Original size:", e_s_exog_data.shape)
print("Selected size:", sub_X.shape)

# Plot ALE plots for only this region.
for column in tqdm(sub_X.columns, desc="Calculating ALE plots"):
    with figure_saver(column, sub_directory="ale_braz_amaz"):
        ale_plot(
            model_results.model,
            sub_X,
            column,
            bins=40,
            monte_carlo=True,
            monte_carlo_rep=30,
            monte_carlo_ratio=0.1,
            verbose=False,
            log="x"
            if any(
                feature.lower() in column.lower()
                for feature in ("dry day period", "popd",)
            )
            else None,
        )

### Europe

In [None]:
# Create new mask.
region_mask = ~box_mask(lats=(33, 73), lons=(-11, 29))
cube_plotting(region_mask)

# XXX: This only allows subsetting the original data (both training and test!) since otherwise the original mask does not apply.

# Apply new mask.
sub_X = subset_dataframe(e_s_exog_data, e_s_master_mask, region_mask, "EUROPE")
print("Original size:", e_s_exog_data.shape)
print("Selected size:", sub_X.shape)

# Plot ALE plots for only this region.
for column in tqdm(sub_X.columns, desc="Calculating ALE plots"):
    with figure_saver(column, sub_directory="ale_europe"):
        ale_plot(
            model_results.model,
            sub_X,
            column,
            bins=40,
            monte_carlo=True,
            monte_carlo_rep=30,
            monte_carlo_ratio=0.1,
            verbose=False,
            log="x"
            if any(
                feature.lower() in column.lower()
                for feature in ("dry day period", "popd",)
            )
            else None,
        )