In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pmdarima import auto_arima
from statsmodels.tsa.statespace.sarimax import SARIMAX
from statsmodels.tsa.stattools import adfuller
from sklearn.metrics import mean_squared_error
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf

In [2]:
data = pd.read_csv("../data/HMIS_DATA_CORRECTED_17_21/mh_dist17_21_with_IDs_date_correction.csv")
data = data[(data['indicator_type'] == 'Total [(A+B) or (C+D)]')]
data['date'] = pd.to_datetime(data['date'])
data = data.set_index('date')
data.index = pd.DatetimeIndex(data.index)

In [3]:
def sarima_forecast_district(series, district_name, seasonal=True,
                            max_p=2, max_q=2, max_P=1, max_Q=1, m=12):
    """
    Robust SARIMA forecasting with error handling and stability features
    """
    # Create directory structure
    os.makedirs('SARIMA', exist_ok=True)
    
    # 0. Data Validation
    if len(series) < 24:  # Minimum 2 years for monthly data
        print(f"Insufficient data ({len(series)} points) for {district_name}")
        return None

    # 1. Enhanced Stationarity Analysis
    adf_result = adfuller(series.dropna())
    is_stationary = adf_result[1] < 0.05
    
    # Diagnostic plots
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
    plot_acf(series.dropna(), ax=ax1)
    plot_pacf(series.dropna(), ax=ax2, method='ywm')
    plt.savefig(os.path.join('SARIMA', f'ACF_PACF_{district_name}.png'))
    plt.close()

    # 2. Train-Test Split (80:20)
    train_size = int(len(series) * 0.8)
    train, test = series.iloc[:train_size], series.iloc[train_size:]

    # 3. Robust Auto-SARIMA Modeling
    try:
        model = auto_arima(
            train,
            start_p=0, d='auto',  # Automatic differencing
            start_q=0, D='auto',  # Auto seasonal differencing
            max_p=max_p, max_q=max_q,
            max_P=max_P, max_Q=max_Q,
            m=m,
            seasonal=seasonal,
            information_criterion='bic',
            trace=True,  # Enable tracing
            error_action='warn',
            suppress_warnings=False,
            stepwise=True,
            with_intercept=True
        )
    except Exception as e:
        print(f"Auto ARIMA failed for {district_name}: {str(e)}")
        return None

    # 4. Model Fitting with Stability Features
    try:
        sarima_model = SARIMAX(
            train,
            order=model.order,
            seasonal_order=model.seasonal_order,
            enforce_stationarity=True,
            enforce_invertibility=True,
            initialization='approximate_diffuse'
        )
        model_fit = sarima_model.fit(disp=True, maxiter=200)
    except LinAlgError:
        print(f"Retrying with simplified model for {district_name}")
        # Fallback configuration
        model = auto_arima(
            train,
            d=1, D=1,
            max_p=1, max_q=1,
            max_P=0, max_Q=0,
            seasonal=seasonal,
            stepwise=True
        )
        sarima_model = SARIMAX(
            train,
            order=model.order,
            seasonal_order=model.seasonal_order
        )
        model_fit = sarima_model.fit(disp=False)

    # 5. Forecasting with Confidence Intervals
    forecast_result = model_fit.get_forecast(steps=len(test))
    forecast = forecast_result.predicted_mean
    conf_int = forecast_result.conf_int()

    # 6. Create forecast DataFrame
    forecast_df = pd.DataFrame({
        'district': district_name,
        'date': test.index,
        'actual': test.values,
        'forecast': forecast.values,
        'lower_ci': conf_int.iloc[:, 0],
        'upper_ci': conf_int.iloc[:, 1]
    })

    # 7. Save forecasts to CSV (append mode)
    forecast_csv_path = os.path.join('SARIMA', 'sarima_forecasts.csv')
    write_header = not os.path.exists(forecast_csv_path)
    forecast_df.to_csv(forecast_csv_path, 
                      mode='a', 
                      header=write_header, 
                      index=False)

    # 8. Calculate and save metrics
    rmse = np.sqrt(mean_squared_error(test, forecast))
    metrics_df = pd.DataFrame([{
        'district': district_name,
        'rmse': rmse,
        'best_order': str(model.order),
        'best_seasonal_order': str(model.seasonal_order),
        'stationary': is_stationary,
        'adf_pvalue': adf_result[1],
        'n_obs': len(series)
    }])

    metrics_csv_path = os.path.join('SARIMA', 'sarima_results.csv')
    write_header_metrics = not os.path.exists(metrics_csv_path)
    metrics_df.to_csv(metrics_csv_path, 
                     mode='a', 
                     header=write_header_metrics, 
                     index=False)

    # 9. Save plot with enhanced visualization
    plt.figure(figsize=(14, 7))
    plt.plot(train, label='Training Data', color='#1f77b4')
    plt.plot(test, label='Actual Values', color='#2ca02c', linewidth=2)
    plt.plot(forecast, label='Forecast', color='#ff7f0e', linestyle='--')
    plt.fill_between(forecast.index, 
                    conf_int.iloc[:, 0],
                    conf_int.iloc[:, 1],
                    color='#ff7f0e', alpha=0.15)
    plt.title(f'SARIMA{model.order}x{model.seasonal_order} Forecast for {district_name}\nRMSE: {rmse:.2f} | ADF p-value: {adf_result[1]:.3f}')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join('SARIMA', f'SARIMA_forecast_{district_name}.png'), dpi=300)
    plt.close()

    return {
        'district': district_name,
        'forecast_df': forecast_df,
        'metrics_df': metrics_df,
        'model_summary': model_fit.summary()
    }


In [4]:
districts = data['district'].unique()
rmse_values = []

def run_for_each_district():
    results = {}
    
    for district in districts:
        district_data = data[data['district'] == district]
        # ts = district_data["I8"].asfreq('MS')
        ts = district_data["I40"].astype(float).asfreq('MS')

        
        results = sarima_forecast_district(ts, district)
    
        # Show results
        # print("=== Metrics ===")
        # print(results['metrics_df'])
        # print("\n=== Forecast Data ===")
        # print(results['forecast_df'].head())
    
    return results
run_for_each_district()

Auto ARIMA failed for AHMEDNAGAR: '<' not supported between instances of 'str' and 'int'
Auto ARIMA failed for AKOLA: '<' not supported between instances of 'str' and 'int'
Auto ARIMA failed for AMRAVATI: '<' not supported between instances of 'str' and 'int'
Auto ARIMA failed for AURANGABAD: '<' not supported between instances of 'str' and 'int'
Auto ARIMA failed for BEED: '<' not supported between instances of 'str' and 'int'
Auto ARIMA failed for BHANDARA: '<' not supported between instances of 'str' and 'int'
Auto ARIMA failed for BULDHANA: '<' not supported between instances of 'str' and 'int'
Auto ARIMA failed for CHANDRAPUR: '<' not supported between instances of 'str' and 'int'
Auto ARIMA failed for DHULE: '<' not supported between instances of 'str' and 'int'
Auto ARIMA failed for GADCHIROLI: '<' not supported between instances of 'str' and 'int'
Auto ARIMA failed for GONDIA: '<' not supported between instances of 'str' and 'int'
Auto ARIMA failed for HINGOLI: '<' not supporte