In [None]:
import calendar

import iris
import numpy as np
from iris.time import PartialDateTime
from wildfires.analysis import *

In [None]:
from wildfires.data import *
from wildfires.utils import *

from empirical_fire_modelling.utils import tqdm

In [None]:
def _persistent_gap_filling(cube, thres=0.5, verbose=False):
    """Fill gaps >= (thres * 100)% of months with minimum value at that location.

    This is done in-place.

    """
    if not cube.coords("month_number"):
        iris.coord_categorisation.add_month_number(cube, "time")

    combined_mask = np.all(cube.data.mask, axis=0)

    nr_inval_cube = cube.copy(
        data=np.ma.MaskedArray(
            cube.data.mask, mask=match_shape(combined_mask, cube.shape)
        )
    )

    min_cube = cube.collapsed("time", iris.analysis.MIN)

    # Month numbers in [1, 12].
    month_numbers = cube.coord("month_number").points

    fill_masks = {}

    for month_number in tqdm(range(1, 13), desc="Months", disable=not verbose):
        extracted = iris.Constraint(month_number=month_number).extract(nr_inval_cube)
        missing_frac = np.sum(extracted.data, axis=0) / extracted.shape[0]
        persistent = ((missing_frac + 1e-5) >= thres).data
        persistent[combined_mask] = False

        fill_masks[month_number] = []

        for month_index in np.where(month_numbers == month_number)[0]:
            month_data = cube.data[month_index]

            fill_mask = persistent & cube.data.mask[month_index]
            month_data[fill_mask] = min_cube.data[fill_mask]

            cube.data[month_index] = month_data

            fill_masks[month_number].append(fill_mask)

    return cube, fill_masks

### Test how many times filling is done for each month

In [None]:
for dataset, var_name in [
    (Ext_MOD15A2H_fPAR(), "FAPAR"),
]:
    selected_d = Datasets(dataset).select_variables(var_name).dataset
    selected_d.limit_months(
        PartialDateTime(year=2008, month=1), PartialDateTime(year=2015, month=4)
    )
    cube = selected_d.cube
    assert cube.shape[0] > 10
    filled, fill_masks = _persistent_gap_filling(cube, verbose=True)
    fill_masks_sum = {
        month_number: np.vstack([data[None] for data in fill_masks[month_number]]).mean(
            axis=0
        )
        for month_number in fill_masks
    }
    for month_number in fill_masks_sum:
        cube_plotting(
            fill_masks_sum[month_number],
            title=f"Fraction Nr. filled for {var_name} {calendar.month_abbr[month_number]}",
        )