In [None]:
from functools import reduce
from operator import add

In [None]:
from pathlib import Path
from warnings import filterwarnings

import iris
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from jules_output_analysis.utils import cube_1d_to_2d
from pygam import LinearGAM, s
from tqdm import tqdm
from tqdm.auto import tqdm
from wildfires.analysis import cube_plotting
from wildfires.data import homogenise_time_coordinate
from wildfires.utils import reorder_cube_coord

filterwarnings("ignore", ".*divide by zero.*")

In [None]:
def get_climatologies(scubes):
    ccubes = iris.cube.CubeList()
    for cube in tqdm(scubes):
        if not cube.coords("month_number"):
            iris.coord_categorisation.add_month_number(cube, "time")
        ccube = cube.aggregated_by("month_number", iris.analysis.MEAN)

        sort_indices = np.argsort(ccube.coord("month_number").points)
        if not np.all(sort_indices == np.arange(len(sort_indices))):
            # Reorder cubes to let month numbers increase monotonically if needed.
            ccube = reorder_cube_coord(
                ccube, sort_indices, name="month_number", promote=True
            )
        ccubes.append(ccube)
    return ccubes

In [None]:
source_dir = Path("~/JULES_output/jules_output5").expanduser()
assert source_dir.is_dir()

In [None]:
cubes = homogenise_time_coordinate(
    iris.load(str(source_dir / "*Monthly*.nc"))
).concatenate()

In [None]:
cubes

In [None]:
[cube.name() for cube in cubes]

In [None]:
names = [
    "Gridbox mean burnt area fraction",
    "Gridbox precipitation rate",
    "Gridbox effective radiative temperature (assuming emissivity=1)",
    "C in decomposable plant material, gridbox total",
    "Gridbox Absorbed Photosynthetically Active Radiation",
    "Gridbox leaf area index",
    "Gridbox unfrozen soil moisture as fraction of saturation",
]
scubes = cubes.extract(iris.Constraint(cube_func=lambda cube: cube.name() in names))
assert len(names) == len(scubes), scubes
ccubes = get_climatologies(scubes)
ccubes

In [None]:
for cube in ccubes:
    pdata = cube[:, 0, 100].data
    plt.plot(pdata / np.max(pdata), label=cube.var_name)
_ = plt.legend()

In [None]:
cube_2d = cube_1d_to_2d(ccubes[0], temporal_dim="month_number")
cube_2d

In [None]:
from wildfires.data import GFEDv4, regrid

gfed = GFEDv4()
gfed_clim = gfed.get_climatology_dataset(gfed.min_time, gfed.max_time)

gfed_clim_cube = regrid(
    gfed_clim.cube,
    new_latitudes=cube_2d.coord("latitude"),
    new_longitudes=cube_2d.coord("longitude"),
)
gfed_clim_cube

In [None]:
gfed_clim_cube_1d = ccubes.extract_strict("Gridbox mean burnt area fraction").copy(
    data=gfed_clim_cube.data.data[~cube_2d.data.mask].reshape(12, 1, -1)
)
gfed_clim_cube_1d.metadata = gfed_clim_cube.metadata
gfed_clim_cube_1d

In [None]:
lims = np.array([0, 0.001, 0.01, 0.02, 0.05, 0.08])
figsize = (10, 4)
_ = cube_plotting(
    cube_1d_to_2d(
        ccubes.extract_strict("Gridbox mean burnt area fraction"),
        temporal_dim="month_number",
    ),
    boundaries=lims / 8e6,
    fig=plt.figure(figsize=figsize),
)
_ = cube_plotting(
    cube_1d_to_2d(gfed_clim_cube_1d, temporal_dim="month_number"),
    boundaries=lims,
    fig=plt.figure(figsize=figsize),
)

In [None]:
def proc_names(names):
    scubes = cubes.extract(iris.Constraint(cube_func=lambda cube: cube.name() in names))
    assert len(names) == len(scubes), scubes
    ccubes = get_climatologies(scubes)
    return ccubes

In [None]:
ba_inferno_y_cube = proc_names(["Gridbox mean burnt area fraction"])[0]
gfed_y_cube = gfed_clim_cube_1d

In [None]:
ba_inferno_y = pd.Series(
    ba_inferno_y_cube.data.data.ravel(), name=ba_inferno_y_cube.name()
)
gfed_y = pd.Series(gfed_y_cube.data.ravel(), name=gfed_y_cube.name())

In [None]:
names = [
    "Gridbox precipitation rate",
    "Gridbox unfrozen soil moisture as fraction of saturation",
    "Gridbox effective radiative temperature (assuming emissivity=1)",
    "C in decomposable plant material, gridbox total",
    "Gridbox Absorbed Photosynthetically Active Radiation",
    #     "Gridbox leaf area index",
]
X_cubes = proc_names(names)
X = pd.DataFrame(
    np.hstack([cube.data.data.ravel().reshape(-1, 1) for cube in X_cubes]),
    columns=[cube.name() for cube in X_cubes],
)
X

In [None]:
gam = LinearGAM(reduce(add, (s(i) for i in range(X.shape[1])))).fit(X, ba_inferno_y)

gam.summary()

for i, term in enumerate(gam.terms):
    if term.isintercept:
        continue

    XX = gam.generate_X_grid(term=i)
    pdep, confi = gam.partial_dependence(term=i, X=XX, width=0.95)

    plt.figure()
    plt.plot(XX[:, term.feature], pdep)
    plt.plot(XX[:, term.feature], confi, c="r", ls="--")
    plt.title(X.columns[i])
    plt.show()

In [None]:
gam = LinearGAM(reduce(add, (s(i) for i in range(X.shape[1])))).fit(X, gfed_y)

gam.summary()

for i, term in enumerate(gam.terms):
    if term.isintercept:
        continue

    XX = gam.generate_X_grid(term=i)
    pdep, confi = gam.partial_dependence(term=i, X=XX, width=0.95)

    plt.figure()
    plt.plot(XX[:, term.feature], pdep)
    plt.plot(XX[:, term.feature], confi, c="r", ls="--")
    plt.title(X.columns[i])
    plt.show()