In [None]:
import sqlite3
import numpy as np
import pandas as pd
from scipy.stats import gaussian_kde
import matplotlib.cm as cm

from ipyleaflet import Map, CircleMarker, LayerGroup, WidgetControl, basemaps
from ipywidgets import Button, Layout, HTML, RadioButtons, IntSlider, Accordion, VBox, HBox
from IPython.display import display, HTML as DispHTML
import matplotlib.colors as mcolors

display(DispHTML("""
<style>
  /* Hide horizontal overflow on widget controls */
  .leaflet-control.widget-control {
    overflow-x: hidden !important;
  }
  /* Ensure top-right controls overlay other elements */
  .leaflet-top.leaflet-right .leaflet-control.widget-control {
    z-index: 999 !important;
  }

  /* Ensure vertical scroll on radio buttons */
  .widget-radio-box {
    overflow-y: auto !important;
    max-height: 300px !important;
  }
    .widget-area .widget-subarea {
    padding: 0 !important;
  }
  .container {
    width: 100% !important;
    margin: 0 !important;
  }
  .leaflet-control.widget-control {
    overflow-x: hidden !important;
  }
</style>
"""))

INITIAL_CENTER = (37.8, -96.0)
INITIAL_ZOOM = 4
MIN_RADIUS = 5
MAX_RADIUS = 15
CMAP = cm.get_cmap("YlOrRd")

def load_data(csv_path='https://raw.githubusercontent.com/AllenJin0818/heatmap/refs/heads/main/wildfires.csv', n=3000):
    df = pd.read_csv(csv_path)
    df = df[["FIRE_YEAR", "DISCOVERY_DOY", "CONT_DOY", "FIRE_SIZE", "LATITUDE", "LONGITUDE"]]
    df = df.sample(n=n)
    df["DAY_TO_CONT"] = df["CONT_DOY"] - df["DISCOVERY_DOY"]
    return df.dropna(subset=["LATITUDE", "LONGITUDE"]).reset_index(drop=True)


m = Map(
    center=INITIAL_CENTER,
    zoom=INITIAL_ZOOM,
    min_zoom=3,
    max_zoom=10,
    basemap=basemaps.CartoDB.Positron,
    scroll_wheel_zoom=True,
    layout=Layout(width='100%', height='90vh')
)

counter = HTML()
m.add_control(WidgetControl(widget=counter, position='topleft'))

colored_year_layers = {}
neutral_year_layers = {}
year_counts = {}

def plot_markers(df):
    for grp in list(colored_year_layers.values()) + list(neutral_year_layers.values()):
        try:
            m.remove_layer(grp)
        except:
            pass
    colored_year_layers.clear()
    neutral_year_layers.clear()
    year_counts.clear()

    for year in sorted(df["FIRE_YEAR"].unique()):
        subset = df[df["FIRE_YEAR"] == year]
        year_counts[str(year)] = len(subset)

        sizes = subset["FIRE_SIZE"].fillna(0)
        if sizes.max() != sizes.min():
            snorm = (sizes - sizes.min())/(sizes.max() - sizes.min())
        else:
            snorm = np.zeros_like(sizes)
        radii = MIN_RADIUS + snorm*(MAX_RADIUS - MIN_RADIUS)

        grp_col = LayerGroup(name=str(year))
        grp_neu = LayerGroup(name=str(year))
        for (lat, lon), r, s in zip(subset[["LATITUDE","LONGITUDE"]].values, radii, snorm):
            color = mcolors.to_hex(CMAP(s))
            grp_col.add_layer(CircleMarker(
                location=(lat, lon), radius=int(r),
                fill=True, fill_color=color,
                fill_opacity=0.9, stroke=False
            ))
            grp_neu.add_layer(CircleMarker(
                location=(lat, lon), radius=int(r),
                fill=True, fill_color="#888888",
                fill_opacity=0.2, stroke=False
            ))

        colored_year_layers[str(year)] = grp_col
        neutral_year_layers[str(year)] = grp_neu

df0 = load_data()
plot_markers(df0)
years = sorted(df0["FIRE_YEAR"].unique())

year_selector = RadioButtons(
    options=years,
    value=years[0],
    style={'description_width': '0px'}
)
year_selector.layout = Layout(
    height='290px',
    overflow='auto',
    width='180px'
)

scroll_box = VBox([year_selector], layout=Layout(
    width='200px',
    height='300px',
    overflow_y='auto',
    overflow_x='hidden',
    border='1px solid lightgray',
    padding='2px'
))

year_acc = Accordion(children=[scroll_box], layout=Layout(width='200px'))
year_acc.set_title(0, 'Year')
year_acc.selected_index = None

years_back_slider = IntSlider(
    value=0, min=0, max=min(5, len(years)-1), step=1,
    description='', readout=False,
    style={'description_width': '0px'},
    layout=Layout(width='150px')
)
back_label = HTML("<div style='text-align:center; margin-top:4px'><b>Years Back: 0</b></div>")

slider_box = VBox([
    HBox([years_back_slider], layout=Layout(justify_content='center')),
    back_label
])
control_box = VBox([year_acc, slider_box], layout=Layout(width='200px'))
m.add_control(WidgetControl(widget=control_box, position='topright'))

def update_counter():
    base = int(year_selector.value)
    back = years_back_slider.value
    total = sum(year_counts.get(str(base - i), 0) for i in range(back+1))
    counter.value = (
        f"<div style='background:white; padding:6px 10px; "
        f"border-radius:4px; box-shadow:0 1px 3px rgba(0,0,0,0.2); "
        f"font-size:14px;'><b>Total Fires:</b> {total}</div>"
    )

def update_layers(*_):
    for grp in list(colored_year_layers.values()) + list(neutral_year_layers.values()):
        try: m.remove_layer(grp)
        except: pass

    base = str(year_selector.value)
    m.add_layer(colored_year_layers[base])
    for i in range(1, years_back_slider.value + 1):
        prev = str(int(base) - i)
        if prev in neutral_year_layers:
            m.add_layer(neutral_year_layers[prev])

    update_counter()
    back_label.value = f"<div style='text-align:center; margin-top:4px'><b>Years Back: {years_back_slider.value}</b></div>"

year_selector.observe(update_layers, names='value')
years_back_slider.observe(update_layers, names='value')

update_layers()

btn = Button(description="New Random Sample", layout=Layout(width='180px'))
def on_random(_):
    df_new = load_data()
    plot_markers(df_new)
    new_years = sorted(df_new["FIRE_YEAR"].unique())
    year_selector.options = new_years
    year_selector.value   = new_years[0]
    years_back_slider.max = min(5, len(new_years)-1)
    years_back_slider.value = 0
btn.on_click(on_random)
m.add_control(WidgetControl(widget=btn, position='bottomleft'))

reset_btn = Button(description="Reset View", layout=Layout(width='180px'))
reset_btn.on_click(lambda _: (setattr(m, 'center', INITIAL_CENTER), setattr(m, 'zoom', INITIAL_ZOOM)))
m.add_control(WidgetControl(widget=reset_btn, position='bottomleft'))

display(m)