In [None]:
from itertools import product
from pathlib import Path
from warnings import filterwarnings

import iris
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from alepython import ale_plot
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
from utils import (
    collapse_cube_dim,
    cube_1d_to_2d,
    get_mm_data,
    param_dict,
    train_test_split_kwargs,
)
from wildfires.analysis import cube_plotting
from wildfires.data import MonthlyDataset, homogenise_time_coordinate
from wildfires.utils import ensure_datetime

filterwarnings("ignore", ".*divide by zero.*")

In [None]:
source_dir = Path("~/JULES_output/jules_output5").expanduser()
assert source_dir.is_dir()

In [None]:
cubes = homogenise_time_coordinate(
    iris.load(str(source_dir / "*Monthly*.nc"))
).concatenate()

In [None]:
cubes

In [None]:
for cube in tqdm(cubes, desc="Plotting cubes (single timeslice)"):
    cube_2d = cube_1d_to_2d(cube[0])

    assert len(cube_2d.shape) >= 2

    if len(cube_2d.shape) == 2:
        sel = slice(None)
    else:
        for indices in product(*(range(l) for l in cube_2d.shape[:-2])):
            sel = (*indices, slice(None), slice(None))
    try:
        fig = cube_plotting(cube_2d[sel])
    except Exception as e:
        print("cube:", str(cube))
        print("Error:", e)

In [None]:
ba_cube = cubes.extract_strict(iris.Constraint(name="Gridbox mean burnt area fraction"))

In [None]:
ba_2d = cube_1d_to_2d(ba_cube[10][0])
ba_2d.data.mask |= np.isnan(ba_2d.data)
fig = cube_plotting(ba_2d)

In [None]:
sorted({cube.long_name for cube in cubes})

In [None]:
cube_names = [
    "Gridbox precipitation rate",
    "Gridbox soil carbon (total)",
    "Gridbox soil carbon in each pool (DPM,RPM,bio,hum)",
    "Gridbox surface evapotranspiration from soil moisture store",
    # "Gridbox surface temperature",
    "Gridbox effective radiative temperature (assuming emissivity=1)",
    # "Gridbox unfrozen moisture content of each soil layer as a fraction of saturation",
    "Gridbox unfrozen soil moisture as fraction of saturation",
    "PFT burnt area fraction",
    "PFT gross primary productivity",
    "PFT leaf area index",
    # "PFT net primary productivity",
    "NPP (GBM) post N-limitation",
    "PFT net primary productivity prior to N limitation",
    # "PFT total carbon content of the vegetation at the end of model timestep.",
    "C in decomposable plant material",
    "C in decomposable plant material, gridbox total",
    "C in humus",
    "C in resistant plant material",
    "C in resistant plant material, gridbox total",
    "C in soil biomass",
    "C in soil biomass, gridbox total",
    "C in soil humus, gridbox total",
]

In [None]:
lat_lon_dict = {
    "UK": (51.5, 0),
    "Uganda": (2.36, 32.51),
    "Durban": (-29.13, 31),
    "Cape Town": (-33.57, 19.28),
    "Yosemite": (37.7, 360 - 119.67),
}

In [None]:
n_months = 25

In [None]:
soil_carbon_pools = ("DPM", "RPM", "bio", "hum")
pfts = list(range(9))

### Plotting with individual PFTs

In [None]:
for (location, (lat, lon)) in lat_lon_dict.items():
    fig, axes = plt.subplots(
        len(cube_names),
        1,
        figsize=(10, 3.5 * len(cube_names)),
        sharex=True,
        constrained_layout=True,
        dpi=120,
    )
    fig.suptitle(location)

    for ax, var in zip(axes, cube_names):
        ext_cube = cubes.extract_strict(iris.Constraint(name=var))
        var_cube = cube_1d_to_2d(ext_cube)[-n_months:]

        print(var_cube.long_name, var_cube.shape)

        ax.set_title(f"{var_cube.long_name} ({var_cube.units})")

        lat_i = np.argmin(np.abs(var_cube.coord("latitude").points - lat))
        lat_sel = var_cube.coord("latitude").points[lat_i]

        lon_i = np.argmin(np.abs(var_cube.coord("longitude").points - lon))
        lon_sel = var_cube.coord("longitude").points[lon_i]

        if len(var_cube.shape) == 3:
            plot_data_list = [var_cube[:, lat_i, lon_i].data.copy()]
        elif len(var_cube.shape) == 4:
            plot_data_list = [
                var_cube[:, i, lat_i, lon_i].data.copy()
                for i in range(var_cube.shape[1])
            ]
        else:
            raise ValueError()

        labels = list(map(str, range(len(plot_data_list))))
        if len(var_cube.shape) == 4:
            if var_cube.shape[1] == 4:
                labels = soil_carbon_pools
            elif var_cube.shape[1] == 5:
                labels = ["BT", "NT", "C3", "C4", "SH"]

        for plot_data, label in zip(plot_data_list, labels):
            # plot_data_mean = np.mean(plot_data)
            # plot_data -= plot_data_mean

            ax.plot(
                [
                    ensure_datetime(var_cube.coord("time").cell(i).point)
                    for i in range(var_cube.shape[0])
                ],
                plot_data,
                label=label,
                marker=".",
            )
        ax.legend(loc="upper left", bbox_to_anchor=(1, 1))

### Plotting with mean over individual PFTs and pools

In [None]:
# fig, axes = plt.subplots(
#     len(cube_names),
#     1,
#     figsize=(9, 3 * len(cube_names)),
#     sharex=True,
#     constrained_layout=True,
#     dpi=120,
# )

# for ax, var in zip(axes, cube_names):
#     ext_cube = cubes.extract_strict(iris.Constraint(name=var))
#     if "generic" in list(coord.name() for coord in ext_cube.coords()):
#         # Average over carbon pools or PFTs
#         ext_cube = ext_cube.collapsed("generic", iris.analysis.MEAN)

#     var_cube = cube_1d_to_2d(ext_cube)[-n_months:]

#     print(var_cube.long_name, var_cube.shape)

#     ax.set_title(f"{var_cube.long_name} ({var_cube.units})")

#     for (location, (lat, lon)) in lat_lon_dict.items():
#         lat_i = np.argmin(np.abs(var_cube.coord("latitude").points - lat))
#         lat_sel = var_cube.coord("latitude").points[lat_i]

#         lon_i = np.argmin(np.abs(var_cube.coord("longitude").points - lon))
#         lon_sel = var_cube.coord("longitude").points[lon_i]

#         assert len(var_cube.shape) == 3

#         plot_data = var_cube[:, lat_i, lon_i].data.copy()

#         # plot_data_mean = np.mean(plot_data)
#         # plot_data -= plot_data_mean

#         ax.plot(
#             [
#                 ensure_datetime(var_cube.coord("time").cell(i).point)
#                 for i in range(var_cube.shape[0])
#             ],
#             plot_data,
#             label=location,
#             marker=".",
#         )
#     ax.legend(loc="upper left", bbox_to_anchor=(1, 1))

### Processing of cubes

In [None]:
proc_months = 150

processed_cubes = iris.cube.CubeList([])
for var in tqdm(cube_names, desc="Processing cubes"):
    ext_cube = cubes.extract_strict(iris.Constraint(name=var))

    if len(ext_cube.shape) > 3:
        ext_cube = collapse_cube_dim(ext_cube, 1)

    # If needed, apply the same operation once again.
    if len(ext_cube.shape) == 4:
        ext_cube = collapse_cube_dim(ext_cube, 1)

    assert len(ext_cube.shape) == 3

    var_cube = cube_1d_to_2d(ext_cube)[-proc_months:]
    processed_cubes.append(var_cube)

In [None]:
processed_cubes

#### 'proper' averaging - requires more RAM

In [None]:
# # Raw, shifted datasets.
# raw_proc_insts = []

# for proc_cube in tqdm(processed_cubes):
#     # Create a new Dataset for each cube.
#     proc_inst = type(
#         proc_cube.long_name.replace(" ", ""),
#         (MonthlyDataset,),
#         {
#             "__init__": lambda self: None,
#             "frequency": "monthly",
#         },
#     )()
#     proc_inst.cubes = iris.cube.CubeList([proc_cube])

#     raw_proc_insts.append(proc_inst)

#     if "burnt area" not in proc_inst.cube.long_name:
#         # Shift if applicable, i.e. not BA.
#         for shift in (1, 3, 6, 9):
#             raw_proc_insts.append(
#                 proc_inst.get_temporally_shifted_dataset(months=-shift, deep=False)
#             )

In [None]:
# clim_proc_insts = prepare_selection(Datasets(raw_proc_insts), which="climatology")

In [None]:
# raw_proc_insts.homogenise_masks()
# raw_proc_insts.apply_masks(ba_cube.data.mask)

In [None]:
proc_insts = []

for proc_cube in tqdm(processed_cubes):
    # Create a new Dataset for each cube.
    proc_inst = type(
        proc_cube.long_name.replace(" ", ""),
        (MonthlyDataset,),
        {
            "__init__": lambda self: None,
            "frequency": "monthly",
        },
    )()
    proc_inst.cubes = iris.cube.CubeList([proc_cube])

    proc_insts.append(
        proc_inst.get_climatology_dataset(proc_inst.min_time, proc_inst.max_time)
    )

In [None]:
shifted_proc_cubes = []

ba_cube = None

for proc_inst in proc_insts:
    if "burnt area" in proc_inst.cube.long_name:
        ba_cube = proc_inst.cube
        continue
    shifted_proc_cubes.append(proc_inst.cube)
    for shift in (1, 3, 6, 9):
        # XXX: Overly simplistic np.roll() implementation!!
        c2 = proc_inst.cube.copy()
        c2.data = np.roll(proc_inst.cube.data, shift, axis=0)
        c2.long_name = c2.long_name + f"({shift})"
        c2.var_name = None
        c2.short_name = None

        shifted_proc_cubes.append(c2)

assert ba_cube is not None

In [None]:
if not cube.coords("month_number"):
    iris.coord_categorisation.add_month_number(cube, "time")

plt.figure(figsize=(5, 3), dpi=150)
for i in np.array([0, 1, 2]) + 35:
    print(cube.coord("month_number").points)
    cube = shifted_proc_cubes[i]
    plt.plot(
        cube.coord("month_number").points,
        cube.data[:, 72, 110],
        label=cube.name(),
        alpha=0.9,
        marker="x",
    )
plt.legend()

In [None]:
master_mask = ba_cube.data.mask

In [None]:
endog_data = pd.Series(ba_cube.data.data[~master_mask])
endog_data.name = "burnt area"

exog_dict = {}
for cube in shifted_proc_cubes:
    exog_dict[cube.long_name] = cube.data.data[~master_mask]

exog_data = pd.DataFrame(exog_dict)

In [None]:
endog_data

In [None]:
exog_data

In [None]:
shorten_mapping = {
    "Gridbox": "",
    "precipitation": "precip",
    "soil carbon in each pool": "soil pool carbon",
    "surface": "surf",
    "evapotranspiration": "evapot",
    "from soil moisture store": "soil moist",
    "temperature": "temp",
    "unfrozen moisture content of each soil layer as a fraction of saturation": "unfrozen moist soil layer / sat",
    "gross primary productivity": "gpp",
    "net primary productivity": "npp",
    "soil moisture availability factor (beta)": "soil moist avail fact",
    "total carbon content of the vegetation at the end of model timestep": "total veg C end timestep",
}

In [None]:
def shorten_columns(df):
    new_cols = []
    for col in df.columns:
        for old, new in shorten_mapping.items():
            col = col.replace(old, new)
        col = col.strip()
        new_cols.append(col)
    df.columns = new_cols
    return df

In [None]:
shorten_columns(exog_data)

### Rescale the BA data so it has higher magnitudes

This seems to be required by the RF algorithm in order to make any predictions at all.

In [None]:
endog_data /= np.max(endog_data)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    exog_data, endog_data, **train_test_split_kwargs
)

In [None]:
model = RandomForestRegressor(**param_dict)
model.n_jobs = 3
model.fit(X_train, y_train)

In [None]:
y_train_pred = model.predict(X_train)
y_test_pred = model.predict(X_test)

In [None]:
print("train:", r2_score(y_train_pred, y_train))
print("val:", r2_score(y_test_pred, y_test))

In [None]:
exog_data.columns

In [None]:
plt.hist(y_test.values)
plt.xscale("log")
plt.yscale("log")

In [None]:
fig = cube_plotting(
    get_mm_data(y_test.values, master_mask, "val"),
    boundaries=[0, 1e-3, 1e-2, 0.05, 0.2, 0.5],
    cmap="inferno",
    fig=plt.figure(figsize=(7, 3.2), dpi=150),
    colorbar_kwargs=dict(label="burnt area (scaled)"),
    title="Validation Observations",
    extend="neither",
)

In [None]:
fig = cube_plotting(
    get_mm_data(y_test_pred, master_mask, "val"),
    boundaries=[0, 1e-3, 1e-2, 0.05, 0.2, 0.5],
    cmap="inferno",
    fig=plt.figure(figsize=(7, 3.2), dpi=150),
    colorbar_kwargs=dict(label="burnt area (scaled)"),
    title="Validation Predictions",
    extend="neither",
)

In [None]:
plt.figure(dpi=250)
plt.hexbin(y_train, y_train_pred, bins="log")
plt.xlabel("observed BA (train)")
_ = plt.ylabel("predicted BA (train)")

In [None]:
plt.figure(dpi=250)
plt.hexbin(y_test, y_test_pred, bins="log")
plt.xlabel("observed BA (test)")
_ = plt.ylabel("predicted BA (test)")

### Gini importances

In [None]:
ind_trees_gini = pd.DataFrame(
    [tree.feature_importances_ for tree in model],
    columns=X_train.columns,
)
mean_importances = ind_trees_gini.mean().sort_values(ascending=False)
ind_trees_gini = ind_trees_gini.reindex(mean_importances.index, axis=1)

fig, ax = plt.subplots(1, 1, figsize=(10, 4), dpi=170)

N_col = 18

sns.boxplot(data=ind_trees_gini.iloc[:, :N_col], ax=ax)
ax.set(
    # title="Gini Importances",
    ylabel="Gini Importance\n"
)
_ = plt.setp(ax.xaxis.get_majorticklabels(), rotation=60, ha="right")

In [None]:
ind_trees_gini = pd.DataFrame(
    [tree.feature_importances_ for tree in model],
    columns=X_train.columns,
)
mean_importances = ind_trees_gini.mean().sort_values(ascending=False)
ind_trees_gini = ind_trees_gini.reindex(mean_importances.index, axis=1)

fig, ax = plt.subplots(1, 1, figsize=(10, 4), dpi=170)

N_col = 30

sns.boxplot(data=ind_trees_gini.iloc[:, :N_col], ax=ax)
ax.set(
    # title="Gini Importances",
    ylabel="Gini Importance\n"
)
_ = plt.setp(ax.xaxis.get_majorticklabels(), rotation=60, ha="right")

In [None]:
for feature in tqdm(X_train.columns, desc="1D ALE plots"):
    fig, axes = ale_plot(
        model,
        X_train,
        feature,
        bins=20,
        fig=plt.figure(figsize=(5.5, 2), dpi=150),
        quantile_axis=True,
        monte_carlo=True,
        monte_carlo_rep=10,
        monte_carlo_ratio=0.2,
    )
    plt.setp(axes["ale"].xaxis.get_majorticklabels(), rotation=60, ha="right")