In [None]:
from crococamp.utils.config import read_config
config = read_config("../demo/config.yaml")

import dask.dataframe as dd
ddf = dd.read_parquet(config["parquet_folder"])
ddf.head()

In [None]:
##########################################################################
## Imports
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
from datetime import timedelta
##########################################################################

## Widget & Variable Initialization
# Output widget for plot display
output = widgets.Output()

# Select available observation types
type_options = ddf["type"].drop_duplicates().compute().sort_values().tolist()
type_dropdown = widgets.Dropdown(
    options=type_options,
    value="FLOAT_TEMPERATURE" if "FLOAT_TEMPERATURE" in type_options else type_options[0],
    description="Observation type:",
)

# Select plotted variable
disallowed_plotvars = ["time", "type", "longitude", "latitude", "vertical"]
refvar_options = [val for val in ddf.columns.to_list() if val not in disallowed_plotvars]
refvar_dropdown = widgets.Dropdown(
    options=refvar_options,
    value="residual" if "residual" in refvar_options else refvar_options[0],
    description="Plotted variable",
)

# Global variables for filtered dataframe and metadata
filtered_df = None
min_time = None
max_time = None
total_hours = None
plot_var = None

# Window sliders for selecting the time window
default_window_hours = 24 * 7 * 4  # 4 weeks
window_slider = widgets.IntSlider(
    value=default_window_hours,
    min=1,
    max=default_window_hours,
    step=1,
    description='Window (hrs):',
    style={'description_width': 'initial'},
    continuous_update=False
)
window_text = widgets.Text(
    value='',
    description='Override window:',
    placeholder='e.g. 4 weeks, 3 days, 48 hours',
    style={'description_width': 'initial'}
)

# Center time slider initialization (dummy value to avoid TraitError)
dummy_time = pd.Timestamp('2000-01-01 00:00:00')
center_slider = widgets.SelectionSlider(
    options=[dummy_time],
    value=dummy_time,
    description='Center time:',
    style={'description_width': 'initial'},
    continuous_update=False
)

# Colorbar slider for map color limits
colorbar_slider = widgets.FloatRangeSlider(
    value=[0, 1],
    min=0,
    max=1,
    step=0.01,
    description='Colorbar limits:',
    style={'description_width': 'initial'},
    continuous_update=False
)

# Map limits
padding = 5
lon_min = (ddf['longitude'].min().compute()%180-180) - padding
if lon_min < -170:
    lon_min = -180
lon_max = (ddf['longitude'].max().compute()%180-180) + padding
if lon_max > 170:
    lon_min = 180
lat_min = ddf['latitude'].min().compute() - padding
if lat_min < -80:
    lon_min = -90
lat_max = ddf['latitude'].max().compute() + padding
if lat_max > 80:
    lon_min = 90

##########################################################################
# -----------------------------------------------------------------------#
# Methods                                                                #
# -----------------------------------------------------------------------#

def parse_window(text):
    """Parse a human-readable window string to a timedelta."""
    import re
    text = text.strip().lower()
    if not text:
        return None
    patterns = [
        (r'(\d+)\s*weeks?', 'weeks'),
        (r'(\d+)\s*days?', 'days'),
        (r'(\d+)\s*hours?', 'hours'),
        (r'(\d+)\s*minutes?', 'minutes'),
    ]
    kwargs = {}
    for pat, unit in patterns:
        m = re.search(pat, text)
        if m:
            kwargs[unit] = int(m.group(1))
    if kwargs:
        return timedelta(**kwargs)
    try:
        return pd.to_timedelta(text)
    except Exception:
        return None

#------------------------------------------------------------------------------#
def get_window_timedelta():
    """Get current window timedelta from text or slider."""
    td = parse_window(window_text.value)
    if td is None:
        td = timedelta(hours=window_slider.value)
    return td

#------------------------------------------------------------------------------#
def update_refvar(selected_type):
    """Update global variable for the plotted variable."""
    global plot_var
    plot_var = selected_type

#------------------------------------------------------------------------------#
def update_filtered_df(selected_type):
    """Update filtered dataframe and its time metadata for a selected type."""
    global filtered_df, min_time, max_time, total_hours
    filtered_df = ddf[
        (ddf["type"] == selected_type)
    ].persist()
    min_time = filtered_df['time'].min().compute()
    max_time = filtered_df['time'].max().compute()
    total_hours = int((max_time - min_time).total_seconds() // 3600)

#------------------------------------------------------------------------------#
def update_center_slider(window_td):
    """Update center time slider options based on currently filtered data and window."""
    if min_time is None or max_time is None:
        center_slider.options = [dummy_time]
        center_slider.value = dummy_time
        return
    half_window = window_td / 2
    center_min = min_time + half_window
    center_max = max_time - half_window
    if center_min > center_max:
        center_slider.options = [min_time]
        center_slider.value = min_time
        return
    options = pd.date_range(center_min, center_max, freq='1h')
    center_slider.options = options
    if center_slider.value not in options:
        center_slider.value = options[0] if len(options) > 0 else min_time

#------------------------------------------------------------------------------#
def update_colorbar_slider():
    """Update the colorbar slider limits and step according to the filtered data
    and selected variable. Uses the current time window.

    Sets slider min/max/step based on data range and percentiles.
    """
    t0 = center_slider.value - get_window_timedelta() / 2
    t1 = center_slider.value + get_window_timedelta() / 2
    df_win = filtered_df[
        (filtered_df['time'] >= t0) &
        (filtered_df['time'] <= t1)
    ]
    try:
        col_min = df_win[plot_var].min().compute()
        col_max = df_win[plot_var].max().compute()
        step = (col_max - col_min) / 100. if (col_max - col_min) > 0 else 0.01
        colorbar_slider.min = float(col_min)
        colorbar_slider.max = float(col_max)
        colorbar_slider.step = step
        q_low = df_win[plot_var].quantile(0.01).compute()
        q_high = df_win[plot_var].quantile(0.99).compute()
        colorbar_slider.value = [
            float(q_low),
            float(q_high)
        ]
    except Exception:
        colorbar_slider.min = 0
        colorbar_slider.max = 1
        colorbar_slider.value = [0, 1]

# ----------------------------------------------------------------------- #
def plot_map(center, window_td):
    """Plot the reference map showing mean values of plot_var for each location (lat, lon)
    within the selected time window.

    Arguments:
    center     -- center time for the window
    window_td  -- window timedelta

    Uses cartopy for map, applies color limits from colorbar_slider.
    """
    with output:
        clear_output(wait=True)
        if filtered_df is None or center is None or window_td is None:
            print("No data selected.")
            return
        t0 = center - window_td / 2
        t1 = center + window_td / 2
        df_win = filtered_df[
            (filtered_df['time'] >= t0) &
            (filtered_df['time'] <= t1)
        ]
        ref_df = df_win.groupby(['latitude', 'longitude'])[plot_var].mean().compute().reset_index()
        fig = plt.figure(figsize=(18, 10))
        ax = plt.axes(projection=ccrs.PlateCarree())
        ax.add_feature(cfeature.COASTLINE, linewidth=0.8)
        ax.add_feature(cfeature.LAND, color='lightgray', alpha=0.5)
        ax.add_feature(cfeature.OCEAN, color='lightblue', alpha=0.3)
        ax.add_feature(cfeature.BORDERS, linewidth=0.5)
        if not ref_df.empty:
            vmin, vmax = colorbar_slider.value
            scatter = ax.scatter(
                ref_df['longitude'],
                ref_df['latitude'],
                s=100,
                alpha=0.7,
                c=ref_df[plot_var],
                vmin=vmin,
                vmax=vmax,
                cmap='cividis',
                label=f'Differences (n={len(ref_df):,})',
                marker='o',
                edgecolors='none',
                transform=ccrs.PlateCarree()
            )
            plt.colorbar(scatter)
            ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
        else:
            ax.set_global()
            ax.text(
                0.5,
                0.5,
                'No data in selected window',
                ha='center',
                va='center',
                fontsize=20,
                transform=ax.transAxes
            )
        gl = ax.gridlines(
            draw_labels=True,
            linewidth=1,
            color='gray',
            alpha=0.5,
            linestyle='--'
        )
        gl.top_labels = False
        gl.right_labels = False
        plt.title(
            f'Obs relative differences\n({len(ref_df):,} points)\nTime window: {window_td}',
            fontsize=16,
            pad=20
        )
        plt.legend(loc='upper left', bbox_to_anchor=(0.02, 0.98))
        plt.tight_layout()
        plt.show()

# ----------------------------------------------------------------------- #
def on_type_change(change):
    """Callback for type dropdown change event.
    Updates filtered dataframe and UI components, then triggers plot update.
    """
    update_filtered_df(change['new'])
    window_slider.max = max(total_hours, 1)
    window_slider.value = min(window_slider.value, window_slider.max)
    update_center_slider(get_window_timedelta())
    update_colorbar_slider()
    plot_map(center_slider.value, get_window_timedelta())

# ----------------------------------------------------------------------- #
def on_refvar_change(change):
    """Callback for reference variable dropdown change event.
    Updates reference variable, colorbar, and triggers plot update.
    """
    update_refvar(change['new'])
    window_slider.max = max(total_hours, 1)
    window_slider.value = min(window_slider.value, window_slider.max)
    update_center_slider(get_window_timedelta())
    update_colorbar_slider()
    plot_map(center_slider.value, get_window_timedelta())

# ----------------------------------------------------------------------- #
def on_window_change(change):
    """Callback for window slider/text change event.
    Updates center slider, colorbar, and triggers plot update.
    """
    window_td = get_window_timedelta()
    update_center_slider(window_td)
    update_colorbar_slider()
    plot_map(center_slider.value, window_td)

# ----------------------------------------------------------------------- #
def on_center_change(change):
    """Callback for center slider change event.
    Updates colorbar and triggers plot update.
    """
    update_colorbar_slider()
    plot_map(center_slider.value, get_window_timedelta())

# ----------------------------------------------------------------------- #
def on_colorbar_change(change):
    """Callback for colorbar slider change event.
    Triggers plot update with new color limits.
    """
    plot_map(center_slider.value, get_window_timedelta())

##########################################################################
## Widget Observers Setup
refvar_dropdown.observe(on_refvar_change, names='value')
type_dropdown.observe(on_type_change, names='value')
window_slider.observe(on_window_change, names='value')
window_text.observe(on_window_change, names='value')
center_slider.observe(on_center_change, names='value')
colorbar_slider.observe(on_colorbar_change, names='value')

## Initial setup and display
update_refvar(refvar_dropdown.value)
update_filtered_df(type_dropdown.value)
window_slider.max = max(total_hours, 1)
update_center_slider(get_window_timedelta())
update_colorbar_slider()
display(widgets.VBox([
    refvar_dropdown,
    type_dropdown,
    window_slider,
    window_text,
    center_slider,
    colorbar_slider,
    output
]))
plot_map(center_slider.value, get_window_timedelta())