In [1]:
# =============================
# 1. Imports
# =============================
import os
import json
import joblib
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from pmdarima import auto_arima
import warnings

warnings.filterwarnings("ignore")

# =============================
# 2. Config
# =============================
DATA_DIR = r"E:\Codes\Projects\ML\air_quality\data\cleaned_city_data"
PLOT_DIR = r"E:\Codes\Projects\ML\air_quality\visuals\plots"
MODEL_DIR = r"E:\Codes\Projects\ML\air_quality\models"

os.makedirs(PLOT_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

CITIES = [
    'Ahmedabad', 'Aizawl', 'Amaravati', 'Amritsar', 'Bengaluru', 'Bhopal',
    'Brajrajnagar', 'Chandigarh', 'Chennai', 'Coimbatore', 'Delhi', 'Ernakulam',
    'Gurugram', 'Guwahati', 'Hyderabad', 'Jaipur', 'Jorapokhar', 'Kochi', 'Kolkata',
    'Lucknow', 'Mumbai', 'Patna', 'Shillong', 'Talcher', 'Thiruvananthapuram',
    'Visakhapatnam'
]

# =============================
# 3. Helper Functions
# =============================
def load_city_data(city):
    """Load cleaned CSV data for a city."""
    file_path = os.path.join(DATA_DIR, f"{city.replace(' ', '_')}.csv")
    if not os.path.exists(file_path):
        print(f"⚠️ No file for {city}")
        return None
    df = pd.read_csv(file_path, parse_dates=['Date'], index_col='Date')
    if 'AQI' not in df.columns:
        print(f"⚠️ No AQI column for {city}")
        return None
    df['AQI'] = pd.to_numeric(df['AQI'], errors='coerce')
    df.dropna(subset=['AQI'], inplace=True)
    if len(df) < 30:
        print(f"⚠️ Not enough data for {city}")
        return None
    return df

def train_and_save_model(city, df):
    """Train Auto ARIMA, save model, plot, and return RMSE & order."""
    train_size = int(len(df) * 0.8)
    train, test = df.iloc[:train_size], df.iloc[train_size:]

    model = auto_arima(
        train['AQI'], start_p=1, start_q=1,
        max_p=5, max_q=5, seasonal=False, d=None,
        trace=False, error_action='ignore', suppress_warnings=True, stepwise=True
    )

    forecast = model.predict(n_periods=len(test))
    rmse = np.sqrt(mean_squared_error(test['AQI'], forecast))

    # Save model
    model_filename = f"{city.replace(' ', '_')}_AutoARIMA.pkl"
    joblib.dump(model, os.path.join(MODEL_DIR, model_filename))

    # Save plot
    plt.figure(figsize=(14, 6))
    plt.plot(train.index, train['AQI'], label='Training')
    plt.plot(test.index, test['AQI'], label='Actual', color='blue')
    plt.plot(test.index, forecast, label='Predicted', color='red')
    plt.title(f"{city} AQI Forecast (Order: {model.order}, RMSE: {rmse:.2f})")
    plt.xlabel("Date")
    plt.ylabel("AQI")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(PLOT_DIR, f"{city.replace(' ', '_')}_plot.png"))
    plt.close()

    return rmse, model.order, model_filename

# =============================
# 4. Main Loop
# =============================
results = []
city_model_map = {}

for city in CITIES:
    print(f"\n=== Processing {city} ===")
    df = load_city_data(city)
    if df is None:
        continue

    try:
        rmse, order, filename = train_and_save_model(city, df)
        results.append({"City": city, "RMSE": rmse, "Order": order})
        city_model_map[city] = {
            "model_file": filename,
            "order": order,
            "rmse": rmse
        }
        print(f"✅ RMSE: {rmse:.2f} | Order: {order}")
    except Exception as e:
        print(f"❌ Error training {city}: {e}")

# =============================
# 5. Save Summary Files
# =============================
results_df = pd.DataFrame(results).sort_values(by="RMSE")
results_df.to_csv(os.path.join(MODEL_DIR, "AutoARIMA_RMSE_Comparison.csv"), index=False)

with open(os.path.join(MODEL_DIR, "city_model_map.json"), "w") as f:
    json.dump(city_model_map, f, indent=4)

print("\n=== Model Performance Summary ===")
print(results_df)
print(f"\n💾 Models saved in: {MODEL_DIR}")
print(f"📄 City-to-model mapping saved as: city_model_map.json")



=== Processing Ahmedabad ===
✅ RMSE: 262.79 | Order: (1, 1, 2)

=== Processing Aizawl ===
✅ RMSE: 5.44 | Order: (0, 1, 0)

=== Processing Amaravati ===
✅ RMSE: 68.80 | Order: (2, 1, 1)

=== Processing Amritsar ===
✅ RMSE: 121.77 | Order: (2, 1, 1)

=== Processing Bengaluru ===
✅ RMSE: 34.41 | Order: (1, 1, 1)

=== Processing Bhopal ===
✅ RMSE: 33.21 | Order: (2, 1, 1)

=== Processing Brajrajnagar ===
✅ RMSE: 61.54 | Order: (5, 1, 1)

=== Processing Chandigarh ===
✅ RMSE: 26.41 | Order: (4, 1, 1)

=== Processing Chennai ===
✅ RMSE: 36.97 | Order: (1, 1, 1)

=== Processing Coimbatore ===
✅ RMSE: 35.83 | Order: (4, 0, 2)

=== Processing Delhi ===
✅ RMSE: 114.07 | Order: (1, 1, 2)

=== Processing Ernakulam ===
✅ RMSE: 8.23 | Order: (0, 1, 0)

=== Processing Gurugram ===
✅ RMSE: 130.26 | Order: (3, 1, 3)

=== Processing Guwahati ===
✅ RMSE: 142.90 | Order: (3, 1, 1)

=== Processing Hyderabad ===
✅ RMSE: 46.64 | Order: (0, 1, 3)

=== Processing Jaipur ===
✅ RMSE: 51.94 | Order: (2, 1, 1)

=