In [None]:
import xarray as xr
import numpy as np
import os
import warnings
from scipy.stats import norm

warnings.filterwarnings("ignore")

# ================= CONFIGURATION =================
RAIN_FOLDER = r'/content/drive/MyDrive/Colab Notebooks/Climate Data Processing/Data/rainfall'
RAIN_VAR = 'monthly_rain'

PET_FOLDER = r'/content/drive/MyDrive/Colab Notebooks/Climate Data Processing/Data/et_short'
PET_VAR = 'et_short_crop'

FILE_PATTERN = '*.nc'
TIMESCALE_LIST = [1, 3, 6, 12]

START_YEAR_OUTPUT = 1991
END_YEAR_OUTPUT = 2023

OUTPUT_FOLDER = r'/content/drive/MyDrive/Colab Notebooks/Climate Data Processing/Data/SPEI'
# ================================================


# ---------- L-MOMENT FUNCTIONS ----------
def lmoments(data):
    data = np.sort(data)
    n = len(data)
    if n < 30:
        return None

    b0 = np.mean(data)
    b1 = np.sum((np.arange(n) / (n - 1)) * data) / n
    b2 = np.sum((np.arange(n) * (np.arange(n) - 1)) /
                ((n - 1) * (n - 2)) * data) / n

    l1 = b0
    l2 = 2 * b1 - b0
    l3 = 6 * b2 - 6 * b1 + b0

    if l2 == 0:
        return None

    t3 = l3 / l2
    return l1, l2, t3


def loglogistic_lmom_params(l1, l2, t3):
    c = (1 + 3 * t3) / (1 - t3)
    scale = l2 * np.sin(np.pi / c) / (np.pi / c)
    loc = l1 - scale * (np.pi / c) / np.sin(np.pi / c)
    return c, loc, scale


def compute_spei_ll(data):
    mask = ~np.isnan(data)
    clean = data[mask]

    if clean.size < 30:
        return np.full(data.shape, np.nan, dtype=np.float32)

    lm = lmoments(clean)
    if lm is None:
        return np.full(data.shape, np.nan, dtype=np.float32)

    l1, l2, t3 = lm

    try:
        c, loc, scale = loglogistic_lmom_params(l1, l2, t3)
        cdf = 1 / (1 + ((scale / (clean - loc)) ** c))
        cdf = np.clip(cdf, 1e-6, 1 - 1e-6)
        spei = norm.ppf(cdf)

        out = np.full(data.shape, np.nan, dtype=np.float32)
        out[mask] = spei
        return out

    except Exception:
        return np.full(data.shape, np.nan, dtype=np.float32)


# ---------- MAIN PIPELINE ----------
def main():

    os.makedirs(OUTPUT_FOLDER, exist_ok=True)

    print("1. Loading datasets...")
    ds_rain = xr.open_mfdataset(
        os.path.join(RAIN_FOLDER, FILE_PATTERN),
        combine='by_coords',
        parallel=True,
        chunks={'time': -1, 'lat': 200, 'lon': 200}
    ).sortby("time")

    ds_pet = xr.open_mfdataset(
        os.path.join(PET_FOLDER, FILE_PATTERN),
        combine='by_coords',
        parallel=True,
        chunks={'time': -1, 'lat': 200, 'lon': 200}
    ).sortby("time")

    ds_rain, ds_pet = xr.align(ds_rain, ds_pet, join='inner')

    print("2. Monthly water balance...")
    rain = ds_rain[RAIN_VAR].resample(time="1MS").sum()
    pet = ds_pet[PET_VAR].resample(time="1MS").sum()

    diff = (rain - pet).astype("float32")
    diff = diff.chunk({'time': -1})

    ds_out = xr.Dataset(coords=diff.coords)

    for scale in TIMESCALE_LIST:
        print(f"3. Computing SPEI-{scale}...")

        rolling = diff.rolling(time=scale, min_periods=scale).sum()

        spei = rolling.groupby("time.month").map(
            lambda x: xr.apply_ufunc(
                compute_spei_ll,
                x,
                input_core_dims=[["time"]],
                output_core_dims=[["time"]],
                vectorize=True,
                dask="parallelized",
                output_dtypes=[np.float32],
            )
        )

        ds_out[f"spei_{scale}"] = spei.sortby("time")

    print("4. Saving yearly NetCDF files...")
    ds_final = ds_out.sel(
        time=slice(f"{START_YEAR_OUTPUT}-01-01", f"{END_YEAR_OUTPUT}-12-31")
    )

    for year in np.unique(ds_final.time.dt.year):
        out_path = os.path.join(OUTPUT_FOLDER, f"SPEI_{year}.nc")
        ds_final.sel(time=str(year)).to_netcdf(out_path)
        print(f"   Saved SPEI_{year}.nc")

    print("âœ… SPEI computation finished successfully.")


if __name__ == "__main__":
    main()
