In [None]:
import matplotlib.pyplot as plt
import numpy as np

from wildfires.utils import *

plt.rcParams["figure.max_open_warning"] = 200

In [None]:
diff_threshold = 0.4
ptp_threshold = 1

train_samples = np.random.RandomState(1).random((5, 5)) ** 3

train_samples = np.vstack(
    (
        train_samples,
        np.array([0, 0.02, -0.02, 0.005, -0.007]),
        np.array([0, 0.3, -0.02, 0.005, -0.007]),
    )
)

# Normalise.
train_samples -= np.min(train_samples, axis=1).reshape(-1, 1)
# TODO: This has to be per grid cell too.
train_samples /= np.mean(train_samples)

for samples in train_samples:

    fig, ax = plt.subplots(1, 1, figsize=(6, 2.5))
    ax.plot(samples, marker="o", linestyle="", label="data")
    ax.grid(linestyle="--", alpha=0.4)

    offset = 0.05 * np.max(train_samples)
    ax.set_ylim(0 - offset, np.max(train_samples) + offset)

    ax.legend(loc="best")

    xs = np.arange(len(samples))
    ptp = np.ptp(samples)

    weighted_avg = np.sum(xs * samples) / np.sum(samples)

    sig = significant_peak(
        samples, diff_threshold=diff_threshold, ptp_threshold=ptp_threshold
    )

    if ptp < ptp_threshold:
        ax.set_title("Accepted: False (ptp too low)")
        assert not sig
        continue

    max_i = np.argmax(samples)

    max_mask = get_local_maxima(samples)
    min_mask = get_local_minima(samples)

    ax.plot(
        xs[max_mask], samples[max_mask], c="C3", marker="o", label="max", linestyle="",
    )
    ax.plot(
        xs[min_mask], samples[min_mask], c="C4", marker="o", label="min", linestyle="",
    )

    if np.sum(max_mask) == 1:
        # If there is only one peak, there is nothing left to do.
        ax.set_title(f"Accepted: True (only 1 peak) {weighted_avg}")
        assert sig
        continue

    # If there are multiple peaks, we have to decide if these local maxima
    # are significant. If they are, there is no clearly defined
    # maximum for this sample.

    # Define significance of the minor peaks as the ratio between the
    # difference (peak value - local minima) and (peak value - local minima)
    # for the global maximum.

    min_indices = np.where(min_mask)[0]

    global_max_index = np.where(samples == np.max(samples[max_mask]))[0][0]

    agg_diffs = {}

    for max_index in np.where(max_mask)[0]:
        # Sample value at the local maximum.
        max_value = samples[max_index]

        # Find the surrounding local minima.
        local_minima = []

        # Find preceding local minima, if any.
        if max_index > 0:
            local_minima.append(
                samples[min_indices[np.where(min_indices < max_index)][-1]]
            )

        # Find following local minima, if any.
        if max_index < (len(samples) - 1):
            local_minima.append(
                samples[min_indices[np.where(min_indices > max_index)][0]]
            )

        local_minima = np.array(local_minima)

        minima_diffs = max_value - local_minima

        agg_diffs[max_index] = np.max(minima_diffs)

    max_diff = agg_diffs[global_max_index]

    # Rescale using the maximum diff.
    for index, diff in agg_diffs.items():
        agg_diffs[index] = diff / max_diff

    if all(
        diff < diff_threshold
        for index, diff in agg_diffs.items()
        if index != global_max_index
    ):
        ax.set_title(f"Accepted: True (sig. peak) {weighted_avg}")
        assert sig
    else:
        ax.set_title("Accepted: False (not sig. peak)")
        assert not sig

In [None]:
x = np.array([1, 0.4, -0.5, 0.1, 0.3])
plt.plot(x, linestyle="", marker="o")
for i in [0, 2]:
    plt.plot(i, x[i], linestyle="", marker="o", c="C1")
for i in significant_peak(x, 0.5, 0, 0):
    plt.plot(i, x[i], linestyle="", marker="x", c="C2", alpha=0.6, ms=15)
plt.grid(linestyle="--", alpha=0.4)

In [None]:
x = np.array([1, 0.4, -0.5, 0.7, 0.3])
plt.plot(x, linestyle="", marker="o")
for i in [0, 2, 3]:
    plt.plot(i, x[i], linestyle="", marker="o", c="C1")
for i in significant_peak(x, 0.5, 0, 0):
    plt.plot(i, x[i], linestyle="", marker="x", c="C2", alpha=0.6, ms=15)
plt.grid(linestyle="--", alpha=0.4)

In [None]:
x = np.array([1, 0.4, -0.5, 0.7, 0.3])
plt.plot(x, linestyle="", marker="o")
for i in [0, 3]:
    plt.plot(i, x[i], linestyle="", marker="o", c="C1")
for i in significant_peak(x, 0.6, 0, 0):
    plt.plot(i, x[i], linestyle="", marker="x", c="C2", alpha=0.6, ms=15)
plt.grid(linestyle="--", alpha=0.4)