# Shifted (and relative) data with RuleFit model

## Initialisation

In [None]:
import logging
import os
import re
import sys
import warnings
from collections import namedtuple
from functools import reduce
from itertools import combinations
from operator import mul

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
from joblib import Memory, parallel_backend
from loguru import logger as loguru_logger
from matplotlib.patches import Rectangle
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import cross_val_score, train_test_split
from tqdm import tqdm

import wildfires.analysis
from alepython import ale_plot
from alepython.ale import _second_order_ale_quant
from rulefit import RuleFit
from wildfires.analysis import *
from wildfires.dask_cx1 import get_client
from wildfires.data import *
from wildfires.logging_config import enable_logging
from wildfires.qstat import get_ncpus
from wildfires.utils import *

loguru_logger.enable("alepython")
loguru_logger.remove()
loguru_logger.add(sys.stderr, level="WARNING")

logger = logging.getLogger(__name__)

enable_logging("jupyter")

warnings.filterwarnings("ignore", ".*Collapsing a non-contiguous coordinate.*")
warnings.filterwarnings("ignore", ".*DEFAULT_SPHERICAL_EARTH_RADIUS*")
warnings.filterwarnings("ignore", ".*guessing contiguous bounds*")

normal_coast_linewidth = 0.5
mpl.rc("figure", figsize=(14, 6))
mpl.rc("font", size=9.0)

figure_saver = FigureSaver(
    directories=os.path.join("~", "tmp", "analysis_multiple_lags_rulefit"), debug=True
)
memory = get_memory("analysis_multiple_lags_rulefit", verbose=100)

In [None]:
value = "symlog"
linthres = 1e-2
subs = [2, 3, 4, 5, 6, 7, 8, 9]
log_xscale_kwargs = dict(value=value, linthreshx=linthres, subsx=subs)
log_yscale_kwargs = dict(value=value, linthreshy=linthres, subsy=subs)
log_vars = (
    "dry day period",
    "popd",
    "agb tree",
    "cape x precip",
    "lai",
    "shruball",
    "pftherb",
    "pftcrop",
    "treeall",
)

## Creating the Data Structures used for Fitting

In [None]:
shift_months = [1, 3, 6, 9, 12, 18, 24]

# selection_variables = (
#     "VOD Ku-band -3 Month",
#     # "SIF",  # Fix regridding!!
#     "VOD Ku-band -1 Month",
#     "Dry Day Period -3 Month",
#     "FAPAR",
#     "pftHerb",
#     "LAI -1 Month",
#     "popd",
#     "Dry Day Period -24 Month",
#     "pftCrop",
#     "FAPAR -1 Month",
#     "FAPAR -24 Month",
#     "Max Temp",
#     "Dry Day Period -6 Month",
#     "VOD Ku-band -6 Month",
# )

# ext_selection_variables = selection_variables + (
#     "Dry Day Period -1 Month",
#     "FAPAR -6 Month",
#     "ShrubAll",
#     "SWI(1)",
#     "TreeAll",
# )
from ipdb import launch_ipdb_on_exception

with launch_ipdb_on_exception():
    (
        e_s_endog_data,
        e_s_exog_data,
        e_s_master_mask,
        e_s_filled_datasets,
        e_s_masked_datasets,
        e_s_land_mask,
    ) = wildfires.analysis.time_lags.get_data(
        shift_months=shift_months, selection_variables=None
    )

### Offset data that has come 12 or more months before the current month in order to ease analysis.
We are interested in the trends in these properties, not their absolute values, therefore we subtract a recent 'seasonal cycle' analogue.
This hopefully avoids capturing the same relationships for a variable and its 12 month counterpart due to their high correlation.

In [None]:
to_delete = []
for column in e_s_exog_data:
    match = re.search(r"-\d{1,2}", column)
    if match:
        span = match.span()
        # Change the string to reflect the shift.
        original_offset = int(column[slice(*span)])
        if original_offset > -12:
            # Only shift months that are 12 or more months before the current month.
            continue
        comp = -(-original_offset % 12)
        new_column = " ".join(
            (
                column[: span[0] - 1],
                f"{original_offset} - {comp}",
                column[span[1] + 1 :],
            )
        )
        if comp == 0:
            comp_column = column[: span[0] - 1]
        else:
            comp_column = " ".join(
                (column[: span[0] - 1], f"{comp}", column[span[1] + 1 :])
            )
        print(column, comp_column)
        e_s_exog_data[new_column] = e_s_exog_data[column] - e_s_exog_data[comp_column]
        to_delete.append(column)
for column in to_delete:
    del e_s_exog_data[column]

# Fit the RuleFit Model

In [None]:
import pickle

model_pickle = os.path.join(DATA_DIR, ".pickle", "rulefit_03_05_2020_gb_default")
if os.path.isfile(model_pickle):
    print("Loading files")
    with open(model_pickle, "rb") as f:
        pass
        # rf, X_train, X_test, y_train, y_test = pickle.load(f)
else:
    print("Fitting model")
    # Split the data.
    X_train, X_test, y_train, y_test = train_test_split(
        e_s_exog_data, e_s_endog_data, random_state=1, shuffle=True, test_size=0.3
    )
    rf = RuleFit()
    # Fit the model.
    rf.fit(X_train.values, y_train.values, feature_names=X_train.columns)
    print("Writing files to disk")
    os.makedirs(os.path.dirname(model_pickle), exist_ok=True)
    with open(model_pickle, "wb") as f:
        pickle.dump((rf, X_train, X_test, y_train, y_test), f, -1)

In [None]:
from sklearn.metrics import r2_score

train_pred_y = rf.predict(X_train.values)
test_pred_y = rf.predict(X_test.values)
train_r2 = r2_score(y_true=y_train, y_pred=train_pred_y)
test_r2 = r2_score(y_true=y_test, y_pred=test_pred_y)
print("train R2", train_r2)
print("test R2", test_r2)

In [None]:
rules = rf.get_rules()
rules = rules[rules.coef != 0].sort_values(["support", "importance"], ascending=False)

In [None]:
pd.options.display.max_rows = 900
pd.options.display.max_colwidth = 300
rules