# Model Training and Evaluation: SARIMA & LightGBM Ensemble

This notebook trains SARIMA models to forecast economic indicators, then uses LightGBM to ensemble these forecasts and predict recession probability for 1, 3, and 6 months ahead. Includes evaluation and visualization.

In [5]:
!pip install lightgbm




[notice] A new release of pip is available: 24.0 -> 25.2
[notice] To update, run: C:\Users\dulak\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip


In [None]:
# 1. Import Libraries
import pandas as pd
import numpy as np
# import matplotlib.pyplot as plt  # Removed for headless model training
# import seaborn as sns  # Removed for headless model training
from statsmodels.tsa.statespace.sarimax import SARIMAX
import lightgbm as lgb
from sklearn.metrics import roc_auc_score, accuracy_score, classification_report

# 2. Load Feature-Engineered Data
df = pd.read_csv('../data/processed/feature_engineered_economic_indicators.csv', index_col=0, parse_dates=True)
print(f"Loaded feature-engineered data: {df.shape}")
print("\nAvailable columns:")
print(df.columns.tolist())

# Set target columns explicitly for this dataset
target_1m = '1_month_recession_probability'
target_3m = '3_month_recession_probability'
target_6m = '6_month_recession_probability'

targets = [target_1m, target_3m, target_6m]
indicator_cols = [col for col in df.select_dtypes(include=[np.number]).columns if col not in targets]

# 3. SARIMA Forecasting for Each Indicator
sarima_forecasts = pd.DataFrame(index=df.index)
for col in indicator_cols:
    print(f"Fitting SARIMA for {col}...")
    train = df[col].iloc[:-6]
    model = SARIMAX(train, order=(1,1,1), seasonal_order=(0,1,1,12), enforce_stationarity=False, enforce_invertibility=False)
    results = model.fit(disp=False)
    forecast = results.get_forecast(steps=6)
    sarima_forecasts[col+'_sarima'] = np.nan
    # Assign forecasted values to the last 6 rows using iloc
    sarima_forecasts.iloc[-6:, sarima_forecasts.columns.get_loc(col+'_sarima')] = forecast.predicted_mean.values
sarima_forecasts = sarima_forecasts.fillna(method='ffill')
print("SARIMA forecasts complete.")

# 4. Prepare LightGBM Data (using SARIMA forecasts as features)
X = sarima_forecasts.copy()
y_1m = df[target_1m].loc[X.index]
y_3m = df[target_3m].loc[X.index]
y_6m = df[target_6m].loc[X.index]

# Binarize targets: 1 if probability >= 0.5, else 0
y_1m = (df[target_1m].loc[X.index] >= 0.5).astype(int)
y_3m = (df[target_3m].loc[X.index] >= 0.5).astype(int)
y_6m = (df[target_6m].loc[X.index] >= 0.5).astype(int)

# 5. Train/Test Split
split = int(len(X)*0.8)
X_train, X_test = X.iloc[:split], X.iloc[split:]
y_train_1m, y_test_1m = y_1m.iloc[:split], y_1m.iloc[split:]
y_train_3m, y_test_3m = y_3m.iloc[:split], y_3m.iloc[split:]
y_train_6m, y_test_6m = y_6m.iloc[:split], y_6m.iloc[split:]

# 6. LightGBM Model Training and Evaluation
def train_eval_lgb(X_train, y_train, X_test, y_test, horizon):
    print(f"\nTraining LightGBM for {horizon} ahead...")
    lgb_train = lgb.Dataset(X_train, y_train)
    lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
    params = {'objective': 'binary', 'metric': 'auc', 'verbosity': -1}
    gbm = lgb.train(
        params,
        lgb_train,
        valid_sets=[lgb_train, lgb_eval],
        num_boost_round=100,
        callbacks=[lgb.early_stopping(10)]
    )
    y_pred = gbm.predict(X_test)
    auc = roc_auc_score(y_test, y_pred)
    y_pred_label = (y_pred > 0.5).astype(int)
    acc = accuracy_score(y_test, y_pred_label)
    print(f"AUC: {auc:.3f}, Accuracy: {acc:.3f}")
    print(classification_report(y_test, y_pred_label))
    # Plotting removed for headless training
    return gbm

model_1m = train_eval_lgb(X_train, y_train_1m, X_test, y_test_1m, '1 month')
model_3m = train_eval_lgb(X_train, y_train_3m, X_test, y_test_3m, '3 months')
model_6m = train_eval_lgb(X_train, y_train_6m, X_test, y_test_6m, '6 months')

print("\nAll models trained and evaluated.")

ModuleNotFoundError: No module named 'matplotlib.backends.registry'

In [None]:
# --- Train models using only the 11 dashboard indicators ---
# INSTRUCTIONS:
# 1. Run the diagnostic cell in your feature engineering notebook to print all available columns in your 11-indicator dataset.
# 2. Update the indicator_11 list below to match the EXACT column names you want to use (e.g., use lagged or rolling columns if needed).
# 3. Uncomment and run this cell after updating the list.

# Example: If you want the most recent value, use the base column name (if it exists).
# If only lagged/rolling columns exist, use those (e.g., 'DTB1YR_lag1', 'CPIAUCSL_rollmean3', etc.).

import joblib

# Load the 11-indicator dataset
X_11 = pd.read_csv('../data/processed/feature_engineered_economic_indicators_11.csv', index_col=0, parse_dates=True)

# indicator_11 = [
#     'DTB1YR_lag1',
#     'DTB3_lag1',
#     'DTB6_lag1',
#     'IRLTLT01USM156N_lag1',
#     'CPIAUCSL_lag1',
#     'INDPRO_lag1',
#     'PCU3312103312100_lag1',
#     'UNRATE_lag1',
#     'SPASTT01USM661N_lag1',
#     'UMCSENT_lag1',
#     'PCU3312103312100_lag1'
# ]

# target_cols = ['1_month_recession_probability', '3_month_recession_probability', '6_month_recession_probability']

# X = X_11[indicator_11]
# y_1m = (X_11['1_month_recession_probability'] >= 0.5).astype(int)
# y_3m = (X_11['3_month_recession_probability'] >= 0.5).astype(int)
# y_6m = (X_11['6_month_recession_probability'] >= 0.5).astype(int)

# split = int(len(X)*0.8)
# X_train, X_test = X.iloc[:split], X.iloc[split:]
# y_train_1m, y_test_1m = y_1m.iloc[:split], y_1m.iloc[split:]
# y_train_3m, y_test_3m = y_3m.iloc[:split], y_3m.iloc[split:]
# y_train_6m, y_test_6m = y_6m.iloc[:split], y_6m.iloc[split:]

# params = {'objective': 'binary', 'metric': 'auc', 'verbosity': -1}
# model_1m = lgb.train(params, lgb.Dataset(X_train, y_train_1m), num_boost_round=100)
# model_3m = lgb.train(params, lgb.Dataset(X_train, y_train_3m), num_boost_round=100)
# model_6m = lgb.train(params, lgb.Dataset(X_train, y_train_6m), num_boost_round=100)

# joblib.dump(model_1m, '../models/recession_model_1m.pkl')
# joblib.dump(model_3m, '../models/recession_model_3m.pkl')
# joblib.dump(model_6m, '../models/recession_model_6m.pkl')
# print("Saved new 1m, 3m, 6m models for dashboard with 11 features.")

In [None]:
# If the automatic search fails, inspect the DataFrame and set target columns manually below.
print("\nFirst few rows of the DataFrame:")
display(df.head())
# Example: Uncomment and set these if needed:
# target_1m = 'your_column_name_for_1m'
# target_3m = 'your_column_name_for_3m'
# target_6m = 'your_column_name_for_6m'


First few rows of the DataFrame:


Unnamed: 0_level_0,recession_probability,1_month_recession_probability,3_month_recession_probability,6_month_recession_probability,1_year_rate,3_months_rate,6_months_rate,CPI,INDPRO,10_year_rate,...,OECD_CLI_index_rollmean6,OECD_CLI_index_rollstd6,OECD_CLI_index_rollmean12,OECD_CLI_index_rollstd12,CSI_index_rollmean3,CSI_index_rollstd3,CSI_index_rollmean6,CSI_index_rollstd6,CSI_index_rollmean12,CSI_index_rollstd12
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1968-10-01,-0.647085,-0.770217,-0.402839,1.180406,0.519028,0.589387,0.552407,-1.713034,-1.687336,-0.024419,...,0.715973,0.212248,0.491081,0.362463,0.463302,0.0,0.463302,0.0,0.649274,0.222117
1968-11-01,-0.647085,-0.266544,-0.837295,0.279778,0.587763,0.628222,0.622804,-1.711668,-1.665596,0.019845,...,0.822794,0.246756,0.612845,0.3108,0.441695,0.037424,0.452498,0.026463,0.608376,0.219073
1968-12-01,-0.288238,-0.434435,-0.620067,0.05462,0.760602,0.823446,0.798529,-1.708936,-1.66022,0.14157,...,0.928762,0.235642,0.712184,0.279231,0.420088,0.037424,0.441695,0.033473,0.567477,0.207367
1969-01-01,-0.407854,-0.770217,1.11776,0.730092,0.796667,0.894571,0.884942,-1.70757,-1.64989,0.145259,...,1.0145,0.179614,0.773675,0.280165,0.599116,0.347509,0.531209,0.232031,0.608376,0.240642
1969-02-01,-0.647085,-0.602326,0.248847,2.306192,0.851053,0.887239,0.891503,-1.706204,-1.638894,0.200589,...,1.060056,0.110203,0.812518,0.274637,0.79975,0.347509,0.620722,0.29551,0.616092,0.252318
