In [4]:
import multiprocessing
multiprocessing.set_start_method("spawn", force=True)

In [None]:
import os
import pandas as pd
import numpy as np
import joblib
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_squared_error, mean_absolute_error
import matplotlib.pyplot as plt
input_dir = "SARIMA_data"
output_dir = "SARIMA_results"
model_dir = os.path.join(output_dir, "models")
plot_dir = os.path.join(output_dir, "plots")
os.makedirs(model_dir, exist_ok=True)
os.makedirs(plot_dir, exist_ok=True)
performance_summary = []
def strip_model(model_fit):
   # Remove nested attributes
    del_attrs = [
        "model.data", "data", "filter_results", "model.endog",
        "model._data", "model.exog", "model._index", "model.orig_endog",
        "model.orig_exog", "model._init_keys", "model._init_kwds",
        "model._cache", "model.results", "model.mle_retvals"
    ]
    for attr in del_attrs:
        try:
            keys = attr.split('.')
            obj = model_fit
            for k in keys[:-1]:
                obj = getattr(obj, k)
            setattr(obj, keys[-1], None)
        except Exception:
            continue
 # Remove large arrays from results
    for attr in ["resid", "fittedvalues", "predict_results", "normalized_cov_params"]:
        if hasattr(model_fit, attr):
            setattr(model_fit, attr, None)
# # Remove miscellaneous metadata
#     if hasattr(model_fit, "_results"):
#         model_fit._results = None
#     if hasattr(model_fit, "_cache"):
#         model_fit._cache = {}

#     return model_fit

def sarima_forecast(file, forecast_days=60):
    stock = file.replace("_cleaned.csv", "")
    df = pd.read_csv(os.path.join(input_dir, file))
    df['Date'] = pd.to_datetime(df['Date'])
    df.set_index('Date', inplace=True)

    series = df['Close'].values
    train = series[:-10]
    test = series[-10:]

    order = (1, 1, 1)
    seasonal_order = (1, 1, 1, 7)

    try:
        model = SARIMAX(train, order=order, seasonal_order=seasonal_order)
        model_fit = model.fit(disp=False)
# Predictions
        train_preds = model_fit.predict(start=0, end=len(train)-1)
        test_preds = model_fit.predict(start=len(train), end=len(train)+len(test)-1)
        future_preds = model_fit.predict(start=len(train)+len(test), end=len(train)+len(test)+forecast_days-1)
# Metrics
        train_rmse = np.sqrt(mean_squared_error(train, train_preds))
        test_rmse = np.sqrt(mean_squared_error(test, test_preds))
        train_mae = mean_absolute_error(train, train_preds)
        test_mae = mean_absolute_error(test, test_preds)
# Strip model to reduce size
        # model_fit = strip_model(model_fit)
# Save model
        model_path = os.path.join(model_dir, f"{stock}_sarima_model.pkl")
        joblib.dump(model_fit, model_path, compress=3)
        size_mb = os.path.getsize(model_path) / (1024 * 1024)
# Plot test predictions
        test_index = df.index[-10:]
        plt.figure(figsize=(10, 5))
        plt.plot(test_index, test, label='Actual')
        plt.plot(test_index, test_preds, label='Predicted')
        plt.title(f"{stock} - SARIMA Test Forecast")
        plt.xlabel("Date")
        plt.ylabel("Price")
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(plot_dir, f"{stock}_test_plot.png"))
        plt.close()
        performance_summary.append({
            "Stock": stock,
            "Train_RMSE": round(train_rmse, 4),
            "Test_RMSE": round(test_rmse, 4),
            "Train_MAE": round(train_mae, 4),
            "Test_MAE": round(test_mae, 4),
            "Model_Size_MB": round(size_mb, 2)
        })
    print(f"{stock},Train RMSE={train_rmse:.2f}, Test RMSE={test_rmse:.2f} , Model Size={size_mb:.2f} MB")
    except Exception as e:
        print(f"Failed for {stock}: {e}")
# Run for all files
for file in os.listdir(input_dir):
    if file.endswith("_cleaned.csv"):
        sarima_forecast(file)
# Save performance summary
summary_path = os.path.join(output_dir, "sarima_performance_summary.csv")
pd.DataFrame(performance_summary).to_csv(summary_path, index=False)
print(f"Performance summary saved to {summary_path}")
