In [None]:
import math
from pathlib import Path
from warnings import filterwarnings

import cartopy.crs as ccrs
import dask.array as da
import iris
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from iris.coord_categorisation import add_day_of_year, add_month_number
from python_inferno.utils import exponential_average, temporal_nearest_neighbour_interp
from wildfires.configuration import DATA_DIR
from wildfires.utils import match_shape

from jules_output_analysis.data import (
    cube_1d_to_2d,
    get_n96e_land_mask,
    load_lat_lon_coords,
    regrid_to_n96e,
)
from jules_output_analysis.utils import PFTs, pft_acronyms, pft_names

filterwarnings("ignore", ".*divide by zero.*")
filterwarnings("ignore", ".*invalid units.*")
filterwarnings("ignore", ".*may not be fully.*")
filterwarnings("ignore", ".*axes.*")
filterwarnings("ignore")
mpl.rc_file("matplotlibrc")

In [None]:
source_file = str(Path("~/tmp/climatology5_c.nc").expanduser())
lat_coord, lon_coord = load_lat_lon_coords(source_file)

In [None]:
def get_closest(x, points):
    """Get element of `points` that is closest to x."""
    diffs = np.abs(points - x)
    return points[np.argmin(diffs)]

In [None]:
[c.name() for c in iris.load_raw(source_file)]

In [None]:
def monthly_climatology(cube):
    avg_cube = cube.aggregated_by("month_number", iris.analysis.MEAN)

    # Check that the points are correctly ordered.
    month_number_points = avg_cube.coord("month_number").points
    assert np.all(np.sort(month_number_points) == month_number_points)

    # Save final temporal coord.
    temp_coord = iris.coords.DimCoord.from_coord(avg_cube.coord("month_number"))

    # Remove all other temporal coordinates.
    for coord in avg_cube.coords(dimensions=0):
        avg_cube.remove_coord(coord)

    # Add the stored coord.
    avg_cube.add_dim_coord(temp_coord, 0)

    assert len(avg_cube.coords(dimensions=0)) == 1

    return avg_cube

In [None]:
def daily_climatology(cube):
    avg_cube = cube.aggregated_by("day_of_year", iris.analysis.MEAN)

    # Check that the points are correctly ordered.
    points = avg_cube.coord("day_of_year").points
    assert np.all(np.sort(points) == points)

    # Save final temporal coord.
    temp_coord = iris.coords.DimCoord.from_coord(avg_cube.coord("day_of_year"))

    # Remove all other temporal coordinates.
    for coord in avg_cube.coords(dimensions=0):
        avg_cube.remove_coord(coord)

    # Add the stored coord.
    avg_cube.add_dim_coord(temp_coord, 0)

    assert len(avg_cube.coords(dimensions=0)) == 1

    return avg_cube

In [None]:
def frac_dim_check(cube_2d):
    if (len(cube_2d.shape) > 2) and cube_2d.shape[-3] == 13:
        return True
    return False

In [None]:
def load_data(name, data, frac_cube=None):
    if frac_cube is None and name != "frac":
        raise RuntimeError(
            "We should only miss `frac_cube` if 'frac' itself is being loaded."
        )

    data["raw_cube"] = iris.load_cube(source_file, constraint=data["var_name"])
    add_month_number(data["raw_cube"], "time")
    add_day_of_year(data["raw_cube"], "time")

    N = len(data["raw_cube"].shape)
    if not data["raw_cube"].coords("latitude"):
        data["raw_cube"].add_aux_coord(
            lat_coord,
            (N - 2, N - 1),
        )
    if not data["raw_cube"].coords("longitude"):
        data["raw_cube"].add_aux_coord(
            lon_coord,
            (N - 2, N - 1),
        )

    if frac_cube is not None and frac_dim_check(data["raw_cube"]):
        assert frac_cube[..., :13, :, :].shape == data["raw_cube"].shape
        assert frac_cube.shape[-3] == 17

        # 'Select' the first PFT in order to have a template for the resulting cube.
        # Fill this template with the actual PFT-weighted mean.
        data["w_frac_raw_cube"] = data["raw_cube"][..., 0, :, :].copy(
            data=da.sum(
                frac_cube[..., :13, :, :].lazy_data() * data["raw_cube"].lazy_data(),
                axis=-3,
            )
            / da.sum(frac_cube[..., :13, :, :].lazy_data(), axis=-3)
        )

    for proc_name in (
        "raw_cube",
        *(("w_frac_raw_cube",) if "w_frac_raw_cube" in data else ()),
    ):
        new_name = proc_name.replace("raw_cube", "avg_cube")
        data[new_name] = data[proc_name].collapsed("time", iris.analysis.MEAN)
        data[f"mon_{new_name}"] = monthly_climatology(data[proc_name])
        data[f"day_{new_name}"] = daily_climatology(data[proc_name])

    print(name)

    for avg_name, check_n in (
        ("avg_cube", 2),
        # ("mon_avg_cube", 3),
        # ("day_avg_cube", 3),
        *(
            (
                ("w_frac_avg_cube", 2),
                # ("mon_w_frac_vg_cube", 3),
                # ("day_w_frac_vg_cube", 3),
            )
            if "w_frac_raw_cube" in data
            else ()
        ),
    ):
        N = len(data[avg_name].shape)
        assert N >= check_n

        data[f"{avg_name}_2d"] = cube_1d_to_2d(data[avg_name])

        print(data[avg_name].shape, data[f"{avg_name}_2d"].shape)

In [None]:
variables = {
    "pft_lai": dict(
        var_name="PFT leaf area index",
        label="LAI (1)",
        name="JULES LAI",
    ),
    "pft_fapar": dict(
        var_name="PFT Fraction of Absorbed Photosynthetically Active Radiation",
        label="FAPAR (1)",
        name="JULES FAPAR",
    ),
    "frac": dict(
        var_name="Fractional cover of each surface type",
        label="1",
        name="JULES Frac",
    ),
    "pft_gpp": dict(
        var_name="PFT gross primary productivity",
        label="GPP",
        name="JULES GPP",
    ),
    "pft_npp": dict(
        var_name="PFT net primary productivity prior to N limitation",
        label="NPP",
        name="JULES NPP",
    ),
    "ba": dict(
        var_name="Gridbox mean burnt area fraction",
        label="BA",
        name="JULES BA",
    ),
}

# Load frac data first.
load_data("frac", variables["frac"])

for name, data in variables.items():
    if name == "frac":
        continue

    load_data(name, data, variables["frac"]["raw_cube"])

In [None]:
for name, val in variables["pft_lai"].items():
    if hasattr(val, "shape"):
        print(name, val.shape)

In [None]:
ref_variables = dict(
    LAI=dict(
        filename=Path(DATA_DIR) / "LAI_climatology.nc",
    ),
    FAPAR=dict(
        filename=Path(DATA_DIR) / "FAPAR_climatology.nc",
    ),
    BA=dict(
        filename=Path(DATA_DIR) / "GFED4_climatology.nc",
    ),
)
for name, data in ref_variables.items():
    cube_2d = regrid_to_n96e(iris.load_cube(str(data["filename"])))
    cube_2d.data.mask |= match_shape(
        ~get_n96e_land_mask(),
        cube_2d.shape,
    )
    data["mon_avg_cube_2d"] = cube_2d

In [None]:
def extract_lat_lon(cube, latitude, longitude):
    return cube.extract(
        constraint=iris.Constraint(
            latitude=lambda cell: latitude - 1e-8 < cell.point < latitude + 1e-8
        )
    ).extract(
        constraint=iris.Constraint(
            longitude=lambda cell: longitude - 1e-8 < cell.point < longitude + 1e-8
        )
    )

In [None]:
latitude = get_closest(8, lat_coord.points.ravel())
longitude = get_closest(18, lon_coord.points.ravel())

fig, ax = plt.subplots(subplot_kw=dict(projection=ccrs.Robinson()))
ax.plot(longitude, latitude, linestyle="", marker="x", transform=ccrs.PlateCarree())
ax.set_global()
ax.coastlines()

for name, data in variables.items():
    if "pft" not in name:
        continue

    fig, axes = plt.subplots(3, 5, figsize=(14, 7))
    axes = axes.ravel()
    for ax in axes[-2:]:
        ax.axis("off")

    for (i, (ax, pft_name, pft_acr)) in enumerate(
        zip(axes, pft_names[PFTs.VEG13], pft_acronyms[PFTs.VEG13])
    ):
        ax.set_title(f"{pft_name} ({pft_acr})")

        lines = []
        labels = []

        for plot_ax, cube, label, c in (
            (ax, data["day_avg_cube"], name, "C0"),
            (ax.twinx(), variables["frac"]["day_avg_cube"], "frac", "C1"),
        ):
            plot_ax.plot(
                cube.coord("day_of_year").points,
                extract_lat_lon(
                    cube[..., i, 0, :],
                    latitude,
                    longitude,
                )[:].data,
                linestyle="--",
                marker="x",
                alpha=0.2,
                label=label,
                c=c,
            )
            new_lines, new_labels = plot_ax.get_legend_handles_labels()
            lines.extend(new_lines)
            labels.extend(new_labels)

        ax.legend(lines, labels)

    fig.tight_layout(rect=[0, 0, 1, 0.95])
    fig.suptitle(name)

In [None]:
for olat, olon in (
    [8, 18],
    [-18.6, 131.5],
    [-0.3, 20.6],
    [-20, 307],
):
    latitude = get_closest(olat, lat_coord.points.ravel())
    longitude = get_closest(olon, lon_coord.points.ravel())

    fig, ax = plt.subplots(subplot_kw=dict(projection=ccrs.Robinson()))
    ax.plot(longitude, latitude, linestyle="", marker="x", transform=ccrs.PlateCarree())
    ax.set_global()
    ax.coastlines()
    ax.set_title(f"lat={latitude}, lon={longitude}")

    plot_vars = {name: data for name, data in variables.items() if name != "frac"}

    N = len(plot_vars)
    ncols = math.ceil(N ** 0.5)
    nrows = math.ceil(N / ncols)

    fig, axes = plt.subplots(
        nrows, ncols, figsize=np.array([2.5, 3]) * np.array([ncols, nrows])
    )
    axes = axes.ravel()
    if ncols * nrows > N:
        for ax in axes[-(ncols * nrows - N) :]:
            ax.axis("off")

    for (ax, (name, data)) in zip(axes, plot_vars.items()):
        ax.set_title(name)

        xs = data["day_avg_cube"].coord("day_of_year").points
        points = extract_lat_lon(
            data.get("day_w_frac_avg_cube", data["day_avg_cube"])[..., 0, :],
            latitude,
            longitude,
        ).data
        plot_kwargs = dict(linestyle="--", marker="x", alpha=0.2)

        ax.plot(xs, points, label="JULES", **plot_kwargs)

        # Calculate the antecedent fuel build-up metric.
        # This uses the fact that we are using data that is exported by the model every 4
        # timesteps.
        # Repeat the averaging procedure in order to reach convergence for a more
        # realistic depiction of the averaged parameter.
        if name != "ba":
            for alpha, label in (
                (4.6e-4, ""),
                (1e-4, "*"),
                (1e-3, "+"),
            ):
                ax.plot(
                    xs,
                    exponential_average(
                        temporal_nearest_neighbour_interp(points, 4),
                        alpha,
                        repetitions=10,
                    )[::4],
                    label=f"Antec JULES{label}",
                    **plot_kwargs,
                )

        obs_name = name.split("_")[-1].upper()

        if obs_name in ref_variables:
            ax.plot(
                (365 / 12) * (0.5 + np.arange(12)),
                extract_lat_lon(
                    ref_variables[obs_name]["mon_avg_cube_2d"],
                    latitude,
                    longitude,
                ).data
                / (30 * 24 * 60 * 60 if name == "ba" else 1),
                linestyle="-.",
                marker="o",
                alpha=0.6,
                label="Obs",
            )

        ax.legend()

    fig.tight_layout()