In [2]:
# ─────────────────────────────────────────────────────────────────────────────
# 1. Imports
# ─────────────────────────────────────────────────────────────────────────────
import pandas as pd
import numpy as np
import calendar
import ipywidgets as w
from IPython.display import display, HTML, clear_output
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

from worldcereal.utils.refdata import month_diff, get_best_valid_time

HTML("""
<style>
  .widget-label        { font-size:18px !important; white-space:normal !important; }
  .widget-readout      { font-size:18px !important; }
  .widget-slider .ui-slider-handle { font-size:18px !important; }
</style>
""")
plt.rcParams.update({
    "font.family"     : "sans-serif",    
    "font.sans-serif" : ["DejaVu Sans"], 
    "font.size"       : 18
})

# ─────────────────────────────────────────────────────────────────────────────
# 2. Helper function
# ─────────────────────────────────────────────────────────────────────────────
def month_diff_signed(target, source):
    """Shortest signed difference in months (-11 … +11)."""
    diff = (target - source) % 12
    return diff-12 if diff > 6 else diff
    
def evaluate(start_date, end_date, valid_time, buffer, num_ts=12):
    row = pd.Series({"start_date": start_date, "end_date": end_date,
                     "valid_time": valid_time})
    res = []
    for m in range(1, 13):
        row["true_valid_time_month"] = valid_time.month
        row["proposed_valid_time_month"] = m
        row["valid_month_shift_backward"] = month_diff(m, row["true_valid_time_month"])
        row["valid_month_shift_forward"]  = month_diff(row["true_valid_time_month"], m)
        res.append([m, get_best_valid_time(row, buffer, num_ts)])

    df = pd.DataFrame(res, columns=["proposed_month", "resulting_valid_time"])
    df["proposed_month_str"] = df["proposed_month"].map(calendar.month_abbr.__getitem__)
    df["acceptable"] = df["resulting_valid_time"].notna()
    return df

# ─────────────────────────────────────────────────────────────────────────────
# 3. Helper to build a DatePicker + ‹/› buttons
# ─────────────────────────────────────────────────────────────────────────────
def date_picker_with_arrows(label_text, init_date):
    """Return (HBox, datepicker) with label – left – datepicker – right."""
    label = w.Label(value=label_text, layout=w.Layout(width="160px"))  # adjust width if needed
    left  = w.Button(icon="chevron-left",  layout=w.Layout(width="32px"))
    right = w.Button(icon="chevron-right", layout=w.Layout(width="32px"))
    dp    = w.DatePicker(value=init_date, description="", layout=w.Layout(width="140px"))

    def shift(months):
        if dp.value:
            dp.value = (pd.Timestamp(dp.value) + pd.DateOffset(months=months)).to_pydatetime()

    left.on_click(lambda _: shift(-1))
    right.on_click(lambda _: shift(1))

    # order: label → left arrow → datepicker → right arrow
    return w.HBox([label, left, dp, right], layout=w.Layout(align_items="center")), dp

# ─────────────────────────────────────────────────────────────────────────────
# 3. Widgets
# ─────────────────────────────────────────────────────────────────────────────
start_box, start_w = date_picker_with_arrows("Extractions start date", pd.Timestamp("2018-08-01"))
end_box,   end_w   = date_picker_with_arrows("Extractions end date",  pd.Timestamp("2019-11-30"))
valid_box, valid_w = date_picker_with_arrows("True valid time",       pd.Timestamp("2019-06-01"))

buffer_w = w.IntSlider(value=2, min=0, max=6, step=1,
                       description="Buffer (months)",
                       style={"description_width":"initial"},
                       layout=w.Layout(width="300px"))

buffer_note = w.HTML(
    "<i>How close (in months) we allow the true valid_time to be to the edges of the user-defined temporal extent (TE).</i>"
)

# ─────────────────────────────────────────────────────────────────────────────
# 4. Visual grouping
# ─────────────────────────────────────────────────────────────────────────────
frame_title = w.HTML("<b>Hypothetical sample from a public dataset: select available extractions start and end dates and valid time (use ‹/› to jump a month)</b>")
date_box = w.VBox([frame_title, start_box, end_box, valid_box],
                  layout=w.Layout(border="1px solid #ccc",
                                  padding="10px", margin="5px 0"))
dates_frame = w.VBox([date_box])

buffer_w = w.IntSlider(value=2,min=0,max=6,step=1,
                       description="Buffer (months)",
                       style={"description_width":"initial"}, layout=w.Layout(width="280px"))
buffer_note = w.HTML("<i>Allowed proximity (in months) of <code>true valid_time</code> to window edges, or TE edges to available extractions.</i>")

date_frame = w.VBox([w.HTML("<b>Select extraction period and valid&nbsp;time</b>"),
                     start_box, end_box, valid_box],
                    layout=w.Layout(border="1px solid #ccc", padding="10px", margin="5px 0"))

buffer_frame = w.VBox([
    w.HTML("<b>Buffer settings</b>"),
    w.VBox([buffer_w, buffer_note])],
    layout=w.Layout(border="1px solid #ccc", padding="10px", margin="5px 0"))

# Radio selector
radio_sel = w.RadioButtons(
    options=[],                       # filled later
    layout=w.Layout(width="200px",
                    height="300px",
                    overflow_y="auto"),   # ← border removed
    style={"description_width":"0"}
)

radio_sel.add_class("te-table")  # custom class
HTML("""
<style>
/* full-width row */ .te-table .widget-radio-button label{display:block;padding:2px 6px;}
/* stripe */         .te-table .widget-radio-button:nth-child(odd){background:#f7f7f7;}
/* red x */          .te-table .widget-radio-button label:first-child::first-letter{color:#d62728;}
</style>
""")

te_frame = w.VBox([w.HTML("<b>Which Temporal Extent (TE) to align?</b>"), radio_sel],
                  layout=w.Layout(border="1px solid #ccc",padding="10px",margin="5px 0"))
plot_out=w.Output()
plot_frame=w.VBox([plot_out],layout=w.Layout(border="1px solid #ccc",
                                             padding="10px",margin="5px 0"))

# ──────────────────────────────────────────────────────────────
# 5. Plot refresh
# ──────────────────────────────────────────────────────────────
def refresh_plot(df, start, end, vt, buf):
    with plot_out:
        clear_output()
        fig, ax = plt.subplots(figsize=(10,3.6))

        # base NDVI (green curve)
        days = pd.date_range(start - pd.DateOffset(months=2),
                             end   + pd.DateOffset(months=2), freq="D")
        ndvi = 0.4 + 0.35*np.cos((mdates.date2num(days)-mdates.date2num(vt))/365.25*2*np.pi)
        ax.plot(days, ndvi, color="forestgreen", label="Simulated NDVI")

        # available extractions
        ax.axvspan(start, end, color="skyblue", alpha=0.20, label="Available extractions")
        ax.axvline(vt, color="forestgreen", ls="--", lw=1.8, label="True valid_time")

        sel = radio_sel.value
        if sel is not None:
            shift_m  = month_diff_signed(df.loc[sel,"proposed_month"], vt.month)
            new_mid  = vt + pd.DateOffset(months=shift_m)
            new_start= new_mid - pd.DateOffset(months=5)
            new_end  = new_mid + pd.DateOffset(months=6)

            # proposed window (yellow) + buffer zones (red) + available extractions  (sky-blue)
            ax.axvspan(start, end,
                    facecolor="skyblue", alpha=0.25,
                    edgecolor="gold", linewidth=5,
                    label="Available extractions")
            # Proposed TE
            ax.axvspan(new_start, new_end,
                    facecolor="mediumseagreen", alpha=0.20,
                    edgecolor="mediumseagreen", linewidth=2,
                    label="Proposed TE")
            # Buffer zones (two spans)
            ax.axvspan(new_start, new_start + pd.DateOffset(months=buf),
                    facecolor="mediumseagreen",   # green background
                    edgecolor="firebrick",              # hatch (and border) colour
                    hatch='//',                   # diagonal stripes   (/ or \ or x etc.)
                    linewidth=0,                # thin red hatch lines
                    alpha=0.20,                   # overall transparency
                    label="Buffer")
            ax.axvspan(new_end - pd.DateOffset(months=buf), new_end,
                    facecolor="mediumseagreen",   # green background
                    edgecolor="firebrick",              # hatch (and border) colour
                    hatch='//',                   # diagonal stripes   (/ or \ or x etc.)
                    linewidth=0,                # thin red hatch lines
                    alpha=0.20,                   # overall transparency
                    )

            ax.axvline(new_mid,color="firebrick",ls="--",alpha=0.70,lw=1.8,label="Middle of proposed TE")

            ax.annotate("", xy=(new_mid, 0.95), xytext=(vt, 0.95),
                        arrowprops=dict(arrowstyle="->", lw=2, color="black"))

        # month ticks
        ax.xaxis.set_major_locator(mdates.MonthLocator(interval=2))
        ax.xaxis.set_major_formatter(mdates.DateFormatter('%b'))
        ax.tick_params(axis='x', labelsize=10)
        ax.set_ylim(0,1); ax.set_ylabel("NDVI"); ax.set_title("Simulated NDVI")
        order = [
            "Available extractions", "Proposed TE", 
            "Simulated NDVI", "True valid_time",
            "Middle of proposed TE", "Buffer"
        ]
        h, l = ax.get_legend_handles_labels()
        lookup = {lab: hnd for hnd, lab in zip(h, l)}
        ax.legend([lookup[o] for o in order if o in lookup],
                order, loc="upper right")

        if sel is not None:
            te_label = f"{calendar.month_abbr[new_start.month]}–{calendar.month_abbr[new_end.month]}"
            if df.loc[sel, "acceptable"]:
                caption = f"✓ Sample will be used for the {te_label} period (buffer {buf} months)."
            else:
                # test why it is rejected
                too_close = (vt <= new_start + pd.DateOffset(months=buf)) or \
                            (vt >= new_end   - pd.DateOffset(months=buf))
                if too_close:
                    reason = "true valid_time is too close to the TE edges"
                else:
                    reason = "proposed TE extends beyond available extractions (would add NODATA)"
                caption = f"✗ Sample dismissed for {te_label}: {reason}."
        else:
            caption = "Select a temporal extent to evaluate the sample."
        # place caption below the plot area
        fig.text(0.01, -0.12, caption, fontsize=11, va="top")
        # give a little bottom margin so caption is visible
        plt.tight_layout(rect=[0, 0.05, 1, 1])

        plt.tight_layout()
        plt.show()

# ──────────────────────────────────────────────────────────────
# 6. Master redraw
# ──────────────────────────────────────────────────────────────
current_df = {}

def redraw(*_):
    s,e,vt,buf=start_w.value,end_w.value,valid_w.value,buffer_w.value
    if None in (s,e,vt) or e<=s or not(s<vt<e):
        radio_sel.options=[]; plot_out.clear_output()
        with plot_out: print("Choose valid dates: start < valid < end"); return
    df=evaluate(pd.Timestamp(s),pd.Timestamp(e),pd.Timestamp(vt),buf)
    opts=[(f"{row['proposed_month_str']} (✓)" if row['acceptable'] else f"❌ {row['proposed_month_str']}",idx) for idx,row in df.iterrows()]
    keep=radio_sel.value
    radio_sel.options=opts
    radio_sel.value=keep if keep in [v for _l,v in opts] else None
    current_df["df"]=df
    refresh_plot(df, s, e, vt, buf)

# observe changes
for wid in (start_w, end_w, valid_w, buffer_w):
    wid.observe(redraw, "value")

radio_sel.observe(lambda c: refresh_plot(
    current_df["df"],
    start_w.value, end_w.value, valid_w.value, buffer_w.value), "value")

redraw()

# ──────────────────────────────────────────────────────────────
# 7. Layout & display
# ──────────────────────────────────────────────────────────────

left_column = w.VBox(
    [date_frame, buffer_frame, te_frame],
    layout=w.Layout(width="420px")
    )
# left_column.children = (global_style,) + left_column.children
ui=w.HBox([left_column,plot_frame])
display(w.VBox([w.HTML("<h2>Temporal-shift checker for sample acceptability</h2>"),ui]))

VBox(children=(HTML(value='<h2>Temporal-shift checker for sample acceptability</h2>'), HBox(children=(VBox(chi…