In [None]:
"""
NOTE: 
- Work in progress.
- Follow up on Jin with this script. 
"""

import zipfile
from pathlib import Path

import numpy as np
import pandas as pd
import geopandas as gpd
from sklearn.neighbors import KDTree          # NEW: for past-burn feature
from shapely.geometry import Point
from libpysal.weights import KNN, lag_spatial
from spreg import ML_Lag 

zip_path = Path("ciffc_202389.zip")

# ----------------------------------------------------------------------
# 1) Load CIFFC data from ZIP and keep key columns (same as before)
# ----------------------------------------------------------------------
rows = []
with zipfile.ZipFile(zip_path, "r") as z:
    for name in z.namelist():
        with z.open(name) as f:
            df = pd.read_csv(f)
            df["src_file"] = name
            rows.append(df[[
                "field_agency_fire_id",
                "field_agency_code",
                "field_status_date",
                "field_fire_size",
                "field_latitude",
                "field_longitude",
                "src_file"
            ]])

df = pd.concat(rows, ignore_index=True)

# keep 2023 only
df["field_status_date"] = pd.to_datetime(df["field_status_date"], errors="coerce", utc=True)
df = df[(df["field_status_date"] >= "2023-01-01") & (df["field_status_date"] < "2024-01-01")]

# one record per fire: first time we saw it
fires23 = (df
           .sort_values("field_status_date")
           .groupby("field_agency_fire_id", as_index=False)
           .first())

# basic cleaning
fires23 = fires23.dropna(subset=["field_latitude", "field_longitude"])
fires23["date"] = fires23["field_status_date"].dt.date

# ----------------------------------------------------------------------
# 2) GeoDataFrame and projection (same as before)
# ----------------------------------------------------------------------
gdf = gpd.GeoDataFrame(
    fires23,
    geometry=gpd.points_from_xy(fires23["field_longitude"], fires23["field_latitude"]),
    crs="EPSG:4326"
)
gdf = gdf.to_crs(3978)   # NAD83 / Canada Albers

# this initial KNN is leftover from original code; harmless to keep
coords_list = [(geom.x, geom.y) for geom in gdf.geometry]
w_init = KNN.from_array(coords_list, k=10)   # 10 nearest fires
w_init.transform = "r"                       # row-standard

# ----------------------------------------------------------------------
# 3) Rebuild gdf (template from your original code) and basic fields
# ----------------------------------------------------------------------
gdf = gpd.GeoDataFrame(
    fires23.copy(),
    geometry=gpd.points_from_xy(fires23["field_longitude"], fires23["field_latitude"]),
    crs="EPSG:4326"
).to_crs(3978)  # NAD83 / Canada Albers (meters)

# Coerce types and ensure required cols
gdf["field_fire_size"] = pd.to_numeric(gdf["field_fire_size"], errors="coerce")
gdf["month"] = pd.to_datetime(gdf["field_status_date"], errors="coerce").dt.month
if "field_agency_code" not in gdf.columns:
    gdf["field_agency_code"] = "UNK"

# Drop unusable rows
gdf = gdf.dropna(subset=["field_fire_size", "month", "geometry"]).copy()
gdf["month"] = gdf["month"].astype(int)

# ----------------------------------------------------------------------
# 3b) NEW: compute "past_burn_same_spot" feature
#     (total area burned before this fire within a small radius)
# ----------------------------------------------------------------------
# Sort by time so "past" is well-defined
gdf = gdf.sort_values("field_status_date").reset_index(drop=True)

coords = np.column_stack([gdf.geometry.x.values, gdf.geometry.y.values])
dates = pd.to_datetime(gdf["field_status_date"], errors="coerce")
sizes = gdf["field_fire_size"].to_numpy(dtype=float)

same_spot_radius_m = 5000   # 5 km "same spot" radius
max_days_back = 365         # up to 1 year back (within 2023: earlier in year)

tree = KDTree(coords)
past_burn_same_spot = np.zeros(len(gdf))

for i in range(len(gdf)):
    t_i = dates.iloc[i]

    # neighbours within same_spot_radius_m
    idx_space = tree.query_radius(coords[i:i+1], r=same_spot_radius_m)[0]

    # only earlier fires
    idx_past = idx_space[idx_space < i]
    if idx_past.size == 0:
        continue

    dt_days = (t_i - dates.iloc[idx_past]).dt.days
    mask_time = (dt_days >= 0) & (dt_days <= max_days_back)
    idx_final = idx_past[mask_time]

    if idx_final.size > 0:
        past_burn_same_spot[i] = sizes[idx_final].sum()

gdf["past_burn_same_spot"] = past_burn_same_spot

# ----------------------------------------------------------------------
# 4) Spatial weights (KNN) base setup (same idea as before)
# ----------------------------------------------------------------------
coords = np.column_stack([gdf.geometry.x.values, gdf.geometry.y.values])
n = len(gdf)
if n < 10:
    raise RuntimeError(f"Too few observations after cleaning (n={n}). Need more points for a stable spatial model.")

# ----------------------------------------------------------------------
# 5) Design matrix X (constant + months + past_burn_same_spot) & outcome y
#    NOTE: agency dummies are dropped (prof: "variable being agency not good")
# ----------------------------------------------------------------------
# Month as dummies (one base month omitted)
month_dum = pd.get_dummies(gdf["month"].astype(int),
                           prefix="month",
                           drop_first=True)

X_df = pd.concat(
    [
        pd.Series(1.0, index=gdf.index, name="const"),                 # constant
        month_dum.astype(float),                                       # months
        gdf[["past_burn_same_spot"]].astype(float)                     # past fires in same spot
        # NOTE: no agency_dum here
    ],
    axis=1
)

X = X_df.to_numpy(dtype=float)
y = np.log1p(gdf["field_fire_size"].to_numpy(dtype=float)).reshape(-1, 1)  # log1p to tame heavy tail
y_name = "log1p_fire_size"
name_x = list(X_df.columns)

# Ensure finiteness; rebuild coords if we dropped rows
mask = np.isfinite(y).ravel() & np.all(np.isfinite(X), axis=1)
if mask.sum() != len(mask):
    gdf = gdf.loc[mask].copy()
    X = X[mask, :]
    y = y[mask, :]
    coords = coords[mask, :]
    n = len(gdf)

# ----------------------------------------------------------------------
# 6) Fit Spatial Lag (SAR) with hyper-parameter tuning over k
# ----------------------------------------------------------------------
k_list = [4, 6, 8, 10, 12, 16, 20, 25, 30]    # simple hyper-parameter grid for k

best_aic = np.inf
best_model = None
best_w = None
best_k = None

for k_try in k_list:
    k_eff = int(min(k_try, max(1, n - 1)))  # safety

    w_try = KNN.from_array(coords, k=k_eff)
    w_try.transform = "r"

    model_try = ML_Lag(
        y, X, w=w_try,
        name_y=y_name,
        name_x=name_x,
        name_w=f"knn{k_eff}"
    )

    print(f"\n===== k = {k_eff} =====")
    print(model_try.summary)

    if float(model_try.aic) < best_aic:
        best_aic = float(model_try.aic)
        best_model = model_try
        best_w = w_try
        best_k = k_eff

# Use best model/weights so the diagnostics section stays the same pattern
model = best_model
w = best_w
k = best_k

# ----------------------------------------------------------------------
# 7) Quick diagnostics (same idea as your original code)
# ----------------------------------------------------------------------
gdf["w_log1p_size"] = lag_spatial(w, y.ravel())

def as_scalar(x):
    return float(np.asarray(x).squeeze())

print("\nRows used:", len(gdf))
print(f"Best k: {k}")
print(f"Spatial rho (autoreg parameter): {as_scalar(model.rho):.4f}")
print(f"AIC: {as_scalar(model.aic):.2f}  LogLik: {as_scalar(model.logll):.2f}")


 There are 23 disconnected components.
  W.__init__(self, neighbors, id_order=ids, **kwargs)



===== k = 4 =====
REGRESSION RESULTS
------------------

SUMMARY OF OUTPUT: MAXIMUM LIKELIHOOD SPATIAL LAG (METHOD = FULL)
-----------------------------------------------------------------
Data set            :     unknown
Weights matrix      :        knn4
Dependent Variable  :log1p_fire_size                Number of Observations:        6439
Mean dependent var  :      1.5459                Number of Variables   :          10
S.D. dependent var  :      2.6289                Degrees of Freedom    :        6429
Pseudo R-squared    :      0.2792
Spatial Pseudo R-squared:  0.0616
Log likelihood      : -14541.4014
Sigma-square ML     :      5.0724                Akaike info criterion :   29102.803
S.E of regression   :      2.2522                Schwarz criterion     :   29170.504

------------------------------------------------------------------------------------
            Variable     Coefficient       Std.Error     z-Statistic     Probability
-----------------------------------------

 There are 4 disconnected components.
  W.__init__(self, neighbors, id_order=ids, **kwargs)



===== k = 6 =====
REGRESSION RESULTS
------------------

SUMMARY OF OUTPUT: MAXIMUM LIKELIHOOD SPATIAL LAG (METHOD = FULL)
-----------------------------------------------------------------
Data set            :     unknown
Weights matrix      :        knn6
Dependent Variable  :log1p_fire_size                Number of Observations:        6439
Mean dependent var  :      1.5459                Number of Variables   :          10
S.D. dependent var  :      2.6289                Degrees of Freedom    :        6429
Pseudo R-squared    :      0.3043
Spatial Pseudo R-squared:  0.0609
Log likelihood      : -14429.0200
Sigma-square ML     :      4.8942                Akaike info criterion :   28878.040
S.E of regression   :      2.2123                Schwarz criterion     :   28945.741

------------------------------------------------------------------------------------
            Variable     Coefficient       Std.Error     z-Statistic     Probability
-----------------------------------------