In [None]:

"""
AEI downscaling — Stefan-faithful allocator with robust pixel alignment
- Replicate Stefan's 10-step logic exactly (GMIA gating + per-step ceilings).
- Terminates with a mixed allocator (1/2 proportional, 1/2 equal) at the step
  where the remaining target fits.
- Enforces LAND ceilings per pixel *every step* and again at the end (paranoia).
- Uses GMIA grid as reference; optional prebuilt UNITMASK raster.
- Writes rich per-unit debug and an audit with HYDE consistency checks.

Debug columns added:
- step_total  : sum of step_1..step_10
- pixel_total : sum of the allocated vector that actually landed on pixels
- clamp_loss  : step_total - pixel_total (0 unless LAND clamped something)
- terminating_step : 1..10 (or 0 if nothing was placed)
"""

import os
import time
import random
import numpy as np
import pandas as pd
import geopandas as gpd
import rasterio
from rasterio.enums import Resampling
from rasterio.warp import reproject
from rasterio.features import rasterize

# ----------------------- USER SETTINGS -----------------------
# ROOT_DIR 
ROOT_DIR =  r"D/data/AEI_2020"

### SCENARIOS CONFIG ###
SCENARIO_CONFIG = {
    "base":      {"hyde_variant": "base",  "consistency": "stats"},
    "low":       {"hyde_variant": "low",   "consistency": "stats"},
    "upper":     {"hyde_variant": "upper", "consistency": "stats"},
    "hyde_base": {"hyde_variant": "base",  "consistency": "hyde"},
    "hyde_low":  {"hyde_variant": "low",   "consistency": "hyde"},
    "hyde_upper":{"hyde_variant": "upper", "consistency": "hyde"},
}

# Choose which of the six to run
scenarios = ["base", "low", "upper", "hyde_base", "hyde_low", "hyde_upper"]


years = [2000, 2005, 2010, 2015, 2020]

# "STRICT" (per-unit only) or "GLOBAL_SPILL" (optional second pass)
BALANCE_MODE = "STRICT"

# If you have a pre-built, GMIA-aligned integer unitmask raster set this True.
# Otherwise polygons are rasterized (center-of-pixel, valid geometries).
USE_UNITMASK_RASTER = False

# Local folder chains (same relative structure as Drive version)
F_GMIA    = ["input", "Meier"]              # GMIA_<year>.tif (reference grid)
F_STEFAN  = ["input", "Stefan"]             # LAND_GMIA.tif, LAND_HYDE.tif (+ optional UNITMASK.tif)
F_HYDE    = ["input", "hyde"]               # base/low/upper
F_AEI_CSV = ["input","AEI","Global","New AEI GLobal","Final AEI","FInal"]
F_SHP     = ["input","AEI","Shapefile","new Shapefile","Piyushs"]

OUT_DIR_NAME = "Output/Global/Meier"  # keeps GMIA in path; filenames carry scenario info

HYDE_TO_AEI = 100.0      # km² -> ha
DEFAULT_SRC_CRS = "EPSG:4326"

# Names of the two land masks, following Stefan logic:
LAND_STATS_NAME = "LAND_GMIA.tif"  # "consistent to irrigation statistics"
LAND_HYDE_NAME  = "LAND_HYDE.tif"  # "consistent to HYDE"

# ----------------------- LOCAL PATH HELPERS -----------------------

def _join_root(*parts):
    """Join paths under ROOT_DIR."""
    return os.path.join(ROOT_DIR, *parts)

def _ensure_dir(path):
    os.makedirs(path, exist_ok=True)
    return path

def get_out_folder():
    """Return the output directory path (and ensure it exists)."""
    out_dir = _join_root(*OUT_DIR_NAME.split("/"))
    _ensure_dir(out_dir)
    return out_dir

def get_gmia(year):
    path = _join_root(*F_GMIA, f"Meier_{year}.tif")
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing GMIA raster: {path}")
    return path

def get_land(consistency):
    """
    consistency: "stats" → GMIA land (LAND_GMIA.tif)
                 "hyde"  → HYDE land (LAND_HYDE.tif)
    """
    name = LAND_STATS_NAME if consistency == "stats" else LAND_HYDE_NAME
    path = _join_root(*F_STEFAN, name)
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing LAND raster: {path}")
    return path

def get_unitmask(year):
    # Try UNITMASK_<year>.tif then UNITMASK.tif
    p1 = _join_root(*F_STEFAN, f"UNITMASK_{year}.tif")
    p2 = _join_root(*F_STEFAN, "UNITMASK.tif")
    if os.path.exists(p1):
        return p1
    if os.path.exists(p2):
        return p2
    raise FileNotFoundError(f"No UNITMASK found for {year}: {p1} or {p2}")

def get_hyde(hyde_variant, year, var):
    """
    hyde_variant: "base", "low", "upper"
    var: "cropland" or "pasture"
    """
    path = _join_root(*F_HYDE, hyde_variant, f"{var}{year}AD.tif")
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing HYDE raster: {path}")
    return path

def get_shp(year):
    """
    Expects shapefile parts like AEI_<year>_2.shp/.dbf/.shx/.prj
    in ROOT_DIR / F_SHP.
    """
    base = f"AEI_{year}_2"
    shp_path = _join_root(*F_SHP, base + ".shp")
    if not os.path.exists(shp_path):
        raise FileNotFoundError(f"Shapefile not found: {shp_path}")
    return shp_path  # geopandas will find the sidecar files by the same base name

def get_aei_stats(year):
    """
    AEI stats: AEI_<year>_Global.xlsx in ROOT_DIR / F_AEI_CSV
    """
    path = _join_root(*F_AEI_CSV, f"AEI_{year}_Global.xlsx")
    if not os.path.exists(path):
        raise FileNotFoundError(f"AEI stats file not found: {path}")
    return path

# ----------------------- ALIGNMENT HELPERS -----------------------
def _same_grid(meta_a, meta_b):
    return (meta_a["crs"] == meta_b["crs"]) and (meta_a["transform"] == meta_b["transform"]) \
           and (meta_a["height"] == meta_b["height"]) and (meta_a["width"] == meta_b["width"])

def read_as_aligned_array(src_path, ref_transform, ref_crs, ref_shape, resampling=Resampling.nearest):
    with rasterio.open(src_path) as src:
        arr = src.read(1).astype(np.float32)
        arr[~np.isfinite(arr)] = 0.0
        arr = np.maximum(arr, 0.0)
        src_crs = src.crs if src.crs is not None else (ref_crs or DEFAULT_SRC_CRS)
        src_tr  = src.transform
        if (arr.shape == ref_shape) and (src_crs == ref_crs) and (src_tr == ref_transform):
            return arr
        dst = np.zeros(ref_shape, np.float32)
        reproject(arr, dst, src_transform=src_tr, dst_transform=ref_transform,
                  src_crs=src_crs, dst_crs=ref_crs, resampling=resampling)
    dst[~np.isfinite(dst)] = 0.0
    dst = np.maximum(dst, 0.0)
    return dst

def build_unit_id_raster_from_polygons(gdf, id_field, transform, shape, crs):
    gdf = gdf.to_crs(crs).copy()
    try:
        from shapely import make_valid
        gdf["geometry"] = gdf["geometry"].map(make_valid)
    except Exception:
        gdf["geometry"] = gdf["geometry"].buffer(0)
    shapes = [(geom, int(i)) for geom, i in zip(gdf.geometry, gdf[id_field])]
    # Center-of-pixel semantics (Stefan-faithful)
    return rasterize(shapes=shapes, out_shape=shape, transform=transform,
                     fill=0, all_touched=False, dtype="int32")

# ----------------------- ALLOCATION CORE -----------------------
def _distribute_mixed(total, alloc, ceil, eps=1e-6):
    """
    Mixed allocator used on the terminating step:
    - 1/2 proportional to current AEI (alloc), 1/2 equal across active pixels
    - never exceed 'ceil' per pixel, iterate until placed or no room
    """
    alloc = np.asarray(alloc, dtype=np.float64)
    ceil  = np.asarray(ceil,  dtype=np.float64)

    inc  = np.zeros_like(alloc)
    left = float(total)

    def room_now():
        return np.maximum(ceil - (alloc + inc), 0.0)

    while left > eps:
        room = room_now()
        active = room > eps
        if not np.any(active):
            break

        rel_mass = left * 0.5
        abs_mass = left * 0.5

        w = (alloc + inc).copy()
        w[~active] = 0.0
        ws = w.sum()
        rel_take = np.zeros_like(room)
        if ws > 0:
            rel_take[active] = rel_mass * (w[active] / ws)

        k = int(active.sum())
        abs_take = np.zeros_like(room)
        if k > 0:
            abs_take[active] = abs_mass / k

        x = rel_take + abs_take
        take = np.minimum(room, x)
        got = float(np.sum(take, dtype=np.float64))

        if got <= eps:
            if k == 0:
                break
            eq = np.zeros_like(room)
            eq[active] = left / k
            take = np.minimum(room, eq)
            got = float(np.sum(take, dtype=np.float64))
            if got <= eps:
                break

        inc  += take
        left -= got

    return inc.astype(np.float32), float(max(left, 0.0))

def stefan_allocate_unit(target, irri_t1, crop, pasture, land_cap):
    """
    Returns:
      alloc (float32[n]), steps_used (10), steps_avail (10), leftover,
      term_step (int 0..10), step_total (float), pixel_total (float)
    """
    irri_t1 = np.asarray(irri_t1, dtype=np.float64)
    crop    = np.asarray(crop,    dtype=np.float64)
    pasture = np.asarray(pasture, dtype=np.float64)
    land_cap= np.asarray(land_cap,dtype=np.float64)

    n = irri_t1.size
    if n == 0 or target <= 0:
        return np.zeros(n, np.float32), [0.0]*10, [0.0]*10, float(target), 0, 0.0, 0.0

    agri  = crop + pasture
    alloc = np.zeros(n, np.float64)
    steps_used, steps_avail = [], []
    gmia_mask = irri_t1 > 0
    term_step = 0

    # ABSOLUTE ceilings for each step
    s_abs = [
        np.minimum(irri_t1, crop),                         # S1
        np.minimum(irri_t1, agri),                         # S2
        np.minimum(irri_t1, land_cap),                     # S3
        irri_t1,                                           # S4
        np.maximum(irri_t1, crop),                         # S5 (GMIA>0)
        np.maximum(irri_t1, agri),                         # S6 (GMIA>0)
        land_cap,                                          # S7 (GMIA>0)
        np.where(irri_t1==0, crop, np.maximum(irri_t1, agri)),  # S8 (GMIA==0 part)
        np.maximum(irri_t1, agri),                         # S9
        np.maximum(irri_t1, land_cap),                     # S10
    ]

    left = float(target)

    for i in range(10):
        if i < 7:
            ceil_i = np.where(gmia_mask, s_abs[i], 0.0)
        elif i == 7:
            ceil_i = np.where(irri_t1==0, s_abs[i], 0.0)  # only the GMIA==0 part
        else:
            ceil_i = s_abs[i]

        # enforce LAND at this step too
        ceil_i = np.minimum(ceil_i, land_cap)

        room = np.maximum(ceil_i - alloc, 0.0)
        avail = float(np.sum(room, dtype=np.float64))
        steps_avail.append(avail)

        if avail <= 1e-9:
            steps_used.append(0.0)
            continue

        if left <= avail + 1e-9:
            inc, rem = _distribute_mixed(left, alloc, ceil_i)
            alloc += inc
            steps_used.append(float(np.sum(inc, dtype=np.float64)))
            left = float(rem)
            term_step = i + 1
            # clamp to LAND again this step
            alloc = np.minimum(alloc, land_cap)
            break
        else:
            alloc += room
            steps_used.append(avail)
            left -= avail

        # clamp to LAND after each step
        alloc = np.minimum(alloc, land_cap)

    while len(steps_used)  < 10: steps_used.append(0.0)
    while len(steps_avail) < 10: steps_avail.append(0.0)

    step_total  = float(np.sum(np.array(steps_used, dtype=np.float64)))
    pixel_total = float(np.sum(alloc, dtype=np.float64))

    return alloc.astype(np.float32), steps_used, steps_avail, float(max(left, 0.0)), term_step, step_total, pixel_total

# ----------------------- AUDIT HELPERS -----------------------
def _audit_and_write(out_dir, yr, scen, mosaic, unit_id, df, land_full, crop_full, past_full, meta, irri_t1_full):
    tol = 1e-3
    recs = []
    total_alloc  = 0.0
    total_target = float(np.sum(df["AEI"].to_numpy(dtype=np.float64)))

    for _, row in df.iterrows():
        uid = int(row["unit_id"])
        tgt = float(row["AEI"])
        m = (unit_id == uid)

        alloc_sum = float(np.sum(mosaic[m], dtype=np.float64))
        land_sum  = float(np.sum(land_full[m], dtype=np.float64))
        crop_sum  = float(np.sum(crop_full[m], dtype=np.float64))
        agri_sum  = float(np.sum(crop_full[m] + past_full[m], dtype=np.float64))
        gmia_pos  = int(((unit_id == uid) & (irri_t1_full > 0)).sum())

        recs.append({
            "unit_code": row["unit_code"],
            "unit_id": uid,
            "target_AEI": tgt,
            "allocated_AEI": alloc_sum,
            "diff": alloc_sum - tgt,
            "ok_unit": abs(alloc_sum - tgt) <= tol,
            "n_pix": int(m.sum()),
            "land_sum": land_sum,
            "crop_sum": crop_sum,
            "agri_sum": agri_sum,
            "gmia_pos": gmia_pos
        })
        total_alloc += alloc_sum

    land_viol = float(np.sum(np.maximum(mosaic - land_full, 0.0), dtype=np.float64))
    crop_viol = float(np.sum(np.maximum(mosaic - crop_full, 0.0), dtype=np.float64))
    agri_viol = float(np.sum(np.maximum(mosaic - (crop_full + past_full), 0.0), dtype=np.float64))

    recs.append({
        "unit_code": "__GLOBAL__",
        "unit_id": -1,
        "target_AEI": total_target,
        "allocated_AEI": total_alloc,
        "diff": total_alloc - total_target,
        "ok_unit": abs(total_alloc - total_target) <= tol,
        "n_pix": int((unit_id>0).sum()),
        "land_sum": float(np.sum(land_full, dtype=np.float64)),
        "crop_sum": float(np.sum(crop_full, dtype=np.float64)),
        "agri_sum": float(np.sum(crop_full + past_full, dtype=np.float64)),
        "gmia_pos": int((irri_t1_full > 0).sum())
    })

    df_audit = pd.DataFrame(recs)
    out_path = os.path.join(out_dir, f"audit_{yr}_{scen}.csv")
    # Write audit plus a one-row summary of global violations
    with open(out_path, "w", newline="") as f:
        df_audit.to_csv(f, index=False)
        f.write("\n")
        pd.DataFrame([{
            "land_excess_ha": land_viol,
            "crop_excess_ha": crop_viol,
            "agri_excess_ha": agri_viol
        }]).to_csv(f, index=False)

# ----------------------- MAIN -----------------------
def run():
    print("ROOT_DIR:", ROOT_DIR)
    out_dir = get_out_folder()
    print("Output dir:", out_dir)
    print("Balance mode:", BALANCE_MODE, "| Scenarios:", ", ".join(scenarios))

    for scen in scenarios:
        if scen not in SCENARIO_CONFIG:
            raise ValueError(f"Unknown scenario '{scen}'. Must be one of {list(SCENARIO_CONFIG.keys())}")

        cfg = SCENARIO_CONFIG[scen]
        hyde_variant = cfg["hyde_variant"]   # base/low/upper
        consistency  = cfg["consistency"]    # "stats" or "hyde"

        for yr in years:
            print(f"\n▶ Year {yr} | Scenario {scen} "
                  f"(HYDE='{hyde_variant}', consistency='{consistency}')")

            # Reference grid (GMIA)
            gmia_path = get_gmia(yr)
            with rasterio.open(gmia_path) as rref:
                ref_meta = rref.meta.copy()
                ref_transform, ref_crs = rref.transform, rref.crs
                H, W = rref.height, rref.width
                ref_shape = (H, W)
                irri_t1_full = rref.read(1).astype(np.float32)
            irri_t1_full[~np.isfinite(irri_t1_full)] = 0.0
            irri_t1_full = np.maximum(irri_t1_full, 0.0)

            # LAND depends on consistency (GMIA vs HYDE)
            land_full = read_as_aligned_array(
                get_land(consistency),
                ref_transform, ref_crs, ref_shape, Resampling.nearest
            )

            # ----- Unit IDs -----
            if USE_UNITMASK_RASTER:
                try:
                    um_path = get_unitmask(yr)
                    with rasterio.open(um_path) as um:
                        unit_id = um.read(1).astype(np.int32)
                        um_meta = um.meta.copy()
                    if not _same_grid(um_meta, ref_meta):
                        raise RuntimeError("Unitmask raster is not exactly aligned to GMIA.")
                    print("  • Using unitmask raster (aligned)")
                    # Build df for AEI targets
                    shp_path = get_shp(yr)
                    gdf = gpd.read_file(shp_path).to_crs(ref_crs)
                    gdf["unit_code"] = gdf["unit_code"].astype(str).str.strip()
                    stats = pd.read_excel(get_aei_stats(yr)); stats["unit_code"] = stats["unit_code"].astype(str).str.strip()
                    df = gdf.merge(stats, on="unit_code", how="inner")
                    df = df[df["AEI"].astype(float) > 0].reset_index(drop=True)
                    code_to_id = {uc: i+1 for i, uc in enumerate(df["unit_code"].tolist())}
                    df["unit_id"] = df["unit_code"].map(code_to_id).astype(int)
                except Exception as e:
                    print("  ! Unitmask raster not found or misaligned:", e)
                    print("  • Falling back to polygon rasterization (center-of-pixel)")
                    shp_path = get_shp(yr)
                    gdf = gpd.read_file(shp_path).to_crs(ref_crs)
                    gdf["unit_code"] = gdf["unit_code"].astype(str).str.strip()
                    stats = pd.read_excel(get_aei_stats(yr)); stats["unit_code"] = stats["unit_code"].astype(str).str.strip()
                    df = gdf.merge(stats, on="unit_code", how="inner")
                    df = df[df["AEI"].astype(float) > 0].reset_index(drop=True)
                    code_to_id = {uc: i+1 for i, uc in enumerate(df["unit_code"].tolist())}
                    df["unit_id"] = df["unit_code"].map(code_to_id).astype(int)
                    unit_id = build_unit_id_raster_from_polygons(df[["unit_id","geometry"]], "unit_id",
                                                                 ref_transform, ref_shape, ref_crs)
            else:
                shp_path = get_shp(yr)
                gdf = gpd.read_file(shp_path).to_crs(ref_crs)
                gdf["unit_code"] = gdf["unit_code"].astype(str).str.strip()
                stats = pd.read_excel(get_aei_stats(yr)); stats["unit_code"] = stats["unit_code"].astype(str).str.strip()
                df = gdf.merge(stats, on="unit_code", how="inner")
                df = df[df["AEI"].astype(float) > 0].reset_index(drop=True)
                code_to_id = {uc: i+1 for i, uc in enumerate(df["unit_code"].tolist())}
                df["unit_id"] = df["unit_code"].map(code_to_id).astype(int)
                unit_id = build_unit_id_raster_from_polygons(df[["unit_id","geometry"]], "unit_id",
                                                             ref_transform, ref_shape, ref_crs)

            # ----- HYDE inputs aligned to GMIA grid -----
            crop_full = read_as_aligned_array(
                get_hyde(hyde_variant, yr, "cropland"),
                ref_transform, ref_crs, ref_shape, Resampling.nearest
            ) * HYDE_TO_AEI
            past_full = read_as_aligned_array(
                get_hyde(hyde_variant, yr, "pasture"),
                ref_transform, ref_crs, ref_shape, Resampling.nearest
            ) * HYDE_TO_AEI

            # Pass 1 — per-unit allocation
            mosaic = np.zeros(ref_shape, np.float32)
            remaining = land_full.copy().astype(np.float64)
            debug_rows, leftovers = [], []

            for _, row in df.iterrows():
                uid = int(row["unit_id"])
                tgt = float(row["AEI"])
                m = (unit_id == uid)
                if not np.any(m):
                    debug_rows.append({
                        "unit_code": row["unit_code"], "unit_id": uid,
                        "target_AEI": tgt, "allocated_AEI": 0.0, "diff": -tgt,
                        "leftover_after_unit": tgt, "mode": "no_pixels",
                        "n_pix": 0, "land_sum": 0.0, "crop_sum": 0.0,
                        "agri_sum": 0.0, "gmia_pos": 0,
                        "terminating_step": 0,
                        "step_total": 0.0,
                        "pixel_total": 0.0,
                        "clamp_loss": 0.0,
                        "scenario": scen,
                        "hyde_variant": hyde_variant,
                        "consistency": consistency
                    })
                    leftovers.append((uid, tgt))
                    continue

                idx = np.where(m)
                irri = irri_t1_full[idx]
                crop = crop_full[idx]
                past = past_full[idx]
                cap  = np.maximum(remaining[idx], 0.0)

                alloc_vec, used_vec, av_vec, left, term_step, step_total, pixel_total = \
                    stefan_allocate_unit(tgt, irri, crop, past, cap)

                mosaic[idx] += alloc_vec
                # reduce remaining LAND only within this unit's pixels
                remaining[idx] = np.maximum(remaining[idx] - alloc_vec.astype(np.float64), 0.0)

                alloc_sum = float(np.sum(alloc_vec, dtype=np.float64))
                land_sum  = float(np.sum(land_full[idx], dtype=np.float64))
                crop_sum  = float(np.sum(crop_full[idx], dtype=np.float64))
                agri_sum  = float(np.sum(crop_full[idx] + past_full[idx], dtype=np.float64))

                debug_rows.append({
                    "unit_code": row["unit_code"], "unit_id": uid,
                    "target_AEI": tgt, "allocated_AEI": alloc_sum,
                    "diff": alloc_sum - tgt, "leftover_after_unit": float(left),
                    **{f"step_{i+1}": float(used_vec[i]) for i in range(10)},
                    **{f"avail_{i+1}": float(av_vec[i]) for i in range(10)},
                    "n_pix": int(m.sum()),
                    "land_sum": land_sum,
                    "crop_sum": crop_sum,
                    "agri_sum": agri_sum,
                    "gmia_pos": int((irri > 0).sum()),
                    "terminating_step": int(term_step),
                    "step_total": float(step_total),
                    "pixel_total": float(pixel_total),
                    "clamp_loss": float(step_total - pixel_total),
                    "scenario": scen,
                    "hyde_variant": hyde_variant,
                    "consistency": consistency
                })

                if left > 1e-6:
                    leftovers.append((uid, float(left)))

            # Pass 2 — optional global spill
            total_left = float(sum(l for _, l in leftovers))
            if BALANCE_MODE.upper() == "GLOBAL_SPILL" and total_left > 1e-6:
                print(f"  • Global spill: trying to place {total_left:,.1f} ha into spare capacity")

                def spill_into(mask, left):
                    if left <= 1e-6: return 0.0
                    sel = mask & (remaining > 1e-9)
                    if not np.any(sel): return left
                    alloc_sub = mosaic[sel].astype(np.float64)
                    ceil_sub  = (mosaic[sel].astype(np.float64) + remaining[sel])
                    inc, rem = _distribute_mixed(left, alloc_sub, ceil_sub)
                    if inc.sum() > 0:
                        mosaic[sel] += inc
                        remaining[sel] = np.maximum(remaining[sel] - inc.astype(np.float64), 0.0)
                    return float(rem)

                gmia_pos  = (irri_t1_full > 0)
                gmia_zero = (irri_t1_full == 0)
                total_left = spill_into(gmia_pos, total_left)
                total_left = spill_into(gmia_zero, total_left)
                print(f"  • Spill leftover after pass: {total_left:,.1f} ha")

            # Final (paranoia) clamp to LAND
            mosaic = np.minimum(mosaic.astype(np.float64), land_full.astype(np.float64)).astype(np.float32)

            # ----- Write outputs -----
            # Raster
            out_ras = os.path.join(out_dir, f"AEI_{yr}_{scen}.tif")
            meta = {**ref_meta, "dtype": rasterio.float32, "nodata": 0}
            with rasterio.open(out_ras, "w", **meta) as dst:
                dst.write(mosaic.astype(np.float32), 1)

            # Debug CSV
            dbg = pd.DataFrame(debug_rows)
            dbg_totals = {
                "sum_target": float(np.sum(dbg["target_AEI"].to_numpy(dtype=np.float64))),
                "sum_allocated": float(np.sum(dbg["allocated_AEI"].to_numpy(dtype=np.float64))),
                "sum_diff": float(
                    np.sum(dbg["allocated_AEI"].to_numpy(dtype=np.float64))
                    - np.sum(dbg["target_AEI"].to_numpy(dtype=np.float64))
                ),
                "sum_step_total": float(np.sum(dbg["step_total"].to_numpy(dtype=np.float64))),
                "sum_pixel_total": float(np.sum(dbg["pixel_total"].to_numpy(dtype=np.float64))),
                "sum_clamp_loss": float(np.sum(dbg["clamp_loss"].to_numpy(dtype=np.float64))),
                "spill_leftover_after_pass2": 0.0 if BALANCE_MODE.upper()=="STRICT" else float(total_left),
                "scenario": scen,
                "hyde_variant": hyde_variant,
                "consistency": consistency
            }
            summ = pd.DataFrame([dbg_totals])
            dbg_path = os.path.join(out_dir, f"debug_{yr}_{scen}.csv")
            with open(dbg_path, "w", newline="") as f:
                dbg.to_csv(f, index=False)
                f.write("\n")
                summ.to_csv(f, index=False)

            # Audit (HYDE & stats)
            _audit_and_write(out_dir, yr, scen, mosaic, unit_id, df,
                             land_full, crop_full, past_full, meta, irri_t1_full)

            print(f"  ✓ Wrote AEI_{yr}_{scen}.tif, debug_{yr}_{scen}.csv, audit_{yr}_{scen}.csv")

if __name__ == "__main__":
    run()
