In [None]:
from pydrive2.auth import GoogleAuth
from pydrive2.drive import GoogleDrive

# local dirve path to  JSON.
CLIENT_JSON = "client_secret_Final.json"   #Google Dirve API
gauth = GoogleAuth()
gauth.LoadClientConfigFile(CLIENT_JSON)

# Command-line auth -> 
gauth.CommandLineAuth()

# Save token 
gauth.SaveCredentialsFile('token.json')

drive = GoogleDrive(gauth)
print("âœ… Google Drive connected")

âœ… Google Drive connected


In [None]:
"""
Plot US_binary.tif with:
  - Top panel: country-wide map (ADM0 boundary + irrigated in dark blue).
  - Bottom panel: zoomed inset, clipping US_binary to US_small shapefile.

Files (Google Drive layout):
  CountryModelPredicted_Cropland /
    Probability|Probablity|Prob /
      Merged /
        US_binary.tif
        US_small.(shp,shx,dbf,prj,cpg)
        Comparisons /
          US.(shp,shx,dbf,prj,cpg)
          --> output PNG will be saved here as:
              US_binary_with_zoom_map.png

Requirements:
  - PyDrive2-authenticated global `drive`
  - rasterio, numpy, matplotlib, fiona, shapely, pyproj
"""

import os
import re
import tempfile
import warnings

import numpy as np
import rasterio
from rasterio.enums import Resampling
from rasterio.mask import mask

import fiona
from shapely.geometry import shape
from shapely.ops import transform as shp_transform
from shapely.ops import unary_union
from pyproj import Transformer
import os
import re
import tempfile
import warnings

import numpy as np
import rasterio
from rasterio.enums import Resampling
from rasterio.mask import mask  # kept, though not used now
import rasterio.windows as rwindows

import fiona
from shapely.geometry import shape
from shapely.ops import transform as shp_transform
from shapely.ops import unary_union
from pyproj import Transformer

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

warnings.filterwarnings("ignore", category=RuntimeWarning)

# ===================== CONFIG =====================
ROOT_FOLDER_ID      = os.environ.get("ROOT_FOLDER_ID", "19Ee1kUar4oEabBObq7GcA-jDR7F29x2_")
PARENT_FOLDER_NAME  = "CountryModelPredicted_Cropland"
PROB_NAMES          = ("Probability", "Probablity", "Prob")
BINARY_SUBFOLDER    = "Merged"
COMPARE_SUBFOLDER   = "Comparisons"

CHINA_BINARY_NAME   = "China_binary.tif"
SMALL_BASE          = "China_small_area_shapefile"
CHN_ADM0_BASE       = "gadm41_CHN_0"


PLOT_DPI            = 160
# ---------------------- Drive helpers ---------------------
def _dq(drive, q):
    return drive.ListFile({
        "q": q,
        "supportsAllDrives": True,
        "includeItemsFromAllDrives": True,
        "maxResults": 1000
    }).GetList()

def list_files(drive, parent_id):
    return _dq(drive, f"'{parent_id}' in parents and trashed=false")

def child_folders(drive, parent_id):
    return _dq(
        drive,
        f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder'"
    )

def get_subfolder_exact(drive, parent_id, name):
    res = _dq(
        drive,
        f"'{parent_id}' in parents and trashed=false "
        f"and mimeType='application/vnd.google-apps.folder' and title='{name}'"
    )
    return res[0]["id"] if res else None

def get_subfolder_fuzzy(drive, parent_id, desired):
    """Simple fuzzy finder for folders (handles 'Probability' typo)."""
    eid = get_subfolder_exact(drive, parent_id, desired)
    if eid:
        return eid
    key = re.sub(r"\s+", "", desired.lower())
    for f in child_folders(drive, parent_id):
        t = (f.get("title") or f.get("name") or "").lower()
        tkey = re.sub(r"\s+", "", t)
        if key in tkey or ("prob" in key and "prob" in tkey):
            return f["id"]
    return None

def download_to_temp(drive_file, local_path):
    drive_file.GetContentFile(local_path)

def upload_path(drive, local_path, parent_id, title=None):
    f = drive.CreateFile(
        {"title": title or os.path.basename(local_path),
         "parents":[{"id": parent_id}]}
    )
    f.SetContentFile(local_path)
    f.Upload()
    return f["id"]

def download_shapefile_bundle(drive, folder_id, base):
    """
    Download base.(shp,shx,dbf,prj,cpg) from a folder into a temp dir.
    Returns path to the local .shp.
    """
    exts = [".shp", ".shx", ".dbf", ".prj", ".cpg"]
    items = {}
    for it in list_files(drive, folder_id):
        title = (it.get("title") or "").lower()
        for e in exts:
            if title == (base.lower() + e):
                items[e] = it

    if ".shp" not in items:
        raise FileNotFoundError(f"Could not find {base}.shp in folder.")

    tmpdir = tempfile.mkdtemp(prefix=base + "_")
    for e, it in items.items():
        lp = os.path.join(tmpdir, base + e)
        download_to_temp(it, lp)
    return os.path.join(tmpdir, base + ".shp")


# --------------------- Main plotting logic --------------------
def plot_china_binary_with_zoom(drive):
    # 1) Navigate to folders
    root_id = ROOT_FOLDER_ID
    cmp_id = get_subfolder_fuzzy(drive, root_id, PARENT_FOLDER_NAME)
    if not cmp_id:
        raise RuntimeError(f"Folder '{PARENT_FOLDER_NAME}' not found under ROOT.")

    prob_id = None
    for nm in PROB_NAMES:
        prob_id = get_subfolder_fuzzy(drive, cmp_id, nm)
        if prob_id:
            break
    if not prob_id:
        raise RuntimeError("Probability/Probablity/Prob folder not found.")

    binary_id = get_subfolder_fuzzy(drive, prob_id, BINARY_SUBFOLDER)
    if not binary_id:
        raise RuntimeError(f"Binary folder '{BINARY_SUBFOLDER}' not found.")

    comp_id = get_subfolder_fuzzy(drive, binary_id, COMPARE_SUBFOLDER)
    if not comp_id:
        raise RuntimeError(f"Comparisons folder '{COMPARE_SUBFOLDER}' not found inside {BINARY_SUBFOLDER}/.")

    # 2) Find US_binary.tif and small-area shapefile in Merged
    china_file = None
    small_shp_present = False
    for it in list_files(drive, binary_id):
        title = (it.get("title") or "").strip().lower()
        if title == CHINA_BINARY_NAME.lower():
            china_file = it
        if title.startswith(SMALL_BASE.lower()) and title.endswith(".shp"):
            small_shp_present = True
    if not china_file:
        raise FileNotFoundError(f"{CHINA_BINARY_NAME} not found in Merged/.")
    if not small_shp_present:
        raise FileNotFoundError(f"{SMALL_BASE}.shp not found in Merged/.")

    # 3) Download US_binary.tif and shapefiles
    china_tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
    download_to_temp(china_file, china_tmp)

    shp_small_path = download_shapefile_bundle(drive, binary_id, SMALL_BASE)
    shp_chn_path   = download_shapefile_bundle(drive, comp_id, CHN_ADM0_BASE)

    try:
        # === Read US_binary with downsampling for the main map ===
        with rasterio.open(china_tmp) as src:
            H, W = src.height, src.width
            bounds = src.bounds
            raster_crs = src.crs
            nodata = src.nodata

            # target ~10M pixels max for the main map
            max_pixels = 10_000_000
            scale_factor = max(1, int(np.sqrt((H * W) / max_pixels)))
            out_h = max(1, H // scale_factor)
            out_w = max(1, W // scale_factor)

            china_arr = src.read(
                1,
                out_shape=(out_h, out_w),
                resampling=Resampling.nearest,
                masked=True
            )

            # Extent for imshow
            extent_main = (bounds.left, bounds.right, bounds.bottom, bounds.top)

        # === Read ADM0 boundary and reproject to raster CRS ===
        chn_geoms = []
        with fiona.open(shp_chn_path) as src:
            shp_crs = src.crs
            if shp_crs and raster_crs and shp_crs != raster_crs:
                transformer = Transformer.from_crs(shp_crs, raster_crs, always_xy=True)
                project = lambda x, y, z=None: transformer.transform(x, y)
            else:
                project = None

            for feat in src:
                g = shape(feat["geometry"])
                if project is not None:
                    g = shp_transform(project, g)
                chn_geoms.append(g)

        # === Read small-area shapefile and reproject to raster CRS ===
        small_geoms = []
        with fiona.open(shp_small_path) as src:
            shp_crs = src.crs
            if shp_crs and raster_crs and shp_crs != raster_crs:
                transformer = Transformer.from_crs(shp_crs, raster_crs, always_xy=True)
                project = lambda x, y, z=None: transformer.transform(x, y)
            else:
                project = None

            for feat in src:
                g = shape(feat["geometry"])
                if project is not None:
                    g = shp_transform(project, g)
                small_geoms.append(g)

        small_union = unary_union(small_geoms)

        # === Crop US_binary to the RECTANGULAR BBOX of the small area for inset ===
        with rasterio.open(china_tmp) as src:
            bbox = small_union.bounds  # (minx, miny, maxx, maxy)
            win = rwindows.from_bounds(*bbox, transform=src.transform)
            inset_arr = src.read(1, window=win, masked=True)
            inset_transform = src.window_transform(win)

            inset_bounds = rasterio.transform.array_bounds(
                inset_arr.shape[0], inset_arr.shape[1], inset_transform
            )
            # array_bounds returns (ymin, ymax, xmin, xmax)
            inset_extent = (inset_bounds[2], inset_bounds[3],
                            inset_bounds[0], inset_bounds[1])

        # === Set up colormap: 0=grey (non-irrigated), 1=dark blue (irrigated), nodata transparent ===
        cmap = ListedColormap([
            (0.9, 0.9, 0.9, 1.0),  # 0 non-irrigated
            (0.0, 0.0, 0.5, 1.0),  # 1 irrigated (dark blue)
        ])
        # for masked values (nodata)
        cmap.set_bad((0, 0, 0, 0))

        # === Build the figure ===
        fig, (ax_main, ax_inset) = plt.subplots(
            nrows=2, ncols=1,
            figsize=(7.5, 9.0),
            dpi=PLOT_DPI,
            gridspec_kw={"height_ratios": [3, 2]}
        )

        # ----- TOP PANEL: country-wide map -----
        ax_main.imshow(
            china_arr,
            origin="upper",
            cmap=cmap,
            vmin=0,
            vmax=1,
            extent=extent_main
        )

        # ONLY outer shapefile boundary in black (ADM0)
        for g in chn_geoms:
            if g.geom_type == "MultiPolygon":
                polys = list(g.geoms)
            else:
                polys = [g]
            for poly in polys:
                x, y = poly.exterior.xy
                ax_main.plot(x, y, color="black", linewidth=0.7)

        # Remove axis box but keep titles
        for spine in ax_main.spines.values():
            spine.set_visible(False)

        ax_main.set_xticks([])
        ax_main.set_yticks([])
        ax_main.set_xlabel("")
        ax_main.set_ylabel("")
        ax_main.set_title(f"Irrigated area (dark blue) from {CHINA_BINARY_NAME}", fontsize=13)

        # ----- BOTTOM PANEL: zoomed inset -----
        inset_masked = inset_arr  # already masked=True above

        ax_inset.imshow(
            inset_masked,
            origin="upper",
            cmap=cmap,
            vmin=0,
            vmax=1,
            extent=inset_extent
        )

        # ONLY outer shapefile boundary in black (US_small)
        for g in small_geoms:
            if g.geom_type == "MultiPolygon":
                polys = list(g.geoms)
            else:
                polys = [g]
            for poly in polys:
                x, y = poly.exterior.xy
                ax_inset.plot(x, y, color="black", linewidth=0.7)

        # Remove axis box but keep titles
        for spine in ax_inset.spines.values():
            spine.set_visible(False)

        ax_inset.set_xticks([])
        ax_inset.set_yticks([])
        ax_inset.set_xlabel("")
        ax_inset.set_ylabel("")
        ax_inset.set_title(f"Zoomed irrigated area ({SMALL_BASE}.shp, dark blue)", fontsize=12)
        ax_inset.set_aspect("equal")

        plt.tight_layout()

        # ---- Save & upload PNG ----
        out_png = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name
        fig.savefig(out_png, bbox_inches="tight", dpi=PLOT_DPI)
        plt.close(fig)

        out_title = "US_binary_with_zoom_map.png"
        upload_path(drive, out_png, comp_id, title=out_title)

        try:
            os.remove(out_png)
        except Exception:
            pass

        print(f"âœ… Saved country-wide + zoom map PNG as '{out_title}' in Merged/Comparisons/")

    finally:
        try:
            os.remove(china_tmp)
        except Exception:
            pass
        # shapefile tempdirs will be cleaned by OS eventually


# ---------------------- CLI ----------------------
if __name__ == "__main__":
    try:
        drive  # noqa: F821
    except NameError:
        raise RuntimeError(
            "PyDrive2 'drive' not found. Authenticate and expose a global `drive` before running."
        )
    plot_china_binary_with_zoom(drive)

âœ… Saved country-wide + zoom map PNG as 'US_binary_with_zoom_map.png' in Merged/Comparisons/


In [None]:
"""
Nationwide comparison: US_binary.tif (GMIA-NEXT) vs LANID

- Reprojects/resamples LANID onto the GMIA-NEXT national binary grid to make sure both have same dimensions and projections.
- Clips comparison to the extent where US_binary has valid data.
- Computes:
    * IoU (intersection over union) 
    * % Agreement (both 1)
    * % GMIA-NEXT only (US=1, LANID=0 or nodata)
    * % LANID only (LANID=1, US=0)
- Saves:
    * CSV summary
    * Bar chart PNG
    * Classified GeoTIFF (0=nodata, 1=Agreement, 2=GMIA-only, 3=LANID-only)
      in the GMIA-NEXT grid / projection.

Uploads outputs to:
    CountryModelPredicted_Cropland / (Probability|Probablity|Prob) /
      <BINARY_SUBFOLDER> / Comparisons/

Requirements:
  - PyDrive2-authenticated global `drive`
  - rasterio, numpy, matplotlib
"""

import os, re, csv, tempfile, warnings
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from rasterio.enums import Resampling
from rasterio.vrt import WarpedVRT

warnings.filterwarnings("ignore", category=RuntimeWarning)

# ------------------------- CONFIG -------------------------
ROOT_FOLDER_ID      = os.environ.get("ROOT_FOLDER_ID", "18pQKnMMnLramhHRZSNwUJrLqG5DXNMmS")
PARENT_FOLDER_NAME  = "CountryModelPredicted"
PROB_NAMES          = ("Probability", "Probablity", "Prob")   

# Folder holding the nationwide US_binary + LANID inside Probability/...
# e.g. "Binary_MAJ",
BINARY_SUBFOLDER    = "Merged"

# Name of nationwide GMIA-NEXT raster inside BINARY_SUBFOLDER
NATIONAL_NAME       = "US_Binary.tif"  #A country mmosaiced bianry map

LANID_HINT          = os.environ.get("LANID_HINT", "2014_2015")   # The dataset file anme to be validated agianst GMIA-NEXT
COMPARE_SUBFOLDER   = "Comparisons"
PLOT_DPI            = 160

# Colors (distinct & colorblind-friendly)
COLOR_AGREE = "#2ca02c"   # blue (Agreement)
COLOR_GMIA  = "#E69F00"   # orange  (GMIA-NEXT only)
COLOR_LANID = "#8B0000"   # dark red (LANID only)
# ----------------------------------------------------------

# ---------------------- Drive helpers ---------------------
def _dq(drive, q):
    return drive.ListFile({
        "q": q,
        "supportsAllDrives": True,
        "includeItemsFromAllDrives": True,
        "maxResults": 1000
    }).GetList()

def list_files(drive, parent_id):
    return _dq(drive, f"'{parent_id}' in parents and trashed=false")

def child_folders(drive, parent_id):
    return _dq(drive, f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder'")

def get_subfolder_exact(drive, parent_id, name):
    res = _dq(drive, f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'")
    return res[0]["id"] if res else None

def get_subfolder_fuzzy(drive, parent_id, desired):
    """Fuzzy finder for subfolders; also handles 'prob' variations."""
    eid = get_subfolder_exact(drive, parent_id, desired)
    if eid:
        return eid
    key = re.sub(r"\s+", "", desired.lower())
    for f in child_folders(drive, parent_id):
        t = (f.get("title") or f.get("name") or "").lower()
        tkey = re.sub(r"\s+", "", t)
        if key in tkey or ("prob" in key and "prob" in tkey):
            return f["id"]
    return None

def get_or_create_folder(drive, parent_id, name):
    res = _dq(drive, f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder' and title='{name}'")
    if res:
        return res[0]["id"]
    nf = drive.CreateFile({"title": name, "parents":[{"id": parent_id}], "mimeType":"application/vnd.google-apps.folder"})
    nf.Upload()
    return nf["id"]

def download_to_temp(drive_file, suffix):
    p = tempfile.NamedTemporaryFile(delete=False, suffix=suffix).name
    drive_file.GetContentFile(p)
    return p

def upload_path(drive, local_path, parent_id, title=None):
    f = drive.CreateFile({"title": title or os.path.basename(local_path), "parents":[{"id": parent_id}]})
    f.SetContentFile(local_path)
    f.Upload()
    return f["id"]

# ---------------------- Raster helpers --------------------
def find_lanid(drive, folder_id, hint="lanid"):
    cands = [it for it in list_files(drive, folder_id)
             if it.get("mimeType") != "application/vnd.google-apps.folder"
             and (it.get("title","") or "").lower().endswith((".tif",".tiff"))
             and hint.lower() in (it.get("title","") or "").lower()]
    return (cands[0], cands[0]["id"]) if cands else (None, None)

def find_us_binary(drive, folder_id, national_name=NATIONAL_NAME):
    for it in list_files(drive, folder_id):
        if it.get("mimeType") == "application/vnd.google-apps.folder":
            continue
        title = (it.get("title","") or "").strip()
        if title.lower() == national_name.lower():
            return it
    return None

def warp_to_ref(src_path, ref_ds, resampling=Resampling.nearest):
    """
    Warp src raster to reference dataset grid/CRS and return float32 array
    with NaNs as nodata.
    """
    with rasterio.open(src_path) as src:
        with WarpedVRT(
            src,
            crs=ref_ds.crs,
            transform=ref_ds.transform,
            height=ref_ds.height,
            width=ref_ds.width,
            resampling=resampling,
            src_nodata=np.nan,
            dst_nodata=np.nan,
        ) as vrt:
            arr = vrt.read(1, masked=True).astype("float32")
    return arr.filled(np.nan) if np.ma.isMaskedArray(arr) else arr

# ---------------------- Nationwide comparison -------------
def compare_us_to_lanid(drive):
    # Locate Binary folder
    cmp_id = get_subfolder_fuzzy(drive, ROOT_FOLDER_ID, PARENT_FOLDER_NAME)
    if not cmp_id:
        raise RuntimeError(f"Folder '{PARENT_FOLDER_NAME}' not found under ROOT.")

    
    prob_id = None
    for name in PROB_NAMES:
        prob_id = get_subfolder_fuzzy(drive, cmp_id, name)
        if prob_id:
            break
    if not prob_id:
        raise RuntimeError("Probability/Probablity/Prob folder not found.")

    binary_id = get_subfolder_fuzzy(drive, prob_id, BINARY_SUBFOLDER)
    if not binary_id:
        raise RuntimeError(f"Binary folder '{BINARY_SUBFOLDER}' not found.")

    out_id = get_or_create_folder(drive, binary_id, COMPARE_SUBFOLDER)

    # Find required rasters
    us_file = find_us_binary(drive, binary_id, NATIONAL_NAME)
    if not us_file:
        raise FileNotFoundError(f"National GMIA-NEXT raster '{NATIONAL_NAME}' not found in {BINARY_SUBFOLDER}/.")

    lanid_file, lanid_id = find_lanid(drive, binary_id, hint=LANID_HINT)
    if not lanid_id:
        raise FileNotFoundError("LANID raster not found in Binary folder (filename must contain 'lanid').")

    # Download both
    us_local    = download_to_temp(us_file, ".tif")
    lanid_local = download_to_temp(lanid_file, ".tif")

    try:
        # --- Use US_binary as reference grid/projection ---
        with rasterio.open(us_local) as ref:
            us = ref.read(1, masked=True).astype("float32")
            us = us.filled(np.nan) if np.ma.isMaskedArray(us) else us

            # Warp LANID to GMIA-NEXT grid
            lanid = warp_to_ref(lanid_local, ref, resampling=Resampling.nearest)

            # ---- Define comparison domain: where US has valid data ----
            finite_us    = np.isfinite(us)
            finite_lanid = np.isfinite(lanid)

            # Clip everything to US extent
            mask_eval = finite_us

            # Positive pixels (treat >0 as irrigated)
            us_pos    = (us > 0) & mask_eval
            lanid_pos = (lanid > 0) & finite_lanid & mask_eval

            # Categories (all within US extent)
            agree_mask   = us_pos & lanid_pos
            gmia_only_m  = us_pos & ~lanid_pos
            lanid_only_m = (~us_pos) & lanid_pos

            # Counts & stats
            agree      = int(np.count_nonzero(agree_mask))
            gmia_only  = int(np.count_nonzero(gmia_only_m))
            lanid_only = int(np.count_nonzero(lanid_only_m))

            union = agree + gmia_only + lanid_only
            if union == 0:
                raise RuntimeError("No positive pixels in the union of US_binary and LANID within US extent.")

            iou       = agree / union
            agree_pct = 100.0 * agree / union
            gmia_pct  = 100.0 * gmia_only / union
            lanid_pct = 100.0 * lanid_only / union

            # ---------- DIFF GeoTIFF (1=agree, 2=GMIA-only, 3=LANID-only) ----------
            class_map = np.zeros(ref.shape, dtype=np.uint8)  # 0 = nodata/background
            class_map[agree_mask]   = 1
            class_map[gmia_only_m]  = 2
            class_map[lanid_only_m] = 3

            profile = ref.profile.copy()
            profile.update(
                driver="GTiff",
                dtype="uint8",
                count=1,
                nodata=0,
                compress="deflate",
                predictor=2,
                tiled=True,
                blockxsize=256,
                blockysize=256,
                bigtiff="IF_SAFER",
            )

            diff_tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
            # Colormap (RGBA) matching bar colors
            cmap = {
                0: (0, 0, 0, 0),         # transparent nodata
                1: (44, 160, 44, 255),   # #2ca02c agreement
                2: (230, 159, 0, 255),   # #E69F00 GMIA-only
                3: (139, 0, 0, 255),     # #8B0000 LANID-only
            }
            with rasterio.open(diff_tmp, "w", **profile) as dst:
                dst.write(class_map, 1)
                dst.write_colormap(1, cmap)

            diff_title = re.sub(r"\.tif(f)?$", "", NATIONAL_NAME, flags=re.I) + "_vs_LANID_diff.tif"
            upload_path(drive, diff_tmp, out_id, title=diff_title)
            try:
                os.remove(diff_tmp)
            except Exception:
                pass
            # ------------------------------------------------------------------------

        # ---------- CSV ----------
        with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w", newline="") as tmp_csv:
            csv_path = tmp_csv.name
            w = csv.writer(tmp_csv)
            w.writerow(["raster_title", "iou", "agreement_pct", "gmia_next_only_pct", "lanid_only_pct"])
            w.writerow([NATIONAL_NAME, f"{iou:.6f}", f"{agree_pct:.2f}", f"{gmia_pct:.2f}", f"{lanid_pct:.2f}"])
        upload_path(drive, csv_path, out_id, title="US_vs_LANID_summary.csv")
        try:
            os.remove(csv_path)
        except Exception:
            pass

        # ---------- Bar chart PNG ----------
        fig, ax = plt.subplots(figsize=(6.6, 5.6), dpi=PLOT_DPI)
        labels = ["Agreement", "GMIA-NEXT Only", "LANID Only"]
        vals   = [agree_pct, gmia_pct, lanid_pct]
        colors = [COLOR_AGREE, COLOR_GMIA, COLOR_LANID]
        bars = ax.bar(labels, vals, color=colors)
        ax.set_ylabel("Percentage (%)")
        ax.set_ylim(0, 100)
        ax.set_title(f"US_binary vs LANID â€” IoU = {iou:.3f}")
        for b, v in zip(bars, vals):
            ax.text(b.get_x()+b.get_width()/2., v+1.2, f"{v:.0f}%", ha="center", va="bottom", fontsize=9)
        fig.tight_layout()

        with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_png:
            plot_path = tmp_png.name
        fig.savefig(plot_path, bbox_inches="tight")
        plt.close(fig)

        upload_path(
            drive, plot_path, out_id,
            title=re.sub(r"\.tif(f)?$", "", NATIONAL_NAME, flags=re.I) + "_vs_LANID.png"
        )
        try:
            os.remove(plot_path)
        except Exception:
            pass

        print(f"âœ… {NATIONAL_NAME}  IoU={iou:.3f}  Agreement={agree_pct:.1f}%  GMIA-NEXT only={gmia_pct:.1f}%  LANID only={lanid_pct:.1f}%")
        print("ðŸŽ‰ PNG + CSV + DIFF TIF uploaded to", BINARY_SUBFOLDER, "/Comparisons/")

    finally:
        for p in (us_local, lanid_local):
            try:
                os.remove(p)
            except Exception:
                pass

# ---------------------- CLI ----------------------
if __name__ == "__main__":
    try:
        drive  # noqa: F821
    except NameError:
        raise RuntimeError("PyDrive2 'drive' not found. Authenticate and expose a global `drive` before running.")
    compare_us_to_lanid(drive)


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Plot China GMIA vs otehr existing dataset (e.g., CIrrMap difference map with country boundary.

Uses *existing* files in Comparisons/:
  - China_binary_30m_vs_CIrrMap_diff.tif   (0=nodata, 1=agree, 2=GMIA-only, 3=CIrrMap-only)
  - gadm41_CHN_0.*                         (national boundary shapefile)

Output:
  - China_binary_30m_vs_CIrrMap_diff_map.png in the same Comparisons/ folder.
"""

import os
import re
import tempfile
import warnings

import numpy as np
import rasterio
from rasterio.enums import Resampling

import fiona
from shapely.geometry import shape
from shapely.ops import transform as shp_transform
from pyproj import Transformer

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch

warnings.filterwarnings("ignore", category=RuntimeWarning)

# ===================== CONFIG =====================
ROOT_FOLDER_ID      = os.environ.get("ROOT_FOLDER_ID", "18pQKnMMnLramhHRZSNwUJrLqG5DXNMmS")
PARENT_FOLDER_NAME  = "CountryModelPredicted"
PROB_NAMES          = ("Probability", "Probablity", "Prob")
BINARY_SUBFOLDER    = "Merged"
COMPARE_SUBFOLDER   = "Comparisons"

DIFF_NAME           = "China_binary_30m_vs_CIrrMap_diff.tif"
SHAPE_BASE          = "gadm41_CHN_0"          # Country boundary shapefile
PLOT_DPI            = 150                     # keep it modest

# Colors aligned with diff classes
COLOR_AGREE = "#1f77b4"   # 1  (blue)
COLOR_GMIA  = "#E69F00"   # 2
COLOR_LANID = "#8B0000"   # 3
# ==================================================


# ---------------------- Drive helpers ---------------------
def _dq(drive, q):
    return drive.ListFile({
        "q": q,
        "supportsAllDrives": True,
        "includeItemsFromAllDrives": True,
        "maxResults": 1000
    }).GetList()

def list_files(drive, parent_id):
    return _dq(drive, f"'{parent_id}' in parents and trashed=false")

def child_folders(drive, parent_id):
    return _dq(
        drive,
        f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder'"
    )

def get_subfolder_exact(drive, parent_id, name):
    res = _dq(
        drive,
        f"'{parent_id}' in parents and trashed=false "
        f"and mimeType='application/vnd.google-apps.folder' and title='{name}'"
    )
    return res[0]["id"] if res else None

def get_subfolder_fuzzy(drive, parent_id, desired):
    """Simple fuzzy finder for folders (handles 'Probability' typo)."""
    eid = get_subfolder_exact(drive, parent_id, desired)
    if eid:
        return eid
    key = re.sub(r"\s+", "", desired.lower())
    for f in child_folders(drive, parent_id):
        t = (f.get("title") or f.get("name") or "").lower()
        tkey = re.sub(r"\s+", "", t)
        if key in tkey or ("prob" in key and "prob" in tkey):
            return f["id"]
    return None

def download_to_temp(drive_file, local_path):
    drive_file.GetContentFile(local_path)

def upload_path(drive, local_path, parent_id, title=None):
    f = drive.CreateFile(
        {"title": title or os.path.basename(local_path),
         "parents":[{"id": parent_id}]}
    )
    f.SetContentFile(local_path)
    f.Upload()
    return f["id"]


# ------------------ Shapefile bundle downloader ---------
def download_shapefile_bundle(drive, folder_id, base):
    """
    Download gadm41_CHN_0.* into a temp dir with consistent basename.
    Returns local .shp path.
    """
    exts = [".shp", ".shx", ".dbf", ".prj", ".cpg"]
    items = {}
    for it in list_files(drive, folder_id):
        title = (it.get("title") or "").lower()
        for e in exts:
            if title == (base.lower() + e):
                items[e] = it

    if ".shp" not in items:
        raise FileNotFoundError(f"Could not find {base}.shp in Comparisons folder.")

    tmpdir = tempfile.mkdtemp(prefix="chn_boundary_")
    for e, it in items.items():
        lp = os.path.join(tmpdir, base + e)
        download_to_temp(it, lp)
    return os.path.join(tmpdir, base + ".shp")


# --------------------- Plotting logic --------------------
def plot_diff_with_boundary(drive):
    # 1) Navigate to Comparisons folder
    root_id = ROOT_FOLDER_ID
    cmp_id = get_subfolder_fuzzy(drive, root_id, PARENT_FOLDER_NAME)
    if not cmp_id:
        raise RuntimeError(f"Folder '{PARENT_FOLDER_NAME}' not found under ROOT.")

    prob_id = None
    for nm in PROB_NAMES:
        prob_id = get_subfolder_fuzzy(drive, cmp_id, nm)
        if prob_id:
            break
    if not prob_id:
        raise RuntimeError("Probability/Probablity/Prob folder not found.")

    binary_id = get_subfolder_fuzzy(drive, prob_id, BINARY_SUBFOLDER)
    if not binary_id:
        raise RuntimeError(f"Binary folder '{BINARY_SUBFOLDER}' not found.")

    comp_id = get_subfolder_fuzzy(drive, binary_id, COMPARE_SUBFOLDER)
    if not comp_id:
        raise RuntimeError(f"Comparisons folder '{COMPARE_SUBFOLDER}' not found inside {BINARY_SUBFOLDER}/.")

    # 2) Locate diff TIF directly by name
    diff_file = None
    for it in list_files(drive, comp_id):
        if it.get("mimeType") == "application/vnd.google-apps.folder":
            continue
        if (it.get("title") or "").strip().lower() == DIFF_NAME.lower():
            diff_file = it
            break
    if diff_file is None:
        raise FileNotFoundError(f"Diff raster '{DIFF_NAME}' not found in Comparisons/.")

    diff_tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
    download_to_temp(diff_file, diff_tmp)

    # 3) Download shapefile bundle once
    shp_path = download_shapefile_bundle(drive, comp_id, SHAPE_BASE)

    try:
        # ---- Read diff raster, but DOWNSAMPLED ----
        with rasterio.open(diff_tmp) as ds:
            H, W = ds.height, ds.width
            bounds = ds.bounds

            # target ~ 8 million pixels max
            max_pixels = 8_000_000
            scale_factor = max(1, int(np.sqrt((H * W) / max_pixels)))
            out_h = max(1, H // scale_factor)
            out_w = max(1, W // scale_factor)

            diff_arr = ds.read(
                1,
                out_shape=(out_h, out_w),
                resampling=Resampling.nearest
            ).astype("int16")

            raster_crs = ds.crs
            extent = (bounds.left, bounds.right, bounds.bottom, bounds.top)

        # ---- Read & reproject boundary ----
        geoms = []
        with fiona.open(shp_path) as src:
            shp_crs = src.crs
            if shp_crs and raster_crs and shp_crs != raster_crs:
                transformer = Transformer.from_crs(shp_crs, raster_crs, always_xy=True)
                project = lambda x, y, z=None: transformer.transform(x, y)
            else:
                project = None

            for feat in src:
                g = shape(feat["geometry"])
                if project is not None:
                    g = shp_transform(project, g)
                geoms.append(g)

        # ---- Colormap for classes 0â€“3 ----
        cmap = ListedColormap([
            (0, 0, 0, 0),                      # 0 nodata transparent
            (31/255, 119/255, 180/255, 1.0),   # 1 Agreement (blue)
            (230/255, 159/255, 0/255, 1.0),    # 2 GMIA-only
            (139/255, 0/255, 0/255, 1.0),      # 3 CIrrMap-only
        ])

        fig, ax = plt.subplots(figsize=(9, 8), dpi=PLOT_DPI)

        # Raster (downsampled)
        im = ax.imshow(
            diff_arr,
            origin="upper",
            cmap=cmap,
            vmin=0,
            vmax=3,
            extent=extent
        )

        # Boundary overlay
        for g in geoms:
            if g.geom_type == "MultiPolygon":
                polys = list(g.geoms)
            else:
                polys = [g]
            for poly in polys:
                x, y = poly.exterior.xy
                ax.plot(x, y, color="black", linewidth=0.6)
                for ring in poly.interiors:
                    xi, yi = ring.xy
                    ax.plot(xi, yi, color="black", linewidth=0.4)

        # No title
        # ax.set_title("China â€“ GMIA vs CIrrMap agreement / disagreement", fontsize=13)

        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlabel("")
        ax.set_ylabel("")
        ax.set_aspect("equal")

        # Turn off axes frame (removes rectangular bounding box)
        ax.set_axis_off()

        # Horizontal legend below
        legend_patches = [
            Patch(facecolor=COLOR_AGREE, edgecolor="none", label="Agreement (both 1)"),
            Patch(facecolor=COLOR_GMIA,  edgecolor="none", label="GMIA-NEXT only"),
            Patch(facecolor=COLOR_LANID, edgecolor="none", label="CIrrMap only"),
        ]
        ax.legend(
            handles=legend_patches,
            loc="upper center",
            bbox_to_anchor=(0.5, -0.06),
            ncol=3,
            frameon=False,
            fontsize=10
        )

        plt.tight_layout()

        # ---- Save & upload PNG ----
        out_png = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name
        fig.savefig(out_png, bbox_inches="tight", dpi=PLOT_DPI)
        plt.close(fig)

        out_title = re.sub(r"\.tif(f)?$", "", DIFF_NAME, flags=re.I) + "_map.png"
        upload_path(drive, out_png, comp_id, title=out_title)

        try:
            os.remove(out_png)
        except Exception:
            pass

        print(f"âœ… Saved comparison map PNG as '{out_title}' in Comparisons/")

    finally:
        try:
            os.remove(diff_tmp)
        except Exception:
            pass
        # shapefile tempdir will be cleaned by OS eventually


# ---------------------- CLI ----------------------
if __name__ == "__main__":
    try:
        drive  # noqa: F821
    except NameError:
        raise RuntimeError(
            "PyDrive2 'drive' not found. Authenticate and expose a global `drive` before running."
        )
    plot_diff_with_boundary(drive)


âœ… Saved comparison map PNG as 'Copy of Uzbekistan_AEI_binary_0_1_MAJ_k9_p50_vs_LANID_diff_map.png' in Comparisons/


In [None]:
"""
Plot US_binary.tif with:
  - Top panel: country-wide map (ADM0 boundary + irrigated in dark blue).
  - Bottom panel: zoomed inset, clipping US_binary to US_small shapefile.

Files (Google Drive layout):
  CountryModelPredicted_Cropland /
    Probability|Probablity|Prob /
      Merged /
        US_binary.tif
        US_small.(shp,shx,dbf,prj,cpg)
        Comparisons /
          US_shapeflile.(shp,shx,dbf,prj,cpg)
          --> output PNG will be saved here as:
              US_binary_with_zoom_map.png

Requirements:
  - PyDrive2-authenticated global `drive`
  - rasterio, numpy, matplotlib, fiona, shapely, pyproj
"""

import os
import re
import tempfile
import warnings

import numpy as np
import rasterio
from rasterio.enums import Resampling
from rasterio.mask import mask

import fiona
from shapely.geometry import shape, box
from shapely.ops import transform as shp_transform
from shapely.ops import unary_union
from pyproj import Transformer

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch

warnings.filterwarnings("ignore", category=RuntimeWarning)

# ===================== CONFIG =====================
ROOT_FOLDER_ID      = os.environ.get("ROOT_FOLDER_ID", "19Ee1kUar4oEabBObq7GcA-jDR7F29x2_")
PARENT_FOLDER_NAME  = "CountryModelPredicted_Cropland"
PROB_NAMES          = ("Probability", "Probablity", "Prob")
BINARY_SUBFOLDER    = "Merged"
COMPARE_SUBFOLDER   = "Comparisons"

CHINA_BINARY_NAME   = "China_binary.tif"
SMALL_BASE          = "China_small_area_shapefile"
CHN_ADM0_BASE       = "gadm41_CHN_0"

PLOT_DPI            = 160
# ==================================================


# ---------------------- Drive helpers ---------------------
def _dq(drive, q):
    return drive.ListFile({
        "q": q,
        "supportsAllDrives": True,
        "includeItemsFromAllDrives": True,
        "maxResults": 1000
    }).GetList()

def list_files(drive, parent_id):
    return _dq(drive, f"'{parent_id}' in parents and trashed=false")

def child_folders(drive, parent_id):
    return _dq(
        drive,
        f"'{parent_id}' in parents and trashed=false and mimeType='application/vnd.google-apps.folder'"
    )

def get_subfolder_exact(drive, parent_id, name):
    res = _dq(
        drive,
        f"'{parent_id}' in parents and trashed=false "
        f"and mimeType='application/vnd.google-apps.folder' and title='{name}'"
    )
    return res[0]["id"] if res else None

def get_subfolder_fuzzy(drive, parent_id, desired):
    """Simple fuzzy finder for folders (handles 'Probability' typo)."""
    eid = get_subfolder_exact(drive, parent_id, desired)
    if eid:
        return eid
    key = re.sub(r"\s+", "", desired.lower())
    for f in child_folders(drive, parent_id):
        t = (f.get("title") or f.get("name") or "").lower()
        tkey = re.sub(r"\s+", "", t)
        if key in tkey or ("prob" in key and "prob" in tkey):
            return f["id"]
    return None

def download_to_temp(drive_file, local_path):
    drive_file.GetContentFile(local_path)

def upload_path(drive, local_path, parent_id, title=None):
    f = drive.CreateFile(
        {"title": title or os.path.basename(local_path),
         "parents":[{"id": parent_id}]}
    )
    f.SetContentFile(local_path)
    f.Upload()
    return f["id"]

def download_shapefile_bundle(drive, folder_id, base):
    """
    Download base.(shp,shx,dbf,prj,cpg) from a folder into a temp dir.
    Returns path to the local .shp.
    """
    exts = [".shp", ".shx", ".dbf", ".prj", ".cpg"]
    items = {}
    for it in list_files(drive, folder_id):
        title = (it.get("title") or "").lower()
        for e in exts:
            if title == (base.lower() + e):
                items[e] = it

    if ".shp" not in items:
        raise FileNotFoundError(f"Could not find {base}.shp in folder.")

    tmpdir = tempfile.mkdtemp(prefix=base + "_")
    for e, it in items.items():
        lp = os.path.join(tmpdir, base + e)
        download_to_temp(it, lp)
    return os.path.join(tmpdir, base + ".shp")


# --------------------- Main plotting logic --------------------
def plot_china_binary_with_zoom(drive):
    # 1) Navigate to folders
    root_id = ROOT_FOLDER_ID
    cmp_id = get_subfolder_fuzzy(drive, root_id, PARENT_FOLDER_NAME)
    if not cmp_id:
        raise RuntimeError(f"Folder '{PARENT_FOLDER_NAME}' not found under ROOT.")

    prob_id = None
    for nm in PROB_NAMES:
        prob_id = get_subfolder_fuzzy(drive, cmp_id, nm)
        if prob_id:
            break
    if not prob_id:
        raise RuntimeError("Probability/Probablity/Prob folder not found.")

    binary_id = get_subfolder_fuzzy(drive, prob_id, BINARY_SUBFOLDER)
    if not binary_id:
        raise RuntimeError(f"Binary folder '{BINARY_SUBFOLDER}' not found.")

    comp_id = get_subfolder_fuzzy(drive, binary_id, COMPARE_SUBFOLDER)
    if not comp_id:
        raise RuntimeError(f"Comparisons folder '{COMPARE_SUBFOLDER}' not found inside {BINARY_SUBFOLDER}/.")

    # 2) Find US_binary.tif and small-area shapefile in Merged
    china_file = None
    small_shp_present = False
    for it in list_files(drive, binary_id):
        title = (it.get("title") or "").strip().lower()
        if title == CHINA_BINARY_NAME.lower():
            china_file = it
        if title.startswith(SMALL_BASE.lower()) and title.endswith(".shp"):
            small_shp_present = True
    if not china_file:
        raise FileNotFoundError(f"{CHINA_BINARY_NAME} not found in Merged/.")
    if not small_shp_present:
        raise FileNotFoundError(f"{SMALL_BASE}.shp not found in Merged/.")

    # 3) Download US_binary.tif and shapefiles
    china_tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".tif").name
    download_to_temp(china_file, china_tmp)

    shp_small_path = download_shapefile_bundle(drive, binary_id, SMALL_BASE)
    shp_chn_path   = download_shapefile_bundle(drive, comp_id, CHN_ADM0_BASE)

    try:
        # === Read US_binary with downsampling for the main map ===
        with rasterio.open(china_tmp) as src:
            H, W = src.height, src.width
            bounds = src.bounds
            raster_crs = src.crs
            nodata = src.nodata

            # target ~10M pixels max for the main map
            max_pixels = 10_000_000
            scale_factor = max(1, int(np.sqrt((H * W) / max_pixels)))
            out_h = max(1, H // scale_factor)
            out_w = max(1, W // scale_factor)

            china_arr = src.read(
                1,
                out_shape=(out_h, out_w),
                resampling=Resampling.nearest,
                masked=True
            )

            # Extent for imshow
            extent_main = (bounds.left, bounds.right, bounds.bottom, bounds.top)

        # === Read ADM0 boundary and reproject to raster CRS ===
        chn_geoms = []
        with fiona.open(shp_chn_path) as src:
            shp_crs = src.crs
            if shp_crs and raster_crs and shp_crs != raster_crs:
                transformer = Transformer.from_crs(shp_crs, raster_crs, always_xy=True)
                project = lambda x, y, z=None: transformer.transform(x, y)
            else:
                project = None

            for feat in src:
                g = shape(feat["geometry"])
                if project is not None:
                    g = shp_transform(project, g)
                chn_geoms.append(g)

        # === Read small-area shapefile and reproject to raster CRS ===
        small_geoms = []
        with fiona.open(shp_small_path) as src:
            shp_crs = src.crs
            if shp_crs and raster_crs and shp_crs != raster_crs:
                transformer = Transformer.from_crs(shp_crs, raster_crs, always_xy=True)
                project = lambda x, y, z=None: transformer.transform(x, y)
            else:
                project = None

            for feat in src:
                g = shape(feat["geometry"])
                if project is not None:
                    g = shp_transform(project, g)
                small_geoms.append(g)

        small_union = unary_union(small_geoms)
        minx, miny, maxx, maxy = small_union.bounds

        # === Crop US_binary to the small area for inset ===
        with rasterio.open(china_tmp) as src:
            inset_arr, inset_transform = mask(
                src,
                [small_union],
                crop=True,
                nodata=src.nodata
            )
            inset_arr = inset_arr[0]   # single band
            inset_bounds = rasterio.transform.array_bounds(
                inset_arr.shape[0], inset_arr.shape[1], inset_transform
            )
            # array_bounds returns (ymin, ymax, xmin, xmax)
            inset_extent = (inset_bounds[2], inset_bounds[3],
                            inset_bounds[0], inset_bounds[1])

        # === Set up colormap: 0=grey (non-irrigated), 1=dark blue (irrigated), nodata transparent ===
        cmap = ListedColormap([
            (0.9, 0.9, 0.9, 1.0),  # 0 non-irrigated
            (0.0, 0.0, 0.5, 1.0),  # 1 irrigated (dark blue)
        ])
        # for masked values (nodata)
        cmap.set_bad((0, 0, 0, 0))

        # === Build the figure ===
        fig, (ax_main, ax_inset) = plt.subplots(
            nrows=2, ncols=1,
            figsize=(7.5, 9.0),
            dpi=PLOT_DPI,
            gridspec_kw={"height_ratios": [3, 2]}
        )

        # ----- TOP PANEL: country-wide map -----
        im_main = ax_main.imshow(
            china_arr,
            origin="upper",
            cmap=cmap,
            vmin=0,
            vmax=1,
            extent=extent_main
        )

        # Draw national boundary
        for g in chn_geoms:
            if g.geom_type == "MultiPolygon":
                polys = list(g.geoms)
            else:
                polys = [g]
            for poly in polys:
                x, y = poly.exterior.xy
                ax_main.plot(x, y, color="black", linewidth=0.7)
                for ring in poly.interiors:
                    xi, yi = ring.xy
                    ax_main.plot(xi, yi, color="black", linewidth=0.4)

        ax_main.set_xticks([])
        ax_main.set_yticks([])
        ax_main.set_xlabel("")
        ax_main.set_ylabel("")
        ax_main.set_title(f"Irrigated area (dark blue) from {CHINA_BINARY_NAME}", fontsize=13)

        # NOTE: no bounding box rectangle around the inset area anymore

        # Legend (top panel)
        legend_patches = [
            Patch(facecolor=(0.0, 0.0, 0.5, 1.0), edgecolor="none", label="Irrigated"),
            Patch(facecolor=(0.9, 0.9, 0.9, 1.0), edgecolor="none", label="Non-irrigated"),
        ]
        ax_main.legend(
            handles=legend_patches,
            loc="lower right",
            frameon=True,
            fontsize=9
        )

        # ----- BOTTOM PANEL: zoomed inset -----
        # Mask nodata for inset
        inset_masked = np.ma.masked_equal(inset_arr, nodata) if nodata is not None else np.ma.masked_invalid(inset_arr)

        ax_inset.imshow(
            inset_masked,
            origin="upper",
            cmap=cmap,
            vmin=0,
            vmax=1,
            extent=inset_extent
        )

        # Overlay the small-area polygons
        for g in small_geoms:
            if g.geom_type == "MultiPolygon":
                polys = list(g.geoms)
            else:
                polys = [g]
            for poly in polys:
                x, y = poly.exterior.xy
                ax_inset.plot(x, y, color="black", linewidth=0.5)
                for ring in poly.interiors:
                    xi, yi = ring.xy
                    ax_inset.plot(xi, yi, color="black", linewidth=0.4)

        ax_inset.set_xticks([])
        ax_inset.set_yticks([])
        ax_inset.set_xlabel("")
        ax_inset.set_ylabel("")
        ax_inset.set_title(f"Zoomed irrigated area ({SMALL_BASE}.shp, dark blue)", fontsize=12)
        ax_inset.set_aspect("equal")

        plt.tight_layout()

        # ---- Save & upload PNG ----
        out_png = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name
        fig.savefig(out_png, bbox_inches="tight", dpi=PLOT_DPI)
        plt.close(fig)

        out_title = "US_binary_with_zoom_map.png"
        upload_path(drive, out_png, comp_id, title=out_title)

        try:
            os.remove(out_png)
        except Exception:
            pass

        print(f"âœ… Saved country-wide + zoom map PNG as '{out_title}' in Merged/Comparisons/")

    finally:
        try:
            os.remove(china_tmp)
        except Exception:
            pass
        # shapefile tempdirs will be cleaned by OS eventually


# ---------------------- CLI ----------------------
if __name__ == "__main__":
    try:
        drive  # noqa: F821
    except NameError:
        raise RuntimeError(
            "PyDrive2 'drive' not found. Authenticate and expose a global `drive` before running."
        )
    plot_china_binary_with_zoom(drive)


âœ… Saved country-wide + zoom map PNG as 'US_binary_with_zoom_map.png' in Merged/Comparisons/


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# ---------------- Data ----------------
countries = [
    "Italy", "Greece", "Spain", "France",
    "India", "China", "Brazil", "US",
    "Oceania", "Uzbekistan"
]

agreement  = [89.8, 96.8, 72.8, 53.5, 29.1, 54,   49, 53, 66, 59.4]
gmia_only  = [0,    0,    0,    0,    60.9, 21,  0, 26, 0,  39.7]
other_only = [10.2, 3.2,  27.2, 46.5, 10,   25,  51, 21, 34, 0.9]

# ---------------- Styling ----------------
plt.rcParams["font.family"] = "Tw Cen MT"
plt.rcParams["font.size"] = 24

x = np.arange(len(countries))
bar_width = 0.5

fig, ax = plt.subplots(figsize=(18, 9), dpi=160)

# Colors
COLOR_AGREE = "#1f4e79"   # dark blue
COLOR_GMIA  = "#E69F00"   # orange
COLOR_OTHER = "#8B0000"   # dark red

# Stacked bars
bars_agree = ax.bar(
    x, agreement, width=bar_width,
    label="Agreement", color=COLOR_AGREE
)
bars_gmia = ax.bar(
    x, gmia_only, width=bar_width,
    bottom=agreement,
    label="GMIA-NEXT Only", color=COLOR_GMIA
)
bottom_other = [a + g for a, g in zip(agreement, gmia_only)]
bars_other = ax.bar(
    x, other_only, width=bar_width,
    bottom=bottom_other,
    label="Other Dataset Only", color=COLOR_OTHER
)

# Axes formatting
ax.set_xticks(x)
ax.set_xticklabels(countries, rotation=35, ha="right")
ax.set_ylabel("Percentage (%)")
ax.set_ylim(0, 100)

ax.tick_params(axis="both", labelsize=22)

# Legend
ax.legend(
    loc="upper center",
    bbox_to_anchor=(0.5, -0.12),
    ncol=3,
    frameon=False
)

fig.tight_layout(rect=[0, 0.08, 1, 1])

plt.savefig(r"E:\Recovered_15_57_11\My Passport(F)\GII\Validation\bars.png")
plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# ---------------- Data (only requested countries) ----------------
countries = ["Italy", "Greece", "Spain", "France", "Uzbekistan"]

agreement  = [89.8, 96.8, 72.8, 53.5, 59.4]
gmia_only  = [0,    0,    0,    0,    39.7]
other_only = [10.2, 3.2,  27.2, 46.5, 0.9]

# ---------------- Styling ----------------
plt.rcParams["font.family"] = "Tw Cen MT"   # global font
plt.rcParams["font.size"] = 24

x = np.arange(len(countries))
bar_width = 0.5

fig, ax = plt.subplots(figsize=(18, 9), dpi=160)

# Colors
COLOR_AGREE = "#1f4e79"   # dark blue
COLOR_GMIA  = "#E69F00"   # orange
COLOR_OTHER = "#8B0000"   # dark red

# Stacked bars
bars_agree = ax.bar(
    x, agreement, width=bar_width,
    label="Agreement", color=COLOR_AGREE
)
bars_gmia = ax.bar(
    x, gmia_only, width=bar_width,
    bottom=agreement,
    label="GMIA-NEXT Only", color=COLOR_GMIA
)
bottom_other = [a + g for a, g in zip(agreement, gmia_only)]
bars_other = ax.bar(
    x, other_only, width=bar_width,
    bottom=bottom_other,
    label="Other Dataset Only", color=COLOR_OTHER
)

# Axes formatting
ax.set_xticks(x)
ax.set_xticklabels(countries, rotation=35, ha="right")
ax.set_ylabel("Percentage (%)")
ax.set_ylim(0, 100)

ax.tick_params(axis="both", labelsize=22)

# Legend
ax.legend(
    loc="upper center",
    bbox_to_anchor=(0.5, -0.12),
    ncol=3,
    frameon=False
)

fig.tight_layout(rect=[0, 0.08, 1, 1])

plt.savefig(r"E:\Recovered_15_57_11\My Passport(F)\GII\Validation\bars_europe_uzbekistan.png")
plt.show()


In [None]:
import os
import numpy as np
import geopandas as gpd
import rasterio
from rasterio.warp import reproject, Resampling
import matplotlib.pyplot as plt
from matplotlib.colors import BoundaryNorm, ListedColormap

# -------------------------------------------------------------------
# Paths
# -------------------------------------------------------------------
base_dir = r"G:\My Drive\Global_Irr_Maps\Global Maps"
ref_name = "GMIA-NEXT.tif"  # irrigated area (reference grid)
ref_path = os.path.join(base_dir, ref_name)

harvest_path = r"E:\MIRCA-OS\My Passport(F)\PhD_Udel\MICRA_2015\DOwnscaling\AH Grid\New folder\AH\15\Ara_15.tif"
continents_shp = r"E:\Recovered_15_57_11\My Passport(F)\COntinents onlt\Wiothout Rusia Bond\Continenets.shp"

# -------------------------------------------------------------------
# Helper: clean array (nodata, inf, negatives -> NaN)
# -------------------------------------------------------------------
def clean_array(arr, nodata=None):
    arr = arr.astype(np.float32)
    if nodata is not None:
        arr[arr == nodata] = np.nan
    arr[~np.isfinite(arr)] = np.nan  # inf / -inf -> NaN
    return arr

# -------------------------------------------------------------------
# Helper: reproject one raster band to match a reference raster
# -------------------------------------------------------------------
def reproject_to_match(src_path, ref_ds):
    with rasterio.open(src_path) as src:
        src_data = src.read(1).astype(np.float32)

        dst_data = np.empty((ref_ds.height, ref_ds.width), dtype=np.float32)

        reproject(
            source=src_data,
            destination=dst_data,
            src_transform=src.transform,
            src_crs=src.crs,
            dst_transform=ref_ds.transform,
            dst_crs=ref_ds.crs,
            resampling=Resampling.bilinear,
            src_nodata=src.nodata,
            dst_nodata=np.nan,
        )

    return dst_data

# -------------------------------------------------------------------
# Load reference raster (GMIA-NEXT) and harvested area grid
# -------------------------------------------------------------------
with rasterio.open(ref_path) as ref_ds:
    irr = ref_ds.read(1).astype(np.float32)  # irrigated area
    irr = clean_array(irr, nodata=ref_ds.nodata)

    # Reproject harvested area to match GMIA grid
    harv = reproject_to_match(harvest_path, ref_ds)
    harv = clean_array(harv, nodata=None)  # nodata already set to NaN in reproject

    # ----------------------------------------------------------------
    # Compute proportion: irrigated / harvested (%)
    # Zero irrigated OR zero harvested -> NaN (no color)
    # ----------------------------------------------------------------
    valid = (irr > 0) & (harv > 0)
    proportion = np.full_like(irr, np.nan, dtype=np.float32)
    proportion[valid] = irr[valid] / harv[valid]

    # Convert to percent
    prop_pct = proportion * 100.0

    # Clip to max 200% (same top class as original style)
    prop_pct = np.clip(prop_pct, 0, 200)

    # Get spatial extent for plotting
    bounds = ref_ds.bounds
    extent = [bounds.left, bounds.right, bounds.bottom, bounds.top]
    ras_crs = ref_ds.crs

# -------------------------------------------------------------------
# Load continents shapefile and reproject to raster CRS
# -------------------------------------------------------------------
continents = gpd.read_file(continents_shp)
if continents.crs != ras_crs:
    continents = continents.to_crs(ras_crs)

# -------------------------------------------------------------------
# Colormap and classification â€“ gray â†’ light blue â†’ dark blue
# Classes: 0.01â€“1, 1â€“3, 3â€“10, 10â€“30, 30â€“200 %
# -------------------------------------------------------------------
colors = [
    '#d9d9d9',  # light gray (lowest class)
    '#deebf7',  # very light blue
    '#9ecae1',  # light/medium blue
    '#4292c6',  # medium blue
    '#08519c'   # dark blue
]

bounds = [0.01, 1, 3, 10, 30, 200]
cmap = ListedColormap(colors)
norm = BoundaryNorm(bounds, cmap.N)

# Mask NaNs for plotting
prop_masked = np.ma.masked_invalid(prop_pct)

# NaNs and values below 0.01 both white
cmap.set_under('white')   # < 0.01
cmap.set_bad('white')     # NaN (0 irrigated / 0 harvested / outside)

# -------------------------------------------------------------------
# Plot (single map)
# -------------------------------------------------------------------
fig, ax = plt.subplots(1, 1, figsize=(15, 8), subplot_kw={'aspect': 'equal'})

img = ax.imshow(
    prop_masked,
    cmap=cmap,
    norm=norm,
    extent=extent,
    origin='upper'
)

# Continents boundaries
continents.boundary.plot(ax=ax, linewidth=1, color='black', alpha=0.7)

# ax.set_title(
#     "GMIA-NEXT irrigated area as % of harvested area",
#     fontsize=22,
#     fontname='Tw Cen MT'
# )
ax.set_axis_off()

# -------------------------------------------------------------------
# Colorbar (same style / breaks as your regional maps)
# -------------------------------------------------------------------
# [left, bottom, width, height] in figure coordinates
cbar_ax = fig.add_axes([0.15, 0.05, 0.7, 0.05])  # height from 0.02 -> 0.05

cbar = plt.colorbar(img, cax=cbar_ax, orientation='horizontal', extend='min')
cbar.set_label('Irrigated area(%)', fontsize=18, fontname='Tw Cen MT')
cbar.ax.tick_params(labelsize=14)

cbar.set_ticks(bounds)
cbar.set_ticklabels(["0", "1", "3", "10", "30", "100"])  # keep as in your original

# make the colorbar border a bit thicker (optional)
cbar.outline.set_linewidth(1.5)

fig.subplots_adjust(left=0.01, right=0.99, top=0.93, bottom=0.1)

# -------------------------------------------------------------------
# Save and show
# -------------------------------------------------------------------
out_path = r"G:\My Drive\Global_Irr_Maps\Global Maps\GMIA_NEXT_irrigated_share_of_harvested.png"
plt.savefig(out_path, dpi=300, bbox_inches='tight')
plt.show()


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import rasterio
from rasterio.warp import reproject, Resampling
import geopandas as gpd
import matplotlib.patches as mpatches
from matplotlib import font_manager

# =========================
# Paths and dataset names
# =========================

base_dir = r"G:\My Drive\Global_Irr_Maps\Global Maps"

ref_name = "GMIA-NEXT.tif"    # reference dataset
ref_path = os.path.join(base_dir, ref_name)

other_name = "AEI_2020.tif"   # comparison dataset
other_path = os.path.join(base_dir, other_name)

# Continents shapefile
continents_shp = r"E:\Recovered_15_57_11\My Passport(F)\COntinents onlt\Wiothout Rusia Bond\Continenets.shp"

# =========================
# Colors and fonts
# =========================

# Colorblind-friendly colors
dark_blue = '#0072B2'  # Agreement
orange    = '#E69F00'  # GMIA-NEXT only
dark_red  = '#8B0000'  # AEI_2020 only

# Font check for 'Tw Cen MT'
if 'Tw Cen MT' not in [f.name for f in font_manager.fontManager.ttflist]:
    print('The "Tw Cen MT" font is not available. Using default sans-serif.')
    title_font = 'sans-serif'
else:
    title_font = 'Tw Cen MT'

# =========================
# Helper functions
# =========================

def read_raster(file_path):
    """Read raster and return (data, transform, crs, nodata, meta)."""
    with rasterio.open(file_path) as src:
        data = src.read(1).astype(float)
        transform = src.transform
        crs = src.crs
        nodata = src.nodata
        meta = src.meta.copy()
    return data, transform, crs, nodata, meta

def clean_reference(data, nodata):
    """
    Clean GMIA-NEXT data:
      - nodata -> NaN (outside GMIA coverage)
      - extreme values -> NaN
      - 0 is valid absence within GMIA coverage
    """
    if nodata is not None:
        data[data == nodata] = np.nan
    data[(data < -1e6) | (data > 1e10)] = np.nan
    return data

def reproject_to_reference(other_path, ref_meta):
    """
    Reproject other raster (AEI_2020) to GMIA-NEXT grid.
    Values outside source coverage -> 0 (no irrigation).
    """
    with rasterio.open(other_path) as src:
        dst_data = np.zeros((ref_meta['height'], ref_meta['width']), dtype=float)
        reproject(
            source=src.read(1).astype(float),
            destination=dst_data,
            src_transform=src.transform,
            src_crs=src.crs,
            dst_transform=ref_meta['transform'],
            dst_crs=ref_meta['crs'],
            resampling=Resampling.nearest,
            dst_nodata=0
        )
    # Clean weird values: negative or huge -> 0
    dst_data[(dst_data < 0) | (dst_data > 1e10)] = 0
    return dst_data

def get_extent(transform, width, height):
    """Return (left, right, bottom, top) for imshow extent."""
    left = transform[2]
    top = transform[5]
    right = left + transform[0] * width
    bottom = top + transform[4] * height
    return (left, right, bottom, top)

def compute_masks(ref_data, other_data):
    """
    Compute agreement / disagreement masks.

    Valid pixels: where ref_data is not NaN (GMIA coverage).

    Categories:
      agree      : ref >= 1 & other >= 1
      gmia_only  : ref >= 1 & other == 0
      other_only : ref == 0 & other >= 1
    """
    valid = ~np.isnan(ref_data)

    ref_present   = (ref_data >= 1) & valid
    ref_absent    = (ref_data == 0) & valid
    other_present = (other_data >= 1) & valid
    other_absent  = (other_data == 0) & valid

    agree      = ref_present & other_present
    gmia_only  = ref_present & other_absent
    other_only = ref_absent & other_present

    return agree, gmia_only, other_only, valid

def calculate_percentages(agree, gmia_only, other_only):
    """Calculate percentages of each category over all comparison pixels."""
    agree_count = np.sum(agree)
    gmia_only_count = np.sum(gmia_only)
    other_only_count = np.sum(other_only)

    total = agree_count + gmia_only_count + other_only_count
    if total == 0:
        return 0.0, 0.0, 0.0

    agree_pct      = agree_count / total * 100.0
    gmia_only_pct  = gmia_only_count / total * 100.0
    other_only_pct = other_only_count / total * 100.0
    return agree_pct, gmia_only_pct, other_only_pct

# =========================
# Main
# =========================

def main():
    # ---- Load reference GMIA-NEXT ----
    print("Loading reference:", ref_path)
    ref_data, ref_transform, ref_crs, ref_nodata, ref_meta = read_raster(ref_path)
    ref_data = clean_reference(ref_data, ref_nodata)

    # Load continents and project to ref CRS
    continents = gpd.read_file(continents_shp)
    continents = continents.to_crs(ref_crs)

    # ---- Reproject AEI_2020 to GMIA grid ----
    print("Loading and reprojecting AEI_2020:", other_path)
    other_data = reproject_to_reference(other_path, ref_meta)

    # ---- Compute masks and percentages ----
    agree, gmia_only, other_only, valid = compute_masks(ref_data, other_data)
    agree_pct, gmia_only_pct, other_only_pct = calculate_percentages(agree, gmia_only, other_only)

    print("\nGlobal comparison GMIA-NEXT vs AEI_2020")
    print(f"Agreement (both â‰¥1): {agree_pct:.2f}%")
    print(f"GMIA-NEXT only (â‰¥1 vs 0): {gmia_only_pct:.2f}%")
    print(f"AEI_2020 only (0 vs â‰¥1): {other_only_pct:.2f}%")

    # ---- Figure: map + bar chart ----
    fig, (ax_map, ax_bar) = plt.subplots(
        1, 2,
        figsize=(18, 6),
        dpi=100,
        gridspec_kw={'width_ratios': [3, 1]}
    )

    # Build RGB map
    h, w = ref_data.shape
    combined_rgb = np.ones((h, w, 3), dtype=float)  # white background

    combined_rgb[agree]     = np.array([0, 114, 178]) / 255.0  # dark_blue
    combined_rgb[gmia_only] = np.array([230, 159, 0]) / 255.0  # orange
    combined_rgb[other_only]= np.array([139, 0, 0]) / 255.0    # dark_red

    # ref NaN (no coverage) stays white

    extent = get_extent(ref_transform, w, h)
    ax_map.imshow(combined_rgb, extent=extent)
    continents.boundary.plot(ax=ax_map, linewidth=0.5, edgecolor='black')

    ax_map.set_title("GMIA-NEXT vs AEI_2020", fontsize=18, fontname=title_font)
    ax_map.set_xticks([])
    ax_map.set_yticks([])
    ax_map.set_frame_on(False)

    # Bar chart
    labels = ['Agreement', 'GMIA-NEXT only', 'AEI_2020 only']
    percentages = [agree_pct, gmia_only_pct, other_only_pct]
    colors = [dark_blue, orange, dark_red]

    ax_bar.bar(labels, percentages, color=colors)
    ax_bar.set_ylabel('Percentage (%)', fontsize=14, fontname=title_font)
    ax_bar.set_ylim(0, 100)

    ax_bar.set_xticklabels(labels, rotation=20, ha='right',
                           fontsize=12, fontname=title_font)
    ax_bar.set_yticklabels(
        [f'{int(t)}%' for t in ax_bar.get_yticks()],
        fontsize=12,
        fontname=title_font
    )
    ax_bar.grid(axis='y', linestyle=':', linewidth=0.5, alpha=0.7)

    # Legend (below figure)
    legend_elements = [
        mpatches.Patch(color=dark_blue, label='Agreement (both â‰¥1)'),
        mpatches.Patch(color=orange, label='GMIA-NEXT only (â‰¥1 vs 0)'),
        mpatches.Patch(color=dark_red, label='AEI_2020 only (0 vs â‰¥1)'),
    ]
    fig.legend(
        handles=legend_elements,
        loc='lower center',
        ncol=3,
        fontsize=14,
        frameon=False,
        bbox_to_anchor=(0.5, 0.02)
    )

    plt.tight_layout(rect=[0, 0.05, 1, 1])

    # If you want to save the figure, uncomment this:
    out_png = os.path.join(base_dir, "GMIA_NEXT_vs_AEI_2020.png")
    plt.savefig(out_png, bbox_inches='tight', dpi=300)
# G:\My Drive\Global_Irr_Maps\Global Maps
    plt.show()

if __name__ == "__main__":
    main()


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import rasterio
from rasterio.warp import reproject, Resampling
import geopandas as gpd
import matplotlib.patches as mpatches
from matplotlib import font_manager

# =========================
# Paths and dataset names
# =========================

base_dir = r"G:\My Drive\Global_Irr_Maps\Global Maps"

ref_name = "GMIA-NEXT.tif"    # reference dataset
ref_path = os.path.join(base_dir, ref_name)

other_name = "AEI_2020.tif"   # comparison dataset
other_path = os.path.join(base_dir, other_name)

# Continents shapefile
continents_shp = r"E:\Recovered_15_57_11\My Passport(F)\COntinents onlt\Wiothout Rusia Bond\Continenets.shp"

# =========================
# Colors and fonts
# =========================

# Colorblind-friendly colors
dark_blue = '#0072B2'  # Agreement
orange    = '#E69F00'  # GMIA-NEXT only
dark_red  = '#8B0000'  # AEI_2020 only

# Font check for 'Tw Cen MT'
if 'Tw Cen MT' not in [f.name for f in font_manager.fontManager.ttflist]:
    print('The "Tw Cen MT" font is not available. Using default sans-serif.')
    title_font = 'sans-serif'
else:
    title_font = 'Tw Cen MT'

# =========================
# Helper functions
# =========================

def read_raster(file_path):
    """Read raster and return (data, transform, crs, nodata, meta)."""
    with rasterio.open(file_path) as src:
        data = src.read(1).astype(float)
        transform = src.transform
        crs = src.crs
        nodata = src.nodata
        meta = src.meta.copy()
    return data, transform, crs, nodata, meta

def clean_reference(data, nodata):
    """
    Clean GMIA-NEXT data:
      - nodata -> NaN (outside GMIA coverage)
      - extreme values -> NaN
      - 0 is valid absence within GMIA coverage
    """
    if nodata is not None:
        data[data == nodata] = np.nan
    data[(data < -1e6) | (data > 1e10)] = np.nan
    return data

def reproject_to_reference(other_path, ref_meta):
    """
    Reproject other raster (AEI_2020) to GMIA-NEXT grid.
    Values outside source coverage -> 0 (no irrigation).
    """
    with rasterio.open(other_path) as src:
        dst_data = np.zeros((ref_meta['height'], ref_meta['width']), dtype=float)
        reproject(
            source=src.read(1).astype(float),
            destination=dst_data,
            src_transform=src.transform,
            src_crs=src.crs,
            dst_transform=ref_meta['transform'],
            dst_crs=ref_meta['crs'],
            resampling=Resampling.nearest,
            dst_nodata=0
        )
    # Clean weird values: negative or huge -> 0
    dst_data[(dst_data < 0) | (dst_data > 1e10)] = 0
    return dst_data

def get_extent(transform, width, height):
    """Return (left, right, bottom, top) for imshow extent."""
    left = transform[2]
    top = transform[5]
    right = left + transform[0] * width
    bottom = top + transform[4] * height
    return (left, right, bottom, top)

def compute_masks(ref_data, other_data):
    """
    Compute agreement / disagreement masks.

    Valid pixels: where ref_data is not NaN (GMIA coverage).

    Categories:
      agree      : ref >= 1 & other >= 1
      gmia_only  : ref >= 1 & other == 0
      other_only : ref == 0 & other >= 1
    """
    valid = ~np.isnan(ref_data)

    ref_present   = (ref_data >= 1) & valid
    ref_absent    = (ref_data == 0) & valid
    other_present = (other_data >= 1) & valid
    other_absent  = (other_data == 0) & valid

    agree      = ref_present & other_present
    gmia_only  = ref_present & other_absent
    other_only = ref_absent & other_present

    return agree, gmia_only, other_only, valid

def calculate_percentages(agree, gmia_only, other_only):
    """Calculate percentages of each category over all comparison pixels."""
    agree_count = np.sum(agree)
    gmia_only_count = np.sum(gmia_only)
    other_only_count = np.sum(other_only)

    total = agree_count + gmia_only_count + other_only_count
    if total == 0:
        return 0.0, 0.0, 0.0

    agree_pct      = agree_count / total * 100.0
    gmia_only_pct  = gmia_only_count / total * 100.0
    other_only_pct = other_only_count / total * 100.0
    return agree_pct, gmia_only_pct, other_only_pct

# =========================
# Main
# =========================

def main():
    # ---- Load reference GMIA-NEXT ----
    print("Loading reference:", ref_path)
    ref_data, ref_transform, ref_crs, ref_nodata, ref_meta = read_raster(ref_path)
    ref_data = clean_reference(ref_data, ref_nodata)

    # Load continents and project to ref CRS
    continents = gpd.read_file(continents_shp)
    continents = continents.to_crs(ref_crs)

    # ---- Reproject AEI_2020 to GMIA grid ----
    print("Loading and reprojecting AEI_2020:", other_path)
    other_data = reproject_to_reference(other_path, ref_meta)

    # ---- Compute masks and percentages ----
    agree, gmia_only, other_only, valid = compute_masks(ref_data, other_data)
    agree_pct, gmia_only_pct, other_only_pct = calculate_percentages(agree, gmia_only, other_only)

    print("\nGlobal comparison GMIA-NEXT vs AEI_2020")
    print(f"Agreement (both â‰¥1): {agree_pct:.2f}%")
    print(f"GMIA-NEXT only (â‰¥1 vs 0): {gmia_only_pct:.2f}%")
    print(f"AEI_2020 only (0 vs â‰¥1): {other_only_pct:.2f}%")

    # ---- Figure: map + bar chart ----
    fig, (ax_map, ax_bar) = plt.subplots(
        1, 2,
        figsize=(18, 6),
        dpi=100,
        gridspec_kw={'width_ratios': [3, 1]}
    )

    # Build RGB map
    h, w = ref_data.shape
    combined_rgb = np.ones((h, w, 3), dtype=float)  # white background

    combined_rgb[agree]     = np.array([0, 114, 178]) / 255.0  # dark_blue
    combined_rgb[gmia_only] = np.array([230, 159, 0]) / 255.0  # orange
    combined_rgb[other_only]= np.array([139, 0, 0]) / 255.0    # dark_red

    # ref NaN (no coverage) stays white

    extent = get_extent(ref_transform, w, h)
    ax_map.imshow(combined_rgb, extent=extent)
    continents.boundary.plot(ax=ax_map, linewidth=0.5, edgecolor='black')

    ax_map.set_title("GMIA-NEXT vs AEI_2020", fontsize=18, fontname=title_font)
    ax_map.set_xticks([])
    ax_map.set_yticks([])
    ax_map.set_frame_on(False)

    # Bar chart
    labels = ['Agreement', 'GMIA-NEXT only', 'AEI_2020 only']
    percentages = [agree_pct, gmia_only_pct, other_only_pct]
    colors = [dark_blue, orange, dark_red]

    ax_bar.bar(labels, percentages, color=colors)
    ax_bar.set_ylabel('Percentage (%)', fontsize=14, fontname=title_font)
    ax_bar.set_ylim(0, 100)

    ax_bar.set_xticklabels(labels, rotation=20, ha='right',
                           fontsize=12, fontname=title_font)
    ax_bar.set_yticklabels(
        [f'{int(t)}%' for t in ax_bar.get_yticks()],
        fontsize=12,
        fontname=title_font
    )
    ax_bar.grid(axis='y', linestyle=':', linewidth=0.5, alpha=0.7)

    # Legend (below figure)
    legend_elements = [
        mpatches.Patch(color=dark_blue, label='Agreement (both â‰¥1)'),
        mpatches.Patch(color=orange, label='GMIA-NEXT only (â‰¥1 vs 0)'),
        mpatches.Patch(color=dark_red, label='AEI_2020 only (0 vs â‰¥1)'),
    ]
    fig.legend(
        handles=legend_elements,
        loc='lower center',
        ncol=3,
        fontsize=14,
        frameon=False,
        bbox_to_anchor=(0.5, 0.02)
    )

    plt.tight_layout(rect=[0, 0.05, 1, 1])

    # If you want to save the figure, uncomment this:
    out_png = os.path.join(base_dir, "GMIA_NEXT_vs_AEI_2020.png")
    plt.savefig(out_png, bbox_inches='tight', dpi=300)
# G:\My Drive\Global_Irr_Maps\Global Maps
    plt.show()

if __name__ == "__main__":
    main()


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import rasterio
from rasterio.warp import reproject, Resampling
import geopandas as gpd
import matplotlib.patches as mpatches
from matplotlib import font_manager

# =========================
# Paths and dataset names
# =========================

base_dir = r"G:\My Drive\Global_Irr_Maps\Global Maps"

ref_name = "GMIA-NEXT.tif"    # reference dataset
ref_path = os.path.join(base_dir, ref_name)

other_name = "Mejer_2020.tif"   # comparison dataset
other_path = os.path.join(base_dir, other_name)

# Continents shapefile
continents_shp = r"E:\Recovered_15_57_11\My Passport(F)\COntinents onlt\Wiothout Rusia Bond\Continenets.shp"

# =========================
# Colors and fonts
# =========================

# Colorblind-friendly colors
dark_blue = '#0072B2'  # Agreement
orange    = '#E69F00'  # GMIA-NEXT only
dark_red  = '#8B0000'  # AEI_2020 only

# Font check for 'Tw Cen MT'
if 'Tw Cen MT' not in [f.name for f in font_manager.fontManager.ttflist]:
    print('The "Tw Cen MT" font is not available. Using default sans-serif.')
    title_font = 'sans-serif'
else:
    title_font = 'Tw Cen MT'

# =========================
# Helper functions
# =========================

def read_raster(file_path):
    """Read raster and return (data, transform, crs, nodata, meta)."""
    with rasterio.open(file_path) as src:
        data = src.read(1).astype(float)
        transform = src.transform
        crs = src.crs
        nodata = src.nodata
        meta = src.meta.copy()
    return data, transform, crs, nodata, meta

def clean_reference(data, nodata):
    """
    Clean GMIA-NEXT data:
      - nodata -> NaN (outside GMIA coverage)
      - extreme values -> NaN
      - 0 is valid absence within GMIA coverage
    """
    if nodata is not None:
        data[data == nodata] = np.nan
    data[(data < -1e6) | (data > 1e10)] = np.nan
    return data

def reproject_to_reference(other_path, ref_meta):
    """
    Reproject other raster (AEI_2020) to GMIA-NEXT grid.
    Values outside source coverage -> 0 (no irrigation).
    """
    with rasterio.open(other_path) as src:
        dst_data = np.zeros((ref_meta['height'], ref_meta['width']), dtype=float)
        reproject(
            source=src.read(1).astype(float),
            destination=dst_data,
            src_transform=src.transform,
            src_crs=src.crs,
            dst_transform=ref_meta['transform'],
            dst_crs=ref_meta['crs'],
            resampling=Resampling.nearest,
            dst_nodata=0
        )
    # Clean weird values: negative or huge -> 0
    dst_data[(dst_data < 0) | (dst_data > 1e10)] = 0
    return dst_data

def get_extent(transform, width, height):
    """Return (left, right, bottom, top) for imshow extent."""
    left = transform[2]
    top = transform[5]
    right = left + transform[0] * width
    bottom = top + transform[4] * height
    return (left, right, bottom, top)

def compute_masks(ref_data, other_data):
    """
    Compute agreement / disagreement masks.

    Valid pixels: where ref_data is not NaN (GMIA coverage).

    Categories:
      agree      : ref >= 1 & other >= 1
      gmia_only  : ref >= 1 & other == 0
      other_only : ref == 0 & other >= 1
    """
    valid = ~np.isnan(ref_data)

    ref_present   = (ref_data >= 1) & valid
    ref_absent    = (ref_data == 0) & valid
    other_present = (other_data >= 1) & valid
    other_absent  = (other_data == 0) & valid

    agree      = ref_present & other_present
    gmia_only  = ref_present & other_absent
    other_only = ref_absent & other_present

    return agree, gmia_only, other_only, valid

def calculate_percentages(agree, gmia_only, other_only):
    """Calculate percentages of each category over all comparison pixels."""
    agree_count = np.sum(agree)
    gmia_only_count = np.sum(gmia_only)
    other_only_count = np.sum(other_only)

    total = agree_count + gmia_only_count + other_only_count
    if total == 0:
        return 0.0, 0.0, 0.0

    agree_pct      = agree_count / total * 100.0
    gmia_only_pct  = gmia_only_count / total * 100.0
    other_only_pct = other_only_count / total * 100.0
    return agree_pct, gmia_only_pct, other_only_pct

# =========================
# Main
# =========================

def main():
    # ---- Load reference GMIA-NEXT ----
    print("Loading reference:", ref_path)
    ref_data, ref_transform, ref_crs, ref_nodata, ref_meta = read_raster(ref_path)
    ref_data = clean_reference(ref_data, ref_nodata)

    # Load continents and project to ref CRS
    continents = gpd.read_file(continents_shp)
    continents = continents.to_crs(ref_crs)

    # ---- Reproject AEI_2020 to GMIA grid ----
    print("Loading and reprojecting AEI_2020:", other_path)
    other_data = reproject_to_reference(other_path, ref_meta)

    # ---- Compute masks and percentages ----
    agree, gmia_only, other_only, valid = compute_masks(ref_data, other_data)
    agree_pct, gmia_only_pct, other_only_pct = calculate_percentages(agree, gmia_only, other_only)

    print("\nGlobal comparison GMIA-NEXT vs AEI_2020")
    print(f"Agreement (both â‰¥1): {agree_pct:.2f}%")
    print(f"GMIA-NEXT only (â‰¥1 vs 0): {gmia_only_pct:.2f}%")
    print(f"AEI_2020 only (0 vs â‰¥1): {other_only_pct:.2f}%")

    # ---- Figure: map + bar chart ----
    fig, (ax_map, ax_bar) = plt.subplots(
        1, 2,
        figsize=(18, 6),
        dpi=100,
        gridspec_kw={'width_ratios': [3, 1]}
    )

    # Build RGB map
    h, w = ref_data.shape
    combined_rgb = np.ones((h, w, 3), dtype=float)  # white background

    combined_rgb[agree]     = np.array([0, 114, 178]) / 255.0  # dark_blue
    combined_rgb[gmia_only] = np.array([230, 159, 0]) / 255.0  # orange
    combined_rgb[other_only]= np.array([139, 0, 0]) / 255.0    # dark_red

    # ref NaN (no coverage) stays white

    extent = get_extent(ref_transform, w, h)
    ax_map.imshow(combined_rgb, extent=extent)
    continents.boundary.plot(ax=ax_map, linewidth=0.5, edgecolor='black')

    # ax_map.set_title("GMIA-NEXT vs AEI_2020", fontsize=18, fontname=title_font)
    ax_map.set_xticks([])
    ax_map.set_yticks([])
    ax_map.set_frame_on(False)

    # Bar chart
    labels = ['Agreement', 'GMIA-NEXT only', 'Mejer only']
    percentages = [agree_pct, gmia_only_pct, other_only_pct]
    colors = [dark_blue, orange, dark_red]

    ax_bar.bar(labels, percentages, color=colors)
    ax_bar.set_ylabel('Percentage (%)', fontsize=14, fontname=title_font)
    ax_bar.set_ylim(0, 100)

    ax_bar.set_xticklabels(labels, rotation=20, ha='right',
                           fontsize=12, fontname=title_font)
    ax_bar.set_yticklabels(
        [f'{int(t)}%' for t in ax_bar.get_yticks()],
        fontsize=12,
        fontname=title_font
    )
    ax_bar.grid(axis='y', linestyle=':', linewidth=0.5, alpha=0.7)

    # Legend (below figure)
    legend_elements = [
        mpatches.Patch(color=dark_blue, label='Agreement (both â‰¥1)'),
        mpatches.Patch(color=orange, label='GMIA-NEXT only (â‰¥1 vs 0)'),
        mpatches.Patch(color=dark_red, label='AEI_2020 only (0 vs â‰¥1)'),
    ]
    fig.legend(
        handles=legend_elements,
        loc='lower center',
        ncol=3,
        fontsize=14,
        frameon=False,
        bbox_to_anchor=(0.5, 0.02)
    )

    plt.tight_layout(rect=[0, 0.05, 1, 1])

    # If you want to save the figure, uncomment this:
    out_png = os.path.join(base_dir, "GMIA_NEXT_vs_Mietr.png")
    plt.savefig(out_png, bbox_inches='tight', dpi=300)
# G:\My Drive\Global_Irr_Maps\Global Maps
    plt.show()

if __name__ == "__main__":
    main()


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import rasterio
from rasterio.warp import reproject, Resampling
from scipy.stats import linregress, pearsonr

# ------------------------------------------------------------
# Global Matplotlib style â€“ force Tw Cen MT everywhere
# ------------------------------------------------------------
plt.rcParams['font.family'] = 'Tw Cen MT'
plt.rcParams['font.sans-serif'] = ['Tw Cen MT']
plt.rcParams['axes.unicode_minus'] = False  # use "-" instead of Unicode minus
font_size = 24
nan_value = 0  # sentinel used in your other scripts

# ------------------------------------------------------------
# Helper functions
# ------------------------------------------------------------
def clean_raster_data(raster_data, nan_value=None):
    """
    Convert to float32 and:
    - replace nan_value (if given) with np.nan
    - set all +/-inf to np.nan
    - set all negative values to np.nan
    """
    arr = raster_data.astype(np.float32)

    if nan_value is not None:
        arr = np.where(arr == nan_value, np.nan, arr)

    # remove +inf / -inf
    arr[~np.isfinite(arr)] = np.nan

    # remove negatives
    arr = np.where(arr < 0, np.nan, arr)

    return arr


def reproject_to_match(src_path, ref_ds):
    """
    Reproject src raster to match reference (GMIA-NEXT) grid.
    """
    with rasterio.open(src_path) as src:
        src_data = src.read(1).astype(np.float32)

        dst_data = np.empty((ref_ds.height, ref_ds.width), dtype=np.float32)

        reproject(
            source=src_data,
            destination=dst_data,
            src_transform=src.transform,
            src_crs=src.crs,
            dst_transform=ref_ds.transform,
            dst_crs=ref_ds.crs,
            resampling=Resampling.bilinear,
            src_nodata=src.nodata,
            dst_nodata=np.nan,
        )

    return dst_data


def plot_scatter_and_regression(ax, x, y, ds_name, n_valid, common_min, common_max):
    """
    Single panel, same style as the regional MIRCA hexbin plots.
    x = comparison dataset
    y = GMIA-NEXT (our dataset)
    """
    # Pearson correlation
    r, _ = pearsonr(x, y)

    # Hexbin with inferno + log(count), like your regional code
    hb = ax.hexbin(
        x, y,
        gridsize=30,
        cmap='inferno',
        bins='log',
        extent=(common_min, common_max, common_min, common_max)
    )

    # Per-subplot colorbar
    cbar = ax.figure.colorbar(hb, ax=ax)
    cbar.ax.tick_params(labelsize=font_size)
    cbar.set_label('Count', size=font_size)

    # 1:1 line
    ax.plot(
        [common_min, common_max],
        [common_min, common_max],
        color='black', linestyle='-', linewidth=2
    )

    # Regression line
    slope, intercept = linregress(x, y)[:2]
    ax.plot(
        [common_min, common_max],
        [intercept + slope * common_min,
         intercept + slope * common_max],
        color='red', linewidth=2
    )

    # Annotation box
    text_x = common_min + (common_max - common_min) * 0.05
    text_y = common_max - (common_max - common_min) * 0.1
    ax.text(
        text_x, text_y,
        f'r = {r:.2f}\nn = {n_valid:,}',
        color='black', ha='left', va='top',
        backgroundcolor='white', fontsize=font_size
    )

    ax.set_title(ds_name, fontsize=font_size)
    ax.set_xlim(common_min, common_max)
    ax.set_ylim(common_min, common_max)

# ------------------------------------------------------------
# Paths & datasets
# ------------------------------------------------------------
base_dir = r"G:\My Drive\Global_Irr_Maps\Global Maps"

ref_name = "GMIA-NEXT.tif"
ref_path = os.path.join(base_dir, ref_name)

# three comparison rasters
datasets = {
    "AEI_2020":    "AEI_2020.tif",
    "Meier":  "Meier_2018.tif",
    "Nagaraj": "Nagaraj_2018.tif",
}

# ------------------------------------------------------------
# Load GMIA-NEXT and build mask
# ------------------------------------------------------------
with rasterio.open(ref_path) as ref_ds:
    ref_raw = ref_ds.read(1)
    ref_data = clean_raster_data(ref_raw, nan_value=nan_value)

    # GMIA footprint: finite and > 0 (excludes Russia / non-covered areas)
    gmia_mask = np.isfinite(ref_data) & (ref_data > 0)

    # 1 Ã— 3 figure
    fig, axs = plt.subplots(1, 3, figsize=(27, 9))
    axs = axs.ravel()

    for i, (ds_name, fname) in enumerate(datasets.items()):
        comp_path = os.path.join(base_dir, fname)
        if not os.path.exists(comp_path):
            print(f"Warning: {comp_path} not found, skipping.")
            continue

        with rasterio.open(comp_path) as ds:
            # Reproject if needed
            if (ds.crs != ref_ds.crs or
                ds.transform != ref_ds.transform or
                ds.width != ref_ds.width or
                ds.height != ref_ds.height):
                other_raw = reproject_to_match(comp_path, ref_ds)
            else:
                other_raw = ds.read(1)

        other_data = clean_raster_data(other_raw, nan_value=nan_value)

        # ----------------------------------------------------
        # Mask by GMIA footprint and put:
        #   x = comparison dataset
        #   y = GMIA-NEXT
        # ----------------------------------------------------
        x_full = other_data.copy()
        y_full = ref_data.copy()

        x_full[~gmia_mask] = np.nan
        y_full[~gmia_mask] = np.nan

        # Flatten & drop non-finite
        x = x_full.ravel()
        y = y_full.ravel()
        mask = np.isfinite(x) & np.isfinite(y)
        x = x[mask]
        y = y[mask]

        n_valid = len(x)
        print(f'{ds_name}: Number of valid data points = {n_valid:,}')

        ax = axs[i]
        if n_valid < 2:
            ax.text(
                0.5, 0.5, 'Not enough data',
                transform=ax.transAxes,
                ha='center', va='center',
                fontsize=font_size
            )
            ax.set_title(ds_name, fontsize=font_size)
            continue

        # common axis limits for this pair
        common_min = min(np.nanmin(x), np.nanmin(y))
        common_max = max(np.nanmax(x), np.nanmax(y))

        plot_scatter_and_regression(
            ax, x, y, ds_name,
            n_valid, common_min, common_max
        )

    # Shared axis labels (our dataset = y-axis)
    fig.text(
        0.5, 0.02,
        'Comparison dataset irrigated area (ha)',
        ha='center', fontsize=30
    )
    fig.text(
        0.02, 0.5,
        'GMIA-NEXT irrigated area (ha)',
        va='center', rotation='vertical', fontsize=30
    )

    plt.tight_layout(rect=[0.06, 0.06, 1, 0.96])

    out_path = os.path.join(base_dir, "GMIA_NEXT_vs_AEI_Meier_Nagaraj_inferno_GMIA_mask.png")
    plt.savefig(out_path, bbox_inches='tight', dpi=300)
    plt.show()


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import rasterio
from rasterio.warp import reproject, Resampling
import geopandas as gpd
import matplotlib.patches as mpatches
from matplotlib import font_manager

# =========================
# Paths and dataset names
# =========================

base_dir = r"G:\My Drive\Global_Irr_Maps\Global Maps"

ref_name = "GLAD.tif"    # reference dataset
ref_path = os.path.join(base_dir, ref_name)

other_name = "AEI_2020.tif"   # comparison dataset
other_path = os.path.join(base_dir, other_name)

# Continents shapefile
continents_shp = r"E:\Recovered_15_57_11\My Passport(F)\COntinents onlt\Wiothout Rusia Bond\Continenets.shp"

# =========================
# Colors and fonts
# =========================

# Colorblind-friendly colors
dark_blue = '#0072B2'  # Agreement
orange    = '#E69F00'  # GMIA-NEXT only
dark_red  = '#8B0000'  # AEI_2020 only

# Font check for 'Tw Cen MT'
if 'Tw Cen MT' not in [f.name for f in font_manager.fontManager.ttflist]:
    print('The "Tw Cen MT" font is not available. Using default sans-serif.')
    title_font = 'sans-serif'
else:
    title_font = 'Tw Cen MT'

# =========================
# Helper functions
# =========================

def read_raster(file_path):
    """Read raster and return (data, transform, crs, nodata, meta)."""
    with rasterio.open(file_path) as src:
        data = src.read(1).astype(float)
        transform = src.transform
        crs = src.crs
        nodata = src.nodata
        meta = src.meta.copy()
    return data, transform, crs, nodata, meta

def clean_reference(data, nodata):
    """
    Clean GMIA-NEXT data:
      - nodata -> NaN (outside GMIA coverage)
      - extreme values -> NaN
      - 0 is valid absence within GMIA coverage
    """
    if nodata is not None:
        data[data == nodata] = np.nan
    data[(data < -1e6) | (data > 1e10)] = np.nan
    return data

def reproject_to_reference(other_path, ref_meta):
    """
    Reproject other raster (AEI_2020) to GMIA-NEXT grid.
    Values outside source coverage -> 0 (no irrigation).
    """
    with rasterio.open(other_path) as src:
        dst_data = np.zeros((ref_meta['height'], ref_meta['width']), dtype=float)
        reproject(
            source=src.read(1).astype(float),
            destination=dst_data,
            src_transform=src.transform,
            src_crs=src.crs,
            dst_transform=ref_meta['transform'],
            dst_crs=ref_meta['crs'],
            resampling=Resampling.nearest,
            dst_nodata=0
        )
    # Clean weird values: negative or huge -> 0
    dst_data[(dst_data < 0) | (dst_data > 1e10)] = 0
    return dst_data

def get_extent(transform, width, height):
    """Return (left, right, bottom, top) for imshow extent."""
    left = transform[2]
    top = transform[5]
    right = left + transform[0] * width
    bottom = top + transform[4] * height
    return (left, right, bottom, top)

def compute_masks(ref_data, other_data):
    """
    Compute agreement / disagreement masks.

    Valid pixels: where ref_data is not NaN (GMIA coverage).

    Categories:
      agree      : ref >= 1 & other >= 1
      gmia_only  : ref >= 1 & other == 0
      other_only : ref == 0 & other >= 1
    """
    valid = ~np.isnan(ref_data)

    ref_present   = (ref_data >= 1) & valid
    ref_absent    = (ref_data == 0) & valid
    other_present = (other_data >= 1) & valid
    other_absent  = (other_data == 0) & valid

    agree      = ref_present & other_present
    gmia_only  = ref_present & other_absent
    other_only = ref_absent & other_present

    return agree, gmia_only, other_only, valid

def calculate_percentages(agree, gmia_only, other_only):
    """Calculate percentages of each category over all comparison pixels."""
    agree_count = np.sum(agree)
    gmia_only_count = np.sum(gmia_only)
    other_only_count = np.sum(other_only)

    total = agree_count + gmia_only_count + other_only_count
    if total == 0:
        return 0.0, 0.0, 0.0

    agree_pct      = agree_count / total * 100.0
    gmia_only_pct  = gmia_only_count / total * 100.0
    other_only_pct = other_only_count / total * 100.0
    return agree_pct, gmia_only_pct, other_only_pct

# =========================
# Main
# =========================

def main():
    # ---- Load reference GMIA-NEXT ----
    print("Loading reference:", ref_path)
    ref_data, ref_transform, ref_crs, ref_nodata, ref_meta = read_raster(ref_path)
    ref_data = clean_reference(ref_data, ref_nodata)

    # Load continents and project to ref CRS
    continents = gpd.read_file(continents_shp)
    continents = continents.to_crs(ref_crs)

    # ---- Reproject AEI_2020 to GMIA grid ----
    print("Loading and reprojecting AEI_2020:", other_path)
    other_data = reproject_to_reference(other_path, ref_meta)

    # ---- Compute masks and percentages ----
    agree, gmia_only, other_only, valid = compute_masks(ref_data, other_data)
    agree_pct, gmia_only_pct, other_only_pct = calculate_percentages(agree, gmia_only, other_only)

    print("\nGlobal comparison GMIA-NEXT vs AEI_2020")
    print(f"Agreement (both â‰¥1): {agree_pct:.2f}%")
    print(f"GMIA-NEXT only (â‰¥1 vs 0): {gmia_only_pct:.2f}%")
    print(f"AEI_2020 only (0 vs â‰¥1): {other_only_pct:.2f}%")

    # ---- Figure: map + bar chart ----
    fig, (ax_map, ax_bar) = plt.subplots(
        1, 2,
        figsize=(18, 6),
        dpi=100,
        gridspec_kw={'width_ratios': [3, 1]}
    )

    # Build RGB map
    h, w = ref_data.shape
    combined_rgb = np.ones((h, w, 3), dtype=float)  # white background

    combined_rgb[agree]     = np.array([0, 114, 178]) / 255.0  # dark_blue
    combined_rgb[gmia_only] = np.array([230, 159, 0]) / 255.0  # orange
    combined_rgb[other_only]= np.array([139, 0, 0]) / 255.0    # dark_red

    # ref NaN (no coverage) stays white

    extent = get_extent(ref_transform, w, h)
    ax_map.imshow(combined_rgb, extent=extent)
    continents.boundary.plot(ax=ax_map, linewidth=0.5, edgecolor='black')

    # ax_map.set_title("GMIA-NEXT vs AEI_2020", fontsize=18, fontname=title_font)
    ax_map.set_xticks([])
    ax_map.set_yticks([])
    ax_map.set_frame_on(False)

    # Bar chart
    labels = ['Agreement', 'GLAD only only', 'AEI_2020 only']
    percentages = [agree_pct, gmia_only_pct, other_only_pct]
    colors = [dark_blue, orange, dark_red]

    ax_bar.bar(labels, percentages, color=colors)
    ax_bar.set_ylabel('Percentage (%)', fontsize=14, fontname=title_font)
    ax_bar.set_ylim(0, 100)

    ax_bar.set_xticklabels(labels, rotation=20, ha='right',
                           fontsize=12, fontname=title_font)
    ax_bar.set_yticklabels(
        [f'{int(t)}%' for t in ax_bar.get_yticks()],
        fontsize=12,
        fontname=title_font
    )
    ax_bar.grid(axis='y', linestyle=':', linewidth=0.5, alpha=0.7)

    # Legend (below figure)
    legend_elements = [
        mpatches.Patch(color=dark_blue, label='Agreement (both â‰¥1)'),
        mpatches.Patch(color=orange, label='GLAD only (â‰¥1 vs 0)'),
        mpatches.Patch(color=dark_red, label='AEI_2020 only (0 vs â‰¥1)'),
    ]
    # fig.legend(
    #     handles=legend_elements,
    #     loc='lower center',
    #     ncol=3,
    #     fontsize=14,
    #     frameon=False,
    #     bbox_to_anchor=(0.5, 0.02)
    # )

    plt.tight_layout(rect=[0, 0.05, 1, 1])

    # If you want to save the figure, uncomment this:
    out_png = os.path.join(base_dir, "GLAD_vs_AEI_2020.png")
    plt.savefig(out_png, bbox_inches='tight', dpi=300)
# G:\My Drive\Global_Irr_Maps\Global Maps
    plt.show()

if __name__ == "__main__":
    main()
