In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

supermarket_df = pd.read_csv("完整数据.csv"
                             ,encoding="Windows-1252")
# 邮编缺失值过多，删除
supermarket_df.drop(columns="Postal Code", axis=0, inplace=True)

supermarket_df["Order Date"] = pd.to_datetime(supermarket_df["Order Date"], format="mixed")
supermarket_df["Ship Date"] = pd.to_datetime(supermarket_df["Ship Date"], format="mixed")
supermarket_df

In [None]:
%pip install statsmodels

## 训练前预处理

In [None]:
from statsmodels.tsa.statespace.sarimax import SARIMAX
from statsmodels.tsa.seasonal import seasonal_decompose
import itertools
import numpy as np

# 按月份分开

sales_total = supermarket_df.set_index(pd.DatetimeIndex(supermarket_df["Order Date"]))
sales_total = sales_total.groupby(pd.Grouper(freq='MS'))["Sales"].sum()

sales_total.index.freq = "MS"
sales_total

## 检查加法模型/乘法模型

In [None]:
from statsmodels.tsa.stattools import acf
from statsmodels.tsa.seasonal import seasonal_decompose

add_results = seasonal_decompose(sales_total, model="additive")
mul_results = seasonal_decompose(sales_total, model="multipliative")
add_results.plot()


In [None]:
mul_results.plot()

In [None]:
add_resid = add_results.resid.dropna()
mul_resid = mul_results.resid.dropna()
add_acf_total = np.sum(np.square(acf(add_resid)))
mul_acf_total = np.sum(np.square(acf(mul_resid)))
add_acf_total, mul_acf_total, add_acf_total < mul_acf_total

## 随机网格搜索

In [None]:
from statsmodels.tsa.statespace.sarimax import SARIMAX
from tqdm import tqdm
import random
import warnings
warnings.filterwarnings('ignore')

# 网格搜索
params = {
  "p": [0,1,2,3],
  "d": [0,1,2,3],
  "q": [0,1,2,3],
  "P": [0,1,2,3],
  "D": [0,1,2,3],
  "Q": [0,1,2,3]
}

result_matrix = pd.DataFrame(columns=["Parameters", "MSE", "MAE" ,"AIC" , "HQIC","Log-Likelihood"])

def grid_search(params):
  keys, values = zip(*params.items())
  param_combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]
  random_param_combinations = random.sample(param_combinations, k=200)
  for param in tqdm(random_param_combinations):
    p, d, q, P, D, Q = param.values()
    try:
      model = SARIMAX(sales_total, order=(p,d,q), seasonal_order=(P, D, Q,12), trend="ct")
      results = model.fit()
      result_matrix.loc[len(result_matrix.index)] = [f"({p}, {d}, {q})x({P}, {D}, {Q}, 12)", results.mse ** 0.5, results.mae, results.aic, results.hqic, results.llf]
    except:
      result_matrix.loc[len(result_matrix.index)] = [f"({p}, {d}, {q})x({P}, {D}, {Q}, 12)", np.nan, np.nan, np.nan, np.nan, np.nan]


grid_search(params)

result_matrix.sort_values(by="AIC")

In [None]:
result_matrix.sort_values(by="AIC")

In [None]:
from statsmodels.tsa.statespace.sarimax import SARIMAX

p, d, q, P, D, Q = 2,1,0,2,1,0

model = SARIMAX(sales_total, order=(p,d,q), seasonal_order=(P,D,Q,12), trend="ct")
results = model.fit()
results_restricted = results.apply(sales_total, refit=False)
pred = results_restricted.get_prediction(start="2011-01-01", end="2025-01-01")
data = pd.concat([sales_total, pred.predicted_mean], axis=1).reset_index()

import altair as alt

base = alt.Chart(data).encode(x="index:T")
alt.layer(
base.mark_line(color="skyblue").encode(y="Sales:Q"),
base.mark_line(color="red").encode(y="predicted_mean:Q")
).properties(
  width=900
).show()

print(results.aic)