In [None]:
import ee
import geopandas as gpd
import pandas as pd
import numpy as np

# =================================================
# CONFIG
# =================================================
YEAR = 2021
START_DATE = f"{YEAR}-01-01"
END_DATE = f"{YEAR}-12-31"

# Reference period for calculating 90th percentile (typically 30 years)
# Using 1991-2020 as baseline (adjust based on data availability)
BASELINE_START = "1991-01-01"
BASELINE_END = "2020-12-31"

ASSAM_GEOJSON = "assam_district_2024-11.geojson"

DISTRICT_FIELD = "dtname"
OBJECTID_FIELD = "object_id"

OUT_CSV = f"assam_district_heat_temp_summary_{YEAR}.csv"

# =================================================
# INITIALISE EARTH ENGINE
# =================================================
ee.Initialize(project="ee-stephensmathew07")

# =================================================
# LOAD DISTRICTS
# =================================================
gdf = gpd.read_file(ASSAM_GEOJSON).to_crs(epsg=4326)

missing = [c for c in [DISTRICT_FIELD, OBJECTID_FIELD] if c not in gdf.columns]
if missing:
    raise ValueError(f"❌ Missing required columns: {missing}")

# =================================================
# SHAPELY → EE GEOMETRY
# =================================================
def shapely_to_ee(geometry):
    if geometry.geom_type == "Polygon":
        return ee.Geometry.Polygon(list(geometry.exterior.coords))
    elif geometry.geom_type == "MultiPolygon":
        return ee.Geometry.MultiPolygon(
            [list(p.exterior.coords) for p in geometry.geoms]
        )
    else:
        return None

# =================================================
# CALCULATE 90TH PERCENTILE THRESHOLD (IMD METHOD)
# =================================================
print("Calculating 90th percentile thresholds for each district...")

# Load baseline ERA5 data
era5_baseline = (
    ee.ImageCollection("ECMWF/ERA5_LAND/DAILY_AGGR")
    .filterDate(BASELINE_START, BASELINE_END)
    .select("temperature_2m_max")
)

# Convert to Celsius
def kelvin_to_celsius(img):
    return img.subtract(273.15).copyProperties(img, img.propertyNames())

era5_baseline = era5_baseline.map(kelvin_to_celsius)

# Calculate 90th percentile for each district
district_thresholds = {}

for _, row in gdf.iterrows():
    ee_geom = shapely_to_ee(row.geometry)
    if ee_geom is None:
        continue
    
    district_name = str(row[DISTRICT_FIELD]).upper()
    
    try:
        # Calculate 90th percentile for this district
        percentile_90 = (
            era5_baseline
            .reduce(ee.Reducer.percentile([90]))
            .reduceRegion(
                reducer=ee.Reducer.mean(),
                geometry=ee_geom,
                scale=10000,
                maxPixels=1e9
            )
            .getInfo()
        )
        
        threshold = percentile_90.get("temperature_2m_max_p90")
        
        if threshold is not None:
            district_thresholds[district_name] = round(threshold, 2)
            print(f"  {district_name}: {threshold:.2f}°C")
        else:
            print(f"⚠️ Could not calculate threshold for {district_name}, using default 35°C")
            district_thresholds[district_name] = 35.0
            
    except Exception as e:
        print(f"⚠️ Error calculating threshold for {district_name}: {e}")
        district_thresholds[district_name] = 35.0

# =================================================
# ERA5-LAND COLLECTION FOR TARGET YEAR
# =================================================
era5 = (
    ee.ImageCollection("ECMWF/ERA5_LAND/DAILY_AGGR")
    .filterDate(START_DATE, END_DATE)
)

def add_temperature_bands(img):
    tmax_c = img.select("temperature_2m_max").subtract(273.15).rename("tmax_c")
    tmin_c = img.select("temperature_2m_min").subtract(273.15).rename("tmin_c")
    return img.addBands([tmax_c, tmin_c])

era5 = era5.map(add_temperature_bands)

# =================================================
# AGGREGATE TEMPERATURE IMAGES
# =================================================
tmax_mean_img = era5.select("tmax_c").mean().rename("tmax_mean")
tmin_mean_img = era5.select("tmin_c").mean().rename("tmin_mean")

tmax_abs_img = era5.select("tmax_c").max().rename("tmax_abs")
tmin_abs_img = era5.select("tmin_c").min().rename("tmin_abs")

# Stack temperature outputs
temp_img = ee.Image.cat([
    tmax_mean_img,
    tmin_mean_img,
    tmax_abs_img,
    tmin_abs_img
])

# =================================================
# ZONAL STATS (DISTRICT-SPECIFIC HEATWAVE CALCULATION)
# =================================================
print(f"\nCalculating statistics for year {YEAR}...")
results = []

for _, row in gdf.iterrows():
    ee_geom = shapely_to_ee(row.geometry)
    if ee_geom is None:
        continue

    district_name = str(row[DISTRICT_FIELD]).upper()
    threshold = district_thresholds.get(district_name, 35.0)
    
    feature = ee.Feature(
        ee_geom,
        {
            "OBJECT_ID": row[OBJECTID_FIELD],
            "DISTRICT": district_name,
            "THRESHOLD_90P": threshold
        }
    )

    fc = ee.FeatureCollection([feature])

    try:
        # Calculate temperature statistics
        temp_stats = (
            temp_img
            .reduceRegions(
                collection=fc,
                reducer=ee.Reducer.mean(),
                scale=10000
            )
            .first()
            .getInfo()
        )
        
        # Calculate heatwave days using district-specific threshold
        def count_heatwave_days(img):
            tmax = img.select("temperature_2m_max").subtract(273.15)
            is_heatwave = tmax.gte(threshold)
            return is_heatwave.rename("heatwave_day")
        
        heatwave_collection = era5.map(count_heatwave_days)
        heatwave_days_img = heatwave_collection.select("heatwave_day").sum()
        
        heatwave_stats = (
            heatwave_days_img
            .reduceRegion(
                reducer=ee.Reducer.mean(),
                geometry=ee_geom,
                scale=10000,
                maxPixels=1e9
            )
            .getInfo()
        )

        p = temp_stats["properties"]
        
        results.append({
            "OBJECT_ID": p.get("OBJECT_ID"),
            "DISTRICT": district_name,
            "THRESHOLD_90P_C": threshold,
            "TMAX_MEAN_C": round(p.get("tmax_mean", 0), 2),
            "TMIN_MEAN_C": round(p.get("tmin_mean", 0), 2),
            "TMAX_ABS_C": round(p.get("tmax_abs", 0), 2),
            "TMIN_ABS_C": round(p.get("tmin_abs", 0), 2),
            "HEATWAVE_DAYS": round(heatwave_stats.get("heatwave_day", 0))
        })
        
        print(f"✓ {district_name}: {round(heatwave_stats.get('heatwave_day', 0))} heatwave days (threshold: {threshold}°C)")

    except Exception as e:
        print(f"⚠️ Skipped {district_name}: {e}")

# =================================================
# SAVE CSV
# =================================================
df = pd.DataFrame(results)

if df.empty:
    raise RuntimeError("❌ No climate statistics computed.")

df = df.sort_values("DISTRICT")
df.to_csv(OUT_CSV, index=False)

print(f"\n✅ Temperature + heatwave summary saved: {OUT_CSV}")
print(f"\nSummary:")
print(f"  Districts processed: {len(df)}")
print(f"  Total heatwave days (all districts): {df['HEATWAVE_DAYS'].sum():.0f}")
print(f"  Average threshold (90th percentile): {df['THRESHOLD_90P_C'].mean():.2f}°C")