In [None]:
import pandas as pd
import numpy as np
import scipy.stats as st
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

In [None]:
data = pd.read_csv("data\csv\Prec Evapo 1993-2023.csv", parse_dates=["Date"])
data.set_index("Date", inplace=True)

In [None]:
def calculate_spei(precipitation, pet, window):
    """
    Calculate the Standardized Precipitation-Evapotranspiration Index (SPEI).

    Parameters:
    - precipitation: pandas Series, precipitation data.
    - pet: pandas Series, potential evapotranspiration data.
    - window: int, rolling window size in months.

    Returns:
    - spei: pandas Series, SPEI values.
    """
    # Calculate the water balance (D = P - PET)
    deficit = precipitation - pet

    # Rolling sum over the specified window
    rolling_deficit = deficit.rolling(window=window, min_periods=1).sum()

    # Fit Gamma distribution to rolling sums
    rolling_deficit_safe = rolling_deficit.replace(
        0, 1e-6
    )  # Avoid log issues with zeros
    gamma_cdf = st.gamma.cdf(
        rolling_deficit_safe, a=2, scale=1
    )  # Example params, adjust as needed

    # Convert CDF to standard normal distribution (z-scores)
    spei = st.norm.ppf(gamma_cdf)

    return pd.Series(spei, index=precipitation.index)

In [None]:
times = [3, 6, 9, 12]
for i in times:
    data["spei_" + str(i)] = calculate_spei(
        data["precipitation"], data["evapotranspiration"], i
    )

In [None]:
fig, axes = plt.subplots(nrows=4, figsize=(15, 10))
plt.subplots_adjust(hspace=0.15)

for i, ax in enumerate(axes):
    col_scheme = np.where(data["spei_" + str(times[i])] > 0, "b", "r")

    ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
    ax.bar(
        data.index,
        data["spei_" + str(times[i])],
        width=25,
        align="center",
        color=col_scheme,
        label="SPEI " + str(times[i]),
    )
    ax.axhline(y=0, color="k")
    ax.xaxis.set_major_locator(mdates.YearLocator(2))
    ax.legend(loc="upper right")
    ax.set_yticks(range(-3, 4))
    ax.set_yticklabels(range(-3, 4))
    ax.set_ylabel("SPEI", fontsize=12)

    # Remove x-ticks from all but the last subplot
    if i < len(times) - 1:
        ax.set_xticks([], [])

In [None]:
plt.savefig("images/spei_plot.png", dpi=300, bbox_inches="tight")