In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.tsa.statespace.sarimax import SARIMAX
from statsmodels.tsa.holtwinters import ExponentialSmoothing
from statsmodels.tsa.stattools import adfuller
from statsmodels.tsa.api import VAR
import pmdarima as pm
from prophet import Prophet
from statsmodels.graphics.tsaplots import plot_acf
from statsmodels.tsa.seasonal import seasonal_decompose
from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler
from xgboost import XGBRegressor
from sklearn.model_selection import GridSearchCV, TimeSeriesSplit
import os
import logging
from datetime import datetime
import warnings
from tqdm import tqdm
import seaborn as sns
from joblib import Parallel, delayed
import time
import matplotlib.dates as mdates

tqdm.monitor_interval = 0

img_dir = 'classical_model_results'
os.makedirs(img_dir, exist_ok=True)

logging.basicConfig(
    filename=f'{img_dir}/classical_models_log.txt',
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore")

CONFIG = {
    'forecast_horizon': 12,
    'seasonal_periods': 12,
    'min_data_length': 24,
    'img_dir': img_dir,
    'results_file': f'{img_dir}/classical_model_results.csv',
    'n_jobs': -1,
    'outlier_threshold': 3,
    'max_diff': 3,
    'lag_features': list(range(1, 13)),
    'rolling_windows': [3, 6, 12],
    'correlation_threshold': 0.15,
}

def detect_outliers(series, threshold=CONFIG['outlier_threshold']):
    z_scores = np.abs((series - series.mean()) / series.std())
    outliers = z_scores > threshold
    series_clean = series.copy()
    series_clean[outliers] = series.mean()
    logger.info(f"Detected {outliers.sum()} outliers in series")
    return series_clean

def validate_input_data(df, required_columns):
    if not all(col in df.columns for col in required_columns):
        raise ValueError(f"Missing required columns: {required_columns}")
    if df.index.duplicated().any():
        raise ValueError("Index contains duplicates!")
    if not df.index.is_monotonic_increasing:
        raise ValueError("Index is not monotonically increasing!")
    if df[required_columns].isnull().sum().any():
        logger.warning(f"Missing values in data: {df[required_columns].isnull().sum().to_dict()}")
        df[required_columns] = df[required_columns].fillna(method='ffill').fillna(df[required_columns].mean())
    if df[required_columns].replace([np.inf, -np.inf], np.nan).isnull().sum().any():
        raise ValueError("Data contains infinite values!")
    if not all(df[required_columns].dtypes.apply(lambda x: np.issubdtype(x, np.number))):
        raise ValueError("Some columns are not numeric!")

def check_stationarity(series, name):
    max_diff = CONFIG['max_diff']
    series_clean = series.dropna().replace([np.inf, -np.inf], np.nan).dropna()
    if len(series_clean) < 2:
        logger.error(f"{name}: Data too short after cleaning!")
        return 0
    d = 0
    while d <= max_diff:
        result = adfuller(series_clean)
        logger.info(f"ADF Test for {name} (d={d}): Statistic={result[0]:.4f}, p-value={result[1]:.4f}")
        if result[1] < 0.05:
            logger.info(f"{name} stationary at differencing order d={d}")
            return d
        if d == max_diff:
            logger.warning(f"{name} not stationary after {max_diff} differencing. Using d={d}.")
            return d
        series_clean = series_clean.diff().dropna()
        if len(series_clean) < 2:
            logger.warning(f"{name}: Data too short after differencing {d+1}!")
            return d
        d += 1
    return d

def create_features(
    data: pd.DataFrame,
    target: str,
    lags: list = CONFIG['lag_features'],
    rolling_windows: list = CONFIG['rolling_windows'],
    seasonal_features: bool = True,
    fill_method: str = 'ffill'
) -> pd.DataFrame:
    logger.info(f"Creating features for {target}, data shape: {data.shape}")
    if len(data) < CONFIG['min_data_length']:
        raise ValueError(f"Data too short: {len(data)} rows")
    if not isinstance(data.index, pd.DatetimeIndex):
        raise ValueError("DataFrame index must be DatetimeIndex")
    if data.index.duplicated().any():
        raise ValueError("Index contains duplicates")
    if not data.index.is_monotonic_increasing:
        data = data.sort_index()
    exog_var = 'cpi_yoy'
    required_cols = [target, exog_var]
    if not all(col in data.columns for col in required_cols):
        raise ValueError(f"Missing columns: {required_cols}")
    df = data.copy()
    if df[required_cols].isna().any().any() or np.isinf(df[required_cols]).any().any():
        logger.warning(f"Data contains NaN or Inf in {required_cols}, filling...")
        df[required_cols] = df[required_cols].fillna(method='ffill').fillna(df[required_cols].mean())
    for col in required_cols:
        for lag in lags:
            df[f'{col}_lag_{lag}'] = df[col].shift(lag)
        for window in rolling_windows:
            df[f'{col}_roll_mean_{window}'] = df[col].rolling(window=window).mean()
            df[f'{col}_roll_std_{window}'] = df[col].rolling(window=window).std()
    if seasonal_features:
        df['month'] = df.index.month
        df = pd.get_dummies(df, columns=['month'], prefix='month')
        period = CONFIG['seasonal_periods']
        df['month_sin'] = np.sin(2 * np.pi * df.index.month / period)
        df['month_cos'] = np.cos(2 * np.pi * df.index.month / period)
        logger.info("Added seasonal features")
    df['quarter'] = df.index.quarter
    for col in df.columns:
        if df[col].isna().all():
            raise ValueError(f"Column {col} contains only NaN")
        if df[col].isnull().any():
            df[col] = df[col].fillna(method=fill_method).fillna(df[col].mean())
            logger.info(f"Filled NaN in {col} with {fill_method} and mean")
    if df.isnull().any().any():
        raise ValueError(f"Data still contains NaN: {df.isnull().sum().to_dict()}")
    correlations = df.corr()[target].drop(required_cols, errors='ignore')
    low_corr_cols = correlations[abs(correlations) < CONFIG['correlation_threshold']].index
    if low_corr_cols.any():
        logger.info(f"Dropping low-correlation features: {list(low_corr_cols)}")
        df = df.drop(columns=low_corr_cols)
    logger.info(f"Features after processing: {list(df.columns)}")
    return df

def calculate_metrics(actual, predicted):
    actual = np.array(actual, dtype=float)
    predicted = np.array(predicted, dtype=float)
    valid_mask = ~np.isnan(actual) & ~np.isnan(predicted) & ~np.isinf(actual) & ~np.isinf(predicted)
    actual = actual[valid_mask]
    predicted = predicted[valid_mask]
    if len(actual) == 0:
        logger.warning("No valid data for metrics calculation!")
        return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan
    rmse = np.sqrt(mean_squared_error(actual, predicted))
    mae = mean_absolute_error(actual, predicted)
    mape = mean_absolute_percentage_error(actual, predicted) * 100 if np.all(np.abs(actual) > 1e-8) else np.nan
    smape = 100 * np.mean(2 * np.abs(predicted - actual) / (np.abs(actual) + np.abs(predicted)))
    norm_mape = mape / np.mean(np.abs(actual)) if not np.isnan(mape) else np.nan
    directional_acc = np.mean((np.diff(actual) * np.diff(predicted)) > 0) * 100 if len(actual) > 1 else np.nan
    return rmse, mae, mape, smape, norm_mape, directional_acc

def plot_decomposition(series, period, filename):
    decomposition = seasonal_decompose(series, period=period, model='additive')
    plt.figure(figsize=(12, 8))
    plt.subplot(411); plt.plot(series.index, series, label='Original'); plt.legend(loc='upper left')
    plt.subplot(412); plt.plot(series.index, decomposition.trend, label='Trend'); plt.legend(loc='upper left')
    plt.subplot(413); plt.plot(series.index, decomposition.seasonal, label='Seasonal'); plt.legend(loc='upper left')
    plt.subplot(414); plt.plot(series.index, decomposition.resid, label='Residual'); plt.legend(loc='upper left')
    plt.tight_layout()
    try:
        plt.savefig(os.path.join(CONFIG['img_dir'], filename), dpi=300)
        logger.info(f"Saved decomposition plot: {filename}")
    except Exception as e:
        logger.error(f"Error saving decomposition plot: {str(e)}")
    plt.close()

def plot_forecast(historical, test, forecast, forecast_index, title, ylabel, filename, confidence_intervals=None):
    plt.figure(figsize=(12, 6))
    plt.plot(historical.index, historical, label='Historical', color='blue')
    plt.plot(test.index, test, label='Actual (Test)', color='green')
    plt.plot(forecast_index, forecast, label='Forecast', color='orange', linestyle='--', linewidth=2)
    if confidence_intervals:
        plt.fill_between(forecast_index, confidence_intervals[0], confidence_intervals[1], color='orange', alpha=0.2, label='95% CI')
    plt.title(title); plt.xlabel('Time'); plt.ylabel(ylabel); plt.legend(); plt.grid(True)
    plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
    plt.gca().xaxis.set_major_locator(mdates.MonthLocator(interval=3))
    plt.xticks(rotation=45); plt.tight_layout()
    try:
        plt.savefig(os.path.join(CONFIG['img_dir'], filename), dpi=300)
        logger.info(f"Saved plot: {filename}")
    except Exception as e:
        logger.error(f"Error saving plot: {str(e)}")
    plt.close()

def plot_comparison_forecasts(historical, test, forecasts, forecast_index, title, ylabel, filename, metrics=None):
    plt.figure(figsize=(14, 8))
    plt.plot(historical.index, historical, label='Historical', color='blue')
    plt.plot(test.index, test, label='Actual (Test)', color='green')
    colors = sns.color_palette("husl", len(forecasts))
    for (model_name, forecast), color in zip(forecasts.items(), colors):
        rmse = metrics.get(model_name, {}).get('RMSE', np.nan) if metrics else np.nan
        if forecast is None or pd.isna(rmse):
            continue
        plt.plot(forecast_index, forecast, label=f'Forecast {model_name} (RMSE: {rmse:.4f})', linestyle='--', color=color)
    plt.title(title); plt.xlabel('Time'); plt.ylabel(ylabel); plt.legend(); plt.grid(True)
    plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
    plt.gca().xaxis.set_major_locator(mdates.MonthLocator(interval=3))
    plt.xticks(rotation=45); plt.tight_layout()
    try:
        plt.savefig(os.path.join(CONFIG['img_dir'], filename), dpi=300)
        logger.info(f"Saved comparison plot: {filename}")
    except Exception as e:
        logger.error(f"Error saving comparison plot: {str(e)}")
    plt.close()

def plot_metrics_bar(metrics_df, filename):
    plt.figure(figsize=(12, 8))
    metrics = ['RMSE', 'MAE', 'MAPE', 'sMAPE', 'NormMAPE', 'DirAcc']
    for i, metric in enumerate(metrics, 1):
        plt.subplot(2, 3, i)
        sns.barplot(x='Model', y=metric, data=metrics_df)
        plt.title(f'{metric} Comparison')
        plt.xticks(rotation=45)
        plt.tight_layout()
    try:
        plt.savefig(os.path.join(CONFIG['img_dir'], filename), dpi=300)
        logger.info(f"Saved metrics bar plot: {filename}")
    except Exception as e:
        logger.error(f"Error saving metrics bar plot: {str(e)}")
    plt.close()

def plot_residual_acf(residuals, title, filename):
    if residuals is None or len(residuals) < 2:
        logger.warning(f"Skipping ACF: Insufficient residual data - {title}")
        return
    plt.figure(figsize=(5, 3))
    max_lags = min(20, len(residuals) - 1)
    if max_lags < 1:
        return
    try:
        plot_acf(residuals, lags=max_lags, title=title)
        plt.tight_layout()
        plt.savefig(os.path.join(CONFIG['img_dir'], filename), dpi=300)
        logger.info(f"Saved ACF plot: {filename}")
    except Exception as e:
        logger.error(f"Error saving ACF plot: {str(e)}")
    plt.close()

def run_exponential_smoothing(train, test, forecast_index, seasonal_periods=CONFIG['seasonal_periods']):
    start_time = time.time()
    try:
        model = ExponentialSmoothing(train, trend='add', seasonal='add', seasonal_periods=seasonal_periods).fit(optimized=True)
        forecast = model.forecast(CONFIG['forecast_horizon'])
        residuals = train - model.fittedvalues
        forecast = pd.Series(forecast.values, index=forecast_index)
        resid_std = np.std(residuals)
        ci_lower = forecast - 1.96 * resid_std
        ci_upper = forecast + 1.96 * resid_std
        rmse, mae, mape, smape, norm_mape, dir_acc = calculate_metrics(test, forecast)
        logger.info(f"Exponential Smoothing: RMSE={rmse:.4f}, Time={time.time() - start_time:.2f}s")
        return forecast, residuals, rmse, mae, mape, smape, norm_mape, dir_acc, (ci_lower, ci_upper)
    except Exception as e:
        logger.error(f"Error Exponential Smoothing: {str(e)}")
        return None, None, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, None

def run_arima(train, test, forecast_index):
    start_time = time.time()
    try:
        d = check_stationarity(train, "ARIMA")
        try:
            model = pm.auto_arima(train, start_p=0, start_q=0, max_p=3, max_q=3, d=d, max_d=2,
                                  seasonal=False, stepwise=True, trace=False, error_action='ignore',
                                  suppress_warnings=True, information_criterion='aic', maxiter=50)
        except:
            logger.warning("auto_arima failed, using ARIMA(1,1,1)")
            model = pm.ARIMA(order=(1,1,1), suppress_warnings=True).fit(train)
        forecast = model.predict(n_periods=CONFIG['forecast_horizon'])
        residuals = train - model.predict_in_sample()
        forecast = pd.Series(forecast, index=forecast_index)
        rmse, mae, mape, smape, norm_mape, dir_acc = calculate_metrics(test, forecast)
        logger.info(f"ARIMA (order={model.order}): RMSE={rmse:.4f}, Time={time.time() - start_time:.2f}s")
        return forecast, residuals, rmse, mae, mape, smape, norm_mape, dir_acc, None
    except Exception as e:
        logger.error(f"Error ARIMA: {str(e)}")
        return None, None, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, None

def run_sarima(train, test, forecast_index, seasonal_periods=CONFIG['seasonal_periods']):
    start_time = time.time()
    try:
        d = check_stationarity(train, "SARIMA")
        try:
            model = pm.auto_arima(train, start_p=0, start_q=0, max_p=3, max_q=3, d=d, max_d=2,
                                  seasonal=True, m=seasonal_periods, start_P=0, start_Q=0, max_P=2, max_Q=2, max_D=1,
                                  stepwise=True, trace=False, error_action='ignore', suppress_warnings=True,
                                  information_criterion='aic', maxiter=50)
        except:
            logger.warning("auto_arima failed, using SARIMA(1,1,1)(1,1,1,12)")
            model = pm.ARIMA(order=(1,1,1), seasonal_order=(1,1,1,seasonal_periods), suppress_warnings=True).fit(train)
        forecast = model.predict(n_periods=CONFIG['forecast_horizon'])
        residuals = train - model.predict_in_sample()
        forecast = pd.Series(forecast, index=forecast_index)
        rmse, mae, mape, smape, norm_mape, dir_acc = calculate_metrics(test, forecast)
        logger.info(f"SARIMA (order={model.order}, seasonal_order={model.seasonal_order}): RMSE={rmse:.4f}, Time={time.time() - start_time:.2f}s")
        return forecast, residuals, rmse, mae, mape, smape, norm_mape, dir_acc, None
    except Exception as e:
        logger.error(f"Error SARIMA: {str(e)}")
        return None, None, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, None

def run_sarimax(train, test, forecast_index, exog_train, exog_test):
    start_time = time.time()
    try:
        if exog_train.isna().any().any() or exog_test.isna().any().any():
            logger.warning("Exogenous data contains NaN, filling...")
            exog_train = exog_train.fillna(method='ffill').fillna(exog_train.mean())
            exog_test = exog_test.fillna(method='ffill').fillna(exog_test.mean())
        if len(exog_train) != len(train) or len(exog_test) != CONFIG['forecast_horizon']:
            raise ValueError(f"Exogenous data size mismatch: exog_train={len(exog_train)}, train={len(train)}, exog_test={len(exog_test)}")
        d = check_stationarity(train, "SARIMAX")
        try:
            model = pm.auto_arima(train, exogenous=exog_train, start_p=0, start_q=0, max_p=3, max_q=3, d=d, max_d=2,
                                  seasonal=True, m=CONFIG['seasonal_periods'], start_P=0, start_Q=0, max_P=1, max_Q=1, max_D=1,
                                  stepwise=True, trace=False, error_action='ignore', suppress_warnings=True,
                                  information_criterion='aic', maxiter=50)
        except:
            logger.warning("auto_arima failed, using SARIMAX(1,1,1)(1,0,1,12)")
            model = pm.ARIMA(order=(1,1,1), seasonal_order=(1,0,1,CONFIG['seasonal_periods']), suppress_warnings=True).fit(train, exogenous=exog_train)
        forecast = model.predict(n_periods=CONFIG['forecast_horizon'], exogenous=exog_test)
        residuals = train - model.predict_in_sample(exogenous=exog_train)
        forecast = pd.Series(forecast, index=forecast_index)
        rmse, mae, mape, smape, norm_mape, dir_acc = calculate_metrics(test, forecast)
        logger.info(f"SARIMAX (order={model.order}, seasonal_order={model.seasonal_order}): RMSE={rmse:.4f}, Time={time.time() - start_time:.2f}s")
        return forecast, residuals, rmse, mae, mape, smape, norm_mape, dir_acc, None
    except Exception as e:
        logger.error(f"Error SARIMAX: {str(e)}")
        return None, None, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, None

def run_prophet(train, test, forecast_index):
    start_time = time.time()
    try:
        df_train = pd.DataFrame({'ds': train.index, 'y': train.values})
        model = Prophet(yearly_seasonality=True, weekly_seasonality=False, daily_seasonality=False,
                        changepoint_prior_scale=0.05, seasonality_prior_scale=10.0).fit(df_train)
        future = pd.DataFrame({'ds': forecast_index})
        forecast = model.predict(future)
        forecast_series = pd.Series(forecast['yhat'].values, index=forecast_index)
        residuals = train - model.predict(df_train)['yhat']
        ci_lower = pd.Series(forecast['yhat_lower'].values, index=forecast_index)
        ci_upper = pd.Series(forecast['yhat_upper'].values, index=forecast_index)
        rmse, mae, mape, smape, norm_mape, dir_acc = calculate_metrics(test, forecast_series)
        logger.info(f"Prophet: RMSE={rmse:.4f}, Time={time.time() - start_time:.2f}s")
        return forecast_series, residuals, rmse, mae, mape, smape, norm_mape, dir_acc, (ci_lower, ci_upper)
    except Exception as e:
        logger.error(f"Error Prophet: {str(e)}")
        return None, None, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, None


def run_model_for_target(target, train, test, forecast_index, model_name, model_func, params):
    logger.info(f"Running {model_name} for {target}")
    start_time = time.time()
    try:
        if model_name == 'SARIMAX':
            exog_var = 'cpi_yoy'
            feature_cols = [exog_var] + [f'{target}_lag_{lag}' for lag in CONFIG['lag_features']] + \
                           [f'{target}_roll_mean_{w}' for w in CONFIG['rolling_windows']] + \
                           [f'{target}_roll_std_{w}' for w in CONFIG['rolling_windows']] + \
                           [f'{exog_var}_lag_{lag}' for lag in CONFIG['lag_features']] + \
                           [f'{exog_var}_roll_mean_{w}' for w in CONFIG['rolling_windows']] + \
                           [f'{exog_var}_roll_std_{w}' for w in CONFIG['rolling_windows']] + \
                           [col for col in train.columns if col.startswith('month_') or col in ['month_sin', 'month_cos', 'quarter']]
            feature_cols = [col for col in feature_cols if col in train.columns]
            exog_train = train[feature_cols]
            exog_test = pd.DataFrame(index=forecast_index)
            exog_test[exog_var] = test[exog_var].reindex(forecast_index).fillna(train[exog_var].iloc[-1]).fillna(train[exog_var].mean())
            for col in feature_cols:
                if col != exog_var:
                    exog_test[col] = train[col].iloc[-1] if col in train.columns else train[exog_var].mean()
            if exog_train.isna().any().any() or exog_test.isna().any().any():
                logger.warning(f"Exogenous data still contains NaN after filling")
                exog_train = exog_train.fillna(exog_train.mean())
                exog_test = exog_test.fillna(exog_test.mean())
            forecast, residuals, rmse, mae, mape, smape, norm_mape, dir_acc, ci = model_func(
                train[target], test[target], forecast_index, exog_train, exog_test, **params
            )
        else:
            forecast, residuals, rmse, mae, mape, smape, norm_mape, dir_acc, ci = model_func(
                train[target], test[target], forecast_index, **params
            )
        if forecast is None or pd.isna(rmse):
            logger.error(f"{model_name} for {target} failed to produce valid forecast or RMSE")
            return None
        plot_forecast(train[target][-36:], test[target], forecast, forecast_index,
                      f'{model_name} Forecast for {target}', target, f'{target}_{model_name}_forecast.png', ci)
        if residuals is not None:
            plot_residual_acf(residuals.dropna(), f'ACF of Residuals - {model_name} ({target})',
                              f'{target}_{model_name}_acf.png')
        logger.info(f"Completed {model_name} for {target} in {time.time() - start_time:.2f}s")
        return {
            'Target': target,
            'Model': model_name,
            'RMSE': rmse,
            'MAE': mae,
            'MAPE': mape,
            'sMAPE': smape,
            'NormMAPE': norm_mape,
            'DirAcc': dir_acc,
            'Forecast': forecast,
            'Residuals': residuals,
            'CI': ci
        }
    except Exception as e:
        logger.error(f"Error running {model_name} for {target}: {str(e)}")
        return None

def main():
    try:
        data = pd.read_csv('data/analyzed_time_series.csv')
        data['time'] = pd.to_datetime(data['time'])
        data.set_index('time', inplace=True)
        required_columns = ['cpi_mom', 'cpi_yoy']
        validate_input_data(data, required_columns)
        for col in required_columns:
            data[col] = detect_outliers(data[col])
        data_features = create_features(data, 'cpi_mom')
        train_size = len(data) - CONFIG['forecast_horizon']
        train, test = data_features[:train_size], data_features[train_size:]
        forecast_index = pd.date_range(start=test.index[0], periods=CONFIG['forecast_horizon'], freq='MS')
        plot_decomposition(data['cpi_mom'], period=CONFIG['seasonal_periods'], filename='cpi_mom_decomposition.png')
        models = {
            'ARIMA': (run_arima, {}),
            'Exponential Smoothing': (run_exponential_smoothing, {}),
            'Prophet': (run_prophet, {}),
            'SARIMA': (run_sarima, {}),
            'SARIMAX': (run_sarimax, {})
        }
        results = []
        forecasts_mom = {}
        metrics_mom = {}
        logger.info("Running models for cpi_mom")
        tasks = [delayed(run_model_for_target)('cpi_mom', train, test, forecast_index, model_name, model_func, params)
                 for model_name, (model_func, params) in models.items()]
        model_results = Parallel(n_jobs=CONFIG['n_jobs'], verbose=1)(tasks)
        for result in model_results:
            if result is not None:
                results.append({
                    'Target': result['Target'],
                    'Model': result['Model'],
                    'RMSE': result['RMSE'],
                    'MAE': result['MAE'],
                    'MAPE': result['MAPE'],
                    'sMAPE': result['sMAPE'],
                    'NormMAPE': result['NormMAPE'],
                    'DirAcc': result['DirAcc']
                })
                forecasts_mom[result['Model']] = result['Forecast']
                metrics_mom[result['Model']] = {'RMSE': result['RMSE']}
            else:
                logger.warning(f"Result for a cpi_mom model is None, skipping!")
        # if forecasts_mom:
            # ensemble_result = run_model_for_target('cpi_mom', train, test, forecast_index, 'Ensemble', run_ensemble,
            #                                       {'forecasts': forecasts_mom, 'metrics': metrics_mom})
            # if ensemble_result is not None:
            #     results.append({
            #         'Target': ensemble_result['Target'],
            #         'Model': ensemble_result['Model'],
            #         'RMSE': ensemble_result['RMSE'],
            #         'MAE': ensemble_result['MAE'],
            #         'MAPE': ensemble_result['MAPE'],
            #         'sMAPE': ensemble_result['sMAPE'],
            #         'NormMAPE': ensemble_result['NormMAPE'],
            #         'DirAcc': ensemble_result['DirAcc']
            #     })
            #     forecasts_mom[ensemble_result['Model']] = ensemble_result['Forecast']
            #     metrics_mom[ensemble_result['Model']] = {'RMSE': ensemble_result['RMSE']}
        if forecasts_mom:
            plot_comparison_forecasts(train['cpi_mom'][-36:], test['cpi_mom'], forecasts_mom, forecast_index,
                                     'Comparison of Forecasts for cpi_mom', 'cpi_mom', 'cpi_mom_model_comparison.png',
                                     metrics=metrics_mom)
        results_df = pd.DataFrame(results)
        print(results_df)
        results_df.to_csv(CONFIG['results_file'], index=False)
        logger.info(f"Results saved to {CONFIG['results_file']}")
        if not results_df.empty:
            plot_metrics_bar(results_df, 'cpi_mom_metrics_comparison.png')
        if forecasts_mom:
            combined_forecast = pd.DataFrame({'Date': forecast_index})
            for model_name, forecast in forecasts_mom.items():
                combined_forecast[f'{model_name}_cpi_mom'] = forecast
            combined_forecast.to_csv(f'{img_dir}/combined_forecast_cpi_mom.csv', index=False)
            logger.info(f"Combined forecasts saved to {img_dir}/combined_forecast_cpi_mom.csv")
    except Exception as e:
        logger.error(f"Main program error: {str(e)}")
        raise

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 16 concurrent workers.
[Parallel(n_jobs=-1)]: Done 5 out of 5 | elapsed:   31.7s finished


    Target                  Model      RMSE       MAE      MAPE     sMAPE  \
0  cpi_mom                  ARIMA  0.342721  0.274343  0.273675  0.273471   
1  cpi_mom  Exponential Smoothing  0.228608  0.191115  0.190825  0.190594   
2  cpi_mom                Prophet  0.241135  0.173548  0.173083  0.172993   
3  cpi_mom                 SARIMA  0.327964  0.253911  0.253186  0.253048   
4  cpi_mom                SARIMAX  0.314712  0.239798  0.239112  0.238990   

   NormMAPE     DirAcc  
0  0.002730  63.636364  
1  0.001904  36.363636  
2  0.001727  72.727273  
3  0.002526  36.363636  
4  0.002385  54.545455  
