In [None]:
import matplotlib.pyplot as plt
import numpy as np

plt.rcParams["figure.max_open_warning"] = 200

In [None]:
def get_local_extrema(data, extrema_type="max"):
    """
    Args:
        data (array-like):
        extrema_type ({'max', 'min'}): If 'max', find local maxima. 
            If 'min', find local minima.
    
    Returns:
        array-like: Boolean array that is True where a local minimum or maximum 
            is located.
            
    Raises:
        ValueError: If `extrema_type` is not in {'max', 'min'}
    
    """
    if extrema_type == "max":
        # Find local maxima.
        op = np.less
    elif extrema_type == "min":
        op = np.greater
    else:
        raise ValueError(f"Unexpected value for extrema_type: {extrema_type}.")

    return op(np.diff(np.sign(np.diff(np.hstack((data[0], data, data[-1]))))), 0)


def get_local_maxima(data):
    """Return a boolean mask denoting the location of local maxima."""
    return get_local_extrema(data, "max")


def get_local_minima(data):
    """Return a boolean mask denoting the location of local minima."""
    return get_local_extrema(data, "min")


def significant_peak(x, diff_threshold=0.4, ptp_threshold=1):
    """Determine the existence of a 'significant' peak.

    This is determined using both the range of the given data and the characteristics
    of its local maxima.

    Args:
        x (array-like): Data to test.
        diff_threshold (float): Only applies if there are at least 2 local
            maxima. The `diff_threshold` is applied to the heights of local maxima in
            `x`, rescaled linearly to the range [0, 1], where the global maximum takes
            the value 1. If any other local maxima exceed `diff_threshold`, the global
            maximum is not significant. Heights are measured relative to the lowest
            local minimum surrounding each local maximum.
        ptp_threshold (float): If the range of `x` is lower than `ptp_threshold`, no
            peaks will be marked significant.

    Returns:
        bool: True if there is a significant peak, False otherwise.

    """
    max_sample = np.max(x)
    min_sample = np.min(x)
    ptp = max_sample - min_sample

    if ptp < ptp_threshold:
        # If there is not enough variation, there is no significant peak.
        return False

    max_mask = get_local_maxima(x)

    if np.sum(max_mask) == 1:
        # If there is only one peak, there is nothing left to do.
        return True

    # If there are multiple peaks, we have to decide if these local maxima
    # are significant. If they are, there is no clearly defined
    # maximum for this sample.

    # Define significance of the minor peaks as the ratio between the
    # difference (peak value - local minima) and (peak value - local minima)
    # for the global maximum.

    min_indices = np.where(get_local_minima(x))[0]

    global_max_index = np.where(x == max_sample)[0][0]

    max_diffs = {}

    for max_index in np.where(max_mask)[0]:
        # Sample value at the local maximum.
        max_value = x[max_index]

        # Find the surrounding local minima.
        local_minima = []

        # Find preceding local minima, if any.
        if max_index > 0:
            local_minima.append(x[min_indices[min_indices < max_index][-1]])

        # Find following local minima, if any.
        if max_index < (len(x) - 1):
            local_minima.append(x[min_indices[min_indices > max_index][0]])

        max_diffs[max_index] = np.max(max_value - np.array(local_minima))

    global_max_diff = max_diffs[global_max_index]

    # Rescale using the maximum diff.
    for index, diff in max_diffs.items():
        max_diffs[index] = diff / global_max_diff

    if all(
        diff < diff_threshold
        for index, diff in max_diffs.items()
        if index != global_max_index
    ):
        return True
    else:
        return False

In [None]:
diff_threshold = 0.4
ptp_threshold = 1

train_samples = np.random.RandomState(1).random((50, 5)) ** 3

train_samples = np.vstack(
    (
        train_samples,
        np.array([0, 0.02, -0.02, 0.005, -0.007]),
        np.array([0, 0.3, -0.02, 0.005, -0.007]),
    )
)

# Normalise.
train_samples -= np.min(train_samples, axis=1).reshape(-1, 1)
# TODO: This has to be per grid cell too.
train_samples /= np.mean(train_samples)

for samples in train_samples:

    fig, ax = plt.subplots(1, 1, figsize=(6, 2.5))
    ax.plot(samples, marker="o", linestyle="", label="data")
    ax.grid(linestyle="--", alpha=0.4)

    offset = 0.05 * np.max(train_samples)
    ax.set_ylim(0 - offset, np.max(train_samples) + offset)

    ax.legend(loc="best")

    xs = np.arange(len(samples))
    ptp = np.ptp(samples)

    weighted_avg = np.sum(xs * samples) / np.sum(samples)

    sig = significant_peak(
        samples, diff_threshold=diff_threshold, ptp_threshold=ptp_threshold
    )

    if ptp < ptp_threshold:
        ax.set_title("Accepted: False (ptp too low)")
        assert not sig
        continue

    max_i = np.argmax(samples)

    max_mask = get_local_maxima(samples)
    min_mask = get_local_minima(samples)

    ax.plot(
        xs[max_mask], samples[max_mask], c="C3", marker="o", label="max", linestyle="",
    )
    ax.plot(
        xs[min_mask], samples[min_mask], c="C4", marker="o", label="min", linestyle="",
    )

    if np.sum(max_mask) == 1:
        # If there is only one peak, there is nothing left to do.
        ax.set_title(f"Accepted: True (only 1 peak) {weighted_avg}")
        assert sig
        continue

    # If there are multiple peaks, we have to decide if these local maxima
    # are significant. If they are, there is no clearly defined
    # maximum for this sample.

    # Define significance of the minor peaks as the ratio between the
    # difference (peak value - local minima) and (peak value - local minima)
    # for the global maximum.

    min_indices = np.where(min_mask)[0]

    global_max_index = np.where(samples == np.max(samples[max_mask]))[0][0]

    agg_diffs = {}

    for max_index in np.where(max_mask)[0]:
        # Sample value at the local maximum.
        max_value = samples[max_index]

        # Find the surrounding local minima.
        local_minima = []

        # Find preceding local minima, if any.
        if max_index > 0:
            local_minima.append(
                samples[min_indices[np.where(min_indices < max_index)][-1]]
            )

        # Find following local minima, if any.
        if max_index < (len(samples) - 1):
            local_minima.append(
                samples[min_indices[np.where(min_indices > max_index)][0]]
            )

        local_minima = np.array(local_minima)

        minima_diffs = max_value - local_minima

        agg_diffs[max_index] = np.max(minima_diffs)

    max_diff = agg_diffs[global_max_index]

    # Rescale using the maximum diff.
    for index, diff in agg_diffs.items():
        agg_diffs[index] = diff / max_diff

    if all(
        diff < diff_threshold
        for index, diff in agg_diffs.items()
        if index != global_max_index
    ):
        ax.set_title(f"Accepted: True (sig. peak) {weighted_avg}")
        assert sig
    else:
        ax.set_title("Accepted: False (not sig. peak)")
        assert not sig