In [None]:
from pathlib import Path
from string import ascii_lowercase

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D

from wildfires.utils import get_local_maxima, get_local_minima, significant_peak

fig_dir = Path("./plots")
fig_dir.mkdir(exist_ok=True)
mpl.rc("savefig", bbox="tight")

In [None]:
def visualise_peaks(
    x,
    expected_peaks=(),
    diff_threshold=0.2,
    ptp_threshold=0.2,
    strict=False,
    fig=None,
    ax=None,
    legend=True,
    t_offset=0.2,
    plot_mean=True,
    plot_ptp=True,
    plot_zero=True,
):
    if fig is None and ax is None:
        fig = plt.figure()
    elif fig is None:
        fig = ax.get_figure()
    if ax is None:
        ax = plt.axes()

    x = np.asarray(x)

    ax.plot(x, linestyle="", marker="o", c="C0")
    for count, i in enumerate(expected_peaks):
        ax.plot(
            i,
            x[i],
            linestyle="",
            marker="o",
            c="C2",
            label="expected peak" if count == 0 else None,
        )

    peak_indices, peak_heights = significant_peak(
        x,
        diff_threshold=diff_threshold,
        ptp_threshold=ptp_threshold,
        strict=strict,
        return_peak_heights=True,
    )

    max_height = max(peak_heights.values())

    rescaled_heights = {}
    for index, height in peak_heights.items():
        rescaled_heights[index] = height / max_height

    for count, i in enumerate(peak_indices):
        ax.plot(
            i,
            x[i],
            linestyle="",
            marker="x",
            c="C1",
            alpha=1,
            ms=12,
            label="sig peak" if count == 0 else None,
        )

    xlim = ax.get_xlim()

    labelled = []

    for (count, (index, height)) in enumerate(peak_heights.items()):
        rescaled_height = rescaled_heights[index]
        ax.annotate(
            "",
            xy=[index, x[index]],
            xytext=[index, x[index] - np.sign(x[index]) * height],
            arrowprops=dict(arrowstyle="<->"),
        )
        to_label = x[index] - np.sign(x[index]) * height
        if not np.any(np.isclose(labelled, to_label)):
            ax.hlines(to_label, -1, len(x), linestyle="--", color="k", alpha=0.4)
            labelled.append(to_label)
        if index == 0:
            offset = t_offset
        elif index == len(x) - 1:
            offset = -t_offset
        else:
            offset = 0

        ax.text(
            index + offset,
            x[index] - 0.5 * np.sign(x[index]) * height,
            f"{rescaled_height:0.1f}",
            ha="center",
            va="center",
            bbox=dict(
                boxstyle="round",
                facecolor="w",
                alpha=0.8,
                linestyle="-" if rescaled_height >= diff_threshold else "--",
            ),
        )

    mean = np.mean(x)

    if plot_mean:
        ax.hlines(mean, -1, len(x), linestyle="-.", color="C4", label="mean", alpha=0.5)
    if plot_ptp:
        ax.hlines(
            [mean - ptp_threshold / 2, mean + ptp_threshold / 2],
            -1,
            len(x),
            linestyle="--",
            color="C3",
            label="ptp thres",
            alpha=0.5,
        )
    if plot_zero:
        ax.hlines(0, -1, len(x), linestyle="--", color="k", alpha=0.4)

    ax.set_xlim(xlim)
    ax.grid(linestyle="--", alpha=0.4)
    if legend:
        ax.legend(loc="best")
    return ax

In [None]:
plot_data = (
    ([1, 0.4, -0.5, 0.1, 0.4], dict(diff_threshold=0.45, ptp_threshold=0.2)),
    ([1, 0.4, -0.5, 0.7, 0.3], dict(diff_threshold=0.6, ptp_threshold=0.2)),
    ([1, 0.4, 0.3, 0.7, -0.5], dict(diff_threshold=0.8, ptp_threshold=0.2)),
    ([1, 0.4, 0.2, 0.6, 0.3], dict(diff_threshold=0.8, ptp_threshold=0.2)),
)

fig, axes = plt.subplots(1, len(plot_data), sharex=True, sharey=False, figsize=(11, 3))
for (ax, (x, kwargs), title) in zip(
    axes, plot_data, [f"({l})" for l in ascii_lowercase[: len(plot_data)]]
):
    visualise_peaks(
        x,
        ax=ax,
        t_offset=0.15,
        legend=False,
        plot_mean=False,
        plot_ptp=False,
        plot_zero=False,
        **kwargs,
    )
    ax.set_title(r"$\mathrm{h}_{\mathrm{min}}$ = " + f"{kwargs['diff_threshold']:0.2f}")
    ax.text(0, 1.05, title, transform=ax.transAxes, fontsize=11)


for ax in axes[1:]:
    ax.tick_params(labelbottom=False, labelleft=False)

axes[0].tick_params(labelbottom=False)

labels = [""] * 9
labels[3] = "0"
axes[0].set_yticklabels(labels)

for ax in axes:
    ax.set_ylim(-0.56, 1.06)

# Add shared legend.
legend_elements = [
    Line2D([0], [0], c="C0", linestyle="", marker="o", label="data"),
    Line2D([0], [0], c="C1", linestyle="", marker="x", label="sig. peak"),
    # Line2D([0], [0], c="C3", linestyle="--", label=r"$\mathrm{PTP}_{\mathrm{min}}$"),
    # Line2D([0], [0], c="C4", linestyle="-.", label="mean"),
    Line2D([0], [0], c="k", linestyle="--", label="reference"),
]
fig.tight_layout()
# fig.legend(
#     handles=legend_elements,
#     bbox_to_anchor=(0.0, 0.9, 1, 0.1),
#     loc="lower right",
#     ncol=len(legend_elements),
#     borderaxespad=1.5,
# )
fig.legend(
    handles=legend_elements,
    bbox_to_anchor=(1.015, 0.54, 0.1, 0.1),
    loc="lower right",
    borderaxespad=1.5,
)
fig.savefig((fig_dir / "peak_detection").with_suffix(".pdf"))