In [None]:
# %pip install pandas matplotlib seaborn scikit-learn statsmodels
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from statsmodels.tsa.seasonal import seasonal_decompose
from sklearn.model_selection import train_test_split
import statsmodels.api as sm


In [None]:
dff = pd.read_csv('data/load.csv', index_col='datetime', parse_dates=['datetime'])

# Ensure the index is a datetime object
dff.index = pd.to_datetime(dff.index, utc=True)

In [None]:
# Fit to range
df = dff.loc[dff['load_MW'] < 10000].loc[dff['load_MW'] > 2000]

In [None]:
df.head()

In [None]:
df['load_MW'] = df['load_MW'].diff(periods=12)
df.info()

In [None]:
# Plot the time series
plt.figure(figsize=(12, 6))
sns.lineplot(x=df.index, y=df['load_MW'])
plt.title('Data preview')
plt.xlabel('Date')
plt.ylabel('Load (MW)')
plt.show()

In [None]:
# Remove null values
dfn = df.copy()
dfn['load_MW'] = dfn['load_MW'].bfill()
dfn.head()

In [None]:
result = seasonal_decompose(dfn['load_MW'], model='additive', period=12)
trend = result.trend.dropna()
seasonal = result.seasonal.dropna()
residual = result.resid.dropna()

# Plot the decomposed components
plt.figure(figsize=(6,6))

plt.subplot(4, 1, 1)
plt.plot(dfn['load_MW'], label='Original Series')
plt.legend()

plt.subplot(4, 1, 2)
plt.plot(trend, label='Trend')
plt.legend()

plt.subplot(4, 1, 3)
plt.plot(seasonal, label='Seasonal')
plt.legend()

plt.subplot(4, 1, 4)
plt.plot(residual, label='Residuals')
plt.legend()

plt.tight_layout()
plt.show()


In [None]:
# Split data into training and testing sets
train, test = train_test_split(df, test_size=0.2, random_state=42)

In [None]:
# Define the SARIMA model
model = sm.tsa.statespace.SARIMAX(train, order=(1,1,1), seasonal_order=(1,1,1,365))

# Fit the model
results = model.fit(method='cg')

In [None]:
# Print the summary of the model
results.summary()

In [None]:
# Define the forecast horizon (e.g., 30 days)
forecast_horizon = 30

# Use the model to make predictions
forecast = results.predict(start=len(train), end=len(train)+forecast_horizon-1, typ='levels')

In [None]:
# Plot the forecast
plt.figure(figsize=(12, 6))
sns.lineplot(x=train.index, y=train['load_MW'], label='Training Data')
sns.lineplot(x=forecast.index, y=forecast, label='Forecast')
plt.title('SARIMA Forecast')
plt.xlabel('Date')
plt.ylabel('Load (MW)')
plt.legend()
plt.show()