### Setup

In [None]:
import concurrent.futures
import warnings
from datetime import datetime
from functools import reduce
from pathlib import Path

import iris
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.optimize import minimize

from wildfires.analysis import *
from wildfires.data import *
from wildfires.logging_config import enable_logging
from wildfires.qstat import *
from wildfires.utils import *

if "TQDMAUTO" in os.environ:
    from tqdm.auto import tqdm
else:
    from tqdm import tqdm

enable_logging("jupyter")
figure_saver = FigureSaver(directories=Path("~") / "tmp" / "interp_comp", debug=True)

warnings.filterwarnings("ignore", ".*Collapsing a non-contiguous coordinate.*")
warnings.filterwarnings("ignore", ".*DEFAULT_SPHERICAL_EARTH_RADIUS.*")
warnings.filterwarnings("ignore", ".*guessing contiguous bounds.*")
warnings.filterwarnings("ignore", ".*divide by zero.*")

mpl.rc("figure", figsize=(12, 6))

In [None]:
n_months = 3

In [None]:
def harmonic_fit(t, params):
    """Sine-based fitting including offset.
        
    Args:
        t (int): Time index.
        params (array-like): 
            0th - offset
            1th - gradient
            (2j, 2j+1) entries - jth component amplitude and phase, j <= 1.            
            
    Returns:
        float: Fitted function value at `t`.
    
    """
    t = np.asarray(t, dtype=np.float64)
    output = np.zeros_like(t, dtype=np.float64)
    output += params[0]
    output += params[1] * t
    for (j, (amplitude, phase)) in enumerate(zip(params[2::2], params[3::2])):
        j += 1
        output += amplitude * np.sin((2 * np.pi * j * t / 12) + phase)
    return output


def min_fit(x, *args):
    """Function to be minimised.
    
    Args:
        x (array-like): Fit parameters.
        args: Month indices and corresponding data to fit to.
        
    Returns:
        float: MSE fit error.
    
    """
    ts = args[0]
    fit_data = args[1]
    return np.sum((fit_data - harmonic_fit(ts, x)) ** 2.0)


def persistent_gap_filling(cube, combined_mask, thres=0.5):
    """Fill gaps >= (thres * 100)% of months with minimum value at that location."""
    cube = cube.copy()
    if not cube.coords("month_number"):
        iris.coord_categorisation.add_month_number(cube, "time")

    nr_inval_cube = cube.copy(
        data=np.ma.MaskedArray(
            cube.data.mask, mask=match_shape(combined_mask, cube.shape),
        )
    )

    min_cube = cube.collapsed("time", iris.analysis.MIN)

    # Month numbers in [1, 12].
    month_numbers = cube.coord("month_number").points

    for month_number in tqdm(range(1, 13), desc="Months"):
        extracted = iris.Constraint(month_number=month_number).extract(nr_inval_cube)
        missing_frac = np.sum(extracted.data, axis=0) / extracted.shape[0]
        persistent = ((missing_frac + 1e-5) >= thres).data
        persistent[combined_mask] = False

        for month_index in np.where(month_numbers == month_number)[0]:
            month_data = cube.data[month_index]

            fill_mask = persistent & cube.data.mask[month_index]
            month_data[fill_mask] = min_cube.data[fill_mask]

            cube.data[month_index] = month_data

    return cube


def _season_fill(fill_locs, data, n):
    ts = np.arange(data.shape[0])

    for xi, yi in zip(*np.where(fill_locs)):
        sel = data[:, xi, yi]
        # Execute minimisation.
        res = minimize(min_fit, np.zeros(n), (ts, sel))
        # Replace masked elements with function fit values.
        sel[sel.mask] = harmonic_fit(ts, res.x)[sel.mask]

    return data


def season_model_filling(cube, n=8):
    cube = cube.copy()

    # Fill where there is some valid data, but not only valid data, since there would
    # be nothing to fill in the latter case.
    fill_locs = np.any(~cube.data.mask, axis=0) & (~np.all(~cube.data.mask, axis=0))

    # Partition the rows of the array in chunks to be processed.
    ncpus = get_ncpus()
    nrows = cube.shape[1]

    chunk_edges = np.unique(np.append(np.arange(0, nrows, 2, dtype=np.int64), nrows,))

    with concurrent.futures.ProcessPoolExecutor(max_workers=ncpus) as executor:
        fs = []
        processed_slices = []
        for chunk_s, chunk_e in zip(chunk_edges[:-1], chunk_edges[1:]):
            chunk_slice = slice(chunk_s, chunk_e)
            if not np.any(fill_locs[chunk_slice]):
                # Skip those slices without anything to fill.
                continue
            processed_slices.append(chunk_slice)
            fs.append(
                executor.submit(
                    _season_fill, fill_locs[chunk_slice], cube.data[:, chunk_slice], n
                )
            )

        for f in tqdm(
            concurrent.futures.as_completed(fs),
            desc="Season model filling",
            total=len(fs),
        ):
            pass

        for f, chunk_slice, in zip(fs, processed_slices):
            cube.data[:, chunk_slice] = f.result()

    return cube

## Create datasets

In [None]:
variables = ("SWI(1)", "FAPAR", "LAI", "VOD Ku-band", "SIF")
datasets = Datasets(
    (Copernicus_SWI(), MOD15A2H_LAI_fPAR(), VODCA(), GlobFluo_SIF())
).select_variables(variables)

### Temporal interpolation

In [None]:
timeperiod = (datetime(2010, 1, 1, 0, 0), datetime(2015, 1, 1, 0, 0))
period_str = f"{timeperiod[0]:%Y-%m} - {timeperiod[1]:%Y-%m}"
orig_datasets = datasets.copy()
for dataset in orig_datasets:
    datasets.add(dataset.get_temporally_interpolated_dataset(timeperiod, 3))
for dataset in datasets:
    dataset.limit_months(*timeperiod)

In [None]:
datasets

In [None]:
datasets.show("pretty")

### Calculate climatologies

In [None]:
climatologies = Datasets(
    [
        dataset.get_climatology_dataset(dataset.min_time, dataset.max_time)
        for dataset in tqdm(
            datasets.select_variables(variables, inplace=False),
            desc="Getting climatologies",
        )
    ]
)

### Combined mask

In [None]:
total_masks = []

for var in tqdm(variables, desc="Variable"):
    cube = datasets.select_variables(var, inplace=False).cube.copy()
    if not cube.coords("month_number"):
        iris.coord_categorisation.add_month_number(cube, "time")

    # Ignore areas that are always masked, e.g. water.
    ignore_mask = np.all(cube.data.mask, axis=0)

    # Also ignore those areas with low data availability.
    ignore_mask |= np.sum(cube.data.mask, axis=0) > (
        5 * 6 + 1  # Up to 6 months for each of the 5 complete years.  # Extra January.
    )

    total_masks.append(ignore_mask)

In [None]:
combined_mask = reduce(
    np.logical_or, [regrid(dummy_lat_lon_cube(mask)).data for mask in total_masks]
)

### Apply combined mask to 'fresh' datasets and get data filled using minima and season-trend model

In [None]:
masked_datasets = Datasets(
    (Copernicus_SWI(), MOD15A2H_LAI_fPAR(), VODCA(), GlobFluo_SIF())
).select_variables(variables)

# Select correct time period and regrid to common grid.
for dataset in masked_datasets:
    dataset.limit_months(*timeperiod)
    dataset.regrid()

# Apply the combined mask.
masked_datasets.apply_masks(combined_mask)

# Retrieve the filled dataset for later comparison.
processed_datasets = Datasets(
    [dataset.get_persistent_season_trend_dataset() for dataset in masked_datasets]
)

### Missing data filling

In [None]:
gap_filled_cubes = {}
model_filled_cubes = {}

for var in tqdm(variables, desc="Filling variables"):
    cube = regrid(datasets.select_variables(var, inplace=False).cube.copy())

    if not cube.coords("month_number"):
        iris.coord_categorisation.add_month_number(cube, "time")

    cube.data.mask |= match_shape(combined_mask, cube.shape)

    gap_filled_cube = persistent_gap_filling(cube, combined_mask)
    model_filled_cube = season_model_filling(gap_filled_cube)

    gap_filled_cubes[var] = gap_filled_cube
    model_filled_cubes[var] = model_filled_cube

### Function comparison - both model-filled datasets should be equivalent

In [None]:
for var in variables:
    orig_cube = masked_datasets.select_variables(var, inplace=False).cube
    old_cube = model_filled_cubes[var]
    new_cube = processed_datasets.select_variables(var + " 50P 4k", inplace=False).cube

    print(var, new_cube == old_cube)
    cube_plotting(old_cube, title=f"{var} old\n{period_str}")
    cube_plotting(new_cube, title=f"{var} new\n{period_str}")
    mean_diffs = np.mean(np.abs(old_cube.data - new_cube.data), axis=0)
    max_vals = np.abs(np.max(orig_cube.data, axis=0))

    cube_plotting(mean_diffs / max_vals, title="Relative Mean |Diffs|")
    for i in (mean_diffs / max_vals).ravel().argsort(fill_value=0)[::-1][:20]:
        plt.figure()
        s = (slice(None), *np.unravel_index(i, old_cube.shape[1:]))
        plt.plot(orig_cube.data[s], label="orig", marker="x")
        plt.plot(old_cube.data[s], label="old")
        plt.plot(new_cube.data[s], label="new")
        plt.legend(loc="best")

### Masks (how many missing samples)

In [None]:
saver = figure_saver(sub_directory="masks")

for var in tqdm(variables, desc="Variable"):
    cube = datasets.select_variables(var, inplace=False).cube.copy()
    if not cube.coords("month_number"):
        iris.coord_categorisation.add_month_number(cube, "time")

    # Ignore areas that are always masked, e.g. water.
    ignore_mask = np.all(cube.data.mask, axis=0)

    fig = cube_plotting(
        np.ma.MaskedArray(
            cube.data.mask.astype("float64"),
            mask=match_shape(ignore_mask, cube.data.shape),
        ),
        title=f"{var}\n{period_str}",
        boundaries=np.linspace(0, 1, 6),
    )
    saver.save_figure(fig, f"{var} masked samples")

    # Missing fraction by month.
    counts = {}
    for month_number in tqdm(range(1, 13), desc="Months"):
        extracted = iris.Constraint(month_number=month_number).extract(cube)
        ext_mask = np.sum(extracted.data.mask, axis=0)

        fig = cube_plotting(
            np.ma.MaskedArray(ext_mask, mask=match_shape(ignore_mask, ext_mask.shape),),
            title=f"{var} month {month_number}\n{period_str}",
        )
        saver.save_figure(fig, f"{var} month {month_number} masked samples")

        selected = ext_mask[~match_shape(ignore_mask, ext_mask.shape)]
        counts[month_number] = np.sum(selected)

    with saver(f"{var} masked months"):
        pd.Series(counts).to_frame("masked months").plot.bar(rot=0)

### Ignore locations with too much missing data - Resulting Masks

In [None]:
period_str = f"{timeperiod[0]:%Y-%m} - {timeperiod[1]:%Y-%m}"

saver = figure_saver(sub_directory="sel masks")

for var in tqdm(variables, desc="Variable"):
    cube = datasets.select_variables(var, inplace=False).cube.copy()
    if not cube.coords("month_number"):
        iris.coord_categorisation.add_month_number(cube, "time")

    # Ignore areas that are always masked, e.g. water.
    ignore_mask = np.all(cube.data.mask, axis=0)

    # Also ignore those areas with low data availability.
    ignore_mask |= np.sum(cube.data.mask, axis=0) > (
        5 * 6 + 1  # Up to 6 months for each of the 5 complete years.  # Extra January.
    )

    nr_inval_cube = cube.copy(
        data=np.ma.MaskedArray(
            cube.data.mask.copy(), mask=match_shape(ignore_mask, cube.shape),
        )
    )

    nr_inval_cube = nr_inval_cube.collapsed("time", iris.analysis.SUM)

    fig = cube_plotting(
        nr_inval_cube,
        title=f"{var}\n{period_str}",
        colorbar_kwargs={"label": "nr. invalid"},
    )
    saver.save_figure(fig, f"{var} nr masked samples")

    # Missing fraction by month.
    counts = {}
    for month_number in tqdm(range(1, 13), desc="Months"):
        extracted = iris.Constraint(month_number=month_number).extract(cube)
        ext_mask = np.sum(extracted.data.mask, axis=0)
        selected = ext_mask[~match_shape(ignore_mask, ext_mask.shape)]
        counts[month_number] = np.sum(selected)

    with saver(f"{var} masked months"):
        pd.Series(counts).to_frame("masked months").plot.bar(rot=0)

In [None]:
# Plot combined ignore masks.
fig = cube_plotting(
    combined_mask,
    title=f"Combined Mask\n{period_str}",
    colorbar_kwargs={"label": "masked"},
    boundaries=np.linspace(0, 1, 3),
    fig=plt.figure(figsize=(18, 9)),
)
figure_saver.save_figure(fig, f"combined mask samples")

In [None]:
ba_dataset = GFEDv4()
ba_dataset.limit_months(*timeperiod)
mean_ba = ba_dataset.cube.collapsed("time", iris.analysis.MEAN)
mean_ba.data.mask = ~get_land_mask()

In [None]:
fig = cube_plotting(
    mean_ba,
    title=f"Mean BA\n{period_str}",
    colorbar_kwargs={"label": "BA", "format": "%0.0e"},
    cmap="YlOrRd",
    fig=plt.figure(figsize=(18, 9)),
    boundaries=[1e-5, 1e-4, 1e-3, 1e-2, 1e-1],
    extend="min",
)
figure_saver.save_figure(fig, f"burned area")

In [None]:
masked_mean_ba = mean_ba.copy()
masked_mean_ba.data.mask = combined_mask
fig = cube_plotting(
    masked_mean_ba,
    title=f"Mean BA\n{period_str}",
    colorbar_kwargs={"label": "BA", "format": "%0.0e"},
    cmap="YlOrRd",
    fig=plt.figure(figsize=(18, 9)),
    boundaries=[1e-5, 1e-4, 1e-3, 1e-2, 1e-1],
    extend="min",
)
figure_saver.save_figure(fig, f"combined mask burned area")

### Combined mask - missing NH samples

In [None]:
saver = figure_saver(sub_directory="comb mask NH")

for var in tqdm(variables, desc="Variable"):
    cube = regrid(datasets.select_variables(var, inplace=False).cube.copy())
    if not cube.coords("month_number"):
        iris.coord_categorisation.add_month_number(cube, "time")

    nr_inval_cube = cube.copy(
        data=np.ma.MaskedArray(
            cube.data.mask, mask=match_shape(combined_mask, cube.shape),
        )
    )

    nr_inval_cube = iris.Constraint(
        coord_values={"latitude": lambda cell: 0 < cell}
    ).extract(nr_inval_cube)

    fig = cube_plotting(
        nr_inval_cube.collapsed("time", iris.analysis.SUM),
        title=f"{var}\n{period_str}",
        colorbar_kwargs={"label": "nr. invalid"},
    )
    saver.save_figure(fig, f"{var} nr masked samples")

    # Missing fraction by month.
    counts = {}
    for month_number in tqdm(range(1, 13), desc="Months"):
        extracted = iris.Constraint(month_number=month_number).extract(nr_inval_cube)
        counts[month_number] = np.sum(extracted.data)

    with saver(f"{var} masked months"):
        pd.Series(counts).to_frame("masked months").plot.bar(rot=0)

### Combined mask - missing SH samples

In [None]:
saver = figure_saver(sub_directory="comb mask SH")

for var in tqdm(variables, desc="Variable"):
    cube = regrid(datasets.select_variables(var, inplace=False).cube.copy())
    if not cube.coords("month_number"):
        iris.coord_categorisation.add_month_number(cube, "time")

    nr_inval_cube = cube.copy(
        data=np.ma.MaskedArray(
            cube.data.mask, mask=match_shape(combined_mask, cube.shape),
        )
    )

    nr_inval_cube = iris.Constraint(
        coord_values={"latitude": lambda cell: 0 > cell}
    ).extract(nr_inval_cube)

    fig = cube_plotting(
        nr_inval_cube.collapsed("time", iris.analysis.SUM),
        title=f"{var}\n{period_str}",
        colorbar_kwargs={"label": "nr. invalid"},
    )
    saver.save_figure(fig, f"{var} nr masked samples")

    # Missing fraction by month.
    counts = {}
    for month_number in tqdm(range(1, 13), desc="Months"):
        extracted = iris.Constraint(month_number=month_number).extract(nr_inval_cube)
        counts[month_number] = np.sum(extracted.data)

    with saver(f"{var} masked months"):
        pd.Series(counts).to_frame("masked months").plot.bar(rot=0)

In [None]:
def minimisation(ts, sel, n=8, plot=False, sel2=None, sel3=None):
    # Execute minimisation.
    res = minimize(min_fit, np.zeros(n), (ts, sel))
    rmse = np.sqrt(res.fun)

    if plot:
        fig, ax = plt.subplots()
        ax.plot(ts, sel, c="C3", label="persistent gap fill")
        if sel2 is not None:
            ax.plot(ts, sel2, c="C0", linestyle="-.", marker="x", label="raw")

        if sel3 is not None:
            ax.plot(
                ts,
                sel3,
                c="C4",
                linestyle="--",
                marker="o",
                zorder=0,
                label="season model fill",
            )

        # Prevent further autoscaling.
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        ax.autoscale(False)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)

        ax.vlines(
            np.arange(0.5, 62, 12),
            ymin=-100,
            ymax=100,
            colors="C3",
            zorder=3,
            linestyle="--",
            alpha=0.4,
        )
        ax.plot(ts, harmonic_fit(ts, res.x), c="C1", linestyle="--")
        ax.set_title(
            f"invalid: {np.sum(sel2.mask)}, gap filled: {np.sum(sel2.mask) - np.sum(sel.mask)}, model filled: {np.sum(sel.mask) - np.sum(sel3.mask)}, RMSE: {rmse}"
        )
    return rmse

### Northern masked timeseries

In [None]:
saver = figure_saver(sub_directory="masked_timeseries")

for var in variables:
    cube = regrid(datasets.select_variables(var, inplace=False).cube.copy())

    if not cube.coords("month_number"):
        iris.coord_categorisation.add_month_number(cube, "time")

    cube.data.mask |= match_shape(combined_mask, cube.shape)

    gap_filled_cube = gap_filled_cubes[var]
    model_filled_cube = model_filled_cubes[var]

    constraint = iris.Constraint(
        coord_values={"latitude": lambda cell: (0 < cell) and (cell < 90)}
    )

    extracted = constraint.extract(cube)
    extracted_filled = constraint.extract(gap_filled_cube)
    model_filled_extracted = constraint.extract(model_filled_cube)

    cube_plotting(extracted)

    xi, yi = np.where(np.any(~extracted.data.mask, axis=0))

    class StopPlotting(Exception):
        pass

    valid_loc = np.sum(np.any(~extracted.data.mask, axis=0))
    print("number of valid locations:", valid_loc)

    count = 0
    thres = 100  # Number of plots to carry out.

    # Randomly allow plotting to take place to achieve approx. the intended number.
    chance = thres / valid_loc

    rng = np.random.RandomState(1)

    latitudes = extracted.coord("latitude").points
    valid_lats = latitudes[np.any(extracted.data.mask, axis=(0, 2))]

    fig, ax = plt.subplots()

    ax_cb = make_axes_locatable(ax).new_horizontal(size="3%", pad=0.1)
    norm = mpl.colors.Normalize(
        vmin=np.min(valid_lats), vmax=np.max(valid_lats), clip=True
    )
    cb = mpl.colorbar.ColorbarBase(
        ax_cb, cmap=mpl.cm.viridis, orientation="vertical", norm=norm,
    )
    fig.add_axes(ax_cb)

    try:
        for i, j in zip(tqdm(xi, desc="Valid positions"), yi):
            sel = extracted.data[:, i, j]

            if rng.random() > chance:
                # Skip positions at random to only plot the desired number on average.
                continue

            ax.plot(
                range(1, extracted.shape[0] + 1),
                sel,
                c=mpl.cm.viridis(norm(latitudes[i])),
                alpha=0.4,
                zorder=4,
            )
            count += 1

            # Disable hard exit.
            if (count > thres) and False:
                raise StopPlotting()
    except StopPlotting:
        pass

    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    ax.vlines(
        np.arange(0.5, 62, 12),
        ymin=-100,
        ymax=100,
        colors="C3",
        zorder=3,
        linestyle="--",
        alpha=0.4,
    )
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)

    # Plot individual timeseries with seasonal fits.

    count = 0
    limit = 30

    for i in np.random.permutation(np.arange(xi.size)):
        sel = extracted_filled.data[:, xi[i], yi[i]]
        sel2 = extracted.data[:, xi[i], yi[i]]
        sel3 = model_filled_extracted.data[:, xi[i], yi[i]]

        if np.sum(sel2.mask) < 20:
            # Ignore good timeseries.
            continue
        elif (np.sum(sel.mask) - np.sum(sel3.mask)) < 1:
            # Ignore timeseries that do not contain model-filled samples.
            continue
        else:
            count += 1
            if count > limit:
                break

        ts = np.arange(1, sel.size + 1)
        minimisation(ts, sel, plot=True, sel2=sel2, sel3=sel3)

### Southern masked timeseries

In [None]:
saver = figure_saver(sub_directory="masked_timeseries")

for var in variables:
    cube = regrid(datasets.select_variables(var, inplace=False).cube.copy())

    if not cube.coords("month_number"):
        iris.coord_categorisation.add_month_number(cube, "time")

    cube.data.mask |= match_shape(combined_mask, cube.shape)

    gap_filled_cube = gap_filled_cubes[var]
    model_filled_cube = model_filled_cubes[var]

    constraint = iris.Constraint(
        coord_values={"latitude": lambda cell: (-90 < cell) and (cell < 0)}
    )

    extracted = constraint.extract(cube)
    extracted_filled = constraint.extract(gap_filled_cube)
    model_filled_extracted = constraint.extract(model_filled_cube)

    cube_plotting(extracted)

    xi, yi = np.where(np.any(~extracted.data.mask, axis=0))

    class StopPlotting(Exception):
        pass

    valid_loc = np.sum(np.any(~extracted.data.mask, axis=0))
    print("number of valid locations:", valid_loc)

    count = 0
    thres = 100  # Number of plots to carry out.

    # Randomly allow plotting to take place to achieve approx. the intended number.
    chance = thres / valid_loc

    rng = np.random.RandomState(1)

    latitudes = extracted.coord("latitude").points
    valid_lats = latitudes[np.any(extracted.data.mask, axis=(0, 2))]

    fig, ax = plt.subplots()

    ax_cb = make_axes_locatable(ax).new_horizontal(size="3%", pad=0.1)
    norm = mpl.colors.Normalize(
        vmin=np.min(valid_lats), vmax=np.max(valid_lats), clip=True
    )
    cb = mpl.colorbar.ColorbarBase(
        ax_cb, cmap=mpl.cm.viridis, orientation="vertical", norm=norm,
    )
    fig.add_axes(ax_cb)

    try:
        for i, j in zip(tqdm(xi, desc="Valid positions"), yi):
            sel = extracted.data[:, i, j]

            if rng.random() > chance:
                # Skip positions at random to only plot the desired number on average.
                continue

            ax.plot(
                range(1, extracted.shape[0] + 1),
                sel,
                c=mpl.cm.viridis(norm(latitudes[i])),
                alpha=0.4,
                zorder=4,
            )
            count += 1

            # Disable hard exit.
            if (count > thres) and False:
                raise StopPlotting()
    except StopPlotting:
        pass

    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    ax.vlines(
        np.arange(0.5, 62, 12),
        ymin=-100,
        ymax=100,
        colors="C3",
        zorder=3,
        linestyle="--",
        alpha=0.4,
    )
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)

    # Plot individual timeseries with seasonal fits.

    count = 0
    limit = 30

    for i in np.random.permutation(np.arange(xi.size)):
        sel = extracted_filled.data[:, xi[i], yi[i]]
        sel2 = extracted.data[:, xi[i], yi[i]]
        sel3 = model_filled_extracted.data[:, xi[i], yi[i]]

        if np.sum(sel2.mask) < 20:
            # Ignore good timeseries.
            continue
        elif (np.sum(sel.mask) - np.sum(sel3.mask)) < 1:
            # Ignore timeseries that do not contain model-filled samples.
            continue
        else:
            count += 1
            if count > limit:
                break

        ts = np.arange(1, sel.size + 1)
        minimisation(ts, sel, plot=True, sel2=sel2, sel3=sel3)

### Persistent gaps - 50% or more data missing for a given month

In [None]:
saver = figure_saver(sub_directory="persistent_gaps")

cmap = colors.ListedColormap(["blue", "red"])
boundaries = np.linspace(0, 1, 3)
norm = colors.BoundaryNorm(boundaries, cmap.N, clip=True)

for var in tqdm(variables, desc="Variables"):
    cube = regrid(datasets.select_variables(var, inplace=False).cube.copy())

    if not cube.coords("month_number"):
        iris.coord_categorisation.add_month_number(cube, "time")

    nr_inval_cube = cube.copy(
        data=np.ma.MaskedArray(
            cube.data.mask.copy(), mask=match_shape(combined_mask, cube.shape),
        )
    )

    nr_inval_cube.data.mask |= match_shape(np.all(~cube.data.mask, axis=0), cube.shape)

    for month_number in tqdm(range(1, 13), desc="Months"):
        extracted = iris.Constraint(month_number=month_number).extract(nr_inval_cube)
        missing_frac = np.sum(extracted.data, axis=0) / extracted.shape[0]

        fig = cube_plotting(
            (missing_frac + 1e-5) >= 0.5,
            title=f"{var} month {month_number} missing frac >= 0.5\n{period_str}",
            colorbar_kwargs={"label": "missing frac >= 0.5"},
            boundaries=boundaries,
            cmap=cmap,
            norm=norm,
        )
        saver(sub_directory="maps").save_figure(
            fig, f"{var} month {month_number} persistent gaps map"
        )

In [None]:
saver = figure_saver(sub_directory="persistent_gaps")

for var in tqdm(variables, desc="Variables"):
    cube = regrid(datasets.select_variables(var, inplace=False).cube.copy())

    if not cube.coords("month_number"):
        iris.coord_categorisation.add_month_number(cube, "time")

    nr_inval_cube = cube.copy(
        data=np.ma.MaskedArray(
            cube.data.mask.copy(), mask=match_shape(combined_mask, cube.shape),
        )
    )

    nr_inval_cube.data.mask |= match_shape(np.all(~cube.data.mask, axis=0), cube.shape)

    for month_number in tqdm(range(1, 13), desc="Months"):
        extracted = iris.Constraint(month_number=month_number).extract(nr_inval_cube)
        missing_frac = np.sum(extracted.data, axis=0) / extracted.shape[0]

        xi, yi = np.where((missing_frac + 1e-5) >= 0.5)
        fig, ax = plt.subplots()
        ax.hist(cube.coord("latitude").points[xi], bins=20)
        ax.set_xlim(-90, 90)
        ax.set_title(f"{var} month {month_number} missing >= 50%")

        saver.save_figure(fig, f"{var} month {month_number} persistent gaps latitudes")

### Northern Climatological Masks (how many missing samples)

In [None]:
period_str = f"{timeperiod[0]:%Y-%m} - {timeperiod[1]:%Y-%m}"

saver = figure_saver(sub_directory="climatology_masks")

for var in variables:
    cube = iris.Constraint(coord_values={"latitude": lambda cell: 40 < cell}).extract(
        climatologies.select_variables(var, inplace=False).cube.copy()
    )

    if not cube.coords("month_number"):
        iris.coord_categorisation.add_month_number(cube, "time")

    # Ignore areas that are always masked, e.g. water.
    ignore_mask = np.all(cube.data.mask, axis=0)

    cube.data = np.ma.MaskedArray(
        cube.data.mask.astype("float64"),
        mask=match_shape(ignore_mask, cube.data.shape),
    )

    fig = cube_plotting(
        cube.collapsed("month_number", iris.analysis.MEAN),
        title=f"{var}\n{period_str}",
        boundaries=np.linspace(0, 1, 6),
    )
    saver.save_figure(fig, f"{var} masked samples")

    # Missing fraction by month.
    counts = {}
    for month_number in range(1, 13):
        extracted = iris.Constraint(month_number=month_number).extract(cube)

        fig = cube_plotting(
            extracted,
            title=f"{var} month {month_number}\n{period_str}",
            boundaries=np.linspace(0, 1, 6),
        )
        saver.save_figure(fig, f"{var} month {month_number} masked samples")

        selected = extracted.data[~extracted.data.mask]
        counts[month_number] = np.sum(selected)

    with saver(f"{var} masked months"):
        pd.Series(counts).to_frame("masked months").plot.bar(rot=0)

### Northern masked climatological timeseries

In [None]:
period_str = f"{timeperiod[0]:%Y-%m} - {timeperiod[1]:%Y-%m}"

saver = figure_saver(sub_directory="masked_climatological_timeseries")

for var in variables:
    cube = iris.Constraint(coord_values={"latitude": lambda cell: 40 < cell}).extract(
        climatologies.select_variables(var, inplace=False).cube.copy()
    )

    if not cube.coords("month_number"):
        iris.coord_categorisation.add_month_number(cube, "time")

    # Ignore areas that are always masked, e.g. water.
    ignore_mask = np.all(cube.data.mask, axis=0)

    # Ignore areas that are never masked.
    ignore_mask |= ~np.any(cube.data.mask, axis=0)

    cube.data.mask |= match_shape(ignore_mask, cube.shape)

    extracted = cube

    cube_plotting(extracted.collapsed("month_number", iris.analysis.MEAN))

    class StopPlotting(Exception):
        pass

    valid_loc = np.sum(np.any(~extracted.data.mask, axis=0))
    print("number of valid locations:", valid_loc)

    count = 0
    thres = 20  # Number of plots to carry out.

    # Randomly allow plotting to take place to achieve approx. the intended number.
    chance = thres / valid_loc

    rng = np.random.RandomState(1)

    latitudes = extracted.coord("latitude").points
    valid_lats = latitudes[np.any(extracted.data.mask, axis=(0, 2))]

    fig, ax = plt.subplots()
    fig2, ax2 = plt.subplots()

    ax_cb = make_axes_locatable(ax).new_horizontal(size="3%", pad=0.1)
    norm = mpl.colors.Normalize(
        vmin=np.min(valid_lats), vmax=np.max(valid_lats), clip=True
    )
    cb = mpl.colorbar.ColorbarBase(
        ax_cb, cmap=mpl.cm.viridis, orientation="vertical", norm=norm,
    )
    fig.add_axes(ax_cb)

    ax_cb2 = make_axes_locatable(ax2).new_horizontal(size="3%", pad=0.1)
    norm2 = mpl.colors.Normalize(
        vmin=np.min(valid_lats), vmax=np.max(valid_lats), clip=True
    )
    cb2 = mpl.colorbar.ColorbarBase(
        ax_cb2, cmap=mpl.cm.viridis, orientation="vertical", norm=norm2,
    )
    fig2.add_axes(ax_cb2)

    try:
        for i in tqdm(range(extracted.shape[1]), desc="Latitudes"):
            for j in range(extracted.shape[2]):
                sel = extracted.data[:, i, j]
                if np.all(sel.mask):
                    continue

                if rng.random() > chance:
                    continue

                alpha = 0.4

                ax.plot(
                    range(1, extracted.shape[0] + 1),
                    sel,
                    c=mpl.cm.viridis(norm(latitudes[i])),
                    alpha=alpha,
                    zorder=4,
                )

                ax2.plot(
                    range(1, extracted.shape[0] + 1),
                    sel / np.max(sel),
                    c=mpl.cm.viridis(norm(latitudes[i])),
                    alpha=alpha,
                    zorder=4,
                )

                count += 1

                # Disable hard exit.
                if (count > thres) and False:
                    raise StopPlotting()
    except StopPlotting:
        pass

    for _ax in (ax, ax2):
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        ax.vlines(
            np.arange(0.5, 62, 12),
            ymin=-100,
            ymax=100,
            colors="C3",
            zorder=3,
            linestyle="--",
            alpha=0.4,
        )
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)

## Compare original to interpolated datasets

In [None]:
period_str = f"{timeperiod[0]:%Y-%m} - {timeperiod[1]:%Y-%m}"

saver = figure_saver(sub_directory="interp_comp")
for var in variables:
    interp_var = var + f" {n_months}NN"
    plot_vars = (var, interp_var)
    plot_cubes = [
        datasets.select_variables(plot_var, inplace=False).cube
        for plot_var in plot_vars
    ]

    for plot_var, plot_cube in zip(plot_vars, plot_cubes):

        if not plot_cube.coords("month_number"):
            iris.coord_categorisation.add_month_number(plot_cube, "time")

        fig = cube_plotting(plot_cube, title=f"{plot_var}\n{period_str}")
        saver.save_figure(fig, f"{plot_var} mean")

    replace_cube_coord(plot_cubes[0], plot_cubes[1].coord("month_number"))
    replace_cube_coord(plot_cubes[0], plot_cubes[1].coord("latitude"))
    replace_cube_coord(plot_cubes[0], plot_cubes[1].coord("longitude"))

    mean_1 = np.mean(plot_cubes[1].data, axis=0)
    mean_0 = np.mean(plot_cubes[0].data, axis=0)

    fig = cube_plotting(
        mean_1 - mean_0,
        title=f"{plot_var} <interp> - <normal>\n{period_str}",
        cmap_midpoint=0,
        cmap_symmetric=True,
    )
    saver.save_figure(fig, f"{var} interp - normal mean")

    fig = cube_plotting(
        100 * (mean_1 - mean_0) / mean_0,
        title=f"{plot_var} (<interp> - <normal>) / <normal> (%)\n{period_str}",
        cmap_midpoint=0,
        cmap_symmetric=True,
        colorbar_kwargs={"label": "%"},
        log=True,
    )
    saver.save_figure(fig, f"{var} (interp - normal) normalised mean")

In [None]:
period_str = f"{timeperiod[0]:%Y-%m} - {timeperiod[1]:%Y-%m}"

saver = figure_saver(sub_directory="interp_comp_hist")

for var in variables:
    interp_var = var + f" {n_months}NN"
    plot_vars = (var, interp_var)
    plot_cubes = [
        datasets.select_variables(plot_var, inplace=False).cube
        for plot_var in plot_vars
    ]

    for plot_var, plot_cube in zip(plot_vars, plot_cubes):

        if not plot_cube.coords("month_number"):
            iris.coord_categorisation.add_month_number(plot_cube, "time")

    replace_cube_coord(plot_cubes[0], plot_cubes[1].coord("month_number"))
    replace_cube_coord(plot_cubes[0], plot_cubes[1].coord("latitude"))
    replace_cube_coord(plot_cubes[0], plot_cubes[1].coord("longitude"))

    mean_1 = np.mean(plot_cubes[1].data, axis=0)
    mean_0 = np.mean(plot_cubes[0].data, axis=0)

    fig, ax = plt.subplots()
    ax.hist((mean_1 - mean_0).flatten(), bins=200)
    ax.set_title(f"{plot_var} <interp> - <normal>\n{period_str}")
    ax.set_yscale("log")
    ax.set_xlabel(f"{plot_var}")
    saver.save_figure(fig, f"{var} interp - normal hist")

    fig, ax = plt.subplots()
    ax.hist((100 * (mean_1 - mean_0) / mean_0).flatten(), bins=200)
    ax.set_title(f"{plot_var} (<interp> - <normal>) / <normal> (%)\n{period_str}")
    ax.set_yscale("log")
    ax.set_xlabel("%")
    saver.save_figure(fig, f"{var} (interp - normal) normalised hist")

In [None]:
period_str = f"{timeperiod[0]:%Y-%m} - {timeperiod[1]:%Y-%m}"

saver = figure_saver(sub_directory="interp_comp_timeseries")

for var in variables:
    interp_var = var + f" {n_months}NN"
    plot_vars = (var, interp_var)
    plot_cubes = [
        datasets.select_variables(plot_var, inplace=False).cube
        for plot_var in plot_vars
    ]

    for plot_var, plot_cube in zip(plot_vars, plot_cubes):
        if not plot_cube.coords("month_number"):
            iris.coord_categorisation.add_month_number(plot_cube, "time")

    for coord in ("time", "latitude", "longitude"):
        assert np.all(
            plot_cubes[0].coord(coord).points == plot_cubes[1].coord(coord).points
        )

    replace_cube_coord(plot_cubes[0], plot_cubes[1].coord("month_number"))
    replace_cube_coord(plot_cubes[0], plot_cubes[1].coord("latitude"))
    replace_cube_coord(plot_cubes[0], plot_cubes[1].coord("longitude"))

    mean_1 = np.mean(plot_cubes[1].data, axis=0)
    mean_0 = np.mean(plot_cubes[0].data, axis=0)

    normal_data = plot_cubes[0].data
    interp_data = plot_cubes[1].data

    mean_diff_data = (mean_1 - mean_0) / mean_0

    unmasked = get_unmasked(mean_diff_data)
    sort_indices = np.argsort(np.abs(unmasked))
    for diff_i in sort_indices[-20:]:
        diff = unmasked[diff_i]
        indices = np.where(mean_diff_data == diff)

        lat = plot_cubes[0].coord("latitude").points[indices[0][0]]
        lon = plot_cubes[0].coord("longitude").points[indices[1][0]]

        lat_str = f"{abs(lat):0.1f}" + ("°N" if lat >= 0 else "°S")
        lon_str = f"{abs(lon):0.1f}" + ("°E" if lon >= 0 else "°W")
        loc_str = f"{lat_str}, {lon_str}"

        normal_data_sel = normal_data[:, indices[0][0], indices[1][0]]
        interp_data_sel = interp_data[:, indices[0][0], indices[1][0]]

        fig, ax = plt.subplots()
        ax.set_title(f"{plot_var}, {diff:0.1f}\n{period_str}, {loc_str}")
        ax.plot(normal_data_sel, marker="o", label="Original")
        ax.plot(interp_data_sel, marker="x", linestyle="--", label="Interpolated")
        ax.legend(loc="best")
        saver.save_figure(fig, f"{var}_{lat:0.1f}_{lon:0.1f}")

In [None]:
cube_plotting(mean_diff_data.mask)