In [None]:
from pydrive2.auth import GoogleAuth
from pydrive2.drive import GoogleDrive

# local dirve path to  JSON.
CLIENT_JSON = "client_secret_Final.json"   #Google Dirve API
gauth = GoogleAuth()
gauth.LoadClientConfigFile(CLIENT_JSON)

# Command-line auth -> 
gauth.CommandLineAuth()

# Save token 
gauth.SaveCredentialsFile('token.json')

drive = GoogleDrive(gauth)
print("‚úÖ Google Drive connected")

In [None]:
"""
The default GEE tif files are mutiple tiles per country
Thsi code help to mosaic multiple tif rasters for a given country to have a single file per each varaiable for a country
"""
import os
import re
import shutil
import tempfile
import unicodedata
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed

import numpy as np
import rasterio
from rasterio.io import MemoryFile
from rasterio.shutil import copy as rio_copy
from rasterio.warp import reproject, Resampling as WarpResampling
from tqdm.auto import tqdm

# -------------------------------------------------------------------
# Google Drive setup
# -------------------------------------------------------------------
# NOTE: 'drive' must be an authenticated PyDrive2 client in your session.
ROOT_FOLDER_ID = "18pQKnMMnLramhHRZSNwUJrLqG5DXNMmS"  # folder id from gogole dirve

# Variables 
VARIABLES = [
    "NDVI_mean", "NDVI_max", "NDVI_min",
    "NDWI_mean", "NDWI_max", "NDWI_min",
    "GI_mean", "GI_max", "GI_min",
    "elevation", "slope",
    "ET", "PET",

]

# Provinces to process
PROVINCES = [
    "Philippines",
    # ...
]

# Optional synonyms (case-insensitive)
COUNTRY_SYNONYMS = {
    "Philippines": "philippines",
}

# -------------------------------------------------------------------
# Helpers
# -------------------------------------------------------------------
VARIABLE_TOKENS = {v.lower(): v for v in VARIABLES}  # canonicalize


def _squash_spaces(s: str) -> str:
    return re.sub(r"\s+", " ", s).strip()


def normalize_country(name: str) -> str:
    """Normalize 'Africa_Inner_Mongolia' -> 'Africa Inner Mongolia', apply synonyms."""
    if not name:
        return ""
    n = name.strip().replace("_", " ")
    n = _squash_spaces(n)
    lower_syn = {k.lower(): v for k, v in COUNTRY_SYNONYMS.items()}
    if n.lower() in lower_syn:
        n = lower_syn[n.lower()]
    return n


def _safe_basename(s: str) -> str:
    """ASCII slug for filenames."""
    norm = unicodedata.normalize("NFKD", s)
    ascii_only = norm.encode("ascii", "ignore").decode("ascii")
    slug = re.sub(r"[^A-Za-z0-9]+", "_", ascii_only).strip("_")
    return slug or "untitled"


def _drive_q_escape(s: str) -> str:
    return s.replace("\\", "\\\\").replace("'", "\\'")


# -------------------------------------------------------------------
# Drive helpers
# -------------------------------------------------------------------
def get_or_create_folder(drive, parent_id, name):
    safe = _drive_q_escape(name)
    q = (
        f"'{parent_id}' in parents and trashed=false and "
        f"mimeType='application/vnd.google-apps.folder' and title='{safe}'"
    )
    res = drive.ListFile({'q': q}).GetList()
    if res:
        return res[0]['id']
    f = drive.CreateFile({
        'title': name,
        'parents': [{'id': parent_id}],
        'mimeType': 'application/vnd.google-apps.folder'
    })
    f.Upload()
    return f['id']


def list_all_tifs_recursive(drive, root_id):
    """Recursively list all .tif/.tiff under root_id."""
    tifs, stack = [], [root_id]
    while stack:
        folder_id = stack.pop()
        q = f"'{folder_id}' in parents and trashed=false"
        for it in drive.ListFile({'q': q}).GetList():
            mime = it.get('mimeType', '')
            if mime == 'application/vnd.google-apps.folder':
                stack.append(it['id'])
            else:
                title = it.get('title', '')
                if title and title.lower().endswith(('.tif', '.tiff')):
                    tifs.append(it)
    return tifs


def download_many(drive, files, dst_dir, max_workers=8):
    """Download Drive files concurrently to dst_dir, return local paths."""

    def _dl(f):
        local = os.path.join(dst_dir, f['title'])
        f.GetContentFile(local)
        return local

    paths = []
    with ThreadPoolExecutor(max_workers=max_workers) as ex:
        futs = [ex.submit(_dl, f) for f in files]
        for fut in tqdm(as_completed(futs), total=len(futs), desc="Downloading"):
            paths.append(fut.result())
    return paths


def upload_tif(drive, local_path, parent_id, title):
    f = drive.CreateFile({'title': title, 'parents': [{'id': parent_id}]})
    f.SetContentFile(local_path)
    f.Upload()
    return f['id']


# -------------------------------------------------------------------
# Parsing
# -------------------------------------------------------------------
def _filename_core(title: str) -> str:
    """
    Take a Drive title and return the 'logical' base name:

    e.g.
    'New_York_GI_max-0000000000-0000000000.tif' -> 'New_York_GI_max'
    """
    base = os.path.splitext(os.path.basename(title))[0]
    # Drop anything after the first '-' (tile ID, date, etc.)
    if '-' in base:
        base = base.split('-', 1)[0]
    return base


def _split_tokens(path_title: str):
    """
    Split the core filename into '_' tokens.

    'New_York_GI_max-0000000000-0000000000.tif'
      -> core 'New_York_GI_max'
      -> ['New', 'York', 'GI', 'max']
    """
    core = _filename_core(path_title)
    return [t for t in core.split('_') if t]


def parse_title(title: str):
    """
    Return (country_string_like_'Africa_Anhui', canonical_variable) or (None, None)
    """
    if not title.lower().endswith(('.tif', '.tiff')):
        return None, None

    toks = _split_tokens(title)
    if not toks:
        return None, None

    lower = [t.lower() for t in toks]

    # Match country: pick the longest province whose tokenized name matches the start
    best_country = None
    start_len = 0
    for prov in PROVINCES:
        ptoks = prov.split('_')
        if len(ptoks) <= len(lower) and [t.lower() for t in ptoks] == lower[:len(ptoks)]:
            if len(ptoks) > start_len:
                best_country = prov
                start_len = len(ptoks)

    if not best_country:
        return None, None

    # Find variable starting right after the country tokens
    best = None  # (i, j, canonical_var)
    for i in range(start_len, min(start_len + 4, len(lower))):
        for j in range(i, min(i + 3, len(lower))):
            cand = '_'.join(lower[i:j + 1])
            if cand in VARIABLE_TOKENS:
                # prefer longer match
                if best is None or (j - i) > (best[1] - best[0]):
                    best = (i, j, VARIABLE_TOKENS[cand])

    if not best:
        return None, None

    _, _, variable = best
    return best_country, variable


# -------------------------------------------------------------------
# Size helpers / dtype
# -------------------------------------------------------------------
def _bytes_per_pixel(dtype):
    return {
        "uint8": 1, "int8": 1,
        "uint16": 2, "int16": 2,
        "uint32": 4, "int32": 4,
        "float32": 4, "float64": 8
    }.get(dtype, 4)


def _should_bigtiff(width, height, count, dtype):
    est = width * height * count * _bytes_per_pixel(dtype) * 1.05
    return est >= (4 * 1024 ** 3)


# -------------------------------------------------------------------
# Grid/reference helpers
# -------------------------------------------------------------------
def _pick_reference(datasets):
    # Choose the most frequent (xres, yres); break ties by largest pixel area
    reslist = [(ds.transform.a, -ds.transform.e) for ds in datasets]
    counts = Counter(reslist)
    best_res = max(counts.items(), key=lambda kv: (kv[1], kv[0][0] * kv[0][1]))[0]
    for ds in datasets:
        if (ds.transform.a, -ds.transform.e) == best_res:
            return ds
    return datasets[0]


# -------------------------------------------------------------------
# COG writer
# -------------------------------------------------------------------
def write_cog_from_array(mosaic, profile, cog_path):
    force_bigtiff = _should_bigtiff(
        width=profile["width"],
        height=profile["height"],
        count=profile["count"],
        dtype=profile["dtype"]
    )

    src_profile = profile.copy()
    src_profile.update({
        "driver": "GTiff",
        "tiled": True,
        "blockxsize": profile.get("blockxsize", 512),
        "blockysize": profile.get("blockysize", 512),
        "compress": profile.get("compress", "LZW"),
        "predictor": profile.get("predictor", 3),
        "interleave": "band",
        "BIGTIFF": "YES" if force_bigtiff else "IF_SAFER",
    })

    with MemoryFile() as memfile:
        with memfile.open(**src_profile) as tmp:
            tmp.write(mosaic)
        with memfile.open() as src_ds:
            rio_copy(
                src_ds,
                cog_path,
                driver="COG",
                COMPRESS="LZW",
                PREDICTOR=src_profile["predictor"],
                BLOCKSIZE=512,
                OVERVIEW_LEVELS="2,4,8,16",
                OVERVIEW_RESAMPLING="AVERAGE",
                NUM_THREADS="ALL_CPUS",
                BIGTIFF="YES" if force_bigtiff else "IF_SAFER",
                RESAMPLING="NEAREST",
                DST_NODATA=profile.get("nodata", None),
            )


# -------------------------------------------------------------------
# Mosaic core
# -------------------------------------------------------------------
def _reducer_for_variable(var: str):
    """Pick a per-pixel reducer."""
    vl = var.lower()
    if vl.endswith("_min"):
        return np.fmin
    # default to fmax for max/mean/masks, etc.
    return np.fmax


def mosaic_files_to_array(datasets, reducer, default_nodata=-9999.0):
    """
    Mosaic single-band rasters to float32 using NaN as working nodata.
    Handles mixed CRS and pixel sizes (aligned to common grid).
    """
    # Common CRS: take the first non-None
    common_crs = next((ds.crs for ds in datasets if ds.crs is not None), None)
    if common_crs is None:
        raise RuntimeError(
            "None of the input tiles has a CRS defined; cannot reproject to a common grid."
        )

    ref = _pick_reference(datasets)
    ref_transform = ref.transform
    ref_res = (ref_transform.a, -ref_transform.e)

    # Union bounds in common CRS
    from rasterio.warp import transform_bounds
    minx = miny = float("inf")
    maxx = maxy = float("-inf")
    for ds in datasets:
        b = ds.bounds
        if ds.crs is not None and ds.crs != common_crs:
            b = transform_bounds(ds.crs, common_crs, *b, densify_pts=21)
        minx, miny = min(minx, b[0]), min(miny, b[1])
        maxx, maxy = max(maxx, b[2]), max(maxy, b[3])

    # Destination grid at the reference resolution
    from rasterio.transform import from_origin
    dst_transform = from_origin(minx, maxy, ref_res[0], ref_res[1])
    dst_w = int(np.ceil((maxx - minx) / ref_res[0]))
    dst_h = int(np.ceil((maxy - miny) / ref_res[1]))

    acc = np.full((dst_h, dst_w), np.nan, dtype=np.float32)

    for ds in datasets:
        temp = np.full((dst_h, dst_w), np.nan, dtype=np.float32)
        src_nodata = ds.nodata
        reproject(
            source=rasterio.band(ds, 1),
            destination=temp,
            src_transform=ds.transform,
            src_crs=(ds.crs if ds.crs is not None else common_crs),
            dst_transform=dst_transform,
            dst_crs=common_crs,
            src_nodata=src_nodata,
            dst_nodata=np.nan,
            resampling=WarpResampling.nearest,
        )
        if np.isnan(acc).all():
            acc = temp
        else:
            acc = reducer(acc, temp)  # NaN-aware fmin/fmax

    mosaic = np.where(np.isnan(acc), default_nodata, acc).astype(np.float32)[None, ...]
    return mosaic, dst_transform, common_crs, float(default_nodata)


# -------------------------------------------------------------------
# Country + variable mosaicking and upload
# -------------------------------------------------------------------
def mosaic_country_variable_to_drive(drive, parent_id, country, variable, by_country_id):
    """
    Find all tiles for (country, variable), mosaic, write COG, upload.
    """
    tifs = list_all_tifs_recursive(drive, parent_id)

    selected = []
    c_norm = country  # already province token form
    v_norm = variable.lower()

    for it in tifs:
        title = it.get('title', '')
        ctry, var = parse_title(title)
        if ctry is None:
            continue
        if ctry == c_norm and var.lower() == v_norm:
            selected.append(it)

    if not selected:
        print(f"‚ö†Ô∏è No files for {country} / {variable}")
        return

    print(f"\n‚ñ∂Ô∏è {country} / {variable} | {len(selected)} file(s)")

    tmpdir = tempfile.mkdtemp(
        prefix=f"{_safe_basename(country)}_{_safe_basename(variable)}_"
    )
    local_paths, datasets = [], []
    try:
        # Download tiles
        local_paths = download_many(drive, selected, tmpdir, max_workers=8)

        # Open datasets (require single band)
        for p in local_paths:
            ds = rasterio.open(p)
            if ds.count != 1:
                raise RuntimeError(
                    f"Only single-band rasters supported. "
                    f"{os.path.basename(p)} has {ds.count} bands."
                )
            datasets.append(ds)

        reducer = _reducer_for_variable(variable)

        # Mosaic to float32
        mosaic, out_transform, out_crs, out_nodata = mosaic_files_to_array(
            datasets, reducer=reducer
        )

        profile = {
            "dtype": "float32",
            "height": mosaic.shape[1],
            "width": mosaic.shape[2],
            "count": 1,
            "transform": out_transform,
            "crs": out_crs,
            "nodata": out_nodata,
            "blockxsize": 512,
            "blockysize": 512,
            "compress": "LZW",
            "predictor": 3,
        }

        # Output folder + filename
        display_country = normalize_country(country)
        country_folder_id = get_or_create_folder(drive, by_country_id, display_country)

        safe_country = _safe_basename(display_country)
        safe_variable = _safe_basename(variable)
        out_name = f"{safe_country}_{safe_variable}.tif"
        cog_local = os.path.join(tmpdir, out_name)

        print("   ‚Ä¢ writing COG ‚Ä¶")
        write_cog_from_array(mosaic, profile, cog_local)

        print(f"   ‚Ä¢ uploading to Drive as {out_name} ‚Ä¶")
        out_id = upload_tif(drive, cog_local, country_folder_id, out_name)
        print(f"‚úÖ Uploaded: {out_name} (file id: {out_id})")

    finally:
        for ds in datasets:
            try:
                ds.close()
            except Exception:
                pass
        shutil.rmtree(tmpdir, ignore_errors=True)


# -------------------------------------------------------------------
# Batch runner
# -------------------------------------------------------------------
def run_all_provinces_variables(drive):
    by_country_id = get_or_create_folder(drive, ROOT_FOLDER_ID, "By Country")
    # Ensure country folders exist
    for c in PROVINCES:
        get_or_create_folder(drive, by_country_id, normalize_country(c))

    for c in PROVINCES:
        for v in VARIABLES:
            try:
                mosaic_country_variable_to_drive(drive, ROOT_FOLDER_ID, c, v, by_country_id)
            except Exception as e:
                print(f"‚ùå {c} / {v}: {e}")


# -------------------------------------------------------------------
# Execute
# -------------------------------------------------------------------
if __name__ == "__main__":
    run_all_provinces_variables(drive)
    print("\nüéâ Done.")


In [None]:
"""
This script precomputes Agro-Ecological Zone (AEZ) data by:
1) generating per-state/country AEZ raster tiles aligned to existing state grids, and
2) extracting the corresponding AEZ pixel values for each ground-truth training point.
The AEZ value is added as a new column in the CSV, alongside the existing predictor values,
for each corresponding GTPS.
"""

import os, re, json, math, tempfile, random, time
import numpy as np
import pandas as pd
from collections import deque
import rasterio
from rasterio.enums import Resampling as RioResampling
from rasterio.warp import reproject, transform as rio_transform, transform_bounds as rio_transform_bounds
from rasterio.transform import Affine
from rasterio.crs import CRS

try:
    from googleapiclient.errors import HttpError
except Exception:
    HttpError = Exception
try:
    from pydrive2.files import ApiRequestError
except Exception:
    ApiRequestError = Exception

# ---------- reuse your config ----------
ROOT_FOLDER_ID = os.environ.get("ROOT_FOLDER_ID", "1hqMIyDYEFKnpS8KLxC4bqHmF_9dHXImG")
BY_COUNTRY_NAME = "By Country"
MODEL_FOLDER = "Model Training"
LOCAL_BASE_DIR = os.path.join("./", "US")
LOCAL_BY_COUNTRY_DIR = os.path.join(LOCAL_BASE_DIR, "By Country")
LOCAL_MODEL_DIR = os.path.join(LOCAL_BASE_DIR, MODEL_FOLDER)

AEZ_FILE_ID   = os.environ.get("AEZ_FILE_ID", "1te3nKn8vyt2AECmk8NM_xBgpEFXp84RX")
AEZ_FILE_NAME = os.environ.get("AEZ_FILE_NAME", "AEZ_2020s.tif")

POINTS_BASENAME = "US"
LONCOL = "longitude"; LATCOL = "latitude"
AEZ_COL = "AEZ"

DRIVE_MAX_RETRIES = int(os.environ.get("DRIVE_MAX_RETRIES", 6))
DRIVE_RETRY_BASE  = float(os.environ.get("DRIVE_RETRY_BASE", 0.8))

# ---------- Drive helpers (same behavior as your main file) ----------
def _drive_retry(callable_fn, *args, **kwargs):
    last_err = None
    for i in range(DRIVE_MAX_RETRIES):
        try:
            return callable_fn(*args, **kwargs)
        except (HttpError, ApiRequestError) as e:
            last_err = e
            code = getattr(getattr(e, "resp", None), "status", None)
            if code is None:
                msg = str(e).lower()
                transient = any(k in msg for k in ["internal error","backenderror","rate limit","timeout"])
            else:
                transient = 500 <= int(code) < 600 or int(code) in (403, 429)
            if not transient or i == DRIVE_MAX_RETRIES - 1:
                break
            time.sleep(DRIVE_RETRY_BASE * (2 ** i) + random.random() * 0.2)
        except Exception as e:
            last_err = e
            break
    raise last_err

def _drive_query(drive, q):
    def _run():
        return drive.ListFile({"q": q, "supportsAllDrives": True, "includeItemsFromAllDrives": True, "maxResults": 1000}).GetList()
    return _drive_retry(_run)

def get_subfolder(drive, parent_id, name):
    if drive is None:
        p = os.path.join(LOCAL_BASE_DIR, name)
        return p if os.path.isdir(p) else None
    q = f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'"
    r = _drive_query(drive, q)
    return r[0]["id"] if r else None

def get_or_create_folder(drive, parent_id, name):
    if drive is None:
        base = LOCAL_BASE_DIR if not os.path.isabs(parent_id) else parent_id
        p = os.path.join(base, name) if os.path.isdir(base) else os.path.join(LOCAL_BASE_DIR, name)
        os.makedirs(p, exist_ok=True)
        return p
    q = f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'"
    res = _drive_query(drive, q)
    if res:
        return res[0]["id"]
    def _create():
        f = drive.CreateFile({"title": name, "parents": [{"id": parent_id}], "mimeType": "application/vnd.google-apps.folder"})
        f.Upload(); return f["id"]
    return _drive_retry(_create)

def list_files(drive, parent_id):
    if drive is None:
        return sorted([os.path.join(parent_id, p) for p in os.listdir(parent_id)]) if os.path.isdir(parent_id) else []
    q = f"'{parent_id}' in parents and trashed=false"
    try: return _drive_query(drive, q)
    except Exception:
        return []

def download_to_temp(drive_file, dst_path):
    def _dl():
        drive_file.GetContentFile(dst_path); return dst_path
    return _drive_retry(_dl)

def upload_path(drive, local_path, parent_id, title=None):
    if drive is None:
        os.makedirs(parent_id, exist_ok=True)
        import shutil
        dst = os.path.join(parent_id, title or os.path.basename(local_path))
        shutil.copy2(local_path, dst); return dst
    def _up():
        f = drive.CreateFile({"title": title or os.path.basename(local_path), "parents": [{"id": parent_id}]})
        f.SetContentFile(local_path); f.Upload(); return f["id"]
    return _drive_retry(_up)

def _drive_walk(drive, start_id, max_depth=4):
    q = deque([(start_id, 0)])
    while q:
        fid, d = q.popleft()
        items = list_files(drive, fid)
        yield fid, items
        if d >= max_depth: continue
        for it in items:
            if isinstance(it, dict) and it.get("mimeType") == "application/vnd.google-apps.folder":
                q.append((it["id"], d + 1))

# ---------- AEZ and points IO ----------
def open_aez_path(drive):
    if drive is None:
        p = os.path.join(LOCAL_BASE_DIR, "AEZ", AEZ_FILE_NAME)
        return p if os.path.exists(p) else None

    def _resolve_shortcut(file_obj):
        if file_obj.get("mimeType") == "application/vnd.google-apps.shortcut":
            tgt = file_obj.get("shortcutDetails", {}).get("targetId")
            if tgt:
                g = drive.CreateFile({"id": tgt})
                g.FetchMetadata(fields="title,mimeType,shortcutDetails"); return g
        return file_obj

    if AEZ_FILE_ID:
        f = drive.CreateFile({"id": AEZ_FILE_ID})
        f.FetchMetadata(fields="title,mimeType,shortcutDetails"); f = _resolve_shortcut(f)
        mime = f.get("mimeType")
        if mime == "application/vnd.google-apps.folder":
            folder_id = f["id"]
        elif mime and mime.startswith("application/vnd.google-apps."):
            raise RuntimeError(f"AEZ_FILE_ID points to a Google Doc ({mime}), not a TIFF.")
        else:
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
            f.GetContentFile(tmp); return tmp
    else:
        folder_id = None

    folder_id = folder_id or os.environ.get("AEZ_FOLDER_ID", None)
    if folder_id:
        q = f"'{folder_id}' in parents and trashed=false and title='{AEZ_FILE_NAME}'"
        cand = _drive_query(drive, q)
        if not cand:
            q_any = f"'{folder_id}' in parents and trashed=false and title contains '.tif'"
            cand = _drive_query(drive, q_any)
            cand = [c for c in cand if c.get("title","") == AEZ_FILE_NAME] or cand
        if not cand:
            raise RuntimeError(f"Could not find {AEZ_FILE_NAME} in folder id {folder_id}.")
        f = drive.CreateFile({"id": cand[0]["id"]})
        tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
        f.GetContentFile(tmp); return tmp

    p = os.path.join(LOCAL_BASE_DIR, "AEZ", AEZ_FILE_NAME)
    return p if os.path.exists(p) else None

def _find_points_vector_drive(drive):
    try: mt = get_subfolder(drive, ROOT_FOLDER_ID, MODEL_FOLDER)
    except Exception: mt = None
    roots = [mt or ROOT_FOLDER_ID, ROOT_FOLDER_ID]
    for root in roots:
        for fid, items in _drive_walk(drive, root, max_depth=4):
            stems={}
            for it in items:
                if not isinstance(it, dict): continue
                title = it.get("title",""); m = re.match(rf"(.+)\.(shp|dbf|shx|prj|cpg|qpj)$", title, re.IGNORECASE)
                if not m: continue
                stem = m.group(1)
                if stem.lower() == POINTS_BASENAME.lower():
                    stems.setdefault(stem,[]).append(it)
            for stem, parts in stems.items():
                if any(p.get("title","").lower().endswith(".shp") for p in parts):
                    tdir = tempfile.mkdtemp()
                    for p in parts: download_to_temp(p, os.path.join(tdir, p.get("title","")))
                    shp_path = os.path.join(tdir, f"{os.path.basename(stem)}.shp")
                    if os.path.exists(shp_path):
                        print(f"Using vector points (shp): {os.path.basename(stem)}.shp"); return shp_path
            for it in items:
                if not isinstance(it, dict): continue
                title = it.get("title",""); low = title.lower()
                if it.get("mimeType") == "application/vnd.google-apps.folder": continue
                if low.endswith(".gpkg") and low == f"{POINTS_BASENAME.lower()}.gpkg":
                    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".gpkg").name
                    download_to_temp(it, tmp); print(f"Using vector points (gpkg): {title}"); return tmp
                if low.endswith(".zip") and low == f"{POINTS_BASENAME.lower()}.zip":
                    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".zip").name
                    download_to_temp(it, tmp); print(f"Using vector points (zip): {title}"); return tmp
                if low.endswith(".csv") and low == f"{POINTS_BASENAME.lower()}.csv":
                    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
                    download_to_temp(it, tmp); print(f"Using vector points (csv): {title}"); return tmp
    return None

def _find_points_vector_local():
    for root in [os.path.join(LOCAL_MODEL_DIR), LOCAL_BASE_DIR]:
        if not os.path.isdir(root): continue
        for nm in os.listdir(root):
            low = nm.lower()
            if low == f"{POINTS_BASENAME.lower()}.csv": return os.path.join(root, nm)
            if POINTS_BASENAME.lower() in low and low.endswith(".gpkg"): return os.path.join(root, nm)
            if POINTS_BASENAME.lower() in low and low.endswith(".zip"):  return os.path.join(root, nm)
            if low.endswith(".shp") and POINTS_BASENAME.lower() in low:  return os.path.join(root, nm)
    return None

def load_points_raw(drive):
    import geopandas as gpd, fiona
    path = _find_points_vector_drive(drive) if drive is not None else _find_points_vector_local()
    if path is None:
        raise RuntimeError("Could not locate US.(shp/zip/gpkg/csv) in 'Model Training' tree.")
    if path.lower().endswith(".csv"):
        df = pd.read_csv(path)
    else:
        with fiona.Env(SHAPE_RESTORE_SHX="YES"):
            gdf = gpd.read_file(path)
        if gdf.crs is not None and gdf.crs.to_epsg() != 4326:
            gdf = gdf.to_crs(4326)
        df = pd.DataFrame(gdf.drop(columns="geometry", errors="ignore"))
    if LONCOL not in df.columns or LATCOL not in df.columns:
        # best-effort case-insensitive pickup
        def _ci(cols, name): 
            for c in cols:
                if c.lower()==name.lower(): return c
            return name
        df.rename(columns={_ci(df.columns,LONCOL):LONCOL, _ci(df.columns,LATCOL):LATCOL}, inplace=True)
    return df

def _sample_raster_at_lonlat(src, lon_arr, lat_arr):
    xs = np.asarray(lon_arr, dtype=float); ys = np.asarray(lat_arr, dtype=float)
    if src.crs and (src.crs.to_epsg() != 4326):
        tx, ty = rio_transform(CRS.from_epsg(4326), src.crs, xs.tolist(), ys.tolist())
        xs = np.asarray(tx, dtype=float); ys = np.asarray(ty, dtype=float)
    out = np.full(xs.shape[0], np.nan, dtype=np.float32)
    xmin, ymin, xmax, ymax = src.bounds.left, src.bounds.bottom, src.bounds.right, src.bounds.top
    inside = (xs >= xmin) & (xs <= xmax) & (ys >= ymin) & (ys <= ymax)
    if not np.any(inside): return out
    idx = np.where(inside)[0]
    coords = list(zip(xs[idx], ys[idx]))
    vals = np.array([v[0] for v in src.sample(coords)], dtype=np.float32)
    if src.nodata is not None and np.isfinite(src.nodata):
        vals = np.where(np.isclose(vals, np.float32(src.nodata)), np.float32(np.nan), vals)
    out[idx] = vals; return out

# ---------- Per-state AEZ tiles ----------
def _pick_reference_raster_for_bounds(drive, state):
    # use any state raster present (CLE or NDVI_mean etc.)
    if drive is None:
        pdir = os.path.join(LOCAL_BY_COUNTRY_DIR, state)
        if not os.path.isdir(pdir):
            return None
        for nm in os.listdir(pdir):
            if nm.endswith(".tif") and (("_CLE" in nm) or ("NDVI_mean" in nm) or ("NDVI" in nm)):
                return os.path.join(pdir, nm)
        # fallback to any tif
        for nm in os.listdir(pdir):
            if nm.endswith(".tif"):
                return os.path.join(pdir, nm)
        return None
    byc = get_subfolder(drive, ROOT_FOLDER_ID, BY_COUNTRY_NAME)
    if not byc: return None
    q = f"'{byc}' in parents and trashed=false and title='{state}' and mimeType='application/vnd.google-apps.folder'"
    res = _drive_query(drive, q)
    if not res: return None
    sid = res[0]["id"]
    files = list_files(drive, sid)
    pick = None
    for it in files:
        if not isinstance(it, dict): continue
        t = it.get("title","")
        if t.endswith(".tif") and (("_CLE" in t) or ("NDVI_mean" in t) or ("NDVI" in t)):
            pick = it; break
    if not pick:
        for it in files:
            if isinstance(it, dict) and it.get("title","").endswith(".tif"):
                pick = it; break
    if not pick: return None
    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
    return download_to_temp(pick, tmp)

def _target_wgs84_grid_from_bounds(west, south, east, north, pixel_m=30.0):
    mid_lat = 0.5 * (south + north)
    m_per_deg_lat = 111_132.0
    m_per_deg_lon = 111_320.0 * math.cos(math.radians(mid_lat))
    dy = pixel_m / m_per_deg_lat; dx = pixel_m / m_per_deg_lon
    width = int(math.ceil((east - west) / dx)); height = int(math.ceil((north - south) / dy))
    transform = Affine(dx, 0, west, 0, -dy, north)
    return height, width, transform, CRS.from_epsg(4326)

def _clip_bounds_from_reference(raster_path):
    with rasterio.open(raster_path) as src:
        b = rio_transform_bounds(src.crs, CRS.from_epsg(4326),
                                 src.bounds.left, src.bounds.bottom, src.bounds.right, src.bounds.top, densify_pts=8)
    return b

def write_state_aez_tif(drive, state):
    aez_src_path = open_aez_path(drive)
    if not aez_src_path:
        raise RuntimeError("AEZ raster not found.")
    ref = _pick_reference_raster_for_bounds(drive, state)
    if not ref:
        print(f"   ! skip {state}: no reference raster to define grid"); return
    w,s,e,n = _clip_bounds_from_reference(ref)
    H,W,transform,crs = _target_wgs84_grid_from_bounds(w,s,e,n, pixel_m=30.0)

    base_profile = {
        "driver":"GTiff","height":H,"width":W,"count":1,"crs":crs,"transform":transform,
        "compress":"LZW","tiled":True,"blockxsize":512,"blockysize":512,
        "dtype":"float32","nodata":-9999.0
    }
    with tempfile.TemporaryDirectory() as tdir:
        out_tmp = os.path.join(tdir, f"{state}_AEZ.tif")
        with rasterio.open(aez_src_path) as src, rasterio.open(out_tmp, "w", **base_profile) as dst:
            dst_arr = np.full((H,W), -9999.0, dtype=np.float32)
            reproject(source=rasterio.band(src,1), destination=dst_arr,
                      src_transform=src.transform, src_crs=src.crs,
                      dst_transform=transform, dst_crs=crs,
                      dst_nodata=-9999.0, resampling=RioResampling.nearest)
            dst.write(dst_arr, 1)

        # upload/write to By Country/<State>/<State>_AEZ.tif
        if drive is None:
            out_dir = os.path.join(LOCAL_BY_COUNTRY_DIR, state)
            os.makedirs(out_dir, exist_ok=True)
        else:
            byc = get_subfolder(drive, ROOT_FOLDER_ID, BY_COUNTRY_NAME) or get_or_create_folder(drive, ROOT_FOLDER_ID, BY_COUNTRY_NAME)
            # ensure state folder
            if drive is None:
                out_dir = os.path.join(byc, state); os.makedirs(out_dir, exist_ok=True)
            else:
                q = f"'{byc}' in parents and trashed=false and title='{state}' and mimeType='application/vnd.google-apps.folder'"
                res = _drive_query(drive, q)
                state_id = res[0]["id"] if res else get_or_create_folder(drive, byc, state)
                out_dir = state_id
        upload_path(drive, out_tmp, out_dir, f"{state}_AEZ.tif")
    print(f"   ‚Ä¢ wrote {state}_AEZ.tif")

def list_states(drive):
    names=set()
    if drive is None:
        if os.path.isdir(LOCAL_BY_COUNTRY_DIR):
            for d in os.listdir(LOCAL_BY_COUNTRY_DIR):
                if os.path.isdir(os.path.join(LOCAL_BY_COUNTRY_DIR,d)):
                    names.add(d)
    else:
        byc = get_subfolder(drive, ROOT_FOLDER_ID, BY_COUNTRY_NAME)
        if byc:
            subs = _drive_query(drive, f"'{byc}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder'")
            for it in subs: names.add(it["title"])
    return sorted(names)

# ---------- Points cache ----------
def augment_points_with_aez(drive):
    df = load_points_raw(drive)
    if AEZ_COL in df.columns and df[AEZ_COL].notna().any():
        print("   ‚Ä¢ points already have AEZ; writing cache anyway")
    aez_path = open_aez_path(drive)
    if not aez_path:
        raise RuntimeError("AEZ raster not found.")
    with rasterio.open(aez_path) as src:
        vals = _sample_raster_at_lonlat(src, df[LONCOL].values, df[LATCOL].values)
    df[AEZ_COL] = np.rint(pd.to_numeric(vals, errors="coerce")).astype("float64")

    # stable rounding keys for merge
    df["_lonr"] = np.round(df[LONCOL].astype(float), 6)
    df["_latr"] = np.round(df[LATCOL].astype(float), 6)

    # write cache
    if drive is None:
        out_dir = LOCAL_MODEL_DIR
    else:
        out_dir = get_or_create_folder(drive, ROOT_FOLDER_ID, MODEL_FOLDER)
    with tempfile.TemporaryDirectory() as tdir:
        pq = os.path.join(tdir, "US_with_AEZ.parquet")
        cs = os.path.join(tdir, "US_with_AEZ.csv")
        df[[LONCOL,LATCOL,"_lonr","_latr",AEZ_COL]].to_parquet(pq, index=False)
        df[[LONCOL,LATCOL,"_lonr","_latr",AEZ_COL]].to_csv(cs, index=False)
        upload_path(drive, pq, out_dir, "US_with_AEZ.parquet")
        upload_path(drive, cs, out_dir, "US_with_AEZ.csv")
    print("   ‚Ä¢ wrote Model Training/US_with_AEZ.parquet (+ .csv)")

# ---------- Runner ----------
def run(use_drive=True, states=None):
    if use_drive:
        try:
            drive  # noqa
            _drive = drive
            try: _drive.auth.service.http.timeout = 120
            except Exception: pass
        except NameError:
            raise RuntimeError("PyDrive 'drive' not found. Authenticate and expose 'drive', or set use_drive=False.")
    else:
        _drive = None
        os.makedirs(LOCAL_BY_COUNTRY_DIR, exist_ok=True)
        os.makedirs(LOCAL_MODEL_DIR, exist_ok=True)

    # A) per-state AEZ tiles
    if not states:
        states = list_states(_drive)
    print("=== Build per-state AEZ tiles ===")
    for s in states:
        try:
            write_state_aez_tif(_drive, s)
        except Exception as e:
            print(f"   ! {s}: {e}")

    # B) points AEZ cache
    print("=== Build points AEZ cache ===")
    augment_points_with_aez(_drive)
    print("‚úÖ done.")

if __name__ == "__main__":
    run(use_drive=True, states=None)
