In [None]:
from pathlib import Path
from warnings import filterwarnings

import iris
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from tqdm.auto import tqdm
from wildfires.data import homogenise_time_coordinate
from wildfires.utils import reorder_cube_coord

filterwarnings("ignore", ".*divide by zero.*")

In [None]:
source_dir = Path("~/JULES_output/jules_output5").expanduser()
assert source_dir.is_dir()

In [None]:
cubes = homogenise_time_coordinate(
    iris.load(str(source_dir / "*Monthly*.nc"))
).concatenate()

In [None]:
cubes

In [None]:
lai = cubes.extract_strict("Gridbox leaf area index")
lai

In [None]:
dpm = cubes.extract_strict("C in decomposable plant material, gridbox total")
dpm

In [None]:
_ = plt.hist(dpm.data.ravel())

In [None]:
_ = plt.hist(lai.data.ravel())

In [None]:
for i in range(12):
    dpm_data = np.roll(dpm.data, i, axis=0).ravel()
    lai_data = lai.data.ravel()

    plt.figure()
    plt.hexbin(dpm_data, lai_data, bins="log")
    plt.xlabel(f"DPM {i}")
    plt.ylabel("LAI")
    _ = plt.colorbar()

In [None]:
def get_climatologies(scubes):
    ccubes = iris.cube.CubeList()
    for cube in tqdm(scubes):
        if not cube.coords("month_number"):
            iris.coord_categorisation.add_month_number(cube, "time")
        ccube = cube.aggregated_by("month_number", iris.analysis.MEAN)

        sort_indices = np.argsort(ccube.coord("month_number").points)
        if not np.all(sort_indices == np.arange(len(sort_indices))):
            # Reorder cubes to let month numbers increase monotonically if needed.
            ccube = reorder_cube_coord(
                ccube, sort_indices, name="month_number", promote=True
            )
        ccubes.append(ccube)
    return ccubes

In [None]:
dpm_clim, lai_clim = get_climatologies([dpm, lai])

In [None]:
plt.figure()
plt.plot(dpm_clim[:, ..., 100].data)
plt.title("DPM")
plt.figure()
plt.plot(lai_clim[:, ..., 100].data)
plt.title("LAI")

In [None]:
shifts = np.arange(12)
corrs = []
for shift in shifts:
    corrs.append(
        np.corrcoef(
            dpm_clim.data.ravel(), np.roll(lai_clim.data, -shift, axis=0).data.ravel()
        )[0, 1]
    )
plt.plot(shifts, corrs)
plt.ylabel("DPM & shift-LAI Correlation")
_ = plt.xlabel("Shift (month)")

In [None]:
def min_max_scale(data):
    mins = np.min(data, axis=0).reshape(1, 1, -1)
    data = data - mins
    maxs = np.max(data, axis=0).reshape(1, 1, -1)
    return data / maxs

In [None]:
plt.figure()
plt.plot(min_max_scale(dpm_clim.data)[:, ..., 100])
plt.title("DPM")
plt.figure()
plt.plot(min_max_scale(lai_clim.data)[:, ..., 100])
plt.title("LAI")

In [None]:
shifts = np.arange(12)
corrs = []
for shift in shifts:
    corrs.append(
        np.corrcoef(
            min_max_scale(dpm_clim.data).ravel(),
            np.roll(min_max_scale(lai_clim.data), -shift, axis=0).data.ravel(),
        )[0, 1]
    )
plt.plot(shifts, corrs)
plt.ylabel("DPM & shift-LAI Correlation")
_ = plt.xlabel("Shift (month)")