In [None]:
import random
import string
from collections import defaultdict
from functools import reduce
from itertools import product
from pathlib import Path
from warnings import filterwarnings

import iris
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats
import seaborn as sns
from alepython import ale_plot
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
from wildfires.analysis import cube_plotting
from wildfires.data import GFEDv4, MonthlyDataset, homogenise_time_coordinate, regrid

from jules_output_analysis.utils import (
    collapse_cube_dim,
    cube_1d_to_2d,
    get_mm_data,
    param_dict,
    train_test_split_kwargs,
)

filterwarnings("ignore", ".*divide by zero.*")

In [None]:
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}

In [None]:
source_dirs = {
    "old": Path("/work/scratch-nopw/alexkr/ignition3_5/jules_output"),
    "new": Path("/work/scratch-nopw/alexkr/antecedent3/jules_output"),
}
for source_dir in source_dirs.values():
    assert source_dir.is_dir()

runs = {
    "old": "SPINUP6",
    "new": "RUN1",
}

names = {
    "old": "Ignition 3",
    "new": "New Antec (Ig3)",
}

pfts = (
    "Broadleaf trees",
    "Needleleaf trees",
    "C3 (temperate) grass",
    "C4 (tropical) grass",
    "Shrubs",
)

soil_carbon_pools = ("DPM", "RPM", "bio", "hum")
pfts = list(range(9))

In [None]:
exp_cubes = {
    key: homogenise_time_coordinate(
        iris.load(str(source_dir / f"*{run_name}*Monthly*.nc"))
    ).concatenate()
    for ((key, run_name), source_dir) in zip(runs.items(), source_dirs.values())
}

In [None]:
all_cube_names = set()
for cubes in exp_cubes.values():
    all_cube_names.update([cube.name() for cube in cubes])
all_cube_names = sorted(all_cube_names)

In [None]:
for cube_name in tqdm(all_cube_names, desc="Plotting cubes"):
    for cubes, name in zip(exp_cubes.values(), names.values()):
        for cube in cubes:
            if cube.name() != cube_name:
                continue

            cube_2d = cube_1d_to_2d(cube[0])

            assert len(cube_2d.shape) >= 2

            if len(cube_2d.shape) == 2:
                sel = slice(None)
            else:
                for indices in product(*(range(l) for l in cube_2d.shape[:-2])):
                    sel = (*indices, slice(None), slice(None))
            try:
                fig = cube_plotting(cube_2d[sel], title=f"{cube.name()} ({name})")
            except Exception as e:
                print("cube:", str(cube))
                print("Error:", e)

In [None]:
ba_cubes = {
    key: cubes.extract_strict(iris.Constraint(name="Gridbox mean burnt area fraction"))
    for key, cubes in exp_cubes.items()
}

In [None]:
for name, ba_cube in zip(names.values(), ba_cubes.values()):
    ba_2d = cube_1d_to_2d(ba_cube[10][0])
    ba_2d.data.mask |= np.isnan(ba_2d.data)
    title = f"Mean BA {name}"
    fig = cube_plotting(
        ba_2d,
        fig=plt.figure(figsize=(6, 3), dpi=200),
        boundaries=[0, 1e-11, 1e-10, 1e-9, 4e-9, 2e-8],
        cmap="inferno",
        title=title,
        extend="max",
    )

### Processing of cubes

#### Choose cubes to process

In [None]:
proc_cube_names = [
    "C in decomposable plant material, gridbox total",
    "C in resistant plant material, gridbox total",
    "Fractional cover of each surface type",
    "Gridbox effective radiative temperature (assuming emissivity=1)",
    "Gridbox gross primary productivity",
    "Gridbox mean burnt area fraction",
    "Gridbox precipitation rate",
    "Gridbox unfrozen soil moisture as fraction of saturation",
    "PFT Fraction of Absorbed Photosynthetically Active Radiation",
    "PFT fuel build up",
]

In [None]:
exp_processed_cubes = {}
proc_months = 150

for key, cubes in exp_cubes.items():
    processed_cubes = iris.cube.CubeList([])
    for var in tqdm(proc_cube_names, desc="Processing cubes"):
        try:
            ext_cube = cubes.extract_strict(iris.Constraint(name=var))
        except iris.exceptions.ConstraintMismatchError:
            continue

        if len(ext_cube.shape) > 3:
            ext_cube = collapse_cube_dim(ext_cube, 1)

        # If needed, apply the same operation once again.
        if len(ext_cube.shape) == 4:
            ext_cube = collapse_cube_dim(ext_cube, 1)

        assert len(ext_cube.shape) == 3

        var_cube = cube_1d_to_2d(ext_cube)[-proc_months:]
        processed_cubes.append(var_cube)
    exp_processed_cubes[key] = processed_cubes

In [None]:
exp_proc_insts = defaultdict(list)

for key, processed_cubes in exp_processed_cubes.items():
    for proc_cube in tqdm(processed_cubes):
        # Create a new Dataset for each cube.
        proc_inst = type(
            proc_cube.long_name.replace(" ", ""),
            (MonthlyDataset,),
            {
                "__init__": lambda self: None,
                "frequency": "monthly",
            },
        )()
        proc_inst.cubes = iris.cube.CubeList([proc_cube])
        # Circumvent caching by including a random attribute.
        proc_inst.cube.attributes["rand"] = "".join(
            random.choices(string.ascii_lowercase, k=100)
        )

        exp_proc_insts[key].append(
            proc_inst.get_climatology_dataset(proc_inst.min_time, proc_inst.max_time)
        )

In [None]:
exp_shifted_proc_cubes = defaultdict(list)
exp_ba_cubes = {}

for key, proc_insts in exp_proc_insts.items():
    ba_cube = None

    for proc_inst in proc_insts:
        if "burnt area" in proc_inst.cube.long_name:
            exp_ba_cubes[key] = proc_inst.cube
            continue
        exp_shifted_proc_cubes[key].append(proc_inst.cube)
        for shift in (1, 3, 6, 9):
            # XXX: Overly simplistic np.roll() implementation!!
            c2 = proc_inst.cube.copy()
            c2.data = np.roll(proc_inst.cube.data, shift, axis=0)
            c2.long_name = c2.long_name + f"({shift})"
            c2.var_name = None
            c2.short_name = None

            exp_shifted_proc_cubes[key].append(c2)

    assert exp_ba_cubes[key] is not None

In [None]:
for key, shifted_proc_cubes in exp_shifted_proc_cubes.items():
    plt.figure(figsize=(5, 3), dpi=150)
    for i in np.array([0, 1, 2]) + 35:
        cube = shifted_proc_cubes[i]
        print(cube.coord("month_number").points)
        plt.plot(
            cube.coord("month_number").points,
            cube.data[:, 72, 110],
            label=cube.name(),
            alpha=0.9,
            marker="x",
        )
    plt.legend()

In [None]:
master_masks = [
    ba_cube.data.mask | (ba_cube.data.data > 1e-5) for ba_cube in exp_ba_cubes.values()
]
master_mask = reduce(np.logical_or, master_masks)
_ = cube_plotting(master_mask, title="Mask", colorbar_kwargs={"label": ""})

In [None]:
exp_endog_data = {}
exp_exog_data = {}

for key, ba_cube in exp_ba_cubes.items():
    exp_endog_data[key] = pd.Series(ba_cube.data.data[~master_mask])
    exp_endog_data[key].name = "burnt area"

    exog_dict = {}
    for cube in exp_shifted_proc_cubes[key]:
        exog_dict[cube.long_name] = cube.data.data[~master_mask]

    exp_exog_data[key] = pd.DataFrame(exog_dict)

In [None]:
shorten_mapping = {
    "Gridbox": "",
    "precipitation": "precip",
    "soil carbon in each pool": "soil pool carbon",
    "surface": "surf",
    "evapotranspiration": "evapot",
    "from soil moisture store": "soil moist",
    "temperature": "temp",
    "unfrozen moisture content of each soil layer as a fraction of saturation": "unfrozen moist soil layer / sat",
    "gross primary productivity": "gpp",
    "net primary productivity": "npp",
    "soil moisture availability factor (beta)": "soil moist avail fact",
    "total carbon content of the vegetation at the end of model timestep": "total veg C end timestep",
}

In [None]:
def shorten_columns(df):
    new_cols = []
    for col in df.columns:
        for old, new in shorten_mapping.items():
            col = col.replace(old, new)
        col = col.strip()
        new_cols.append(col)
    df.columns = new_cols
    return df

In [None]:
_ = list(map(shorten_columns, exp_exog_data.values()))

### Rescale the BA data so it has higher magnitudes

This seems to be required by the RF algorithm in order to make any predictions at all.

In [None]:
exp_endog_data = {
    key: endog_data / np.max(endog_data) for key, endog_data in exp_endog_data.items()
}

In [None]:
exp_model_data = {}
for key in runs:
    X_train, X_test, y_train, y_test = train_test_split(
        exp_exog_data[key], exp_endog_data[key], **train_test_split_kwargs
    )
    exp_model_data[key] = dict(
        X_train=X_train, X_test=X_test, y_train=y_train, y_test=y_test
    )

In [None]:
exp_models = {key: RandomForestRegressor(**param_dict) for key in runs}
for key, model in exp_models.items():
    model.n_jobs = 3
    model.fit(exp_model_data[key]["X_train"], exp_model_data[key]["y_train"])

In [None]:
for key, model in exp_models.items():
    exp_model_data[key]["y_train_pred"] = model.predict(exp_model_data[key]["X_train"])
    exp_model_data[key]["y_test_pred"] = model.predict(exp_model_data[key]["X_test"])

In [None]:
for key in runs:
    print(key)
    print(
        "train:",
        r2_score(
            y_true=exp_model_data[key]["y_train"],
            y_pred=exp_model_data[key]["y_train_pred"],
        ),
    )
    print(
        "val:",
        r2_score(
            y_true=exp_model_data[key]["y_test"],
            y_pred=exp_model_data[key]["y_test_pred"],
        ),
    )

In [None]:
fig = cube_plotting(
    get_mm_data(y_test.values, master_mask, "val"),
    boundaries=[0, 1e-3, 1e-2, 0.05, 0.2, 0.5],
    cmap="inferno",
    fig=plt.figure(figsize=(7, 3.2), dpi=150),
    colorbar_kwargs=dict(label="burnt area (scaled)"),
    title="Validation Observations",
    extend="neither",
)

In [None]:
fig = cube_plotting(
    get_mm_data(y_test_pred, master_mask, "val"),
    boundaries=[0, 1e-3, 1e-2, 0.05, 0.2, 0.5],
    cmap="inferno",
    fig=plt.figure(figsize=(7, 3.2), dpi=150),
    colorbar_kwargs=dict(label="burnt area (scaled)"),
    title="Validation Predictions",
    extend="neither",
)

In [None]:
plt.figure(dpi=250)
plt.hexbin(y_train, y_train_pred, bins="log")
plt.xlabel("observed BA (train)")
_ = plt.ylabel("predicted BA (train)")

In [None]:
plt.figure(dpi=250)
plt.hexbin(y_test, y_test_pred, bins="log")
plt.xlabel("observed BA (test)")
_ = plt.ylabel("predicted BA (test)")

### Gini importances

In [None]:
for key, model in exp_models.items():
    ind_trees_gini = pd.DataFrame(
        [tree.feature_importances_ for tree in model],
        columns=exp_model_data[key]["X_train"].columns,
    )
    mean_importances = ind_trees_gini.mean().sort_values(ascending=False)
    ind_trees_gini = ind_trees_gini.reindex(mean_importances.index, axis=1)

    fig, ax = plt.subplots(1, 1, figsize=(10, 4), dpi=170)

    N_col = 18

    sns.boxplot(data=ind_trees_gini.iloc[:, :N_col], ax=ax)
    ax.set(ylabel=f"Gini Importance\n")
    ax.set_title(names[key])
    _ = plt.setp(ax.xaxis.get_majorticklabels(), rotation=60, ha="right")

In [None]:
for key, model in exp_models.items():
    for feature in tqdm(exp_model_data[key]["X_train"].columns, desc="1D ALE plots"):
        fig, axes = ale_plot(
            model,
            exp_model_data[key]["X_train"],
            feature,
            bins=20,
            fig=plt.figure(figsize=(5.5, 2), dpi=150),
            quantile_axis=True,
            monte_carlo=True,
            monte_carlo_rep=10,
            monte_carlo_ratio=0.2,
        )
        axes["ale"].set_title(f"{feature}\n({names[key]})")
        plt.setp(axes["ale"].xaxis.get_majorticklabels(), rotation=60, ha="right")

## GFED4 BA

In [None]:
gfed_ba = GFEDv4()

gfed_ba

In [None]:
processed_cubes[0].coord("time").cell(0).point, processed_cubes[0].coord("time").cell(
    -1
).point

In [None]:
gfed_ba.limit_months(
    processed_cubes[0].coord("time").cell(0).point,
    processed_cubes[0].coord("time").cell(-1).point,
)

gfed_ba

In [None]:
clim_gfed_ba = gfed_ba.get_climatology_dataset(gfed_ba.min_time, gfed_ba.max_time)
clim_gfed_ba

In [None]:
clim_gfed_ba.cube.coord("month_number")

#### Regrid GFED4 to N96e grid

In [None]:
clim_gfed_ba.cube.shape

In [None]:
clim_gfed_ba.cube.coord("latitude").bounds = None
clim_gfed_ba.cube.coord("latitude").guess_bounds()

In [None]:
reg_gfed_ba_cube = regrid(
    clim_gfed_ba.cube,
    new_latitudes=shifted_proc_cubes[0].coord("latitude"),
    new_longitudes=shifted_proc_cubes[0].coord("longitude"),
    area_weighted=True,
    verbose=True,
)
reg_gfed_ba_cube.shape

In [None]:
reg_plot_gfed_cube = reg_gfed_ba_cube.copy()
reg_plot_gfed_cube.data.mask = master_mask
fig = cube_plotting(
    reg_plot_gfed_cube / (30 * 24 * 60 * 60),
    colorbar_kwargs=dict(label=r"$\mathrm{s}^{-1}$"),
    fig=plt.figure(figsize=(6, 3), dpi=200),
    boundaries=[0, 1e-11, 1e-10, 1e-9, 4e-9, 2e-8],
    cmap="inferno",
    title="Mean GFED4 BA",
    extend="max",
)

In [None]:
gfed_endog_data = pd.Series(reg_gfed_ba_cube.data.data[~master_mask])
gfed_endog_data.name = "GFED4 burnt area"
gfed_endog_data

### Comparing GFED4 and JULES BA

In [None]:
scipy.stats.pearsonr(gfed_endog_data.values, endog_data.values)

In [None]:
plt.hexbin(gfed_endog_data, endog_data, bins="log")
plt.xlabel("GFED4")
_ = plt.ylabel("JULES BA")

In [None]:
gfed_X_train, gfed_X_test, gfed_y_train, gfed_y_test = train_test_split(
    exog_data, gfed_endog_data, **train_test_split_kwargs
)

In [None]:
gfed_model = RandomForestRegressor(**param_dict)
gfed_model.n_jobs = 3
gfed_model.fit(gfed_X_train, gfed_y_train)

In [None]:
gfed_y_train_pred = gfed_model.predict(gfed_X_train)
gfed_y_test_pred = gfed_model.predict(gfed_X_test)

In [None]:
print("train:", r2_score(gfed_y_train_pred, gfed_y_train))
print("val:", r2_score(gfed_y_test_pred, gfed_y_test))

In [None]:
exog_data.columns

In [None]:
plt.hist(gfed_y_test.values)
plt.xscale("log")
plt.yscale("log")

In [None]:
fig = cube_plotting(
    get_mm_data(gfed_y_test.values, master_mask, "val"),
    boundaries=[0, 1e-3, 1e-2, 0.05, 0.2, 0.5],
    cmap="inferno",
    fig=plt.figure(figsize=(7, 3.2), dpi=150),
    colorbar_kwargs=dict(label="burnt area (scaled)"),
    title="GFED Validation Observations",
    extend="neither",
)

In [None]:
fig = cube_plotting(
    get_mm_data(gfed_y_test_pred, master_mask, "val"),
    boundaries=[0, 1e-3, 1e-2, 0.05, 0.2, 0.5],
    cmap="inferno",
    fig=plt.figure(figsize=(7, 3.2), dpi=150),
    colorbar_kwargs=dict(label="burnt area (scaled)"),
    title="GFED Validation Predictions",
    extend="neither",
)

In [None]:
plt.figure(dpi=250)
plt.hexbin(gfed_y_train, gfed_y_train_pred, bins="log")
plt.xlabel("observed BA (train)")
_ = plt.ylabel("predicted BA (train)")

In [None]:
plt.figure(dpi=250)
plt.hexbin(gfed_y_test, gfed_y_test_pred, bins="log")
plt.xlabel("observed BA (test)")
_ = plt.ylabel("predicted BA (test)")

### Gini importances

In [None]:
gfed_ind_trees_gini = pd.DataFrame(
    [tree.feature_importances_ for tree in gfed_model],
    columns=gfed_X_train.columns,
)
gfed_mean_importances = gfed_ind_trees_gini.mean().sort_values(ascending=False)
gfed_ind_trees_gini = gfed_ind_trees_gini.reindex(gfed_mean_importances.index, axis=1)

fig, ax = plt.subplots(1, 1, figsize=(10, 4), dpi=170)

N_col = 18

sns.boxplot(data=gfed_ind_trees_gini.iloc[:, :N_col], ax=ax)
ax.set(
    # title="Gini Importances",
    ylabel="Gini Importance\n"
)
_ = plt.setp(ax.xaxis.get_majorticklabels(), rotation=60, ha="right")

In [None]:
gfed_ind_trees_gini = pd.DataFrame(
    [tree.feature_importances_ for tree in gfed_model],
    columns=gfed_X_train.columns,
)
gfed_mean_importances = gfed_ind_trees_gini.mean().sort_values(ascending=False)
gfed_ind_trees_gini = gfed_ind_trees_gini.reindex(gfed_mean_importances.index, axis=1)

fig, ax = plt.subplots(1, 1, figsize=(10, 4), dpi=170)

N_col = 30

sns.boxplot(data=gfed_ind_trees_gini.iloc[:, :N_col], ax=ax)
ax.set(
    # title="Gini Importances",
    ylabel="Gini Importance\n"
)
_ = plt.setp(ax.xaxis.get_majorticklabels(), rotation=60, ha="right")

In [None]:
for feature in tqdm(gfed_X_train.columns, desc="1D ALE plots"):
    fig, axes = ale_plot(
        gfed_model,
        gfed_X_train,
        feature,
        bins=20,
        fig=plt.figure(figsize=(5.5, 2), dpi=150),
        quantile_axis=True,
        monte_carlo=True,
        monte_carlo_rep=10,
        monte_carlo_ratio=0.2,
    )
    plt.setp(axes["ale"].xaxis.get_majorticklabels(), rotation=60, ha="right")