In [None]:
from pathlib import Path
from warnings import filterwarnings

import iris
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from IPython.core.display import HTML, display
from tqdm.auto import tqdm
from wildfires.analysis import cube_plotting
from wildfires.configuration import DATA_DIR
from wildfires.data import regions_GFED
from wildfires.utils import match_shape

from jules_output_analysis.data import (
    dummy_lat_lon_cube,
    frac_weighted_mean,
    get_climatology_cube,
    get_n96e_land_mask,
    load_jules_data,
    regrid_to_n96e,
)
from jules_output_analysis.utils import PFTs, pft_names

filterwarnings("ignore", ".*divide by zero.*")
filterwarnings("ignore", ".*invalid units.*")
filterwarnings("ignore", ".*may not be fully.*")
filterwarnings("ignore", ".*axes.*")
filterwarnings("ignore")
mpl.rc_file("matplotlibrc")

In [None]:
# source_dir = Path("/work/scratch-nopw/alexkr/cru_test_mod_copy_fixed_pft")  # Fixed PFT run
# source_dir = Path("/work/scratch-nopw/alexkr/cru_test_mod_copy")  # New run
# source_dir = Path("/work/scratch-nopw/alexkr/ignition3_1/jules_output")  # Old run, 'wrong' FAPAR
# source_dir = Path("/work/scratch-nopw/alexkr/newrun")
source_dir = Path("/work/scratch-pw/alexkr/new-with-antec")
assert source_dir.is_dir()

In [None]:
file_patterns = [
    str(
        source_dir
        / f"JULES-ES.1p0.vn5.4.50.CRUJRA1.365.HYDE33.SPINUP0.Monthly.{year}.nc"
    )
    for year in range(2000, 2001)
]

In [None]:
frac_cube = load_jules_data(
    file_patterns,
    "Fractional cover of each surface type",
    n_pfts=13,
    frac_cube=None,
    single=True,
)
frac_cube

In [None]:
for i in range(frac_cube.shape[1]):
    cube_plotting(frac_cube[0, i], title=pft_names[PFTs.VEG13_ALL][i])

In [None]:
raw_cubes = iris.load(
    str(source_dir / "JULES-ES.1p0.vn5.4.50.CRUJRA1.365.HYDE33.SPINUP0.Monthly.2000.nc")
)
raw_cubes

In [None]:
target = "Gridbox mean burnt area fraction"
ba_cube = load_jules_data(
    "/work/scratch-pw/katie_b1/u-cd730/JULES-ES.1p0.vn5.4.50.CRUJRA1.365.HYDE33.S3.Monthly.2000.nc",
    target,
    n_pfts=13,
    frac_cube=None,
    single=True,
)

avg_data_2d = np.mean(ba_cube.data, axis=0)
max_data_2d = np.max(ba_cube.data, axis=0)
std_data_2d = np.std(ba_cube.data, axis=0)

fig = cube_plotting(avg_data_2d, title=f"{target} mean")
fig = cube_plotting(max_data_2d, title=f"{target} max")

In [None]:
target = "PFT gross primary productivity"
cube_2d = load_jules_data(
    "/work/scratch-pw/katie_b1/u-cd730/JULES-ES.1p0.vn5.4.50.CRUJRA1.365.HYDE33.S3.Monthly.2000.nc",
    target,
    n_pfts=13,
    frac_cube=frac_cube[:12],
    single=True,
)

# The unweighted mean is, as expected, markedly different from the proper weighted average.
# fig = cube_plotting(
#     cube_2d.collapsed(('time', 'pft'), iris.analysis.MEAN),
#     title=f'{target} Unweighted MEAN'
# )

# Take weighted mean, weighted by frac, but only the X natural PFTs (e.g. 13)
agg_data_2d = frac_weighted_mean(cube_2d[:12], frac_cube[:12], n_pfts=13)

avg_data_2d = np.mean(agg_data_2d, axis=0)
max_data_2d = np.max(agg_data_2d, axis=0)
std_data_2d = np.std(agg_data_2d, axis=0)

fig = cube_plotting(avg_data_2d, title=f"{target} mean")
fig = cube_plotting(max_data_2d, title=f"{target} max")
fig = cube_plotting(std_data_2d, title=f"{target} std")
fig = cube_plotting(std_data_2d / avg_data_2d, title=f"{target} std / mean")

In [None]:
target = "Gridbox gross primary productivity"
cube_2d = load_jules_data(
    "/work/scratch-pw/katie_b1/u-cd730/JULES-ES.1p0.vn5.4.50.CRUJRA1.365.HYDE33.S3.Monthly.2000.nc",
    target,
    n_pfts=13,
    frac_cube=None,
    single=True,
)
fig = cube_plotting(np.mean(cube_2d.data, axis=0), title=f"{target} mean")
fig = cube_plotting(np.max(cube_2d.data, axis=0), title=f"{target} max")
fig = cube_plotting(np.std(cube_2d.data, axis=0), title=f"{target} std")

In [None]:
target = "PFT leaf area index"
lai_2d = frac_weighted_mean(
    load_jules_data(
        str(
            source_dir
            / "JULES-ES.1p0.vn5.4.50.CRUJRA1.365.HYDE33.SPINUP0.Monthly.2000.nc"
        ),
        target,
        n_pfts=13,
        frac_cube=frac_cube[:12],
        single=True,
    )[:12],
    frac_cube[:12],
    n_pfts=13,
)

avg_lai_2d = np.mean(lai_2d, axis=0)
max_lai_2d = np.max(lai_2d, axis=0)
std_lai_2d = np.std(lai_2d, axis=0)

fig = cube_plotting(avg_lai_2d, title=f"{target} mean")
fig = cube_plotting(max_lai_2d, title=f"{target} max")
fig = cube_plotting(std_lai_2d, title=f"{target} std")

In [None]:
target = "PFT Fraction of Absorbed Photosynthetically Active Radiation"
fapar_2d = frac_weighted_mean(
    load_jules_data(
        str(
            source_dir
            / "JULES-ES.1p0.vn5.4.50.CRUJRA1.365.HYDE33.SPINUP0.Monthly.2000.nc"
        ),
        target,
        n_pfts=13,
        frac_cube=frac_cube[:12],
        single=True,
    )[:12],
    frac_cube[:12],
    n_pfts=13,
)

avg_fapar_2d = np.mean(fapar_2d, axis=0)
max_fapar_2d = np.max(fapar_2d, axis=0)
std_fapar_2d = np.std(fapar_2d, axis=0)

fig = cube_plotting(avg_fapar_2d, title=f"{target} mean")
fig = cube_plotting(max_fapar_2d, title=f"{target} max")
fig = cube_plotting(std_fapar_2d, title=f"{target} std")

In [None]:
target = "PFT fuel build up"
fuel_build_up_2d = frac_weighted_mean(
    load_jules_data(
        str(
            source_dir
            / "JULES-ES.1p0.vn5.4.50.CRUJRA1.365.HYDE33.SPINUP0.Monthly.2005.nc"
        ),
        target,
        n_pfts=13,
        frac_cube=frac_cube[:12],
        single=True,
    )[:12],
    frac_cube[:12],
    n_pfts=13,
)

In [None]:
fig = cube_plotting(fuel_build_up_2d[0], title="fuel build up 0")
fig = cube_plotting(fapar_2d[0], title="fapar 0")

### FAPAR and Antecedent FAPAR (fuel build up) for certain pixels (small regions)

In [None]:
lat = 0
lon = 15
constraint = iris.Constraint(
    latitude=lambda c: lat < c.point < lat + 5,
    longitude=lambda c: lon < c.point < lon + 5,
)

plt.plot(
    dummy_lat_lon_cube(fapar_2d)
    .extract(constraint)
    .collapsed(("latitude", "longitude"), iris.analysis.MEAN)
    .data,
    label="fapar",
)
plt.plot(
    dummy_lat_lon_cube(fuel_build_up_2d)
    .extract(constraint)
    .collapsed(("latitude", "longitude"), iris.analysis.MEAN)
    .data,
    label="fuel build up",
)
plt.legend()
_ = plt.title(f"lat: {lat}, lon: {lon}")

### Load observed, reference LAI and FAPAR

In [None]:
target = "Obs. LAI"

ref_lai_cube_2d = regrid_to_n96e(
    iris.load_cube(str(Path(DATA_DIR) / "LAI_climatology.nc"))
)
ref_lai_cube_2d.data.mask |= match_shape(
    ~get_n96e_land_mask(),
    ref_lai_cube_2d.shape,
)

display(HTML(ref_lai_cube_2d._repr_html_()))

ref_avg_lai_cube_2d = ref_lai_cube_2d.collapsed("time", iris.analysis.MEAN)
ref_max_lai_cube_2d = ref_lai_cube_2d.collapsed("time", iris.analysis.MAX)
ref_std_lai_cube_2d = ref_avg_lai_cube_2d.copy(
    data=np.std(ref_lai_cube_2d.data, axis=0)
)

fig = cube_plotting(ref_avg_lai_cube_2d, title=f"{target} mean")
fig = cube_plotting(ref_max_lai_cube_2d, title=f"{target} max")
fig = cube_plotting(ref_std_lai_cube_2d, title=f"{target} std")
fig = cube_plotting(
    ref_std_lai_cube_2d / ref_avg_lai_cube_2d, title=f"{target} std / mean"
)

In [None]:
target = "Obs. FAPAR"

ref_fapar_cube_2d = regrid_to_n96e(
    iris.load_cube(str(Path(DATA_DIR) / "FAPAR_climatology.nc"))
)
ref_fapar_cube_2d.data.mask |= match_shape(
    ~get_n96e_land_mask(),
    ref_fapar_cube_2d.shape,
)

display(HTML(ref_fapar_cube_2d._repr_html_()))

ref_avg_fapar_cube_2d = ref_fapar_cube_2d.collapsed("time", iris.analysis.MEAN)
ref_max_fapar_cube_2d = ref_fapar_cube_2d.collapsed("time", iris.analysis.MAX)
ref_std_fapar_cube_2d = ref_avg_fapar_cube_2d.copy(
    data=np.std(ref_fapar_cube_2d.data, axis=0)
)

fig = cube_plotting(ref_avg_fapar_cube_2d, title=f"{target} mean")
fig = cube_plotting(ref_max_fapar_cube_2d, title=f"{target} max")
fig = cube_plotting(ref_std_fapar_cube_2d, title=f"{target} std")
fig = cube_plotting(
    ref_std_fapar_cube_2d / ref_avg_fapar_cube_2d, title=f"{target} std / mean"
)

In [None]:
regions = regions_GFED()
regions

In [None]:
regions_map = regions.attributes["regions"]
regions_map

In [None]:
# Mask out the oceans.
regions.data = np.ma.MaskedArray(regions.data, mask=regions.data == 0)

#### Regrid to N96e

In [None]:
n96e_regions = regrid_to_n96e(regions)

In [None]:
n96e_regions.data.mask = n96e_regions.data == 0

In [None]:
# Apply land mask.
n96e_regions.data.mask |= ~get_n96e_land_mask()

In [None]:
fig = cube_plotting(regions, boundaries=np.arange(1, 16) - 0.5)
fig = cube_plotting(n96e_regions, boundaries=np.arange(1, 16) - 0.5)

#### Climatology comparison by region - with JULES data averaged over PFTs

In [None]:
# Check that the regions selection is working.
# for data_cube in tqdm([pft_avg_cube_2d, ref_cube_2d]):
#     for region in tqdm(range(1, 15)):  # Exclude the ocean.
#         mask = match_shape(n96e_regions.data == region, data_cube.shape)
#         plot_cube = data_cube.copy()
#         plot_cube.data.mask |= ~mask
#         cube_plotting(
#             plot_cube, fig=plt.figure(figsize=(3, 1), dpi=100), title=str(region)
#         )

In [None]:
# Note that the spatial averaging done here is not area weighted!
for region in tqdm(range(1, 15)):  # Exclude the ocean.
    fig, ax = plt.subplots(1, 1)
    ax.set_title(regions_map[region])
    ax2 = ax.twinx()

    handles = []

    for label, data, ls in zip(
        ["JULES FAPAR", "OBS FAPAR", "JULES LAI", "OBS LAI"],
        [
            get_climatology_cube(dummy_lat_lon_cube(fapar_2d)).data,
            ref_fapar_cube_2d.data,
            get_climatology_cube(dummy_lat_lon_cube(lai_2d)).data,
            ref_lai_cube_2d.data,
        ],
        ["--", "-", "--", "-"],
    ):
        mask = match_shape(n96e_regions.data == region, data.shape)
        plot_data = data.copy()
        plot_data.mask |= ~mask
        if "FAPAR" in label:
            plot_ax = ax
            plot_ax.set_ylabel("FAPAR")
            color = "C0"
            alpha = 1.0
            zorder = 2
        elif "LAI" in label:
            plot_ax = ax2
            plot_ax.set_ylabel("LAI")
            color = "C1"
            alpha = 0.7
            zorder = 1

        handles.append(
            plot_ax.errorbar(
                x=np.arange(1, plot_data.shape[0] + 1),
                y=np.mean(plot_data, axis=(1, 2)),
                yerr=np.std(plot_data, axis=(1, 2)),
                capsize=4,
                label=label,
                linestyle=ls,
                color=color,
                alpha=alpha,
                zorder=zorder,
            )
        )
    ax.legend(handles=handles, ncol=1, bbox_to_anchor=(1.1, 1.12), loc="upper left")
    ax.set_xlabel("month")