In [None]:
import ee
import pandas as pd
import numpy as np
import time
import os

# ============================================================
# 0. CONFIG
# ============================================================
PROJECT_ID = "nth-gasket-479617-q5"   # <-- PUT YOUR EE PROJECT ID HERE

TRAIN_CSV = "original_data/wildfire_train.csv"
VAL_CSV   = "original_data/wildfire_val.csv"

TRAIN_OUT_XLSX = "original_data/wildfire_train_with_ndvi.xlsx"
VAL_OUT_XLSX   = "original_data/wildfire_val_with_ndvi.xlsx"

TEST_CSV = "original_data/wildfire_test.csv"
TEST_OUT_XLSX = "original_data/wildfire_test_with_ndvi.xlsx"
TEST_PROGRESS_CSV = "original_data/wildfire_test_with_ndvi_progress.csv"


# NEW: progress CSVs for faster autosave/resume
TRAIN_PROGRESS_CSV = "original_data/wildfire_train_with_ndvi_progress.csv"
VAL_PROGRESS_CSV   = "original_data/wildfire_val_with_ndvi_progress.csv"


LAT_COL  = "latitude"
LON_COL  = "longitude"
DATE_COL = "datetime"
NDVI_COL = "ndvi_modis"

# autosave (rows with NEW NDVI)
SAVE_EVERY = 200

# Max number of points per Earth Engine batch
CHUNK_SIZE = 2000

# MODIS collection ID
MODIS_IC = "MODIS/061/MOD13Q1"

# ============================================================
# 1. Earth Engine init
# ============================================================
ee.Initialize(project=PROJECT_ID)


# ============================================================
# 2. Helper: load (with resume) train/val DataFrame
# ============================================================
def load_split_with_resume(csv_path, out_xlsx_path, progress_csv_path):
    """
    Resume priority:
      1) if progress CSV exists -> load that
      2) elif XLSX exists -> load that
      3) else -> load original CSV
    """
    if os.path.exists(progress_csv_path):
        print(f"Found progress CSV: {progress_csv_path}, resuming from it.")
        df = pd.read_csv(progress_csv_path)
    elif os.path.exists(out_xlsx_path):
        print(f"Found Excel file: {out_xlsx_path}, resuming from it.")
        df = pd.read_excel(out_xlsx_path)
    else:
        print(f"Loading original CSV: {csv_path}")
        df = pd.read_csv(csv_path)

    # Normalize datetime
    df[DATE_COL] = pd.to_datetime(df[DATE_COL]).dt.strftime("%Y-%m-%d")

    # Ensure NDVI column exists
    if NDVI_COL not in df.columns:
        df[NDVI_COL] = np.nan

    return df



# ============================================================
# 3. Helper: get MODIS NDVI image for a given date window
# ============================================================
def get_modis_image_for_date(date_str):
    """
    Given a 'YYYY-MM-DD' date string, return the MOD13Q1 image
    within ±8 days of that date. Returns None if not found.
    """
    date = ee.Date(date_str)
    start = date.advance(-8, "day")
    end   = date.advance( 8, "day")

    collection = (
        ee.ImageCollection(MODIS_IC)
        .filterDate(start, end)
        .sort("system:time_start")
    )

    image = collection.first()
    return image


# ============================================================
# 4. Helper: fetch NDVI for a batch of rows (single date)
# ============================================================
def fetch_ndvi_batch_for_indices(df, indices, max_retries=3):
    """
    For a subset of rows (given by indices) that share a single date,
    build an EE FeatureCollection and sample NDVI in one call.

    Returns: dict {row_index: ndvi_value or None}
    """
    if len(indices) == 0:
        return {}

    sub = df.loc[indices]

    # All rows in this batch must share the same date
    unique_dates = sub[DATE_COL].unique()
    if len(unique_dates) != 1:
        raise ValueError("Batch must contain a single unique date.")
    date_str = unique_dates[0]

    image = get_modis_image_for_date(date_str)
    if image is None:
        print(f"No MODIS image found around date {date_str}.")
        return {idx: None for idx in indices}

    # Build FeatureCollection with row_id property
    features = []
    for idx, row in sub.iterrows():
        lon = float(row[LON_COL])
        lat = float(row[LAT_COL])
        geom = ee.Geometry.Point([lon, lat])
        feat = ee.Feature(geom, {"row_id": int(idx)})
        features.append(feat)

    fc = ee.FeatureCollection(features)

    # Sample NDVI at these points
    sample = image.sampleRegions(
        collection=fc,
        scale=250,
        properties=["row_id"],
        geometries=False
    )

    # Retrieve results with retries
    for attempt in range(max_retries):
        try:
            result = sample.getInfo()
            break
        except Exception as e:
            print(f"sampleRegions failed (attempt {attempt+1}) "
                  f"for date {date_str}: {e}")
            time.sleep(2.0)
    else:
        # All retries failed
        return {idx: None for idx in indices}

    ndvi_map = {idx: None for idx in indices}
    for feat in result.get("features", []):
        props = feat.get("properties", {})
        row_id = props.get("row_id")
        ndvi_val = props.get("NDVI")
        if row_id is not None:
            ndvi_map[int(row_id)] = ndvi_val

    return ndvi_map


# ============================================================
# 5. Main: process one split (train or val)
# ============================================================
def process_split(name, df, out_xlsx_path, progress_csv_path):

    """
    name: 'train' or 'val' (for printing)
    df: DataFrame with columns [LAT_COL, LON_COL, DATE_COL, NDVI_COL]
    out_xlsx_path: where to save progress
    """
    print(f"\n============================")
    print(f"Processing split: {name}")
    print("============================")

    total_rows = len(df)
    # Consider only rows where NDVI is still missing
    pending_mask = df[NDVI_COL].isna()
    pending_indices = df[pending_mask].index.to_list()

    print(f"Total rows: {total_rows}")
    print(f"Rows needing NDVI: {len(pending_indices)}")

    if not pending_indices:
        print("Nothing to do, NDVI already filled for this split.")
        return

    # Group remaining rows by date to batch EE calls
    pending_df = df.loc[pending_indices]
    groups_by_date = pending_df.groupby(DATE_COL).groups  # {date: [idx1, idx2, ...]}

    updated_since_save = 0
    processed_rows = 0
    total_pending = len(pending_indices)

    for date_str, idx_list in groups_by_date.items():
        # Process this date in chunks of CHUNK_SIZE
        idx_list = list(idx_list)
        for i in range(0, len(idx_list), CHUNK_SIZE):
            chunk_indices = idx_list[i : i + CHUNK_SIZE]

            print(
                f"[{name}] Date {date_str}, "
                f"rows {processed_rows+1}–{processed_rows+len(chunk_indices)} "
                f"of {total_pending}"
            )

            ndvi_map = fetch_ndvi_batch_for_indices(df, chunk_indices)

            # Update DataFrame
            for ridx, ndvi_val in ndvi_map.items():
                df.at[ridx, NDVI_COL] = ndvi_val
                updated_since_save += 1
                processed_rows += 1

            # Autosave
            if updated_since_save >= SAVE_EVERY:
                print(f"Autosaving progress (CSV) ...")
                df.to_csv(progress_csv_path, index=False)
                updated_since_save = 0


    # Final save
    print(f"Final save for {name}")
    # Always save CSV (fast, reliable)
    df.to_csv(progress_csv_path, index=False)

    # Try Excel as final nice output
    try:
        print(f"Writing Excel to {out_xlsx_path} ...")
        df.to_excel(out_xlsx_path, index=False)
    except Exception as e:
        print(f"Excel save failed: {e}")
        print("But CSV progress is fully saved; you can convert to Excel later if needed.")



# ============================================================
# 6. Run for TRAIN + VAL
# ============================================================
if __name__ == "__main__":
    # Train
    train_df = load_split_with_resume(TRAIN_CSV, TRAIN_OUT_XLSX, TRAIN_PROGRESS_CSV)
    process_split("train", train_df, TRAIN_OUT_XLSX, TRAIN_PROGRESS_CSV)

    # Val
    val_df = load_split_with_resume(VAL_CSV, VAL_OUT_XLSX, VAL_PROGRESS_CSV)
    process_split("val", val_df, VAL_OUT_XLSX, VAL_PROGRESS_CSV)
    
    # Test
    test_df = load_split_with_resume(TEST_CSV, TEST_OUT_XLSX, TEST_PROGRESS_CSV)
    process_split("test", test_df, TEST_OUT_XLSX, TEST_PROGRESS_CSV)

    print("\nAll done – NDVI added to train + val + test.")

