In [4]:
import os
import glob
from datetime import date, timedelta
import json

import numpy as np
import pandas as pd
import rasterio
from rasterio.transform import rowcol
from pyproj import Transformer

from sklearn.linear_model import LogisticRegression

# ---------------- CONFIG ----------------

FWI_DIR   = "fwi/fwi_tifs_2025_09_12_2025_10_10/"
ECMWF_DIR = "ecmwf/ecmwf_on_fwi_grid/"
CIFFC_FILE = "ciffc_wildfires_20251107.csv"  # CSV or JSON supported

OUT_DIR = "outputs/logreg_fire_prob_20251002_1008"

# Time config
SEQ_LEN          = 7
DATA_START_DATE  = date(2025, 9, 12)   # index 0 in FWI/ECMWF series
PRED_START_DATE  = date(2025, 10, 2)
PRED_END_DATE    = date(2025, 10, 8)

# Train on historical days only (avoid peeking into prediction horizon)
TRAIN_END_DATE   = PRED_START_DATE - timedelta(days=1)

# CIFFC fields (UPDATE TO MATCH YOUR FILE)
CIFFC_DATE_FIELD = "field_situation_report_date"
CIFFC_LAT_FIELD  = "field_latitude"
CIFFC_LON_FIELD  = "field_longitude"
CIFFC_CRS = "EPSG:4326"  # still lat/lon

# Logistic regression sampling (to keep the dataset manageable)
N_SAMPLES_TOTAL = 200_000
RANDOM_SEED = 42

# Normalization constants used for features
FWI_NORM_DENOM = 30.0          # typical FWI scaling for normalization
DRYNESS_MIN_K = 5.0            # 5 K
DRYNESS_RANGE_K = 15.0         # 5â€“20 K mapped into [0,1]


In [5]:
def read_stack_from_tifs(tif_paths):
    """Read list of GeoTIFFs into [T, C, H, W] (C is band count)."""
    arrays = []
    for path in tif_paths:
        with rasterio.open(path) as src:
            arr = src.read()  # [C, H, W]
            arrays.append(arr)
    return np.stack(arrays, axis=0)  # [T, C, H, W]


def read_singleband_stack(tif_paths):
    """Read list of single-band GeoTIFFs into [T, H, W]."""
    arrays = []
    for path in tif_paths:
        with rasterio.open(path) as src:
            arr = src.read(1)  # [H, W]
            arrays.append(arr)
    return np.stack(arrays, axis=0)  # [T, H, W]


def save_probability_map(prob_2d, ref_tif, out_path):
    """Save [H, W] float32 array in [0, 1] as GeoTIFF using ref_tif georeference."""
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    with rasterio.open(ref_tif) as src:
        profile = src.profile.copy()
    profile.update(
        driver="GTiff",
        count=1,
        dtype="float32",
        compress="lzw",
    )
    with rasterio.open(out_path, "w", **profile) as dst:
        dst.write(prob_2d.astype("float32"), 1)


def load_ciffc_by_day(path: str):
    """
    Load CIFFC file (CSV or JSON) and return {date -> DataFrame(rows for that date)}.
    - CSV: uses pandas.read_csv
    - JSON: expects either {"rows": [...]} or a list of row dicts
    """
    if path.lower().endswith(".csv"):
        df = pd.read_csv(path)
    else:
        with open(path, "r", encoding="utf-8", errors="ignore") as f:
            data = json.load(f)
        if isinstance(data, dict) and "rows" in data:
            df = pd.DataFrame(data["rows"])
        elif isinstance(data, list):
            df = pd.DataFrame(data)
        else:
            raise ValueError("Unsupported CIFFC JSON format. Expected dict with 'rows' or list of rows.")

    for col in [CIFFC_DATE_FIELD, CIFFC_LAT_FIELD, CIFFC_LON_FIELD]:
        if col not in df.columns:
            raise ValueError(
                f"Expected CIFFC field '{col}' not found. "
                f"Available columns: {list(df.columns)}"
            )

    df["ciffc_date"] = pd.to_datetime(df[CIFFC_DATE_FIELD]).dt.date

    grouped = {d: g.reset_index(drop=True) for d, g in df.groupby("ciffc_date")}
    return grouped


def rasterize_ciffc_per_day(ciffc_by_day, date_list, ref_tif_path, ciffc_crs=CIFFC_CRS):
    """
    Produce [N_days, H, W] CIFFC rasters aligned to FWI grid.
    Cell = 1 if any fire record for that date falls in that cell, else 0.
    """
    with rasterio.open(ref_tif_path) as src:
        H = src.height
        W = src.width
        transform = src.transform
        fwi_crs = src.crs

    transformer = Transformer.from_crs(ciffc_crs, fwi_crs, always_xy=True)

    out = np.zeros((len(date_list), H, W), dtype=np.float32)

    for idx, d in enumerate(date_list):
        if d not in ciffc_by_day:
            continue

        g = ciffc_by_day[d]
        lons = g[CIFFC_LON_FIELD].values
        lats = g[CIFFC_LAT_FIELD].values
        if len(lons) == 0:
            continue

        xs, ys = transformer.transform(lons, lats)
        for x, y in zip(xs, ys):
            row, col = rowcol(transform, x, y)
            if 0 <= row < H and 0 <= col < W:
                out[idx, row, col] = 1.0

    return out  # [N_days, H, W]


def window_feature_maps(
    fwi_window,     # [T, C_fwi, H, W]
    d2m_window,     # [T, H, W]
    t2m_window,     # [T, H, W]
    ciffc_window,   # [T, H, W]
    fwi_band_index=0,
):
    """
    Compute per-cell features from a T-day window:
      - fwi_max_norm: max FWI over window, normalized by FWI_NORM_DENOM
      - dryness_norm: max (T2m - Td2m) over window, mapped into [0,1]
      - recent_fire:  1 if any fire in the lookback window, else 0
    Returns three [H, W] float32 arrays.
    """
    T, Cf, H, W = fwi_window.shape
    assert fwi_band_index < Cf, "FWI band index out of range."

    # FWI max
    fwi_band = fwi_window[:, fwi_band_index, :, :].astype(np.float32)
    fwi_band = np.nan_to_num(fwi_band, nan=0.0, posinf=0.0, neginf=0.0)
    fwi_max = np.max(fwi_band, axis=0)  # [H,W]
    fwi_max_norm = np.clip(fwi_max / float(FWI_NORM_DENOM), 0.0, 5.0)  # cap extreme outliers

    # Dryness max (T2m - Td2m)
    dew_dep = (t2m_window - d2m_window).astype(np.float32)
    dew_dep = np.nan_to_num(dew_dep, nan=0.0, posinf=0.0, neginf=0.0)
    dew_dep_max = np.max(dew_dep, axis=0)
    dryness_norm = (dew_dep_max - float(DRYNESS_MIN_K)) / float(DRYNESS_RANGE_K)
    dryness_norm = np.clip(dryness_norm, 0.0, 1.0)

    # Recent fire
    recent_fire = (np.max(ciffc_window, axis=0) > 0).astype(np.float32)

    return fwi_max_norm.astype(np.float32), dryness_norm.astype(np.float32), recent_fire.astype(np.float32)


def fit_logistic_regression_from_samples(
    fwi_paths, d2m_paths, t2m_paths, ciffc_daily, date_list,
    train_end_date, seq_len=7, fwi_band_index=0,
    n_samples_total=200_000, random_seed=42,
):
    """
    Build a sampled training dataset and fit sklearn LogisticRegression.

    Label definition:
      y_{t, i} = 1 if CIFFC indicates a fire in cell i on the target day t.

    Features:
      x_{t, i} derived from the seq_len-day lookback window ending at t.
    """
    rng = np.random.default_rng(random_seed)

    # determine training indices (must have a full window and be <= train_end_date)
    idx_last = (train_end_date - DATA_START_DATE).days
    idx_last = min(idx_last, len(date_list) - 1)
    idx_first = seq_len - 1
    train_indices = list(range(idx_first, idx_last + 1))
    if not train_indices:
        raise ValueError("No valid training days. Check TRAIN_END_DATE and DATA_START_DATE/SEQ_LEN.")

    n_days = len(train_indices)
    per_day = max(1, n_samples_total // n_days)

    X_list = []
    y_list = []

    for idx_end in train_indices:
        idx_start = idx_end - (seq_len - 1)

        # read windows
        fwi_window = read_stack_from_tifs(fwi_paths[idx_start:idx_end + 1])
        d2m_window = read_singleband_stack(d2m_paths[idx_start:idx_end + 1])
        t2m_window = read_singleband_stack(t2m_paths[idx_start:idx_end + 1])
        ciffc_window = ciffc_daily[idx_start:idx_end + 1, :, :]

        fwi_norm, dry_norm, recent_fire = window_feature_maps(
            fwi_window, d2m_window, t2m_window, ciffc_window, fwi_band_index=fwi_band_index
        )

        # label is fire on the target day idx_end
        y_map = ciffc_daily[idx_end, :, :].astype(np.int32)

        fwi_flat = fwi_norm.ravel()
        dry_flat = dry_norm.ravel()
        fire_flat = recent_fire.ravel()
        y_flat = y_map.ravel()

        pos = np.flatnonzero(y_flat == 1)
        neg = np.flatnonzero(y_flat == 0)

        # Aim for ~50/50 per day when possible (class imbalance is severe otherwise)
        n_pos_target = per_day // 2
        n_pos = min(len(pos), n_pos_target)
        n_neg = per_day - n_pos

        if n_pos > 0:
            pos_sel = rng.choice(pos, size=n_pos, replace=False)
        else:
            pos_sel = np.array([], dtype=np.int64)

        # If there are not enough negatives (unlikely), sample with replacement
        if len(neg) >= n_neg:
            neg_sel = rng.choice(neg, size=n_neg, replace=False)
        else:
            neg_sel = rng.choice(neg, size=n_neg, replace=True)

        sel = np.concatenate([pos_sel, neg_sel])
        rng.shuffle(sel)

        X = np.column_stack([fwi_flat[sel], dry_flat[sel], fire_flat[sel]]).astype(np.float32)
        y = y_flat[sel].astype(np.int32)

        X_list.append(X)
        y_list.append(y)

        d = date_list[idx_end]
        print(f"Training samples for {d}: pos={int(y.sum())}, neg={len(y)-int(y.sum())}")

    X_all = np.vstack(X_list)
    y_all = np.concatenate(y_list)

    print(f"Total training samples: {len(y_all)} (pos={int(y_all.sum())}, neg={len(y_all)-int(y_all.sum())})")

    model = LogisticRegression(
        solver="lbfgs",
        max_iter=300,
        class_weight="balanced",
    )
    model.fit(X_all, y_all)

    return model


In [6]:
# 1) Locate input rasters
fwi_paths = sorted(glob.glob(os.path.join(FWI_DIR, "*.tif")))
if not fwi_paths:
    raise RuntimeError(f"No FWI tifs found in {FWI_DIR}")
num_days = len(fwi_paths)
print(f"Found {num_days} FWI rasters.")

d2m_paths = sorted(glob.glob(os.path.join(ECMWF_DIR, "*d2m*.tif")))
t2m_paths = sorted(glob.glob(os.path.join(ECMWF_DIR, "*t2m*.tif")))
if not d2m_paths:
    raise RuntimeError(f"No ECMWF d2m (dewpoint) tifs found in {ECMWF_DIR}")
if not t2m_paths:
    raise RuntimeError(f"No ECMWF t2m (temperature) tifs found in {ECMWF_DIR}")
if not (len(d2m_paths) == len(t2m_paths) == num_days):
    raise RuntimeError(
        f"Mismatch in daily counts: FWI={num_days}, d2m={len(d2m_paths)}, t2m={len(t2m_paths)}. "
        "They must match and be time-aligned."
    )
print(f"Found {len(d2m_paths)} d2m rasters and {len(t2m_paths)} t2m rasters.")

# 2) Build date index
date_list = [DATA_START_DATE + timedelta(days=i) for i in range(num_days)]

# 3) Load CIFFC points and rasterize onto the FWI grid (daily)
print(f"Loading CIFFC file: {CIFFC_FILE}")
ciffc_by_day = load_ciffc_by_day(CIFFC_FILE)
print(f"CIFFC records for {len(ciffc_by_day)} unique dates.")

ciffc_daily = rasterize_ciffc_per_day(
    ciffc_by_day=ciffc_by_day,
    date_list=date_list,
    ref_tif_path=fwi_paths[0],
)

# 4) Grid info
with rasterio.open(fwi_paths[0]) as src:
    C_fwi = src.count
    H = src.height
    W = src.width
print(f"FWI channels: {C_fwi}, grid size: H={H}, W={W}")

os.makedirs(OUT_DIR, exist_ok=True)

# 5) Fit logistic regression model from sampled training data
logreg_model = fit_logistic_regression_from_samples(
    fwi_paths=fwi_paths,
    d2m_paths=d2m_paths,
    t2m_paths=t2m_paths,
    ciffc_daily=ciffc_daily,
    date_list=date_list,
    train_end_date=TRAIN_END_DATE,
    seq_len=SEQ_LEN,
    fwi_band_index=0,
    n_samples_total=N_SAMPLES_TOTAL,
    random_seed=RANDOM_SEED,
)

print("Fitted LogisticRegression coefficients:")
print("  intercept:", logreg_model.intercept_)
print("  coef:", logreg_model.coef_)


Found 29 FWI rasters.
Found 29 d2m rasters and 29 t2m rasters.
Loading CIFFC file: ciffc_wildfires_20251107.json
CIFFC records for 250 unique dates.
FWI channels: 1, grid size: H=2281, W=2709
Training samples for 2025-09-18: pos=25, neg=14260
Training samples for 2025-09-19: pos=13, neg=14272
Training samples for 2025-09-20: pos=37, neg=14248
Training samples for 2025-09-21: pos=20, neg=14265
Training samples for 2025-09-22: pos=17, neg=14268
Training samples for 2025-09-23: pos=20, neg=14265
Training samples for 2025-09-24: pos=14, neg=14271
Training samples for 2025-09-25: pos=16, neg=14269
Training samples for 2025-09-26: pos=10, neg=14275
Training samples for 2025-09-27: pos=21, neg=14264
Training samples for 2025-09-28: pos=25, neg=14260
Training samples for 2025-09-29: pos=22, neg=14263
Training samples for 2025-09-30: pos=15, neg=14270
Training samples for 2025-10-01: pos=23, neg=14262
Total training samples: 199990 (pos=278, neg=199712)
Fitted LogisticRegression coefficients:
 

In [7]:
# 6) Build target date list for predictions
target_dates = []
d = PRED_START_DATE
while d <= PRED_END_DATE:
    target_dates.append(d)
    d += timedelta(days=1)

print("Target dates:", target_dates)

# 7) Loop over target dates and generate probability maps
chunk_size = 250_000  # prediction chunk size for memory control

for target_date in target_dates:
    idx_end = (target_date - DATA_START_DATE).days
    idx_start = idx_end - (SEQ_LEN - 1)

    if idx_start < 0 or idx_end >= num_days:
        print(f"Skipping {target_date}: window [{idx_start}, {idx_end}] out of range.")
        continue

    print(f"{target_date}: using window indices [{idx_start}, {idx_end}]")

    # Read windows
    fwi_window = read_stack_from_tifs(fwi_paths[idx_start:idx_end + 1])
    d2m_window = read_singleband_stack(d2m_paths[idx_start:idx_end + 1])
    t2m_window = read_singleband_stack(t2m_paths[idx_start:idx_end + 1])
    ciffc_window = ciffc_daily[idx_start:idx_end + 1, :, :]

    # Compute feature maps
    fwi_norm, dry_norm, recent_fire = window_feature_maps(
        fwi_window=fwi_window,
        d2m_window=d2m_window,
        t2m_window=t2m_window,
        ciffc_window=ciffc_window,
        fwi_band_index=0,
    )

    # Predict probabilities (chunked)
    fwi_flat = fwi_norm.ravel()
    dry_flat = dry_norm.ravel()
    fire_flat = recent_fire.ravel()
    n = fwi_flat.shape[0]

    prob_flat = np.empty(n, dtype=np.float32)
    for s in range(0, n, chunk_size):
        e = min(n, s + chunk_size)
        X_chunk = np.column_stack([fwi_flat[s:e], dry_flat[s:e], fire_flat[s:e]]).astype(np.float32)
        prob_flat[s:e] = logreg_model.predict_proba(X_chunk)[:, 1].astype(np.float32)

    prob = prob_flat.reshape((H, W))

    # Save GeoTIFF aligned to FWI raster for target day
    ref_tif = fwi_paths[idx_end]
    out_name = f"fire_prob_{target_date.strftime('%Y%m%d')}.tif"
    out_path = os.path.join(OUT_DIR, out_name)
    save_probability_map(prob, ref_tif, out_path)

    print(f"Saved {out_path}")


Target dates: [datetime.date(2025, 10, 2), datetime.date(2025, 10, 3), datetime.date(2025, 10, 4), datetime.date(2025, 10, 5), datetime.date(2025, 10, 6), datetime.date(2025, 10, 7), datetime.date(2025, 10, 8)]
2025-10-02: using window indices [14, 20]
Saved outputs/logreg_fire_prob_20251002_1008/fire_prob_20251002.tif
2025-10-03: using window indices [15, 21]
Saved outputs/logreg_fire_prob_20251002_1008/fire_prob_20251003.tif
2025-10-04: using window indices [16, 22]
Saved outputs/logreg_fire_prob_20251002_1008/fire_prob_20251004.tif
2025-10-05: using window indices [17, 23]
Saved outputs/logreg_fire_prob_20251002_1008/fire_prob_20251005.tif
2025-10-06: using window indices [18, 24]
Saved outputs/logreg_fire_prob_20251002_1008/fire_prob_20251006.tif
2025-10-07: using window indices [19, 25]
Saved outputs/logreg_fire_prob_20251002_1008/fire_prob_20251007.tif
2025-10-08: using window indices [20, 26]
Saved outputs/logreg_fire_prob_20251002_1008/fire_prob_20251008.tif
