In [None]:
from datetime import datetime
from itertools import product
from pathlib import Path
from warnings import filterwarnings

import cf_units
import iris
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from alepython import ale_plot
from joblib import Memory
from numba import jit, njit
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
from wildfires.analysis import corr_plot, cube_plotting
from wildfires.data import MonthlyDataset, homogenise_cube_attributes
from wildfires.utils import ensure_datetime

filterwarnings("ignore", ".*divide by zero.*")
memory = Memory(".cache", verbose=0)

In [None]:
# Training and validation test splitting.
train_test_split_kwargs = dict(random_state=1, shuffle=True, test_size=0.3)

# Specify common RF (training) params.
n_splits = 5

default_param_dict = {"random_state": 1, "bootstrap": True}

# XXX
param_dict = {
    **default_param_dict,
    "ccp_alpha": 2e-9,
    "max_depth": 15,
    "max_features": "auto",
    "min_samples_leaf": 4,
    "min_samples_split": 2,
    "n_estimators": 100,
}

In [None]:
def get_mm_indices(master_mask):
    mm_valid_indices = np.where(~master_mask.ravel())[0]
    mm_valid_train_indices, mm_valid_val_indices = train_test_split(
        mm_valid_indices,
        **train_test_split_kwargs,
    )
    return mm_valid_indices, mm_valid_train_indices, mm_valid_val_indices


def get_mm_data(x, master_mask, kind):
    """Return masked master_mask copy and training or validation indices.

    The master_mask copy is filled using the given data.

    Args:
        x (array-like): Data to use.
        master_mask (array):
        kind ({'train', 'val'})

    Returns:
        masked_data, mm_indices:

    """
    mm_valid_indices, mm_valid_train_indices, mm_valid_val_indices = get_mm_indices(
        master_mask
    )
    masked_data = np.ma.MaskedArray(
        np.zeros_like(master_mask, dtype=np.float64), mask=np.ones_like(master_mask)
    )
    if kind == "train":
        masked_data.ravel()[mm_valid_train_indices] = x
    elif kind == "val":
        masked_data.ravel()[mm_valid_val_indices] = x
    else:
        raise ValueError(f"Unknown kind: {kind}")
    return masked_data

In [None]:
@njit
def isclose(a, b, atol=1e-4):
    return np.abs(a - b) < atol


assert isclose(1, 1)
assert isclose(1, 1 + 1e-5)
assert not isclose(1, 1 + 1e-5, atol=1e-6)

In [None]:
@njit
def find_gridpoint(land_lat, land_lon, grid_lats, grid_lons):
    """Mapping from a single land coordinate to the matching grid indices."""
    for lat_i, grid_lat in enumerate(grid_lats):
        for lon_i, grid_lon in enumerate(grid_lons):
            if isclose(land_lat, grid_lat) and isclose(land_lon, grid_lon):
                return lat_i, lon_i
    raise RuntimeError("Matching gridpoint not found.")

In [None]:
%%time

print(find_gridpoint(0, 0, np.array([-1, 0, 1]), np.array([-1, 0, 1])))
print(
    find_gridpoint(
        0.5, 0.5, np.array([-1, 0, 0.5, 1]), np.array([-1, 0, 0.25, 0.5, 0.75, 1])
    )
)
print(
    find_gridpoint(
        0,
        0,
        np.linspace(-90, 90, 100, endpoint=False),
        np.linspace(-90, 90, 100, endpoint=False),
    ),
)

In [None]:
@memory.cache
def get_grid_mask(mask, orig_lats, orig_lons, grid_lats, grid_lons):
    """Calculate mask to transition from one grid to another.

    Note:
        This probably relies on the contiguity structure of the arrays.

    """
    # XXX: Would it not be simpler and almost equally robust to look at the spacing between coordinates and use this to infer the index? i.e. np.rint(lats - lats[0] / (lats[1] - lats[0])), as in jules.py?
    for (land_i, (land_lat, land_lon)) in enumerate(
        zip(tqdm(orig_lats, desc="Land gridpoint"), orig_lons)
    ):
        lat_i, lon_i = find_gridpoint(land_lat, land_lon, grid_lats, grid_lons)
        mask[..., lat_i, lon_i] = True
    return mask

In [None]:
def cube_1d_to_2d(cube):
    """Convert JULES output on 1D grid to 2D grid."""
    land_grid_coord = -1  # The last axis is associated with the spatial domain.

    assert land_grid_coord == -1

    lat_coord = cube.coord("latitude")
    lon_coord = cube.coord("longitude")

    orig_lats = lat_coord.points.data.ravel()
    orig_lons = lon_coord.points.data.ravel()

    lat_step = np.unique(np.diff(np.sort(orig_lats)))[1]
    lon_step = np.unique(np.diff(np.sort(orig_lons)))[1]

    # Use the latitude and longitude steps from above to determine the number of
    # latitude and longitude steps.
    n_lat = round(180 / lat_step)
    n_lon = round(360 / lon_step)

    # Ensure that these represent a regular grid.
    assert np.isclose(n_lat * lat_step, 180)
    assert np.isclose(n_lon * lon_step, 360)

    # Create a grid of ..., lat, lon points to match the shape of the given cube.
    new_shape = tuple(list(cube.shape[:land_grid_coord]) + [n_lat, n_lon])

    # Now convert the 1D data to the 2D array created above, using a mask.
    mask = np.zeros((n_lat, n_lon), dtype=np.bool_)

    grid_lats = np.linspace(-90, 90, n_lat, endpoint=False)
    grid_lons = np.linspace(0, 360, n_lon, endpoint=False)

    mask = get_grid_mask(mask, orig_lats, orig_lons, grid_lats, grid_lons)

    if len(np.squeeze(cube.data).shape) == 1:
        # Simply assign based on the mask.
        new_data = np.ma.MaskedArray(np.zeros_like(mask, dtype=np.float64), mask=True)
        new_data[mask] = np.squeeze(cube.data)
    elif len(np.squeeze(cube.data).shape) > 1:
        # Iterate over earlier dimensions.
        new_data = np.ma.MaskedArray(np.zeros(new_shape, dtype=np.float64), mask=True)
        for indices in product(*(range(l) for l in cube.shape[:-1])):
            sel = (*indices, slice(None))
            new_data[sel][mask] = cube.data[sel]
    else:
        raise ValueError(f"Invalid cube shape {cube.shape}")

    # XXX: Assumes more than 1 lat, lon coord, and destroys other dimensions.
    new_data = np.squeeze(new_data)

    n_dim = len(new_data.shape)
    lat_dim = n_dim - 2
    lon_dim = n_dim - 1

    new_cube = iris.cube.Cube(
        new_data,
        dim_coords_and_dims=[
            (
                iris.coords.DimCoord(
                    grid_lats, standard_name="latitude", units="degrees"
                ),
                lat_dim,
            ),
            (
                iris.coords.DimCoord(
                    grid_lons, standard_name="longitude", units="degrees"
                ),
                lon_dim,
            ),
        ],
    )

    new_cube.metadata = cube.metadata
    return new_cube

In [None]:
memory = Memory(".cache", verbose=0)

In [None]:
jules_gws_dir = Path("/gws/nopw/j04/jules")
source_dir = jules_gws_dir / "stephanemangeon/FireMIP_fixed_clim"
assert source_dir.is_dir()

In [None]:
data_file = "FireMIP.inferno.fixed_clim.Monthly.2013.nc"  # Contains aggregated data.

In [None]:
cubes = iris.load(str(source_dir / data_file))

In [None]:
cubes

In [None]:
for cube in tqdm(cubes, desc="Plotting cubes (single timeslice)"):
    cube_2d = cube_1d_to_2d(cube[0])

    assert len(cube_2d.shape) >= 2

    if len(cube_2d.shape) == 2:
        sel = slice(None)
    else:
        for indices in product(*(range(l) for l in cube_2d.shape[:-2])):
            sel = (*indices, slice(None), slice(None))
    try:
        fig = cube_plotting(cube_2d[sel])
    except Exception as e:
        print("cube:", str(cube))
        print("Error:", e)

In [None]:
ba_cube = cubes.extract_strict(iris.Constraint(name="PFT burnt area fraction"))

In [None]:
ba_2d = cube_1d_to_2d(ba_cube[10][0])
ba_2d.data.mask |= np.isnan(ba_2d.data)
fig = cube_plotting(ba_2d)

In [None]:
sorted({cube.long_name for cube in cubes})

In [None]:
cube_names = [
    "Gridbox precipitation rate",
    "Gridbox soil carbon (total)",
    "Gridbox soil carbon in each pool (DPM,RPM,bio,hum)",
    "Gridbox surface evapotranspiration from soil moisture store",
    "Gridbox surface temperature",
    "Gridbox unfrozen moisture content of each soil layer as a fraction of saturation",
    "PFT burnt area fraction",
    "PFT gross primary productivity",
    "PFT leaf area index",
    "PFT net primary productivity",
    "PFT soil moisture availability factor (beta)",
    "PFT total carbon content of the vegetation at the end of model timestep.",
]

In [None]:
lat_lon_dict = {
    "UK": (51.5, 0),
    "Uganda": (2.36, 32.51),
    "Durban": (-29.13, 31),
    "Cape Town": (-33.57, 19.28),
    "Yosemite": (37.7, 360 - 119.67),
}

In [None]:
n_months = 25

In [None]:
soil_carbon_pools = ("DPM", "RPM", "bio", "hum")
pfts = list(range(9))

### Plotting with individual PFTs

In [None]:
for (location, (lat, lon)) in lat_lon_dict.items():
    fig, axes = plt.subplots(
        len(cube_names),
        1,
        figsize=(10, 3.5 * len(cube_names)),
        sharex=True,
        constrained_layout=True,
        dpi=120,
    )
    fig.suptitle(location)

    for ax, var in zip(axes, cube_names):
        ext_cube = cubes.extract_strict(iris.Constraint(name=var))
        var_cube = cube_1d_to_2d(ext_cube)[-n_months:]
        var_cube.add_dim_coord(ext_cube.coord("time")[-n_months:], 0)

        print(var_cube.long_name, var_cube.shape)

        ax.set_title(f"{var_cube.long_name} ({var_cube.units})")

        lat_i = np.argmin(np.abs(var_cube.coord("latitude").points - lat))
        lat_sel = var_cube.coord("latitude").points[lat_i]

        lon_i = np.argmin(np.abs(var_cube.coord("longitude").points - lon))
        lon_sel = var_cube.coord("longitude").points[lon_i]

        if len(var_cube.shape) == 3:
            plot_data_list = [var_cube[:, lat_i, lon_i].data.copy()]
        elif len(var_cube.shape) == 4:
            plot_data_list = [
                var_cube[:, i, lat_i, lon_i].data.copy()
                for i in range(var_cube.shape[1])
            ]
        else:
            raise ValueError()

        for i, plot_data in enumerate(plot_data_list):
            plot_data_mean = np.mean(plot_data)
            plot_data -= plot_data_mean

            if "soil carbon" in var_cube.long_name and "pool" in var_cube.long_name:
                label = soil_carbon_pools[i]
            else:
                label = str(i)

            ax.plot(
                [
                    ensure_datetime(var_cube.coord("time").cell(i).point)
                    for i in range(var_cube.shape[0])
                ],
                plot_data,
                label=label,
                marker=".",
            )
        ax.legend(loc="upper left", bbox_to_anchor=(1, 1))

### Plotting with mean over individual PFTs and pools

In [None]:
fig, axes = plt.subplots(
    len(cube_names),
    1,
    figsize=(9, 3 * len(cube_names)),
    sharex=True,
    constrained_layout=True,
    dpi=120,
)

for ax, var in zip(axes, cube_names):
    ext_cube = cubes.extract_strict(iris.Constraint(name=var))
    if "generic" in list(coord.name() for coord in ext_cube.coords()):
        # Average over carbon pools or PFTs
        ext_cube = ext_cube.collapsed("generic", iris.analysis.MEAN)

    var_cube = cube_1d_to_2d(ext_cube)[-n_months:]
    var_cube.add_dim_coord(ext_cube.coord("time")[-n_months:], 0)

    print(var_cube.long_name, var_cube.shape)

    ax.set_title(f"{var_cube.long_name} ({var_cube.units})")

    for (location, (lat, lon)) in lat_lon_dict.items():
        lat_i = np.argmin(np.abs(var_cube.coord("latitude").points - lat))
        lat_sel = var_cube.coord("latitude").points[lat_i]

        lon_i = np.argmin(np.abs(var_cube.coord("longitude").points - lon))
        lon_sel = var_cube.coord("longitude").points[lon_i]

        assert len(var_cube.shape) == 3

        plot_data = var_cube[:, lat_i, lon_i].data.copy()

        plot_data_mean = np.mean(plot_data)
        plot_data -= plot_data_mean

        ax.plot(
            [
                ensure_datetime(var_cube.coord("time").cell(i).point)
                for i in range(var_cube.shape[0])
            ],
            plot_data,
            label=location,
            marker=".",
        )
    ax.legend(loc="upper left", bbox_to_anchor=(1, 1))

In [None]:
proc_months = 49

ref_unit = cf_units.Unit("seconds since 1900-1-1", calendar="365_day")

processed_cubes = iris.cube.CubeList([])
for var in tqdm(cube_names, desc="Processing cubes"):
    ext_cube = cubes.extract_strict(iris.Constraint(name=var))
    if "generic" in list(coord.name() for coord in ext_cube.coords()):
        # Average over carbon pools or PFTs
        ext_cube = ext_cube.collapsed("generic", iris.analysis.MEAN)

    var_cube = cube_1d_to_2d(ext_cube)[-proc_months:-1]

    time_coord = ext_cube.coord("time")[-proc_months:-1]
    assert time_coord.units == ref_unit
    time_coord.bounds = None
    # Sometimes ~5 minutes may be missing to get to the next day.
    time_coord.points = time_coord.points + 5 * 60

    var_cube.add_dim_coord(time_coord, 0)

    processed_cubes.append(var_cube)

In [None]:
processed_cubes

In [None]:
proc_insts = []

for proc_cube in tqdm(processed_cubes):
    # Create a new Dataset for each cube.
    proc_inst = type(
        proc_cube.long_name.replace(" ", ""),
        (MonthlyDataset,),
        {
            "__init__": lambda self: None,
            "frequency": "monthly",
        },
    )()
    proc_inst.cubes = iris.cube.CubeList([proc_cube])

    proc_insts.append(
        proc_inst.get_climatology_dataset(proc_inst.min_time, proc_inst.max_time)
    )

In [None]:
shifted_proc_cubes = []

ba_cube = None

for proc_inst in proc_insts:
    if "burnt area" in proc_inst.cube.long_name:
        ba_cube = proc_inst.cube
        continue
    shifted_proc_cubes.append(proc_inst.cube)
    for shift in (1, 3, 6, 9):
        # XXX: Overly simplistic np.roll() implementation!!
        c2 = proc_inst.cube.copy()
        c2.data = np.roll(proc_inst.cube.data, shift, axis=0)
        c2.long_name = c2.long_name + f"({shift})"
        c2.var_name = None
        c2.short_name = None

        shifted_proc_cubes.append(c2)

assert ba_cube is not None

In [None]:
plt.figure(figsize=(5, 3), dpi=150)
for i in np.array([0, 1, 2]) + 35:
    print(cube.coord("month_number").points)
    cube = shifted_proc_cubes[i]
    plt.plot(
        cube.coord("month_number").points,
        cube.data[:, 72, 110],
        label=cube.name(),
        alpha=0.9,
        marker="x",
    )
plt.legend()

In [None]:
master_mask = ba_cube.data.mask

In [None]:
endog_data = pd.Series(ba_cube.data.data[~master_mask])
endog_data.name = "burnt area"

exog_dict = {}
for cube in shifted_proc_cubes:
    exog_dict[cube.long_name] = cube.data.data[~master_mask]

exog_data = pd.DataFrame(exog_dict)

In [None]:
endog_data

In [None]:
exog_data

In [None]:
shorten_mapping = {
    "Gridbox": "",
    "precipitation": "precip",
    "soil carbon in each pool": "soil pool carbon",
    "surface": "surf",
    "evapotranspiration": "evapot",
    "from soil moisture store": "soil moist",
    "temperature": "temp",
    "unfrozen moisture content of each soil layer as a fraction of saturation": "unfrozen moist soil layer / sat",
    "gross primary productivity": "gpp",
    "net primary productivity": "npp",
    "soil moisture availability factor (beta)": "soil moist avail fact",
    "total carbon content of the vegetation at the end of model timestep": "total veg C end timestep",
}

In [None]:
def shorten_columns(df):
    new_cols = []
    for col in df.columns:
        for old, new in shorten_mapping.items():
            col = col.replace(old, new)
        col = col.strip()
        new_cols.append(col)
    df.columns = new_cols
    return df

In [None]:
shorten_columns(exog_data)

### Rescale the BA data so it has higher magnitudes

This seems to be required by the RF algorithm in order to make any predictions at all.

In [None]:
endog_data /= np.max(endog_data)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    exog_data, endog_data, **train_test_split_kwargs
)

In [None]:
model = RandomForestRegressor(**param_dict)
model.n_jobs = 3
model.fit(X_train, y_train)

In [None]:
y_train_pred = model.predict(X_train)
y_test_pred = model.predict(X_test)

In [None]:
print("train:", r2_score(y_train_pred, y_train))
print("val:", r2_score(y_test_pred, y_test))

In [None]:
exog_data.columns

In [None]:
plt.hist(y_test.values)
plt.xscale("log")
plt.yscale("log")

In [None]:
fig = cube_plotting(
    get_mm_data(y_test.values, master_mask, "val"),
    boundaries=[0, 1e-3, 1e-2, 0.05, 0.2, 0.5],
    cmap="inferno",
    fig=plt.figure(figsize=(7, 3.2), dpi=150),
    colorbar_kwargs=dict(label="burnt area (scaled)"),
    title="Validation Observations",
    extend='neither',
)

In [None]:
fig = cube_plotting(
    get_mm_data(y_test_pred, master_mask, "val"),
    boundaries=[0, 1e-3, 1e-2, 0.05, 0.2, 0.5],
    cmap="inferno",
    fig=plt.figure(figsize=(7, 3.2), dpi=150),
    colorbar_kwargs=dict(label="burnt area (scaled)"),
    title="Validation Predictions",
    extend='neither',
)

In [None]:
plt.figure(dpi=250)
plt.hexbin(exog_data["surf temp"], endog_data, bins="log")
plt.xlabel('surface temperature')
_ = plt.ylabel('BA')

In [None]:
plt.figure(dpi=250)
plt.hexbin(y_train, y_train_pred, bins="log")
plt.xlabel("observed BA (train)")
_ = plt.ylabel("predicted BA (train)")

In [None]:
plt.figure(dpi=250)
plt.hexbin(y_test, y_test_pred, bins="log")
plt.xlabel("observed BA (test)")
_ = plt.ylabel("predicted BA (test)")

### Gini importances

In [None]:
ind_trees_gini = pd.DataFrame(
    [tree.feature_importances_ for tree in model],
    columns=X_train.columns,
)
mean_importances = ind_trees_gini.mean().sort_values(ascending=False)
ind_trees_gini = ind_trees_gini.reindex(mean_importances.index, axis=1)

fig, ax = plt.subplots(1, 1, figsize=(10, 4), dpi=170)

N_col = 18

sns.boxplot(data=ind_trees_gini.iloc[:, :N_col], ax=ax)
ax.set(
    # title="Gini Importances",
    ylabel="Gini Importance\n"
)
_ = plt.setp(ax.xaxis.get_majorticklabels(), rotation=60, ha="right")

In [None]:
ind_trees_gini = pd.DataFrame(
    [tree.feature_importances_ for tree in model],
    columns=X_train.columns,
)
mean_importances = ind_trees_gini.mean().sort_values(ascending=False)
ind_trees_gini = ind_trees_gini.reindex(mean_importances.index, axis=1)

fig, ax = plt.subplots(1, 1, figsize=(10, 4), dpi=170)

N_col = 30

sns.boxplot(data=ind_trees_gini.iloc[:, :N_col], ax=ax)
ax.set(
    # title="Gini Importances",
    ylabel="Gini Importance\n"
)
_ = plt.setp(ax.xaxis.get_majorticklabels(), rotation=60, ha="right")

In [None]:
for feature in tqdm(X_train.columns, desc="1D ALE plots"):
    fig, axes = ale_plot(
        model,
        X_train,
        feature,
        bins=20,
        fig=plt.figure(figsize=(5.5, 2), dpi=150),
        quantile_axis=True,
        monte_carlo=True,
        monte_carlo_rep=10,
        monte_carlo_ratio=0.2,
    )
    plt.setp(axes["ale"].xaxis.get_majorticklabels(), rotation=60, ha="right")