## Initialisation

In [None]:
import logging
import os
import re
import sys
import warnings
from collections import namedtuple
from functools import reduce
from itertools import combinations
from operator import mul

import cloudpickle
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import shap
from joblib import Memory, Parallel, delayed
from loguru import logger as loguru_logger
from matplotlib.patches import Rectangle
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import cross_val_score, train_test_split
from tqdm import tqdm

import wildfires.analysis
from alepython import ale_plot
from alepython.ale import _second_order_ale_quant
from wildfires.analysis import *
from wildfires.dask_cx1 import *
from wildfires.data import *
from wildfires.logging_config import enable_logging
from wildfires.qstat import get_ncpus
from wildfires.utils import *

loguru_logger.enable("alepython")
loguru_logger.remove()
loguru_logger.add(sys.stderr, level="WARNING")

logger = logging.getLogger(__name__)

enable_logging("jupyter")

warnings.filterwarnings("ignore", ".*Collapsing a non-contiguous coordinate.*")
warnings.filterwarnings("ignore", ".*DEFAULT_SPHERICAL_EARTH_RADIUS*")
warnings.filterwarnings("ignore", ".*guessing contiguous bounds*")

normal_coast_linewidth = 0.5
mpl.rc("figure", figsize=(14, 6))
mpl.rc("font", size=9.0)

save_name = "analysis_variable_diagnostics"

figure_saver = FigureSaver(directories=os.path.join("~", "tmp", save_name), debug=True,)
memory = get_memory(save_name, verbose=100)
CACHE_DIR = os.path.join(DATA_DIR, ".pickle", save_name)

### Load the customized `get_data()` function for this experiment.

In [None]:
from get_lags_rf_cross_val_data import get_data

## Creating the Data Structures used for Fitting

In [None]:
shift_months = [1, 3, 6, 9, 12, 18, 24]

(
    e_s_endog_data,
    e_s_exog_data,
    e_s_master_mask,
    e_s_filled_datasets,
    e_s_masked_datasets,
    e_s_land_mask,
) = get_data(shift_months=shift_months, selection_variables=None)

### Offset data from 12 or more months before the current month in order to ease analysis (interpretability).

We are interested in the trends in these properties, not their absolute values, therefore we subtract a recent 'seasonal cycle' analogue.
This hopefully avoids capturing the same relationships for a variable and its 12 month counterpart due to their high correlation.

In [None]:
# XXX: Put this into its own method as well! - aside: explore simple shifts vs. mean over antecendent period as well
to_delete = []
for column in e_s_exog_data:
    match = re.search(r"-\d{1,2}", column)
    if match:
        span = match.span()
        # Change the string to reflect the shift.
        original_offset = int(column[slice(*span)])
        if original_offset > -12:
            # Only shift months that are 12 or more months before the current month.
            continue
        comp = -(-original_offset % 12)
        new_column = " ".join(
            (
                column[: span[0] - 1],
                f"{original_offset} - {comp}",
                column[span[1] + 1 :],
            )
        )
        if comp == 0:
            comp_column = column[: span[0] - 1]
        else:
            comp_column = " ".join(
                (column[: span[0] - 1], f"{comp}", column[span[1] + 1 :])
            )
        print(column, comp_column)
        e_s_exog_data[new_column] = e_s_exog_data[column] - e_s_exog_data[comp_column]
        to_delete.append(column)
for column in to_delete:
    del e_s_exog_data[column]

## Mapping

In [None]:
constrained_map_plot(
    {"FAPAR": (0.39, None), "Dry Day Period": (20, None)},
    e_s_exog_data,
    e_s_master_mask,
    plot_variable="FAPAR",
    coastline_kwargs={"linewidth": 0.5},
)

In [None]:
constrained_map_plot(
    {"Dry Day Period -18 - -6 Month": (22, None), "AGB Tree": (0.9, 20)},
    e_s_exog_data,
    e_s_master_mask,
    plot_variable="AGB Tree",
    coastline_kwargs={"linewidth": 0.5},
)

## Correlation Plot

In [None]:
X_corr = e_s_exog_data
with figure_saver("corr_plot_with_sif"):
    corr_plot(X_corr[X_corr.columns[:-12]], fig_kwargs={"figsize": (12, 8)})
print("Excluded columns:", X_corr.columns[-12:])

In [None]:
with figure_saver("corr_plot_full"):
    corr_plot(X_corr[X_corr.columns], fig_kwargs={"figsize": (12, 8)})