In [None]:
from pydrive2.auth import GoogleAuth
from pydrive2.drive import GoogleDrive

# path to  JSON.
CLIENT_JSON = "client_secret_Final.json"   # Google Drive API
gauth = GoogleAuth()
gauth.LoadClientConfigFile(CLIENT_JSON)

# Command-line auth 
gauth.CommandLineAuth()

# Save token 
gauth.SaveCredentialsFile('token.json')

drive = GoogleDrive(gauth)
print("✅ Google Drive connected")

✅ Google Drive connected


In [None]:
import os
# match Slurm allocation ("64") for maximizing efficency 
os.environ["OMP_NUM_THREADS"] = "64"        
os.environ["MKL_NUM_THREADS"] = "64"        
os.environ["NUMEXPR_NUM_THREADS"] = "64"    
os.environ["GDAL_NUM_THREADS"] = "64"       
os.environ["RASTERIO_NUM_THREADS"] = "64"   


In [None]:
# List only folders in Google Drive root
folder_list = drive.ListFile({
    'q': "'root' in parents and mimeType='application/vnd.google-apps.folder' and trashed=false"
}).GetList()

for folder in folder_list:
    print(f"Folder: {folder['title']}  |  ID: {folder['id']}")


In [None]:
# prep_aez_cache_and_tiles.py
# Pre-compute AEZ ("GTPS") for:
#   (A) per-state tiles:   By Country/<State>/<State>_AEZ.tif
#   (B) training points:   Model Training/US_with_AEZ.parquet (+ .csv)
#
# Run:  python prep_aez_cache_and_tiles.py
#       (with PyDrive 'drive' in scope if using Drive; otherwise local)

import os, re, json, math, tempfile, random, time
import numpy as np
import pandas as pd
from collections import deque
import rasterio
from rasterio.enums import Resampling as RioResampling
from rasterio.warp import reproject, transform as rio_transform, transform_bounds as rio_transform_bounds
from rasterio.transform import Affine
from rasterio.crs import CRS

try:
    from googleapiclient.errors import HttpError
except Exception:
    HttpError = Exception
try:
    from pydrive2.files import ApiRequestError
except Exception:
    ApiRequestError = Exception

# ---------- reuse your config ----------
ROOT_FOLDER_ID = os.environ.get("ROOT_FOLDER_ID", "1hqMIyDYEFKnpS8KLxC4bqHmF_9dHXImG")
BY_COUNTRY_NAME = "By Country"
MODEL_FOLDER = "Model Training"
LOCAL_BASE_DIR = os.path.join("./", "US")
LOCAL_BY_COUNTRY_DIR = os.path.join(LOCAL_BASE_DIR, "By Country")
LOCAL_MODEL_DIR = os.path.join(LOCAL_BASE_DIR, MODEL_FOLDER)

AEZ_FILE_ID   = os.environ.get("AEZ_FILE_ID", "1te3nKn8vyt2AECmk8NM_xBgpEFXp84RX")
AEZ_FILE_NAME = os.environ.get("AEZ_FILE_NAME", "AEZ_2020s.tif")

POINTS_BASENAME = "US"
LONCOL = "longitude"; LATCOL = "latitude"
AEZ_COL = "AEZ"

DRIVE_MAX_RETRIES = int(os.environ.get("DRIVE_MAX_RETRIES", 6))
DRIVE_RETRY_BASE  = float(os.environ.get("DRIVE_RETRY_BASE", 0.8))

# ---------- Drive helpers (same behavior as your main file) ----------
def _drive_retry(callable_fn, *args, **kwargs):
    last_err = None
    for i in range(DRIVE_MAX_RETRIES):
        try:
            return callable_fn(*args, **kwargs)
        except (HttpError, ApiRequestError) as e:
            last_err = e
            code = getattr(getattr(e, "resp", None), "status", None)
            if code is None:
                msg = str(e).lower()
                transient = any(k in msg for k in ["internal error","backenderror","rate limit","timeout"])
            else:
                transient = 500 <= int(code) < 600 or int(code) in (403, 429)
            if not transient or i == DRIVE_MAX_RETRIES - 1:
                break
            time.sleep(DRIVE_RETRY_BASE * (2 ** i) + random.random() * 0.2)
        except Exception as e:
            last_err = e
            break
    raise last_err

def _drive_query(drive, q):
    def _run():
        return drive.ListFile({"q": q, "supportsAllDrives": True, "includeItemsFromAllDrives": True, "maxResults": 1000}).GetList()
    return _drive_retry(_run)

def get_subfolder(drive, parent_id, name):
    if drive is None:
        p = os.path.join(LOCAL_BASE_DIR, name)
        return p if os.path.isdir(p) else None
    q = f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'"
    r = _drive_query(drive, q)
    return r[0]["id"] if r else None

def get_or_create_folder(drive, parent_id, name):
    if drive is None:
        base = LOCAL_BASE_DIR if not os.path.isabs(parent_id) else parent_id
        p = os.path.join(base, name) if os.path.isdir(base) else os.path.join(LOCAL_BASE_DIR, name)
        os.makedirs(p, exist_ok=True)
        return p
    q = f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'"
    res = _drive_query(drive, q)
    if res:
        return res[0]["id"]
    def _create():
        f = drive.CreateFile({"title": name, "parents": [{"id": parent_id}], "mimeType": "application/vnd.google-apps.folder"})
        f.Upload(); return f["id"]
    return _drive_retry(_create)

def list_files(drive, parent_id):
    if drive is None:
        return sorted([os.path.join(parent_id, p) for p in os.listdir(parent_id)]) if os.path.isdir(parent_id) else []
    q = f"'{parent_id}' in parents and trashed=false"
    try: return _drive_query(drive, q)
    except Exception:
        return []

def download_to_temp(drive_file, dst_path):
    def _dl():
        drive_file.GetContentFile(dst_path); return dst_path
    return _drive_retry(_dl)

def upload_path(drive, local_path, parent_id, title=None):
    if drive is None:
        os.makedirs(parent_id, exist_ok=True)
        import shutil
        dst = os.path.join(parent_id, title or os.path.basename(local_path))
        shutil.copy2(local_path, dst); return dst
    def _up():
        f = drive.CreateFile({"title": title or os.path.basename(local_path), "parents": [{"id": parent_id}]})
        f.SetContentFile(local_path); f.Upload(); return f["id"]
    return _drive_retry(_up)

def _drive_walk(drive, start_id, max_depth=4):
    q = deque([(start_id, 0)])
    while q:
        fid, d = q.popleft()
        items = list_files(drive, fid)
        yield fid, items
        if d >= max_depth: continue
        for it in items:
            if isinstance(it, dict) and it.get("mimeType") == "application/vnd.google-apps.folder":
                q.append((it["id"], d + 1))

# ---------- AEZ and points IO ----------
def open_aez_path(drive):
    if drive is None:
        p = os.path.join(LOCAL_BASE_DIR, "AEZ", AEZ_FILE_NAME)
        return p if os.path.exists(p) else None

    def _resolve_shortcut(file_obj):
        if file_obj.get("mimeType") == "application/vnd.google-apps.shortcut":
            tgt = file_obj.get("shortcutDetails", {}).get("targetId")
            if tgt:
                g = drive.CreateFile({"id": tgt})
                g.FetchMetadata(fields="title,mimeType,shortcutDetails"); return g
        return file_obj

    if AEZ_FILE_ID:
        f = drive.CreateFile({"id": AEZ_FILE_ID})
        f.FetchMetadata(fields="title,mimeType,shortcutDetails"); f = _resolve_shortcut(f)
        mime = f.get("mimeType")
        if mime == "application/vnd.google-apps.folder":
            folder_id = f["id"]
        elif mime and mime.startswith("application/vnd.google-apps."):
            raise RuntimeError(f"AEZ_FILE_ID points to a Google Doc ({mime}), not a TIFF.")
        else:
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
            f.GetContentFile(tmp); return tmp
    else:
        folder_id = None

    folder_id = folder_id or os.environ.get("AEZ_FOLDER_ID", None)
    if folder_id:
        q = f"'{folder_id}' in parents and trashed=false and title='{AEZ_FILE_NAME}'"
        cand = _drive_query(drive, q)
        if not cand:
            q_any = f"'{folder_id}' in parents and trashed=false and title contains '.tif'"
            cand = _drive_query(drive, q_any)
            cand = [c for c in cand if c.get("title","") == AEZ_FILE_NAME] or cand
        if not cand:
            raise RuntimeError(f"Could not find {AEZ_FILE_NAME} in folder id {folder_id}.")
        f = drive.CreateFile({"id": cand[0]["id"]})
        tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
        f.GetContentFile(tmp); return tmp

    p = os.path.join(LOCAL_BASE_DIR, "AEZ", AEZ_FILE_NAME)
    return p if os.path.exists(p) else None

def _find_points_vector_drive(drive):
    try: mt = get_subfolder(drive, ROOT_FOLDER_ID, MODEL_FOLDER)
    except Exception: mt = None
    roots = [mt or ROOT_FOLDER_ID, ROOT_FOLDER_ID]
    for root in roots:
        for fid, items in _drive_walk(drive, root, max_depth=4):
            stems={}
            for it in items:
                if not isinstance(it, dict): continue
                title = it.get("title",""); m = re.match(rf"(.+)\.(shp|dbf|shx|prj|cpg|qpj)$", title, re.IGNORECASE)
                if not m: continue
                stem = m.group(1)
                if stem.lower() == POINTS_BASENAME.lower():
                    stems.setdefault(stem,[]).append(it)
            for stem, parts in stems.items():
                if any(p.get("title","").lower().endswith(".shp") for p in parts):
                    tdir = tempfile.mkdtemp()
                    for p in parts: download_to_temp(p, os.path.join(tdir, p.get("title","")))
                    shp_path = os.path.join(tdir, f"{os.path.basename(stem)}.shp")
                    if os.path.exists(shp_path):
                        print(f"Using vector points (shp): {os.path.basename(stem)}.shp"); return shp_path
            for it in items:
                if not isinstance(it, dict): continue
                title = it.get("title",""); low = title.lower()
                if it.get("mimeType") == "application/vnd.google-apps.folder": continue
                if low.endswith(".gpkg") and low == f"{POINTS_BASENAME.lower()}.gpkg":
                    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".gpkg").name
                    download_to_temp(it, tmp); print(f"Using vector points (gpkg): {title}"); return tmp
                if low.endswith(".zip") and low == f"{POINTS_BASENAME.lower()}.zip":
                    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".zip").name
                    download_to_temp(it, tmp); print(f"Using vector points (zip): {title}"); return tmp
                if low.endswith(".csv") and low == f"{POINTS_BASENAME.lower()}.csv":
                    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
                    download_to_temp(it, tmp); print(f"Using vector points (csv): {title}"); return tmp
    return None

def _find_points_vector_local():
    for root in [os.path.join(LOCAL_MODEL_DIR), LOCAL_BASE_DIR]:
        if not os.path.isdir(root): continue
        for nm in os.listdir(root):
            low = nm.lower()
            if low == f"{POINTS_BASENAME.lower()}.csv": return os.path.join(root, nm)
            if POINTS_BASENAME.lower() in low and low.endswith(".gpkg"): return os.path.join(root, nm)
            if POINTS_BASENAME.lower() in low and low.endswith(".zip"):  return os.path.join(root, nm)
            if low.endswith(".shp") and POINTS_BASENAME.lower() in low:  return os.path.join(root, nm)
    return None

def load_points_raw(drive):
    import geopandas as gpd, fiona
    path = _find_points_vector_drive(drive) if drive is not None else _find_points_vector_local()
    if path is None:
        raise RuntimeError("Could not locate US.(shp/zip/gpkg/csv) in 'Model Training' tree.")
    if path.lower().endswith(".csv"):
        df = pd.read_csv(path)
    else:
        with fiona.Env(SHAPE_RESTORE_SHX="YES"):
            gdf = gpd.read_file(path)
        if gdf.crs is not None and gdf.crs.to_epsg() != 4326:
            gdf = gdf.to_crs(4326)
        df = pd.DataFrame(gdf.drop(columns="geometry", errors="ignore"))
    if LONCOL not in df.columns or LATCOL not in df.columns:
        # best-effort case-insensitive pickup
        def _ci(cols, name): 
            for c in cols:
                if c.lower()==name.lower(): return c
            return name
        df.rename(columns={_ci(df.columns,LONCOL):LONCOL, _ci(df.columns,LATCOL):LATCOL}, inplace=True)
    return df

def _sample_raster_at_lonlat(src, lon_arr, lat_arr):
    xs = np.asarray(lon_arr, dtype=float); ys = np.asarray(lat_arr, dtype=float)
    if src.crs and (src.crs.to_epsg() != 4326):
        tx, ty = rio_transform(CRS.from_epsg(4326), src.crs, xs.tolist(), ys.tolist())
        xs = np.asarray(tx, dtype=float); ys = np.asarray(ty, dtype=float)
    out = np.full(xs.shape[0], np.nan, dtype=np.float32)
    xmin, ymin, xmax, ymax = src.bounds.left, src.bounds.bottom, src.bounds.right, src.bounds.top
    inside = (xs >= xmin) & (xs <= xmax) & (ys >= ymin) & (ys <= ymax)
    if not np.any(inside): return out
    idx = np.where(inside)[0]
    coords = list(zip(xs[idx], ys[idx]))
    vals = np.array([v[0] for v in src.sample(coords)], dtype=np.float32)
    if src.nodata is not None and np.isfinite(src.nodata):
        vals = np.where(np.isclose(vals, np.float32(src.nodata)), np.float32(np.nan), vals)
    out[idx] = vals; return out

# ---------- Per-state AEZ tiles ----------
def _pick_reference_raster_for_bounds(drive, state):
    # use any state raster present (CLE or NDVI_mean etc.)
    if drive is None:
        pdir = os.path.join(LOCAL_BY_COUNTRY_DIR, state)
        if not os.path.isdir(pdir):
            return None
        for nm in os.listdir(pdir):
            if nm.endswith(".tif") and (("_CLE" in nm) or ("NDVI_mean" in nm) or ("NDVI" in nm)):
                return os.path.join(pdir, nm)
        # fallback to any tif
        for nm in os.listdir(pdir):
            if nm.endswith(".tif"):
                return os.path.join(pdir, nm)
        return None
    byc = get_subfolder(drive, ROOT_FOLDER_ID, BY_COUNTRY_NAME)
    if not byc: return None
    q = f"'{byc}' in parents and trashed=false and title='{state}' and mimeType='application/vnd.google-apps.folder'"
    res = _drive_query(drive, q)
    if not res: return None
    sid = res[0]["id"]
    files = list_files(drive, sid)
    pick = None
    for it in files:
        if not isinstance(it, dict): continue
        t = it.get("title","")
        if t.endswith(".tif") and (("_CLE" in t) or ("NDVI_mean" in t) or ("NDVI" in t)):
            pick = it; break
    if not pick:
        for it in files:
            if isinstance(it, dict) and it.get("title","").endswith(".tif"):
                pick = it; break
    if not pick: return None
    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
    return download_to_temp(pick, tmp)

def _target_wgs84_grid_from_bounds(west, south, east, north, pixel_m=30.0):
    mid_lat = 0.5 * (south + north)
    m_per_deg_lat = 111_132.0
    m_per_deg_lon = 111_320.0 * math.cos(math.radians(mid_lat))
    dy = pixel_m / m_per_deg_lat; dx = pixel_m / m_per_deg_lon
    width = int(math.ceil((east - west) / dx)); height = int(math.ceil((north - south) / dy))
    transform = Affine(dx, 0, west, 0, -dy, north)
    return height, width, transform, CRS.from_epsg(4326)

def _clip_bounds_from_reference(raster_path):
    with rasterio.open(raster_path) as src:
        b = rio_transform_bounds(src.crs, CRS.from_epsg(4326),
                                 src.bounds.left, src.bounds.bottom, src.bounds.right, src.bounds.top, densify_pts=8)
    return b

def write_state_aez_tif(drive, state):
    aez_src_path = open_aez_path(drive)
    if not aez_src_path:
        raise RuntimeError("AEZ raster not found.")
    ref = _pick_reference_raster_for_bounds(drive, state)
    if not ref:
        print(f"   ! skip {state}: no reference raster to define grid"); return
    w,s,e,n = _clip_bounds_from_reference(ref)
    H,W,transform,crs = _target_wgs84_grid_from_bounds(w,s,e,n, pixel_m=30.0)

    base_profile = {
        "driver":"GTiff","height":H,"width":W,"count":1,"crs":crs,"transform":transform,
        "compress":"LZW","tiled":True,"blockxsize":512,"blockysize":512,
        "dtype":"float32","nodata":-9999.0
    }
    with tempfile.TemporaryDirectory() as tdir:
        out_tmp = os.path.join(tdir, f"{state}_AEZ.tif")
        with rasterio.open(aez_src_path) as src, rasterio.open(out_tmp, "w", **base_profile) as dst:
            dst_arr = np.full((H,W), -9999.0, dtype=np.float32)
            reproject(source=rasterio.band(src,1), destination=dst_arr,
                      src_transform=src.transform, src_crs=src.crs,
                      dst_transform=transform, dst_crs=crs,
                      dst_nodata=-9999.0, resampling=RioResampling.nearest)
            dst.write(dst_arr, 1)

        # upload/write to By Country/<State>/<State>_AEZ.tif
        if drive is None:
            out_dir = os.path.join(LOCAL_BY_COUNTRY_DIR, state)
            os.makedirs(out_dir, exist_ok=True)
        else:
            byc = get_subfolder(drive, ROOT_FOLDER_ID, BY_COUNTRY_NAME) or get_or_create_folder(drive, ROOT_FOLDER_ID, BY_COUNTRY_NAME)
            # ensure state folder
            if drive is None:
                out_dir = os.path.join(byc, state); os.makedirs(out_dir, exist_ok=True)
            else:
                q = f"'{byc}' in parents and trashed=false and title='{state}' and mimeType='application/vnd.google-apps.folder'"
                res = _drive_query(drive, q)
                state_id = res[0]["id"] if res else get_or_create_folder(drive, byc, state)
                out_dir = state_id
        upload_path(drive, out_tmp, out_dir, f"{state}_AEZ.tif")
    print(f"   • wrote {state}_AEZ.tif")

def list_states(drive):
    names=set()
    if drive is None:
        if os.path.isdir(LOCAL_BY_COUNTRY_DIR):
            for d in os.listdir(LOCAL_BY_COUNTRY_DIR):
                if os.path.isdir(os.path.join(LOCAL_BY_COUNTRY_DIR,d)):
                    names.add(d)
    else:
        byc = get_subfolder(drive, ROOT_FOLDER_ID, BY_COUNTRY_NAME)
        if byc:
            subs = _drive_query(drive, f"'{byc}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder'")
            for it in subs: names.add(it["title"])
    return sorted(names)

# ---------- Points cache ----------
def augment_points_with_aez(drive):
    df = load_points_raw(drive)
    if AEZ_COL in df.columns and df[AEZ_COL].notna().any():
        print("   • points already have AEZ; writing cache anyway")
    aez_path = open_aez_path(drive)
    if not aez_path:
        raise RuntimeError("AEZ raster not found.")
    with rasterio.open(aez_path) as src:
        vals = _sample_raster_at_lonlat(src, df[LONCOL].values, df[LATCOL].values)
    df[AEZ_COL] = np.rint(pd.to_numeric(vals, errors="coerce")).astype("float64")

    # stable rounding keys for merge
    df["_lonr"] = np.round(df[LONCOL].astype(float), 6)
    df["_latr"] = np.round(df[LATCOL].astype(float), 6)

    # write cache
    if drive is None:
        out_dir = LOCAL_MODEL_DIR
    else:
        out_dir = get_or_create_folder(drive, ROOT_FOLDER_ID, MODEL_FOLDER)
    with tempfile.TemporaryDirectory() as tdir:
        pq = os.path.join(tdir, "US_with_AEZ.parquet")
        cs = os.path.join(tdir, "US_with_AEZ.csv")
        df[[LONCOL,LATCOL,"_lonr","_latr",AEZ_COL]].to_parquet(pq, index=False)
        df[[LONCOL,LATCOL,"_lonr","_latr",AEZ_COL]].to_csv(cs, index=False)
        upload_path(drive, pq, out_dir, "US_with_AEZ.parquet")
        upload_path(drive, cs, out_dir, "US_with_AEZ.csv")
    print("   • wrote Model Training/US_with_AEZ.parquet (+ .csv)")

# ---------- Runner ----------
def run(use_drive=True, states=None):
    if use_drive:
        try:
            drive  # noqa
            _drive = drive
            try: _drive.auth.service.http.timeout = 120
            except Exception: pass
        except NameError:
            raise RuntimeError("PyDrive 'drive' not found. Authenticate and expose 'drive', or set use_drive=False.")
    else:
        _drive = None
        os.makedirs(LOCAL_BY_COUNTRY_DIR, exist_ok=True)
        os.makedirs(LOCAL_MODEL_DIR, exist_ok=True)

    # A) per-state AEZ tiles
    if not states:
        states = list_states(_drive)
    print("=== Build per-state AEZ tiles ===")
    for s in states:
        try:
            write_state_aez_tif(_drive, s)
        except Exception as e:
            print(f"   ! {s}: {e}")

    # B) points AEZ cache
    print("=== Build points AEZ cache ===")
    augment_points_with_aez(_drive)
    print("✅ done.")

if __name__ == "__main__":
    run(use_drive=True, states=None)


In [None]:
"""
Regional RF irrigation model with:
- Farm-aware split via DBSCAN (no leakage)
- AEZ categorical (one-hot)
- OUT-OF-FOLD probability calibration (isotonic by default, Platt optional)
- Regional OOF metrics + artifacts
- Per-country OOF metrics & thresholds saved
- Country predictions on ~30m WGS84 using calibrated probabilities

  • Read AEZ per-point from Model Training/Continent_GTPS_per_point.(parquet|csv) when available
  • Prefer per-state AEZ tiles: By Country/<State>/<State>_AEZ.tif (fallback: global AEZ)

Outputs:
  CountryModelPredicted/<Country>_RF_probability_percent.tif
  CountryModelPredicted/<Country>_RF_binary_0_1_cropland.tif
  CountryModelPredicted/<Country>_RF_predictors_count.tif

  Model Training/Regional Models/<Region>/
    - metrics.json
    - confusion.png
    - feature_importance.csv
    - per_country_test_metrics.csv
    - model.joblib

  Model Training/Country Models/<Country>/
    - threshold.json
    - test_metrics.json
    - test_confusion.png
"""
import os, re, json, math, tempfile, joblib, time, random
from collections import deque

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score, average_precision_score, roc_auc_score,
    precision_recall_fscore_support, confusion_matrix, roc_curve,
    precision_recall_curve
)
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.cluster import DBSCAN
from sklearn.linear_model import LogisticRegression
from sklearn.isotonic import IsotonicRegression
import pyproj
from sklearn.base import clone
import rasterio
from rasterio.enums import Resampling as RioResampling
from rasterio.warp import reproject, transform as rio_transform, transform_bounds as rio_transform_bounds
from rasterio.transform import Affine
from rasterio.windows import Window
from rasterio.crs import CRS

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

try:
    from googleapiclient.errors import HttpError
except Exception:
    HttpError = Exception
try:
    from pydrive2.files import ApiRequestError
except Exception:
    ApiRequestError = Exception

os.environ.setdefault("SHAPE_RESTORE_SHX", "YES")

# ------------------------ CONFIG ------------------------
# ------------------------ CONFIG ------------------------
ROOT_FOLDER_ID = os.environ.get("ROOT_FOLDER_ID", "1ReAmJ809FmU-8jLpyyO6VEdmeqDmbaB7")
BY_COUNTRY_NAME = "By Country"
CLE_FOLDER_ID = "1pOpf-Zy5la4SKTIcrmLX6tGlxVBejKYX"
MODEL_FOLDER = "Model Training"
REGIONAL_MODELS_FOLDER = "Regional Models"  # used per country now: region_name == country
COUNTRY_MODELS_FOLDER = "Country Models"
OUTPUT_FOLDER = "CountryModelPredicted_Cropland"

# River Network (continental fallback for distance-to-river)
RIVER_NET_FOLDER_ID = "1ReAmJ809FmU-8jLpyyO6VEdmeqDmbaB7"
RIVER_DIST_PREFIX = "Dist_"
RIVER_DIST_SUFFIX = "_river.tif"

# Country column auto-detection; we copy into this name
COUNTRY_COL = os.environ.get("COUNTRY_COL", "country_hint")

# DBSCAN farm grouping (meters)
DBSCAN_EPS_M = float(os.environ.get("DBSCAN_EPS_M", 45))
DBSCAN_MIN_SAMPLES = int(os.environ.get("DBSCAN_MIN_SAMPLES", 1))

# AEZ (global raster)
AEZ_FILE_ID   = os.environ.get("AEZ_FILE_ID", "1te3nKn8vyt2AECmk8NM_xBgpEFXp84RX")
AEZ_FILE_NAME = os.environ.get("AEZ_FILE_NAME", "AEZ_2020s.tif")
AEZ_COL       = "AEZ"

# Local fallbacks
LOCAL_BASE_DIR = os.path.join("./", "Africa_VIs_Env")
LOCAL_BY_COUNTRY_DIR = os.path.join(LOCAL_BASE_DIR, "By Country")
LOCAL_CLE_DIR = os.path.join(LOCAL_BASE_DIR, "CLE")
LOCAL_MODEL_DIR = os.path.join(LOCAL_BASE_DIR, MODEL_FOLDER)
LOCAL_OUTPUT_DIR = os.path.join(LOCAL_BASE_DIR, OUTPUT_FOLDER)

# GTPS/AEZ cache made earlier
GTPS_CACHE_PARQUET = "Africa_GTPS_per_point.parquet"
GTPS_CACHE_CSV     = "Africa_GTPS_per_point.csv"

# Drive retry knobs
DRIVE_MAX_RETRIES = int(os.environ.get("DRIVE_MAX_RETRIES", 6))
DRIVE_RETRY_BASE = float(os.environ.get("DRIVE_RETRY_BASE", 0.8))

POINTS_BASENAME = "Africa"  # pooled training points

ALL_PREDICTORS = [
    "NDVI_mean","NDWI_mean","GI_mean",
    "NDVI_max","NDWI_max","GI_max",
    "NDVI_min","NDWI_min","GI_min",
    "elevation","slope",
    # "ET","PET",
    "dist_to_river",
    AEZ_COL,
]
TARGET = "irrigated"
LONCOL = "longitude"
LATCOL = "latitude"


# RF + data
RF_TREES = int(os.environ.get("RF_TREES", 400))
RANDOM_SEED = 42
TEST_FRACTION = 0.30

BORROW_MAX = 2000

# ---- Calibration & thresholds ----
CALIBRATION_METHOD = os.environ.get("CALIBRATION_METHOD", "isotonic").lower()  # "isotonic" | "platt"
THRESHOLD_STRATEGY = os.environ.get("THRESHOLD_STRATEGY", "auto").lower()      # "auto"|"j"|"f1"|"fixed"|"precision_at"|"recall_at"
FIXED_THRESHOLD = float(os.environ.get("FIXED_THRESHOLD", 0.50))
PRECISION_TARGET = os.environ.get("PRECISION_TARGET")
RECALL_TARGET = os.environ.get("RECALL_TARGET")
PRECISION_TARGET = None if PRECISION_TARGET in (None, "", "None") else float(PRECISION_TARGET)
RECALL_TARGET    = None if RECALL_TARGET    in (None, "", "None") else float(RECALL_TARGET)

# Per-country threshold mode: "global" | "country" | "hybrid"
THRESHOLD_MODE = os.environ.get("THRESHOLD_MODE", "hybrid").lower()
MIN_POS_NEG_FOR_COUNTRY = int(os.environ.get("MIN_POS_NEG_FOR_COUNTRY", 30))
MIN_TOTAL_FOR_COUNTRY   = int(os.environ.get("MIN_TOTAL_FOR_COUNTRY", 100))

# Prediction
TILE_SIZE = 1024
RESAMPLING = RioResampling.bilinear
PER_BAND_RESAMPLING = {AEZ_COL: RioResampling.nearest}
STRICT_ALL_BANDS = False
MIN_PREDICTORS = 9

# ------------------ UTILS / HELPERS ---------------------
def _meters_per_degree(lat_deg: float):
    m_per_deg_lat = 111_132.0
    m_per_deg_lon = 111_320.0 * math.cos(math.radians(lat_deg))
    return m_per_deg_lat, m_per_deg_lon

def _target_wgs84_grid_from_bounds(west, south, east, north, pixel_m=30.0):
    mid_lat = 0.5 * (south + north)
    m_lat, m_lon = _meters_per_degree(mid_lat)
    dy = pixel_m / m_lat; dx = pixel_m / m_lon
    width = int(math.ceil((east - west) / dx)); height = int(math.ceil((north - south) / dy))
    transform = Affine(dx, 0, west, 0, -dy, north)
    return height, width, transform, CRS.from_epsg(4326)

def _iter_tiles(H, W, tile):
    for r0 in range(0, H, tile):
        for c0 in range(0, W, tile):
            r1 = min(r0 + tile, H); c1 = min(c0 + tile, W)
            yield Window.from_slices((r0, r1), (c0, c1))

def _tile_count(H, W, tile):
    return ((H + tile - 1) // tile) * ((W + tile - 1) // tile)

def _compute_basic_metrics(y_true, y_prob, thr):
    y_pred = (y_prob >= thr).astype(int)
    acc = accuracy_score(y_true, y_pred)
    pr, rc, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
    try: roc = roc_auc_score(y_true, y_prob)
    except Exception: roc = float("nan")
    try: prauc = average_precision_score(y_true, y_prob)
    except Exception: prauc = float("nan")
    cm = confusion_matrix(y_true, y_pred, labels=[0,1])
    return {
        "accuracy": float(acc),
        "precision": float(pr),
        "recall": float(rc),
        "f1": float(f1),
        "roc_auc": float(roc),
        "pr_auc": float(prauc),
        "tn": int(cm[0,0]), "fp": int(cm[0,1]),
        "fn": int(cm[1,0]), "tp": int(cm[1,1]),
    }, cm

def _save_confusion_png(cm, title, path):
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.imshow(cm, cmap="Blues")
    ax.set_xticks([0,1]); ax.set_yticks([0,1])
    ax.set_xlabel("Pred"); ax.set_ylabel("True")
    ax.set_title(title)
    for (r,c),v in np.ndenumerate(cm):
        ax.text(c, r, f"{v}", ha="center", va="center")
    fig.savefig(path, dpi=200, bbox_inches="tight")
    plt.close(fig)

def assign_farm_ids(df: pd.DataFrame, lon_col=LONCOL, lat_col=LATCOL) -> pd.Series:
    if df.empty:
        return pd.Series([], dtype="int64")
    proj = pyproj.Transformer.from_crs(4326, 3857, always_xy=True)
    x, y = proj.transform(df[lon_col].values, df[lat_col].values)
    coords = np.column_stack([x, y])
    db = DBSCAN(eps=DBSCAN_EPS_M, min_samples=DBSCAN_MIN_SAMPLES, metric="euclidean", algorithm="ball_tree")
    labels = db.fit_predict(coords)
    if (labels == -1).any():
        max_lab = labels[labels >= 0].max() if np.any(labels >= 0) else -1
        noise_idx = np.where(labels == -1)[0]
        labels[noise_idx] = np.arange(max_lab + 1, max_lab + 1 + len(noise_idx))
    return pd.Series(labels.astype("int64"), index=df.index, name="farm_id")

def _pick_country_candidate(cols):
    cand = ["country_hint","country","Country","COUNTRY","admin","ADMIN","state","STATE","STATE_NAME","State","province","Province"]
    low = {c.lower(): c for c in cols}
    for c in cand:
        if c.lower() in low:
            return low[c.lower()]
    return None

# ---------------------- DRIVE HELPERS -------------------
def _drive_retry(callable_fn, *args, **kwargs):
    last_err = None
    for i in range(DRIVE_MAX_RETRIES):
        try:
            return callable_fn(*args, **kwargs)
        except (HttpError, ApiRequestError) as e:
            last_err = e
            code = getattr(getattr(e, "resp", None), "status", None)
            if code is None:
                msg = str(e).lower()
                transient = any(k in msg for k in ["internal error","backenderror","rate limit","timeout"])
            else:
                transient = 500 <= int(code) < 600 or int(code) in (403, 429)
            if not transient or i == DRIVE_MAX_RETRIES - 1:
                break
            sleep_s = DRIVE_RETRY_BASE * (2 ** i) + random.random() * 0.2
            print(f"[Drive Retry] attempt {i+1}/{DRIVE_MAX_RETRIES} after error {e}; sleeping {sleep_s:.2f}s")
            time.sleep(sleep_s)
        except Exception as e:
            last_err = e
            break
    raise last_err

def _drive_query(drive, q):
    def _run():
        return drive.ListFile({"q": q, "supportsAllDrives": True, "includeItemsFromAllDrives": True, "maxResults": 1000}).GetList()
    return _drive_retry(_run)

def get_subfolder(drive, parent_id, name):
    if drive is None:
        p = os.path.join(LOCAL_BASE_DIR, name)
        return p if os.path.isdir(p) else None
    q = f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'"
    r = _drive_query(drive, q)
    return r[0]["id"] if r else None

def get_or_create_folder(drive, parent_id, name):
    if drive is None:
        base = LOCAL_BASE_DIR if not os.path.isabs(parent_id) else parent_id
        p = os.path.join(base, name) if os.path.isdir(base) else os.path.join(LOCAL_BASE_DIR, name)
        os.makedirs(p, exist_ok=True)
        return p
    q = f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'"
    res = _drive_query(drive, q)
    if res:
        return res[0]["id"]
    def _create():
        f = drive.CreateFile({"title": name, "parents": [{"id": parent_id}], "mimeType": "application/vnd.google-apps.folder"})
        f.Upload(); return f["id"]
    return _drive_retry(_create)

def list_files(drive, parent_id):
    if drive is None:
        return sorted([os.path.join(parent_id, p) for p in os.listdir(parent_id)]) if os.path.isdir(parent_id) else []
    q = f"'{parent_id}' in parents and trashed=false"
    try: return _drive_query(drive, q)
    except Exception as e:
        print(f"[Drive Warning] list_files failed for parent {parent_id}: {e}")
        return []

def download_to_temp(drive_file, dst_path):
    def _dl():
        drive_file.GetContentFile(dst_path); return dst_path
    return _drive_retry(_dl)

def upload_path(drive, local_path, parent_id, title=None):
    if drive is None:
        os.makedirs(parent_id, exist_ok=True)
        import shutil
        dst = os.path.join(parent_id, title or os.path.basename(local_path))
        shutil.copy2(local_path, dst); return dst
    def _up():
        f = drive.CreateFile({"title": title or os.path.basename(local_path), "parents": [{"id": parent_id}]})
        f.SetContentFile(local_path); f.Upload(); return f["id"]
    return _drive_retry(_up)

def _drive_walk(drive, start_id, max_depth=4):
    q = deque([(start_id, 0)])
    while q:
        fid, d = q.popleft()
        items = list_files(drive, fid)
        yield fid, items
        if d >= max_depth: continue
        for it in items:
            if isinstance(it, dict) and it.get("mimeType") == "application/vnd.google-apps.folder":
                q.append((it["id"], d + 1))

# ---------------------- AEZ SOURCES ----------------------
def open_aez_path(drive):
    """Global AEZ.tif (fallback if per-country tile missing)."""
    if drive is None:
        p = os.path.join(LOCAL_BASE_DIR, "AEZ", AEZ_FILE_NAME)
        return p if os.path.exists(p) else None

    def _resolve_shortcut(file_obj):
        if file_obj.get("mimeType") == "application/vnd.google-apps.shortcut":
            tgt = file_obj.get("shortcutDetails", {}).get("targetId")
            if tgt:
                g = drive.CreateFile({"id": tgt})
                g.FetchMetadata(fields="title,mimeType,shortcutDetails"); return g
        return file_obj

    folder_id = None
    if AEZ_FILE_ID:
        f = drive.CreateFile({"id": AEZ_FILE_ID})
        f.FetchMetadata(fields="title,mimeType,shortcutDetails"); f = _resolve_shortcut(f)
        mime = f.get("mimeType")
        if mime == "application/vnd.google-apps.folder":
            folder_id = f["id"]
        elif mime and mime.startswith("application/vnd.google-apps."):
            raise RuntimeError(f"AEZ_FILE_ID points to a Google Doc ({mime}), not a TIFF.")
        else:
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
            f.GetContentFile(tmp); return tmp

    folder_id = folder_id or os.environ.get("AEZ_FOLDER_ID", None)
    if folder_id:
        q = f"'{folder_id}' in parents and trashed=false and title='{AEZ_FILE_NAME}'"
        cand = _drive_query(drive, q)
        if not cand:
            q_any = f"'{folder_id}' in parents and trashed=false and title contains '.tif'"
            cand = _drive_query(drive, q_any)
            cand = [c for c in cand if c.get("title","") == AEZ_FILE_NAME] or cand
        if not cand:
            raise RuntimeError(f"Could not find {AEZ_FILE_NAME} in folder id {folder_id}.")
        f = drive.CreateFile({"id": cand[0]["id"]})
        tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
        f.GetContentFile(tmp); return tmp

    p = os.path.join(LOCAL_BASE_DIR, "AEZ", AEZ_FILE_NAME)
    return p if os.path.exists(p) else None

# ---------------------- POINTS / CACHE -------------------
def _find_points_vector_drive(drive):
    try: mt = get_subfolder(drive, ROOT_FOLDER_ID, MODEL_FOLDER)
    except Exception: mt = None
    if mt:
        items = list_files(drive, mt); parts=[]; has_shp=False
        for it in items:
            if not isinstance(it, dict): continue
            title = it.get("title",""); low = title.lower()
            if not low.startswith("US."): continue
            if low.endswith(".shp"): has_shp=True
            if re.search(r"\.(shp|dbf|shx|prj|cpg|qpj)$", low): parts.append(it)
        if has_shp and parts:
            tdir = tempfile.mkdtemp()
            for p in parts: download_to_temp(p, os.path.join(tdir, p.get("title","")))
            shp_path = os.path.join(tdir, "US.shp")
            if os.path.exists(shp_path):
                print("Using vector points (shp): Model Training/US.shp"); return shp_path
    roots = [mt or ROOT_FOLDER_ID, ROOT_FOLDER_ID]
    for root in roots:
        for fid, items in _drive_walk(drive, root, max_depth=4):
            stems={}
            for it in items:
                if not isinstance(it, dict): continue
                title = it.get("title",""); m = re.match(rf"(.+)\.(shp|dbf|shx|prj|cpg|qpj)$", title, re.IGNORECASE)
                if not m: continue
                stem = m.group(1)
                if stem.lower() == POINTS_BASENAME.lower():
                    stems.setdefault(stem,[]).append(it)
            for stem, parts in stems.items():
                if any(p.get("title","").lower().endswith(".shp") for p in parts):
                    tdir = tempfile.mkdtemp()
                    for p in parts: download_to_temp(p, os.path.join(tdir, p.get("title","")))
                    shp_path = os.path.join(tdir, f"{os.path.basename(stem)}.shp")
                    if os.path.exists(shp_path):
                        print(f"Using vector points (shp): {os.path.basename(stem)}.shp"); return shp_path
            for it in items:
                if not isinstance(it, dict): continue
                title = it.get("title",""); low = title.lower()
                if it.get("mimeType") == "application/vnd.google-apps.folder": continue
                if low.endswith(".gpkg") and low == f"{POINTS_BASENAME.lower()}.gpkg":
                    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".gpkg").name
                    download_to_temp(it, tmp); print(f"Using vector points (gpkg): {title}"); return tmp
                if low.endswith(".zip") and low == f"{POINTS_BASENAME.lower()}.zip":
                    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".zip").name
                    download_to_temp(it, tmp); print(f"Using vector points (zip): {title}"); return tmp
                if low.endswith(".csv") and low == f"{POINTS_BASENAME.lower()}.csv":
                    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
                    download_to_temp(it, tmp); print(f"Using vector points (csv): {title}"); return tmp
    return None

def _find_points_vector_local():
    search_roots = [os.path.join(LOCAL_BASE_DIR, MODEL_FOLDER), LOCAL_BASE_DIR]
    for root in search_roots:
        if not os.path.isdir(root): continue
        for nm in os.listdir(root):
            low = nm.lower()
            if low == f"{POINTS_BASENAME.lower()}.csv":
                print(f"Using vector points (csv): {nm}"); return os.path.join(root, nm)
            if POINTS_BASENAME.lower() in low and low.endswith(".gpkg"):
                print(f"Using vector points (gpkg): {nm}"); return os.path.join(root, nm)
            if POINTS_BASENAME.lower() in low and low.endswith(".zip"):
                print(f"Using vector points (zip): {nm}"); return os.path.join(root, nm)
    for root in search_roots:
        if not os.path.isdir(root): continue
        for nm in os.listdir(root):
            if nm.lower().endswith(".shp") and POINTS_BASENAME.lower() in nm.lower():
                print(f"Using vector points (shp): {nm}"); return os.path.join(root, nm)
    return None

def _ci_lookup(cols, name):
    name_l = name.lower()
    for c in cols:
        if c.lower() == name_l: return c
    return None

def _round6(x):
    return np.round(pd.to_numeric(x, errors="coerce"), 6)

def _load_gtps_cache_df(drive):
    """Load US_GTPS_per_point.(parquet|csv) from Model Training if present."""
    def _from_drive():
        mt = get_or_create_folder(drive, ROOT_FOLDER_ID, MODEL_FOLDER)
        # prefer parquet
        res = _drive_query(drive, f"'{mt}' in parents and trashed=false and title='{GTPS_CACHE_PARQUET}'")
        if res:
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".parquet").name
            download_to_temp(res[0], tmp)
            try:
                return pd.read_parquet(tmp)
            except Exception:
                pass
        res = _drive_query(drive, f"'{mt}' in parents and trashed=false and title='{GTPS_CACHE_CSV}'")
        if res:
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
            download_to_temp(res[0], tmp)
            return pd.read_csv(tmp)
        return None

    def _from_local():
        p1 = os.path.join(LOCAL_MODEL_DIR, GTPS_CACHE_PARQUET)
        p2 = os.path.join(LOCAL_MODEL_DIR, GTPS_CACHE_CSV)
        if os.path.exists(p1):
            try:
                return pd.read_parquet(p1)
            except Exception:
                pass
        if os.path.exists(p2):
            return pd.read_csv(p2)
        return None

    try:
        return _from_drive() if drive is not None else _from_local()
    except Exception as e:
        print(f"   ! AEZ cache read error: {e}")
        return None

def load_points_df(drive):
    import geopandas as gpd, fiona
    path = _find_points_vector_drive(drive) if drive is not None else _find_points_vector_local()
    if path is None:
        raise RuntimeError("Could not locate US.(shp/zip/gpkg/csv) in 'Model Training' tree.")
    if path.lower().endswith(".csv"):
        df = pd.read_csv(path)
    else:
        with fiona.Env(SHAPE_RESTORE_SHX="YES"):
            gdf = gpd.read_file(path)
        if gdf.crs is not None and gdf.crs.to_epsg() != 4326:
            gdf = gdf.to_crs(4326)
        df = pd.DataFrame(gdf.drop(columns="geometry", errors="ignore"))

    lon_col = _ci_lookup(df.columns, LONCOL); lat_col = _ci_lookup(df.columns, LATCOL)
    if lon_col is None or lat_col is None:
        raise RuntimeError("Points must contain longitude and latitude columns.")
    if lon_col != LONCOL: df[LONCOL] = pd.to_numeric(df[lon_col], errors="coerce")
    if lat_col != LATCOL: df[LATCOL] = pd.to_numeric(df[lat_col], errors="coerce")

    tgt_col = _ci_lookup(df.columns, TARGET)
    if tgt_col is None: raise RuntimeError(f"Points file missing column: {TARGET}")
    if tgt_col != TARGET: df[TARGET] = pd.to_numeric(df[tgt_col], errors="coerce")

    # pre-create rounders for cache join
    df["_lonr"] = _round6(df[LONCOL]); df["_latr"] = _round6(df[LATCOL])

    # copy a country-like column
    if COUNTRY_COL not in df.columns:
        cand = _pick_country_candidate(df.columns)
        df[COUNTRY_COL] = df[cand] if cand else "Unknown"

    for c in [TARGET, LONCOL, LATCOL]:
        if c in df.columns: df[c] = pd.to_numeric(df[c], errors="coerce")
    df = df.dropna(subset=[TARGET, LONCOL, LATCOL]).reset_index(drop=True)
    df[TARGET] = df[TARGET].astype(int)
    df[COUNTRY_COL] = df[COUNTRY_COL].astype(str)

    # Bring AEZ from cache if available
    if AEZ_COL not in df.columns or df[AEZ_COL].isna().all():
        cache = _load_gtps_cache_df(drive)
        if cache is not None:
            # normalize cache cols
            cc = {c.lower(): c for c in cache.columns}
            lonc = cc.get(LONCOL.lower(), None) or cc.get("x") or cc.get("lon") or list(cache.columns)[1]
            latc = cc.get(LATCOL.lower(), None) or cc.get("y") or cc.get("lat") or list(cache.columns)[2]
            aezc = cc.get(AEZ_COL.lower(), None) or cc.get("gtps") or cc.get("aez")
            if aezc is not None:
                cache["_lonr"] = _round6(cache[lonc]); cache["_latr"] = _round6(cache[latc])
                df = df.merge(cache[["_lonr","_latr",aezc]].rename(columns={aezc:AEZ_COL}),
                              on=["_lonr","_latr"], how="left")
                if AEZ_COL in df.columns:
                    df[AEZ_COL] = pd.to_numeric(df[AEZ_COL], errors="coerce")
                    df[AEZ_COL] = np.rint(df[AEZ_COL]).astype("float64")
                    print(f"   • AEZ filled from cache for {np.isfinite(df[AEZ_COL]).sum()}/{len(df)} points")
    return df

# ---------------------- RASTERS -------------------------
def open_country_var_path(drive, country, variable):
    """Any By Country/<Country>/<Country>_<variable>.tif (case-sensitive on names passed in)."""
    if drive is None:
        p = os.path.join(LOCAL_BY_COUNTRY_DIR, country, f"{country}_{variable}.tif")
        return p if os.path.exists(p) else None
    byc = get_subfolder(drive, ROOT_FOLDER_ID, BY_COUNTRY_NAME)
    if not byc: return None
    q = f"'{byc}' in parents and trashed=false and title='{country}' and mimeType='application/vnd.google-apps.folder'"
    res = _drive_query(drive, q)
    if not res: return None
    cid = res[0]["id"]
    q2 = f"'{cid}' in parents and trashed=false and title='{country}_{variable}.tif'"
    files = _drive_query(drive, q2)
    if files:
        tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
        return download_to_temp(files[0], tmp)
    for it in list_files(drive, cid):
        if isinstance(it, dict) and re.match(rf"^{re.escape(country)}_{re.escape(variable)}.*\.tif$", it.get("title",""), re.IGNORECASE):
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
            return download_to_temp(it, tmp)
    return None

def open_CroplandNE(drive):
    if drive is None:
        p = os.path.join(LOCAL_CLE_DIR, "Cropland_NE.tif")
        return p if os.path.exists(p) else None
    items = list_files(drive, CLE_FOLDER_ID)
    for it in items:
        if not isinstance(it, dict): continue
        if it.get("title","").lower().startswith("Cropland_NE") and it.get("title","").lower().endswith(".tif"):
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
            return download_to_temp(it, tmp)
    return None

def _clip_bounds_from_reference(raster_path):
    with rasterio.open(raster_path) as src:
        b = rio_transform_bounds(src.crs, CRS.from_epsg(4326),
                                 src.bounds.left, src.bounds.bottom, src.bounds.right, src.bounds.top, densify_pts=8)
    return b

def _pick_reference_raster_for_bounds(drive, country):
    # Prefer any available per-country raster (including AEZ tile) to set the grid.
    # IMPORTANT: do NOT fall back to global AEZ here, to avoid predicting outside the country.
    for v in ["CLE"] + ALL_PREDICTORS + [AEZ_COL]:
        if v == AEZ_COL:
            p = open_country_var_path(drive, country, AEZ_COL)
        else:
            p = open_country_var_path(drive, country, v)
        if p is not None:
            return p
    raise RuntimeError(f"No per-country reference raster found for bounds in '{country}'.")

def country_grid_and_mask(drive, country):
    ref_path = _pick_reference_raster_for_bounds(drive, country)
    west, south, east, north = _clip_bounds_from_reference(ref_path)
    H, W, transform, crs = _target_wgs84_grid_from_bounds(west, south, east, north, pixel_m=30.0)

    # Prefer per-country cropland, but fall back to global CroplandNE if needed.
    country_cropland = open_country_var_path(drive, country, "CLE")
    if country_cropland is not None:
        cropland = clip_cle_to_grid(drive, H, W, transform, crs, source_path=country_cropland)
    else:
        print(f"   ! No per-country CLE for '{country}'; using global CroplandNE.tif as cropland mask.")
        cropland = clip_cle_to_grid(drive, H, W, transform, crs, source_path=None)

    return (H, W, transform, crs, cropland, (west, south, east, north))

def clip_cle_to_grid(drive, out_h, out_w, out_transform, out_crs, source_path=None):
    cle = source_path or open_CroplandNE(drive)
    if cle is None: return np.ones((out_h, out_w), dtype=bool)
    arr = np.full((out_h, out_w), -9999.0, dtype=np.float32)
    with rasterio.open(cle) as src:
        reproject(source=rasterio.band(src, 1), destination=arr,
                  src_transform=src.transform, src_crs=src.crs,
                  dst_transform=out_transform, dst_crs=out_crs,
                  dst_nodata=-9999.0, resampling=RioResampling.nearest)
        nd = src.nodata
    if nd is not None: arr = np.where(np.isclose(arr, nd), np.nan, arr)
    arr = np.where(np.isclose(arr, -9999.0), np.nan, arr)
    return np.isfinite(arr) & (arr > 0.5)

# ---------------- THRESHOLDS -----------------------------
def _thr_by_J(y_true, p):
    y = np.asarray(y_true).astype(int)
    if np.unique(y).size < 2: return FIXED_THRESHOLD
    fpr, tpr, th = roc_curve(y, p)
    if len(th) < 2: return FIXED_THRESHOLD
    j = tpr - fpr
    return float(th[int(np.argmax(j))])

def _thr_by_F1(y_true, p):
    y = np.asarray(y_true).astype(int)
    ths = np.linspace(0.01, 0.99, 99); best_t, best_f1 = 0.5, -1
    for t in ths:
        yhat = (p >= t).astype(int)
        _, _, f1, _ = precision_recall_fscore_support(y, yhat, average="binary", zero_division=0)
        if f1 > best_f1: best_f1, best_t = f1, t
    return float(best_t)

def _thr_precision_at(y_true, p, target):
    if target is None: return _thr_by_J(y_true, p)
    prec, rec, thr = precision_recall_curve(y_true, p)
    cand = thr[prec[1:] >= target]
    if len(cand) > 0:
        return float(np.min(cand))  
    return _thr_by_F1(y_true, p)

def _thr_recall_at(y_true, p, target):
    if target is None: return _thr_by_J(y_true, p)
    prec, rec, thr = precision_recall_curve(y_true, p)
    cand = thr[rec[1:] >= target]
    if len(cand) > 0:
        return float(np.max(cand))  
    return _thr_by_F1(y_true, p)

def _pick_threshold(y_true, p):
    s = THRESHOLD_STRATEGY.lower()
    if s == "fixed":       return float(FIXED_THRESHOLD)
    if s == "f1":          return _thr_by_F1(y_true, p)
    if s == "j":           return _thr_by_J(y_true, p)
    if s == "precision_at":return _thr_precision_at(y_true, p, PRECISION_TARGET)
    if s == "recall_at":   return _thr_recall_at(y_true, p, RECALL_TARGET)
    if PRECISION_TARGET is not None:
        return _thr_precision_at(y_true, p, PRECISION_TARGET)
    if RECALL_TARGET is not None:
        return _thr_recall_at(y_true, p, RECALL_TARGET)
    return _thr_by_J(y_true, p)

# ---------------- Calibration helpers -------------------
def _fit_calibrator(raw_pos_probs, y, method="isotonic"):
    raw_pos_probs = np.asarray(raw_pos_probs, dtype=float).ravel()
    y = np.asarray(y, dtype=int).ravel()
    if method == "platt":
        lr = LogisticRegression(solver="lbfgs", max_iter=1000)
        lr.fit(raw_pos_probs.reshape(-1,1), y)
        return {"method": "platt", "model": lr}
    ir = IsotonicRegression(y_min=0.0, y_max=1.0, out_of_bounds="clip")
    ir.fit(raw_pos_probs, y)
    return {"method": "isotonic", "model": ir}

def _apply_calibrator(calib, raw_pos_probs):
    raw_pos_probs = np.asarray(raw_pos_probs, dtype=float).ravel()
    if calib is None:
        return raw_pos_probs
    if calib.get("method") == "platt":
        return calib["model"].predict_proba(raw_pos_probs.reshape(-1,1))[:,1]
    return calib["model"].predict(raw_pos_probs)

# ------------- POINT SAMPLING FROM RASTERS --------------
def _sample_raster_at_lonlat(src, lon_arr, lat_arr):
    xs = np.asarray(lon_arr, dtype=float); ys = np.asarray(lat_arr, dtype=float)
    if src.crs and (src.crs.to_epsg() != 4326):
        tx, ty = rio_transform(CRS.from_epsg(4326), src.crs, xs.tolist(), ys.tolist())
        xs = np.asarray(tx, dtype=float); ys = np.asarray(ty, dtype=float)
    out = np.full(xs.shape[0], np.nan, dtype=np.float32)
    xmin, ymin, xmax, ymax = src.bounds.left, src.bounds.bottom, src.bounds.right, src.bounds.top
    inside = (xs >= xmin) & (xs <= xmax) & (ys >= ymin) & (ys <= ymax)
    if not np.any(inside): return out
    idx = np.where(inside)[0]
    coords = list(zip(xs[idx], ys[idx]))
    vals = np.array([v[0] for v in src.sample(coords)], dtype=np.float32)
    if src.nodata is not None and np.isfinite(src.nodata):
        vals = np.where(np.isclose(vals, np.float32(src.nodata)), np.float32(np.nan), vals)
    out[idx] = vals; return out

# ---- TRAINING FILL (num predictors + AEZ) ---------------
def _fill_training_predictors_from_country_rasters(drive, df_all, countries):
    df = df_all.copy()

    # AEZ: if still missing after cache merge, try per-country tiles, else global
    if AEZ_COL not in df.columns or df[AEZ_COL].isna().any():
        miss = df.index if AEZ_COL not in df.columns else df.index[df[AEZ_COL].isna()]
        if len(miss) > 0:
            aez_filled = 0
            # try per-country tiles first
            for c in countries:
                try:
                    ref = open_country_var_path(drive, c, AEZ_COL) or open_aez_path(drive)
                    if not ref: continue
                    w, s, e, n = _clip_bounds_from_reference(ref)
                    idx = df.loc[miss].index[
                        df.loc[miss, LONCOL].between(min(w,e), max(w,e)) &
                        df.loc[miss, LATCOL].between(min(s,n), max(s,n))
                    ]
                    if len(idx) == 0: continue
                    with rasterio.open(ref) as aez_src:
                        vals = _sample_raster_at_lonlat(aez_src, df.loc[idx, LONCOL].values, df.loc[idx, LATCOL].values)
                    df.loc[idx, AEZ_COL] = np.rint(pd.to_numeric(vals, errors="coerce")).astype("float64")
                    aez_filled += int(np.isfinite(df.loc[idx, AEZ_COL]).sum())
                except Exception as ex:
                    print(f"   ! AEZ fill skip for {c}: {ex}")
            if aez_filled > 0:
                print(f"   • AEZ sampled from tiles/global for {aez_filled} points")

    # Per-country numeric predictors from rasters
    for c in countries:
        try:
            ref = open_country_var_path(drive, c, "CLE") or open_country_var_path(drive, c, "NDVI_mean")
            if not ref: continue
            w, s, e, n = _clip_bounds_from_reference(ref)
        except Exception as ex:
            print(f"   ! Skip raster fill for {c}: {ex}"); continue
        idx = df.index[df[LONCOL].between(min(w,e), max(w,e)) & df[LATCOL].between(min(s,n), max(s,n))]
        if len(idx) == 0: continue
        raster_cache = {}; filled_any = False
        for p in ALL_PREDICTORS:
            if p == AEZ_COL: continue
            pth = open_country_var_path(drive, c, p)
            if pth is None: continue
            if p not in raster_cache: raster_cache[p] = rasterio.open(pth)
            vals = _sample_raster_at_lonlat(raster_cache[p], df.loc[idx, LONCOL].values, df.loc[idx, LATCOL].values)
            df.loc[idx, p] = pd.to_numeric(vals, errors="coerce").astype("float64")
            filled_any = True
        if filled_any:
            print(f"   • Training fill from rasters for {c}: {len(idx)} points")
    return df

def list_countries_auto(drive):
    names = set()
    if drive is None:
        if os.path.isdir(LOCAL_BY_COUNTRY_DIR):
            for d in os.listdir(LOCAL_BY_COUNTRY_DIR):
                if os.path.isdir(os.path.join(LOCAL_BY_COUNTRY_DIR, d)): names.add(d)
    else:
        byc = get_subfolder(drive, ROOT_FOLDER_ID, BY_COUNTRY_NAME)
        if byc:
            subs = _drive_query(drive, f"'{byc}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder'")
            for it in subs: names.add(it["title"])
    return sorted(names)

# ----------------- TRAIN REGIONAL MODEL -----------------
def _borrow_neighbors(df_all, base_cols, max_points):
    if COUNTRY_COL not in df_all.columns:
        df_all[COUNTRY_COL] = "Unknown"
    keep_cols = list(dict.fromkeys(base_cols))
    pooled = df_all[keep_cols].dropna(subset=[TARGET, LONCOL, LATCOL]).copy()
    if len(pooled) > max_points:
        pooled = pooled.sample(max_points, random_state=RANDOM_SEED)
    if pooled[TARGET].nunique() < 2:
        other = 1 - int(pooled[TARGET].iloc[0])
        extra = df_all[df_all[TARGET] == other]
        if not extra.empty:
            take = min(1000, len(extra))
            pooled = pd.concat([pooled, extra.sample(take, random_state=RANDOM_SEED)[keep_cols]], ignore_index=True)
    if COUNTRY_COL not in pooled.columns:
        pooled[COUNTRY_COL] = "Unknown"
    pooled[COUNTRY_COL] = pooled[COUNTRY_COL].astype(str)
    return pooled.reset_index(drop=True)

def train_region_model(drive, region_name, df_all, countries):
    print(f"=== Train regional model: {region_name} ===")

    # Fill predictors (uses AEZ cache + per-country AEZ tiles)
    df_all = _fill_training_predictors_from_country_rasters(drive, df_all, countries)

    use_preds = [p for p in ALL_PREDICTORS if p in df_all.columns]
    base_cols = [TARGET, LONCOL, LATCOL, COUNTRY_COL] + use_preds
    df = _borrow_neighbors(df_all, base_cols, BORROW_MAX)

    if len(df) < 200 or df[TARGET].nunique() < 2:
        raise RuntimeError(f"Insufficient pooled points for region training (n={len(df)}, classes={df[TARGET].nunique()}).")

    num_features = [f for f in use_preds if f != AEZ_COL]
    cat_features = [AEZ_COL] if AEZ_COL in use_preds else []

    pre = ColumnTransformer(
        transformers=[
            ("num", make_pipeline(SimpleImputer(strategy="median")), num_features),
            ("cat", make_pipeline(SimpleImputer(strategy="most_frequent"),
                                  OneHotEncoder(handle_unknown="ignore", sparse_output=False)), cat_features),
        ],
        remainder="drop",
        verbose_feature_names_out=False,
    )

    rf = RandomForestClassifier(
        n_estimators=RF_TREES, max_features="sqrt", class_weight="balanced_subsample",
        n_jobs=-1, random_state=RANDOM_SEED, bootstrap=True,
    )
    pipe = Pipeline([("pre", pre), ("rf", rf)])

    # ---------- Group-aware OOF to fit calibrator ----------
    df["farm_id"] = assign_farm_ids(df, lon_col=LONCOL, lat_col=LATCOL)
    n_splits = 5 if 0.19 <= TEST_FRACTION <= 0.21 else max(3, int(round(1.0 / max(TEST_FRACTION, 1e-3))))
    sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_SEED)

    X_all = df[use_preds].copy()
    y_all = df[TARGET].astype(int).values
    g_all = df["farm_id"].values
    countries_all = df[COUNTRY_COL].astype(str).values
    lon_all = df[LONCOL].values
    lat_all = df[LATCOL].values

    # enforce dtypes
    for col in [f for f in use_preds if f != AEZ_COL]:
        X_all[col] = pd.to_numeric(X_all[col], errors="coerce").astype("float64")
    if AEZ_COL in use_preds:
        X_all[AEZ_COL] = pd.to_numeric(X_all[AEZ_COL], errors="coerce")
        X_all[AEZ_COL] = np.rint(X_all[AEZ_COL]).astype("float64")

    oof_raw = np.full(len(X_all), np.nan, dtype=float)
    classes_ = None
    for tr_idx, te_idx in sgkf.split(X_all, y_all, groups=g_all):
        tr_X = X_all.iloc[tr_idx]; tr_y = y_all[tr_idx]
        te_X = X_all.iloc[te_idx]
        pipe_fold = clone(pipe)
        pipe_fold.fit(tr_X, tr_y.astype(int))
        if classes_ is None:
            classes_ = pipe_fold.named_steps["rf"].classes_
        pos_idx_fold = int(np.where(classes_ == 1)[0][0])
        oof_raw[te_idx] = pipe_fold.predict_proba(te_X)[:, pos_idx_fold]

    # fit final model on all data
    pipe.fit(X_all, y_all.astype(int))
    classes_ = pipe.named_steps["rf"].classes_
    pos_idx = int(np.where(classes_ == 1)[0][0])

    # calibrate on OOF
    calib = _fit_calibrator(oof_raw, y_all, CALIBRATION_METHOD)
    oof_cal = _apply_calibrator(calib, oof_raw)

    # ---------- GLOBAL threshold from OOF calibrated ----------
    global_thr = _pick_threshold(y_all, oof_cal)

    # ---------- Regional OOF metrics & confusion ----------
    reg_metrics, reg_cm = _compute_basic_metrics(y_all, oof_cal, global_thr)
    reg_metrics.update({
        "region": region_name,
        "n_oof": int(len(y_all)),
        "threshold_strategy": THRESHOLD_STRATEGY,
        "global_threshold": float(global_thr),
        "calibration": CALIBRATION_METHOD,
    })

    # feature importances from final model
    out_names = list(pipe.named_steps["pre"].get_feature_names_out())
    rf_imp = pipe.named_steps["rf"].feature_importances_
    fi_df = pd.DataFrame({"feature": out_names, "importance": rf_imp}).sort_values("importance", ascending=False)

    # ---------- Prep folders ----------
    if drive is None:
        reg_base = os.path.join(LOCAL_MODEL_DIR, REGIONAL_MODELS_FOLDER, region_name)
        os.makedirs(reg_base, exist_ok=True)
        c_base_root = os.path.join(LOCAL_MODEL_DIR, COUNTRY_MODELS_FOLDER)
        os.makedirs(c_base_root, exist_ok=True)
    else:
        mt = get_or_create_folder(drive, ROOT_FOLDER_ID, MODEL_FOLDER)
        reg = get_or_create_folder(drive, mt, REGIONAL_MODELS_FOLDER)
        reg_base = get_or_create_folder(drive, reg, region_name)
        c_base_root = get_or_create_folder(drive, mt, COUNTRY_MODELS_FOLDER)

    # ---------- Ensure country tags (bbox fallback) ----------
    assigned = countries_all.copy()
    unknown_mask = (assigned == "Unknown") | (assigned == "") | pd.isna(assigned)
    if unknown_mask.any():
        for ctry in countries:
            try:
                ref = open_country_var_path(drive, ctry, "CLE") or open_country_var_path(drive, ctry, AEZ_COL) or open_aez_path(drive)
                if not ref: continue
                w, s, e, n = _clip_bounds_from_reference(ref)
                inside = unknown_mask & (lon_all >= min(w, e)) & (lon_all <= max(w, e)) \
                                   & (lat_all >= min(s, n)) & (lat_all <= max(s, n))
                assigned[inside] = ctry
            except Exception:
                pass
        countries_all = assigned

    # ---------- Per-country thresholds & OOF metrics ----------
    oof_df = pd.DataFrame({
        "country": countries_all.astype(str),
        "y": y_all.astype(int),
        "prob_cal": oof_cal.astype(float),
    })

    per_rows = []
    per_thr = {}

    for ctry in sorted(set(countries)):
        sub = oof_df[oof_df["country"] == ctry]
        n_total = int(len(sub))
        n_pos = int((sub["y"] == 1).sum())
        n_neg = n_total - n_pos

        if (n_total < MIN_TOTAL_FOR_COUNTRY) or (n_pos < MIN_POS_NEG_FOR_COUNTRY) or (n_neg < MIN_POS_NEG_FOR_COUNTRY):
            thr_c = float(global_thr); note = "fallback_global"
        else:
            thr_c = _pick_threshold(sub["y"].values, sub["prob_cal"].values); note = "oof_calibrated"

        per_thr[ctry] = float(thr_c)
        m_c, cm_c = _compute_basic_metrics(sub["y"].values, sub["prob_cal"].values, thr_c)
        m_c.update({
            "region": region_name, "country": ctry, "n_oof": n_total,
            "n_pos": n_pos, "n_neg": n_neg, "threshold_used": float(thr_c),
            "note": note
        })
        per_rows.append(m_c)

        # save country artifacts (metrics + confusion + threshold.json)
        cdir = get_or_create_folder(drive, c_base_root, ctry)
        with tempfile.TemporaryDirectory() as tdir:
            with open(os.path.join(tdir, "test_metrics.json"), "w") as f:
                json.dump(m_c, f, indent=2)
            upload_path(drive, os.path.join(tdir, "test_metrics.json"), cdir, "test_metrics.json")

            png = os.path.join(tdir, "test_confusion.png")
            _save_confusion_png(cm_c, f"{ctry} — OOF Confusion (thr={thr_c:.3f})", png)
            upload_path(drive, png, cdir, "test_confusion.png")

            with open(os.path.join(tdir, "threshold.json"), "w") as f:
                json.dump({
                    "threshold_strategy": THRESHOLD_STRATEGY,
                    "mode": THRESHOLD_MODE,
                    "calibration": CALIBRATION_METHOD,
                    "threshold": float(thr_c),
                    "n_total": n_total, "n_pos": n_pos, "n_neg": n_neg,
                    "note": note
                }, f, indent=2)
            upload_path(drive, os.path.join(tdir, "threshold.json"), cdir, "threshold.json")

    # regional CSV summary for per-country
    if len(per_rows) > 0:
        per_df = pd.DataFrame(per_rows).sort_values("country")
    else:
        per_df = pd.DataFrame(columns=["country","n_oof","precision","recall","f1","roc_auc","pr_auc",
                                       "tn","fp","fn","tp","threshold_used","note"])

    # ---------- Save regional artifacts ----------
    with tempfile.TemporaryDirectory() as tdir:
        with open(os.path.join(tdir, "metrics.json"), "w") as f:
            json.dump(reg_metrics, f, indent=2)
        upload_path(drive, os.path.join(tdir, "metrics.json"), reg_base, "metrics.json")

        fi_df.to_csv(os.path.join(tdir, "feature_importance.csv"), index=False)
        upload_path(drive, os.path.join(tdir, "feature_importance.csv"), reg_base, "feature_importance.csv")

        per_df.to_csv(os.path.join(tdir, "per_country_test_metrics.csv"), index=False)
        upload_path(drive, os.path.join(tdir, "per_country_test_metrics.csv"), reg_base, "per_country_test_metrics.csv")

        png = os.path.join(tdir, "confusion.png")
        _save_confusion_png(reg_cm, f"{region_name} — OOF Confusion ({THRESHOLD_STRATEGY}, thr={global_thr:.3f})", png)
        upload_path(drive, png, reg_base, "confusion.png")

        joblib.dump(
            {
                "pipeline": pipe,
                "predictors": use_preds,
                "positive_class_index": int(pos_idx),
                "calibration": calib,
                "global_threshold": float(global_thr),
                "threshold_strategy": THRESHOLD_STRATEGY,
                "per_country_thresholds": per_thr,
                "num_features": [f for f in use_preds if f != AEZ_COL],
                "cat_features": [AEZ_COL] if AEZ_COL in use_preds else [],
            },
            os.path.join(tdir, "model.joblib"),
        )
        upload_path(drive, os.path.join(tdir, "model.joblib"), reg_base, "model.joblib")

    model = {
        "pipe": pipe,
        "use_preds": use_preds,
        "pos_idx": int(pos_idx),
        "calib": calib,
        "global_thr": float(global_thr),
        "per_thr": per_thr,
        "reg_base": reg_base
    }
    return model


# ----------------- PREDICT ONE COUNTRY ------------------
def predict_country(drive, country, model, thr, grid_pack):
    out_dir = get_or_create_folder(drive, ROOT_FOLDER_ID, OUTPUT_FOLDER) if drive else LOCAL_OUTPUT_DIR
    os.makedirs(out_dir, exist_ok=True)

    pipe   = model["pipe"]
    pos_idx= model["pos_idx"]
    calib  = model["calib"]

    H, W, transform, crs, cropland = grid_pack
    raster_cache = {}

    # AEZ: prefer per-country tile, fallback to global AEZ
    aez_src_path = open_country_var_path(drive, country, AEZ_COL) or open_aez_path(drive)
    if AEZ_COL in model["use_preds"] and aez_src_path is not None:
        raster_cache[AEZ_COL] = rasterio.open(aez_src_path)

    for v in model["use_preds"]:
        if v == AEZ_COL: continue
        p = open_country_var_path(drive, country, v)
        if p is not None: raster_cache[v] = rasterio.open(p)

    base_profile = {
        "driver": "GTiff","height": H,"width": W,"count": 1,"crs": crs,"transform": transform,
        "compress": "LZW","tiled": True,"blockxsize": 512,"blockysize": 512,
    }
    prob_profile = {**base_profile, "dtype": "float32", "nodata": -9999.0}
    bin_profile  = {**base_profile, "dtype": "uint8", "nodata": 255}
    cnt_profile  = {**base_profile, "dtype": "uint8", "nodata": 0}

    prob_name = f"{country}_RF_probability_percent.tif"
    bin_name  = f"{country}_RF_binary_0_1_cropland.tif"
    cnt_name  = f"{country}_RF_predictors_count.tif"

    eligible_total = 0; positive_total = 0

    with tempfile.TemporaryDirectory() as tdir:
        prob_tmp = os.path.join(tdir, prob_name)
        bin_tmp  = os.path.join(tdir, bin_name)
        cnt_tmp  = os.path.join(tdir, cnt_name)
        with rasterio.open(prob_tmp, "w", **prob_profile) as dst_prob, \
             rasterio.open(bin_tmp,  "w", **bin_profile)  as dst_bin,  \
             rasterio.open(cnt_tmp,  "w", **cnt_profile) as dst_cnt:

            for win in tqdm(_iter_tiles(H, W, TILE_SIZE), total=_tile_count(H, W, TILE_SIZE), desc=f"Predict {country}"):
                r0, r1 = int(win.row_off), int(win.row_off + win.height)
                c0, c1 = int(win.col_off), int(win.col_off + win.width)
                tile_h, tile_w = (r1 - r0), (c1 - c0)
                tile_transform = transform * Affine.translation(c0, r0)

                crop_tile = cropland[r0:r1, c0:c1]
                if not crop_tile.any():
                    dst_prob.write(np.full((1, tile_h, tile_w), prob_profile["nodata"], np.float32), window=win)
                    dst_bin.write (np.full((1, tile_h, tile_w), bin_profile["nodata"], np.uint8), window=win)
                    dst_cnt.write (np.zeros((1, tile_h, tile_w), np.uint8), window=win)
                    continue

                stack = np.full((len(model["use_preds"]), tile_h, tile_w), np.nan, dtype=np.float32)
                valid_count = np.zeros((tile_h, tile_w), dtype=np.uint8)

                for j, v in enumerate(model["use_preds"]):
                    src = raster_cache.get(v)
                    if src is None:
                        p = open_country_var_path(drive, country, v) if v != AEZ_COL else (open_country_var_path(drive, country, AEZ_COL) or open_aez_path(drive))
                        if p: src = raster_cache[v] = rasterio.open(p)
                        else: continue
                    dst_arr = np.full((tile_h, tile_w), np.nan, dtype=np.float32)
                    dst_nd  = -9999.0
                    resamp  = PER_BAND_RESAMPLING.get(v, RESAMPLING)
                    try:
                        reproject(source=rasterio.band(src, 1), destination=dst_arr,
                                  src_transform=src.transform, src_crs=src.crs,
                                  dst_transform=tile_transform, dst_crs=crs,
                                  dst_nodata=dst_nd, resampling=resamp)
                    except Exception as e:
                        print(f"      ! reprojection failed for {v}: {e}"); continue
                    if src.nodata is not None:
                        dst_arr = np.where(np.isclose(dst_arr, src.nodata), np.nan, dst_arr)
                    dst_arr = np.where(np.isclose(dst_arr, dst_nd), np.nan, dst_arr)
                    stack[j] = dst_arr; valid_count += np.isfinite(dst_arr)

                req = (len(model["use_preds"]) if STRICT_ALL_BANDS else min(MIN_PREDICTORS, len(model["use_preds"])))
                good = crop_tile & (valid_count >= req)
                prob_tile = np.full((tile_h, tile_w), np.nan, dtype=np.float32)

                if good.any():
                    X = stack.reshape(len(model["use_preds"]), -1).T
                    good_flat = good.reshape(-1)
                    X_good = X[good_flat].astype("float64")
                    X_good_df = pd.DataFrame(X_good, columns=model["use_preds"])

                    # enforce dtypes
                    num_features = [f for f in model["use_preds"] if f != AEZ_COL]
                    for col in num_features:
                        X_good_df[col] = pd.to_numeric(X_good_df[col], errors="coerce").astype("float64")
                    if AEZ_COL in X_good_df.columns:
                        X_good_df[AEZ_COL] = pd.to_numeric(X_good_df[AEZ_COL], errors="coerce")
                        X_good_df[AEZ_COL] = np.rint(X_good_df[AEZ_COL]).astype("float64")

                    prob_raw = pipe.predict_proba(X_good_df)[:, pos_idx].astype(np.float32)
                    prob_vals = _apply_calibrator(calib, prob_raw).astype(np.float32)

                    bin_vals  = (prob_vals >= thr).astype(np.uint8)
                    prob_tile.reshape(-1)[good_flat] = prob_vals

                    mask_arr = np.isfinite(prob_tile)
                    prob_out = np.full((1, tile_h, tile_w), prob_profile["nodata"], np.float32)
                    prob_out[0][mask_arr] = (prob_tile[mask_arr] * 100.0).astype(np.float32)
                    dst_prob.write(prob_out, window=win)

                    bin_out = np.full((1, tile_h, tile_w), bin_profile["nodata"], np.uint8)
                    bin_out[0].reshape(-1)[good_flat] = bin_vals
                    dst_bin.write(bin_out, window=win)

                    eligible_total += int(mask_arr.sum()); positive_total += int(bin_vals.sum())
                else:
                    dst_prob.write(np.full((1, tile_h, tile_w), prob_profile["nodata"], np.float32), window=win)
                    dst_bin.write (np.full((1, tile_h, tile_w), bin_profile["nodata"], np.uint8), window=win)

                dst_cnt.write(valid_count[np.newaxis], window=win)

        upload_path(drive, prob_tmp, out_dir, prob_name)
        upload_path(drive, bin_tmp,  out_dir, bin_name)
        upload_path(drive, cnt_tmp,  out_dir, cnt_name)

    frac = (positive_total / max(1, eligible_total)) * 100.0
    print(f"   • predicted irrigated fraction inside cropland: {frac:.3f}% ({positive_total:,}/{eligible_total:,})")

# ----------------------------- RUNNER -------------------
def run_region_then_countries(use_drive=True, countries=None, region_name="US"):
    if use_drive:
        try:
            drive  # noqa: F821
            _drive = drive
            try: _drive.auth.service.http.timeout = 120
            except Exception: pass
        except NameError:
            raise RuntimeError("PyDrive 'drive' not found. Authenticate and expose 'drive', or set use_drive=False.")
    else:
        _drive = None
        os.makedirs(LOCAL_OUTPUT_DIR, exist_ok=True)

    df_all = load_points_df(_drive)
    for c in [TARGET, LONCOL, LATCOL] + [p for p in ALL_PREDICTORS if p in df_all.columns and p != AEZ_COL]:
        if c in df_all.columns:
            df_all[c] = pd.to_numeric(df_all[c], errors="coerce")
    df_all = df_all.dropna(subset=[TARGET, LONCOL, LATCOL]).reset_index(drop=True)
    df_all[TARGET] = df_all[TARGET].astype(int)

    if not countries:
        countries = [
            "Sudan","Chad","Central_African_Republic", "Cameroon", "Ethiopia",
            # "Niger",
            # "Niger", 
            # "Cote_d_Ivoire" , 
            # "Niger"
            # "Mali",

        ]

    model = train_region_model(_drive, region_name, df_all, countries)

    # Predict
    print(f"=== PREDICTION per-country (threshold mode: {THRESHOLD_MODE}) ===")
    for c in countries:
        try:
            H, W, transform, crs, cropland, _ = country_grid_and_mask(_drive, c)
            thr_c = model["per_thr"].get(c, model["global_thr"]) if THRESHOLD_MODE in ("country","hybrid") else model["global_thr"]
            predict_country(_drive, c, model, thr_c, (H, W, transform, crs, cropland))
            print(f"✓ {c}: maps → {OUTPUT_FOLDER}, model → {MODEL_FOLDER}/{REGIONAL_MODELS_FOLDER}/{region_name}, thr={thr_c:.3f}")
        except Exception as e:
            print(f"❌ {c}: {e}")

    print("✅ Done.")

if __name__ == "__main__":
    run_region_then_countries(use_drive=True, countries=None, region_name="Africa")


Using vector points (shp): Africa.shp
   • AEZ filled from cache for 265/15126 points
=== Train regional model: Africa ===
   • AEZ sampled from tiles/global for 14861 points
