In [1]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tbats import TBATS as TBATSModel
from sklearn.metrics import mean_squared_error




In [2]:
data = pd.read_csv("../data/HMIS_DATA_CORRECTED_17_21/mh_dist17_21_with_IDs_date_correction.csv")
data = data[(data['indicator_type'] == 'Total [(A+B) or (C+D)]')]
data['date'] = pd.to_datetime(data['date'])
data = data.set_index('date')
data.index = pd.DatetimeIndex(data.index)

In [7]:
def tbats_forecast_district(series, district_name, seasonal_periods=None, 
                           use_box_cox=None, use_arma_errors=True, n_jobs=1):
    """
    TBATS forecasting with data storage and plot saving
    """
    # Create directory structure
    os.makedirs('TBATS', exist_ok=True)
    
    # 1. Train-Test Split (80:20)
    train_size = int(len(series) * 0.8)
    train, test = series.iloc[:train_size], series.iloc[train_size:]

    # 2. TBATS Model Configuration
    estimator = TBATSModel(
        seasonal_periods=seasonal_periods,
        use_box_cox=use_box_cox,
        use_arma_errors=use_arma_errors,
        n_jobs=n_jobs
    )
    
    # 3. Model Fitting
    fitted_model = estimator.fit(train.values)
    
    # 4. Forecasting
    forecast = fitted_model.forecast(steps=len(test))
    
    # 5. Create forecast DataFrame
    forecast_df = pd.DataFrame({
        'district': district_name,
        'date': test.index,
        'actual': test.values,
        'forecast': forecast,
        'model_params': str(fitted_model.params)
    })

    # 6. Save forecasts to CSV
    forecast_csv_path = os.path.join('TBATS', 'tbats_forecasts.csv')
    write_header = not os.path.exists(forecast_csv_path)
    forecast_df.to_csv(forecast_csv_path, 
                      mode='a', 
                      header=write_header, 
                      index=False)

    # 7. Calculate and save metrics
    rmse = np.sqrt(mean_squared_error(test, forecast))
    metrics_df = pd.DataFrame([{
        'district': district_name,
        'rmse': rmse,
        'seasonal_periods': str(seasonal_periods),
        'box_cox_used': fitted_model.params.components.use_box_cox,
        # 'arma_used': fitted_model.params.components.use_arma
        'arma_errors_used': fitted_model.params.components.use_arma_errors

    }])

    metrics_csv_path = os.path.join('TBATS', 'tbats_results.csv')
    write_header_metrics = not os.path.exists(metrics_csv_path)
    metrics_df.to_csv(metrics_csv_path, 
                     mode='a', 
                     header=write_header_metrics, 
                     index=False)

    # 8. Save plot
    plt.figure(figsize=(12, 6))
    plt.plot(train, label='Training Data')
    plt.plot(test, label='Actual Values', color='navy')
    plt.plot(test.index, forecast, label='TBATS Forecast', color='darkorange')
    plt.title(f'TBATS Forecast for {district_name}\nRMSE: {rmse:.2f}')
    plt.legend()
    plt.savefig(os.path.join('TBATS', f'TBATS_forecast_{district_name}.png'))
    plt.close()

    return {
        'district': district_name,
        'forecast_df': forecast_df,
        'metrics_df': metrics_df,
        'model_summary': fitted_model.summary()
    }

In [8]:
districts = data['district'].unique()
rmse_values = []

def run_for_each_district():
    results = {}
    
    for district in districts:
        district_data = data[data['district'] == district]
        ts = district_data["I1"].asfreq('MS')
        
        results = tbats_forecast_district(ts, district)
    
        # Show results
        print("=== Metrics ===")
        print(results['metrics_df'])
        print("\n=== Forecast Data ===")
        print(results['forecast_df'].head())
    
    return results
run_for_each_district()

=== Metrics ===
     district        rmse seasonal_periods  box_cox_used  arma_errors_used
0  AHMEDNAGAR  557.321765             None         False             False

=== Forecast Data ===
     district       date  actual    forecast  \
0  AHMEDNAGAR 2020-06-01    6875  6955.24443   
1  AHMEDNAGAR 2020-07-01    6478  6955.24443   
2  AHMEDNAGAR 2020-08-01    5975  6955.24443   
3  AHMEDNAGAR 2020-09-01    6550  6955.24443   
4  AHMEDNAGAR 2020-10-01    6471  6955.24443   

                                        model_params  
0  <tbats.tbats.ModelParams.ModelParams object at...  
1  <tbats.tbats.ModelParams.ModelParams object at...  
2  <tbats.tbats.ModelParams.ModelParams object at...  
3  <tbats.tbats.ModelParams.ModelParams object at...  
4  <tbats.tbats.ModelParams.ModelParams object at...  
=== Metrics ===
  district        rmse seasonal_periods  box_cox_used  arma_errors_used
0    AKOLA  256.933008             None         False             False

=== Forecast Data ===
  distric

{'district': 'MUMBAI SUBURBAN',
 'forecast_df':           district       date  actual      forecast  \
 0  MUMBAI SUBURBAN 2021-01-01   13242  15034.672362   
 1  MUMBAI SUBURBAN 2021-02-01   13444  15871.799575   
 2  MUMBAI SUBURBAN 2021-03-01   13294  16708.926787   
 
                                         model_params  
 0  <tbats.tbats.ModelParams.ModelParams object at...  
 1  <tbats.tbats.ModelParams.ModelParams object at...  
 2  <tbats.tbats.ModelParams.ModelParams object at...  ,
 'metrics_df':           district         rmse seasonal_periods  box_cox_used  \
 0  MUMBAI SUBURBAN  2631.198075             None         False   
 
    arma_errors_used  
 0             False  ,
 'model_summary': 'Use Box-Cox: False\nUse trend: True\nUse damped trend: False\nSeasonal periods: []\nSeasonal harmonics []\nARMA errors (p, q): (0, 0)\nSmoothing (Alpha): -0.020100\nTrend (Beta): 0.019655\nDamping Parameter (Phi): 1.000000\nSeasonal Parameters (Gamma): []\nAR coefficients []\nMA coeffi