In [None]:
import ee
import os
import glob
import numpy as np
import pandas as pd
import geopandas as gpd
import xarray as xr
import rioxarray
import matplotlib.pyplot as plt
from rasterstats import zonal_stats
from datetime import datetime
from climate_indices import compute, indices

In [None]:
# ==========================================
# PART 1: SETUP & CONFIGURATION
# ==========================================

# Initialize Earth Engine
try:
    ee.Initialize()
except Exception:
    ee.Authenticate()
    ee.Initialize()

START_DATE_HISTORICAL = '1981-01-01' # For SPI calibration
END_DATE_HISTORICAL   = '2025-01-01'
ANALYSIS_START_DATE   = '2022-05-01' # For output CSV/Plots
ANALYSIS_END_DATE     = '2024-11-01'

DATA_DIR = 'data/chirps_tif'
PLOT_DIR_SPI1 = 'figures/spi_1'
PLOT_DIR_SPI3 = 'figures/spi_3'
OUTPUT_CSV = 'data/north_kivu_spi.csv'

for d in [DATA_DIR, PLOT_DIR_SPI1, PLOT_DIR_SPI3, 'data']:
    os.makedirs(d, exist_ok=True)

ASSET_PATH = 'projects/ee-zuruyuyu/assets/airesanteNK'
north_kivu_fc = ee.FeatureCollection(ASSET_PATH)

In [None]:
# ==========================================
# PART 2: DATA DOWNLOAD (EARTH ENGINE)
# ==========================================

def download_chirps_data():
    print("Preparing CHIRPS monthly data...")
    chirps = ee.ImageCollection("UCSB-CHG/CHIRPS/DAILY") \
                .filterDate(START_DATE_HISTORICAL, END_DATE_HISTORICAL) \
                .select('precipitation')

    def monthly_sum(imageCollection):
        def by_month(year):
            months = ee.List.sequence(1, 12)
            def sum_month(month):
                start = ee.Date.fromYMD(year, month, 1)
                end = start.advance(1, 'month')
                monthly_image = imageCollection.filterDate(start, end).sum()
                return monthly_image.set('year', year).set('month', month).set('system:time_start', start.millis())
            return ee.List(months).map(sum_month)
        years = ee.List.sequence(1981, 2024)
        months_by_year = years.map(by_month).flatten()
        return ee.ImageCollection.fromImages(months_by_year)

    chirps_monthly = monthly_sum(chirps)
    chirps_monthly = chirps_monthly.map(lambda img: img.clip(north_kivu_fc))

    import geemap
    print("Downloading GeoTIFFs... (This avoids Drive exports for immediate use)")

    if len(glob.glob(f'{DATA_DIR}/*.tif')) < 500:
        geemap.download_ee_image_collection(
            chirps_monthly,
            out_dir=DATA_DIR,
            scale=5000,
            region=north_kivu_fc.geometry(),
            crs='EPSG:4326'
        )
    else:
        print("Files appear to be already downloaded.")

In [None]:
# ==========================================
# PART 3: LOCAL PROCESSING (SPI & STATS)
# ==========================================

def load_raster_stack():
    print("Loading Raster Stack...")
    tiff_files = sorted(glob.glob(f'{DATA_DIR}/*.tif'))
    datasets = []

    for tiff_file in tiff_files:
        try:
            ds = rioxarray.open_rasterio(tiff_file)
            datasets.append(ds)
        except Exception as e:
            print(f"Skipping {tiff_file}: {e}")

    if not datasets:
        raise ValueError("No datasets loaded.")

    monthly_precip = xr.concat(datasets, dim='time')
    monthly_precip.name = 'precip'

    dates = pd.date_range(start=START_DATE_HISTORICAL, periods=len(datasets), freq='MS')
    monthly_precip = monthly_precip.assign_coords(time=dates)

    monthly_precip = monthly_precip.where(monthly_precip != -9999, np.nan)

    return monthly_precip

def calculate_spi(da, scale, distribution, data_start_year, calibration_year_final):
    print(f"Calculating SPI-{scale}...")
    values = da.values
    spi_values = np.full(values.shape, np.nan, dtype=np.float32)

    rows, cols = values.shape[1], values.shape[2]

    for i in range(rows):
        for j in range(cols):
            series = values[:, i, j]

            if np.count_nonzero(~np.isnan(series)) < 30:
                continue

            try:
                spi = indices.spi(
                    values=series,
                    scale=scale,
                    distribution=distribution,
                    data_start_year=data_start_year,
                    calibration_year_initial=data_start_year,
                    calibration_year_final=calibration_year_final,
                    periodicity=compute.Periodicity.monthly
                )
                spi_values[:, i, j] = spi
            except Exception:
                continue

    return xr.DataArray(
        spi_values,
        coords=da.coords,
        dims=da.dims,
        name=f'spi_{scale}'
    )

def plot_maps(da, output_folder, title_prefix, start_date, end_date):
    print(f"Generating plots for {title_prefix}...")
    subset = da.sel(time=slice(start_date, end_date))

    for t in subset.time:
        date_str = pd.to_datetime(t.values).strftime('%Y-%m-%d')
        date_nice = pd.to_datetime(t.values).strftime('%B %Y')

        plt.figure(figsize=(10, 8))
        subset.sel(time=t).plot(cmap='RdBu', vmin=-3, vmax=3)
        plt.title(f'{title_prefix} - {date_nice}')
        plt.tight_layout()
        plt.savefig(os.path.join(output_folder, f'spi_{date_str}.png'))
        plt.close()

def run_zonal_stats_logic(spi1_da, spi3_da, polygons):
    print("Computing Zonal Statistics...")
    results = []

    common_times = np.intersect1d(spi1_da.time.values, spi3_da.time.values)
    target_times = common_times[(common_times >= pd.to_datetime(ANALYSIS_START_DATE)) &
                                (common_times <= pd.to_datetime(ANALYSIS_END_DATE))]

    if spi1_da.rio.crs is None: spi1_da.rio.write_crs("EPSG:4326", inplace=True)
    if spi3_da.rio.crs is None: spi3_da.rio.write_crs("EPSG:4326", inplace=True)
    polygons = polygons.to_crs(spi1_da.rio.crs)
    transform = spi1_da.rio.transform()

    for t in target_times:
        date_str = pd.to_datetime(t).strftime('%Y/%m')

        arr_spi1 = spi1_da.sel(time=t).values
        arr_spi3 = spi3_da.sel(time=t).values

=        zs_1 = zonal_stats(polygons, arr_spi1, affine=transform, stats=['mean'], nodata=np.nan)
        zs_3 = zonal_stats(polygons, arr_spi3, affine=transform, stats=['mean'], nodata=np.nan)

        for i, row in polygons.iterrows():
            name = row['name']
            val_1 = zs_1[i]['mean']
            val_3 = zs_3[i]['mean']

            # Handle None
            val_1 = np.nan if val_1 is None else val_1
            val_3 = np.nan if val_3 is None else val_3

            results.append({
                'catchment area name': name,
                'month and year': date_str,
                'SPI-1 value': val_1,
                'drought-1': 1 if val_1 < -1.5 else 0,
                'wet-1': 1 if val_1 > 1.5 else 0,
                'SPI-3 value': val_3,
                'drought-3': 1 if val_3 < -1.5 else 0,
                'wet-3': 1 if val_3 > 1.5 else 0
            })

    return pd.DataFrame(results)

In [None]:
# ==========================================
# MAIN EXECUTION
# ==========================================

if __name__ == "__main__":
    download_chirps_data()

    precip_da = load_raster_stack()

    spi1 = calculate_spi(precip_da, 1, indices.Distribution.gamma, 1981, 2024)
    plot_maps(spi1, PLOT_DIR_SPI1, "SPI 1-Month", ANALYSIS_START_DATE, ANALYSIS_END_DATE)


    precip_rolling = precip_da.rolling(time=3, center=False).sum().dropna(dim='time', how='all')
    spi3 = calculate_spi(precip_rolling, 3, indices.Distribution.gamma, 1981, 2024)
    plot_maps(spi3, PLOT_DIR_SPI3, "SPI 3-Month", ANALYSIS_START_DATE, ANALYSIS_END_DATE)


    try:
        import geemap
        gdf = geemap.ee_to_gdf(north_kivu_fc)
    except Exception:
        gdf = gpd.read_file('data/airesanteNK.geojson')

    final_df = run_zonal_stats_logic(spi1, spi3, gdf)

    final_df.to_csv(OUTPUT_CSV, index=False)
    print(f"Processing complete. Saved to {OUTPUT_CSV}")