# Unrestricted max depth RF model ALE plots

## Initialisation

In [None]:
import logging
import os
import re
import sys
import warnings
from collections import namedtuple
from functools import reduce
from itertools import combinations
from operator import mul

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
from joblib import Memory
from loguru import logger as loguru_logger
from matplotlib.patches import Rectangle
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import cross_val_score, train_test_split
from tqdm import tqdm

import wildfires.analysis
from alepython import ale_plot
from alepython.ale import _second_order_ale_quant
from wildfires.analysis import *
from wildfires.dask_cx1 import get_client
from wildfires.data import *
from wildfires.logging_config import enable_logging
from wildfires.qstat import get_ncpus
from wildfires.utils import *

loguru_logger.enable("alepython")
loguru_logger.remove()
loguru_logger.add(sys.stderr, level="WARNING")

logger = logging.getLogger(__name__)

enable_logging("jupyter")

warnings.filterwarnings("ignore", ".*Collapsing a non-contiguous coordinate.*")
warnings.filterwarnings("ignore", ".*DEFAULT_SPHERICAL_EARTH_RADIUS*")
warnings.filterwarnings("ignore", ".*guessing contiguous bounds*")

normal_coast_linewidth = 0.5
mpl.rc("figure", figsize=(14, 6))
mpl.rc("font", size=9.0)

figure_saver = FigureSaver(
    directories=os.path.join("~", "tmp", "analysis_multiple_time_lags_model"),
    debug=True,
)
memory = get_memory("analysis_multiple_time_lags_model", verbose=100)

In [None]:
value = "symlog"
linthres = 1e-2
subs = [2, 3, 4, 5, 6, 7, 8, 9]
log_xscale_kwargs = dict(value=value, linthreshx=linthres, subsx=subs)
log_yscale_kwargs = dict(value=value, linthreshy=linthres, subsy=subs)
log_vars = (
    "dry day period",
    "popd",
    "agb tree",
    "cape x precip",
    "lai",
    "shruball",
    "pftherb",
    "pftcrop",
    "treeall",
)

In [None]:
def ale_2d(predictor, train_set, features, bins=40, coverage=1):
    if coverage < 1:
        # This should be ok if `train_set` is randomised, as it usually is.
        train_set = train_set[: int(train_set.shape[0] * coverage)]
    ale, quantiles_list, samples_grid = _second_order_ale_quant(
        predictor, train_set, features, bins=bins, return_samples_grid=True
    )
    fig, ax = plt.subplots()
    centres_list = [get_centres(quantiles) for quantiles in quantiles_list]
    n_x, n_y = 50, 50
    x = np.linspace(centres_list[0][0], centres_list[0][-1], n_x)
    y = np.linspace(centres_list[1][0], centres_list[1][-1], n_y)

    X, Y = np.meshgrid(x, y, indexing="xy")
    ale_interp = scipy.interpolate.interp2d(centres_list[0], centres_list[1], ale.T)
    CF = ax.contourf(X, Y, ale_interp(x, y), cmap="bwr", levels=30, alpha=0.7)

    # Do not autoscale, so that boxes at the edges (contourf only plots the bin
    # centres, not their edges) don't enlarge the plot. Such boxes include markings for
    # invalid cells, or hatched boxes for valid cells.
    plt.autoscale(False)

    # Add hatching for the significant cells. These have at least `min_samples` samples.
    # By default, calculate this as the number of samples in each bin if everything was equally distributed, divided by 10.
    min_samples = (train_set.shape[0] / reduce(mul, map(len, centres_list))) / 10
    for i, j in zip(*np.where(samples_grid >= min_samples)):
        ax.add_patch(
            Rectangle(
                [quantiles_list[0][i], quantiles_list[1][j]],
                quantiles_list[0][i + 1] - quantiles_list[0][i],
                quantiles_list[1][j + 1] - quantiles_list[1][j],
                linewidth=0,
                fill=None,
                hatch=".",
                alpha=0.4,
            )
        )

    if np.any(ale.mask):
        # Add rectangles to indicate cells without samples.
        for i, j in zip(*np.where(ale.mask)):
            ax.add_patch(
                Rectangle(
                    [quantiles_list[0][i], quantiles_list[1][j]],
                    quantiles_list[0][i + 1] - quantiles_list[0][i],
                    quantiles_list[1][j + 1] - quantiles_list[1][j],
                    linewidth=1,
                    edgecolor="k",
                    facecolor="none",
                    alpha=0.4,
                )
            )
    fig.colorbar(CF, format="%.0e")
    ax.set_xlabel(features[0])
    ax.set_ylabel(features[1])
    nbins_str = "x".join([str(len(centres)) for centres in centres_list])
    ax.set_title(
        f"Second-order ALE of features {features[0]} and {features[1]}\n"
        f"Bins: {nbins_str} (Hatching: Sig., Boxes: Invalid)"
    )

    if any(log_var.lower() in features[0].lower() for log_var in log_vars):
        ax.set_xscale(**log_xscale_kwargs)
    if any(log_var.lower() in features[1].lower() for log_var in log_vars):
        ax.set_yscale(**log_yscale_kwargs)
    figure_saver.save_figure(fig, "__".join(features), sub_directory="2d_ale")
    return ale, quantiles_list, samples_grid

## Creating the Data Structures used for Fitting

In [None]:
shift_months = [1, 3, 6, 9, 12, 18, 24]

# selection_variables = (
#     "VOD Ku-band -3 Month",
#     # "SIF",  # Fix regridding!!
#     "VOD Ku-band -1 Month",
#     "Dry Day Period -3 Month",
#     "FAPAR",
#     "pftHerb",
#     "LAI -1 Month",
#     "popd",
#     "Dry Day Period -24 Month",
#     "pftCrop",
#     "FAPAR -1 Month",
#     "FAPAR -24 Month",
#     "Max Temp",
#     "Dry Day Period -6 Month",
#     "VOD Ku-band -6 Month",
# )

# ext_selection_variables = selection_variables + (
#     "Dry Day Period -1 Month",
#     "FAPAR -6 Month",
#     "ShrubAll",
#     "SWI(1)",
#     "TreeAll",
# )
from ipdb import launch_ipdb_on_exception

with launch_ipdb_on_exception():
    (
        e_s_endog_data,
        e_s_exog_data,
        e_s_master_mask,
        e_s_filled_datasets,
        e_s_masked_datasets,
        e_s_land_mask,
    ) = wildfires.analysis.time_lags.get_data(
        shift_months=shift_months, selection_variables=None
    )

### Offset data that has come 12 or more months before the current month in order to ease analysis.
We are interested in the trends in these properties, not their absolute values, therefore we subtract a recent 'seasonal cycle' analogue.
This hopefully avoids capturing the same relationships for a variable and its 12 month counterpart due to their high correlation.

In [None]:
to_delete = []
for column in e_s_exog_data:
    match = re.search(r"-\d{1,2}", column)
    if match:
        span = match.span()
        # Change the string to reflect the shift.
        original_offset = int(column[slice(*span)])
        if original_offset > -12:
            # Only shift months that are 12 or more months before the current month.
            continue
        comp = -(-original_offset % 12)
        new_column = " ".join(
            (
                column[: span[0] - 1],
                f"{original_offset} - {comp}",
                column[span[1] + 1 :],
            )
        )
        if comp == 0:
            comp_column = column[: span[0] - 1]
        else:
            comp_column = " ".join(
                (column[: span[0] - 1], f"{comp}", column[span[1] + 1 :])
            )
        print(column, comp_column)
        e_s_exog_data[new_column] = e_s_exog_data[column] - e_s_exog_data[comp_column]
        to_delete.append(column)
for column in to_delete:
    del e_s_exog_data[column]

## Cached Model Fitting
If anything regarding the data changes above, the cache has to be refreshed using memory.clear()!

In [None]:
ModelResults = namedtuple(
    "ModelResults",
    ("X_train", "X_test", "y_train", "y_test", "r2_test", "r2_train", "model"),
)


@memory.cache()
def get_time_lags_model():
    """Get a RF model trained on the extended shifted data.
    
    Returns:
        ModelResults: A namedtuple with the fields 'X_train', 'X_test', 'y_train', 'y_test', 
        'r2_test', 'r2_train', and 'model'.
    
    """
    # Split the data.
    X_train, X_test, y_train, y_test = train_test_split(
        e_s_exog_data, e_s_endog_data, random_state=1, shuffle=True, test_size=0.3
    )
    # Define and train the model.
    rf = RandomForestRegressor(
        n_estimators=200, max_depth=None, random_state=1, n_jobs=get_ncpus(),
    )
    rf.fit(X_train, y_train)

    r2_test = rf.score(X_test, y_test)
    r2_train = rf.score(X_train, y_train)

    return ModelResults(X_train, X_test, y_train, y_test, r2_test, r2_train, rf)


model_results = get_time_lags_model()
# Take advantage of all cores available to our job.
model_results.model.n_jobs = get_ncpus()

In [None]:
importances = model_results.model.feature_importances_
std = np.std(
    [tree.feature_importances_ for tree in model_results.model.estimators_], axis=0
)

importances_df = pd.DataFrame(
    {
        "Name": model_results.X_train.columns.values,
        "Importance": importances,
        "Importance STD": std,
        "Ratio": np.array(std) / np.array(importances),
    }
)
print(
    "\n"
    + str(
        importances_df.sort_values("Importance", ascending=False).to_string(
            index=False, float_format="{:0.3f}".format, line_width=200
        )
    )
)

## Worldwide

### 2D ALE interaction plots

In [None]:
from dask.distributed import Client

client = Client("localhost:35494")

In [None]:
model_f = client.scatter(model_results.model, broadcast=True)
model_X = client.scatter(model_results.X_train, broadcast=True)

In [None]:
coverage = 0.5


def plot_ale_and_get_importance(columns, model, train_set):
    logger = logging.getLogger(__name__)
    enable_logging("jupyter")
    mpl.rc("figure", figsize=(14, 6))
    mpl.rc("font", size=9.0)
    try:
        model.n_jobs = get_ncpus()
        ale, quantiles_list, samples_grid = ale_2d(
            model.predict, train_set, columns, bins=20, coverage=coverage,
        )
        min_samples = (
            train_set.shape[0] / reduce(mul, map(lambda x: len(x) - 1, quantiles_list))
        ) / 10
        return np.ma.max(ale[samples_grid > min_samples]) - np.ma.min(
            ale[samples_grid > min_samples]
        )
    except:
        logger.exception(
            f"Something went wrong with 2D ALE plotting for columns {columns}."
        )
        return None

In [None]:
# for columns in tqdm(, desc="Calculating 2D ALE plots"):
columns_list = list(combinations(model_results.X_train.columns, 2))
ptp_values = client.gather(
    client.map(
        plot_ale_and_get_importance,
        *list(zip(*((columns, model_f, model_X) for columns in columns_list)))
    )
)

In [None]:
# Ignore and count None values, then plot a histogram of the ptp values.
filtered_columns_list = []
filtered_ptp_values = []
invalid_columns = []
for columns, ptp in zip(columns_list, ptp_values):
    if ptp is not None:
        filtered_columns_list.append(columns)
        filtered_ptp_values.append(ptp)
    else:
        invalid_columns.append(columns)

print("Nr invalid:", len(invalid_columns))
print(pd.Series(invalid_columns))

np.asarray([ptp for ptp in ptp_values if ptp is not None])
_ = plt.hist(filtered_ptp_values, bins=20)

In [None]:
pdp_results = pd.Series(filtered_ptp_values, index=filtered_columns_list)
pdp_results.sort_values(inplace=True, ascending=False)
print(pdp_results.head(20))

In [None]:
for column in tqdm(model_results.X_train.columns, desc="Calculating ALE plots"):
    with figure_saver(column, sub_directory="ale"):
        ale_plot(
            model_results.model,
            model_results.X_train,
            column,
            bins=40,
            monte_carlo=True,
            monte_carlo_rep=100,
            monte_carlo_ratio=0.01,
            plot_quantiles=False,
        )
        plt.gcf().axes[0].lines[-1].set_marker(".")
        if any(feature.lower() in column.lower() for feature in log_vars):
            plt.gcf().axes[0].set_xscale(**log_xscale_kwargs)

## Subset the original DataFrame to analyse specific regions only

In [None]:
def subset_dataframe(data, original_mask, additional_mask, suffix=""):
    """Sub-set results based on an additional mask.
    
    Args:
        data (pandas.core.frame.DataFrame): Data to select.
        orig_mask (array-like): Original mask that was used to transform the data into the column representation in `data`. This mask should be False where data should be selected.
        additional_mask (array-like): After conversion of columns in `data` back to a lat-lon grid, this mask will be used in addition to `orig_mask` to return a subset of the data to the column format. This mask should be False where data should be selected.
        suffix (str): Suffix to add to column labels. An empty space will be added to the beginning of `suffix` if this is not already present.
        
    Returns:
        pandas.core.frame.DataFrame: Selected data.
    
    """
    additional_mask = match_shape(additional_mask, original_mask.shape)
    if suffix:
        if suffix[0] != " ":
            suffix = " " + suffix
    new_data = {}
    for column in tqdm(data.columns, desc="Selecting data"):
        # Create a blank lat-lon grid.
        lat_lon_data = np.empty_like(original_mask, dtype=np.float64)
        # Convert data from the dense column representation to the sparse lat-lon grid.
        lat_lon_data[~original_mask] = data[column]
        # Use the original and the new mask to create new columns.
        new_data[column + suffix] = lat_lon_data[
            ((~original_mask) & (~additional_mask))
        ]
    return pd.DataFrame(new_data)

### SE Asia

In [None]:
# Create new mask.
region_mask = ~box_mask(lats=(-10, 10), lons=(95, 150))
cube_plotting(region_mask)

# XXX: This only allows subsetting the original data (both training and test!) since otherwise the original mask does not apply.

# Apply new mask.
sub_X = subset_dataframe(e_s_exog_data, e_s_master_mask, region_mask, "SE ASIA")
print("Original size:", e_s_exog_data.shape)
print("Selected size:", sub_X.shape)

# Plot ALE plots for only this region.
for column in tqdm(sub_X.columns, desc="Calculating ALE plots"):
    with figure_saver(column, sub_directory="ale_se_asia"):
        ale_plot(
            model_results.model,
            sub_X,
            column,
            bins=40,
            monte_carlo=True,
            monte_carlo_rep=30,
            monte_carlo_ratio=0.1,
            verbose=False,
            log="x"
            if any(
                feature.lower() in column.lower()
                for feature in ("dry day period", "popd",)
            )
            else None,
        )

### Brazilian Amazon

In [None]:
# Create new mask.
region_mask = ~box_mask(lats=(-15, 1), lons=(-72, -46))
cube_plotting(region_mask)

# XXX: This only allows subsetting the original data (both training and test!) since otherwise the original mask does not apply.

# Apply new mask.
sub_X = subset_dataframe(e_s_exog_data, e_s_master_mask, region_mask, "BRAZ AMAZ")
print("Original size:", e_s_exog_data.shape)
print("Selected size:", sub_X.shape)

# Plot ALE plots for only this region.
for column in tqdm(sub_X.columns, desc="Calculating ALE plots"):
    with figure_saver(column, sub_directory="ale_braz_amaz"):
        ale_plot(
            model_results.model,
            sub_X,
            column,
            bins=40,
            monte_carlo=True,
            monte_carlo_rep=30,
            monte_carlo_ratio=0.1,
            verbose=False,
            log="x"
            if any(
                feature.lower() in column.lower()
                for feature in ("dry day period", "popd",)
            )
            else None,
        )

### Europe

In [None]:
# Create new mask.
region_mask = ~box_mask(lats=(33, 73), lons=(-11, 29))
cube_plotting(region_mask)

# XXX: This only allows subsetting the original data (both training and test!) since otherwise the original mask does not apply.

# Apply new mask.
sub_X = subset_dataframe(e_s_exog_data, e_s_master_mask, region_mask, "EUROPE")
print("Original size:", e_s_exog_data.shape)
print("Selected size:", sub_X.shape)

# Plot ALE plots for only this region.
for column in tqdm(sub_X.columns, desc="Calculating ALE plots"):
    with figure_saver(column, sub_directory="ale_europe"):
        ale_plot(
            model_results.model,
            sub_X,
            column,
            bins=40,
            monte_carlo=True,
            monte_carlo_rep=30,
            monte_carlo_ratio=0.1,
            verbose=False,
            log="x"
            if any(
                feature.lower() in column.lower()
                for feature in ("dry day period", "popd",)
            )
            else None,
        )