## Initialisation

In [None]:
import logging
import os
import sys
import warnings
from collections import namedtuple
from itertools import product

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from joblib import Memory
from loguru import logger as loguru_logger
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import cross_val_score, train_test_split
from tqdm import tqdm

import wildfires.analysis
from alepython import ale_plot
from wildfires.analysis import *
from wildfires.dask_cx1 import get_client
from wildfires.data import *
from wildfires.logging_config import enable_logging
from wildfires.qstat import get_ncpus

loguru_logger.enable("alepython")
loguru_logger.remove()
loguru_logger.add(sys.stderr, level="WARNING")

logger = logging.getLogger(__name__)

enable_logging("jupyter")

figure_saver = FigureSaver(
    directories=os.path.join("~", "tmp", "time_lags_pdp_ale"), debug=True
)

warnings.filterwarnings("ignore", ".*Collapsing a non-contiguous coordinate.*")
warnings.filterwarnings("ignore", ".*DEFAULT_SPHERICAL_EARTH_RADIUS*")
warnings.filterwarnings("ignore", ".*guessing contiguous bounds*")

normal_coast_linewidth = 0.5
mpl.rc("figure", figsize=(14, 6))
mpl.rc("font", size=9.0)

memory = get_memory("analysis_time_lags_pdp_ale", verbose=100)

## Creating the Data Structures used for Fitting

In [None]:
shift_months = [1, 3, 6, 12, 24]

selection_variables = (
    "VOD Ku-band -3 Month",
    # "SIF",  # Fix regridding!!
    "VOD Ku-band -1 Month",
    "Dry Day Period -3 Month",
    "FAPAR",
    "pftHerb",
    "LAI -1 Month",
    "popd",
    "Dry Day Period -24 Month",
    "pftCrop",
    "FAPAR -1 Month",
    "FAPAR -24 Month",
    "Max Temp",
    "Dry Day Period -6 Month",
    "VOD Ku-band -6 Month",
)

ext_selection_variables = selection_variables + (
    "Dry Day Period -1 Month",
    "FAPAR -6 Month",
    "ShrubAll",
    "SWI(1)",
    "TreeAll",
)

(
    s_endog_data,
    s_exog_data,
    s_master_mask,
    s_filled_datasets,
    s_masked_datasets,
    s_land_mask,
) = wildfires.analysis.time_lags.get_data(
    shift_months=[1, 3, 6, 12, 24], selection_variables=selection_variables
)

(
    e_s_endog_data,
    e_s_exog_data,
    e_s_master_mask,
    e_s_filled_datasets,
    e_s_masked_datasets,
    e_s_land_mask,
) = wildfires.analysis.time_lags.get_data(
    shift_months=[1, 3, 6, 12, 24], selection_variables=ext_selection_variables
)

## Cached Model Fitting

In [None]:
ModelResults = namedtuple(
    "ModelResults",
    ("X_train", "X_test", "y_train", "y_test", "r2_test", "r2_train", "model"),
)


@memory.cache()
def get_ext_shifted_model():
    """Get a RF model trained on the extended shifted data.
    
    Returns:
        ModelResults: A namedtuple with the fields 'X_train', 'X_test', 'y_train', 'y_test', 
        'r2_test', 'r2_train', and 'model'.
    
    """
    # Split the data.
    X_train, X_test, y_train, y_test = train_test_split(
        e_s_exog_data, e_s_endog_data, random_state=1, shuffle=True, test_size=0.3
    )
    # Define and train the model.
    rf = RandomForestRegressor(
        n_estimators=20, max_depth=None, random_state=1, n_jobs=get_ncpus(),
    )
    rf.fit(X_train, y_train)

    r2_test = rf.score(X_test, y_test)
    r2_train = rf.score(X_train, y_train)

    return ModelResults(X_train, X_test, y_train, y_test, r2_test, r2_train, rf)


model_results = get_ext_shifted_model()
# Take advantage of all (bot no more) cores available to our job.
model_results.model.n_jobs = get_ncpus()

In [None]:
with figure_saver(model_results.X_train.columns, sub_directory="pdp"):
    figs_axes = partial_dependence_plot(
        model_results.model,
        model_results.X_train,
        model_results.X_train.columns,
        grid_resolution=40,
        plot_range=False,
        log_x_scale=("Dry Day Period", "popd"),
        single_plots=True,
        predicted_name="Burned Area",
    )

In [None]:
for column in tqdm(model_results.X_train.columns, desc="Calculating ALE plots"):
    with figure_saver(column, sub_directory="ale"):
        ale_plot(
            model_results.model,
            model_results.X_train,
            column,
            bins=40,
            monte_carlo=True,
            monte_carlo_rep=30,
            monte_carlo_ratio=0.1,
            verbose=False,
            log="x"
            if any(
                feature.lower() in column.lower()
                for feature in ("dry day period", "popd",)
            )
            else None,
        )