In [None]:
from pathlib import Path
from warnings import filterwarnings

import iris
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from IPython.core.display import HTML, display
from tqdm.auto import tqdm
from wildfires.analysis import cube_plotting
from wildfires.configuration import DATA_DIR
from wildfires.data import regions_GFED
from wildfires.utils import match_shape

from jules_output_analysis.data import (
    cube_1d_to_2d,
    dummy_lat_lon_cube,
    frac_weighted_mean,
    get_climatology_cube,
    get_n96e_land_mask,
    load_lat_lon_coords,
    regrid_to_n96e,
)

filterwarnings("ignore", ".*divide by zero.*")
filterwarnings("ignore", ".*invalid units.*")
filterwarnings("ignore", ".*may not be fully.*")
filterwarnings("ignore", ".*axes.*")
filterwarnings("ignore")
mpl.rc_file("matplotlibrc")

In [None]:
source_file = str(Path("~/tmp/climatology5_c.nc").expanduser())
lat_coord, lon_coord = load_lat_lon_coords(source_file)

In [None]:
[c.name() for c in iris.load_raw(source_file)]

In [None]:
raw_lai = iris.load_cube(source_file, constraint="PFT leaf area index")
raw_fapar = iris.load_cube(
    source_file,
    constraint="PFT Fraction of Absorbed Photosynthetically Active Radiation",
)
raw_frac = iris.load_cube(
    source_file, constraint="Fractional cover of each surface type"
)
avg_lai = raw_lai.collapsed("time", iris.analysis.MEAN)
avg_fapar = raw_fapar.collapsed("time", iris.analysis.MEAN)
avg_frac = raw_frac.collapsed("time", iris.analysis.MEAN)

for cube in avg_lai, avg_fapar, avg_frac:
    cube.add_aux_coord(lat_coord, (1, 2))
    cube.add_aux_coord(lon_coord, (1, 2))

avg_lai_2d = cube_1d_to_2d(avg_lai)
avg_fapar_2d = cube_1d_to_2d(avg_fapar)
avg_frac_2d = cube_1d_to_2d(avg_frac)

# [..., 0, :]

avg_lai, avg_fapar, avg_frac

In [None]:
avg_lai_2d

In [None]:
avg_lai_2d.shape

In [None]:
def frac_weighted_mean(cube_2d):
    assert cube_2d.shape[0] == 13
    assert avg_frac_2d.shape[0] == 17
    return np.sum(avg_frac_2d[:13].data * cube_2d.data, axis=0) / np.sum(
        avg_frac_2d[:13].data, axis=0
    )

In [None]:
_ = cube_plotting(
    frac_weighted_mean(avg_lai_2d), title="JULES LAI", colorbar_kwargs=dict(label="LAI")
)

In [None]:
_ = cube_plotting(
    frac_weighted_mean(avg_fapar_2d),
    title="JULES FAPAR",
    colorbar_kwargs=dict(label="FAPAR"),
)

### FAPAR and Antecedent FAPAR (fuel build up) for certain pixels (small regions)

In [None]:
# lat = 0
# lon = 15
# constraint = iris.Constraint(
#     latitude=lambda c: lat < c.point < lat + 5,
#     longitude=lambda c: lon < c.point < lon + 5,
# )

# plt.plot(
#     dummy_lat_lon_cube(fapar_2d)
#     .extract(constraint)
#     .collapsed(("latitude", "longitude"), iris.analysis.MEAN)
#     .data,
#     label="fapar",
# )
# plt.plot(
#     dummy_lat_lon_cube(fuel_build_up_2d)
#     .extract(constraint)
#     .collapsed(("latitude", "longitude"), iris.analysis.MEAN)
#     .data,
#     label="fuel build up",
# )
# plt.legend()
# _ = plt.title(f"lat: {lat}, lon: {lon}")

### Load observed, reference LAI and FAPAR

In [None]:
target = "Obs. LAI"

ref_lai_cube_2d = regrid_to_n96e(
    iris.load_cube(str(Path(DATA_DIR) / "LAI_climatology.nc"))
)
ref_lai_cube_2d.data.mask |= match_shape(
    ~get_n96e_land_mask(),
    ref_lai_cube_2d.shape,
)

display(HTML(ref_lai_cube_2d._repr_html_()))

ref_avg_lai_cube_2d = ref_lai_cube_2d.collapsed("time", iris.analysis.MEAN)
ref_max_lai_cube_2d = ref_lai_cube_2d.collapsed("time", iris.analysis.MAX)
ref_std_lai_cube_2d = ref_avg_lai_cube_2d.copy(
    data=np.std(ref_lai_cube_2d.data, axis=0)
)

fig = cube_plotting(ref_avg_lai_cube_2d, title=f"{target} mean")
fig = cube_plotting(ref_max_lai_cube_2d, title=f"{target} max")
fig = cube_plotting(ref_std_lai_cube_2d, title=f"{target} std")
fig = cube_plotting(
    ref_std_lai_cube_2d / ref_avg_lai_cube_2d, title=f"{target} std / mean"
)

In [None]:
target = "Obs. FAPAR"

ref_fapar_cube_2d = regrid_to_n96e(
    iris.load_cube(str(Path(DATA_DIR) / "FAPAR_climatology.nc"))
)
ref_fapar_cube_2d.data.mask |= match_shape(
    ~get_n96e_land_mask(),
    ref_fapar_cube_2d.shape,
)

display(HTML(ref_fapar_cube_2d._repr_html_()))

ref_avg_fapar_cube_2d = ref_fapar_cube_2d.collapsed("time", iris.analysis.MEAN)
ref_max_fapar_cube_2d = ref_fapar_cube_2d.collapsed("time", iris.analysis.MAX)
ref_std_fapar_cube_2d = ref_avg_fapar_cube_2d.copy(
    data=np.std(ref_fapar_cube_2d.data, axis=0)
)

fig = cube_plotting(ref_avg_fapar_cube_2d, title=f"{target} mean")
fig = cube_plotting(ref_max_fapar_cube_2d, title=f"{target} max")
fig = cube_plotting(ref_std_fapar_cube_2d, title=f"{target} std")
fig = cube_plotting(
    ref_std_fapar_cube_2d / ref_avg_fapar_cube_2d, title=f"{target} std / mean"
)

In [None]:
regions = regions_GFED()
regions

In [None]:
regions_map = regions.attributes["regions"]
regions_map

In [None]:
# Mask out the oceans.
regions.data = np.ma.MaskedArray(regions.data, mask=regions.data == 0)

#### Regrid to N96e

In [None]:
n96e_regions = regrid_to_n96e(regions)

In [None]:
n96e_regions.data.mask = n96e_regions.data == 0

In [None]:
# Apply land mask.
n96e_regions.data.mask |= ~get_n96e_land_mask()

In [None]:
fig = cube_plotting(regions, boundaries=np.arange(1, 16) - 0.5)
fig = cube_plotting(n96e_regions, boundaries=np.arange(1, 16) - 0.5)

#### Climatology comparison by region - with JULES data averaged over PFTs

In [None]:
# Check that the regions selection is working.
# for data_cube in tqdm([pft_avg_cube_2d, ref_cube_2d]):
#     for region in tqdm(range(1, 15)):  # Exclude the ocean.
#         mask = match_shape(n96e_regions.data == region, data_cube.shape)
#         plot_cube = data_cube.copy()
#         plot_cube.data.mask |= ~mask
#         cube_plotting(
#             plot_cube, fig=plt.figure(figsize=(3, 1), dpi=100), title=str(region)
#         )

In [None]:
# Note that the spatial averaging done here is not area weighted!
for region in tqdm(range(1, 15)):  # Exclude the ocean.
    fig, ax = plt.subplots(1, 1)
    ax.set_title(regions_map[region])
    ax2 = ax.twinx()

    handles = []

    for label, data, ls in zip(
        ["JULES FAPAR", "OBS FAPAR", "JULES LAI", "OBS LAI"],
        [
            get_climatology_cube(dummy_lat_lon_cube(fapar_2d)).data,
            ref_fapar_cube_2d.data,
            get_climatology_cube(dummy_lat_lon_cube(lai_2d)).data,
            ref_lai_cube_2d.data,
        ],
        ["--", "-", "--", "-"],
    ):
        mask = match_shape(n96e_regions.data == region, data.shape)
        plot_data = data.copy()
        plot_data.mask |= ~mask
        if "FAPAR" in label:
            plot_ax = ax
            plot_ax.set_ylabel("FAPAR")
            color = "C0"
            alpha = 1.0
            zorder = 2
        elif "LAI" in label:
            plot_ax = ax2
            plot_ax.set_ylabel("LAI")
            color = "C1"
            alpha = 0.7
            zorder = 1

        handles.append(
            plot_ax.errorbar(
                x=np.arange(1, plot_data.shape[0] + 1),
                y=np.mean(plot_data, axis=(1, 2)),
                yerr=np.std(plot_data, axis=(1, 2)),
                capsize=4,
                label=label,
                linestyle=ls,
                color=color,
                alpha=alpha,
                zorder=zorder,
            )
        )
    ax.legend(handles=handles, ncol=1, bbox_to_anchor=(1.1, 1.12), loc="upper left")
    ax.set_xlabel("month")

### Global comparison of FAPAR

In [None]:
combined_mask = fapar_2d.mask | ref_fapar_cube_2d.data.mask
plt.hexbin(
    ref_fapar_cube_2d.data.data[~combined_mask],
    fapar_2d.data[~combined_mask],
    bins="log",
)
plt.colorbar()
plt.xlabel("OBS FAPAR")
plt.ylabel("JULES FAPAR")
xlim = plt.xlim()
ylim = plt.ylim()
plt.plot(np.linspace(0, 10, 100), np.linspace(0, 10, 100), ls="--", c="C3")
plt.xlim(xlim)
_ = plt.ylim(ylim)

### Global comparison of LAI

In [None]:
combined_mask = lai_2d.mask | ref_lai_cube_2d.data.mask
plt.hexbin(
    ref_lai_cube_2d.data.data[~combined_mask], lai_2d.data[~combined_mask], bins="log"
)
plt.colorbar()
plt.xlabel("OBS LAI")
plt.ylabel("JULES LAI")
xlim = plt.xlim()
ylim = plt.ylim()
plt.plot(np.linspace(0, 10, 100), np.linspace(0, 10, 100), ls="--", c="C3")
plt.xlim(xlim)
_ = plt.ylim(ylim)