In [None]:
import xarray as xr
import numpy as np
from scipy import stats
import os
import glob
import warnings

# Suppress warnings related to division by zero or log of zero during Gamma fitting
warnings.filterwarnings('ignore')

# =================CONFIGURATION =================
# Update these variables to match your data
INPUT_FOLDER = 'mon_rainfall'          # Current folder
OUTPUT_FOLDER = 'drought_indices'
VAR_NAME = 'monthly_rain'       # Name of the variable inside your .nc files
FILE_PATTERN = '*.nc'       # Pattern to match your files (e.g. 1990.nc, 1991.nc)
TIMESCALE_LIST = [1, 3, 6, 12]
START_YEAR_OUTPUT = 1991
END_YEAR_OUTPUT = 2023
# ================================================

def compute_spi_gamma(data_array):
    """
    Computes SPI using Gamma distribution fitting.
    Expects a 1D numpy array (time series for a specific pixel and specific month).
    """
    # Remove NaNs
    clean_data = data_array[~np.isnan(data_array)]
    
    if len(clean_data) < 3:
        # Not enough data to fit
        return np.full(data_array.shape, np.nan)

    # SPI requires handling zeros separately (Gamma is undefined for 0)
    zeros = clean_data == 0
    non_zeros = clean_data[clean_data > 0]
    
    n = len(clean_data)
    n_zeros = np.sum(zeros)
    q = n_zeros / n  # Probability of zero
    
    if len(non_zeros) == 0:
        # All rain is zero
        # Probability is q (which is 1.0). 
        # In standard SPI, this is handled, but often maps to a specific minimum bound.
        # We return a standard dry value or calculation based on q.
        # Simple approach: standard normal of q
        return np.full(data_array.shape, stats.norm.ppf(q))

    # Fit Gamma distribution to non-zero values
    # alpha (shape), loc, beta (scale)
    alpha, loc, beta = stats.gamma.fit(non_zeros, floc=0)
    
    # Calculate Cumulative Probability
    # 1. CDF of Gamma for observed values
    y_gamma = stats.gamma.cdf(clean_data, alpha, loc=loc, scale=beta)
    
    # 2. Combined CDF considering zeros
    # H(x) = q + (1-q) * G(x)
    y_norm = q + (1 - q) * y_gamma
    
    # If y_norm is 1.0 (max), norm.ppf is inf. We clip slightly.
    y_norm = np.clip(y_norm, 0, 0.99999)
    
    # 3. Convert probability to Z-score (SPI)
    spi = stats.norm.ppf(y_norm)
    
    # We must return an array matching the original shape (including NaNs if any)
    # However, xarray apply_ufunc handles the re-mapping if we return consistent shape
    # Here we just return the calculated SPIs for the valid inputs
    
    # Reconstruct full array with NaNs if the input had NaNs 
    # (Though usually apply_ufunc passes clean slices if configured, 
    # basic apply passes the whole vector)
    result = np.full(data_array.shape, np.nan)
    result[~np.isnan(data_array)] = spi
    
    return result

def main():
    if not os.path.exists(OUTPUT_FOLDER):
        os.makedirs(OUTPUT_FOLDER)

    print("1. Loading Data...")
    # Load all files
    ds = xr.open_mfdataset(os.path.join(INPUT_FOLDER, FILE_PATTERN), 
                           combine='by_coords', 
                           parallel=True, 
                           chunks={'time': -1, 'lat': 50, 'lon': 50}) 

    ds = ds.sortby('time')
    
    if VAR_NAME not in ds:
        raise ValueError(f"Variable '{VAR_NAME}' not found. Found: {list(ds.data_vars)}")

    print("2. Resampling to Monthly Sums...")
    precip_monthly = ds[VAR_NAME].resample(time='1MS').sum(skipna=False)
    
    # Force 'time' to be a single chunk to prevent the error
    # This ensures the entire time history is available for the SPI calculation
    precip_monthly = precip_monthly.chunk({'time': -1})

    ds_spi = xr.Dataset()
    ds_spi.coords.update(precip_monthly.coords)

    for scale in TIMESCALE_LIST:
        print(f"3. Processing SPI-{scale}...")
        
        # A. Calculate Rolling Sum
        rolling_sum = precip_monthly.rolling(time=scale, center=False, min_periods=scale).sum()

        # B. Calculate SPI
        # We group by month (Januaries, Februarys...) and apply the SPI function
        spi_da = rolling_sum.groupby('time.month').map(
            lambda x: xr.apply_ufunc(
                compute_spi_gamma, 
                x,
                input_core_dims=[['time']],
                output_core_dims=[['time']],
                vectorize=True, 
                dask='parallelized',
                output_dtypes=[float],
                # --- THE FIX IS HERE ---
                # This allows dask to merge time chunks if they became fragmented
                dask_gufunc_kwargs={'allow_rechunk': True} 
            )
        )
        
        # Sort back to chronological order
        spi_da = spi_da.sortby('time')
        
        ds_spi[f'spi_{scale}'] = spi_da

    print("4. Saving Year-wise Files...")
    
    # Select only the requested output years
    ds_final = ds_spi.sel(time=slice(f'{START_YEAR_OUTPUT}-01-01', f'{END_YEAR_OUTPUT}-12-31'))

    years = np.unique(ds_final['time.year'])
    
    for year in years:
        ds_year = ds_final.sel(time=str(year))
        
        out_filename = f"SPI_{year}.nc"
        out_path = os.path.join(OUTPUT_FOLDER, out_filename)
        
        print(f"  - Writing {out_filename}...")
        ds_year.to_netcdf(out_path)

    print("Done! Files saved in 'SPI_Output' folder.")

if __name__ == "__main__":
    main()

1. Loading Data...
2. Resampling to Monthly Sums...
3. Processing SPI-1...
3. Processing SPI-3...
3. Processing SPI-6...
3. Processing SPI-12...
4. Saving Year-wise Files...
  - Writing SPI_1991.nc...
  - Writing SPI_1992.nc...
  - Writing SPI_1993.nc...
  - Writing SPI_1994.nc...
  - Writing SPI_1995.nc...
  - Writing SPI_1996.nc...
  - Writing SPI_1997.nc...
  - Writing SPI_1998.nc...
  - Writing SPI_1999.nc...
  - Writing SPI_2000.nc...
  - Writing SPI_2001.nc...
  - Writing SPI_2002.nc...
  - Writing SPI_2003.nc...
  - Writing SPI_2004.nc...
  - Writing SPI_2005.nc...
  - Writing SPI_2006.nc...
  - Writing SPI_2007.nc...
  - Writing SPI_2008.nc...
  - Writing SPI_2009.nc...
  - Writing SPI_2010.nc...
  - Writing SPI_2011.nc...
  - Writing SPI_2012.nc...
  - Writing SPI_2013.nc...
  - Writing SPI_2014.nc...
  - Writing SPI_2015.nc...
  - Writing SPI_2016.nc...
  - Writing SPI_2017.nc...
  - Writing SPI_2018.nc...
  - Writing SPI_2019.nc...
  - Writing SPI_2020.nc...
  - Writing SPI_