In [None]:
!pip -q install earthengine-api

import os
import ee
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split, StratifiedKFold, RandomizedSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score

# Config
PROJECT_ID = "asean-envoirmental-project"
OUT_CSV    = "thailand_rice_env_s1_s2_full.csv"

STRATIFY_BY_LABEL = False
POINTS_PER_CLASS_PER_REGION = 10
NUM_POINTS_PER_REGION       = 200

SUPERVISED_RICE_ASSET_ID = "projects/asean-envoirmental-project/assets/MSEAsia2019"
LABEL_YEAR = 2019

S2_COLLECTION = "COPERNICUS/S2_SR_HARMONIZED"

# Earth Engine init
os.environ["EARTHENGINE_PROJECT"] = PROJECT_ID
try:
    ee.Initialize(project=PROJECT_ID)
except Exception:
    ee.Authenticate(auth_mode="notebook")
    ee.Initialize(project=PROJECT_ID)

print("EE project:", PROJECT_ID)

# Study area
thailand = ee.FeatureCollection("FAO/GAUL/2015/level0").filter(
    ee.Filter.eq("ADM0_NAME", "Thailand")
)

regions = {
    "Central":   {"lat": 14.5, "lon": 100.5, "radius": 50000},
    "Northeast": {"lat": 16.5, "lon": 103.0, "radius": 50000},
    "North":     {"lat": 18.5, "lon":  99.0, "radius": 50000},
    "South":     {"lat":  8.0, "lon":  99.5, "radius": 50000},
}

def region_geom(region_info):
    center = ee.Geometry.Point([region_info["lon"], region_info["lat"]])
    return center.buffer(region_info["radius"])

# Labels from 2019 MSEAsia supervised rice mask
def load_supervised_rice_image():
    print("\nLoading supervised rice (MSE Asia 2019)…")
    raw = ee.Image(SUPERVISED_RICE_ASSET_ID).select("b1").clip(thailand.geometry())
    rice_bin = raw.eq(255).rename("rice")
    kernel = ee.Kernel.square(60, "meters", True)
    rice_mode = rice_bin.reduceNeighborhood(ee.Reducer.mode(), kernel).rename("rice")
    print("OK: binary rice (0/1), 20 m grid")
    return rice_mode

RICE_IMG = load_supervised_rice_image()

# ENV predictors
def get_evi_metrics(point, year):
    try:
        col = (
            ee.ImageCollection("MODIS/061/MOD13Q1")
            .filterDate(f"{year}-05-01", f"{year}-11-30")
            .select("EVI")
        )
        if col.size().getInfo() == 0:
            return None, None, None

        def scale_and_mask(img):
            scaled = img.multiply(0.0001)
            return scaled.updateMask(
                scaled.gte(-0.2).And(scaled.lte(1.0))
            ).copyProperties(img, img.propertyNames())

        col_scaled = col.map(scale_and_mask)
        combo = col_scaled.mean().addBands(
            [col_scaled.max(), col_scaled.min()]
        ).rename(["mean_evi", "max_evi", "min_evi"])
        geom = point.geometry().buffer(500)

        stats = combo.reduceRegion(
            reducer=ee.Reducer.mean(),
            geometry=geom,
            scale=250,
            maxPixels=1e9,
            bestEffort=True,
            tileScale=4,
        )

        mean_val = stats.get("mean_evi")
        max_val  = stats.get("max_evi")
        min_val  = stats.get("min_evi")
        if mean_val is None:
            return None, None, None

        mean_evi_val = float(ee.Number(mean_val).getInfo())
        max_evi_val  = float(ee.Number(max_val).getInfo()) if max_val else None
        min_evi_val  = float(ee.Number(min_val).getInfo()) if min_val else None
        amp = (
            max_evi_val - min_evi_val
            if (max_evi_val is not None and min_evi_val is not None)
            else None
        )
        return mean_evi_val, max_evi_val, amp
    except Exception:
        return None, None, None

def get_lst_metrics(point, year):
    try:
        col = (
            ee.ImageCollection("MODIS/061/MOD11A2")
            .filterDate(f"{year}-05-01", f"{year}-11-30")
            .select("LST_Day_1km")
        )
        if col.size().getInfo() == 0:
            return None, None

        def scale_and_mask(img):
            kelvin  = img.multiply(0.02)
            masked  = kelvin.updateMask(kelvin.gte(200).And(kelvin.lte(400)))
            celsius = masked.subtract(273.15)
            return celsius.copyProperties(img, img.propertyNames())

        col_c = col.map(scale_and_mask)
        combo = col_c.mean().addBands(col_c.max()).rename(["mean_temp", "max_temp"])
        geom = point.geometry().buffer(1000)

        stats = combo.reduceRegion(
            reducer=ee.Reducer.mean(),
            geometry=geom,
            scale=1000,
            maxPixels=1e9,
            bestEffort=True,
            tileScale=4,
        )

        mean_val = stats.get("mean_temp")
        max_val  = stats.get("max_temp")
        if mean_val is None:
            return None, None

        return float(ee.Number(mean_val).getInfo()), (
            float(ee.Number(max_val).getInfo()) if max_val else None
        )
    except Exception:
        return None, None

def get_precipitation_metrics(point, year):
    try:
        chirps = (
            ee.ImageCollection("UCSB-CHG/CHIRPS/DAILY")
            .filterDate(f"{year}-01-01", f"{year}-12-31")
            .select("precipitation")
        )
        if chirps.size().getInfo() == 0:
            return None, None, None

        annual_total = chirps.sum()
        gs_total     = chirps.filter(
            ee.Filter.calendarRange(5, 11, "month")
        ).sum()
        peak_total   = chirps.filter(
            ee.Filter.calendarRange(7, 9, "month")
        ).sum()
        combo = annual_total.addBands(
            [gs_total, peak_total]
        ).rename(["annual", "gs", "peak"])
        geom = point.geometry().buffer(5000)

        stats = combo.reduceRegion(
            reducer=ee.Reducer.mean(),
            geometry=geom,
            scale=5000,
            maxPixels=1e9,
            bestEffort=True,
            tileScale=4,
        )

        a = stats.get("annual")
        g = stats.get("gs")
        p = stats.get("peak")
        if a is None:
            return None, None, None

        return float(ee.Number(a).getInfo()), (
            float(ee.Number(g).getInfo()) if g else None
        ), (
            float(ee.Number(p).getInfo()) if p else None
        )
    except Exception:
        return None, None, None

def get_precipitation_metrics_era5(point, year):
    try:
        era5 = (
            ee.ImageCollection("ECMWF/ERA5/DAILY")
            .filterDate(f"{year}-01-01", f"{year}-12-31")
            .select("total_precipitation")
        )
        if era5.size().getInfo() == 0:
            return None, None, None

        era5_mm = era5.map(
            lambda img: img.multiply(1000).copyProperties(
                img, img.propertyNames()
            )
        )
        annual_total = era5_mm.sum()
        gs_total     = era5_mm.filter(
            ee.Filter.calendarRange(5, 11, "month")
        ).sum()
        peak_total   = era5_mm.filter(
            ee.Filter.calendarRange(7, 9, "month")
        ).sum()
        combo = annual_total.addBands(
            [gs_total, peak_total]
        ).rename(["annual", "gs", "peak"])
        geom = point.geometry().buffer(27000)

        stats = combo.reduceRegion(
            reducer=ee.Reducer.mean(),
            geometry=geom,
            scale=27830,
            maxPixels=1e9,
            bestEffort=True,
            tileScale=4,
        )

        a = stats.get("annual")
        g = stats.get("gs")
        p = stats.get("peak")
        if a is None:
            return None, None, None

        return float(ee.Number(a).getInfo()), (
            float(ee.Number(g).getInfo()) if g else None
        ), (
            float(ee.Number(p).getInfo()) if p else None
        )
    except Exception:
        return None, None, None

def get_water_occurrence(point):
    try:
        water = ee.Image("JRC/GSW1_4/GlobalSurfaceWater").select("occurrence")
        sample = water.sample(
            region=point.geometry().buffer(1500),
            scale=30,
            numPixels=9,
        )
        if sample.size().getInfo() == 0:
            return 0.0, 0.0

        mean_occ = sample.aggregate_mean("occurrence")
        max_occ  = sample.aggregate_max("occurrence")
        return (
            float(mean_occ.getInfo()) if mean_occ else 0.0,
            float(max_occ.getInfo())  if max_occ  else 0.0,
        )
    except Exception:
        return 0.0, 0.0

def get_terrain_data(point):
    try:
        srtm  = ee.Image("USGS/SRTMGL1_003").select("elevation")
        slope = ee.Terrain.slope(srtm)
        sample = srtm.addBands(slope).sample(
            region=point.geometry().buffer(90),
            scale=30,
            numPixels=1,
        )
        if sample.size().getInfo() == 0:
            return None, None
        feat = sample.first()
        return (
            float(feat.get("elevation").getInfo()) if feat.get("elevation") else None,
            float(feat.get("slope").getInfo())     if feat.get("slope")     else None,
        )
    except Exception:
        return None, None

def get_soil_properties(point):
    try:
        soc  = ee.Image(
            "OpenLandMap/SOL/SOL_ORGANIC-CARBON_USDA-6A1C_M/v02"
        ).select("b0")
        clay = ee.Image(
            "OpenLandMap/SOL/SOL_CLAY-WFRACTION_USDA-3A1A1A_M/v02"
        ).select("b0")
        sample = soc.addBands(clay).sample(
            region=point.geometry().buffer(250),
            scale=250,
            numPixels=1,
        )
        if sample.size().getInfo() == 0:
            return None, None
        feat = sample.first()
        soc_gkg  = float(feat.get("b0").getInfo())   if feat.get("b0")   else None
        clay_pct = (
            float(feat.get("b0_1").getInfo()) / 10.0 if feat.get("b0_1") else None
        )
        return soc_gkg, clay_pct
    except Exception:
        return None, None

# Sentinel-1 predictors (fixed: global composite, no log10)
def build_s1_stack(year):
    s1 = (
        ee.ImageCollection("COPERNICUS/S1_GRD")
        .filterDate(f"{year}-01-01", f"{year}-12-31")
        .filter(ee.Filter.eq("instrumentMode", "IW"))
        .filter(ee.Filter.eq("orbitProperties_pass", "DESCENDING"))
        .filter(ee.Filter.eq("resolution_meters", 10))
        .filter(ee.Filter.eq("productType", "GRD"))
        .filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VV"))
        .filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VH"))
    )

    vv = s1.select("VV").mean().rename("s1_vv_mean")
    vh = s1.select("VH").mean().rename("s1_vh_mean")
    return vv.addBands(vh)

print("\nBuilding Sentinel-1 stack for 2019…")
S1_STACK = build_s1_stack(LABEL_YEAR)
print("S1 stack ready.")

def get_s1_metrics(point, year):
    try:
        geom = point.geometry().buffer(60)
        stats = S1_STACK.reduceRegion(
            reducer=ee.Reducer.mean(),
            geometry=geom,
            scale=10,
            maxPixels=1e7,
            bestEffort=True,
        )

        vv = stats.get("s1_vv_mean")
        vh = stats.get("s1_vh_mean")
        if vv is None or vh is None:
            return None, None

        return float(ee.Number(vv).getInfo()), float(ee.Number(vh).getInfo())
    except Exception:
        return None, None

# Sentinel-2 predictors
def mask_s2_clouds(img):
    qa = img.select("QA60")
    cloud_bit  = 1 << 10
    cirrus_bit = 1 << 11
    cloud  = qa.bitwiseAnd(cloud_bit).neq(0)
    cirrus = qa.bitwiseAnd(cirrus_bit).neq(0)
    mask   = cloud.Or(cirrus).Not()
    return img.updateMask(mask).copyProperties(img, img.propertyNames())

def add_s2_indices(img):
    scale = 0.0001
    b4  = img.select("B4").multiply(scale)
    b8  = img.select("B8").multiply(scale)
    b8a = img.select("B8A").multiply(scale)

    ndvi   = b8.subtract(b4).divide(b8.add(b4).add(1e-6)).rename("NDVI")
    re_idx = b8a.subtract(b4).divide(b8a.add(b4).add(1e-6)).rename("RE_INDEX")

    return img.addBands([ndvi, re_idx])

def get_s2_metrics(point, year):
    try:
        start_date = f"{year}-05-01"
        end_date   = f"{year}-11-30"

        col = (
            ee.ImageCollection(S2_COLLECTION)
            .filterDate(start_date, end_date)
            .filterBounds(point.geometry())
            .filter(ee.Filter.lt("CLOUDY_PIXEL_PERCENTAGE", 40))
            .map(mask_s2_clouds)
            .map(add_s2_indices)
            .select(["NDVI", "RE_INDEX"])
        )

        if col.size().getInfo() == 0:
            return None

        composite = (
            col.reduce(ee.Reducer.mean())
            .addBands(col.reduce(ee.Reducer.max()))
            .addBands(col.reduce(ee.Reducer.min()))
            .rename(
                [
                    "NDVI_mean",
                    "RE_mean",
                    "NDVI_max",
                    "RE_max",
                    "NDVI_min",
                    "RE_min",
                ]
            )
        )

        ndvi_amp = composite.select("NDVI_max").subtract(
            composite.select("NDVI_min")
        ).rename("NDVI_amp")
        composite = composite.addBands(ndvi_amp)

        geom = point.geometry().buffer(60)
        stats = composite.reduceRegion(
            reducer=ee.Reducer.mean(),
            geometry=geom,
            scale=20,
            maxPixels=1e7,
            bestEffort=True,
        )

        vals = stats.getInfo()
        if vals is None:
            return None

        out = {}
        for k in [
            "NDVI_mean",
            "NDVI_max",
            "NDVI_min",
            "RE_mean",
            "RE_max",
            "RE_min",
            "NDVI_amp",
        ]:
            v = vals.get(k)
            out[k] = float(v) if v is not None else None

        return out
    except Exception:
        return None

# Collect one point
def collect_data_for_point(point, region_name, year, point_id, is_rice_val):
    try:
        coords = point.geometry().coordinates()
        row = {
            "point_id": point_id,
            "region": region_name,
            "year": int(year),
            "latitude": float(coords.get(1).getInfo()),
            "longitude": float(coords.get(0).getInfo()),
            "is_rice": int(is_rice_val),
        }

        mean_evi, max_evi, evi_amp = get_evi_metrics(point, year)
        row["mean_evi_growing_season"] = mean_evi
        row["max_evi_growing_season"]  = max_evi
        row["evi_amplitude"]           = evi_amp

        mean_temp, max_temp = get_lst_metrics(point, year)
        row["mean_temperature_c"] = mean_temp
        row["max_temperature_c"]  = max_temp

        annual, gs, peak = get_precipitation_metrics(point, year)
        if annual is None:
            annual, gs, peak = get_precipitation_metrics_era5(point, year)
        row["annual_rainfall_mm"]         = annual
        row["growing_season_rainfall_mm"] = gs
        row["peak_growth_rainfall_mm"]    = peak

        water_mean, water_max = get_water_occurrence(point)
        row["water_occurrence_mean_pct"] = water_mean
        row["water_occurrence_max_pct"]  = water_max

        elevation, slope = get_terrain_data(point)
        row["elevation_m"]   = elevation
        row["slope_degrees"] = slope

        soc, clay = get_soil_properties(point)
        row["soil_organic_carbon_g_kg"] = soc
        row["clay_content_pct"]         = clay

        s1_vv, s1_vh = get_s1_metrics(point, year)
        row["s1_vv_mean"] = s1_vv
        row["s1_vh_mean"] = s1_vh

        s2_vals = get_s2_metrics(point, year)
        if s2_vals is not None:
            row.update(s2_vals)

        return row
    except Exception:
        return None

# Sampling
def sample_all_features():
    print("\nTHAILAND RICE DATA (2019)")
    print("Labels: MSE Asia 2019 | Predictors: ENV + S1 + S2")
    print(f"Year: {LABEL_YEAR}")
    print(f"Regions: {list(regions.keys())}")
    if STRATIFY_BY_LABEL:
        print(f"Sampling: STRATIFIED, ~{2 * POINTS_PER_CLASS_PER_REGION}/region")
    else:
        print(f"Sampling: RANDOM {NUM_POINTS_PER_REGION}/region")

    all_rows = []
    total_attempted = 0
    total_successful = 0
    year = LABEL_YEAR

    for region_name, region_info in regions.items():
        print(f"\n== {region_name} ==")
        geom = region_geom(region_info)

        if STRATIFY_BY_LABEL:
            try:
                strat = RICE_IMG.stratifiedSample(
                    numPoints=0,
                    classBand="rice",
                    region=geom,
                    scale=20,
                    geometries=True,
                    classValues=[0, 1],
                    classPoints=[
                        POINTS_PER_CLASS_PER_REGION,
                        POINTS_PER_CLASS_PER_REGION,
                    ],
                    seed=42,
                )
            except Exception as e:
                print("  stratifiedSample failed, falling back:", e)
                strat = None

            if strat is None or strat.size().getInfo() == 0:
                print("  no stratified samples, skipping")
                continue

            fc_list = strat.toList(strat.size())
            n_pts   = strat.size().getInfo()
            print(f"  candidates: {n_pts}")

            year_success = 0
            for i in range(n_pts):
                total_attempted += 1
                ft = ee.Feature(fc_list.get(i))
                is_rice_val = int(ee.Number(ft.get("rice")).getInfo()) if ft.get("rice") is not None else None
                if is_rice_val is None:
                    continue
                pid = f"{region_name}_{year}_{i}"
                row = collect_data_for_point(ft, region_name, year, pid, is_rice_val)
                if row is not None:
                    all_rows.append(row)
                    year_success += 1
                    total_successful += 1
                if (i + 1) % 50 == 0:
                    print(f"    processed {i + 1}/{n_pts}")
            print(f"  done: {year_success}/{n_pts}")

        else:
            points = ee.FeatureCollection.randomPoints(
                region=geom,
                points=NUM_POINTS_PER_REGION,
                seed=42,
            )
            flist  = points.toList(points.size())
            year_success = 0
            for i in range(NUM_POINTS_PER_REGION):
                total_attempted += 1
                pt = ee.Feature(flist.get(i))
                try:
                    v = RICE_IMG.sample(
                        region=pt.geometry(),
                        scale=20,
                        numPixels=1,
                        geometries=False,
                    ).first()
                    if v is None:
                        continue
                    is_rice_val = int(ee.Number(v.get("rice")).getInfo())
                except Exception:
                    continue

                pid = f"{region_name}_{year}_{i}"
                row = collect_data_for_point(pt, region_name, year, pid, is_rice_val)
                if row is not None:
                    all_rows.append(row)
                    year_success += 1
                    total_successful += 1
                if (i + 1) % 50 == 0:
                    print(f"    processed {i + 1}/{NUM_POINTS_PER_REGION}")
            print(f"  done: {year_success}/{NUM_POINTS_PER_REGION}")

    df = pd.DataFrame(all_rows)
    df.to_csv(OUT_CSV, index=False)

    print("\nDONE SAMPLING.")
    print(
        f"attempted: {total_attempted} | success: {total_successful} | "
        f"rate: {(total_successful / total_attempted * 100) if total_attempted else 0:.1f}%"
    )
    print(f"saved: {OUT_CSV}")

    if len(df) > 0:
        print("\nSUMMARY (2019)")
        print(f"rows: {len(df)} | regions: {df['region'].nunique()}")
        print(
            f"rice: {int(df['is_rice'].sum())} ({df['is_rice'].mean() * 100:.1f}%) | "
            f"non-rice: {len(df) - int(df['is_rice'].sum())}"
        )

        print(f"\ncols ({len(df.columns)}):")
        print(list(df.columns))

        print("\nsample (head):")
        print(df.head())

    return df

# Random Forest model
def train_random_forest(df, n_iter=40, random_state=42):
    print("\n=== RANDOM FOREST MODELING (ENV + S1 + S2) ===")

    df_model = df.dropna(subset=["is_rice"]).reset_index(drop=True)
    y = df_model["is_rice"].astype(int)

    numeric_cols = df_model.select_dtypes(include=[np.number]).columns.tolist()
    feature_cols = [c for c in numeric_cols if c != "is_rice"]

    X = df_model[feature_cols].copy()
    X = X.fillna(X.median())

    stds = X.std()
    keep_cols = [c for c in feature_cols if stds[c] > 1e-6]
    X = X[keep_cols]
    feature_cols = keep_cols

    print("Number of features:", len(feature_cols))
    print("First 15 features:", feature_cols[:15])

    X_tr, X_te, y_tr, y_te = train_test_split(
        X,
        y,
        test_size=0.2,
        random_state=random_state,
        stratify=y,
    )

    print("Train shape:", X_tr.shape, "Test shape:", X_te.shape)

    rf_base = RandomForestClassifier(
        n_estimators=400,
        max_depth=None,
        min_samples_leaf=2,
        n_jobs=-1,
        random_state=random_state,
        class_weight="balanced_subsample",
    )

    param_dist = {
        "max_depth":         [None, 12, 20, 30],
        "min_samples_split": [2, 4, 6],
        "min_samples_leaf":  [1, 2, 4],
        "max_features":      ["sqrt", 0.5, 1.0],
        "bootstrap":         [True, False],
    }

    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=random_state)

    random_search = RandomizedSearchCV(
        estimator=rf_base,
        param_distributions=param_dist,
        n_iter=n_iter,
        scoring="roc_auc",
        n_jobs=-1,
        cv=cv,
        verbose=2,
        random_state=random_state,
        refit=True,
    )

    print("\n=== Hyperparameter search (AUC) ===")
    random_search.fit(X_tr, y_tr)

    print("\nBest CV AUC:", random_search.best_score_)
    print("Best params:")
    for k, v in random_search.best_params_.items():
        print(f"  {k}: {v}")

    best_rf = random_search.best_estimator_

    y_pred  = best_rf.predict(X_te)
    y_proba = best_rf.predict_proba(X_te)[:, 1]

    print("\n=== Test set performance ===")
    print(classification_report(y_te, y_pred, digits=3))
    print("Confusion matrix:")
    print(confusion_matrix(y_te, y_pred))

    auc_test = roc_auc_score(y_te, y_proba)
    print(f"\nROC–AUC (test): {auc_test:.3f}")

    importances = pd.Series(best_rf.feature_importances_, index=feature_cols)
    importances = importances.sort_values(ascending=False)

    print("\nTop 20 most important features:")
    print(importances.head(20).round(4))

    return best_rf, feature_cols

# Main
if __name__ == "__main__":
    df_all = sample_all_features()

    try:
        from google.colab import files
        files.download(OUT_CSV)
        print(f"\nDownloaded {OUT_CSV}")
    except Exception:
        print(f"\nFile saved: {OUT_CSV}")

    if len(df_all) > 0:
        best_rf, feature_cols = train_random_forest(df_all)
