In [None]:
import os, json, logging, gc
import numpy as np
import pandas as pd
import rasterio
from rasterio.warp import reproject, Resampling
from rasterio.features import rasterize
from rasterio.transform import xy

import geopandas as gpd
from shapely.ops import unary_union

import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

logging.getLogger("rasterio._err").setLevel(logging.ERROR)

# ===========================
# ABSOLUTE PATHS
# ===========================
# Predictors (1970–2019) on/near 0.1°
EDGAR_NOX_PATH  = r"D:\Sekh Mohinuddin\PhD Work\NOx_NH3_Conference\2nd Objectives\EDGAR\EDGAR_Nox_1970_2019.tif"
EDGAR_NH3_PATH  = r"D:\Sekh Mohinuddin\PhD Work\NOx_NH3_Conference\2nd Objectives\EDGAR\EDGAR_NH3_1970_2019.tif"
PPT_TIF         = r"D:\Sekh Mohinuddin\PhD Work\NOx_NH3_Conference\2nd Objectives\ppt\TerraClimate_annual_ppt_1970_2019_masked.tif"

# Observation rasters (USA, Europe, China)
USA_DIR = r"D:\Sekh Mohinuddin\PhD Work\NOx_NH3_Conference\1st Objectives\Global N deposition\USA\Deposition\WetDIN_Total_asN"
USA_PATTERN = os.path.join(USA_DIR, "WetDIN_{year}_asN.tif")  # 1990..2019

EU_DIR_A = r"D:\Sekh Mohinuddin\PhD Work\NOx_NH3_Conference\1st Objectives\Global N deposition\Europe"
EU_PATTERN_A = os.path.join(EU_DIR_A, "WDEP_Total_{year}_flipped.tif")   # 1990..2004
EU_DIR_B = r"D:\Sekh Mohinuddin\PhD Work\NOx_NH3_Conference\1st Objectives\Global N deposition\Europe\Wet N"
EU_PATTERN_B = os.path.join(EU_DIR_B, "EMEP_WDEP_{year}.tif")            # 2005..2019

CN_DIR_BLOCKS = r"D:\Sekh Mohinuddin\PhD Work\NOx_NH3_Conference\1st Objectives\Global N deposition\Inorganic Nitrogen deposition  database1.0\NITROGEN1_1996-2015\1996-2015"
CN_BLOCK_96_00 = os.path.join(CN_DIR_BLOCKS, "DIN_1996_2000.tif")
CN_BLOCK_01_05 = os.path.join(CN_DIR_BLOCKS, "DIN_2001_2005.tif")
CN_BLOCK_06_10 = os.path.join(CN_DIR_BLOCKS, "DIN_2006_2010.tif")
CN_DIR_2011_2019 = r"D:\Sekh Mohinuddin\PhD Work\NOx_NH3_Conference\1st Objectives\Global N deposition\Inorganic Nitrogen deposition  database1.0\Inorganic Nitrogen deposition  database1.0\Data File\Wet N"
CN_PERYEAR_PATTERN = os.path.join(CN_DIR_2011_2019, "{year}.tif")        # 2011..2019

# Region boundaries (USA, Europe). China mask from obs coverage.
USA_BOUNDARY_SHP = r"D:\Sekh Mohinuddin\PhD Work\NOx_NH3_Conference\1st Objectives\Global N deposition\USA\USA_Mainland.shp"
EU_BOUNDARY_SHP  = r"D:\Sekh Mohinuddin\PhD Work\NOx_NH3_Conference\1st Objectives\Global N deposition\Europe\Europe\Com_Europe_merged.shp"

# ===========================
# OUTPUTS
# ===========================
OUTPUT_DIR   = r"D:\Sekh Mohinuddin\PhD Work\NOx_NH3_Conference\2nd Objectives\Outputs-5"
PLOTS_DIR    = os.path.join(OUTPUT_DIR, "Plots_SingleModel")
MODELS_DIR   = os.path.join(OUTPUT_DIR, "Models_SingleModel")
FINAL_DIR    = os.path.join(OUTPUT_DIR, "SingleModel_Global_from_USA_EU_CHN")
for d in [OUTPUT_DIR, PLOTS_DIR, MODELS_DIR, FINAL_DIR]:
    os.makedirs(d, exist_ok=True)

# ===========================
# CONSTANTS
# ===========================
START_YEAR = 1990
END_YEAR   = 2019
EDGAR_BASE_YEAR = 1970
NODATA = np.float32(-9999.0)

# Sampling caps (tune for compute)
MAX_SAMPLES_PER_YEAR_PER_REGION = 150_000   # for training set build
HELDOUT_FRAC = 0.2                           # for gating per-year R2 inside each region

# ===========================
# Helpers
# ===========================
def map_year_to_band(year, base): return int(year - base + 1)

def _clean(a, nodata):
    aa = a.astype(np.float64, copy=False)
    if nodata is not None and np.isfinite(nodata): aa[aa == nodata] = np.nan
    aa[(aa <= -1e19) | (aa >= 1e19)] = np.nan
    return aa

def read_band(fp, band_idx):
    with rasterio.open(fp) as src:
        if not (1 <= band_idx <= src.count):
            raise ValueError(f"{fp}: band {band_idx}/{src.count}")
        arr = src.read(band_idx)
        return _clean(arr, src.nodata), src.transform, src.crs, src.profile

def resample_to(arr, s_tx, s_crs, d_shape, d_tx, d_crs, method="bilinear"):
    resamp = {"nearest": Resampling.nearest, "bilinear": Resampling.bilinear, "cubic": Resampling.cubic}.get(method, Resampling.bilinear)
    out = np.full(d_shape, np.nan, dtype=np.float64)
    reproject(source=arr.astype(np.float64, copy=False),
              destination=out,
              src_transform=s_tx, src_crs=s_crs,
              dst_transform=d_tx, dst_crs=d_crs,
              src_nodata=np.nan, dst_nodata=np.nan,
              resampling=resamp, num_threads=2)
    out[(out <= -1e19) | (out >= 1e19)] = np.nan
    return out

def dissolve_to_single(gdf):
    geom = unary_union([g for g in gdf.geometry if g is not None and not g.is_empty])
    return gpd.GeoDataFrame(geometry=[geom], crs=gdf.crs)

def rasterize_gdf(gdf, tx, crs, out_shape):
    if gdf.crs != crs: gdf = gdf.to_crs(crs)
    geom = unary_union([g for g in gdf.geometry if g is not None and not g.is_empty])
    if geom is None or geom.is_empty:
        return np.zeros(out_shape, dtype=bool)
    return rasterize([(geom, 1)], out_shape=out_shape, transform=tx, fill=0, dtype='uint8').astype(bool)

# ===========================
# Build 0.1° template from EDGAR NOx 1990
# ===========================
with rasterio.open(EDGAR_NOX_PATH) as _src:
    b1990 = map_year_to_band(1990, EDGAR_BASE_YEAR)
    _tmpl, T_TX, T_CRS, T_PROF = read_band(EDGAR_NOX_PATH, b1990)
    T_H, T_W = _tmpl.shape
    TEMPLATE_PROFILE = T_PROF.copy()
    TEMPLATE_PROFILE.update(count=1, dtype=rasterio.float32, nodata=NODATA)
print(f"[template] 0.1° grid: H={T_H}, W={T_W}, CRS={T_CRS}")

# ===========================
# Predictors (to template)
# ===========================
def predictors_for_year(year):
    b = map_year_to_band(year, EDGAR_BASE_YEAR)
    nox, nx_tx, nx_crs, _ = read_band(EDGAR_NOX_PATH, b)
    nh3, nh_tx, nh_crs, _ = read_band(EDGAR_NH3_PATH, b)
    ppt, pp_tx, pp_crs, _ = read_band(PPT_TIF, b)
    if (nox.shape != (T_H, T_W)) or (nx_tx != T_TX) or (nx_crs != T_CRS):
        nox = resample_to(nox, nx_tx, nx_crs, (T_H, T_W), T_TX, T_CRS)
    if (nh3.shape != (T_H, T_W)) or (nh_tx != T_TX) or (nh_crs != T_CRS):
        nh3 = resample_to(nh3, nh_tx, nh_crs, (T_H, T_W), T_TX, T_CRS)
    if (ppt.shape != (T_H, T_W)) or (pp_tx != T_TX) or (pp_crs != T_CRS):
        ppt = resample_to(ppt, pp_tx, pp_crs, (T_H, T_W), T_TX, T_CRS)
    # domain-safe clips
    nox = np.clip(nox, 0, 1e7); nh3 = np.clip(nh3, 0, 1e7); ppt = np.clip(ppt, 0, 1e4)
    return nox, nh3, ppt

# ===========================
# Observation rasters (to template) + China mask builder
# ===========================
def obs_for_year(year):
    """Return list of (obs_array_on_template, region_tag)."""
    out = []

    # USA
    upath = USA_PATTERN.format(year=year)
    if os.path.exists(upath):
        arr, s_tx, s_crs, _ = read_band(upath, 1)
        if (arr.shape != (T_H, T_W)) or (s_tx != T_TX) or (s_crs != T_CRS):
            arr = resample_to(arr, s_tx, s_crs, (T_H, T_W), T_TX, T_CRS)
        arr = np.clip(arr, 0, 1e3)
        out.append((arr, "USA"))

    # Europe
    if 1990 <= year <= 2004:
        epath = EU_PATTERN_A.format(year=year)
    else:
        epath = EU_PATTERN_B.format(year=year)
    if os.path.exists(epath):
        arr, s_tx, s_crs, _ = read_band(epath, 1)
        if (arr.shape != (T_H, T_W)) or (s_tx != T_TX) or (s_crs != T_CRS):
            arr = resample_to(arr, s_tx, s_crs, (T_H, T_W), T_TX, T_CRS)
        arr = np.clip(arr, 0, 1e3)
        out.append((arr, "EUROPE"))

    # China
    cpath = None
    if 1996 <= year <= 2000 and os.path.exists(CN_BLOCK_96_00):
        cpath = CN_BLOCK_96_00
    elif 2001 <= year <= 2005 and os.path.exists(CN_BLOCK_01_05):
        cpath = CN_BLOCK_01_05
    elif 2006 <= year <= 2010 and os.path.exists(CN_BLOCK_06_10):
        cpath = CN_BLOCK_06_10
    elif 2011 <= year <= 2019:
        per = CN_PERYEAR_PATTERN.format(year=year)
        if os.path.exists(per): cpath = per
    if cpath is not None:
        arr, s_tx, s_crs, _ = read_band(cpath, 1)
        if (arr.shape != (T_H, T_W)) or (s_tx != T_TX) or (s_crs != T_CRS):
            arr = resample_to(arr, s_tx, s_crs, (T_H, T_W), T_TX, T_CRS)
        arr = np.clip(arr, 0, 1e3)
        out.append((arr, "CHINA"))

    return out

def china_mask_from_obs():
    mask = np.zeros((T_H, T_W), dtype=bool)
    for y in list(range(1996, 2011)) + list(range(2011, 2020)):
        lst = obs_for_year(y)
        for arr, tag in lst:
            if tag == "CHINA":
                mask |= np.isfinite(arr)
    return mask

# ===========================
# Build region masks (USA/EU from polygons; China from obs coverage)
# ===========================
usa_gdf = dissolve_to_single(gpd.read_file(USA_BOUNDARY_SHP))
eu_gdf  = dissolve_to_single(gpd.read_file(EU_BOUNDARY_SHP))

with rasterio.open(EDGAR_NOX_PATH) as _src:
    USA_MASK = rasterize_gdf(usa_gdf, T_TX, T_CRS, (T_H, T_W))
    EU_MASK  = rasterize_gdf(eu_gdf,  T_TX, T_CRS, (T_H, T_W))
CHN_MASK = china_mask_from_obs()
REGION_MASKS = {"USA": USA_MASK, "EUROPE": EU_MASK, "CHINA": CHN_MASK}
print(f"[masks] USA={USA_MASK.sum()} px | EU={EU_MASK.sum()} px | CHN={CHN_MASK.sum()} px")

# ===========================
# Build training set from observed pixels in three regions
# ===========================
def build_training_from_regions():
    Xs, ys = [], []
    year_tags = []
    for year in range(START_YEAR, END_YEAR+1):
        obs_list = obs_for_year(year)
        if not obs_list:
            continue
        # predictors (aligned to template)
        nox, nh3, ppt = predictors_for_year(year)
        yrn = (year - 2000.0)/50.0

        for arr, tag in obs_list:
            mask = REGION_MASKS.get(tag, None)
            if mask is None:
                continue
            valid = np.isfinite(arr) & np.isfinite(nox) & np.isfinite(nh3) & np.isfinite(ppt) & mask
            r, c = np.where(valid)
            if r.size == 0:
                continue
            take = min(MAX_SAMPLES_PER_YEAR_PER_REGION, r.size)
            sel = np.random.choice(r.size, take, replace=False)
            rr, cc = r[sel], c[sel]

            # coords
            lat = np.empty(take, float); lon = np.empty(take, float)
            for i in range(take):
                yy, xx = int(rr[i]), int(cc[i])
                xlon, ylat = xy(T_TX, yy, xx, offset='center')
                lon[i] = xlon; lat[i] = ylat

            feats = np.column_stack([nox[rr,cc], nh3[rr,cc], ppt[rr,cc], lat, lon, np.full(take, yrn)])
            targ  = arr[rr, cc].astype(np.float32)
            Xs.append(feats.astype(np.float32)); ys.append(targ)
            year_tags.append(np.full(take, year, dtype=np.int16))
        print(f"[train-build] {year}: collected {sum(len(y) for y in ys)} samples so far")

    if not Xs:
        raise RuntimeError("No training samples found; check observation raster paths.")
    X = np.vstack(Xs)
    y = np.concatenate(ys)
    yt = np.concatenate(year_tags)
    print(f"[train-build] Total samples: {X.shape[0]:,}")
    return X, y, yt

def build_mlp(n_in):
    m = keras.Sequential([
        layers.Dense(384, activation='relu', input_shape=(n_in,)),
        layers.BatchNormalization(),
        layers.Dropout(0.35),
        layers.Dense(192, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.25),
        layers.Dense(96, activation='relu'),
        layers.Dropout(0.15),
        layers.Dense(48, activation='relu'),
        layers.Dense(1)
    ])
    m.compile(optimizer=keras.optimizers.Adam(1e-3), loss='mse', metrics=['mae','mse'])
    return m

# ===========================
# Train single model + Gate (per-region R² by year)
# ===========================
def train_single_and_gate():
    X, y, y_year = build_training_from_regions()
    sx, sy = StandardScaler(), StandardScaler()
    Xs, ys_ = sx.fit_transform(X), sy.fit_transform(y.reshape(-1,1))

    model = build_mlp(X.shape[1])
    cbs = [
        keras.callbacks.EarlyStopping(patience=15, restore_best_weights=True, verbose=1),
        keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=8, verbose=1),
        keras.callbacks.ModelCheckpoint(os.path.join(MODELS_DIR, "single_best.h5"), save_best_only=True, verbose=1),
    ]
    print("[single-train] fitting...")
    model.fit(Xs, ys_, epochs=200, batch_size=1024, validation_split=0.2, verbose=1, callbacks=cbs)
    model.save(os.path.join(MODELS_DIR, "single_final.h5"))

    # Gate: for each region+year where obs exists, evaluate R² on a held-out 20% random pixel set
    gate_records = []
    for region_name, region_mask in REGION_MASKS.items():
        r2_per_year = []
        for year in range(START_YEAR, END_YEAR+1):
            # Load obs for this region+year
            obs_arr = None
            for arr, tag in obs_for_year(year):
                if tag == region_name:
                    obs_arr = arr; break
            if obs_arr is None:
                continue

            nox, nh3, ppt = predictors_for_year(year)
            valid = np.isfinite(obs_arr) & np.isfinite(nox) & np.isfinite(nh3) & np.isfinite(ppt) & region_mask
            r, c = np.where(valid)
            if r.size < 400:
                continue
            take = max(400, int(HELDOUT_FRAC * r.size))
            sel = np.random.choice(r.size, take, replace=False)
            rr, cc = r[sel], c[sel]

            lat = np.empty(take, float); lon = np.empty(take, float)
            for i in range(take):
                yy, xx = int(rr[i]), int(cc[i])
                xlon, ylat = xy(T_TX, yy, xx, offset='center')
                lon[i] = xlon; lat[i] = ylat
            yrn = (year - 2000.0)/50.0
            Xte = np.column_stack([nox[rr,cc], nh3[rr,cc], ppt[rr,cc], lat, lon, np.full(take, yrn)])

            pred = sy.inverse_transform(model.predict(sx.transform(Xte), verbose=0, batch_size=4096)).ravel()
            r2 = r2_score(obs_arr[rr,cc], pred)
            r2_per_year.append((year, r2))

        if r2_per_year:
            yrs, vals = zip(*r2_per_year)
            mean_r2 = float(np.nanmean(vals))
            gate_records.append((region_name, mean_r2))
            # quick plot
            plt.figure(figsize=(8,3))
            plt.plot(yrs, vals, marker='o'); plt.grid(True)
            plt.ylim(0, 1.0)
            plt.title(f"Gate R² by year — {region_name}")
            plt.xlabel("Year"); plt.ylabel("R² (obs vs single-model)")
            plt.tight_layout()
            plt.savefig(os.path.join(PLOTS_DIR, f"gate_r2_{region_name}.png"), dpi=200)
            plt.close()
            pd.DataFrame(r2_per_year, columns=["year","r2"]).to_csv(
                os.path.join(PLOTS_DIR, f"gate_r2_{region_name}.csv"), index=False)

    df_gate = pd.DataFrame(gate_records, columns=["region","mean_r2"])
    df_gate.to_csv(os.path.join(PLOTS_DIR, "gate_summary.csv"), index=False)
    print(df_gate)

    all_ok = True
    for _, mean_r2 in gate_records:
        if not (np.isfinite(mean_r2) and mean_r2 >= 0.85):
            all_ok = False
            break

    # persist scalers
    with open(os.path.join(MODELS_DIR, "single_scalers.json"), "w") as f:
        json.dump({"sx_mean": sx.mean_.tolist(), "sx_scale": sx.scale_.tolist(),
                   "sy_mean": float(sy.mean_[0]), "sy_scale": float(sy.scale_[0])}, f)
    return model, sx, sy, all_ok

# ===========================
# Predict global seamless series
# ===========================
def predict_single_year(model, sx, sy, year, block=500):
    nox, nh3, ppt = predictors_for_year(year)
    out = np.full((T_H, T_W), np.nan, np.float32)
    for r0 in range(0, T_H, block):
        for c0 in range(0, T_W, block):
            r1, c1 = min(T_H, r0+block), min(T_W, c0+block)
            nb, hb, pb = nox[r0:r1, c0:c1], nh3[r0:r1, c0:c1], ppt[r0:r1, c0:c1]
            v = np.isfinite(nb) & np.isfinite(hb) & np.isfinite(pb)
            if not v.any(): continue
            rr, cc = np.where(v)
            lat = np.empty(rr.size, float); lon = np.empty(rr.size, float)
            for i in range(rr.size):
                yy, xx = r0+int(rr[i]), c0+int(cc[i])
                xlon, ylat = xy(T_TX, yy, xx, offset='center')
                lon[i] = xlon; lat[i] = ylat
            yrn = (year - 2000.0)/50.0
            feats = np.column_stack([nb[rr,cc], hb[rr,cc], pb[rr,cc], lat, lon, np.full(rr.size, yrn)])
            pred = sy.inverse_transform(model.predict(sx.transform(feats), verbose=0, batch_size=4096)).ravel()
            out[r0:r1, c0:c1][v] = pred.astype(np.float32)
    # physical clamp
    out = np.clip(out, 0, np.nanmax(out))
    return out

def write_single_series(model, sx, sy, start=START_YEAR, end=END_YEAR):
    prof = TEMPLATE_PROFILE.copy()
    for year in range(start, end+1):
        print(f"[predict] {year}")
        pred = predict_single_year(model, sx, sy, year)
        out = np.where(np.isfinite(pred), pred, NODATA).astype(np.float32)
        p = os.path.join(FINAL_DIR, f"SingleModel_Global_0p1_{year}.tif")
        with rasterio.open(p, "w", **prof) as dst:
            dst.write(out, 1)

# ===========================
# MAIN
# ===========================
def main():
    print("=== Single-model (USA+EU+CHN only) → Global prediction (1990–2019) ===")
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for g in gpus:
                tf.config.experimental.set_memory_growth(g, True)
        except RuntimeError:
            pass

    model, sx, sy, gate_ok = train_single_and_gate()
    if not gate_ok:
        print("[STOP] Gate failed — mean R² < 0.85 in at least one region. Not writing outputs.")
        return

    print("\n=== Gate passed in USA, EUROPE, CHINA (mean R² ≥ 0.85). Writing seamless global GTiffs ===")
    write_single_series(model, sx, sy, START_YEAR, END_YEAR)

    print("\nDONE")
    print(f"Seamless single-model outputs: {FINAL_DIR}")

if __name__ == "__main__":
    main()