# Fill in missing values using temporal NN interpolation

### Visualise results using VODCA

#### Gaps in the climatology over the selected timespan are used to seed the interpolation algorithm.

In [None]:
import warnings
from datetime import datetime

import iris
import iris.coord_categorisation
from dateutil.relativedelta import relativedelta
from tqdm.auto import tqdm

from wildfires.analysis import cube_plotting
from wildfires.data import VODCA, Datasets
from wildfires.logging_config import enable_logging
from wildfires.utils import box_mask, get_land_mask, match_shape

enable_logging("jupyter")
warnings.filterwarnings("ignore", ".*converting a masked element to nan.*")
warnings.filterwarnings("ignore", ".*Collapsing a non-contiguous coordinate.*")

vodca = Datasets(VODCA()).select_variables("VOD Ku-band")
source = vodca.dataset.copy(deep=True)
source.limit_months(datetime(2010, 1, 1), datetime(2015, 4, 1))
if not source.cube.coords("month_number"):
    iris.coord_categorisation.add_month_number(source.cube, "time")

source_masks = {}
for month_number in range(1, 13):
    single_months = source.cube.extract(iris.Constraint(month_number=month_number))

    raw_mask = single_months.data.mask

    # Add the land mask.
    land_mask = get_land_mask(n_lon=vodca.cubes[0].shape[-1])
    raw_mask &= match_shape(land_mask, raw_mask.shape)

    # Ignore regions south of -60° S.
    raw_mask &= match_shape(
        box_mask(lats=(-60, 90), lons=(-180, 180), n_lon=vodca.cubes[0].shape[-1]),
        raw_mask.shape,
    )

    source_masks[month_number] = np.mean(raw_mask, axis=0)
    _ = cube_plotting(
        source_masks[month_number], title=f"Source Mask Month {month_number}"
    )

In [None]:
# Set up the data.
target = vodca.dataset.copy(deep=True)
source = vodca.dataset.copy(deep=True)

# Set up the source mask (where to interpolate).

# Interpolate if a larger fraction of data than this is missing for a given month.
threshold = 0.6
interpolate_masks = {
    month_number: mask_frac > threshold
    for month_number, mask_frac in source_masks.items()
}

# Start and end date of the final data (inclusive).
target_timespan = (datetime(2010, 1, 1), datetime(2015, 4, 1))

# Number of months allowed to look forward or backward for valid samples.
n_months = 3
source_timespan = (
    target_timespan[0] - relativedelta(months=n_months),
    target_timespan[1] + relativedelta(months=n_months),
)
target_months = (
    (target_timespan[1].year - target_timespan[0].year) * 12
    + target_timespan[1].month
    - target_timespan[0].month
)

# Sanity check.
assert (target_months + n_months * 2) == (
    (source_timespan[1].year - source_timespan[0].year) * 12
    + source_timespan[1].month
    - source_timespan[0].month
)

# Discard unneeded months.
target.limit_months(*target_timespan)
source.limit_months(*source_timespan)

target_data = target.cube.data
source_data = source.cube.data


def temporal_nn(source_data, target_index, interpolate_masks, month_number, n_months):
    interpolate_mask = interpolate_masks[month_number]
    n_interp = interpolate_mask.sum()

    monthly_target_data = np.ma.MaskedArray(np.empty(n_interp), mask=True)

    for i, indices in enumerate(
        tqdm(
            zip(*np.where(interpolate_mask)), total=n_interp, leave=False, disable=False
        )
    ):
        adjacent_data = source_data[
            (slice(target_index, target_index + 2 * n_months + 1), *indices)
        ]
        # Try to find at least one match in the fewest months possible.
        for d in range(1, n_months + 1):
            selection_mask = (
                adjacent_data.mask[n_months - d],
                adjacent_data.mask[n_months + d],
            )
            if all(selection_mask):
                # All data is masked, so there is no valid data to choose from.
                continue
            selection = np.ma.MaskedArray(
                [adjacent_data[n_months - d], adjacent_data[n_months + d]]
            )
            # Fill in the missing element.
            monthly_target_data[i] = np.mean(selection)
            # Stop looking for matches.
            break
    return monthly_target_data


# Iterate over the months to fill.
current = target_timespan[0]
for target_index in tqdm(range(target_months)):
    month_number = current.month
    target_data[target_index][interpolate_masks[month_number]] = temporal_nn(
        source_data, target_index, interpolate_masks, month_number, n_months
    )
    current += relativedelta(months=1)

In [None]:
cube_plotting(target_data[1])

In [None]:
interpolated = vodca.dataset.copy()
interpolated.limit_months(*target_timespan)
interpolated.cube.data = target_data

In [None]:
for month_number in range(1, 13):
    single_months = interpolated.cube.extract(
        iris.Constraint(month_number=month_number)
    )

    raw_mask = single_months.data.mask

    # Add the land mask.
    land_mask = get_land_mask(n_lon=vodca.cubes[0].shape[-1])
    raw_mask &= match_shape(land_mask, raw_mask.shape)

    # Ignore regions south of -60° S.
    raw_mask &= match_shape(
        box_mask(lats=(-60, 90), lons=(-180, 180), n_lon=vodca.cubes[0].shape[-1]),
        raw_mask.shape,
    )
    _ = cube_plotting(
        np.isclose(np.mean(raw_mask, axis=0), 1),
        title=f"Interpolated Mask Month {month_number}",
    )

In [None]:
for month_number in range(1, 13):
    single_months = interpolated.cube.extract(
        iris.Constraint(month_number=month_number)
    )

    raw_mask = single_months.data.mask

    # Add the land mask.
    land_mask = get_land_mask(n_lon=vodca.cubes[0].shape[-1])
    raw_mask &= match_shape(land_mask, raw_mask.shape)

    # Ignore regions south of -60° S.
    raw_mask &= match_shape(
        box_mask(lats=(-60, 90), lons=(-180, 180), n_lon=vodca.cubes[0].shape[-1]),
        raw_mask.shape,
    )
    _ = cube_plotting(
        np.isclose(np.mean(raw_mask, axis=0), 1),
        title=f"Interpolated Mask Month {month_number}",
    )