In [2]:
from pydrive2.auth import GoogleAuth
from pydrive2.drive import GoogleDrive

# Use the exact path to your JSON. If it's next to the notebook, just the filename is fine.
CLIENT_JSON = "client_secret_Final.json"   # or "/home/endiabe/client_secrets.json"

gauth = GoogleAuth()
gauth.LoadClientConfigFile(CLIENT_JSON)

# Command-line auth -> opens a URL; you paste back the code
gauth.CommandLineAuth()

# Save token so you won't need to sign in every time
gauth.SaveCredentialsFile('token.json')

drive = GoogleDrive(gauth)
print("✅ Google Drive connected")

✅ Google Drive connected


In [3]:
import os
# match your Slurm allocation ("64")
os.environ["OMP_NUM_THREADS"] = "64"        # OpenMP (NumPy/SciPy, some raster ops)
os.environ["MKL_NUM_THREADS"] = "64"        # MKL-backed NumPy / scikit-learn
os.environ["NUMEXPR_NUM_THREADS"] = "64"    # if you use numexpr
os.environ["GDAL_NUM_THREADS"] = "64"       # for internal GDAL threaded ops
os.environ["RASTERIO_NUM_THREADS"] = "64"   # rasterio’s thread pool


In [4]:
# List only folders in Google Drive root
folder_list = drive.ListFile({
    'q': "'root' in parents and mimeType='application/vnd.google-apps.folder' and trashed=false"
}).GetList()

for folder in folder_list:
    print(f"Folder: {folder['title']}  |  ID: {folder['id']}")


Folder: Nigeria_VIs_Env_GrowingSeason  |  ID: 1VYsHaK4CXu3wo_UZIN6Kv44qan8d1smR
Folder: Nigeria_VIs_Env_GrowingSeason  |  ID: 1nUAHXZ2zKpT22jIsCVHJahGDa7-7bCoW
Folder: .ipynb_checkpoints  |  ID: 1veaaOV_OobM00bScEiQiY8pDoJSxFIC0
Folder: Continental Irr Maps  |  ID: 1atiKrbLV7Ifez9AKTuSqCFMsTxkOqOvc
Folder: Moldova_VIs_Env_GrowingSeason  |  ID: 183w10O5hg6pS6ou2gT7LE1N4SI1pSn5B
Folder: Cuba_VIs_Env_IrrigationSeason  |  ID: 1OrL9zVhyRh57oLfvg84ySJhB0GZSCM6m
Folder: Caribbean_VIs_Env_IrrigationSeason  |  ID: 1vsKprJVxyyjHR6zPcfYV1Rc1wkjIw2fn
Folder: GMIE  |  ID: 1lVkH_OShkP5n_9CI4bdJgwFuHwmbx57H
Folder: CLE_National  |  ID: 1pOpf-Zy5la4SKTIcrmLX6tGlxVBejKYX
Folder: AfricaEnv_WA_noPrecip_NoSoil_NoCropMask  |  ID: 1KGWrOMmhnCVvUFr7uh9y2wyjPdLkRFDd
Folder: Europe_VIs_Env2  |  ID: 1Rz21QLD1UnspShqUD7jB6Bfsgjk03Z9e
Folder: China  |  ID: 19Ee1kUar4oEabBObq7GcA-jDR7F29x2_
Folder: Canada  |  ID: 1qXhq3-3-J2YhKjXmMcCOT0exvb12LxhD
Folder: Global_Irr_Maps  |  ID: 1XKYX2EPkPmu1el4Q0FpfZMEkgsWvRuYD
Fo

In [None]:
# -*- coding: utf-8 -*-
"""
Regional RF irrigation model with:
- Farm-aware split via DBSCAN (no leakage)
- AEZ categorical (one-hot)
- OUT-OF-FOLD probability calibration (isotonic by default, Platt optional)
- Regional OOF metrics + artifacts
- Per-country OOF metrics & thresholds saved
- Country predictions on ~30m WGS84 using calibrated probabilities

Now updated to:
  • Read AEZ per-point from Model Training/US_GTPS_per_point.(parquet|csv) when available
  • Prefer per-state AEZ tiles: By Country/<State>/<State>_AEZ.tif (fallback: global AEZ)

Outputs:
  CountryModelPredicted/<Country>_RF_probability_percent.tif
  CountryModelPredicted/<Country>_RF_binary_0_1_cropland.tif
  CountryModelPredicted/<Country>_RF_predictors_count.tif

  Model Training/Regional Models/<Region>/
    - metrics.json
    - confusion.png
    - feature_importance.csv
    - per_country_test_metrics.csv
    - model.joblib

  Model Training/Country Models/<Country>/
    - threshold.json
    - test_metrics.json
    - test_confusion.png
"""
import os, re, json, math, tempfile, joblib, time, random
from collections import deque

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score, average_precision_score, roc_auc_score,
    precision_recall_fscore_support, confusion_matrix, roc_curve,
    precision_recall_curve
)
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.cluster import DBSCAN
from sklearn.linear_model import LogisticRegression
from sklearn.isotonic import IsotonicRegression
import pyproj
from sklearn.base import clone
import rasterio
from rasterio.enums import Resampling as RioResampling
from rasterio.warp import reproject, transform as rio_transform, transform_bounds as rio_transform_bounds
from rasterio.transform import Affine
from rasterio.windows import Window
from rasterio.crs import CRS

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

try:
    from googleapiclient.errors import HttpError
except Exception:
    HttpError = Exception
try:
    from pydrive2.files import ApiRequestError
except Exception:
    ApiRequestError = Exception

os.environ.setdefault("SHAPE_RESTORE_SHX", "YES")

# ------------------------ CONFIG ------------------------
# ------------------------ CONFIG ------------------------
ROOT_FOLDER_ID = os.environ.get("ROOT_FOLDER_ID", "1ReAmJ809FmU-8jLpyyO6VEdmeqDmbaB7")
BY_COUNTRY_NAME = "By Country"
CLE_FOLDER_ID = "1pOpf-Zy5la4SKTIcrmLX6tGlxVBejKYX"
MODEL_FOLDER = "Model Training"
REGIONAL_MODELS_FOLDER = "Regional Models"  # used per country now: region_name == country
COUNTRY_MODELS_FOLDER = "Country Models"
OUTPUT_FOLDER = "CountryModelPredicted_Cropland"

# River Network (continental fallback for distance-to-river)
RIVER_NET_FOLDER_ID = "1ReAmJ809FmU-8jLpyyO6VEdmeqDmbaB7"
RIVER_DIST_PREFIX = "Dist_"
RIVER_DIST_SUFFIX = "_river.tif"

# Country column auto-detection; we copy into this name
COUNTRY_COL = os.environ.get("COUNTRY_COL", "country_hint")

# DBSCAN farm grouping (meters)
DBSCAN_EPS_M = float(os.environ.get("DBSCAN_EPS_M", 45))
DBSCAN_MIN_SAMPLES = int(os.environ.get("DBSCAN_MIN_SAMPLES", 1))

# AEZ (global raster)
AEZ_FILE_ID   = os.environ.get("AEZ_FILE_ID", "1te3nKn8vyt2AECmk8NM_xBgpEFXp84RX")
AEZ_FILE_NAME = os.environ.get("AEZ_FILE_NAME", "AEZ_2020s.tif")
AEZ_COL       = "AEZ"

# Local fallbacks
LOCAL_BASE_DIR = os.path.join("./", "Africa_VIs_Env")
LOCAL_BY_COUNTRY_DIR = os.path.join(LOCAL_BASE_DIR, "By Country")
LOCAL_CLE_DIR = os.path.join(LOCAL_BASE_DIR, "CLE")
LOCAL_MODEL_DIR = os.path.join(LOCAL_BASE_DIR, MODEL_FOLDER)
LOCAL_OUTPUT_DIR = os.path.join(LOCAL_BASE_DIR, OUTPUT_FOLDER)

# GTPS/AEZ cache made earlier
GTPS_CACHE_PARQUET = "Nigeria_GTPS_per_point.parquet"
GTPS_CACHE_CSV     = "Nigeria_GTPS_per_point.csv"

# Drive retry knobs
DRIVE_MAX_RETRIES = int(os.environ.get("DRIVE_MAX_RETRIES", 6))
DRIVE_RETRY_BASE = float(os.environ.get("DRIVE_RETRY_BASE", 0.8))

POINTS_BASENAME = "Africa"  # pooled training points

ALL_PREDICTORS = [
    "NDVI_mean","NDWI_mean","GI_mean",
    "NDVI_max","NDWI_max","GI_max",
    "NDVI_min","NDWI_min","GI_min",
    "elevation","slope",
    # "ET","PET",
    "dist_to_river",
    AEZ_COL,
]
TARGET = "irrigated"
LONCOL = "longitude"
LATCOL = "latitude"


# RF + data
RF_TREES = int(os.environ.get("RF_TREES", 400))
RANDOM_SEED = 42
TEST_FRACTION = 0.30

BORROW_MAX = 2000

# ---- Calibration & thresholds ----
CALIBRATION_METHOD = os.environ.get("CALIBRATION_METHOD", "isotonic").lower()  # "isotonic" | "platt"
THRESHOLD_STRATEGY = os.environ.get("THRESHOLD_STRATEGY", "auto").lower()      # "auto"|"j"|"f1"|"fixed"|"precision_at"|"recall_at"
FIXED_THRESHOLD = float(os.environ.get("FIXED_THRESHOLD", 0.50))
PRECISION_TARGET = os.environ.get("PRECISION_TARGET")
RECALL_TARGET = os.environ.get("RECALL_TARGET")
PRECISION_TARGET = None if PRECISION_TARGET in (None, "", "None") else float(PRECISION_TARGET)
RECALL_TARGET    = None if RECALL_TARGET    in (None, "", "None") else float(RECALL_TARGET)

# Per-country threshold mode: "global" | "country" | "hybrid"
THRESHOLD_MODE = os.environ.get("THRESHOLD_MODE", "hybrid").lower()
MIN_POS_NEG_FOR_COUNTRY = int(os.environ.get("MIN_POS_NEG_FOR_COUNTRY", 30))
MIN_TOTAL_FOR_COUNTRY   = int(os.environ.get("MIN_TOTAL_FOR_COUNTRY", 100))

# Prediction
TILE_SIZE = 1024
RESAMPLING = RioResampling.bilinear
PER_BAND_RESAMPLING = {AEZ_COL: RioResampling.nearest}
STRICT_ALL_BANDS = False
MIN_PREDICTORS = 9

# ------------------ UTILS / HELPERS ---------------------
def _meters_per_degree(lat_deg: float):
    m_per_deg_lat = 111_132.0
    m_per_deg_lon = 111_320.0 * math.cos(math.radians(lat_deg))
    return m_per_deg_lat, m_per_deg_lon

def _target_wgs84_grid_from_bounds(west, south, east, north, pixel_m=30.0):
    mid_lat = 0.5 * (south + north)
    m_lat, m_lon = _meters_per_degree(mid_lat)
    dy = pixel_m / m_lat; dx = pixel_m / m_lon
    width = int(math.ceil((east - west) / dx)); height = int(math.ceil((north - south) / dy))
    transform = Affine(dx, 0, west, 0, -dy, north)
    return height, width, transform, CRS.from_epsg(4326)

def _iter_tiles(H, W, tile):
    for r0 in range(0, H, tile):
        for c0 in range(0, W, tile):
            r1 = min(r0 + tile, H); c1 = min(c0 + tile, W)
            yield Window.from_slices((r0, r1), (c0, c1))

def _tile_count(H, W, tile):
    return ((H + tile - 1) // tile) * ((W + tile - 1) // tile)

def _compute_basic_metrics(y_true, y_prob, thr):
    y_pred = (y_prob >= thr).astype(int)
    acc = accuracy_score(y_true, y_pred)
    pr, rc, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
    try: roc = roc_auc_score(y_true, y_prob)
    except Exception: roc = float("nan")
    try: prauc = average_precision_score(y_true, y_prob)
    except Exception: prauc = float("nan")
    cm = confusion_matrix(y_true, y_pred, labels=[0,1])
    return {
        "accuracy": float(acc),
        "precision": float(pr),
        "recall": float(rc),
        "f1": float(f1),
        "roc_auc": float(roc),
        "pr_auc": float(prauc),
        "tn": int(cm[0,0]), "fp": int(cm[0,1]),
        "fn": int(cm[1,0]), "tp": int(cm[1,1]),
    }, cm

def _save_confusion_png(cm, title, path):
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.imshow(cm, cmap="Blues")
    ax.set_xticks([0,1]); ax.set_yticks([0,1])
    ax.set_xlabel("Pred"); ax.set_ylabel("True")
    ax.set_title(title)
    for (r,c),v in np.ndenumerate(cm):
        ax.text(c, r, f"{v}", ha="center", va="center")
    fig.savefig(path, dpi=200, bbox_inches="tight")
    plt.close(fig)

def assign_farm_ids(df: pd.DataFrame, lon_col=LONCOL, lat_col=LATCOL) -> pd.Series:
    if df.empty:
        return pd.Series([], dtype="int64")
    proj = pyproj.Transformer.from_crs(4326, 3857, always_xy=True)
    x, y = proj.transform(df[lon_col].values, df[lat_col].values)
    coords = np.column_stack([x, y])
    db = DBSCAN(eps=DBSCAN_EPS_M, min_samples=DBSCAN_MIN_SAMPLES, metric="euclidean", algorithm="ball_tree")
    labels = db.fit_predict(coords)
    if (labels == -1).any():
        max_lab = labels[labels >= 0].max() if np.any(labels >= 0) else -1
        noise_idx = np.where(labels == -1)[0]
        labels[noise_idx] = np.arange(max_lab + 1, max_lab + 1 + len(noise_idx))
    return pd.Series(labels.astype("int64"), index=df.index, name="farm_id")

def _pick_country_candidate(cols):
    cand = ["country_hint","country","Country","COUNTRY","admin","ADMIN","state","STATE","STATE_NAME","State","province","Province"]
    low = {c.lower(): c for c in cols}
    for c in cand:
        if c.lower() in low:
            return low[c.lower()]
    return None

# ---------------------- DRIVE HELPERS -------------------
def _drive_retry(callable_fn, *args, **kwargs):
    last_err = None
    for i in range(DRIVE_MAX_RETRIES):
        try:
            return callable_fn(*args, **kwargs)
        except (HttpError, ApiRequestError) as e:
            last_err = e
            code = getattr(getattr(e, "resp", None), "status", None)
            if code is None:
                msg = str(e).lower()
                transient = any(k in msg for k in ["internal error","backenderror","rate limit","timeout"])
            else:
                transient = 500 <= int(code) < 600 or int(code) in (403, 429)
            if not transient or i == DRIVE_MAX_RETRIES - 1:
                break
            sleep_s = DRIVE_RETRY_BASE * (2 ** i) + random.random() * 0.2
            print(f"[Drive Retry] attempt {i+1}/{DRIVE_MAX_RETRIES} after error {e}; sleeping {sleep_s:.2f}s")
            time.sleep(sleep_s)
        except Exception as e:
            last_err = e
            break
    raise last_err

def _drive_query(drive, q):
    def _run():
        return drive.ListFile({"q": q, "supportsAllDrives": True, "includeItemsFromAllDrives": True, "maxResults": 1000}).GetList()
    return _drive_retry(_run)

def get_subfolder(drive, parent_id, name):
    if drive is None:
        p = os.path.join(LOCAL_BASE_DIR, name)
        return p if os.path.isdir(p) else None
    q = f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'"
    r = _drive_query(drive, q)
    return r[0]["id"] if r else None

def get_or_create_folder(drive, parent_id, name):
    if drive is None:
        base = LOCAL_BASE_DIR if not os.path.isabs(parent_id) else parent_id
        p = os.path.join(base, name) if os.path.isdir(base) else os.path.join(LOCAL_BASE_DIR, name)
        os.makedirs(p, exist_ok=True)
        return p
    q = f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'"
    res = _drive_query(drive, q)
    if res:
        return res[0]["id"]
    def _create():
        f = drive.CreateFile({"title": name, "parents": [{"id": parent_id}], "mimeType": "application/vnd.google-apps.folder"})
        f.Upload(); return f["id"]
    return _drive_retry(_create)

def list_files(drive, parent_id):
    if drive is None:
        return sorted([os.path.join(parent_id, p) for p in os.listdir(parent_id)]) if os.path.isdir(parent_id) else []
    q = f"'{parent_id}' in parents and trashed=false"
    try: return _drive_query(drive, q)
    except Exception as e:
        print(f"[Drive Warning] list_files failed for parent {parent_id}: {e}")
        return []

def download_to_temp(drive_file, dst_path):
    def _dl():
        drive_file.GetContentFile(dst_path); return dst_path
    return _drive_retry(_dl)

def upload_path(drive, local_path, parent_id, title=None):
    if drive is None:
        os.makedirs(parent_id, exist_ok=True)
        import shutil
        dst = os.path.join(parent_id, title or os.path.basename(local_path))
        shutil.copy2(local_path, dst); return dst
    def _up():
        f = drive.CreateFile({"title": title or os.path.basename(local_path), "parents": [{"id": parent_id}]})
        f.SetContentFile(local_path); f.Upload(); return f["id"]
    return _drive_retry(_up)

def _drive_walk(drive, start_id, max_depth=4):
    q = deque([(start_id, 0)])
    while q:
        fid, d = q.popleft()
        items = list_files(drive, fid)
        yield fid, items
        if d >= max_depth: continue
        for it in items:
            if isinstance(it, dict) and it.get("mimeType") == "application/vnd.google-apps.folder":
                q.append((it["id"], d + 1))

# ---------------------- AEZ SOURCES ----------------------
def open_aez_path(drive):
    """Global AEZ.tif (fallback if per-country tile missing)."""
    if drive is None:
        p = os.path.join(LOCAL_BASE_DIR, "AEZ", AEZ_FILE_NAME)
        return p if os.path.exists(p) else None

    def _resolve_shortcut(file_obj):
        if file_obj.get("mimeType") == "application/vnd.google-apps.shortcut":
            tgt = file_obj.get("shortcutDetails", {}).get("targetId")
            if tgt:
                g = drive.CreateFile({"id": tgt})
                g.FetchMetadata(fields="title,mimeType,shortcutDetails"); return g
        return file_obj

    folder_id = None
    if AEZ_FILE_ID:
        f = drive.CreateFile({"id": AEZ_FILE_ID})
        f.FetchMetadata(fields="title,mimeType,shortcutDetails"); f = _resolve_shortcut(f)
        mime = f.get("mimeType")
        if mime == "application/vnd.google-apps.folder":
            folder_id = f["id"]
        elif mime and mime.startswith("application/vnd.google-apps."):
            raise RuntimeError(f"AEZ_FILE_ID points to a Google Doc ({mime}), not a TIFF.")
        else:
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
            f.GetContentFile(tmp); return tmp

    folder_id = folder_id or os.environ.get("AEZ_FOLDER_ID", None)
    if folder_id:
        q = f"'{folder_id}' in parents and trashed=false and title='{AEZ_FILE_NAME}'"
        cand = _drive_query(drive, q)
        if not cand:
            q_any = f"'{folder_id}' in parents and trashed=false and title contains '.tif'"
            cand = _drive_query(drive, q_any)
            cand = [c for c in cand if c.get("title","") == AEZ_FILE_NAME] or cand
        if not cand:
            raise RuntimeError(f"Could not find {AEZ_FILE_NAME} in folder id {folder_id}.")
        f = drive.CreateFile({"id": cand[0]["id"]})
        tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
        f.GetContentFile(tmp); return tmp

    p = os.path.join(LOCAL_BASE_DIR, "AEZ", AEZ_FILE_NAME)
    return p if os.path.exists(p) else None

# ---------------------- POINTS / CACHE -------------------
def _find_points_vector_drive(drive):
    try: mt = get_subfolder(drive, ROOT_FOLDER_ID, MODEL_FOLDER)
    except Exception: mt = None
    if mt:
        items = list_files(drive, mt); parts=[]; has_shp=False
        for it in items:
            if not isinstance(it, dict): continue
            title = it.get("title",""); low = title.lower()
            if not low.startswith("US."): continue
            if low.endswith(".shp"): has_shp=True
            if re.search(r"\.(shp|dbf|shx|prj|cpg|qpj)$", low): parts.append(it)
        if has_shp and parts:
            tdir = tempfile.mkdtemp()
            for p in parts: download_to_temp(p, os.path.join(tdir, p.get("title","")))
            shp_path = os.path.join(tdir, "US.shp")
            if os.path.exists(shp_path):
                print("Using vector points (shp): Model Training/US.shp"); return shp_path
    roots = [mt or ROOT_FOLDER_ID, ROOT_FOLDER_ID]
    for root in roots:
        for fid, items in _drive_walk(drive, root, max_depth=4):
            stems={}
            for it in items:
                if not isinstance(it, dict): continue
                title = it.get("title",""); m = re.match(rf"(.+)\.(shp|dbf|shx|prj|cpg|qpj)$", title, re.IGNORECASE)
                if not m: continue
                stem = m.group(1)
                if stem.lower() == POINTS_BASENAME.lower():
                    stems.setdefault(stem,[]).append(it)
            for stem, parts in stems.items():
                if any(p.get("title","").lower().endswith(".shp") for p in parts):
                    tdir = tempfile.mkdtemp()
                    for p in parts: download_to_temp(p, os.path.join(tdir, p.get("title","")))
                    shp_path = os.path.join(tdir, f"{os.path.basename(stem)}.shp")
                    if os.path.exists(shp_path):
                        print(f"Using vector points (shp): {os.path.basename(stem)}.shp"); return shp_path
            for it in items:
                if not isinstance(it, dict): continue
                title = it.get("title",""); low = title.lower()
                if it.get("mimeType") == "application/vnd.google-apps.folder": continue
                if low.endswith(".gpkg") and low == f"{POINTS_BASENAME.lower()}.gpkg":
                    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".gpkg").name
                    download_to_temp(it, tmp); print(f"Using vector points (gpkg): {title}"); return tmp
                if low.endswith(".zip") and low == f"{POINTS_BASENAME.lower()}.zip":
                    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".zip").name
                    download_to_temp(it, tmp); print(f"Using vector points (zip): {title}"); return tmp
                if low.endswith(".csv") and low == f"{POINTS_BASENAME.lower()}.csv":
                    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
                    download_to_temp(it, tmp); print(f"Using vector points (csv): {title}"); return tmp
    return None

def _find_points_vector_local():
    search_roots = [os.path.join(LOCAL_BASE_DIR, MODEL_FOLDER), LOCAL_BASE_DIR]
    for root in search_roots:
        if not os.path.isdir(root): continue
        for nm in os.listdir(root):
            low = nm.lower()
            if low == f"{POINTS_BASENAME.lower()}.csv":
                print(f"Using vector points (csv): {nm}"); return os.path.join(root, nm)
            if POINTS_BASENAME.lower() in low and low.endswith(".gpkg"):
                print(f"Using vector points (gpkg): {nm}"); return os.path.join(root, nm)
            if POINTS_BASENAME.lower() in low and low.endswith(".zip"):
                print(f"Using vector points (zip): {nm}"); return os.path.join(root, nm)
    for root in search_roots:
        if not os.path.isdir(root): continue
        for nm in os.listdir(root):
            if nm.lower().endswith(".shp") and POINTS_BASENAME.lower() in nm.lower():
                print(f"Using vector points (shp): {nm}"); return os.path.join(root, nm)
    return None

def _ci_lookup(cols, name):
    name_l = name.lower()
    for c in cols:
        if c.lower() == name_l: return c
    return None

def _round6(x):
    return np.round(pd.to_numeric(x, errors="coerce"), 6)

def _load_gtps_cache_df(drive):
    """Load US_GTPS_per_point.(parquet|csv) from Model Training if present."""
    def _from_drive():
        mt = get_or_create_folder(drive, ROOT_FOLDER_ID, MODEL_FOLDER)
        # prefer parquet
        res = _drive_query(drive, f"'{mt}' in parents and trashed=false and title='{GTPS_CACHE_PARQUET}'")
        if res:
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".parquet").name
            download_to_temp(res[0], tmp)
            try:
                return pd.read_parquet(tmp)
            except Exception:
                pass
        res = _drive_query(drive, f"'{mt}' in parents and trashed=false and title='{GTPS_CACHE_CSV}'")
        if res:
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
            download_to_temp(res[0], tmp)
            return pd.read_csv(tmp)
        return None

    def _from_local():
        p1 = os.path.join(LOCAL_MODEL_DIR, GTPS_CACHE_PARQUET)
        p2 = os.path.join(LOCAL_MODEL_DIR, GTPS_CACHE_CSV)
        if os.path.exists(p1):
            try:
                return pd.read_parquet(p1)
            except Exception:
                pass
        if os.path.exists(p2):
            return pd.read_csv(p2)
        return None

    try:
        return _from_drive() if drive is not None else _from_local()
    except Exception as e:
        print(f"   ! AEZ cache read error: {e}")
        return None

def load_points_df(drive):
    import geopandas as gpd, fiona
    path = _find_points_vector_drive(drive) if drive is not None else _find_points_vector_local()
    if path is None:
        raise RuntimeError("Could not locate US.(shp/zip/gpkg/csv) in 'Model Training' tree.")
    if path.lower().endswith(".csv"):
        df = pd.read_csv(path)
    else:
        with fiona.Env(SHAPE_RESTORE_SHX="YES"):
            gdf = gpd.read_file(path)
        if gdf.crs is not None and gdf.crs.to_epsg() != 4326:
            gdf = gdf.to_crs(4326)
        df = pd.DataFrame(gdf.drop(columns="geometry", errors="ignore"))

    lon_col = _ci_lookup(df.columns, LONCOL); lat_col = _ci_lookup(df.columns, LATCOL)
    if lon_col is None or lat_col is None:
        raise RuntimeError("Points must contain longitude and latitude columns.")
    if lon_col != LONCOL: df[LONCOL] = pd.to_numeric(df[lon_col], errors="coerce")
    if lat_col != LATCOL: df[LATCOL] = pd.to_numeric(df[lat_col], errors="coerce")

    tgt_col = _ci_lookup(df.columns, TARGET)
    if tgt_col is None: raise RuntimeError(f"Points file missing column: {TARGET}")
    if tgt_col != TARGET: df[TARGET] = pd.to_numeric(df[tgt_col], errors="coerce")

    # pre-create rounders for cache join
    df["_lonr"] = _round6(df[LONCOL]); df["_latr"] = _round6(df[LATCOL])

    # copy a country-like column
    if COUNTRY_COL not in df.columns:
        cand = _pick_country_candidate(df.columns)
        df[COUNTRY_COL] = df[cand] if cand else "Unknown"

    for c in [TARGET, LONCOL, LATCOL]:
        if c in df.columns: df[c] = pd.to_numeric(df[c], errors="coerce")
    df = df.dropna(subset=[TARGET, LONCOL, LATCOL]).reset_index(drop=True)
    df[TARGET] = df[TARGET].astype(int)
    df[COUNTRY_COL] = df[COUNTRY_COL].astype(str)

    # Bring AEZ from cache if available
    if AEZ_COL not in df.columns or df[AEZ_COL].isna().all():
        cache = _load_gtps_cache_df(drive)
        if cache is not None:
            # normalize cache cols
            cc = {c.lower(): c for c in cache.columns}
            lonc = cc.get(LONCOL.lower(), None) or cc.get("x") or cc.get("lon") or list(cache.columns)[1]
            latc = cc.get(LATCOL.lower(), None) or cc.get("y") or cc.get("lat") or list(cache.columns)[2]
            aezc = cc.get(AEZ_COL.lower(), None) or cc.get("gtps") or cc.get("aez")
            if aezc is not None:
                cache["_lonr"] = _round6(cache[lonc]); cache["_latr"] = _round6(cache[latc])
                df = df.merge(cache[["_lonr","_latr",aezc]].rename(columns={aezc:AEZ_COL}),
                              on=["_lonr","_latr"], how="left")
                if AEZ_COL in df.columns:
                    df[AEZ_COL] = pd.to_numeric(df[AEZ_COL], errors="coerce")
                    df[AEZ_COL] = np.rint(df[AEZ_COL]).astype("float64")
                    print(f"   • AEZ filled from cache for {np.isfinite(df[AEZ_COL]).sum()}/{len(df)} points")
    return df

# ---------------------- RASTERS -------------------------
def open_country_var_path(drive, country, variable):
    """Any By Country/<Country>/<Country>_<variable>.tif (case-sensitive on names passed in)."""
    if drive is None:
        p = os.path.join(LOCAL_BY_COUNTRY_DIR, country, f"{country}_{variable}.tif")
        return p if os.path.exists(p) else None
    byc = get_subfolder(drive, ROOT_FOLDER_ID, BY_COUNTRY_NAME)
    if not byc: return None
    q = f"'{byc}' in parents and trashed=false and title='{country}' and mimeType='application/vnd.google-apps.folder'"
    res = _drive_query(drive, q)
    if not res: return None
    cid = res[0]["id"]
    q2 = f"'{cid}' in parents and trashed=false and title='{country}_{variable}.tif'"
    files = _drive_query(drive, q2)
    if files:
        tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
        return download_to_temp(files[0], tmp)
    for it in list_files(drive, cid):
        if isinstance(it, dict) and re.match(rf"^{re.escape(country)}_{re.escape(variable)}.*\.tif$", it.get("title",""), re.IGNORECASE):
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
            return download_to_temp(it, tmp)
    return None

def open_CroplandNE(drive):
    if drive is None:
        p = os.path.join(LOCAL_CLE_DIR, "Cropland_Nigeria.tif")
        return p if os.path.exists(p) else None
    items = list_files(drive, CLE_FOLDER_ID)
    for it in items:
        if not isinstance(it, dict): continue
        if it.get("title","").lower().startswith("Cropland_Nigeria") and it.get("title","").lower().endswith(".tif"):
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
            return download_to_temp(it, tmp)
    return None

def _clip_bounds_from_reference(raster_path):
    with rasterio.open(raster_path) as src:
        b = rio_transform_bounds(src.crs, CRS.from_epsg(4326),
                                 src.bounds.left, src.bounds.bottom, src.bounds.right, src.bounds.top, densify_pts=8)
    return b

def _pick_reference_raster_for_bounds(drive, country):
    # Prefer any available per-country raster (including AEZ tile) to set the grid.
    # IMPORTANT: do NOT fall back to global AEZ here, to avoid predicting outside the country.
    for v in ["CLE"] + ALL_PREDICTORS + [AEZ_COL]:
        if v == AEZ_COL:
            p = open_country_var_path(drive, country, AEZ_COL)
        else:
            p = open_country_var_path(drive, country, v)
        if p is not None:
            return p
    raise RuntimeError(f"No per-country reference raster found for bounds in '{country}'.")

def country_grid_and_mask(drive, country):
    ref_path = _pick_reference_raster_for_bounds(drive, country)
    west, south, east, north = _clip_bounds_from_reference(ref_path)
    H, W, transform, crs = _target_wgs84_grid_from_bounds(west, south, east, north, pixel_m=30.0)

    # --- MUST have cropland mask for this country; otherwise skip prediction ---
    country_cropland = open_country_var_path(drive, country, "CLE")
    if country_cropland is None:
        # This error is caught in run_region_then_countries and will print a debug line
        # then continue to the next country.
        raise RuntimeError(f"No cropland CLE raster found for '{country}'; skipping prediction for this country.")

    cropland = clip_cle_to_grid(drive, H, W, transform, crs, source_path=country_cropland)
    return (H, W, transform, crs, cropland, (west, south, east, north))

def clip_cle_to_grid(drive, out_h, out_w, out_transform, out_crs, source_path=None):
    cle = source_path or open_CroplandNE(drive)
    if cle is None: return np.ones((out_h, out_w), dtype=bool)
    arr = np.full((out_h, out_w), -9999.0, dtype=np.float32)
    with rasterio.open(cle) as src:
        reproject(source=rasterio.band(src, 1), destination=arr,
                  src_transform=src.transform, src_crs=src.crs,
                  dst_transform=out_transform, dst_crs=out_crs,
                  dst_nodata=-9999.0, resampling=RioResampling.nearest)
        nd = src.nodata
    if nd is not None: arr = np.where(np.isclose(arr, nd), np.nan, arr)
    arr = np.where(np.isclose(arr, -9999.0), np.nan, arr)
    return np.isfinite(arr) & (arr > 0.5)

# ---------------- THRESHOLDS -----------------------------
def _thr_by_J(y_true, p):
    y = np.asarray(y_true).astype(int)
    if np.unique(y).size < 2: return FIXED_THRESHOLD
    fpr, tpr, th = roc_curve(y, p)
    if len(th) < 2: return FIXED_THRESHOLD
    j = tpr - fpr
    return float(th[int(np.argmax(j))])

def _thr_by_F1(y_true, p):
    y = np.asarray(y_true).astype(int)
    ths = np.linspace(0.01, 0.99, 99); best_t, best_f1 = 0.5, -1
    for t in ths:
        yhat = (p >= t).astype(int)
        _, _, f1, _ = precision_recall_fscore_support(y, yhat, average="binary", zero_division=0)
        if f1 > best_f1: best_f1, best_t = f1, t
    return float(best_t)

def _thr_precision_at(y_true, p, target):
    if target is None: return _thr_by_J(y_true, p)
    prec, rec, thr = precision_recall_curve(y_true, p)
    cand = thr[prec[1:] >= target]
    if len(cand) > 0:
        return float(np.min(cand))  # smallest threshold achieving target precision
    return _thr_by_F1(y_true, p)

def _thr_recall_at(y_true, p, target):
    if target is None: return _thr_by_J(y_true, p)
    prec, rec, thr = precision_recall_curve(y_true, p)
    cand = thr[rec[1:] >= target]
    if len(cand) > 0:
        return float(np.max(cand))  # largest threshold while keeping recall >= target
    return _thr_by_F1(y_true, p)

def _pick_threshold(y_true, p):
    s = THRESHOLD_STRATEGY.lower()
    if s == "fixed":       return float(FIXED_THRESHOLD)
    if s == "f1":          return _thr_by_F1(y_true, p)
    if s == "j":           return _thr_by_J(y_true, p)
    if s == "precision_at":return _thr_precision_at(y_true, p, PRECISION_TARGET)
    if s == "recall_at":   return _thr_recall_at(y_true, p, RECALL_TARGET)
    if PRECISION_TARGET is not None:
        return _thr_precision_at(y_true, p, PRECISION_TARGET)
    if RECALL_TARGET is not None:
        return _thr_recall_at(y_true, p, RECALL_TARGET)
    return _thr_by_J(y_true, p)

# ---------------- Calibration helpers -------------------
def _fit_calibrator(raw_pos_probs, y, method="isotonic"):
    raw_pos_probs = np.asarray(raw_pos_probs, dtype=float).ravel()
    y = np.asarray(y, dtype=int).ravel()
    if method == "platt":
        lr = LogisticRegression(solver="lbfgs", max_iter=1000)
        lr.fit(raw_pos_probs.reshape(-1,1), y)
        return {"method": "platt", "model": lr}
    ir = IsotonicRegression(y_min=0.0, y_max=1.0, out_of_bounds="clip")
    ir.fit(raw_pos_probs, y)
    return {"method": "isotonic", "model": ir}

def _apply_calibrator(calib, raw_pos_probs):
    raw_pos_probs = np.asarray(raw_pos_probs, dtype=float).ravel()
    if calib is None:
        return raw_pos_probs
    if calib.get("method") == "platt":
        return calib["model"].predict_proba(raw_pos_probs.reshape(-1,1))[:,1]
    return calib["model"].predict(raw_pos_probs)

# ------------- POINT SAMPLING FROM RASTERS --------------
def _sample_raster_at_lonlat(src, lon_arr, lat_arr):
    xs = np.asarray(lon_arr, dtype=float); ys = np.asarray(lat_arr, dtype=float)
    if src.crs and (src.crs.to_epsg() != 4326):
        tx, ty = rio_transform(CRS.from_epsg(4326), src.crs, xs.tolist(), ys.tolist())
        xs = np.asarray(tx, dtype=float); ys = np.asarray(ty, dtype=float)
    out = np.full(xs.shape[0], np.nan, dtype=np.float32)
    xmin, ymin, xmax, ymax = src.bounds.left, src.bounds.bottom, src.bounds.right, src.bounds.top
    inside = (xs >= xmin) & (xs <= xmax) & (ys >= ymin) & (ys <= ymax)
    if not np.any(inside): return out
    idx = np.where(inside)[0]
    coords = list(zip(xs[idx], ys[idx]))
    vals = np.array([v[0] for v in src.sample(coords)], dtype=np.float32)
    if src.nodata is not None and np.isfinite(src.nodata):
        vals = np.where(np.isclose(vals, np.float32(src.nodata)), np.float32(np.nan), vals)
    out[idx] = vals; return out

# ---- TRAINING FILL (num predictors + AEZ) ---------------
def _fill_training_predictors_from_country_rasters(drive, df_all, countries):
    df = df_all.copy()

    # AEZ: if still missing after cache merge, try per-country tiles, else global
    if AEZ_COL not in df.columns or df[AEZ_COL].isna().any():
        miss = df.index if AEZ_COL not in df.columns else df.index[df[AEZ_COL].isna()]
        if len(miss) > 0:
            aez_filled = 0
            # try per-country tiles first
            for c in countries:
                try:
                    ref = open_country_var_path(drive, c, AEZ_COL) or open_aez_path(drive)
                    if not ref: continue
                    w, s, e, n = _clip_bounds_from_reference(ref)
                    idx = df.loc[miss].index[
                        df.loc[miss, LONCOL].between(min(w,e), max(w,e)) &
                        df.loc[miss, LATCOL].between(min(s,n), max(s,n))
                    ]
                    if len(idx) == 0: continue
                    with rasterio.open(ref) as aez_src:
                        vals = _sample_raster_at_lonlat(aez_src, df.loc[idx, LONCOL].values, df.loc[idx, LATCOL].values)
                    df.loc[idx, AEZ_COL] = np.rint(pd.to_numeric(vals, errors="coerce")).astype("float64")
                    aez_filled += int(np.isfinite(df.loc[idx, AEZ_COL]).sum())
                except Exception as ex:
                    print(f"   ! AEZ fill skip for {c}: {ex}")
            if aez_filled > 0:
                print(f"   • AEZ sampled from tiles/global for {aez_filled} points")

    # Per-country numeric predictors from rasters
    for c in countries:
        try:
            ref = open_country_var_path(drive, c, "CLE") or open_country_var_path(drive, c, "NDVI_mean")
            if not ref: continue
            w, s, e, n = _clip_bounds_from_reference(ref)
        except Exception as ex:
            print(f"   ! Skip raster fill for {c}: {ex}"); continue
        idx = df.index[df[LONCOL].between(min(w,e), max(w,e)) & df[LATCOL].between(min(s,n), max(s,n))]
        if len(idx) == 0: continue
        raster_cache = {}; filled_any = False
        for p in ALL_PREDICTORS:
            if p == AEZ_COL: continue
            pth = open_country_var_path(drive, c, p)
            if pth is None: continue
            if p not in raster_cache: raster_cache[p] = rasterio.open(pth)
            vals = _sample_raster_at_lonlat(raster_cache[p], df.loc[idx, LONCOL].values, df.loc[idx, LATCOL].values)
            df.loc[idx, p] = pd.to_numeric(vals, errors="coerce").astype("float64")
            filled_any = True
        if filled_any:
            print(f"   • Training fill from rasters for {c}: {len(idx)} points")
    return df

def list_countries_auto(drive):
    names = set()
    if drive is None:
        if os.path.isdir(LOCAL_BY_COUNTRY_DIR):
            for d in os.listdir(LOCAL_BY_COUNTRY_DIR):
                if os.path.isdir(os.path.join(LOCAL_BY_COUNTRY_DIR, d)): names.add(d)
    else:
        byc = get_subfolder(drive, ROOT_FOLDER_ID, BY_COUNTRY_NAME)
        if byc:
            subs = _drive_query(drive, f"'{byc}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder'")
            for it in subs: names.add(it["title"])
    return sorted(names)

# ----------------- TRAIN REGIONAL MODEL -----------------
def _borrow_neighbors(df_all, base_cols, max_points):
    if COUNTRY_COL not in df_all.columns:
        df_all[COUNTRY_COL] = "Unknown"
    keep_cols = list(dict.fromkeys(base_cols))
    pooled = df_all[keep_cols].dropna(subset=[TARGET, LONCOL, LATCOL]).copy()
    if len(pooled) > max_points:
        pooled = pooled.sample(max_points, random_state=RANDOM_SEED)
    if pooled[TARGET].nunique() < 2:
        other = 1 - int(pooled[TARGET].iloc[0])
        extra = df_all[df_all[TARGET] == other]
        if not extra.empty:
            take = min(1000, len(extra))
            pooled = pd.concat([pooled, extra.sample(take, random_state=RANDOM_SEED)[keep_cols]], ignore_index=True)
    if COUNTRY_COL not in pooled.columns:
        pooled[COUNTRY_COL] = "Unknown"
    pooled[COUNTRY_COL] = pooled[COUNTRY_COL].astype(str)
    return pooled.reset_index(drop=True)

def train_region_model(drive, region_name, df_all, countries):
    print(f"=== Train regional model: {region_name} ===")

    # Fill predictors (uses AEZ cache + per-country AEZ tiles)
    df_all = _fill_training_predictors_from_country_rasters(drive, df_all, countries)

    use_preds = [p for p in ALL_PREDICTORS if p in df_all.columns]
    base_cols = [TARGET, LONCOL, LATCOL, COUNTRY_COL] + use_preds
    df = _borrow_neighbors(df_all, base_cols, BORROW_MAX)

    if len(df) < 200 or df[TARGET].nunique() < 2:
        raise RuntimeError(f"Insufficient pooled points for region training (n={len(df)}, classes={df[TARGET].nunique()}).")

    num_features = [f for f in use_preds if f != AEZ_COL]
    cat_features = [AEZ_COL] if AEZ_COL in use_preds else []

    pre = ColumnTransformer(
        transformers=[
            ("num", make_pipeline(SimpleImputer(strategy="median")), num_features),
            ("cat", make_pipeline(SimpleImputer(strategy="most_frequent"),
                                  OneHotEncoder(handle_unknown="ignore", sparse_output=False)), cat_features),
        ],
        remainder="drop",
        verbose_feature_names_out=False,
    )

    rf = RandomForestClassifier(
        n_estimators=RF_TREES, max_features="sqrt", class_weight="balanced_subsample",
        n_jobs=-1, random_state=RANDOM_SEED, bootstrap=True,
    )
    pipe = Pipeline([("pre", pre), ("rf", rf)])

    # ---------- Group-aware OOF to fit calibrator ----------
    df["farm_id"] = assign_farm_ids(df, lon_col=LONCOL, lat_col=LATCOL)
    n_splits = 5 if 0.19 <= TEST_FRACTION <= 0.21 else max(3, int(round(1.0 / max(TEST_FRACTION, 1e-3))))
    sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_SEED)

    X_all = df[use_preds].copy()
    y_all = df[TARGET].astype(int).values
    g_all = df["farm_id"].values
    countries_all = df[COUNTRY_COL].astype(str).values
    lon_all = df[LONCOL].values
    lat_all = df[LATCOL].values

    # enforce dtypes
    for col in [f for f in use_preds if f != AEZ_COL]:
        X_all[col] = pd.to_numeric(X_all[col], errors="coerce").astype("float64")
    if AEZ_COL in use_preds:
        X_all[AEZ_COL] = pd.to_numeric(X_all[AEZ_COL], errors="coerce")
        X_all[AEZ_COL] = np.rint(X_all[AEZ_COL]).astype("float64")

    oof_raw = np.full(len(X_all), np.nan, dtype=float)
    classes_ = None
    for tr_idx, te_idx in sgkf.split(X_all, y_all, groups=g_all):
        tr_X = X_all.iloc[tr_idx]; tr_y = y_all[tr_idx]
        te_X = X_all.iloc[te_idx]
        pipe_fold = clone(pipe)
        pipe_fold.fit(tr_X, tr_y.astype(int))
        if classes_ is None:
            classes_ = pipe_fold.named_steps["rf"].classes_
        pos_idx_fold = int(np.where(classes_ == 1)[0][0])
        oof_raw[te_idx] = pipe_fold.predict_proba(te_X)[:, pos_idx_fold]

    # fit final model on all data
    pipe.fit(X_all, y_all.astype(int))
    classes_ = pipe.named_steps["rf"].classes_
    pos_idx = int(np.where(classes_ == 1)[0][0])

    # calibrate on OOF
    calib = _fit_calibrator(oof_raw, y_all, CALIBRATION_METHOD)
    oof_cal = _apply_calibrator(calib, oof_raw)

    # ---------- GLOBAL threshold from OOF calibrated ----------
    global_thr = _pick_threshold(y_all, oof_cal)

    # ---------- Regional OOF metrics & confusion ----------
    reg_metrics, reg_cm = _compute_basic_metrics(y_all, oof_cal, global_thr)
    reg_metrics.update({
        "region": region_name,
        "n_oof": int(len(y_all)),
        "threshold_strategy": THRESHOLD_STRATEGY,
        "global_threshold": float(global_thr),
        "calibration": CALIBRATION_METHOD,
    })

    # feature importances from final model
    out_names = list(pipe.named_steps["pre"].get_feature_names_out())
    rf_imp = pipe.named_steps["rf"].feature_importances_
    fi_df = pd.DataFrame({"feature": out_names, "importance": rf_imp}).sort_values("importance", ascending=False)

    # ---------- Prep folders ----------
    if drive is None:
        reg_base = os.path.join(LOCAL_MODEL_DIR, REGIONAL_MODELS_FOLDER, region_name)
        os.makedirs(reg_base, exist_ok=True)
        c_base_root = os.path.join(LOCAL_MODEL_DIR, COUNTRY_MODELS_FOLDER)
        os.makedirs(c_base_root, exist_ok=True)
    else:
        mt = get_or_create_folder(drive, ROOT_FOLDER_ID, MODEL_FOLDER)
        reg = get_or_create_folder(drive, mt, REGIONAL_MODELS_FOLDER)
        reg_base = get_or_create_folder(drive, reg, region_name)
        c_base_root = get_or_create_folder(drive, mt, COUNTRY_MODELS_FOLDER)

    # ---------- Ensure country tags (bbox fallback) ----------
    assigned = countries_all.copy()
    unknown_mask = (assigned == "Unknown") | (assigned == "") | pd.isna(assigned)
    if unknown_mask.any():
        for ctry in countries:
            try:
                ref = open_country_var_path(drive, ctry, "CLE") or open_country_var_path(drive, ctry, AEZ_COL) or open_aez_path(drive)
                if not ref: continue
                w, s, e, n = _clip_bounds_from_reference(ref)
                inside = unknown_mask & (lon_all >= min(w, e)) & (lon_all <= max(w, e)) \
                                   & (lat_all >= min(s, n)) & (lat_all <= max(s, n))
                assigned[inside] = ctry
            except Exception:
                pass
        countries_all = assigned

    # ---------- Per-country thresholds & OOF metrics ----------
    oof_df = pd.DataFrame({
        "country": countries_all.astype(str),
        "y": y_all.astype(int),
        "prob_cal": oof_cal.astype(float),
    })

    per_rows = []
    per_thr = {}

    for ctry in sorted(set(countries)):
        sub = oof_df[oof_df["country"] == ctry]
        n_total = int(len(sub))
        n_pos = int((sub["y"] == 1).sum())
        n_neg = n_total - n_pos

        if (n_total < MIN_TOTAL_FOR_COUNTRY) or (n_pos < MIN_POS_NEG_FOR_COUNTRY) or (n_neg < MIN_POS_NEG_FOR_COUNTRY):
            thr_c = float(global_thr); note = "fallback_global"
        else:
            thr_c = _pick_threshold(sub["y"].values, sub["prob_cal"].values); note = "oof_calibrated"

        per_thr[ctry] = float(thr_c)
        m_c, cm_c = _compute_basic_metrics(sub["y"].values, sub["prob_cal"].values, thr_c)
        m_c.update({
            "region": region_name, "country": ctry, "n_oof": n_total,
            "n_pos": n_pos, "n_neg": n_neg, "threshold_used": float(thr_c),
            "note": note
        })
        per_rows.append(m_c)

        # save country artifacts (metrics + confusion + threshold.json)
        cdir = get_or_create_folder(drive, c_base_root, ctry)
        with tempfile.TemporaryDirectory() as tdir:
            with open(os.path.join(tdir, "test_metrics.json"), "w") as f:
                json.dump(m_c, f, indent=2)
            upload_path(drive, os.path.join(tdir, "test_metrics.json"), cdir, "test_metrics.json")

            png = os.path.join(tdir, "test_confusion.png")
            _save_confusion_png(cm_c, f"{ctry} — OOF Confusion (thr={thr_c:.3f})", png)
            upload_path(drive, png, cdir, "test_confusion.png")

            with open(os.path.join(tdir, "threshold.json"), "w") as f:
                json.dump({
                    "threshold_strategy": THRESHOLD_STRATEGY,
                    "mode": THRESHOLD_MODE,
                    "calibration": CALIBRATION_METHOD,
                    "threshold": float(thr_c),
                    "n_total": n_total, "n_pos": n_pos, "n_neg": n_neg,
                    "note": note
                }, f, indent=2)
            upload_path(drive, os.path.join(tdir, "threshold.json"), cdir, "threshold.json")

    # regional CSV summary for per-country
    if len(per_rows) > 0:
        per_df = pd.DataFrame(per_rows).sort_values("country")
    else:
        per_df = pd.DataFrame(columns=["country","n_oof","precision","recall","f1","roc_auc","pr_auc",
                                       "tn","fp","fn","tp","threshold_used","note"])

    # ---------- Save regional artifacts ----------
    with tempfile.TemporaryDirectory() as tdir:
        with open(os.path.join(tdir, "metrics.json"), "w") as f:
            json.dump(reg_metrics, f, indent=2)
        upload_path(drive, os.path.join(tdir, "metrics.json"), reg_base, "metrics.json")

        fi_df.to_csv(os.path.join(tdir, "feature_importance.csv"), index=False)
        upload_path(drive, os.path.join(tdir, "feature_importance.csv"), reg_base, "feature_importance.csv")

        per_df.to_csv(os.path.join(tdir, "per_country_test_metrics.csv"), index=False)
        upload_path(drive, os.path.join(tdir, "per_country_test_metrics.csv"), reg_base, "per_country_test_metrics.csv")

        png = os.path.join(tdir, "confusion.png")
        _save_confusion_png(reg_cm, f"{region_name} — OOF Confusion ({THRESHOLD_STRATEGY}, thr={global_thr:.3f})", png)
        upload_path(drive, png, reg_base, "confusion.png")

        joblib.dump(
            {
                "pipeline": pipe,
                "predictors": use_preds,
                "positive_class_index": int(pos_idx),
                "calibration": calib,
                "global_threshold": float(global_thr),
                "threshold_strategy": THRESHOLD_STRATEGY,
                "per_country_thresholds": per_thr,
                "num_features": [f for f in use_preds if f != AEZ_COL],
                "cat_features": [AEZ_COL] if AEZ_COL in use_preds else [],
            },
            os.path.join(tdir, "model.joblib"),
        )
        upload_path(drive, os.path.join(tdir, "model.joblib"), reg_base, "model.joblib")

    model = {
        "pipe": pipe,
        "use_preds": use_preds,
        "pos_idx": int(pos_idx),
        "calib": calib,
        "global_thr": float(global_thr),
        "per_thr": per_thr,
        "reg_base": reg_base
    }
    return model


# ----------------- PREDICT ONE COUNTRY ------------------
def predict_country(drive, country, model, thr, grid_pack):
    out_dir = get_or_create_folder(drive, ROOT_FOLDER_ID, OUTPUT_FOLDER) if drive else LOCAL_OUTPUT_DIR
    os.makedirs(out_dir, exist_ok=True)

    pipe   = model["pipe"]
    pos_idx= model["pos_idx"]
    calib  = model["calib"]

    H, W, transform, crs, cropland = grid_pack
    raster_cache = {}

    # AEZ: prefer per-country tile, fallback to global AEZ
    aez_src_path = open_country_var_path(drive, country, AEZ_COL) or open_aez_path(drive)
    if AEZ_COL in model["use_preds"] and aez_src_path is not None:
        raster_cache[AEZ_COL] = rasterio.open(aez_src_path)

    for v in model["use_preds"]:
        if v == AEZ_COL: continue
        p = open_country_var_path(drive, country, v)
        if p is not None: raster_cache[v] = rasterio.open(p)

    base_profile = {
        "driver": "GTiff","height": H,"width": W,"count": 1,"crs": crs,"transform": transform,
        "compress": "LZW","tiled": True,"blockxsize": 512,"blockysize": 512,
    }
    prob_profile = {**base_profile, "dtype": "float32", "nodata": -9999.0}
    bin_profile  = {**base_profile, "dtype": "uint8", "nodata": 255}
    cnt_profile  = {**base_profile, "dtype": "uint8", "nodata": 0}

    prob_name = f"{country}_RF_probability_percent.tif"
    bin_name  = f"{country}_RF_binary_0_1_cropland.tif"
    cnt_name  = f"{country}_RF_predictors_count.tif"

    eligible_total = 0; positive_total = 0

    with tempfile.TemporaryDirectory() as tdir:
        prob_tmp = os.path.join(tdir, prob_name)
        bin_tmp  = os.path.join(tdir, bin_name)
        cnt_tmp  = os.path.join(tdir, cnt_name)
        with rasterio.open(prob_tmp, "w", **prob_profile) as dst_prob, \
             rasterio.open(bin_tmp,  "w", **bin_profile)  as dst_bin,  \
             rasterio.open(cnt_tmp,  "w", **cnt_profile) as dst_cnt:

            for win in tqdm(_iter_tiles(H, W, TILE_SIZE), total=_tile_count(H, W, TILE_SIZE), desc=f"Predict {country}"):
                r0, r1 = int(win.row_off), int(win.row_off + win.height)
                c0, c1 = int(win.col_off), int(win.col_off + win.width)
                tile_h, tile_w = (r1 - r0), (c1 - c0)
                tile_transform = transform * Affine.translation(c0, r0)

                crop_tile = cropland[r0:r1, c0:c1]
                if not crop_tile.any():
                    dst_prob.write(np.full((1, tile_h, tile_w), prob_profile["nodata"], np.float32), window=win)
                    dst_bin.write (np.full((1, tile_h, tile_w), bin_profile["nodata"], np.uint8), window=win)
                    dst_cnt.write (np.zeros((1, tile_h, tile_w), np.uint8), window=win)
                    continue

                stack = np.full((len(model["use_preds"]), tile_h, tile_w), np.nan, dtype=np.float32)
                valid_count = np.zeros((tile_h, tile_w), dtype=np.uint8)

                for j, v in enumerate(model["use_preds"]):
                    src = raster_cache.get(v)
                    if src is None:
                        p = open_country_var_path(drive, country, v) if v != AEZ_COL else (open_country_var_path(drive, country, AEZ_COL) or open_aez_path(drive))
                        if p: src = raster_cache[v] = rasterio.open(p)
                        else: continue
                    dst_arr = np.full((tile_h, tile_w), np.nan, dtype=np.float32)
                    dst_nd  = -9999.0
                    resamp  = PER_BAND_RESAMPLING.get(v, RESAMPLING)
                    try:
                        reproject(source=rasterio.band(src, 1), destination=dst_arr,
                                  src_transform=src.transform, src_crs=src.crs,
                                  dst_transform=tile_transform, dst_crs=crs,
                                  dst_nodata=dst_nd, resampling=resamp)
                    except Exception as e:
                        print(f"      ! reprojection failed for {v}: {e}"); continue
                    if src.nodata is not None:
                        dst_arr = np.where(np.isclose(dst_arr, src.nodata), np.nan, dst_arr)
                    dst_arr = np.where(np.isclose(dst_arr, dst_nd), np.nan, dst_arr)
                    stack[j] = dst_arr; valid_count += np.isfinite(dst_arr)

                req = (len(model["use_preds"]) if STRICT_ALL_BANDS else min(MIN_PREDICTORS, len(model["use_preds"])))
                good = crop_tile & (valid_count >= req)
                prob_tile = np.full((tile_h, tile_w), np.nan, dtype=np.float32)

                if good.any():
                    X = stack.reshape(len(model["use_preds"]), -1).T
                    good_flat = good.reshape(-1)
                    X_good = X[good_flat].astype("float64")
                    X_good_df = pd.DataFrame(X_good, columns=model["use_preds"])

                    # enforce dtypes
                    num_features = [f for f in model["use_preds"] if f != AEZ_COL]
                    for col in num_features:
                        X_good_df[col] = pd.to_numeric(X_good_df[col], errors="coerce").astype("float64")
                    if AEZ_COL in X_good_df.columns:
                        X_good_df[AEZ_COL] = pd.to_numeric(X_good_df[AEZ_COL], errors="coerce")
                        X_good_df[AEZ_COL] = np.rint(X_good_df[AEZ_COL]).astype("float64")

                    prob_raw = pipe.predict_proba(X_good_df)[:, pos_idx].astype(np.float32)
                    prob_vals = _apply_calibrator(calib, prob_raw).astype(np.float32)

                    bin_vals  = (prob_vals >= thr).astype(np.uint8)
                    prob_tile.reshape(-1)[good_flat] = prob_vals

                    mask_arr = np.isfinite(prob_tile)
                    prob_out = np.full((1, tile_h, tile_w), prob_profile["nodata"], np.float32)
                    prob_out[0][mask_arr] = (prob_tile[mask_arr] * 100.0).astype(np.float32)
                    dst_prob.write(prob_out, window=win)

                    bin_out = np.full((1, tile_h, tile_w), bin_profile["nodata"], np.uint8)
                    bin_out[0].reshape(-1)[good_flat] = bin_vals
                    dst_bin.write(bin_out, window=win)

                    eligible_total += int(mask_arr.sum()); positive_total += int(bin_vals.sum())
                else:
                    dst_prob.write(np.full((1, tile_h, tile_w), prob_profile["nodata"], np.float32), window=win)
                    dst_bin.write (np.full((1, tile_h, tile_w), bin_profile["nodata"], np.uint8), window=win)

                dst_cnt.write(valid_count[np.newaxis], window=win)

        upload_path(drive, prob_tmp, out_dir, prob_name)
        upload_path(drive, bin_tmp,  out_dir, bin_name)
        upload_path(drive, cnt_tmp,  out_dir, cnt_name)

    frac = (positive_total / max(1, eligible_total)) * 100.0
    print(f"   • predicted irrigated fraction inside cropland: {frac:.3f}% ({positive_total:,}/{eligible_total:,})")

# ----------------------------- RUNNER -------------------
def run_region_then_countries(use_drive=True, countries=None, region_name="US"):
    if use_drive:
        try:
            drive  # noqa: F821
            _drive = drive
            try: _drive.auth.service.http.timeout = 120
            except Exception: pass
        except NameError:
            raise RuntimeError("PyDrive 'drive' not found. Authenticate and expose 'drive', or set use_drive=False.")
    else:
        _drive = None
        os.makedirs(LOCAL_OUTPUT_DIR, exist_ok=True)

    df_all = load_points_df(_drive)
    for c in [TARGET, LONCOL, LATCOL] + [p for p in ALL_PREDICTORS if p in df_all.columns and p != AEZ_COL]:
        if c in df_all.columns:
            df_all[c] = pd.to_numeric(df_all[c], errors="coerce")
    df_all = df_all.dropna(subset=[TARGET, LONCOL, LATCOL]).reset_index(drop=True)
    df_all[TARGET] = df_all[TARGET].astype(int)

    if not countries:
        countries = [
            "Nigeria",
            # "Sudan","Chad", "Niger"
            # "Florida","Georgia","Arkansas","Kentucky","Louisiana",
            #          "Mississippi","South Carolina","North Carolina"
 
        ]

    model = train_region_model(_drive, region_name, df_all, countries)

    # Predict
    print(f"=== PREDICTION per-country (threshold mode: {THRESHOLD_MODE}) ===")
    for c in countries:
        try:
            H, W, transform, crs, cropland, _ = country_grid_and_mask(_drive, c)
            thr_c = model["per_thr"].get(c, model["global_thr"]) if THRESHOLD_MODE in ("country","hybrid") else model["global_thr"]
            predict_country(_drive, c, model, thr_c, (H, W, transform, crs, cropland))
            print(f"✓ {c}: maps → {OUTPUT_FOLDER}, model → {MODEL_FOLDER}/{REGIONAL_MODELS_FOLDER}/{region_name}, thr={thr_c:.3f}")
        except Exception as e:
            print(f"❌ {c}: {e}")

    print("✅ Done.")

if __name__ == "__main__":
    run_region_then_countries(use_drive=True, countries=None, region_name="Africa")


Using vector points (shp): Africa.shp
   • AEZ filled from cache for 265/15126 points
=== Train regional model: Africa ===
   • AEZ sampled from tiles/global for 14861 points


In [4]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Per-admin AEI-constrained binarization using ONLY the admin shapefile.

For each country probability TIFF:
  - Use unit_code polygons from AEI_2020_2_withAEI.* (AEI in HECTARES).
  - For each unit_code, pick pixels INSIDE (centroid-in) from highest probability
    downward until its AEI quota is met (quota = floor(AEI_ha*10_000 / 900) pixels).
  - Preserve NaNs from the probability input.
  - Write ONE binary 0/1 TIFF per country to:
      Drive → CountryModelPredicted/Probability/Binary/
  - Also write a CSV summary per country with thresholds and counts.

No AEI base raster is used.
"""

import os, re, math, csv, tempfile, warnings, unicodedata
from collections import defaultdict, Counter

import numpy as np
import rasterio
from rasterio.windows import Window
from rasterio.features import rasterize

import fiona
from shapely.geometry import shape, box
from shapely.ops import transform as shp_transform
from shapely.strtree import STRtree
from pyproj import Transformer

# ------------------------- CONFIG -------------------------
ROOT_FOLDER_ID        = os.environ.get("ROOT_FOLDER_ID", "1ReAmJ809FmU-8jLpyyO6VEdmeqDmbaB7")
PARENT_FOLDER_NAME    = "CountryModelPredicted"
PROB_SUBFOLDER_NAME   = "Probability"      # fuzzy-matched (also matches "Porbability")
NATIONAL_AEI_FOLDER   = "National AEI"     # shapefile lives here

# Probability & binning
PIXEL_AREA_M2         = 30.0 * 30.0        # 900 m² per 30 m pixel
TILE                  = 1024               # reduce if memory tight
SCALE                 = 1000               # probability bins: 0..1 → 0..1000

# Shapefile (AEI in HECTARES)
ADMIN_SHP_BASE        = "Copy of AEI_2020_2_with_AEI"  # base name (no extension)
ADMIN_CODE_COL        = "unit_code"
ADMIN_AEI_COLS        = ["AEI_2020", "AEI2020", "AEI"]  # values in HECTARES
ADMIN_CNTRY_COLS      = ["name_cntr", "name_cntr1", "name_admin", "ST_NM"]

# Deterministic tie-breaking
RNG_SEED              = int(os.environ.get("AEI_RNG_SEED", "0"))
_rng = np.random.default_rng(RNG_SEED)
# ----------------------------------------------------------

warnings.filterwarnings("ignore", category=RuntimeWarning)

# ---------------------- Drive helpers ---------------------
def _dq(drive, q):
    return drive.ListFile({
        "q": q,
        "supportsAllDrives": True,
        "includeItemsFromAllDrives": True,
        "maxResults": 1000
    }).GetList()

def list_files(drive, parent_id):
    return _dq(drive, f"'{parent_id}' in parents and trashed=false")

def child_folders(drive, parent_id):
    return _dq(drive, f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder'")

def get_subfolder_exact(drive, parent_id, name):
    res = _dq(drive, f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'")
    return res[0]["id"] if res else None

def get_subfolder_fuzzy(drive, parent_id, desired):
    eid = get_subfolder_exact(drive, parent_id, desired)
    if eid: return eid
    key = re.sub(r"\s+", "", desired.lower())
    for f in child_folders(drive, parent_id):
        tkey = re.sub(r"\s+", "", f.get("title","").lower())
        if key in tkey or ("prob" in key and "prob" in tkey):
            return f["id"]
    return None

def _resolve_field(props_sample: dict, candidates, required=False):
    """Fuzzy, case-insensitive resolver for attribute fields."""
    def canon(s): return re.sub(r"[^a-z0-9]+", "", str(s).lower())
    keys = list(props_sample.keys())
    norm_map = {canon(k): k for k in keys}
    for want in candidates:
        w = canon(want)
        if w in norm_map:
            return norm_map[w]
    want_roots = {canon(want) for want in candidates}
    for k in keys:
        ck = canon(k)
        if any(root in ck for root in want_roots):
            return k
    if required:
        raise RuntimeError(f"Required attribute not found. Looked for: {candidates}. Available: {keys}")
    return None

def download_to_temp(drive_file, suffix):
    p = tempfile.NamedTemporaryFile(delete=False, suffix=suffix).name
    drive_file.GetContentFile(p)
    return p

def get_or_create_folder(drive, parent_id, name):
    res = _dq(drive, f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'")
    if res: return res[0]["id"]
    nf = drive.CreateFile({"title": name, "parents":[{"id": parent_id}], "mimeType":"application/vnd.google-apps.folder"})
    nf.Upload()
    return nf["id"]

def upload_path(drive, local_path, parent_id, title=None):
    f = drive.CreateFile({"title": title or os.path.basename(local_path), "parents":[{"id": parent_id}]})
    f.SetContentFile(local_path)
    f.Upload()
    return f["id"]

# ---------------------- Name helpers ----------------------
def _norm(s): return re.sub(r"[^a-z0-9]+", "", str(s).lower())
def _canon(s):
    s = unicodedata.normalize("NFKD", str(s)).encode("ascii","ignore").decode()
    return re.sub(r"\s+","", s.lower())

def _extract_country_from_fname(fname):
    # Albania_RF_probability_percent.tif → Albania
    fn = re.sub(r"\s*\(.*\)\.tif(f)?$", ".tif", fname, flags=re.IGNORECASE)
    m = re.match(r"(.+?)_RF_probability_percent", fn, flags=re.IGNORECASE)
    if m: return m.group(1)
    return re.sub(r"\.tif(f)?$", "", fn, flags=re.IGNORECASE)

# ---------------------- Shapefile helpers -----------------
def _download_shapefile_bundle(drive, folder_id, base):
    """
    Download AEI_2020_2_withAEI.* into a single temp directory, ensuring
    all sidecars share the SAME basename so GDAL/Fiona can see attributes.
    Returns a dict of local paths keyed by extension ('.shp', '.dbf', etc).
    """
    exts = [".shp", ".shx", ".dbf", ".prj", ".cpg"]
    items = {}
    for it in list_files(drive, folder_id):
        t = it.get("title", "")
        for e in exts:
            if t.lower() == (base.lower() + e):
                items[e] = it

    if ".shp" not in items or ".dbf" not in items:
        raise FileNotFoundError(f"Missing pieces of {base} shapefile (need at least .shp and .dbf). Found: {sorted(items.keys())}")

    tmpdir = tempfile.mkdtemp(prefix="aei_admin_")
    out = {}
    for e, it in items.items():
        local_path = os.path.join(tmpdir, base + e)  # SAME BASENAME!
        it.GetContentFile(local_path)
        out[e] = local_path

    print(f"[AEI] Shapefile bundle at {tmpdir} → found {sorted(out.keys())}")
    return out


def _read_admins_for_raster(shp_path, raster_crs, raster_bounds):
    """
    Load admin polygons, reproject to raster CRS (no AEI filtering).
    - Reads field names from schema (not the first feature).
    - Honors .cpg encoding when present.
    - Skips null/malformed geometries defensively.
    """
    feats, attrs = [], []
    shp_dir = os.path.dirname(shp_path)
    base = os.path.splitext(os.path.basename(shp_path))[0]
    cpg_path = os.path.join(shp_dir, base + ".cpg")

    # Determine DBF encoding
    encoding = None
    if os.path.exists(cpg_path):
        try:
            with open(cpg_path, "r", encoding="ascii", errors="ignore") as f:
                enc_line = f.read().strip()
                if enc_line:
                    encoding = enc_line
        except Exception:
            pass

    def _open_fiona(enc):
        return fiona.open(shp_path, encoding=enc) if enc else fiona.open(shp_path)

    shp_crs_final = None
    with fiona.Env(SHAPE_RESTORE_SHX='YES'):
        # Try cpg encoding → utf-8 → latin1
        tried = [encoding, "utf-8", "latin1"]
        last_err = None
        for enc in tried:
            try:
                with _open_fiona(enc) as src:
                    shp_crs_local = src.crs_wkt or src.crs
                    props_schema = (src.schema or {}).get("properties", {})
                    field_names = list(props_schema.keys())
                    if not field_names:
                        raise RuntimeError("No attribute fields in schema (DBF not visible).")

                    # Resolve keys against SCHEMA (not a sample feature)
                    dummy_props = {k: None for k in field_names}
                    code_key  = _resolve_field(dummy_props, [ADMIN_CODE_COL], required=True)
                    aei_key   = _resolve_field(dummy_props, ADMIN_AEI_COLS, required=True)  # AEI in hectares
                    cntry_key = _resolve_field(dummy_props, ADMIN_CNTRY_COLS, required=False)

                    # Transform raster bounds to shapefile CRS for coarse prefilter
                    if shp_crs_local:
                        rb_to_shp = Transformer.from_crs(raster_crs, shp_crs_local, always_xy=True)
                        rb_shp = shp_transform(lambda x, y: rb_to_shp.transform(x, y), box(*raster_bounds))
                    else:
                        rb_shp = box(*raster_bounds)

                    # Iterate features
                    for rec in src:
                        gj = rec.get("geometry")
                        if gj is None:
                            continue  # null geometry → skip (defensive)
                        try:
                            g = shape(gj)
                        except Exception:
                            continue
                        if g.is_empty:
                            continue
                        if not g.intersects(rb_shp):
                            continue

                        props = rec.get("properties") or {}
                        try:
                            uc = int(props[code_key])
                            aei_ha = float(props[aei_key])  # hectares
                        except Exception:
                            continue

                        feats.append(g)
                        attrs.append({
                            "unit_code": uc,
                            "aei_ha": aei_ha,
                            "country": str(props.get(cntry_key, "")).strip() if cntry_key else ""
                        })
                # success → keep the CRS we used
                shp_crs_final = shp_crs_local
                break
            except Exception as e:
                last_err = e
                continue

        if last_err and not feats:
            raise RuntimeError(f"Failed to read attributes from shapefile. Tried encodings {tried}. Last error: {last_err}")

    if not feats:
        return [], [], None, {}

    # Reproject to raster CRS for rasterize
    shp_crs = shp_crs_final
    if not shp_crs:
        print("⚠️  Shapefile has no CRS (.prj missing). Assuming raster CRS.")
        shp_crs = raster_crs

    transformer = Transformer.from_crs(shp_crs, raster_crs, always_xy=True)
    geoms_ras = [shp_transform(lambda x, y: transformer.transform(x, y), g) for g in feats]

    tree_ras = STRtree(geoms_ras)
    # IMPORTANT: map by WKB (value identity), not id(...)
    g2i_wkb = {g.wkb: i for i, g in enumerate(geoms_ras)}

    return geoms_ras, attrs, tree_ras, g2i_wkb


# ---- STRtree helper: get candidate indices robustly (Shapely 2 or fallback) ----
def _tree_candidate_indices(tree_ras, tile_poly, geoms_ras, g2i_wkb):
    """
    Return list of indices of geoms that intersect tile_poly.
    Prefer Shapely 2's predicate indices; otherwise map WKBs.
    """
    # Fast path: Shapely 2 can return integer indices with predicate
    try:
        idx = tree_ras.query(tile_poly, predicate="intersects")
        if isinstance(idx, np.ndarray) and np.issubdtype(idx.dtype, np.integer):
            return idx.tolist()
    except TypeError:
        # Older shapely: predicate argument not supported
        pass

    # Fallback: geometry array → map to indices by WKB, then precise intersects
    cand = tree_ras.query(tile_poly)
    if isinstance(cand, np.ndarray):
        cand = cand.tolist()
    out = []
    for g in cand:
        i = g2i_wkb.get(g.wkb, None)
        if i is None:
            # last resort: linear search (rare)
            try:
                i = next(j for j, gg in enumerate(geoms_ras) if gg.equals(g))
            except StopIteration:
                continue
        if geoms_ras[i].intersects(tile_poly):
            out.append(i)
    return out


# ---------------------- Raster helpers -------------------
def _iter_tiles(H, W, tile=TILE):
    for r0 in range(0, H, tile):
        for c0 in range(0, W, tile):
            h = min(tile, H - r0)
            w = min(tile, W - c0)
            yield Window(c0, r0, w, h)

def _read_prob_tile(src, W):
    arr = src.read(1, window=W, out_dtype="float32", masked=True).filled(np.nan)
    finite = np.isfinite(arr)
    if finite.any() and float(np.nanmax(arr[finite])) > 1.5:
        arr[finite] /= 100.0
    np.clip(arr, 0.0, 1.0, out=arr, where=finite)
    return arr

def _tile_bounds(window, transform):
    left, top = transform * (window.col_off, window.row_off)
    right, bottom = transform * (window.col_off + window.width, window.row_off + window.height)
    x0, x1 = sorted([left, right])
    y0, y1 = sorted([bottom, top])
    return (x0, y0, x1, y1)

# ---------------------- Core algorithm -------------------
def aei_binarize_per_admin_from_shapefile(drive):
    """
    Main entry: uses ONLY the admin shapefile with AEI in hectares
    to allocate per-admin pixel quotas and write one binary per country.
    """
    # Locate folders
    cmp_id  = get_subfolder_fuzzy(drive, ROOT_FOLDER_ID, PARENT_FOLDER_NAME)
    if not cmp_id: raise RuntimeError(f"Folder '{PARENT_FOLDER_NAME}' not found under ROOT.")
    prob_id = get_subfolder_fuzzy(drive, cmp_id, PROB_SUBFOLDER_NAME)
    if not prob_id: raise RuntimeError("Probability folder not found (tried fuzzy match).")
    binary_id = get_or_create_folder(drive, prob_id, "Binary")

    aei_folder_id = get_subfolder_fuzzy(drive, ROOT_FOLDER_ID, NATIONAL_AEI_FOLDER)
    if not aei_folder_id: raise RuntimeError("National AEI folder not found at ROOT.")

    # Download admin shapefile bundle
    shp_paths = _download_shapefile_bundle(drive, aei_folder_id, ADMIN_SHP_BASE)
    shp_path  = shp_paths[".shp"]

    # List probability TIFFs
    files = [it for it in list_files(drive, prob_id)
             if isinstance(it, dict)
             and it.get("mimeType") != "application/vnd.google-apps.folder"
             and it.get("title","").lower().endswith((".tif",".tiff"))]
    if not files:
        raise FileNotFoundError("No probability TIFFs in Probability folder.")

    for it in files:
        title   = it.get("title","")
        country = _extract_country_from_fname(title)
        print(f"\n=== {country} (per-admin from shapefile; AEI in hectares) ===")

        rtmp = download_to_temp(it, ".tif")
        with rasterio.open(rtmp) as src:
            H, W = src.height, src.width
            ras_crs = src.crs
            rb = src.bounds
            ras_bounds = (rb.left, rb.bottom, rb.right, rb.top)

            # Read & subset admins, reproject to raster CRS
            geoms_ras, attrs, tree_ras, g2i_wkb = _read_admins_for_raster(shp_path, ras_crs, ras_bounds)
            if not geoms_ras:
                print("  ⚠️  No admin polygons intersect this raster; skipping.")
                continue

            # Optional filter by country name if present (keeps all if missing)
            want = _canon(country)
            keep = [i for i,a in enumerate(attrs) if (not a["country"]) or _canon(a["country"]) == want]
            if keep and len(keep) < len(attrs):
                geoms_ras = [geoms_ras[i] for i in keep]
                attrs     = [attrs[i] for i in keep]
                tree_ras  = STRtree(geoms_ras)
                g2i_wkb   = {g.wkb: i for i, g in enumerate(geoms_ras)}

            # Targets per admin (AEI in HECTARES → m² → pixels); include zeros
            K_map = {}
            for a in attrs:
                aei_m2 = a["aei_ha"] * 10_000.0
                K_map[a["unit_code"]] = int(math.floor(aei_m2 / PIXEL_AREA_M2))  # may be 0

            # PASS 1: per-admin histograms of probability bins (centroid-in)
            hists = defaultdict(Counter)
            for w in _iter_tiles(H, W, TILE):
                prob = _read_prob_tile(src, w)
                valid = np.isfinite(prob)
                if not valid.any():
                    continue

                tile_t = rasterio.windows.transform(w, src.transform)
                tb = _tile_bounds(w, src.transform)
                tile_poly = box(*tb)

                idxs = _tree_candidate_indices(tree_ras, tile_poly, geoms_ras, g2i_wkb)
                if len(idxs) == 0:
                    continue

                shapes = [(geoms_ras[i], attrs[i]["unit_code"]) for i in idxs]
                labels = rasterize(
                    shapes=shapes,
                    out_shape=prob.shape,
                    transform=tile_t,
                    fill=0, dtype="int64",
                    all_touched=False  # centroid-in
                )
                m = valid & (labels != 0)
                if not m.any():
                    continue

                p_int = np.zeros(prob.shape, dtype=np.int32)
                p_int_valid = np.rint(prob[m] * SCALE).astype(np.int32)
                p_int[m] = p_int_valid

                uc = labels[m].ravel()
                pi = p_int[m].ravel()
                for u in np.unique(uc):
                    sel = (uc == u)
                    bc = np.bincount(pi[sel], minlength=SCALE+1)
                    nz = np.nonzero(bc)[0]
                    for b, v in zip(nz, bc[nz]):
                        hists[u][int(b)] += int(v)

            # thresholds per admin (quota 0 => thr=-1, no selection)
            thr_map, need_eq_map = {}, {}
            for u, K in K_map.items():
                total = sum(hists[u].values())
                if K <= 0 or total == 0:
                    thr_map[u] = -1
                    need_eq_map[u] = 0
                    continue
                K = min(K, total)
                cum = 0; gt = 0
                for b in range(SCALE, -1, -1):
                    cnt = int(hists[u].get(b, 0))
                    if cum + cnt >= K:
                        thr_map[u] = b
                        need_eq_map[u] = K - gt
                        break
                    cum += cnt; gt += cnt
            need_eq_left = dict(need_eq_map)

            # PASS 2: write binary output (NaN preserved; default 0; set 1s per admin)
            out_profile = src.profile.copy()
            out_profile.update(
                driver="GTiff",
                height=H, width=W,
                transform=src.transform,
                count=1, dtype="float32", nodata=np.nan,
                compress="LZW", tiled=True, blockxsize=512, blockysize=512,
                BIGTIFF="IF_NEEDED"
            )
            with tempfile.NamedTemporaryFile(delete=False, suffix=".tif") as tmp_out:
                out_local = tmp_out.name

            with rasterio.open(out_local, "w", **out_profile) as dst:
                for w in _iter_tiles(H, W, TILE):
                    prob = _read_prob_tile(src, w)
                    out_tile = np.full(prob.shape, np.nan, dtype=np.float32)

                    valid = np.isfinite(prob)
                    if not valid.any():
                        dst.write(out_tile, 1, window=w)
                        continue

                    tile_t = rasterio.windows.transform(w, src.transform)
                    tb = _tile_bounds(w, src.transform)
                    tile_poly = box(*tb)

                    idxs = _tree_candidate_indices(tree_ras, tile_poly, geoms_ras, g2i_wkb)
                    if len(idxs) == 0:
                        dst.write(out_tile, 1, window=w)
                        continue

                    shapes = [(geoms_ras[i], attrs[i]["unit_code"]) for i in idxs]
                    labels = rasterize(
                        shapes=shapes,
                        out_shape=prob.shape,
                        transform=tile_t,
                        fill=0, dtype="int64",
                        all_touched=False
                    )

                    out_tile[valid] = 0.0  # default: valid-but-not-selected = 0
                    m_all = valid & (labels != 0)
                    if m_all.any():
                        p_int = np.zeros(prob.shape, dtype=np.int32)
                        p_int_valid = np.rint(prob[m_all] * SCALE).astype(np.int32)
                        p_int[m_all] = p_int_valid

                        present = np.unique(labels[m_all])
                        present = [u for u in present if u != 0]  # all units allowed
                        for u in present:
                            u_m = m_all & (labels == u)
                            if not u_m.any():
                                continue
                            t = thr_map.get(u, -1)
                            if t < 0:
                                continue  # quota 0 or no pixels -> stays 0
                            gt_m = u_m & (p_int > t)
                            out_tile[gt_m] = 1.0
                            need = need_eq_left.get(u, 0)
                            if need > 0:
                                eq_m = u_m & (p_int == t) & (out_tile != 1.0)
                                if eq_m.any():
                                    idx = np.flatnonzero(eq_m.ravel())
                                    _rng.shuffle(idx)
                                    take = min(need, idx.size)
                                    sel = idx[:take]
                                    rr, cc = np.unravel_index(sel, eq_m.shape)
                                    out_tile[rr, cc] = 1.0
                                    need_eq_left[u] = need - int(take)

                    dst.write(out_tile, 1, window=w)

            out_name = f"{country}_AEI_binary_0_1.tif"
            upload_path(drive, out_local, binary_id, title=out_name)
            try: os.remove(out_local)
            except: pass

            # CSV summary per admin (includes AEI==0 units)
            with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w", newline="") as tmpcsv:
                wcsv = csv.writer(tmpcsv)
                wcsv.writerow(["unit_code","aei_ha","target_pixels","thr_bin","scale","selected_pixels"])
                for a in attrs:
                    u = a["unit_code"]
                    aei_ha = a["aei_ha"]
                    K = int(math.floor((aei_ha * 10_000.0) / PIXEL_AREA_M2))
                    th = int(thr_map.get(u, -1))
                    gt = sum(v for b, v in (hists[u].items() if u in hists else []) if b > th)
                    ties_taken = (need_eq_map.get(u, 0) - need_eq_left.get(u, 0))
                    sel = int(gt + max(0, ties_taken))
                    wcsv.writerow([u, aei_ha, int(K), th, SCALE, sel])
                csv_path = tmpcsv.name
            upload_path(drive, csv_path, binary_id, title=f"{country}_AEI_admin_summary.csv")
            try: os.remove(csv_path)
            except: pass

            print(f"   ✓ Wrote Binary/{out_name}")

        try: os.remove(rtmp)
        except: pass

    print("\n✅ Done (per-admin from shapefile; AEI in hectares; NaNs preserved).")

# ---------------------- CLI ----------------------
if __name__ == "__main__":
    try:
        drive  # noqa: F821
    except NameError:
        raise RuntimeError("PyDrive2 'drive' not found. Authenticate and expose a global `drive` before running.")
    aei_binarize_per_admin_from_shapefile(drive)


[AEI] Shapefile bundle at /tmp/aei_admin_v75mbg_n → found ['.cpg', '.dbf', '.prj', '.shp', '.shx']

=== Nigeria (per-admin from shapefile; AEI in hectares) ===
   ✓ Wrote Binary/Nigeria_AEI_binary_0_1.tif

✅ Done (per-admin from shapefile; AEI in hectares; NaNs preserved).


In [7]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Post-processing for AEI per-admin binary masks.

Step 2 after running `aei_binarize_per_admin_from_shapefile`:

  - Take each *_AEI_binary_0_1.tif in
        ROOT / CountryModelPredicted_Cropland / Probability / Binary
  - Apply a window-based majority filter (on the 0/1 binary mask)
    in a streaming / tile-based fashion.
  - Preserve NaNs from the input.
  - Write smoothed 0/1 binary TIFFs to:
        ROOT / CountryModelPredicted_Cropland / Probability / Binary_MAJ

Notes
-----
- This does NOT change the original AEI thresholding step; it only smooths
  the resulting binary maps, so exact per-admin AEI may change slightly.
- Requires SciPy for `scipy.ndimage.uniform_filter`.
"""

import os
import re
import tempfile
import warnings

import numpy as np
import rasterio
from rasterio.windows import Window
from scipy.ndimage import uniform_filter   # pip install scipy

# ------------------------- CONFIG -------------------------
ROOT_FOLDER_ID      = os.environ.get("ROOT_FOLDER_ID", "1ReAmJ809FmU-8jLpyyO6VEdmeqDmbaB7")
PARENT_FOLDER_NAME  = "CountryModelPredicted"
PROB_SUBFOLDER_NAME = "Probability"   # fuzzy-matched
BINARY_FOLDER_NAME  = "Binary"        # input binaries from step 1
SMOOTH_FOLDER_NAME  = "Binary_MAJ_West_Africa"    # output smoothed binaries

TILE                = 1024            # I/O tile size
KERNEL_PX           = 9               # majority window (odd, in pixels, e.g. 9 ~ 270 m at 30 m res)
MAJ_THR             = 0.50            # majority threshold (>= 50% neighbors == 1)

warnings.filterwarnings("ignore", category=RuntimeWarning)

# ---------------------- Drive helpers ---------------------
def _dq(drive, q):
    return drive.ListFile({
        "q": q,
        "supportsAllDrives": True,
        "includeItemsFromAllDrives": True,
        "maxResults": 1000
    }).GetList()

def list_files(drive, parent_id):
    return _dq(drive, f"'{parent_id}' in parents and trashed=false")

def child_folders(drive, parent_id):
    return _dq(drive, f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder'")

def get_subfolder_exact(drive, parent_id, name):
    res = _dq(drive, f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'")
    return res[0]["id"] if res else None

def get_subfolder_fuzzy(drive, parent_id, desired):
    eid = get_subfolder_exact(drive, parent_id, desired)
    if eid:
        return eid
    key = re.sub(r"\s+", "", desired.lower())
    for f in child_folders(drive, parent_id):
        tkey = re.sub(r"\s+", "", f.get("title", "").lower())
        if key in tkey or ("prob" in key and "prob" in tkey):
            return f["id"]
    return None

def get_or_create_folder(drive, parent_id, name):
    res = _dq(drive, f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'")
    if res:
        return res[0]["id"]
    nf = drive.CreateFile({"title": name, "parents":[{"id": parent_id}], "mimeType":"application/vnd.google-apps.folder"})
    nf.Upload()
    return nf["id"]

def download_to_temp(drive_file, suffix):
    p = tempfile.NamedTemporaryFile(delete=False, suffix=suffix).name
    drive_file.GetContentFile(p)
    return p

def upload_path(drive, local_path, parent_id, title=None):
    f = drive.CreateFile({"title": title or os.path.basename(local_path), "parents":[{"id": parent_id}]})
    f.SetContentFile(local_path)
    f.Upload()
    return f["id"]

# ---------------------- Raster helpers -------------------
def _iter_tiles(H, W, tile=TILE):
    for r0 in range(0, H, tile):
        for c0 in range(0, W, tile):
            h = min(tile, H - r0)
            w = min(tile, W - c0)
            yield Window(c0, r0, w, h)

# ---------------------- Majority smoothing ---------------

def _smooth_binary_stream(src_path, dst_path, kernel_px=KERNEL_PX, maj_thr=MAJ_THR):
    """
    Read a float32 0/1/NaN binary raster in tiles and write a smoothed
    0/1/NaN raster using a window-based majority filter.

    - Input: float32, nodata = NaN, values 0 or 1 elsewhere.
    - Output: float32, nodata = NaN, values 0 or 1 elsewhere.
    """
    r = kernel_px // 2
    win_area = float(kernel_px * kernel_px)

    with rasterio.open(src_path) as src:
        profile = src.profile.copy()
        # keep same geo/tiling; ensure float32 + NaN nodata
        profile.update(
            dtype="float32",
            count=1,
            nodata=np.nan,
            compress="LZW",
            tiled=True,
            blockxsize=512,
            blockysize=512,
            BIGTIFF="IF_NEEDED"
        )

        with rasterio.open(dst_path, "w", **profile) as dst:
            for _, w in src.block_windows(1):
                # expand window by halo for neighborhood support
                r0 = max(0, w.row_off - r)
                c0 = max(0, w.col_off - r)
                r1 = min(src.height, w.row_off + w.height + r)
                c1 = min(src.width,  w.col_off + w.width  + r)
                if (c1 - c0) <= 0 or (r1 - r0) <= 0:
                    # nothing there
                    dst.write(np.full((w.height, w.width), np.nan, dtype=np.float32), 1, window=w)
                    continue
                win_pad = Window(c0, r0, c1 - c0, r1 - r0)

                a = src.read(1, window=win_pad, masked=True)
                data = a.data
                # valid where not masked AND finite
                valid = (~a.mask) & np.isfinite(data)
                if not valid.any():
                    dst.write(np.full((w.height, w.width), np.nan, dtype=np.float32), 1, window=w)
                    continue

                # current binary: anything > 0.5 treated as 1
                base = valid & (data > 0.5)
                base_f = base.astype(np.float32)
                valf   = valid.astype(np.float32)

                # neighborhood counts via uniform_filter
                sum_ones  = uniform_filter(base_f, size=kernel_px, mode="constant", cval=0.0) * win_area
                cnt_valid = uniform_filter(valf,   size=kernel_px, mode="constant", cval=0.0) * win_area
                frac = np.divide(sum_ones, cnt_valid, out=np.zeros_like(sum_ones), where=(cnt_valid > 0))

                # majority decision (only where we have valid neighbors)
                smoothed = np.zeros_like(base, dtype=bool)
                has_nb = (cnt_valid > 0)
                smoothed[has_nb] = frac[has_nb] >= maj_thr

                # crop back to original tile window
                rs = w.row_off - r0
                cs = w.col_off - c0
                re = rs + w.height
                ce = cs + w.width

                valid_core    = valid[rs:re, cs:ce]
                smooth_core   = smoothed[rs:re, cs:ce]

                out_block = np.full((w.height, w.width), np.nan, dtype=np.float32)
                out_block[valid_core & smooth_core]  = 1.0
                out_block[valid_core & ~smooth_core] = 0.0

                dst.write(out_block, 1, window=w)

# ---------------------- Main driver ----------------------

def smooth_aei_binary_maps(drive):
    """
    Entry point: locate Binary folder, smooth each AEI binary map,
    and write to Binary_MAJ sibling folder.
    """
    # Locate folders
    cmp_id = get_subfolder_fuzzy(drive, ROOT_FOLDER_ID, PARENT_FOLDER_NAME)
    if not cmp_id:
        raise RuntimeError(f"Folder '{PARENT_FOLDER_NAME}' not found under ROOT.")

    prob_id = get_subfolder_fuzzy(drive, cmp_id, PROB_SUBFOLDER_NAME)
    if not prob_id:
        raise RuntimeError("Probability folder not found (tried fuzzy match).")

    binary_id = get_subfolder_fuzzy(drive, prob_id, BINARY_FOLDER_NAME)
    if not binary_id:
        raise RuntimeError("Binary folder (with AEI binaries) not found under Probability.")

    smooth_id = get_or_create_folder(drive, prob_id, SMOOTH_FOLDER_NAME)

    # List binary TIFFs
    files = [
        it for it in list_files(drive, binary_id)
        if isinstance(it, dict)
        and it.get("mimeType") != "application/vnd.google-apps.folder"
        and it.get("title", "").lower().endswith((".tif", ".tiff"))
    ]
    if not files:
        raise FileNotFoundError("No binary TIFFs in Binary folder.")

    print(f"Found {len(files)} binary rasters in Binary/")

    for it in files:
        title = it.get("title", "")
        base  = re.sub(r"\.tif(f)?$", "", title, flags=re.IGNORECASE)
        out_name = f"{base}_MAJ_k{KERNEL_PX}_p{int(MAJ_THR*100)}.tif"

        print(f"\n=== Smoothing {title} → {out_name} ===")
        tmp_in  = download_to_temp(it, ".tif")
        with tempfile.NamedTemporaryFile(delete=False, suffix=".tif") as tmp_out:
            out_local = tmp_out.name

        _smooth_binary_stream(tmp_in, out_local, kernel_px=KERNEL_PX, maj_thr=MAJ_THR)
        upload_path(drive, out_local, smooth_id, title=out_name)

        # cleanup
        try:
            os.remove(tmp_in)
        except Exception:
            pass
        try:
            os.remove(out_local)
        except Exception:
            pass

        print(f"   ✓ Wrote {SMOOTH_FOLDER_NAME}/{out_name}")

    print("\n✅ Done smoothing all AEI binary maps (majority filter).")

# ---------------------- CLI ----------------------
if __name__ == "__main__":
    try:
        drive  # noqa: F821
    except NameError:
        raise RuntimeError("PyDrive2 'drive' not found. Authenticate and expose a global `drive` before running.")
    smooth_aei_binary_maps(drive)


Found 1 binary rasters in Binary/

=== Smoothing Nigeria_AEI_binary_0_1.tif → Nigeria_AEI_binary_0_1_MAJ_k9_p50.tif ===
   ✓ Wrote Binary_MAJ_West_Africa/Nigeria_AEI_binary_0_1_MAJ_k9_p50.tif

✅ Done smoothing all AEI binary maps (majority filter).


In [8]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Merge all per-country AEI binary rasters in Drive → CountryModelPredicted/Probability/Binary/
into a single global binary mosaic (union = max), preserving NaNs.

Output: Merged_AEI_binary_0_1.tif in the same Binary folder.
"""

import os, re, tempfile, warnings
import numpy as np
import rasterio
from rasterio.merge import merge as rio_merge
from rasterio.enums import Resampling
ROOT_FOLDER_ID      = os.environ.get("ROOT_FOLDER_ID", "11ReAmJ809FmU-8jLpyyO6VEdmeqDmbaB7")
# PARENT_FOLDER_NAME  = "CountryModelPredicted"
# PROB_SUBFOLDER_NAME = "Probability"   # fuzzy-matched
# BINARY_FOLDER_NAME  = "Binary"        # input binaries from step 1
# SMOOTH_FOLDER_NAME  = "Binary_MAJ_West_Africa
------------------------- CONFIG -------------------------
ROOT_FOLDER_ID      = os.environ.get("ROOT_FOLDER_ID", "161LpFY8dIGXWrKO6dVNDVQaHDbugYvJA")
PARENT_FOLDER_NAME  = "CountryModelPredicted"
PROB_SUBFOLDER_NAME = "Probability"   # fuzzy-matched (also matches misspellings)
BINARY_SUBFOLDER    = "Binary_MAJ_West_Africa"
OUT_NAME            = "West_Africa_binary.tif"
# ----------------------------------------------------------

warnings.filterwarnings("ignore", category=RuntimeWarning)

# ---------------------- Drive helpers ---------------------
def _dq(drive, q):
    return drive.ListFile({
        "q": q,
        "supportsAllDrives": True,
        "includeItemsFromAllDrives": True,
        "maxResults": 1000
    }).GetList()

def list_files(drive, parent_id):
    return _dq(drive, f"'{parent_id}' in parents and trashed=false")

def child_folders(drive, parent_id):
    return _dq(drive, f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder'")

def get_subfolder_exact(drive, parent_id, name):
    res = _dq(drive, f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'")
    return res[0]["id"] if res else None

def get_subfolder_fuzzy(drive, parent_id, desired):
    eid = get_subfolder_exact(drive, parent_id, desired)
    if eid: return eid
    key = re.sub(r"\s+", "", desired.lower())
    for f in child_folders(drive, parent_id):
        tkey = re.sub(r"\s+", "", f.get("title","").lower())
        if key in tkey or ("prob" in key and "prob" in tkey):
            return f["id"]
    return None

def get_or_create_folder(drive, parent_id, name):
    res = _dq(drive, f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'")
    if res: return res[0]["id"]
    nf = drive.CreateFile({"title": name, "parents":[{"id": parent_id}], "mimeType":"application/vnd.google-apps.folder"})
    nf.Upload()
    return nf["id"]

def download_to_temp(drive_file, suffix):
    p = tempfile.NamedTemporaryFile(delete=False, suffix=suffix).name
    drive_file.GetContentFile(p)
    return p

def upload_path(drive, local_path, parent_id, title=None):
    f = drive.CreateFile({"title": title or os.path.basename(local_path), "parents":[{"id": parent_id}]})
    f.SetContentFile(local_path)
    f.Upload()
    return f["id"]

# ---------------------- Merge helpers ---------------------
def _try_merge_with_max(srcs):
    """Preferred path (newer rasterio): method='max', nodata=np.nan."""
    return rio_merge(
        srcs,
        nodata=np.nan,
        dtype="float32",
        precision=7,
        resampling=Resampling.nearest,
        method="max",
    )

def _try_merge_basic(srcs):
    """Older rasterio: no 'method'. We'll nanmax ourselves after merging."""
    mosaic, out_transform = rio_merge(
        srcs,
        nodata=np.nan,
        dtype="float32",
        precision=7,
        resampling=Resampling.nearest,
    )
    # emulate union across sources (elementwise maximum, ignoring NaNs)
    mosaic = np.nanmax(mosaic, axis=0, keepdims=True).astype("float32")
    return mosaic, out_transform

def _try_merge_basic_sentinel(srcs, sentinel=-9999.0):
    """Very old rasterio: nodata cannot be NaN. Use sentinel then convert and nanmax."""
    mosaic, out_transform = rio_merge(
        srcs,
        nodata=sentinel,
        dtype="float32",
        precision=7,
        resampling=Resampling.nearest,
    )
    # convert sentinel to NaN
    mosaic = mosaic.astype("float32", copy=False)
    mosaic[mosaic == sentinel] = np.nan
    mosaic = np.nanmax(mosaic, axis=0, keepdims=True).astype("float32")
    return mosaic, out_transform

# ---------------------- Main merge ------------------------
def merge_all_binary_rasters(drive):
    # Locate folders
    cmp_id     = get_subfolder_fuzzy(drive, ROOT_FOLDER_ID, PARENT_FOLDER_NAME)
    if not cmp_id: raise RuntimeError(f"Folder '{PARENT_FOLDER_NAME}' not found under ROOT.")
    prob_id    = get_subfolder_fuzzy(drive, cmp_id, PROB_SUBFOLDER_NAME)
    if not prob_id: raise RuntimeError("Probability folder not found (tried fuzzy match).")
    binary_id  = get_subfolder_fuzzy(drive, prob_id, BINARY_SUBFOLDER)
    if not binary_id:
        binary_id = get_or_create_folder(drive, prob_id, BINARY_SUBFOLDER)

    # Find all binary GeoTIFFs (skip the mosaic itself if re-running)
    tifs = [it for it in list_files(drive, binary_id)
            if isinstance(it, dict)
            and it.get("mimeType") != "application/vnd.google-apps.folder"
            and it.get("title","").lower().endswith((".tif",".tiff"))
            and OUT_NAME.lower() not in it.get("title","").lower()]
    if not tifs:
        raise FileNotFoundError("No binary .tif files found in Binary/.")

    # Download & open datasets
    local_paths, srcs = [], []
    try:
        for it in tifs:
            p = download_to_temp(it, ".tif")
            local_paths.append(p)
            srcs.append(rasterio.open(p))

        # Try modern merge with 'method=max' → else fallback strategies
        try:
            mosaic, out_transform = _try_merge_with_max(srcs)
        except TypeError:
            # 'method' not supported
            try:
                mosaic, out_transform = _try_merge_basic(srcs)
            except Exception:
                mosaic, out_transform = _try_merge_basic_sentinel(srcs)
        except Exception:
            # Any other unexpected error → robust fallback
            try:
                mosaic, out_transform = _try_merge_basic(srcs)
            except Exception:
                mosaic, out_transform = _try_merge_basic_sentinel(srcs)

        # Build output profile from first raster
        ref = srcs[0]
        out_profile = ref.profile.copy()
        out_profile.update(
            driver="GTiff",
            height=mosaic.shape[1],
            width=mosaic.shape[2],
            transform=out_transform,
            count=1,
            dtype="float32",
            nodata=np.nan,
            compress="LZW",
            tiled=True,
            blockxsize=512,
            blockysize=512,
            BIGTIFF="IF_NEEDED"
        )

        # Write to temp, then upload
        with tempfile.NamedTemporaryFile(delete=False, suffix=".tif") as tmp_out:
            out_local = tmp_out.name

        with rasterio.open(out_local, "w", **out_profile) as dst:
            dst.write(mosaic[0], 1)

        upload_path(drive, out_local, binary_id, title=OUT_NAME)
        try: os.remove(out_local)
        except: pass

        print(f"✅ Merged {len(srcs)} rasters → Binary/{OUT_NAME}")

    finally:
        # Cleanup
        for s in srcs:
            try: s.close()
            except: pass
        for p in local_paths:
            try: os.remove(p)
            except: pass

# ---------------------- CLI ----------------------
if __name__ == "__main__":
    try:
        drive  # noqa: F821
    except NameError:
        raise RuntimeError("PyDrive2 'drive' not found. Authenticate and expose a global `drive` before running.")
    merge_all_binary_rasters(drive)


SyntaxError: invalid syntax (1380171659.py, line 20)

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Plot US_binary.tif with:
  - Top panel: country-wide map (ADM0 boundary + irrigated in dark blue).
  - Bottom panel: zoomed inset, clipping US_binary to US_small shapefile.

Files (Google Drive layout):
  CountryModelPredicted_Cropland /
    Probability|Probablity|Prob /
      Merged /
        US_binary.tif
        US_small.(shp,shx,dbf,prj,cpg)
        Comparisons /
          US.(shp,shx,dbf,prj,cpg)
          --> output PNG will be saved here as:
              US_binary_with_zoom_map.png

Requirements:
  - PyDrive2-authenticated global `drive`
  - rasterio, numpy, matplotlib, fiona, shapely, pyproj
"""

import os
import re
import tempfile
import warnings

import numpy as np
import rasterio
from rasterio.enums import Resampling
from rasterio.mask import mask

import fiona
from shapely.geometry import shape
from shapely.ops import transform as shp_transform
from shapely.ops import unary_union
from pyproj import Transformer

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

warnings.filterwarnings("ignore", category=RuntimeWarning)

# ===================== CONFIG =====================
ROOT_FOLDER_ID      = os.environ.get("ROOT_FOLDER_ID", "18pQKnMMnLramhHRZSNwUJrLqG5DXNMmS")
PARENT_FOLDER_NAME  = "CountryModelPredicted"
PROB_NAMES          = ("Probability", "Probablity", "Prob")
BINARY_SUBFOLDER    = "Merged"
COMPARE_SUBFOLDER   = "Comparisons"

CHINA_BINARY_NAME   = "India_binary.tif"
SMALL_BASE          = "India_small"
CHN_ADM0_BASE       = "gadm41_IND_0"


PLOT_DPI            = 160
# ==================================================


# ---------------------- Drive helpers ---------------------
def _dq(drive, q):
    return drive.ListFile({
        "q": q,
        "supportsAllDrives": True,
        "includeItemsFromAllDrives": True,
        "maxResults": 1000
    }).GetList()

def list_files(drive, parent_id):
    return _dq(drive, f"'{parent_id}' in parents and trashed=false")

def child_folders(drive, parent_id):
    return _dq(
        drive,
        f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder'"
    )

def get_subfolder_exact(drive, parent_id, name):
    res = _dq(
        drive,
        f"'{parent_id}' in parents and trashed=false "
        f"and mimeType='application/vnd.google-apps.folder' and title='{name}'"
    )
    return res[0]["id"] if res else None

def get_subfolder_fuzzy(drive, parent_id, desired):
    """Simple fuzzy finder for folders (handles 'Probability' typo)."""
    eid = get_subfolder_exact(drive, parent_id, desired)
    if eid:
        return eid
    key = re.sub(r"\s+", "", desired.lower())
    for f in child_folders(drive, parent_id):
        t = (f.get("title") or f.get("name") or "").lower()
        tkey = re.sub(r"\s+", "", t)
        if key in tkey or ("prob" in key and "prob" in tkey):
            return f["id"]
    return None

def download_to_temp(drive_file, local_path):
    drive_file.GetContentFile(local_path)

def upload_path(drive, local_path, parent_id, title=None):
    f = drive.CreateFile(
        {"title": title or os.path.basename(local_path),
         "parents":[{"id": parent_id}]}
    )
    f.SetContentFile(local_path)
    f.Upload()
    return f["id"]

def download_shapefile_bundle(drive, folder_id, base):
    """
    Download base.(shp,shx,dbf,prj,cpg) from a folder into a temp dir.
    Returns path to the local .shp.
    """
    exts = [".shp", ".shx", ".dbf", ".prj", ".cpg"]
    items = {}
    for it in list_files(drive, folder_id):
        title = (it.get("title") or "").lower()
        for e in exts:
            if title == (base.lower() + e):
                items[e] = it

    if ".shp" not in items:
        raise FileNotFoundError(f"Could not find {base}.shp in folder.")

    tmpdir = tempfile.mkdtemp(prefix=base + "_")
    for e, it in items.items():
        lp = os.path.join(tmpdir, base + e)
        download_to_temp(it, lp)
    return os.path.join(tmpdir, base + ".shp")


# --------------------- Main plotting logic --------------------
def plot_china_binary_with_zoom(drive):
    # 1) Navigate to folders
    root_id = ROOT_FOLDER_ID
    cmp_id = get_subfolder_fuzzy(drive, root_id, PARENT_FOLDER_NAME)
    if not cmp_id:
        raise RuntimeError(f"Folder '{PARENT_FOLDER_NAME}' not found under ROOT.")

    prob_id = None
    for nm in PROB_NAMES:
        prob_id = get_subfolder_fuzzy(drive, cmp_id, nm)
        if prob_id:
            break
    if not prob_id:
        raise RuntimeError("Probability/Probablity/Prob folder not found.")

    binary_id = get_subfolder_fuzzy(drive, prob_id, BINARY_SUBFOLDER)
    if not binary_id:
        raise RuntimeError(f"Binary folder '{BINARY_SUBFOLDER}' not found.")

    comp_id = get_subfolder_fuzzy(drive, binary_id, COMPARE_SUBFOLDER)
    if not comp_id:
        raise RuntimeError(f"Comparisons folder '{COMPARE_SUBFOLDER}' not found inside {BINARY_SUBFOLDER}/.")

    # 2) Find US_binary.tif and small-area shapefile in Merged
    china_file = None
    small_shp_present = False
    for it in list_files(drive, binary_id):
        title = (it.get("title") or "").strip().lower()
        if title == CHINA_BINARY_NAME.lower():
            china_file = it
        if title.startswith(SMALL_BASE.lower()) and title.endswith(".shp"):
            small_shp_present = True
    if not china_file:
        raise FileNotFoundError(f"{CHINA_BINARY_NAME} not found in Merged/.")
    if not small_shp_present:
        raise FileNotFoundError(f"{SMALL_BASE}.shp not found in Merged/.")

    # 3) Download US_binary.tif and shapefiles
    china_tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
    download_to_temp(china_file, china_tmp)

    shp_small_path = download_shapefile_bundle(drive, binary_id, SMALL_BASE)
    shp_chn_path   = download_shapefile_bundle(drive, comp_id, CHN_ADM0_BASE)

    try:
        # === Read US_binary with downsampling for the main map ===
        with rasterio.open(china_tmp) as src:
            H, W = src.height, src.width
            bounds = src.bounds
            raster_crs = src.crs
            nodata = src.nodata

            # target ~10M pixels max for the main map
            max_pixels = 10_000_000
            scale_factor = max(1, int(np.sqrt((H * W) / max_pixels)))
            out_h = max(1, H // scale_factor)
            out_w = max(1, W // scale_factor)

            china_arr = src.read(
                1,
                out_shape=(out_h, out_w),
                resampling=Resampling.nearest,
                masked=True
            )

            # Extent for imshow
            extent_main = (bounds.left, bounds.right, bounds.bottom, bounds.top)

        # === Read ADM0 boundary and reproject to raster CRS ===
        chn_geoms = []
        with fiona.open(shp_chn_path) as src:
            shp_crs = src.crs
            if shp_crs and raster_crs and shp_crs != raster_crs:
                transformer = Transformer.from_crs(shp_crs, raster_crs, always_xy=True)
                project = lambda x, y, z=None: transformer.transform(x, y)
            else:
                project = None

            for feat in src:
                g = shape(feat["geometry"])
                if project is not None:
                    g = shp_transform(project, g)
                chn_geoms.append(g)

        # === Read small-area shapefile and reproject to raster CRS ===
        small_geoms = []
        with fiona.open(shp_small_path) as src:
            shp_crs = src.crs
            if shp_crs and raster_crs and shp_crs != raster_crs:
                transformer = Transformer.from_crs(shp_crs, raster_crs, always_xy=True)
                project = lambda x, y, z=None: transformer.transform(x, y)
            else:
                project = None

            for feat in src:
                g = shape(feat["geometry"])
                if project is not None:
                    g = shp_transform(project, g)
                small_geoms.append(g)

        small_union = unary_union(small_geoms)

        # === Crop US_binary to the small area for inset ===
        with rasterio.open(china_tmp) as src:
            inset_arr, inset_transform = mask(
                src,
                [small_union],
                crop=True,
                nodata=src.nodata
            )
            inset_arr = inset_arr[0]   # single band
            inset_bounds = rasterio.transform.array_bounds(
                inset_arr.shape[0], inset_arr.shape[1], inset_transform
            )
            # array_bounds returns (ymin, ymax, xmin, xmax)
            inset_extent = (inset_bounds[2], inset_bounds[3],
                            inset_bounds[0], inset_bounds[1])

        # === Set up colormap: 0=grey (non-irrigated), 1=dark blue (irrigated), nodata transparent ===
        cmap = ListedColormap([
            (0.9, 0.9, 0.9, 1.0),  # 0 non-irrigated
            (0.0, 0.0, 0.5, 1.0),  # 1 irrigated (dark blue)
        ])
        # for masked values (nodata)
        cmap.set_bad((0, 0, 0, 0))

        # === Build the figure ===
        fig, (ax_main, ax_inset) = plt.subplots(
            nrows=2, ncols=1,
            figsize=(7.5, 9.0),
            dpi=PLOT_DPI,
            gridspec_kw={"height_ratios": [3, 2]}
        )

        # ----- TOP PANEL: country-wide map -----
        ax_main.imshow(
            china_arr,
            origin="upper",
            cmap=cmap,
            vmin=0,
            vmax=1,
            extent=extent_main
        )

        # ONLY outer shapefile boundary in black (ADM0)
        for g in chn_geoms:
            if g.geom_type == "MultiPolygon":
                polys = list(g.geoms)
            else:
                polys = [g]
            for poly in polys:
                x, y = poly.exterior.xy
                ax_main.plot(x, y, color="black", linewidth=0.7)

        ax_main.set_xticks([])
        ax_main.set_yticks([])
        ax_main.set_xlabel("")
        ax_main.set_ylabel("")
        ax_main.set_title(f"Irrigated area (dark blue) from {CHINA_BINARY_NAME}", fontsize=13)

        # NOTE: no bounding box rectangle
        # NOTE: no legend

        # ----- BOTTOM PANEL: zoomed inset -----
        # Mask nodata for inset
        inset_masked = np.ma.masked_equal(inset_arr, nodata) if nodata is not None else np.ma.masked_invalid(inset_arr)

        ax_inset.imshow(
            inset_masked,
            origin="upper",
            cmap=cmap,
            vmin=0,
            vmax=1,
            extent=inset_extent
        )

        # ONLY outer shapefile boundary in black (US_small)
        for g in small_geoms:
            if g.geom_type == "MultiPolygon":
                polys = list(g.geoms)
            else:
                polys = [g]
            for poly in polys:
                x, y = poly.exterior.xy
                ax_inset.plot(x, y, color="black", linewidth=0.7)

        ax_inset.set_xticks([])
        ax_inset.set_yticks([])
        ax_inset.set_xlabel("")
        ax_inset.set_ylabel("")
        ax_inset.set_title(f"Zoomed irrigated area ({SMALL_BASE}.shp, dark blue)", fontsize=12)
        ax_inset.set_aspect("equal")

        plt.tight_layout()

        # ---- Save & upload PNG ----
        out_png = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name
        fig.savefig(out_png, bbox_inches="tight", dpi=PLOT_DPI)
        plt.close(fig)

        out_title = "US_binary_with_zoom_map.png"
        upload_path(drive, out_png, comp_id, title=out_title)

        try:
            os.remove(out_png)
        except Exception:
            pass

        print(f"✅ Saved country-wide + zoom map PNG as '{out_title}' in Merged/Comparisons/")

    finally:
        try:
            os.remove(china_tmp)
        except Exception:
            pass
        # shapefile tempdirs will be cleaned by OS eventually


# ---------------------- CLI ----------------------
if __name__ == "__main__":
    try:
        drive  # noqa: F821
    except NameError:
        raise RuntimeError(
            "PyDrive2 'drive' not found. Authenticate and expose a global `drive` before running."
        )
    plot_china_binary_with_zoom(drive)


In [2]:
# List only folders in Google Drive root
folder_list = drive.ListFile({
    'q': "'root' in parents and mimeType='application/vnd.google-apps.folder' and trashed=false"
}).GetList()

for folder in folder_list:
    print(f"Folder: {folder['title']}  |  ID: {folder['id']}")


Folder: Continental Irr Maps  |  ID: 1atiKrbLV7Ifez9AKTuSqCFMsTxkOqOvc
Folder: Moldova_VIs_Env_GrowingSeason  |  ID: 183w10O5hg6pS6ou2gT7LE1N4SI1pSn5B
Folder: Cuba_VIs_Env_IrrigationSeason  |  ID: 1OrL9zVhyRh57oLfvg84ySJhB0GZSCM6m
Folder: Caribbean_VIs_Env_IrrigationSeason  |  ID: 1vsKprJVxyyjHR6zPcfYV1Rc1wkjIw2fn
Folder: GMIE  |  ID: 1lVkH_OShkP5n_9CI4bdJgwFuHwmbx57H
Folder: CLE_National  |  ID: 1pOpf-Zy5la4SKTIcrmLX6tGlxVBejKYX
Folder: AfricaEnv_WA_noPrecip_NoSoil_NoCropMask  |  ID: 1KGWrOMmhnCVvUFr7uh9y2wyjPdLkRFDd
Folder: Europe_VIs_Env2  |  ID: 1Rz21QLD1UnspShqUD7jB6Bfsgjk03Z9e
Folder: China  |  ID: 19Ee1kUar4oEabBObq7GcA-jDR7F29x2_
Folder: Canada  |  ID: 1qXhq3-3-J2YhKjXmMcCOT0exvb12LxhD
Folder: Global_Irr_Maps  |  ID: 1XKYX2EPkPmu1el4Q0FpfZMEkgsWvRuYD
Folder: Validation Dataset  |  ID: 1yHuAWZZdz5KsWxvEXPw4nxiRcKgYPtp8
Folder: SouthAmerica_VIs_Env  |  ID: 1hitOuuj27P1_Lleo4Dhl630Q-TCcS_b6
Folder: Canada_Irrigation_RF  |  ID: 1JUJ8xSgSiuccxHd80eiVYfLgwjZcPS5b
Folder: US_VIs_Env  

In [4]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Plot US_binary.tif with:
  - Top panel: country-wide map (ADM0 boundary + irrigated in dark blue).
  - Bottom panel: zoomed inset, clipping US_binary to US_small shapefile.

Files (Google Drive layout):
  CountryModelPredicted_Cropland /
    Probability|Probablity|Prob /
      Merged /
        US_binary.tif
        US_small.(shp,shx,dbf,prj,cpg)
        Comparisons /
          US.(shp,shx,dbf,prj,cpg)
          --> output PNG will be saved here as:
              US_binary_with_zoom_map.png

Requirements:
  - PyDrive2-authenticated global `drive`
  - rasterio, numpy, matplotlib, fiona, shapely, pyproj
"""

import os
import re
import tempfile
import warnings

import numpy as np
import rasterio
from rasterio.enums import Resampling
from rasterio.mask import mask

import fiona
from shapely.geometry import shape
from shapely.ops import transform as shp_transform
from shapely.ops import unary_union
from pyproj import Transformer
import os
import re
import tempfile
import warnings

import numpy as np
import rasterio
from rasterio.enums import Resampling
from rasterio.mask import mask  # kept, though not used now
import rasterio.windows as rwindows

import fiona
from shapely.geometry import shape
from shapely.ops import transform as shp_transform
from shapely.ops import unary_union
from pyproj import Transformer

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

warnings.filterwarnings("ignore", category=RuntimeWarning)

# ===================== CONFIG =====================
ROOT_FOLDER_ID      = os.environ.get("ROOT_FOLDER_ID", "1ReAmJ809FmU-8jLpyyO6VEdmeqDmbaB7")
PARENT_FOLDER_NAME  = "CountryModelPredicted"
PROB_NAMES          = ("Probability", "Probablity", "Prob")
BINARY_SUBFOLDER    = "Merged"
COMPARE_SUBFOLDER   = "Comparisons"

CHINA_BINARY_NAME   = "Egypt_Binary.tif"
SMALL_BASE          = "Egypt_smalls"
CHN_ADM0_BASE       = "gadm41_EGY_0"


PLOT_DPI            = 160
# ==================================================
# ---------------------- Drive helpers ---------------------
def _dq(drive, q):
    return drive.ListFile({
        "q": q,
        "supportsAllDrives": True,
        "includeItemsFromAllDrives": True,
        "maxResults": 1000
    }).GetList()

def list_files(drive, parent_id):
    return _dq(drive, f"'{parent_id}' in parents and trashed=false")

def child_folders(drive, parent_id):
    return _dq(
        drive,
        f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder'"
    )

def get_subfolder_exact(drive, parent_id, name):
    res = _dq(
        drive,
        f"'{parent_id}' in parents and trashed=false "
        f"and mimeType='application/vnd.google-apps.folder' and title='{name}'"
    )
    return res[0]["id"] if res else None

def get_subfolder_fuzzy(drive, parent_id, desired):
    """Simple fuzzy finder for folders (handles 'Probability' typo)."""
    eid = get_subfolder_exact(drive, parent_id, desired)
    if eid:
        return eid
    key = re.sub(r"\s+", "", desired.lower())
    for f in child_folders(drive, parent_id):
        t = (f.get("title") or f.get("name") or "").lower()
        tkey = re.sub(r"\s+", "", t)
        if key in tkey or ("prob" in key and "prob" in tkey):
            return f["id"]
    return None

def download_to_temp(drive_file, local_path):
    drive_file.GetContentFile(local_path)

def upload_path(drive, local_path, parent_id, title=None):
    f = drive.CreateFile(
        {"title": title or os.path.basename(local_path),
         "parents":[{"id": parent_id}]}
    )
    f.SetContentFile(local_path)
    f.Upload()
    return f["id"]

def download_shapefile_bundle(drive, folder_id, base):
    """
    Download base.(shp,shx,dbf,prj,cpg) from a folder into a temp dir.
    Returns path to the local .shp.
    """
    exts = [".shp", ".shx", ".dbf", ".prj", ".cpg"]
    items = {}
    for it in list_files(drive, folder_id):
        title = (it.get("title") or "").lower()
        for e in exts:
            if title == (base.lower() + e):
                items[e] = it

    if ".shp" not in items:
        raise FileNotFoundError(f"Could not find {base}.shp in folder.")

    tmpdir = tempfile.mkdtemp(prefix=base + "_")
    for e, it in items.items():
        lp = os.path.join(tmpdir, base + e)
        download_to_temp(it, lp)
    return os.path.join(tmpdir, base + ".shp")


# --------------------- Main plotting logic --------------------
def plot_china_binary_with_zoom(drive):
    # 1) Navigate to folders
    root_id = ROOT_FOLDER_ID
    cmp_id = get_subfolder_fuzzy(drive, root_id, PARENT_FOLDER_NAME)
    if not cmp_id:
        raise RuntimeError(f"Folder '{PARENT_FOLDER_NAME}' not found under ROOT.")

    prob_id = None
    for nm in PROB_NAMES:
        prob_id = get_subfolder_fuzzy(drive, cmp_id, nm)
        if prob_id:
            break
    if not prob_id:
        raise RuntimeError("Probability/Probablity/Prob folder not found.")

    binary_id = get_subfolder_fuzzy(drive, prob_id, BINARY_SUBFOLDER)
    if not binary_id:
        raise RuntimeError(f"Binary folder '{BINARY_SUBFOLDER}' not found.")

    comp_id = get_subfolder_fuzzy(drive, binary_id, COMPARE_SUBFOLDER)
    if not comp_id:
        raise RuntimeError(f"Comparisons folder '{COMPARE_SUBFOLDER}' not found inside {BINARY_SUBFOLDER}/.")

    # 2) Find US_binary.tif and small-area shapefile in Merged
    china_file = None
    small_shp_present = False
    for it in list_files(drive, binary_id):
        title = (it.get("title") or "").strip().lower()
        if title == CHINA_BINARY_NAME.lower():
            china_file = it
        if title.startswith(SMALL_BASE.lower()) and title.endswith(".shp"):
            small_shp_present = True
    if not china_file:
        raise FileNotFoundError(f"{CHINA_BINARY_NAME} not found in Merged/.")
    if not small_shp_present:
        raise FileNotFoundError(f"{SMALL_BASE}.shp not found in Merged/.")

    # 3) Download US_binary.tif and shapefiles
    china_tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
    download_to_temp(china_file, china_tmp)

    shp_small_path = download_shapefile_bundle(drive, binary_id, SMALL_BASE)
    shp_chn_path   = download_shapefile_bundle(drive, comp_id, CHN_ADM0_BASE)

    try:
        # === Read US_binary with downsampling for the main map ===
        with rasterio.open(china_tmp) as src:
            H, W = src.height, src.width
            bounds = src.bounds
            raster_crs = src.crs
            nodata = src.nodata

            # target ~10M pixels max for the main map
            max_pixels = 10_000_000
            scale_factor = max(1, int(np.sqrt((H * W) / max_pixels)))
            out_h = max(1, H // scale_factor)
            out_w = max(1, W // scale_factor)

            china_arr = src.read(
                1,
                out_shape=(out_h, out_w),
                resampling=Resampling.nearest,
                masked=True
            )

            # Extent for imshow
            extent_main = (bounds.left, bounds.right, bounds.bottom, bounds.top)

        # === Read ADM0 boundary and reproject to raster CRS ===
        chn_geoms = []
        with fiona.open(shp_chn_path) as src:
            shp_crs = src.crs
            if shp_crs and raster_crs and shp_crs != raster_crs:
                transformer = Transformer.from_crs(shp_crs, raster_crs, always_xy=True)
                project = lambda x, y, z=None: transformer.transform(x, y)
            else:
                project = None

            for feat in src:
                g = shape(feat["geometry"])
                if project is not None:
                    g = shp_transform(project, g)
                chn_geoms.append(g)

        # === Read small-area shapefile and reproject to raster CRS ===
        small_geoms = []
        with fiona.open(shp_small_path) as src:
            shp_crs = src.crs
            if shp_crs and raster_crs and shp_crs != raster_crs:
                transformer = Transformer.from_crs(shp_crs, raster_crs, always_xy=True)
                project = lambda x, y, z=None: transformer.transform(x, y)
            else:
                project = None

            for feat in src:
                g = shape(feat["geometry"])
                if project is not None:
                    g = shp_transform(project, g)
                small_geoms.append(g)

        small_union = unary_union(small_geoms)

        # === Crop US_binary to the RECTANGULAR BBOX of the small area for inset ===
        with rasterio.open(china_tmp) as src:
            bbox = small_union.bounds  # (minx, miny, maxx, maxy)
            win = rwindows.from_bounds(*bbox, transform=src.transform)
            inset_arr = src.read(1, window=win, masked=True)
            inset_transform = src.window_transform(win)

            inset_bounds = rasterio.transform.array_bounds(
                inset_arr.shape[0], inset_arr.shape[1], inset_transform
            )
            # array_bounds returns (ymin, ymax, xmin, xmax)
            inset_extent = (inset_bounds[2], inset_bounds[3],
                            inset_bounds[0], inset_bounds[1])

        # === Set up colormap: 0=grey (non-irrigated), 1=dark blue (irrigated), nodata transparent ===
        cmap = ListedColormap([
            (0.9, 0.9, 0.9, 1.0),  # 0 non-irrigated
            (0.0, 0.0, 0.5, 1.0),  # 1 irrigated (dark blue)
        ])
        # for masked values (nodata)
        cmap.set_bad((0, 0, 0, 0))

        # === Build the figure ===
        fig, (ax_main, ax_inset) = plt.subplots(
            nrows=2, ncols=1,
            figsize=(7.5, 9.0),
            dpi=PLOT_DPI,
            gridspec_kw={"height_ratios": [3, 2]}
        )

        # ----- TOP PANEL: country-wide map -----
        ax_main.imshow(
            china_arr,
            origin="upper",
            cmap=cmap,
            vmin=0,
            vmax=1,
            extent=extent_main
        )

        # ONLY outer shapefile boundary in black (ADM0)
        for g in chn_geoms:
            if g.geom_type == "MultiPolygon":
                polys = list(g.geoms)
            else:
                polys = [g]
            for poly in polys:
                x, y = poly.exterior.xy
                ax_main.plot(x, y, color="black", linewidth=0.7)

        # Remove axis box but keep titles
        for spine in ax_main.spines.values():
            spine.set_visible(False)

        ax_main.set_xticks([])
        ax_main.set_yticks([])
        ax_main.set_xlabel("")
        ax_main.set_ylabel("")
        ax_main.set_title(f"Irrigated area (dark blue) from {CHINA_BINARY_NAME}", fontsize=13)

        # ----- BOTTOM PANEL: zoomed inset -----
        inset_masked = inset_arr  # already masked=True above

        ax_inset.imshow(
            inset_masked,
            origin="upper",
            cmap=cmap,
            vmin=0,
            vmax=1,
            extent=inset_extent
        )

        # ONLY outer shapefile boundary in black (US_small)
        for g in small_geoms:
            if g.geom_type == "MultiPolygon":
                polys = list(g.geoms)
            else:
                polys = [g]
            for poly in polys:
                x, y = poly.exterior.xy
                ax_inset.plot(x, y, color="black", linewidth=0.7)

        # Remove axis box but keep titles
        for spine in ax_inset.spines.values():
            spine.set_visible(False)

        ax_inset.set_xticks([])
        ax_inset.set_yticks([])
        ax_inset.set_xlabel("")
        ax_inset.set_ylabel("")
        ax_inset.set_title(f"Zoomed irrigated area ({SMALL_BASE}.shp, dark blue)", fontsize=12)
        ax_inset.set_aspect("equal")

        plt.tight_layout()

        # ---- Save & upload PNG ----
        out_png = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name
        fig.savefig(out_png, bbox_inches="tight", dpi=PLOT_DPI)
        plt.close(fig)

        out_title = "US_binary_with_zoom_map.png"
        upload_path(drive, out_png, comp_id, title=out_title)

        try:
            os.remove(out_png)
        except Exception:
            pass

        print(f"✅ Saved country-wide + zoom map PNG as '{out_title}' in Merged/Comparisons/")

    finally:
        try:
            os.remove(china_tmp)
        except Exception:
            pass
        # shapefile tempdirs will be cleaned by OS eventually


# ---------------------- CLI ----------------------
if __name__ == "__main__":
    try:
        drive  # noqa: F821
    except NameError:
        raise RuntimeError(
            "PyDrive2 'drive' not found. Authenticate and expose a global `drive` before running."
        )
    plot_china_binary_with_zoom(drive)

✅ Saved country-wide + zoom map PNG as 'US_binary_with_zoom_map.png' in Merged/Comparisons/


In [14]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Per-admin AEI-constrained binarization using ONLY the admin shapefile.

For each country probability TIFF:
  - Use unit_code polygons from AEI_2020_2_withAEI.* (AEI in HECTARES).
  - For each unit_code, pick pixels INSIDE (centroid-in) from highest probability
    downward until its AEI quota is met (quota = floor(AEI_ha*10_000 / 900) pixels).
  - Preserve NaNs from the probability input.
  - Write ONE binary 0/1 TIFF per country to:
      Drive → CountryModelPredicted/Probability/Binary/
  - Also write a CSV summary per country with thresholds and counts.

No AEI base raster is used.
"""

import os, re, math, csv, tempfile, warnings, unicodedata
from collections import defaultdict, Counter

import numpy as np
import rasterio
from rasterio.windows import Window
from rasterio.features import rasterize

import fiona
from shapely.geometry import shape, box
from shapely.ops import transform as shp_transform
from shapely.strtree import STRtree
from pyproj import Transformer

# ------------------------- CONFIG -------------------------
ROOT_FOLDER_ID        = os.environ.get("ROOT_FOLDER_ID", "1gtVp-x-dIgcuNPcoaxHfmZs0ctoG71PS")
PARENT_FOLDER_NAME    = "CountryModelPredicted"
PROB_SUBFOLDER_NAME   = "Probability"      # fuzzy-matched (also matches "Porbability")
NATIONAL_AEI_FOLDER   = "National AEI"     # shapefile lives here

# Probability & binning
PIXEL_AREA_M2         = 30.0 * 30.0        # 900 m² per 30 m pixel
TILE                  = 1024               # reduce if memory tight
SCALE                 = 1000               # probability bins: 0..1 → 0..1000

# Shapefile (AEI in HECTARES)
ADMIN_SHP_BASE        = "Copy of AEI_2020_2_with_AEI"  # base name (no extension)
ADMIN_CODE_COL        = "unit_code"
ADMIN_AEI_COLS        = ["AEI_2020", "AEI2020", "AEI"]  # values in HECTARES
ADMIN_CNTRY_COLS      = ["name_cntr", "name_cntr1", "name_admin", "ST_NM"]

# Deterministic tie-breaking
RNG_SEED              = int(os.environ.get("AEI_RNG_SEED", "0"))
_rng = np.random.default_rng(RNG_SEED)
# ----------------------------------------------------------

warnings.filterwarnings("ignore", category=RuntimeWarning)

# ---------------------- Drive helpers ---------------------
def _dq(drive, q):
    return drive.ListFile({
        "q": q,
        "supportsAllDrives": True,
        "includeItemsFromAllDrives": True,
        "maxResults": 1000
    }).GetList()

def list_files(drive, parent_id):
    return _dq(drive, f"'{parent_id}' in parents and trashed=false")

def child_folders(drive, parent_id):
    return _dq(drive, f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder'")

def get_subfolder_exact(drive, parent_id, name):
    res = _dq(drive, f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'")
    return res[0]["id"] if res else None

def get_subfolder_fuzzy(drive, parent_id, desired):
    eid = get_subfolder_exact(drive, parent_id, desired)
    if eid: return eid
    key = re.sub(r"\s+", "", desired.lower())
    for f in child_folders(drive, parent_id):
        tkey = re.sub(r"\s+", "", f.get("title","").lower())
        if key in tkey or ("prob" in key and "prob" in tkey):
            return f["id"]
    return None

def _resolve_field(props_sample: dict, candidates, required=False):
    """Fuzzy, case-insensitive resolver for attribute fields."""
    def canon(s): return re.sub(r"[^a-z0-9]+", "", str(s).lower())
    keys = list(props_sample.keys())
    norm_map = {canon(k): k for k in keys}
    for want in candidates:
        w = canon(want)
        if w in norm_map:
            return norm_map[w]
    want_roots = {canon(want) for want in candidates}
    for k in keys:
        ck = canon(k)
        if any(root in ck for root in want_roots):
            return k
    if required:
        raise RuntimeError(f"Required attribute not found. Looked for: {candidates}. Available: {keys}")
    return None

def download_to_temp(drive_file, suffix):
    p = tempfile.NamedTemporaryFile(delete=False, suffix=suffix).name
    drive_file.GetContentFile(p)
    return p

def get_or_create_folder(drive, parent_id, name):
    res = _dq(drive, f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'")
    if res: return res[0]["id"]
    nf = drive.CreateFile({"title": name, "parents":[{"id": parent_id}], "mimeType":"application/vnd.google-apps.folder"})
    nf.Upload()
    return nf["id"]

def upload_path(drive, local_path, parent_id, title=None):
    f = drive.CreateFile({"title": title or os.path.basename(local_path), "parents":[{"id": parent_id}]})
    f.SetContentFile(local_path)
    f.Upload()
    return f["id"]

# ---------------------- Name helpers ----------------------
def _norm(s): return re.sub(r"[^a-z0-9]+", "", str(s).lower())
def _canon(s):
    s = unicodedata.normalize("NFKD", str(s)).encode("ascii","ignore").decode()
    return re.sub(r"\s+","", s.lower())

def _extract_country_from_fname(fname):
    # Albania_RF_probability_percent.tif → Albania
    fn = re.sub(r"\s*\(.*\)\.tif(f)?$", ".tif", fname, flags=re.IGNORECASE)
    m = re.match(r"(.+?)_RF_probability_percent", fn, flags=re.IGNORECASE)
    if m: return m.group(1)
    return re.sub(r"\.tif(f)?$", "", fn, flags=re.IGNORECASE)

# ---------------------- Shapefile helpers -----------------
def _download_shapefile_bundle(drive, folder_id, base):
    """
    Download AEI_2020_2_withAEI.* into a single temp directory, ensuring
    all sidecars share the SAME basename so GDAL/Fiona can see attributes.
    Returns a dict of local paths keyed by extension ('.shp', '.dbf', etc).
    """
    exts = [".shp", ".shx", ".dbf", ".prj", ".cpg"]
    items = {}
    for it in list_files(drive, folder_id):
        t = it.get("title", "")
        for e in exts:
            if t.lower() == (base.lower() + e):
                items[e] = it

    if ".shp" not in items or ".dbf" not in items:
        raise FileNotFoundError(f"Missing pieces of {base} shapefile (need at least .shp and .dbf). Found: {sorted(items.keys())}")

    tmpdir = tempfile.mkdtemp(prefix="aei_admin_")
    out = {}
    for e, it in items.items():
        local_path = os.path.join(tmpdir, base + e)  # SAME BASENAME!
        it.GetContentFile(local_path)
        out[e] = local_path

    print(f"[AEI] Shapefile bundle at {tmpdir} → found {sorted(out.keys())}")
    return out


def _read_admins_for_raster(shp_path, raster_crs, raster_bounds):
    """
    Load admin polygons, reproject to raster CRS (no AEI filtering).
    - Reads field names from schema (not the first feature).
    - Honors .cpg encoding when present.
    - Skips null/malformed geometries defensively.
    """
    feats, attrs = [], []
    shp_dir = os.path.dirname(shp_path)
    base = os.path.splitext(os.path.basename(shp_path))[0]
    cpg_path = os.path.join(shp_dir, base + ".cpg")

    # Determine DBF encoding
    encoding = None
    if os.path.exists(cpg_path):
        try:
            with open(cpg_path, "r", encoding="ascii", errors="ignore") as f:
                enc_line = f.read().strip()
                if enc_line:
                    encoding = enc_line
        except Exception:
            pass

    def _open_fiona(enc):
        return fiona.open(shp_path, encoding=enc) if enc else fiona.open(shp_path)

    shp_crs_final = None
    with fiona.Env(SHAPE_RESTORE_SHX='YES'):
        # Try cpg encoding → utf-8 → latin1
        tried = [encoding, "utf-8", "latin1"]
        last_err = None
        for enc in tried:
            try:
                with _open_fiona(enc) as src:
                    shp_crs_local = src.crs_wkt or src.crs
                    props_schema = (src.schema or {}).get("properties", {})
                    field_names = list(props_schema.keys())
                    if not field_names:
                        raise RuntimeError("No attribute fields in schema (DBF not visible).")

                    # Resolve keys against SCHEMA (not a sample feature)
                    dummy_props = {k: None for k in field_names}
                    code_key  = _resolve_field(dummy_props, [ADMIN_CODE_COL], required=True)
                    aei_key   = _resolve_field(dummy_props, ADMIN_AEI_COLS, required=True)  # AEI in hectares
                    cntry_key = _resolve_field(dummy_props, ADMIN_CNTRY_COLS, required=False)

                    # Transform raster bounds to shapefile CRS for coarse prefilter
                    if shp_crs_local:
                        rb_to_shp = Transformer.from_crs(raster_crs, shp_crs_local, always_xy=True)
                        rb_shp = shp_transform(lambda x, y: rb_to_shp.transform(x, y), box(*raster_bounds))
                    else:
                        rb_shp = box(*raster_bounds)

                    # Iterate features
                    for rec in src:
                        gj = rec.get("geometry")
                        if gj is None:
                            continue  # null geometry → skip (defensive)
                        try:
                            g = shape(gj)
                        except Exception:
                            continue
                        if g.is_empty:
                            continue
                        if not g.intersects(rb_shp):
                            continue

                        props = rec.get("properties") or {}
                        try:
                            uc = int(props[code_key])
                            aei_ha = float(props[aei_key])  # hectares
                        except Exception:
                            continue

                        feats.append(g)
                        attrs.append({
                            "unit_code": uc,
                            "aei_ha": aei_ha,
                            "country": str(props.get(cntry_key, "")).strip() if cntry_key else ""
                        })
                # success → keep the CRS we used
                shp_crs_final = shp_crs_local
                break
            except Exception as e:
                last_err = e
                continue

        if last_err and not feats:
            raise RuntimeError(f"Failed to read attributes from shapefile. Tried encodings {tried}. Last error: {last_err}")

    if not feats:
        return [], [], None, {}

    # Reproject to raster CRS for rasterize
    shp_crs = shp_crs_final
    if not shp_crs:
        print("⚠️  Shapefile has no CRS (.prj missing). Assuming raster CRS.")
        shp_crs = raster_crs

    transformer = Transformer.from_crs(shp_crs, raster_crs, always_xy=True)
    geoms_ras = [shp_transform(lambda x, y: transformer.transform(x, y), g) for g in feats]

    tree_ras = STRtree(geoms_ras)
    # IMPORTANT: map by WKB (value identity), not id(...)
    g2i_wkb = {g.wkb: i for i, g in enumerate(geoms_ras)}

    return geoms_ras, attrs, tree_ras, g2i_wkb


# ---- STRtree helper: get candidate indices robustly (Shapely 2 or fallback) ----
def _tree_candidate_indices(tree_ras, tile_poly, geoms_ras, g2i_wkb):
    """
    Return list of indices of geoms that intersect tile_poly.
    Prefer Shapely 2's predicate indices; otherwise map WKBs.
    """
    # Fast path: Shapely 2 can return integer indices with predicate
    try:
        idx = tree_ras.query(tile_poly, predicate="intersects")
        if isinstance(idx, np.ndarray) and np.issubdtype(idx.dtype, np.integer):
            return idx.tolist()
    except TypeError:
        # Older shapely: predicate argument not supported
        pass

    # Fallback: geometry array → map to indices by WKB, then precise intersects
    cand = tree_ras.query(tile_poly)
    if isinstance(cand, np.ndarray):
        cand = cand.tolist()
    out = []
    for g in cand:
        i = g2i_wkb.get(g.wkb, None)
        if i is None:
            # last resort: linear search (rare)
            try:
                i = next(j for j, gg in enumerate(geoms_ras) if gg.equals(g))
            except StopIteration:
                continue
        if geoms_ras[i].intersects(tile_poly):
            out.append(i)
    return out


# ---------------------- Raster helpers -------------------
def _iter_tiles(H, W, tile=TILE):
    for r0 in range(0, H, tile):
        for c0 in range(0, W, tile):
            h = min(tile, H - r0)
            w = min(tile, W - c0)
            yield Window(c0, r0, w, h)

def _read_prob_tile(src, W):
    arr = src.read(1, window=W, out_dtype="float32", masked=True).filled(np.nan)
    finite = np.isfinite(arr)
    if finite.any() and float(np.nanmax(arr[finite])) > 1.5:
        arr[finite] /= 100.0
    np.clip(arr, 0.0, 1.0, out=arr, where=finite)
    return arr

def _tile_bounds(window, transform):
    left, top = transform * (window.col_off, window.row_off)
    right, bottom = transform * (window.col_off + window.width, window.row_off + window.height)
    x0, x1 = sorted([left, right])
    y0, y1 = sorted([bottom, top])
    return (x0, y0, x1, y1)

# ---------------------- Core algorithm -------------------
def aei_binarize_per_admin_from_shapefile(drive):
    """
    Main entry: uses ONLY the admin shapefile with AEI in hectares
    to allocate per-admin pixel quotas and write one binary per country.
    """
    # Locate folders
    cmp_id  = get_subfolder_fuzzy(drive, ROOT_FOLDER_ID, PARENT_FOLDER_NAME)
    if not cmp_id: raise RuntimeError(f"Folder '{PARENT_FOLDER_NAME}' not found under ROOT.")
    prob_id = get_subfolder_fuzzy(drive, cmp_id, PROB_SUBFOLDER_NAME)
    if not prob_id: raise RuntimeError("Probability folder not found (tried fuzzy match).")
    binary_id = get_or_create_folder(drive, prob_id, "Binary")

    aei_folder_id = get_subfolder_fuzzy(drive, ROOT_FOLDER_ID, NATIONAL_AEI_FOLDER)
    if not aei_folder_id: raise RuntimeError("National AEI folder not found at ROOT.")

    # Download admin shapefile bundle
    shp_paths = _download_shapefile_bundle(drive, aei_folder_id, ADMIN_SHP_BASE)
    shp_path  = shp_paths[".shp"]

    # List probability TIFFs
    files = [it for it in list_files(drive, prob_id)
             if isinstance(it, dict)
             and it.get("mimeType") != "application/vnd.google-apps.folder"
             and it.get("title","").lower().endswith((".tif",".tiff"))]
    if not files:
        raise FileNotFoundError("No probability TIFFs in Probability folder.")

    for it in files:
        title   = it.get("title","")
        country = _extract_country_from_fname(title)
        print(f"\n=== {country} (per-admin from shapefile; AEI in hectares) ===")

        rtmp = download_to_temp(it, ".tif")
        with rasterio.open(rtmp) as src:
            H, W = src.height, src.width
            ras_crs = src.crs
            rb = src.bounds
            ras_bounds = (rb.left, rb.bottom, rb.right, rb.top)

            # Read & subset admins, reproject to raster CRS
            geoms_ras, attrs, tree_ras, g2i_wkb = _read_admins_for_raster(shp_path, ras_crs, ras_bounds)
            if not geoms_ras:
                print("  ⚠️  No admin polygons intersect this raster; skipping.")
                continue

            # Optional filter by country name if present (keeps all if missing)
            want = _canon(country)
            keep = [i for i,a in enumerate(attrs) if (not a["country"]) or _canon(a["country"]) == want]
            if keep and len(keep) < len(attrs):
                geoms_ras = [geoms_ras[i] for i in keep]
                attrs     = [attrs[i] for i in keep]
                tree_ras  = STRtree(geoms_ras)
                g2i_wkb   = {g.wkb: i for i, g in enumerate(geoms_ras)}

            # Targets per admin (AEI in HECTARES → m² → pixels); include zeros
            K_map = {}
            for a in attrs:
                aei_m2 = a["aei_ha"] * 10_000.0
                K_map[a["unit_code"]] = int(math.floor(aei_m2 / PIXEL_AREA_M2))  # may be 0

            # PASS 1: per-admin histograms of probability bins (centroid-in)
            hists = defaultdict(Counter)
            for w in _iter_tiles(H, W, TILE):
                prob = _read_prob_tile(src, w)
                valid = np.isfinite(prob)
                if not valid.any():
                    continue

                tile_t = rasterio.windows.transform(w, src.transform)
                tb = _tile_bounds(w, src.transform)
                tile_poly = box(*tb)

                idxs = _tree_candidate_indices(tree_ras, tile_poly, geoms_ras, g2i_wkb)
                if len(idxs) == 0:
                    continue

                shapes = [(geoms_ras[i], attrs[i]["unit_code"]) for i in idxs]
                labels = rasterize(
                    shapes=shapes,
                    out_shape=prob.shape,
                    transform=tile_t,
                    fill=0, dtype="int64",
                    all_touched=False  # centroid-in
                )
                m = valid & (labels != 0)
                if not m.any():
                    continue

                p_int = np.zeros(prob.shape, dtype=np.int32)
                p_int_valid = np.rint(prob[m] * SCALE).astype(np.int32)
                p_int[m] = p_int_valid

                uc = labels[m].ravel()
                pi = p_int[m].ravel()
                for u in np.unique(uc):
                    sel = (uc == u)
                    bc = np.bincount(pi[sel], minlength=SCALE+1)
                    nz = np.nonzero(bc)[0]
                    for b, v in zip(nz, bc[nz]):
                        hists[u][int(b)] += int(v)

            # thresholds per admin (quota 0 => thr=-1, no selection)
            thr_map, need_eq_map = {}, {}
            for u, K in K_map.items():
                total = sum(hists[u].values())
                if K <= 0 or total == 0:
                    thr_map[u] = -1
                    need_eq_map[u] = 0
                    continue
                K = min(K, total)
                cum = 0; gt = 0
                for b in range(SCALE, -1, -1):
                    cnt = int(hists[u].get(b, 0))
                    if cum + cnt >= K:
                        thr_map[u] = b
                        need_eq_map[u] = K - gt
                        break
                    cum += cnt; gt += cnt
            need_eq_left = dict(need_eq_map)

            # PASS 2: write binary output (NaN preserved; default 0; set 1s per admin)
            out_profile = src.profile.copy()
            out_profile.update(
                driver="GTiff",
                height=H, width=W,
                transform=src.transform,
                count=1, dtype="float32", nodata=np.nan,
                compress="LZW", tiled=True, blockxsize=512, blockysize=512,
                BIGTIFF="IF_NEEDED"
            )
            with tempfile.NamedTemporaryFile(delete=False, suffix=".tif") as tmp_out:
                out_local = tmp_out.name

            with rasterio.open(out_local, "w", **out_profile) as dst:
                for w in _iter_tiles(H, W, TILE):
                    prob = _read_prob_tile(src, w)
                    out_tile = np.full(prob.shape, np.nan, dtype=np.float32)

                    valid = np.isfinite(prob)
                    if not valid.any():
                        dst.write(out_tile, 1, window=w)
                        continue

                    tile_t = rasterio.windows.transform(w, src.transform)
                    tb = _tile_bounds(w, src.transform)
                    tile_poly = box(*tb)

                    idxs = _tree_candidate_indices(tree_ras, tile_poly, geoms_ras, g2i_wkb)
                    if len(idxs) == 0:
                        dst.write(out_tile, 1, window=w)
                        continue

                    shapes = [(geoms_ras[i], attrs[i]["unit_code"]) for i in idxs]
                    labels = rasterize(
                        shapes=shapes,
                        out_shape=prob.shape,
                        transform=tile_t,
                        fill=0, dtype="int64",
                        all_touched=False
                    )

                    out_tile[valid] = 0.0  # default: valid-but-not-selected = 0
                    m_all = valid & (labels != 0)
                    if m_all.any():
                        p_int = np.zeros(prob.shape, dtype=np.int32)
                        p_int_valid = np.rint(prob[m_all] * SCALE).astype(np.int32)
                        p_int[m_all] = p_int_valid

                        present = np.unique(labels[m_all])
                        present = [u for u in present if u != 0]  # all units allowed
                        for u in present:
                            u_m = m_all & (labels == u)
                            if not u_m.any():
                                continue
                            t = thr_map.get(u, -1)
                            if t < 0:
                                continue  # quota 0 or no pixels -> stays 0
                            gt_m = u_m & (p_int > t)
                            out_tile[gt_m] = 1.0
                            need = need_eq_left.get(u, 0)
                            if need > 0:
                                eq_m = u_m & (p_int == t) & (out_tile != 1.0)
                                if eq_m.any():
                                    idx = np.flatnonzero(eq_m.ravel())
                                    _rng.shuffle(idx)
                                    take = min(need, idx.size)
                                    sel = idx[:take]
                                    rr, cc = np.unravel_index(sel, eq_m.shape)
                                    out_tile[rr, cc] = 1.0
                                    need_eq_left[u] = need - int(take)

                    dst.write(out_tile, 1, window=w)

            out_name = f"{country}_AEI_binary_0_1.tif"
            upload_path(drive, out_local, binary_id, title=out_name)
            try: os.remove(out_local)
            except: pass

            # CSV summary per admin (includes AEI==0 units)
            with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w", newline="") as tmpcsv:
                wcsv = csv.writer(tmpcsv)
                wcsv.writerow(["unit_code","aei_ha","target_pixels","thr_bin","scale","selected_pixels"])
                for a in attrs:
                    u = a["unit_code"]
                    aei_ha = a["aei_ha"]
                    K = int(math.floor((aei_ha * 10_000.0) / PIXEL_AREA_M2))
                    th = int(thr_map.get(u, -1))
                    gt = sum(v for b, v in (hists[u].items() if u in hists else []) if b > th)
                    ties_taken = (need_eq_map.get(u, 0) - need_eq_left.get(u, 0))
                    sel = int(gt + max(0, ties_taken))
                    wcsv.writerow([u, aei_ha, int(K), th, SCALE, sel])
                csv_path = tmpcsv.name
            upload_path(drive, csv_path, binary_id, title=f"{country}_AEI_admin_summary.csv")
            try: os.remove(csv_path)
            except: pass

            print(f"   ✓ Wrote Binary/{out_name}")

        try: os.remove(rtmp)
        except: pass

    print("\n✅ Done (per-admin from shapefile; AEI in hectares; NaNs preserved).")

# ---------------------- CLI ----------------------
if __name__ == "__main__":
    try:
        drive  # noqa: F821
    except NameError:
        raise RuntimeError("PyDrive2 'drive' not found. Authenticate and expose a global `drive` before running.")
    aei_binarize_per_admin_from_shapefile(drive)


[AEI] Shapefile bundle at /tmp/aei_admin_b9psri1j → found ['.cpg', '.dbf', '.prj', '.shp', '.shx']

=== New_Zealand (per-admin from shapefile; AEI in hectares) ===
   ✓ Wrote Binary/New_Zealand_AEI_binary_0_1.tif

✅ Done (per-admin from shapefile; AEI in hectares; NaNs preserved).


Found 1 binary rasters in Binary/

=== Smoothing New_Zealand_AEI_binary_0_1.tif → New_Zealand_AEI_binary_0_1_MAJ_k9_p50.tif ===
   ✓ Wrote Binary_MAJ/New_Zealand_AEI_binary_0_1_MAJ_k9_p50.tif

✅ Done smoothing all AEI binary maps (majority filter).


In [16]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Merge all per-country AEI binary rasters in Drive → CountryModelPredicted/Probability/Binary/
into a single global binary mosaic (union = max), preserving NaNs.

Output: Merged_AEI_binary_0_1.tif in the same Binary folder.
"""

import os, re, tempfile, warnings
import numpy as np
import rasterio
from rasterio.merge import merge as rio_merge
from rasterio.enums import Resampling

# ------------------------- CONFIG -------------------------
ROOT_FOLDER_ID      = os.environ.get("ROOT_FOLDER_ID", "1gtVp-x-dIgcuNPcoaxHfmZs0ctoG71PS")
PARENT_FOLDER_NAME  = "CountryModelPredicted"
PROB_SUBFOLDER_NAME = "Probability"   # fuzzy-matched (also matches misspellings)
BINARY_SUBFOLDER    = "Binary_MAJ"
OUT_NAME            = "Oceania_binary.tif"
# ----------------------------------------------------------

warnings.filterwarnings("ignore", category=RuntimeWarning)

# ---------------------- Drive helpers ---------------------
def _dq(drive, q):
    return drive.ListFile({
        "q": q,
        "supportsAllDrives": True,
        "includeItemsFromAllDrives": True,
        "maxResults": 1000
    }).GetList()

def list_files(drive, parent_id):
    return _dq(drive, f"'{parent_id}' in parents and trashed=false")

def child_folders(drive, parent_id):
    return _dq(drive, f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder'")

def get_subfolder_exact(drive, parent_id, name):
    res = _dq(drive, f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'")
    return res[0]["id"] if res else None

def get_subfolder_fuzzy(drive, parent_id, desired):
    eid = get_subfolder_exact(drive, parent_id, desired)
    if eid: return eid
    key = re.sub(r"\s+", "", desired.lower())
    for f in child_folders(drive, parent_id):
        tkey = re.sub(r"\s+", "", f.get("title","").lower())
        if key in tkey or ("prob" in key and "prob" in tkey):
            return f["id"]
    return None

def get_or_create_folder(drive, parent_id, name):
    res = _dq(drive, f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'")
    if res: return res[0]["id"]
    nf = drive.CreateFile({"title": name, "parents":[{"id": parent_id}], "mimeType":"application/vnd.google-apps.folder"})
    nf.Upload()
    return nf["id"]

def download_to_temp(drive_file, suffix):
    p = tempfile.NamedTemporaryFile(delete=False, suffix=suffix).name
    drive_file.GetContentFile(p)
    return p

def upload_path(drive, local_path, parent_id, title=None):
    f = drive.CreateFile({"title": title or os.path.basename(local_path), "parents":[{"id": parent_id}]})
    f.SetContentFile(local_path)
    f.Upload()
    return f["id"]

# ---------------------- Merge helpers ---------------------
def _try_merge_with_max(srcs):
    """Preferred path (newer rasterio): method='max', nodata=np.nan."""
    return rio_merge(
        srcs,
        nodata=np.nan,
        dtype="float32",
        precision=7,
        resampling=Resampling.nearest,
        method="max",
    )

def _try_merge_basic(srcs):
    """Older rasterio: no 'method'. We'll nanmax ourselves after merging."""
    mosaic, out_transform = rio_merge(
        srcs,
        nodata=np.nan,
        dtype="float32",
        precision=7,
        resampling=Resampling.nearest,
    )
    # emulate union across sources (elementwise maximum, ignoring NaNs)
    mosaic = np.nanmax(mosaic, axis=0, keepdims=True).astype("float32")
    return mosaic, out_transform

def _try_merge_basic_sentinel(srcs, sentinel=-9999.0):
    """Very old rasterio: nodata cannot be NaN. Use sentinel then convert and nanmax."""
    mosaic, out_transform = rio_merge(
        srcs,
        nodata=sentinel,
        dtype="float32",
        precision=7,
        resampling=Resampling.nearest,
    )
    # convert sentinel to NaN
    mosaic = mosaic.astype("float32", copy=False)
    mosaic[mosaic == sentinel] = np.nan
    mosaic = np.nanmax(mosaic, axis=0, keepdims=True).astype("float32")
    return mosaic, out_transform

# ---------------------- Main merge ------------------------
def merge_all_binary_rasters(drive):
    # Locate folders
    cmp_id     = get_subfolder_fuzzy(drive, ROOT_FOLDER_ID, PARENT_FOLDER_NAME)
    if not cmp_id: raise RuntimeError(f"Folder '{PARENT_FOLDER_NAME}' not found under ROOT.")
    prob_id    = get_subfolder_fuzzy(drive, cmp_id, PROB_SUBFOLDER_NAME)
    if not prob_id: raise RuntimeError("Probability folder not found (tried fuzzy match).")
    binary_id  = get_subfolder_fuzzy(drive, prob_id, BINARY_SUBFOLDER)
    if not binary_id:
        binary_id = get_or_create_folder(drive, prob_id, BINARY_SUBFOLDER)

    # Find all binary GeoTIFFs (skip the mosaic itself if re-running)
    tifs = [it for it in list_files(drive, binary_id)
            if isinstance(it, dict)
            and it.get("mimeType") != "application/vnd.google-apps.folder"
            and it.get("title","").lower().endswith((".tif",".tiff"))
            and OUT_NAME.lower() not in it.get("title","").lower()]
    if not tifs:
        raise FileNotFoundError("No binary .tif files found in Binary/.")

    # Download & open datasets
    local_paths, srcs = [], []
    try:
        for it in tifs:
            p = download_to_temp(it, ".tif")
            local_paths.append(p)
            srcs.append(rasterio.open(p))

        # Try modern merge with 'method=max' → else fallback strategies
        try:
            mosaic, out_transform = _try_merge_with_max(srcs)
        except TypeError:
            # 'method' not supported
            try:
                mosaic, out_transform = _try_merge_basic(srcs)
            except Exception:
                mosaic, out_transform = _try_merge_basic_sentinel(srcs)
        except Exception:
            # Any other unexpected error → robust fallback
            try:
                mosaic, out_transform = _try_merge_basic(srcs)
            except Exception:
                mosaic, out_transform = _try_merge_basic_sentinel(srcs)

        # Build output profile from first raster
        ref = srcs[0]
        out_profile = ref.profile.copy()
        out_profile.update(
            driver="GTiff",
            height=mosaic.shape[1],
            width=mosaic.shape[2],
            transform=out_transform,
            count=1,
            dtype="float32",
            nodata=np.nan,
            compress="LZW",
            tiled=True,
            blockxsize=512,
            blockysize=512,
            BIGTIFF="IF_NEEDED"
        )

        # Write to temp, then upload
        with tempfile.NamedTemporaryFile(delete=False, suffix=".tif") as tmp_out:
            out_local = tmp_out.name

        with rasterio.open(out_local, "w", **out_profile) as dst:
            dst.write(mosaic[0], 1)

        upload_path(drive, out_local, binary_id, title=OUT_NAME)
        try: os.remove(out_local)
        except: pass

        print(f"✅ Merged {len(srcs)} rasters → Binary/{OUT_NAME}")

    finally:
        # Cleanup
        for s in srcs:
            try: s.close()
            except: pass
        for p in local_paths:
            try: os.remove(p)
            except: pass

# ---------------------- CLI ----------------------
if __name__ == "__main__":
    try:
        drive  # noqa: F821
    except NameError:
        raise RuntimeError("PyDrive2 'drive' not found. Authenticate and expose a global `drive` before running.")
    merge_all_binary_rasters(drive)




✅ Merged 9 rasters → Binary/Oceania_binary.tif


RuntimeError: PyDrive2 'drive' not found. Authenticate and expose a global `drive` before running.

In [7]:
import os
# match your Slurm allocation ("64")
os.environ["OMP_NUM_THREADS"] = "64"        # OpenMP (NumPy/SciPy, some raster ops)
os.environ["MKL_NUM_THREADS"] = "64"        # MKL-backed NumPy / scikit-learn
os.environ["NUMEXPR_NUM_THREADS"] = "64"    # if you use numexpr
os.environ["GDAL_NUM_THREADS"] = "64"       # for internal GDAL threaded ops
os.environ["RASTERIO_NUM_THREADS"] = "64"   # rasterio’s thread pool


In [4]:
# -*- coding: utf-8 -*-
"""
Regional RF irrigation model with:
- Farm-aware split via DBSCAN (no leakage)
- AEZ categorical (one-hot)
- OUT-OF-FOLD probability calibration (isotonic by default, Platt optional)
- Regional OOF metrics + artifacts
- Per-country OOF metrics & thresholds saved
- Country predictions on ~30m WGS84 using calibrated probabilities

Now updated to:
  • Read AEZ per-point from Model Training/US_GTPS_per_point.(parquet|csv) when available
  • Prefer per-state AEZ tiles: By Country/<State>/<State>_AEZ.tif (fallback: global AEZ)

Outputs:
  CountryModelPredicted/<Country>_RF_probability_percent.tif
  CountryModelPredicted/<Country>_RF_binary_0_1_cropland.tif
  CountryModelPredicted/<Country>_RF_predictors_count.tif

  Model Training/Regional Models/<Region>/
    - metrics.json
    - confusion.png
    - feature_importance.csv
    - per_country_test_metrics.csv
    - model.joblib

  Model Training/Country Models/<Country>/
    - threshold.json
    - test_metrics.json
    - test_confusion.png
"""
import os, re, json, math, tempfile, joblib, time, random
from collections import deque

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score, average_precision_score, roc_auc_score,
    precision_recall_fscore_support, confusion_matrix, roc_curve,
    precision_recall_curve
)
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.cluster import DBSCAN
from sklearn.linear_model import LogisticRegression
from sklearn.isotonic import IsotonicRegression
import pyproj
from sklearn.base import clone
import rasterio
from rasterio.enums import Resampling as RioResampling
from rasterio.warp import reproject, transform as rio_transform, transform_bounds as rio_transform_bounds
from rasterio.transform import Affine
from rasterio.windows import Window
from rasterio.crs import CRS

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

try:
    from googleapiclient.errors import HttpError
except Exception:
    HttpError = Exception
try:
    from pydrive2.files import ApiRequestError
except Exception:
    ApiRequestError = Exception

os.environ.setdefault("SHAPE_RESTORE_SHX", "YES")

# ------------------------ CONFIG ------------------------
# ------------------------ CONFIG ------------------------
ROOT_FOLDER_ID = os.environ.get("ROOT_FOLDER_ID", "19Ee1kUar4oEabBObq7GcA-jDR7F29x2_")
BY_COUNTRY_NAME = "By Country"
CLE_FOLDER_ID = "1BQqWjXTEfTjWcf_SICM4h3ICeX9jAI65"
MODEL_FOLDER = "Model Training"
REGIONAL_MODELS_FOLDER = "Regional Models"  # used per country now: region_name == country
COUNTRY_MODELS_FOLDER = "Country Models"
OUTPUT_FOLDER = "CountryModelPredicted_Cropland"

# River Network (continental fallback for distance-to-river)
RIVER_NET_FOLDER_ID = "1ReAmJ809FmU-8jLpyyO6VEdmeqDmbaB7"
RIVER_DIST_PREFIX = "Dist_"
RIVER_DIST_SUFFIX = "_river.tif"

# Country column auto-detection; we copy into this name
COUNTRY_COL = os.environ.get("COUNTRY_COL", "country_hint")

# DBSCAN farm grouping (meters)
DBSCAN_EPS_M = float(os.environ.get("DBSCAN_EPS_M", 45))
DBSCAN_MIN_SAMPLES = int(os.environ.get("DBSCAN_MIN_SAMPLES", 1))

# AEZ (global raster)
AEZ_FILE_ID   = os.environ.get("AEZ_FILE_ID", "1te3nKn8vyt2AECmk8NM_xBgpEFXp84RX")
AEZ_FILE_NAME = os.environ.get("AEZ_FILE_NAME", "AEZ_2020s.tif")
AEZ_COL       = "AEZ"

# Local fallbacks
LOCAL_BASE_DIR = os.path.join("./", "China")
LOCAL_BY_COUNTRY_DIR = os.path.join(LOCAL_BASE_DIR, "By Country")
LOCAL_CLE_DIR = os.path.join(LOCAL_BASE_DIR, "CLE")
LOCAL_MODEL_DIR = os.path.join(LOCAL_BASE_DIR, MODEL_FOLDER)
LOCAL_OUTPUT_DIR = os.path.join(LOCAL_BASE_DIR, OUTPUT_FOLDER)

# GTPS/AEZ cache made earlier
GTPS_CACHE_PARQUET = "ChinaGTPS_per_point.parquet"
GTPS_CACHE_CSV     = "ChinaGTPS_per_point.csv"

# Drive retry knobs
DRIVE_MAX_RETRIES = int(os.environ.get("DRIVE_MAX_RETRIES", 6))
DRIVE_RETRY_BASE = float(os.environ.get("DRIVE_RETRY_BASE", 0.8))

POINTS_BASENAME = "China"  # pooled training points

ALL_PREDICTORS = [
    "NDVI_mean","NDWI_mean","GI_mean",
    "NDVI_max","NDWI_max","GI_max",
    "NDVI_min","NDWI_min","GI_min",
    "elevation","slope","ET","PET","dist_to_river",
    AEZ_COL,
]
TARGET = "irrigated"
LONCOL = "longitude"
LATCOL = "latitude"


# RF + data
RF_TREES = int(os.environ.get("RF_TREES", 400))
RANDOM_SEED = 42
TEST_FRACTION = 0.30

BORROW_MAX = 2000

# ---- Calibration & thresholds ----
CALIBRATION_METHOD = os.environ.get("CALIBRATION_METHOD", "isotonic").lower()  # "isotonic" | "platt"
THRESHOLD_STRATEGY = os.environ.get("THRESHOLD_STRATEGY", "auto").lower()      # "auto"|"j"|"f1"|"fixed"|"precision_at"|"recall_at"
FIXED_THRESHOLD = float(os.environ.get("FIXED_THRESHOLD", 0.50))
PRECISION_TARGET = os.environ.get("PRECISION_TARGET")
RECALL_TARGET = os.environ.get("RECALL_TARGET")
PRECISION_TARGET = None if PRECISION_TARGET in (None, "", "None") else float(PRECISION_TARGET)
RECALL_TARGET    = None if RECALL_TARGET    in (None, "", "None") else float(RECALL_TARGET)

# Per-country threshold mode: "global" | "country" | "hybrid"
THRESHOLD_MODE = os.environ.get("THRESHOLD_MODE", "hybrid").lower()
MIN_POS_NEG_FOR_COUNTRY = int(os.environ.get("MIN_POS_NEG_FOR_COUNTRY", 30))
MIN_TOTAL_FOR_COUNTRY   = int(os.environ.get("MIN_TOTAL_FOR_COUNTRY", 100))

# Prediction
TILE_SIZE = 1024
RESAMPLING = RioResampling.bilinear
PER_BAND_RESAMPLING = {AEZ_COL: RioResampling.nearest}
STRICT_ALL_BANDS = False
MIN_PREDICTORS = 9

# ------------------ UTILS / HELPERS ---------------------
def _meters_per_degree(lat_deg: float):
    m_per_deg_lat = 111_132.0
    m_per_deg_lon = 111_320.0 * math.cos(math.radians(lat_deg))
    return m_per_deg_lat, m_per_deg_lon

def _target_wgs84_grid_from_bounds(west, south, east, north, pixel_m=30.0):
    mid_lat = 0.5 * (south + north)
    m_lat, m_lon = _meters_per_degree(mid_lat)
    dy = pixel_m / m_lat; dx = pixel_m / m_lon
    width = int(math.ceil((east - west) / dx)); height = int(math.ceil((north - south) / dy))
    transform = Affine(dx, 0, west, 0, -dy, north)
    return height, width, transform, CRS.from_epsg(4326)

def _iter_tiles(H, W, tile):
    for r0 in range(0, H, tile):
        for c0 in range(0, W, tile):
            r1 = min(r0 + tile, H); c1 = min(c0 + tile, W)
            yield Window.from_slices((r0, r1), (c0, c1))

def _tile_count(H, W, tile):
    return ((H + tile - 1) // tile) * ((W + tile - 1) // tile)

def _compute_basic_metrics(y_true, y_prob, thr):
    y_pred = (y_prob >= thr).astype(int)
    acc = accuracy_score(y_true, y_pred)
    pr, rc, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
    try: roc = roc_auc_score(y_true, y_prob)
    except Exception: roc = float("nan")
    try: prauc = average_precision_score(y_true, y_prob)
    except Exception: prauc = float("nan")
    cm = confusion_matrix(y_true, y_pred, labels=[0,1])
    return {
        "accuracy": float(acc),
        "precision": float(pr),
        "recall": float(rc),
        "f1": float(f1),
        "roc_auc": float(roc),
        "pr_auc": float(prauc),
        "tn": int(cm[0,0]), "fp": int(cm[0,1]),
        "fn": int(cm[1,0]), "tp": int(cm[1,1]),
    }, cm

def _save_confusion_png(cm, title, path):
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.imshow(cm, cmap="Blues")
    ax.set_xticks([0,1]); ax.set_yticks([0,1])
    ax.set_xlabel("Pred"); ax.set_ylabel("True")
    ax.set_title(title)
    for (r,c),v in np.ndenumerate(cm):
        ax.text(c, r, f"{v}", ha="center", va="center")
    fig.savefig(path, dpi=200, bbox_inches="tight")
    plt.close(fig)

def assign_farm_ids(df: pd.DataFrame, lon_col=LONCOL, lat_col=LATCOL) -> pd.Series:
    if df.empty:
        return pd.Series([], dtype="int64")
    proj = pyproj.Transformer.from_crs(4326, 3857, always_xy=True)
    x, y = proj.transform(df[lon_col].values, df[lat_col].values)
    coords = np.column_stack([x, y])
    db = DBSCAN(eps=DBSCAN_EPS_M, min_samples=DBSCAN_MIN_SAMPLES, metric="euclidean", algorithm="ball_tree")
    labels = db.fit_predict(coords)
    if (labels == -1).any():
        max_lab = labels[labels >= 0].max() if np.any(labels >= 0) else -1
        noise_idx = np.where(labels == -1)[0]
        labels[noise_idx] = np.arange(max_lab + 1, max_lab + 1 + len(noise_idx))
    return pd.Series(labels.astype("int64"), index=df.index, name="farm_id")

def _pick_country_candidate(cols):
    cand = ["country_hint","country","Country","COUNTRY","admin","ADMIN","state","STATE","STATE_NAME","State","province","Province"]
    low = {c.lower(): c for c in cols}
    for c in cand:
        if c.lower() in low:
            return low[c.lower()]
    return None

# ---------------------- DRIVE HELPERS -------------------
def _drive_retry(callable_fn, *args, **kwargs):
    last_err = None
    for i in range(DRIVE_MAX_RETRIES):
        try:
            return callable_fn(*args, **kwargs)
        except (HttpError, ApiRequestError) as e:
            last_err = e
            code = getattr(getattr(e, "resp", None), "status", None)
            if code is None:
                msg = str(e).lower()
                transient = any(k in msg for k in ["internal error","backenderror","rate limit","timeout"])
            else:
                transient = 500 <= int(code) < 600 or int(code) in (403, 429)
            if not transient or i == DRIVE_MAX_RETRIES - 1:
                break
            sleep_s = DRIVE_RETRY_BASE * (2 ** i) + random.random() * 0.2
            print(f"[Drive Retry] attempt {i+1}/{DRIVE_MAX_RETRIES} after error {e}; sleeping {sleep_s:.2f}s")
            time.sleep(sleep_s)
        except Exception as e:
            last_err = e
            break
    raise last_err

def _drive_query(drive, q):
    def _run():
        return drive.ListFile({"q": q, "supportsAllDrives": True, "includeItemsFromAllDrives": True, "maxResults": 1000}).GetList()
    return _drive_retry(_run)

def get_subfolder(drive, parent_id, name):
    if drive is None:
        p = os.path.join(LOCAL_BASE_DIR, name)
        return p if os.path.isdir(p) else None
    q = f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'"
    r = _drive_query(drive, q)
    return r[0]["id"] if r else None

def get_or_create_folder(drive, parent_id, name):
    if drive is None:
        base = LOCAL_BASE_DIR if not os.path.isabs(parent_id) else parent_id
        p = os.path.join(base, name) if os.path.isdir(base) else os.path.join(LOCAL_BASE_DIR, name)
        os.makedirs(p, exist_ok=True)
        return p
    q = f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'"
    res = _drive_query(drive, q)
    if res:
        return res[0]["id"]
    def _create():
        f = drive.CreateFile({"title": name, "parents": [{"id": parent_id}], "mimeType": "application/vnd.google-apps.folder"})
        f.Upload(); return f["id"]
    return _drive_retry(_create)

def list_files(drive, parent_id):
    if drive is None:
        return sorted([os.path.join(parent_id, p) for p in os.listdir(parent_id)]) if os.path.isdir(parent_id) else []
    q = f"'{parent_id}' in parents and trashed=false"
    try: return _drive_query(drive, q)
    except Exception as e:
        print(f"[Drive Warning] list_files failed for parent {parent_id}: {e}")
        return []

def download_to_temp(drive_file, dst_path):
    def _dl():
        drive_file.GetContentFile(dst_path); return dst_path
    return _drive_retry(_dl)

def upload_path(drive, local_path, parent_id, title=None):
    if drive is None:
        os.makedirs(parent_id, exist_ok=True)
        import shutil
        dst = os.path.join(parent_id, title or os.path.basename(local_path))
        shutil.copy2(local_path, dst); return dst
    def _up():
        f = drive.CreateFile({"title": title or os.path.basename(local_path), "parents": [{"id": parent_id}]})
        f.SetContentFile(local_path); f.Upload(); return f["id"]
    return _drive_retry(_up)

def _drive_walk(drive, start_id, max_depth=4):
    q = deque([(start_id, 0)])
    while q:
        fid, d = q.popleft()
        items = list_files(drive, fid)
        yield fid, items
        if d >= max_depth: continue
        for it in items:
            if isinstance(it, dict) and it.get("mimeType") == "application/vnd.google-apps.folder":
                q.append((it["id"], d + 1))

# ---------------------- AEZ SOURCES ----------------------
def open_aez_path(drive):
    """Global AEZ.tif (fallback if per-country tile missing)."""
    if drive is None:
        p = os.path.join(LOCAL_BASE_DIR, "AEZ", AEZ_FILE_NAME)
        return p if os.path.exists(p) else None

    def _resolve_shortcut(file_obj):
        if file_obj.get("mimeType") == "application/vnd.google-apps.shortcut":
            tgt = file_obj.get("shortcutDetails", {}).get("targetId")
            if tgt:
                g = drive.CreateFile({"id": tgt})
                g.FetchMetadata(fields="title,mimeType,shortcutDetails"); return g
        return file_obj

    folder_id = None
    if AEZ_FILE_ID:
        f = drive.CreateFile({"id": AEZ_FILE_ID})
        f.FetchMetadata(fields="title,mimeType,shortcutDetails"); f = _resolve_shortcut(f)
        mime = f.get("mimeType")
        if mime == "application/vnd.google-apps.folder":
            folder_id = f["id"]
        elif mime and mime.startswith("application/vnd.google-apps."):
            raise RuntimeError(f"AEZ_FILE_ID points to a Google Doc ({mime}), not a TIFF.")
        else:
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
            f.GetContentFile(tmp); return tmp

    folder_id = folder_id or os.environ.get("AEZ_FOLDER_ID", None)
    if folder_id:
        q = f"'{folder_id}' in parents and trashed=false and title='{AEZ_FILE_NAME}'"
        cand = _drive_query(drive, q)
        if not cand:
            q_any = f"'{folder_id}' in parents and trashed=false and title contains '.tif'"
            cand = _drive_query(drive, q_any)
            cand = [c for c in cand if c.get("title","") == AEZ_FILE_NAME] or cand
        if not cand:
            raise RuntimeError(f"Could not find {AEZ_FILE_NAME} in folder id {folder_id}.")
        f = drive.CreateFile({"id": cand[0]["id"]})
        tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
        f.GetContentFile(tmp); return tmp

    p = os.path.join(LOCAL_BASE_DIR, "AEZ", AEZ_FILE_NAME)
    return p if os.path.exists(p) else None

# ---------------------- POINTS / CACHE -------------------
def _find_points_vector_drive(drive):
    try: mt = get_subfolder(drive, ROOT_FOLDER_ID, MODEL_FOLDER)
    except Exception: mt = None
    if mt:
        items = list_files(drive, mt); parts=[]; has_shp=False
        for it in items:
            if not isinstance(it, dict): continue
            title = it.get("title",""); low = title.lower()
            if not low.startswith("US."): continue
            if low.endswith(".shp"): has_shp=True
            if re.search(r"\.(shp|dbf|shx|prj|cpg|qpj)$", low): parts.append(it)
        if has_shp and parts:
            tdir = tempfile.mkdtemp()
            for p in parts: download_to_temp(p, os.path.join(tdir, p.get("title","")))
            shp_path = os.path.join(tdir, "US.shp")
            if os.path.exists(shp_path):
                print("Using vector points (shp): Model Training/US.shp"); return shp_path
    roots = [mt or ROOT_FOLDER_ID, ROOT_FOLDER_ID]
    for root in roots:
        for fid, items in _drive_walk(drive, root, max_depth=4):
            stems={}
            for it in items:
                if not isinstance(it, dict): continue
                title = it.get("title",""); m = re.match(rf"(.+)\.(shp|dbf|shx|prj|cpg|qpj)$", title, re.IGNORECASE)
                if not m: continue
                stem = m.group(1)
                if stem.lower() == POINTS_BASENAME.lower():
                    stems.setdefault(stem,[]).append(it)
            for stem, parts in stems.items():
                if any(p.get("title","").lower().endswith(".shp") for p in parts):
                    tdir = tempfile.mkdtemp()
                    for p in parts: download_to_temp(p, os.path.join(tdir, p.get("title","")))
                    shp_path = os.path.join(tdir, f"{os.path.basename(stem)}.shp")
                    if os.path.exists(shp_path):
                        print(f"Using vector points (shp): {os.path.basename(stem)}.shp"); return shp_path
            for it in items:
                if not isinstance(it, dict): continue
                title = it.get("title",""); low = title.lower()
                if it.get("mimeType") == "application/vnd.google-apps.folder": continue
                if low.endswith(".gpkg") and low == f"{POINTS_BASENAME.lower()}.gpkg":
                    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".gpkg").name
                    download_to_temp(it, tmp); print(f"Using vector points (gpkg): {title}"); return tmp
                if low.endswith(".zip") and low == f"{POINTS_BASENAME.lower()}.zip":
                    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".zip").name
                    download_to_temp(it, tmp); print(f"Using vector points (zip): {title}"); return tmp
                if low.endswith(".csv") and low == f"{POINTS_BASENAME.lower()}.csv":
                    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
                    download_to_temp(it, tmp); print(f"Using vector points (csv): {title}"); return tmp
    return None

def _find_points_vector_local():
    search_roots = [os.path.join(LOCAL_BASE_DIR, MODEL_FOLDER), LOCAL_BASE_DIR]
    for root in search_roots:
        if not os.path.isdir(root): continue
        for nm in os.listdir(root):
            low = nm.lower()
            if low == f"{POINTS_BASENAME.lower()}.csv":
                print(f"Using vector points (csv): {nm}"); return os.path.join(root, nm)
            if POINTS_BASENAME.lower() in low and low.endswith(".gpkg"):
                print(f"Using vector points (gpkg): {nm}"); return os.path.join(root, nm)
            if POINTS_BASENAME.lower() in low and low.endswith(".zip"):
                print(f"Using vector points (zip): {nm}"); return os.path.join(root, nm)
    for root in search_roots:
        if not os.path.isdir(root): continue
        for nm in os.listdir(root):
            if nm.lower().endswith(".shp") and POINTS_BASENAME.lower() in nm.lower():
                print(f"Using vector points (shp): {nm}"); return os.path.join(root, nm)
    return None

def _ci_lookup(cols, name):
    name_l = name.lower()
    for c in cols:
        if c.lower() == name_l: return c
    return None

def _round6(x):
    return np.round(pd.to_numeric(x, errors="coerce"), 6)

def _load_gtps_cache_df(drive):
    """Load US_GTPS_per_point.(parquet|csv) from Model Training if present."""
    def _from_drive():
        mt = get_or_create_folder(drive, ROOT_FOLDER_ID, MODEL_FOLDER)
        # prefer parquet
        res = _drive_query(drive, f"'{mt}' in parents and trashed=false and title='{GTPS_CACHE_PARQUET}'")
        if res:
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".parquet").name
            download_to_temp(res[0], tmp)
            try:
                return pd.read_parquet(tmp)
            except Exception:
                pass
        res = _drive_query(drive, f"'{mt}' in parents and trashed=false and title='{GTPS_CACHE_CSV}'")
        if res:
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
            download_to_temp(res[0], tmp)
            return pd.read_csv(tmp)
        return None

    def _from_local():
        p1 = os.path.join(LOCAL_MODEL_DIR, GTPS_CACHE_PARQUET)
        p2 = os.path.join(LOCAL_MODEL_DIR, GTPS_CACHE_CSV)
        if os.path.exists(p1):
            try:
                return pd.read_parquet(p1)
            except Exception:
                pass
        if os.path.exists(p2):
            return pd.read_csv(p2)
        return None

    try:
        return _from_drive() if drive is not None else _from_local()
    except Exception as e:
        print(f"   ! AEZ cache read error: {e}")
        return None

def load_points_df(drive):
    import geopandas as gpd, fiona
    path = _find_points_vector_drive(drive) if drive is not None else _find_points_vector_local()
    if path is None:
        raise RuntimeError("Could not locate US.(shp/zip/gpkg/csv) in 'Model Training' tree.")
    if path.lower().endswith(".csv"):
        df = pd.read_csv(path)
    else:
        with fiona.Env(SHAPE_RESTORE_SHX="YES"):
            gdf = gpd.read_file(path)
        if gdf.crs is not None and gdf.crs.to_epsg() != 4326:
            gdf = gdf.to_crs(4326)
        df = pd.DataFrame(gdf.drop(columns="geometry", errors="ignore"))

    lon_col = _ci_lookup(df.columns, LONCOL); lat_col = _ci_lookup(df.columns, LATCOL)
    if lon_col is None or lat_col is None:
        raise RuntimeError("Points must contain longitude and latitude columns.")
    if lon_col != LONCOL: df[LONCOL] = pd.to_numeric(df[lon_col], errors="coerce")
    if lat_col != LATCOL: df[LATCOL] = pd.to_numeric(df[lat_col], errors="coerce")

    tgt_col = _ci_lookup(df.columns, TARGET)
    if tgt_col is None: raise RuntimeError(f"Points file missing column: {TARGET}")
    if tgt_col != TARGET: df[TARGET] = pd.to_numeric(df[tgt_col], errors="coerce")

    # pre-create rounders for cache join
    df["_lonr"] = _round6(df[LONCOL]); df["_latr"] = _round6(df[LATCOL])

    # copy a country-like column
    if COUNTRY_COL not in df.columns:
        cand = _pick_country_candidate(df.columns)
        df[COUNTRY_COL] = df[cand] if cand else "Unknown"

    for c in [TARGET, LONCOL, LATCOL]:
        if c in df.columns: df[c] = pd.to_numeric(df[c], errors="coerce")
    df = df.dropna(subset=[TARGET, LONCOL, LATCOL]).reset_index(drop=True)
    df[TARGET] = df[TARGET].astype(int)
    df[COUNTRY_COL] = df[COUNTRY_COL].astype(str)

    # Bring AEZ from cache if available
    if AEZ_COL not in df.columns or df[AEZ_COL].isna().all():
        cache = _load_gtps_cache_df(drive)
        if cache is not None:
            # normalize cache cols
            cc = {c.lower(): c for c in cache.columns}
            lonc = cc.get(LONCOL.lower(), None) or cc.get("x") or cc.get("lon") or list(cache.columns)[1]
            latc = cc.get(LATCOL.lower(), None) or cc.get("y") or cc.get("lat") or list(cache.columns)[2]
            aezc = cc.get(AEZ_COL.lower(), None) or cc.get("gtps") or cc.get("aez")
            if aezc is not None:
                cache["_lonr"] = _round6(cache[lonc]); cache["_latr"] = _round6(cache[latc])
                df = df.merge(cache[["_lonr","_latr",aezc]].rename(columns={aezc:AEZ_COL}),
                              on=["_lonr","_latr"], how="left")
                if AEZ_COL in df.columns:
                    df[AEZ_COL] = pd.to_numeric(df[AEZ_COL], errors="coerce")
                    df[AEZ_COL] = np.rint(df[AEZ_COL]).astype("float64")
                    print(f"   • AEZ filled from cache for {np.isfinite(df[AEZ_COL]).sum()}/{len(df)} points")
    return df

# ---------------------- RASTERS -------------------------
def open_country_var_path(drive, country, variable):
    """Any By Country/<Country>/<Country>_<variable>.tif (case-sensitive on names passed in)."""
    if drive is None:
        p = os.path.join(LOCAL_BY_COUNTRY_DIR, country, f"{country}_{variable}.tif")
        return p if os.path.exists(p) else None
    byc = get_subfolder(drive, ROOT_FOLDER_ID, BY_COUNTRY_NAME)
    if not byc: return None
    q = f"'{byc}' in parents and trashed=false and title='{country}' and mimeType='application/vnd.google-apps.folder'"
    res = _drive_query(drive, q)
    if not res: return None
    cid = res[0]["id"]
    q2 = f"'{cid}' in parents and trashed=false and title='{country}_{variable}.tif'"
    files = _drive_query(drive, q2)
    if files:
        tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
        return download_to_temp(files[0], tmp)
    for it in list_files(drive, cid):
        if isinstance(it, dict) and re.match(rf"^{re.escape(country)}_{re.escape(variable)}.*\.tif$", it.get("title",""), re.IGNORECASE):
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
            return download_to_temp(it, tmp)
    return None

def open_CroplandNE(drive):
    if drive is None:
        p = os.path.join(LOCAL_CLE_DIR, "CroplandSE.tif")
        return p if os.path.exists(p) else None
    items = list_files(drive, CLE_FOLDER_ID)
    for it in items:
        if not isinstance(it, dict): continue
        if it.get("title","").lower().startswith("CroplandSE") and it.get("title","").lower().endswith(".tif"):
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
            return download_to_temp(it, tmp)
    return None

def _clip_bounds_from_reference(raster_path):
    with rasterio.open(raster_path) as src:
        b = rio_transform_bounds(src.crs, CRS.from_epsg(4326),
                                 src.bounds.left, src.bounds.bottom, src.bounds.right, src.bounds.top, densify_pts=8)
    return b

def _pick_reference_raster_for_bounds(drive, country):
    # Prefer any available country raster (including AEZ tile) to set the grid
    for v in [AEZ_COL, "CLE"] + ALL_PREDICTORS:
        p = open_country_var_path(drive, country, v) if v != AEZ_COL else open_country_var_path(drive, country, AEZ_COL)
        if v == AEZ_COL and p is None:
            p = open_aez_path(drive)  # fallback
        if p is not None: return p
    raise RuntimeError(f"No reference raster found for bounds in '{country}'.")

def country_grid_and_mask(drive, country):
    ref_path = _pick_reference_raster_for_bounds(drive, country)
    west, south, east, north = _clip_bounds_from_reference(ref_path)
    H, W, transform, crs = _target_wgs84_grid_from_bounds(west, south, east, north, pixel_m=30.0)
    country_cropland = open_country_var_path(drive, country, "CLE")
    if country_cropland is not None:
        cropland = clip_cle_to_grid(drive, H, W, transform, crs, source_path=country_cropland)
    else:
        cropland = np.ones((H, W), dtype=bool)
    return (H, W, transform, crs, cropland, (west, south, east, north))

def clip_cle_to_grid(drive, out_h, out_w, out_transform, out_crs, source_path=None):
    cle = source_path or open_CroplandNE(drive)
    if cle is None: return np.ones((out_h, out_w), dtype=bool)
    arr = np.full((out_h, out_w), -9999.0, dtype=np.float32)
    with rasterio.open(cle) as src:
        reproject(source=rasterio.band(src, 1), destination=arr,
                  src_transform=src.transform, src_crs=src.crs,
                  dst_transform=out_transform, dst_crs=out_crs,
                  dst_nodata=-9999.0, resampling=RioResampling.nearest)
        nd = src.nodata
    if nd is not None: arr = np.where(np.isclose(arr, nd), np.nan, arr)
    arr = np.where(np.isclose(arr, -9999.0), np.nan, arr)
    return np.isfinite(arr) & (arr > 0.5)

# ---------------- THRESHOLDS -----------------------------
def _thr_by_J(y_true, p):
    y = np.asarray(y_true).astype(int)
    if np.unique(y).size < 2: return FIXED_THRESHOLD
    fpr, tpr, th = roc_curve(y, p)
    if len(th) < 2: return FIXED_THRESHOLD
    j = tpr - fpr
    return float(th[int(np.argmax(j))])

def _thr_by_F1(y_true, p):
    y = np.asarray(y_true).astype(int)
    ths = np.linspace(0.01, 0.99, 99); best_t, best_f1 = 0.5, -1
    for t in ths:
        yhat = (p >= t).astype(int)
        _, _, f1, _ = precision_recall_fscore_support(y, yhat, average="binary", zero_division=0)
        if f1 > best_f1: best_f1, best_t = f1, t
    return float(best_t)

def _thr_precision_at(y_true, p, target):
    if target is None: return _thr_by_J(y_true, p)
    prec, rec, thr = precision_recall_curve(y_true, p)
    cand = thr[prec[1:] >= target]
    if len(cand) > 0:
        return float(np.min(cand))  # smallest threshold achieving target precision
    return _thr_by_F1(y_true, p)

def _thr_recall_at(y_true, p, target):
    if target is None: return _thr_by_J(y_true, p)
    prec, rec, thr = precision_recall_curve(y_true, p)
    cand = thr[rec[1:] >= target]
    if len(cand) > 0:
        return float(np.max(cand))  # largest threshold while keeping recall >= target
    return _thr_by_F1(y_true, p)

def _pick_threshold(y_true, p):
    s = THRESHOLD_STRATEGY.lower()
    if s == "fixed":       return float(FIXED_THRESHOLD)
    if s == "f1":          return _thr_by_F1(y_true, p)
    if s == "j":           return _thr_by_J(y_true, p)
    if s == "precision_at":return _thr_precision_at(y_true, p, PRECISION_TARGET)
    if s == "recall_at":   return _thr_recall_at(y_true, p, RECALL_TARGET)
    if PRECISION_TARGET is not None:
        return _thr_precision_at(y_true, p, PRECISION_TARGET)
    if RECALL_TARGET is not None:
        return _thr_recall_at(y_true, p, RECALL_TARGET)
    return _thr_by_J(y_true, p)

# ---------------- Calibration helpers -------------------
def _fit_calibrator(raw_pos_probs, y, method="isotonic"):
    raw_pos_probs = np.asarray(raw_pos_probs, dtype=float).ravel()
    y = np.asarray(y, dtype=int).ravel()
    if method == "platt":
        lr = LogisticRegression(solver="lbfgs", max_iter=1000)
        lr.fit(raw_pos_probs.reshape(-1,1), y)
        return {"method": "platt", "model": lr}
    ir = IsotonicRegression(y_min=0.0, y_max=1.0, out_of_bounds="clip")
    ir.fit(raw_pos_probs, y)
    return {"method": "isotonic", "model": ir}

def _apply_calibrator(calib, raw_pos_probs):
    raw_pos_probs = np.asarray(raw_pos_probs, dtype=float).ravel()
    if calib is None:
        return raw_pos_probs
    if calib.get("method") == "platt":
        return calib["model"].predict_proba(raw_pos_probs.reshape(-1,1))[:,1]
    return calib["model"].predict(raw_pos_probs)

# ------------- POINT SAMPLING FROM RASTERS --------------
def _sample_raster_at_lonlat(src, lon_arr, lat_arr):
    xs = np.asarray(lon_arr, dtype=float); ys = np.asarray(lat_arr, dtype=float)
    if src.crs and (src.crs.to_epsg() != 4326):
        tx, ty = rio_transform(CRS.from_epsg(4326), src.crs, xs.tolist(), ys.tolist())
        xs = np.asarray(tx, dtype=float); ys = np.asarray(ty, dtype=float)
    out = np.full(xs.shape[0], np.nan, dtype=np.float32)
    xmin, ymin, xmax, ymax = src.bounds.left, src.bounds.bottom, src.bounds.right, src.bounds.top
    inside = (xs >= xmin) & (xs <= xmax) & (ys >= ymin) & (ys <= ymax)
    if not np.any(inside): return out
    idx = np.where(inside)[0]
    coords = list(zip(xs[idx], ys[idx]))
    vals = np.array([v[0] for v in src.sample(coords)], dtype=np.float32)
    if src.nodata is not None and np.isfinite(src.nodata):
        vals = np.where(np.isclose(vals, np.float32(src.nodata)), np.float32(np.nan), vals)
    out[idx] = vals; return out

# ---- TRAINING FILL (num predictors + AEZ) ---------------
def _fill_training_predictors_from_country_rasters(drive, df_all, countries):
    df = df_all.copy()

    # AEZ: if still missing after cache merge, try per-country tiles, else global
    if AEZ_COL not in df.columns or df[AEZ_COL].isna().any():
        miss = df.index if AEZ_COL not in df.columns else df.index[df[AEZ_COL].isna()]
        if len(miss) > 0:
            aez_filled = 0
            # try per-country tiles first
            for c in countries:
                try:
                    ref = open_country_var_path(drive, c, AEZ_COL) or open_aez_path(drive)
                    if not ref: continue
                    w, s, e, n = _clip_bounds_from_reference(ref)
                    idx = df.loc[miss].index[
                        df.loc[miss, LONCOL].between(min(w,e), max(w,e)) &
                        df.loc[miss, LATCOL].between(min(s,n), max(s,n))
                    ]
                    if len(idx) == 0: continue
                    with rasterio.open(ref) as aez_src:
                        vals = _sample_raster_at_lonlat(aez_src, df.loc[idx, LONCOL].values, df.loc[idx, LATCOL].values)
                    df.loc[idx, AEZ_COL] = np.rint(pd.to_numeric(vals, errors="coerce")).astype("float64")
                    aez_filled += int(np.isfinite(df.loc[idx, AEZ_COL]).sum())
                except Exception as ex:
                    print(f"   ! AEZ fill skip for {c}: {ex}")
            if aez_filled > 0:
                print(f"   • AEZ sampled from tiles/global for {aez_filled} points")

    # Per-country numeric predictors from rasters
    for c in countries:
        try:
            ref = open_country_var_path(drive, c, "CLE") or open_country_var_path(drive, c, "NDVI_mean")
            if not ref: continue
            w, s, e, n = _clip_bounds_from_reference(ref)
        except Exception as ex:
            print(f"   ! Skip raster fill for {c}: {ex}"); continue
        idx = df.index[df[LONCOL].between(min(w,e), max(w,e)) & df[LATCOL].between(min(s,n), max(s,n))]
        if len(idx) == 0: continue
        raster_cache = {}; filled_any = False
        for p in ALL_PREDICTORS:
            if p == AEZ_COL: continue
            pth = open_country_var_path(drive, c, p)
            if pth is None: continue
            if p not in raster_cache: raster_cache[p] = rasterio.open(pth)
            vals = _sample_raster_at_lonlat(raster_cache[p], df.loc[idx, LONCOL].values, df.loc[idx, LATCOL].values)
            df.loc[idx, p] = pd.to_numeric(vals, errors="coerce").astype("float64")
            filled_any = True
        if filled_any:
            print(f"   • Training fill from rasters for {c}: {len(idx)} points")
    return df

def list_countries_auto(drive):
    names = set()
    if drive is None:
        if os.path.isdir(LOCAL_BY_COUNTRY_DIR):
            for d in os.listdir(LOCAL_BY_COUNTRY_DIR):
                if os.path.isdir(os.path.join(LOCAL_BY_COUNTRY_DIR, d)): names.add(d)
    else:
        byc = get_subfolder(drive, ROOT_FOLDER_ID, BY_COUNTRY_NAME)
        if byc:
            subs = _drive_query(drive, f"'{byc}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder'")
            for it in subs: names.add(it["title"])
    return sorted(names)

# ----------------- TRAIN REGIONAL MODEL -----------------
def _borrow_neighbors(df_all, base_cols, max_points):
    if COUNTRY_COL not in df_all.columns:
        df_all[COUNTRY_COL] = "Unknown"
    keep_cols = list(dict.fromkeys(base_cols))
    pooled = df_all[keep_cols].dropna(subset=[TARGET, LONCOL, LATCOL]).copy()
    if len(pooled) > max_points:
        pooled = pooled.sample(max_points, random_state=RANDOM_SEED)
    if pooled[TARGET].nunique() < 2:
        other = 1 - int(pooled[TARGET].iloc[0])
        extra = df_all[df_all[TARGET] == other]
        if not extra.empty:
            take = min(1000, len(extra))
            pooled = pd.concat([pooled, extra.sample(take, random_state=RANDOM_SEED)[keep_cols]], ignore_index=True)
    if COUNTRY_COL not in pooled.columns:
        pooled[COUNTRY_COL] = "Unknown"
    pooled[COUNTRY_COL] = pooled[COUNTRY_COL].astype(str)
    return pooled.reset_index(drop=True)

def train_region_model(drive, region_name, df_all, countries):
    print(f"=== Train regional model: {region_name} ===")

    # Fill predictors (uses AEZ cache + per-country AEZ tiles)
    df_all = _fill_training_predictors_from_country_rasters(drive, df_all, countries)

    use_preds = [p for p in ALL_PREDICTORS if p in df_all.columns]
    base_cols = [TARGET, LONCOL, LATCOL, COUNTRY_COL] + use_preds
    df = _borrow_neighbors(df_all, base_cols, BORROW_MAX)

    if len(df) < 200 or df[TARGET].nunique() < 2:
        raise RuntimeError(f"Insufficient pooled points for region training (n={len(df)}, classes={df[TARGET].nunique()}).")

    num_features = [f for f in use_preds if f != AEZ_COL]
    cat_features = [AEZ_COL] if AEZ_COL in use_preds else []

    pre = ColumnTransformer(
        transformers=[
            ("num", make_pipeline(SimpleImputer(strategy="median")), num_features),
            ("cat", make_pipeline(SimpleImputer(strategy="most_frequent"),
                                  OneHotEncoder(handle_unknown="ignore", sparse_output=False)), cat_features),
        ],
        remainder="drop",
        verbose_feature_names_out=False,
    )

    rf = RandomForestClassifier(
        n_estimators=RF_TREES, max_features="sqrt", class_weight="balanced_subsample",
        n_jobs=-1, random_state=RANDOM_SEED, bootstrap=True,
    )
    pipe = Pipeline([("pre", pre), ("rf", rf)])

    # ---------- Group-aware OOF to fit calibrator ----------
    df["farm_id"] = assign_farm_ids(df, lon_col=LONCOL, lat_col=LATCOL)
    n_splits = 5 if 0.19 <= TEST_FRACTION <= 0.21 else max(3, int(round(1.0 / max(TEST_FRACTION, 1e-3))))
    sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_SEED)

    X_all = df[use_preds].copy()
    y_all = df[TARGET].astype(int).values
    g_all = df["farm_id"].values
    countries_all = df[COUNTRY_COL].astype(str).values
    lon_all = df[LONCOL].values
    lat_all = df[LATCOL].values

    # enforce dtypes
    for col in [f for f in use_preds if f != AEZ_COL]:
        X_all[col] = pd.to_numeric(X_all[col], errors="coerce").astype("float64")
    if AEZ_COL in use_preds:
        X_all[AEZ_COL] = pd.to_numeric(X_all[AEZ_COL], errors="coerce")
        X_all[AEZ_COL] = np.rint(X_all[AEZ_COL]).astype("float64")

    oof_raw = np.full(len(X_all), np.nan, dtype=float)
    classes_ = None
    for tr_idx, te_idx in sgkf.split(X_all, y_all, groups=g_all):
        tr_X = X_all.iloc[tr_idx]; tr_y = y_all[tr_idx]
        te_X = X_all.iloc[te_idx]
        pipe_fold = clone(pipe)
        pipe_fold.fit(tr_X, tr_y.astype(int))
        if classes_ is None:
            classes_ = pipe_fold.named_steps["rf"].classes_
        pos_idx_fold = int(np.where(classes_ == 1)[0][0])
        oof_raw[te_idx] = pipe_fold.predict_proba(te_X)[:, pos_idx_fold]

    # fit final model on all data
    pipe.fit(X_all, y_all.astype(int))
    classes_ = pipe.named_steps["rf"].classes_
    pos_idx = int(np.where(classes_ == 1)[0][0])

    # calibrate on OOF
    calib = _fit_calibrator(oof_raw, y_all, CALIBRATION_METHOD)
    oof_cal = _apply_calibrator(calib, oof_raw)

    # ---------- GLOBAL threshold from OOF calibrated ----------
    global_thr = _pick_threshold(y_all, oof_cal)

    # ---------- Regional OOF metrics & confusion ----------
    reg_metrics, reg_cm = _compute_basic_metrics(y_all, oof_cal, global_thr)
    reg_metrics.update({
        "region": region_name,
        "n_oof": int(len(y_all)),
        "threshold_strategy": THRESHOLD_STRATEGY,
        "global_threshold": float(global_thr),
        "calibration": CALIBRATION_METHOD,
    })

    # feature importances from final model
    out_names = list(pipe.named_steps["pre"].get_feature_names_out())
    rf_imp = pipe.named_steps["rf"].feature_importances_
    fi_df = pd.DataFrame({"feature": out_names, "importance": rf_imp}).sort_values("importance", ascending=False)

    # ---------- Prep folders ----------
    if drive is None:
        reg_base = os.path.join(LOCAL_MODEL_DIR, REGIONAL_MODELS_FOLDER, region_name)
        os.makedirs(reg_base, exist_ok=True)
        c_base_root = os.path.join(LOCAL_MODEL_DIR, COUNTRY_MODELS_FOLDER)
        os.makedirs(c_base_root, exist_ok=True)
    else:
        mt = get_or_create_folder(drive, ROOT_FOLDER_ID, MODEL_FOLDER)
        reg = get_or_create_folder(drive, mt, REGIONAL_MODELS_FOLDER)
        reg_base = get_or_create_folder(drive, reg, region_name)
        c_base_root = get_or_create_folder(drive, mt, COUNTRY_MODELS_FOLDER)

    # ---------- Ensure country tags (bbox fallback) ----------
    assigned = countries_all.copy()
    unknown_mask = (assigned == "Unknown") | (assigned == "") | pd.isna(assigned)
    if unknown_mask.any():
        for ctry in countries:
            try:
                ref = open_country_var_path(drive, ctry, "CLE") or open_country_var_path(drive, ctry, AEZ_COL) or open_aez_path(drive)
                if not ref: continue
                w, s, e, n = _clip_bounds_from_reference(ref)
                inside = unknown_mask & (lon_all >= min(w, e)) & (lon_all <= max(w, e)) \
                                   & (lat_all >= min(s, n)) & (lat_all <= max(s, n))
                assigned[inside] = ctry
            except Exception:
                pass
        countries_all = assigned

    # ---------- Per-country thresholds & OOF metrics ----------
    oof_df = pd.DataFrame({
        "country": countries_all.astype(str),
        "y": y_all.astype(int),
        "prob_cal": oof_cal.astype(float),
    })

    per_rows = []
    per_thr = {}

    for ctry in sorted(set(countries)):
        sub = oof_df[oof_df["country"] == ctry]
        n_total = int(len(sub))
        n_pos = int((sub["y"] == 1).sum())
        n_neg = n_total - n_pos

        if (n_total < MIN_TOTAL_FOR_COUNTRY) or (n_pos < MIN_POS_NEG_FOR_COUNTRY) or (n_neg < MIN_POS_NEG_FOR_COUNTRY):
            thr_c = float(global_thr); note = "fallback_global"
        else:
            thr_c = _pick_threshold(sub["y"].values, sub["prob_cal"].values); note = "oof_calibrated"

        per_thr[ctry] = float(thr_c)
        m_c, cm_c = _compute_basic_metrics(sub["y"].values, sub["prob_cal"].values, thr_c)
        m_c.update({
            "region": region_name, "country": ctry, "n_oof": n_total,
            "n_pos": n_pos, "n_neg": n_neg, "threshold_used": float(thr_c),
            "note": note
        })
        per_rows.append(m_c)

        # save country artifacts (metrics + confusion + threshold.json)
        cdir = get_or_create_folder(drive, c_base_root, ctry)
        with tempfile.TemporaryDirectory() as tdir:
            with open(os.path.join(tdir, "test_metrics.json"), "w") as f:
                json.dump(m_c, f, indent=2)
            upload_path(drive, os.path.join(tdir, "test_metrics.json"), cdir, "test_metrics.json")

            png = os.path.join(tdir, "test_confusion.png")
            _save_confusion_png(cm_c, f"{ctry} — OOF Confusion (thr={thr_c:.3f})", png)
            upload_path(drive, png, cdir, "test_confusion.png")

            with open(os.path.join(tdir, "threshold.json"), "w") as f:
                json.dump({
                    "threshold_strategy": THRESHOLD_STRATEGY,
                    "mode": THRESHOLD_MODE,
                    "calibration": CALIBRATION_METHOD,
                    "threshold": float(thr_c),
                    "n_total": n_total, "n_pos": n_pos, "n_neg": n_neg,
                    "note": note
                }, f, indent=2)
            upload_path(drive, os.path.join(tdir, "threshold.json"), cdir, "threshold.json")

    # regional CSV summary for per-country
    if len(per_rows) > 0:
        per_df = pd.DataFrame(per_rows).sort_values("country")
    else:
        per_df = pd.DataFrame(columns=["country","n_oof","precision","recall","f1","roc_auc","pr_auc",
                                       "tn","fp","fn","tp","threshold_used","note"])

    # ---------- Save regional artifacts ----------
    with tempfile.TemporaryDirectory() as tdir:
        with open(os.path.join(tdir, "metrics.json"), "w") as f:
            json.dump(reg_metrics, f, indent=2)
        upload_path(drive, os.path.join(tdir, "metrics.json"), reg_base, "metrics.json")

        fi_df.to_csv(os.path.join(tdir, "feature_importance.csv"), index=False)
        upload_path(drive, os.path.join(tdir, "feature_importance.csv"), reg_base, "feature_importance.csv")

        per_df.to_csv(os.path.join(tdir, "per_country_test_metrics.csv"), index=False)
        upload_path(drive, os.path.join(tdir, "per_country_test_metrics.csv"), reg_base, "per_country_test_metrics.csv")

        png = os.path.join(tdir, "confusion.png")
        _save_confusion_png(reg_cm, f"{region_name} — OOF Confusion ({THRESHOLD_STRATEGY}, thr={global_thr:.3f})", png)
        upload_path(drive, png, reg_base, "confusion.png")

        joblib.dump(
            {
                "pipeline": pipe,
                "predictors": use_preds,
                "positive_class_index": int(pos_idx),
                "calibration": calib,
                "global_threshold": float(global_thr),
                "threshold_strategy": THRESHOLD_STRATEGY,
                "per_country_thresholds": per_thr,
                "num_features": [f for f in use_preds if f != AEZ_COL],
                "cat_features": [AEZ_COL] if AEZ_COL in use_preds else [],
            },
            os.path.join(tdir, "model.joblib"),
        )
        upload_path(drive, os.path.join(tdir, "model.joblib"), reg_base, "model.joblib")

    model = {
        "pipe": pipe,
        "use_preds": use_preds,
        "pos_idx": int(pos_idx),
        "calib": calib,
        "global_thr": float(global_thr),
        "per_thr": per_thr,
        "reg_base": reg_base
    }
    return model


# ----------------- PREDICT ONE COUNTRY ------------------
def predict_country(drive, country, model, thr, grid_pack):
    out_dir = get_or_create_folder(drive, ROOT_FOLDER_ID, OUTPUT_FOLDER) if drive else LOCAL_OUTPUT_DIR
    os.makedirs(out_dir, exist_ok=True)

    pipe   = model["pipe"]
    pos_idx= model["pos_idx"]
    calib  = model["calib"]

    H, W, transform, crs, cropland = grid_pack
    raster_cache = {}

    # AEZ: prefer per-country tile, fallback to global AEZ
    aez_src_path = open_country_var_path(drive, country, AEZ_COL) or open_aez_path(drive)
    if AEZ_COL in model["use_preds"] and aez_src_path is not None:
        raster_cache[AEZ_COL] = rasterio.open(aez_src_path)

    for v in model["use_preds"]:
        if v == AEZ_COL: continue
        p = open_country_var_path(drive, country, v)
        if p is not None: raster_cache[v] = rasterio.open(p)

    base_profile = {
        "driver": "GTiff","height": H,"width": W,"count": 1,"crs": crs,"transform": transform,
        "compress": "LZW","tiled": True,"blockxsize": 512,"blockysize": 512,
    }
    prob_profile = {**base_profile, "dtype": "float32", "nodata": -9999.0}
    bin_profile  = {**base_profile, "dtype": "uint8", "nodata": 255}
    cnt_profile  = {**base_profile, "dtype": "uint8", "nodata": 0}

    prob_name = f"{country}_RF_probability_percent.tif"
    bin_name  = f"{country}_RF_binary_0_1_cropland.tif"
    cnt_name  = f"{country}_RF_predictors_count.tif"

    eligible_total = 0; positive_total = 0

    with tempfile.TemporaryDirectory() as tdir:
        prob_tmp = os.path.join(tdir, prob_name)
        bin_tmp  = os.path.join(tdir, bin_name)
        cnt_tmp  = os.path.join(tdir, cnt_name)
        with rasterio.open(prob_tmp, "w", **prob_profile) as dst_prob, \
             rasterio.open(bin_tmp,  "w", **bin_profile)  as dst_bin,  \
             rasterio.open(cnt_tmp,  "w", **cnt_profile) as dst_cnt:

            for win in tqdm(_iter_tiles(H, W, TILE_SIZE), total=_tile_count(H, W, TILE_SIZE), desc=f"Predict {country}"):
                r0, r1 = int(win.row_off), int(win.row_off + win.height)
                c0, c1 = int(win.col_off), int(win.col_off + win.width)
                tile_h, tile_w = (r1 - r0), (c1 - c0)
                tile_transform = transform * Affine.translation(c0, r0)

                crop_tile = cropland[r0:r1, c0:c1]
                if not crop_tile.any():
                    dst_prob.write(np.full((1, tile_h, tile_w), prob_profile["nodata"], np.float32), window=win)
                    dst_bin.write (np.full((1, tile_h, tile_w), bin_profile["nodata"], np.uint8), window=win)
                    dst_cnt.write (np.zeros((1, tile_h, tile_w), np.uint8), window=win)
                    continue

                stack = np.full((len(model["use_preds"]), tile_h, tile_w), np.nan, dtype=np.float32)
                valid_count = np.zeros((tile_h, tile_w), dtype=np.uint8)

                for j, v in enumerate(model["use_preds"]):
                    src = raster_cache.get(v)
                    if src is None:
                        p = open_country_var_path(drive, country, v) if v != AEZ_COL else (open_country_var_path(drive, country, AEZ_COL) or open_aez_path(drive))
                        if p: src = raster_cache[v] = rasterio.open(p)
                        else: continue
                    dst_arr = np.full((tile_h, tile_w), np.nan, dtype=np.float32)
                    dst_nd  = -9999.0
                    resamp  = PER_BAND_RESAMPLING.get(v, RESAMPLING)
                    try:
                        reproject(source=rasterio.band(src, 1), destination=dst_arr,
                                  src_transform=src.transform, src_crs=src.crs,
                                  dst_transform=tile_transform, dst_crs=crs,
                                  dst_nodata=dst_nd, resampling=resamp)
                    except Exception as e:
                        print(f"      ! reprojection failed for {v}: {e}"); continue
                    if src.nodata is not None:
                        dst_arr = np.where(np.isclose(dst_arr, src.nodata), np.nan, dst_arr)
                    dst_arr = np.where(np.isclose(dst_arr, dst_nd), np.nan, dst_arr)
                    stack[j] = dst_arr; valid_count += np.isfinite(dst_arr)

                req = (len(model["use_preds"]) if STRICT_ALL_BANDS else min(MIN_PREDICTORS, len(model["use_preds"])))
                good = crop_tile & (valid_count >= req)
                prob_tile = np.full((tile_h, tile_w), np.nan, dtype=np.float32)

                if good.any():
                    X = stack.reshape(len(model["use_preds"]), -1).T
                    good_flat = good.reshape(-1)
                    X_good = X[good_flat].astype("float64")
                    X_good_df = pd.DataFrame(X_good, columns=model["use_preds"])

                    # enforce dtypes
                    num_features = [f for f in model["use_preds"] if f != AEZ_COL]
                    for col in num_features:
                        X_good_df[col] = pd.to_numeric(X_good_df[col], errors="coerce").astype("float64")
                    if AEZ_COL in X_good_df.columns:
                        X_good_df[AEZ_COL] = pd.to_numeric(X_good_df[AEZ_COL], errors="coerce")
                        X_good_df[AEZ_COL] = np.rint(X_good_df[AEZ_COL]).astype("float64")

                    prob_raw = pipe.predict_proba(X_good_df)[:, pos_idx].astype(np.float32)
                    prob_vals = _apply_calibrator(calib, prob_raw).astype(np.float32)

                    bin_vals  = (prob_vals >= thr).astype(np.uint8)
                    prob_tile.reshape(-1)[good_flat] = prob_vals

                    mask_arr = np.isfinite(prob_tile)
                    prob_out = np.full((1, tile_h, tile_w), prob_profile["nodata"], np.float32)
                    prob_out[0][mask_arr] = (prob_tile[mask_arr] * 100.0).astype(np.float32)
                    dst_prob.write(prob_out, window=win)

                    bin_out = np.full((1, tile_h, tile_w), bin_profile["nodata"], np.uint8)
                    bin_out[0].reshape(-1)[good_flat] = bin_vals
                    dst_bin.write(bin_out, window=win)

                    eligible_total += int(mask_arr.sum()); positive_total += int(bin_vals.sum())
                else:
                    dst_prob.write(np.full((1, tile_h, tile_w), prob_profile["nodata"], np.float32), window=win)
                    dst_bin.write (np.full((1, tile_h, tile_w), bin_profile["nodata"], np.uint8), window=win)

                dst_cnt.write(valid_count[np.newaxis], window=win)

        upload_path(drive, prob_tmp, out_dir, prob_name)
        upload_path(drive, bin_tmp,  out_dir, bin_name)
        upload_path(drive, cnt_tmp,  out_dir, cnt_name)

    frac = (positive_total / max(1, eligible_total)) * 100.0
    print(f"   • predicted irrigated fraction inside cropland: {frac:.3f}% ({positive_total:,}/{eligible_total:,})")

# ----------------------------- RUNNER -------------------
def run_region_then_countries(use_drive=True, countries=None, region_name="US"):
    if use_drive:
        try:
            drive  # noqa: F821
            _drive = drive
            try: _drive.auth.service.http.timeout = 120
            except Exception: pass
        except NameError:
            raise RuntimeError("PyDrive 'drive' not found. Authenticate and expose 'drive', or set use_drive=False.")
    else:
        _drive = None
        os.makedirs(LOCAL_OUTPUT_DIR, exist_ok=True)

    df_all = load_points_df(_drive)
    for c in [TARGET, LONCOL, LATCOL] + [p for p in ALL_PREDICTORS if p in df_all.columns and p != AEZ_COL]:
        if c in df_all.columns:
            df_all[c] = pd.to_numeric(df_all[c], errors="coerce")
    df_all = df_all.dropna(subset=[TARGET, LONCOL, LATCOL]).reset_index(drop=True)
    df_all[TARGET] = df_all[TARGET].astype(int)

    if not countries:
        countries = [
            "China_Heilongjiang",
            # "Sudan","Chad", "Niger"
            # "Florida","Georgia","Arkansas","Kentucky","Louisiana",
            #          "Mississippi","South Carolina","North Carolina"
 
        ]

    model = train_region_model(_drive, region_name, df_all, countries)

    # Predict
    print(f"=== PREDICTION per-country (threshold mode: {THRESHOLD_MODE}) ===")
    for c in countries:
        try:
            H, W, transform, crs, cropland, _ = country_grid_and_mask(_drive, c)
            thr_c = model["per_thr"].get(c, model["global_thr"]) if THRESHOLD_MODE in ("country","hybrid") else model["global_thr"]
            predict_country(_drive, c, model, thr_c, (H, W, transform, crs, cropland))
            print(f"✓ {c}: maps → {OUTPUT_FOLDER}, model → {MODEL_FOLDER}/{REGIONAL_MODELS_FOLDER}/{region_name}, thr={thr_c:.3f}")
        except Exception as e:
            print(f"❌ {c}: {e}")

    print("✅ Done.")

if __name__ == "__main__":
    run_region_then_countries(use_drive=True, countries=None, region_name="Algeria")


Using vector points (shp): China.shp
=== Train regional model: Algeria ===
   • AEZ sampled from tiles/global for 8465 points
   • Training fill from rasters for China_Heilongjiang: 8465 points
=== PREDICTION per-country (threshold mode: hybrid) ===


Predict China_Heilongjiang:   0%|          | 0/1221 [00:00<?, ?it/s]

   • predicted irrigated fraction inside cropland: 30.787% (61,209,471/198,819,062)
✓ China_Heilongjiang: maps → CountryModelPredicted_Cropland, model → Model Training/Regional Models/Algeria, thr=0.619
✅ Done.
