In [None]:
import calendar
import logging
import sys
import warnings

import cartopy.crs as ccrs
import iris
import matplotlib as mpl
import numpy as np
from iris.time import PartialDateTime
from loguru import logger as loguru_logger
from wildfires.data import Datasets, Ext_MOD15A2H_fPAR
from wildfires.utils import match_shape

from empirical_fire_modelling.logging_config import enable_logging
from empirical_fire_modelling.plotting import cube_plotting, map_figure_saver
from empirical_fire_modelling.utils import tqdm

mpl.rc_file("../matplotlibrc")

loguru_logger.enable("alepython")
loguru_logger.remove()
loguru_logger.add(sys.stderr, level="WARNING")

logger = logging.getLogger(__name__)
enable_logging(level="WARNING")

warnings.filterwarnings("ignore", ".*Collapsing a non-contiguous coordinate.*")
warnings.filterwarnings("ignore", ".*DEFAULT_SPHERICAL_EARTH_RADIUS.*")
warnings.filterwarnings("ignore", ".*guessing contiguous bounds.*")

warnings.filterwarnings(
    "ignore", 'Setting feature_perturbation = "tree_path_dependent".*'
)

In [None]:
def _persistent_gap_filling(cube, thres=0.5, verbose=False):
    """Fill gaps >= (thres * 100)% of months with minimum value at that location.

    This is done in-place.

    """
    if not cube.coords("month_number"):
        iris.coord_categorisation.add_month_number(cube, "time")

    combined_mask = np.all(cube.data.mask, axis=0)

    nr_inval_cube = cube.copy(
        data=np.ma.MaskedArray(
            cube.data.mask, mask=match_shape(combined_mask, cube.shape)
        )
    )

    min_cube = cube.collapsed("time", iris.analysis.MIN)

    # Month numbers in [1, 12].
    month_numbers = cube.coord("month_number").points

    fill_masks = {}

    for month_number in tqdm(range(1, 13), desc="Months", disable=not verbose):
        extracted = iris.Constraint(month_number=month_number).extract(nr_inval_cube)
        missing_frac = np.sum(extracted.data, axis=0) / extracted.shape[0]
        persistent = ((missing_frac + 1e-5) >= thres).data
        persistent[combined_mask] = False

        fill_masks[month_number] = []

        for month_index in np.where(month_numbers == month_number)[0]:
            month_data = cube.data[month_index]

            fill_mask = persistent & cube.data.mask[month_index]
            month_data[fill_mask] = min_cube.data[fill_mask]

            cube.data[month_index] = month_data

            fill_masks[month_number].append(fill_mask)

    return cube, fill_masks

### Test how many times filling is done for each month

In [None]:
fill_sums_data = {}
for dataset, var_name in [
    (Ext_MOD15A2H_fPAR(), "FAPAR"),
]:
    selected_d = Datasets(dataset).select_variables(var_name).dataset
    selected_d.limit_months(
        PartialDateTime(year=2008, month=1), PartialDateTime(year=2015, month=4)
    )
    cube = selected_d.cube
    assert cube.shape[0] > 10
    filled, fill_masks = _persistent_gap_filling(cube, verbose=True)
    fill_masks_sum = {
        month_number: np.vstack([data[None] for data in fill_masks[month_number]]).mean(
            axis=0
        )
        for month_number in fill_masks
    }
    fill_sums_data[var_name] = fill_masks_sum

In [None]:
for var_name, fill_masks_sum in fill_sums_data.items():
    fig = plt.figure(figsize=(10, 8))
    axes = []

    cax_height = 0.01
    cax_width = 0.3
    cax = fig.add_axes([0.5 - cax_width / 2, 0, cax_width, cax_height])

    height = (1 - cax_height) / 4
    width = 1 / 3
    width_pad = 0.008

    for y in np.linspace(1, cax_height, 4, endpoint=False):
        for x in np.linspace(0, 1, 3, endpoint=False):
            axes.append(
                fig.add_axes(
                    [x + width_pad, y - height, width - 2 * width_pad, height],
                    projection=ccrs.Robinson(),
                )
            )

    for month_number, ax in zip(fill_masks_sum, axes):
        cube_plotting(
            fill_masks_sum[month_number],
            title="",
            ax=ax,
            boundaries=np.linspace(0, 1, 11),
            colorbar_kwargs=False
            if month_number != 11
            else dict(
                label="Filled Fraction (1)",
                cax=cax,
                orientation="horizontal",
                format="%.1f",
            ),
        )
        ax.set_title(f"{calendar.month_abbr[month_number]}")
    map_figure_saver.save_figure(fig, f"{var_name}_nr_filled_per_month")