## Initialisation

In [None]:
import logging
import os
import re
import sys
import warnings
from collections import namedtuple
from functools import reduce
from itertools import combinations
from operator import mul

import cloudpickle
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import shap
from joblib import Memory, Parallel, delayed
from loguru import logger as loguru_logger
from matplotlib.patches import Rectangle
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import cross_val_score, train_test_split
from tqdm import tqdm

import wildfires.analysis
from alepython import ale_plot
from alepython.ale import _second_order_ale_quant
from wildfires.analysis import *
from wildfires.dask_cx1 import get_parallel_backend
from wildfires.data import *
from wildfires.logging_config import enable_logging
from wildfires.qstat import get_ncpus
from wildfires.utils import *

loguru_logger.enable("alepython")
loguru_logger.remove()
loguru_logger.add(sys.stderr, level="WARNING")

logger = logging.getLogger(__name__)

enable_logging("jupyter")

warnings.filterwarnings("ignore", ".*Collapsing a non-contiguous coordinate.*")
warnings.filterwarnings("ignore", ".*DEFAULT_SPHERICAL_EARTH_RADIUS*")
warnings.filterwarnings("ignore", ".*guessing contiguous bounds*")

normal_coast_linewidth = 0.5
mpl.rc("figure", figsize=(14, 6))
mpl.rc("font", size=9.0)

save_name = "analysis_time_lags_xgboost"

figure_saver = FigureSaver(directories=os.path.join("~", "tmp", save_name), debug=True,)
memory = get_memory(save_name, verbose=100)
CACHE_DIR = os.path.join(DATA_DIR, ".pickle", save_name)

## Overwrite wildfires get_data with our own personalised version

In [None]:
from get_time_lag_data import get_data

In [None]:
value = "symlog"
linthres = 1e-2
subs = [2, 3, 4, 5, 6, 7, 8, 9]
log_xscale_kwargs = dict(value=value, linthreshx=linthres, subsx=subs)
log_yscale_kwargs = dict(value=value, linthreshy=linthres, subsy=subs)
log_vars = (
    "dry day period",
    "popd",
    "agb tree",
    "cape x precip",
    "lai",
    "shruball",
    "pftherb",
    "pftcrop",
    "treeall",
)

In [None]:
def ale_2d(predictor, train_set, features, bins=40, coverage=1):
    if coverage < 1:
        # This should be ok if `train_set` is randomised, as it usually is.
        train_set = train_set[: int(train_set.shape[0] * coverage)]
    ale, quantiles_list, samples_grid = _second_order_ale_quant(
        predictor, train_set, features, bins=bins, return_samples_grid=True
    )
    fig, ax = plt.subplots()
    centres_list = [get_centres(quantiles) for quantiles in quantiles_list]
    n_x, n_y = 50, 50
    x = np.linspace(centres_list[0][0], centres_list[0][-1], n_x)
    y = np.linspace(centres_list[1][0], centres_list[1][-1], n_y)

    X, Y = np.meshgrid(x, y, indexing="xy")
    ale_interp = scipy.interpolate.interp2d(centres_list[0], centres_list[1], ale.T)
    CF = ax.contourf(X, Y, ale_interp(x, y), cmap="bwr", levels=30, alpha=0.7)

    # Do not autoscale, so that boxes at the edges (contourf only plots the bin
    # centres, not their edges) don't enlarge the plot. Such boxes include markings for
    # invalid cells, or hatched boxes for valid cells.
    plt.autoscale(False)

    # Add hatching for the significant cells. These have at least `min_samples` samples.
    # By default, calculate this as the number of samples in each bin if everything was equally distributed, divided by 10.
    min_samples = (train_set.shape[0] / reduce(mul, map(len, centres_list))) / 10
    for i, j in zip(*np.where(samples_grid >= min_samples)):
        ax.add_patch(
            Rectangle(
                [quantiles_list[0][i], quantiles_list[1][j]],
                quantiles_list[0][i + 1] - quantiles_list[0][i],
                quantiles_list[1][j + 1] - quantiles_list[1][j],
                linewidth=0,
                fill=None,
                hatch=".",
                alpha=0.4,
            )
        )

    if np.any(ale.mask):
        # Add rectangles to indicate cells without samples.
        for i, j in zip(*np.where(ale.mask)):
            ax.add_patch(
                Rectangle(
                    [quantiles_list[0][i], quantiles_list[1][j]],
                    quantiles_list[0][i + 1] - quantiles_list[0][i],
                    quantiles_list[1][j + 1] - quantiles_list[1][j],
                    linewidth=1,
                    edgecolor="k",
                    facecolor="none",
                    alpha=0.4,
                )
            )
    fig.colorbar(CF, format="%.0e")
    ax.set_xlabel(features[0])
    ax.set_ylabel(features[1])
    nbins_str = "x".join([str(len(centres)) for centres in centres_list])
    ax.set_title(
        f"Second-order ALE of features {features[0]} and {features[1]}\n"
        f"Bins: {nbins_str} (Hatching: Sig., Boxes: Invalid)"
    )

    if any(log_var.lower() in features[0].lower() for log_var in log_vars):
        ax.set_xscale(**log_xscale_kwargs)
    if any(log_var.lower() in features[1].lower() for log_var in log_vars):
        ax.set_yscale(**log_yscale_kwargs)
    figure_saver.save_figure(fig, "__".join(columns), sub_directory="2d_ale_low")
    return ale, quantiles_list, samples_grid

## Creating the Data Structures used for Fitting

In [None]:
shift_months = [1, 3, 6, 9, 12, 18, 24]

# selection_variables = (
#     "VOD Ku-band -3 Month",
#     # "SIF",  # Fix regridding!!
#     "VOD Ku-band -1 Month",
#     "Dry Day Period -3 Month",
#     "FAPAR",
#     "pftHerb",
#     "LAI -1 Month",
#     "popd",
#     "Dry Day Period -24 Month",
#     "pftCrop",
#     "FAPAR -1 Month",
#     "FAPAR -24 Month",
#     "Max Temp",
#     "Dry Day Period -6 Month",
#     "VOD Ku-band -6 Month",
# )

# ext_selection_variables = selection_variables + (
#     "Dry Day Period -1 Month",
#     "FAPAR -6 Month",
#     "ShrubAll",
#     "SWI(1)",
#     "TreeAll",
# )
from ipdb import launch_ipdb_on_exception

with launch_ipdb_on_exception():
    (
        e_s_endog_data,
        e_s_exog_data,
        e_s_master_mask,
        e_s_filled_datasets,
        e_s_masked_datasets,
        e_s_land_mask,
    ) = get_data(shift_months=shift_months, selection_variables=None)

### Offset data from 12 or more months before the current month in order to ease analysis (interpretability).
We are interested in the trends in these properties, not their absolute values, therefore we subtract a recent 'seasonal cycle' analogue.
This hopefully avoids capturing the same relationships for a variable and its 12 month counterpart due to their high correlation.

In [None]:
to_delete = []
for column in e_s_exog_data:
    match = re.search(r"-\d{1,2}", column)
    if match:
        span = match.span()
        # Change the string to reflect the shift.
        original_offset = int(column[slice(*span)])
        if original_offset > -12:
            # Only shift months that are 12 or more months before the current month.
            continue
        comp = -(-original_offset % 12)
        new_column = " ".join(
            (
                column[: span[0] - 1],
                f"{original_offset} - {comp}",
                column[span[1] + 1 :],
            )
        )
        if comp == 0:
            comp_column = column[: span[0] - 1]
        else:
            comp_column = " ".join(
                (column[: span[0] - 1], f"{comp}", column[span[1] + 1 :])
            )
        print(column, comp_column)
        e_s_exog_data[new_column] = e_s_exog_data[column] - e_s_exog_data[comp_column]
        to_delete.append(column)
for column in to_delete:
    del e_s_exog_data[column]

## Cached Model Fitting
If anything regarding the data changes above, the cache has to be refreshed using memory.clear()!

In [None]:
from dask.distributed import *

client = Client(n_workers=1, threads_per_worker=get_ncpus())
client

In [None]:
import dask_xgboost
import xgboost as xgb
from dask import array as darray
from dask import dataframe as dd
from dask_ml.xgboost import XGBRegressor
from sklearn.metrics import mean_squared_error, r2_score

from wildfires.dask_cx1 import *

ModelResults = namedtuple(
    "ModelResults",
    (
        "X_train",
        "X_test",
        "y_train",
        "y_test",
        "r2_test",
        "r2_train",
        "model",
        "mse_train",
        "mse_test",
    ),
)

model_cache = SimpleCache("xgboost_model", cache_dir=CACHE_DIR)

# XXX:
model_cache.clear()


def get_darray(numpy_array, chunk_nrows):
    if len(numpy_array.shape) == 1:
        return darray.from_array(numpy_array, chunks=(chunk_nrows,))
    elif len(numpy_array.shape) == 2:
        return darray.from_array(numpy_array, chunks=(chunk_nrows, -1))
    raise ValueError("Expected (m,) or (m, n) array.")


@model_cache
def get_time_lags_model():
    """Get an XGBOOST model trained on the extended shifted data.
    
    Returns:
        ModelResults: A namedtuple with the fields 'X_train', 'X_test', 'y_train', 'y_test', 
        'r2_test', 'r2_train', and 'model'.
    
    """
    N = slice(0, 10000)

    # Split the data.
    X_train, X_test, y_train, y_test = train_test_split(
        e_s_exog_data[N], e_s_endog_data[N], random_state=1, shuffle=True, test_size=0.3
    )
    # Define and train the model.
    # client = get_client()

    # Used to define the chunk size.
    n_cores = 20

    chunk_nrows = max((int(X_train.shape[0] / (n_cores * 3)), 20))

    da_X_train = get_darray(X_train.values, chunk_nrows)
    da_y_train = get_darray(y_train.values, chunk_nrows)

    da_X_test = get_darray(X_test.values, chunk_nrows)
    da_y_test = get_darray(y_test.values, chunk_nrows)

    regressor = xgb.dask.DaskXGBRegressor(n_estimators=10000, max_depth=10,)
    #     regressor.set_params(**
    #         {
    #             # 'tree_method': 'hist',
    #             # 'grow_policy': 'lossguide',
    #             # 'eta': 0.3,
    #             # 'min_child_weight': 1,
    #             # 'subsample': 1,
    #             # 'colsample_bytree': 1,
    #         }),
    # Apparently this is optional - a 'global client' is used if not given.
    regressor.client = client

    #
    # Need to combat overfitting!!!
    #
    # The dask-ml version might support early stopping!
    # https://ml.dask.org/modules/generated/dask_ml.xgboost.XGBRegressor.html
    regressor.fit(
        da_X_train,
        da_y_train,
        early_stopping_rounds=10,
        verbose=True,
        eval_set=[(da_X_test, da_y_test)],
    )

    bst = regressor.get_booster()
    history = regressor.evals_result()
    # print('Evaluation history:', history)

    y_test_pred = regressor.predict(da_X_test)
    y_train_pred = regressor.predict(da_X_train)

    mse_test = mean_squared_error(y_test_pred, y_test)
    mse_train = mean_squared_error(y_train_pred, y_train)

    r2_test = r2_score(y_test_pred, y_test)
    r2_train = r2_score(y_train_pred, y_train)

    return ModelResults(
        X_train, X_test, y_train, y_test, r2_test, r2_train, bst, mse_train, mse_test
    )


model_results = get_time_lags_model()

## R2 Scores

In [None]:
print("R2 train:", model_results.r2_train)
print("R2 test:", model_results.r2_test)

print("mse train:", model_results.mse_train)
print("mse test:", model_results.mse_test)

In [None]:
y_train_pred = model_results.model.predict(
    dd.from_pandas(model_results.X_train, npartitions=100)
).compute()

In [None]:
plt.plot(model_results.y_train[:100], y_train_pred[:100], linestyle="", marker="o")
_ = plt.gca().set(xlabel="Train", ylabel="Predicted")

## SHAP Values
Using dask to parallelise the SHAP value calculations is possible (sometimes), but VERY unstable, just like local backends (eg. loky)

In [None]:
explainer = shap.TreeExplainer(model_results.model, data=model_results.X_train[-100:])
with Time("100, 100 vals"):
    shap_values = explainer.shap_values(model_results.X_train[:100])
shap.summary_plot(shap_values, model_results.X_train[:100])

In [None]:
explainer = shap.TreeExplainer(model_results.model, data=model_results.X_train[-200:])
with Time("200, 100 vals"):
    shap_values = explainer.shap_values(model_results.X_train[:100])
shap.summary_plot(shap_values, model_results.X_train[:100])

In [None]:
explainer = shap.TreeExplainer(model_results.model, data=model_results.X_train[-400:])
with Time("400, 100 vals"):
    shap_values = explainer.shap_values(model_results.X_train[:100])
shap.summary_plot(shap_values, model_results.X_train[:100])

In [None]:
explainer = shap.TreeExplainer(model_results.model, data=model_results.X_train[-400:])
with Time("400, 500 vals"):
    shap_values = explainer.shap_values(model_results.X_train[:500])
shap.summary_plot(shap_values, model_results.X_train[:500])

In [None]:
explainer = shap.TreeExplainer(model_results.model, data=model_results.X_train[-500:])
with Time("500, 2000 vals"):
    shap_values = explainer.shap_values(model_results.X_train[:2000])

In [None]:
shap.summary_plot(shap_values, model_results.X_train[:2000], max_display=100)

In [None]:
explainer = shap.TreeExplainer(model_results.model)
N = 100
with Time(f"None, {N} vals"):
    shap_values = explainer.shap_values(model_results.X_train[:N])
shap.summary_plot(shap_values, model_results.X_train[:N])

### SHAP Interaction Values

In [None]:
explainer = shap.TreeExplainer(model_results.model)
N = 10
with Time(f"shap interaction, {N}"):
    shap_interaction_values = explainer.shap_interaction_values(
        model_results.X_train[:N]
    )
shap.summary_plot(shap_values, model_results.X_train[:N])