In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import spearmanr
import ipywidgets as widgets
from ipywidgets import Layout
import numpy as np

# Initialize the default dataset
rainfall_url = "https://raw.githubusercontent.com/akhi9661/dumps/refs/heads/main/rainfall_bom-nasa_era5_jan-2022_oct-2024.csv"
temperature_url = "https://raw.githubusercontent.com/akhi9661/dumps/refs/heads/main/temp_bom-nasa_era5_jan-2022_oct-2024.csv"
df = pd.read_csv(rainfall_url).dropna()  # Default to Rainfall dataset

# Function to update the plot based on the widget inputs
def update_plot(site, lag, site_col, var_source, type):
    site_column = site_column_mapping[site_col]

    if type == 'Rainfall':
        var_column = rainfall_source_mapping[var_source]
        bom_column = 'bom_rainfall'
    else:
        var_column = tmp_source_mapping[var_source]
        bom_column = 'tmp_bom'
    
    # Calculate lagged accumulations (ensure using the selected rainfall source column)
    df['bom_rolling'] = df.groupby(site_column)[bom_column].rolling(window=lag).sum().reset_index(0, drop=True)
    df['var_rolling'] = df.groupby(site_column)[var_column].rolling(window=lag).sum().reset_index(0, drop=True)

    # Filter data for the selected site
    subset = df[df[site_column] == site]
    
    # Create the plot
    plt.figure(figsize=(8, 6))
    
    # Original scatterplot
    sns.scatterplot(data=subset, x=bom_column, y=var_column, color='red', label='Original', edgecolors='white', linewidth=0.2, s=25, marker='s')
    
    # Lagged scatterplot
    sns.scatterplot(data=subset, x='bom_rolling', y='var_rolling', color='blue', label=f'Rolling ({lag} days)', edgecolors='white', linewidth=0.2, s=25)
    
    # Calculate correlations
    spearman_orig = spearmanr(subset[bom_column], subset[var_column]).correlation
    spearman_lagged = spearmanr(subset['bom_rolling'].dropna(), subset['var_rolling'].dropna()).correlation
    
    # Display correlations on the plot
    plt.text(0.05, 0.95, f"Spearman Correlation\nOriginal: {spearman_orig:.2f}\nRolling ({lag} days): {spearman_lagged:.2f}",
             transform=plt.gca().transAxes, fontsize=9, verticalalignment='top', bbox=dict(boxstyle="round", alpha=0.05))
    
    # Add labels, legend, and title
    plt.xlabel(f'BOM {type}')
    plt.ylabel(f'{var_source} {type}')
    plt.title(f'Comparison of {var_source} and BOM {type} at {site}', ha='left', x=0, fontsize=15, fontweight='bold')
    plt.legend(loc='upper right', bbox_to_anchor=(1.35, 0.95), ncol=1, title="Series", labelspacing=0.5)
    plt.tight_layout()
    plt.show()

# Define a mapping for user-friendly names to actual DataFrame column names
site_column_mapping = {
    'Site': 'site',  # 'Site' maps to 'site' in the DataFrame
    'BOM Station': 'station'  # 'BOM Station' maps to 'station' in the DataFrame
}

# Define the mapping for Rainfall Source
rainfall_source_mapping = {
    'NASA-POWER': 'nasa_rainfall',
    'ERA5': 'era5_rainfall'
}

# Define the mapping for Temperature Source
tmp_source_mapping = {
    'NASA-POWER': 'tmp_nasa',
    'ERA5': 'tmp_era5'
}

# Function to update the options for the site_widget based on site_selector
def update_site_widget(site_selector_value):
    column_name = site_column_mapping[site_selector_value]
    site_widget.options = df[column_name].unique()
    site_widget.value = df[column_name].unique()[0]  # Default value

# Function to reload data based on the selected type
def reload_data(type_selected):
    global df  # Update the global dataframe
    if type_selected == 'Rainfall':
        df = pd.read_csv(rainfall_url).dropna()
    else:  # Temperature selected
        df = pd.read_csv(temperature_url).dropna()
   
    # Update the site widget options
    update_site_widget(site_selector.value)

# Widgets
type_selector = widgets.ToggleButtons(
    options=['Rainfall', 'Temperature'],
    value='Rainfall',
    description='',
    style={'description_width': 'initial'},
    layout=Layout(width='300px')
)
type_selector.observe(lambda change: reload_data(change.new), names='value')

var_source_widget = widgets.Dropdown(
    options=['NASA-POWER', 'ERA5'],
    value='NASA-POWER',
    description='Source:',
    style={'description_width': 'initial'},
    layout=Layout(width='300px')
)

site_selector = widgets.Dropdown(
    options=['Site', 'BOM Station'],
    value='Site',
    description='Select from:',
    style={'description_width': 'initial'},
    layout=Layout(width='300px')
)
site_selector.observe(lambda change: update_site_widget(change.new), names='value')

site_widget = widgets.Dropdown(
    options=df['site'].unique(),
    value=df['site'].unique()[0],
    description='Site/Station:',
    style={'description_width': 'initial'},
    layout=Layout(width='300px')
)

lag_widget = widgets.IntSlider(
    min=1,
    max=30,
    value=2,
    description='Lag:',
    style={'description_width': 'initial'},
    layout=Layout(width='350px')
)

interactive_plot = widgets.interactive_output(
    update_plot,
    {'site': site_widget, 'lag': lag_widget, 'site_col': site_selector, 'var_source': var_source_widget, 'type': type_selector}
)

# Layout
toggle_column = widgets.VBox([type_selector], layout=Layout(justify_content='flex-start', align_items='flex-start', gap='20px'))
row_1 = widgets.HBox([site_selector, lag_widget])
row_2 = widgets.HBox([site_widget, var_source_widget])
widget_rows = widgets.VBox([row_1, row_2])
final_layout = widgets.HBox([toggle_column, widget_rows], layout=Layout(justify_content='flex-start', align_items='flex-start', gap='20px'))

# Divider line (HTML widget)
divider = widgets.HTML(value="<hr style='border: 1px solid #ddd; margin-top: 20px;'>")

# Display the final layout and interactive plot
display(final_layout, divider, interactive_plot)