In [None]:
import warnings

import cartopy.crs as ccrs
import iris
import matplotlib.pyplot as plt
import numpy as np

from wildfires.analysis import cube_plotting
from wildfires.data import ESA_CCI_Landcover_PFT, Ext_ESA_CCI_Landcover_PFT
from wildfires.logging_config import enable_logging
from wildfires.utils import get_land_mask, match_shape

warnings.filterwarnings("ignore", "Collapsing a non-contiguous coordinate.*time'.")

enable_logging(mode="jupyter")

In [None]:
def mask_water(cube):
    assert isinstance(cube.data, np.ndarray)
    if not hasattr(cube.data, "mask"):
        cube.data = np.ma.MaskedArray(
            cube.data, mask=np.zeros_like(cube.data, dtype=np.bool_)
        )
    cube.data.mask |= ~match_shape(get_land_mask(), cube.shape)
    return cube

In [None]:
lc = ESA_CCI_Landcover_PFT()
ext_lc = Ext_ESA_CCI_Landcover_PFT()
lc, ext_lc

In [None]:
lc.cubes

In [None]:
ext_lc.cubes

In [None]:
fig_kwargs = dict(figsize=(12, 3.5), dpi=200)

### Comparing average PFTs (note the different temporal extents)

In [None]:
for pft in lc.variable_names("raw"):
    fig, axes = plt.subplots(
        1, 2, subplot_kw=dict(projection=ccrs.Robinson()), **fig_kwargs
    )
    for ax, lc_dataset in zip(axes, (lc, ext_lc)):
        ax.set_title(lc_dataset.name)
        cube_plotting(
            mask_water(lc_dataset.cubes.extract_cube(iris.Constraint(pft))),
            ax=ax,
            title=pft,
        )

### Comparing maximum |temporal differences| per location and PFT

In [None]:
for pft in lc.variable_names("raw"):
    fig, axes = plt.subplots(
        1, 2, subplot_kw=dict(projection=ccrs.Robinson()), **fig_kwargs
    )
    for ax, lc_dataset in zip(axes, (lc, ext_lc)):
        pft_cube = lc_dataset.cubes.extract_cube(iris.Constraint(pft))
        pft_cube = pft_cube[0].copy(
            data=np.max(np.abs(np.diff(pft_cube.data, axis=0)), axis=0)
        )
        ax.set_title(lc_dataset.name)
        cube_plotting(mask_water(pft_cube), ax=ax, title=pft)

### Explicitly compare differences between the datasets

In [None]:
for pft in lc.variable_names("raw"):
    ext_cube, cube = (
        ext_lc.cubes.extract_cube(iris.Constraint("TreeAll"))[:24],
        lc.cubes.extract_cube(iris.Constraint("TreeAll")),
    )
    # Ensure the time coordinates are aligned.
    assert (
        ext_cube.coord("time").cell(0).point.year
        == cube.coord("time").cell(0).point.year
    )

    max_diffs = np.max(ext_cube.data - cube.data, axis=0)
    min_diffs = np.min(ext_cube.data - cube.data, axis=0)
    max_abs_diffs = np.where(-min_diffs > max_diffs, min_diffs, max_diffs)

    mean_diffs = np.mean(ext_cube.data - cube.data, axis=0)

    plot_kwargs = dict(title="", cmap="RdBu_r", cmap_midpoint=0, cmap_symmetric=True)

    fig, axes = plt.subplots(
        1, 2, subplot_kw=dict(projection=ccrs.Robinson()), **fig_kwargs
    )

    fig.suptitle(f"{pft} ({ext_lc.name} - {lc.name})")

    axes[0].set_title("Mean Diffs")
    cube_plotting(mask_water(cube[0].copy(data=mean_diffs)), ax=axes[0], **plot_kwargs)

    axes[1].set_title("Max Abs Diffs")
    cube_plotting(
        mask_water(cube[0].copy(data=max_abs_diffs)), ax=axes[1], **plot_kwargs
    )