# Data OverView

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import seaborn as sns

## I. Define a Funcwtion

In [None]:
def upload_dataset(from_drive = False):
  if not os.path.exists("kaggle.json"):
    if not from_drive:
      print("Upload Kaggle API Key")
      files.upload()
      print("Downloading dataset...")
    else:
      !cp /content/drive/MyDrive/kaggle.json /content/
    !mkdir -p ~/.kaggle
    !cp kaggle.json ~/.kaggle/
    !chmod 600 ~/.kaggle/kaggle.json
    !kaggle datasets download -d saurabhshahane/electricity-load-forecasting
    !unzip electricity-load-forecasting.zip
  else:
    print("Dataset already exists")

#We start with basic statistics for both numeric and categorical data
def unistats(dataframe,sorted="Missing"):
    """"
    Takes dataframe and sorted as parameter
    Returns count, missing, unique, dtype, mode and other stats"""
    pd.set_option("display.max_rows",100)
    pd.set_option("display.max_columns",100)
    output_df = pd.DataFrame(columns = ["Count","Missing","Unique", "Dtype", "Mode", "Mean", "Min", "25%", "Median", "75%", "Max", "Std", "Skew", "Kurt"])

    for col in dataframe:
        if pd.api.types.is_numeric_dtype(dataframe[col]):
            output_df.loc[col] =[dataframe[col].count() ,dataframe[col].isnull().sum() ,dataframe[col].nunique() ,dataframe[col].dtype ,dataframe[col].mode().values[0], dataframe[col].mean(), dataframe[col].min(), dataframe[col].quantile(0.25), dataframe[col].median(), dataframe[col].quantile(0.75),dataframe[col].max(), dataframe[col].std(), dataframe[col].skew(),dataframe[col].kurt()]
        else:
            output_df.loc[col] =[dataframe[col].count() ,dataframe[col].isnull().sum() ,dataframe[col].nunique() ,dataframe[col].dtype , "-", "-", "-","-", "-", "-","-", "-", "-","-"]


    return output_df.sort_values(by = ["Dtype",sorted])

def scatter(dataframe, target, feature):
    from statsmodels.formula.api import ols
    from statsmodels.stats.diagnostic import het_breuschpagan
    from scipy import stats
    """
    Takes dataframe, target and feature as parameter
    Use it with a numeric column
    Fits an OLS model with the given feature
    Applies breuschpagan test
    Returns the scatterplot, regression and test results.
    """

    sns.set_style(style="white")

    model = ols(formula= f"{target}~{feature}", data = dataframe).fit()

    lm, p1, f, p2 = het_breuschpagan(model.resid,model.model.exog)
    m, b, r, p, err = stats.linregress(dataframe[feature], dataframe[target])

    string = "y = " + str(round(m,2)) + "x " + str(round(b,2)) + "\n"
    string += "r_2 = " + str(round(r**2, 4))  + "\n"
    string += str(round(r**2, 4)*100) + "% of variance is explained" + "\n"
    string += "p = " + str(round(p, 5)) + "\n"
    if p < 0.05:
        string += "Significant" + "\n"
    else:
        string += "Not Significant" + "\n"
    string += str(dataframe[feature].name) + " skew = " + str(round(dataframe[feature].skew(), 2)) + "\n"
    if dataframe[feature].skew() < 0:
        string += str(dataframe[feature].name) + " is negatively skewed" + "\n"
    else:
        string += str(dataframe[feature].name) + " is positively skewed" + "\n"
    string += str(dataframe[target].name) + " skew = " + str(round(dataframe[target].skew(), 2)) + "\n"
    if dataframe[target].skew() < 0:
        string += str(dataframe[target].name) + " is negatively skewed" + "\n"
    else:
        string += str(dataframe[target].name) + " is positively skewed" + "\n"
    string += str(dataframe[feature].name) + " Breushpagan Test = " + "LM stat: " + str(round(lm,4)) + " p value: " + str(round(p1,4)) + " F stat: " + str(round(f,4)) + " p value: " + str(round(p2,4)) + "\n"
    if p1 < 0.05:
        string += "Variance of residuals are not distributed equally" + "\n"
    else:
        string += "Variance of residuals are distributed equally" + "\n"
    ax = sns.jointplot(x = feature, y = target, kind = "reg", data = dataframe)
    ax.fig.text( 1, 0.1, string, fontsize = 12, transform = plt.gcf().transFigure)

def plot_predictions(test,predicted):
    plt.plot(test, color='red',label='Real Demand')
    plt.plot(predicted, color='green',label='Predicted Demand')
    plt.title('Demand Prediction')
    plt.xlabel('Time')
    plt.ylabel('Prediction')
    plt.legend()
    plt.show()


def return_rmse(test,predicted):
    rmse = math.sqrt(mean_squared_error(test, predicted))
    print("The root mean squared error is {}.".format(rmse))

def hist_and_boxplot(dataframe, label):
    """
    Takes dataframe and feature as parameter
    Returns histogram and boxplot"""
    plt.figure(figsize=(6,3))
    plt.subplot(1,2,1)
    sns.histplot(data = dataframe, x = label)
    plt.subplot(1,2,2)
    sns.boxplot(data = dataframe, x = label)
    plt.show();

## II. Reading in Data and Preprocessing

In [None]:
file_path = "continuous dataset.csv"
df = pd.read_csv(file_path)
df.head()

In [None]:
unistats(df)

### Overall Summary:
1. Holiday and school variables show binary distributions indicating the presence or absence of holidays and school days.
2. nat_demand shows significant variability and a near-normal distribution.
3. Weather-related variables (temperature, humidity, precipitation, wind speed) exhibit various degrees of skewness and kurtosis, indicating different levels of variability and distribution shapes.
4. datetime is a unique identifier for each record.

In [None]:
df["datetime"] = pd.to_datetime(df["datetime"])
df["month"] = df["datetime"].dt.month
df["day"] = df["datetime"].dt.day
df["hour"] = df["datetime"].dt.hour
df["dayofweek"] = df["datetime"].dt.dayofweek
df["dayofyear"] = df["datetime"].dt.dayofyear
df.head(3)

In [None]:
df["datetime"].min(), df["datetime"].max()

In [None]:
for col in df.select_dtypes(include = "number").columns:
  hist_and_boxplot(df,col)

In [None]:
hist_and_boxplot(df,"nat_demand")

In [None]:
def plot_df(df, x, y, title="", xlabel='Date', ylabel='Value', dpi=100):
    plt.figure(figsize=(16,5), dpi=dpi)
    plt.plot(x, y, color='tab:red')
    plt.gca().set(title=title, xlabel=xlabel, ylabel=ylabel)
    plt.show()

plot_df(df, x=df["datetime"], y=df['nat_demand'], title='Time Series')

Due to the outbreak of the novel coronavirus at the end of 2019, which severely affected electricity consumption, only the data before the end of 2019 is being focused on.

In [None]:
df = df[df["datetime"] < "2019-12-31"]

In [None]:
plot_df(df, x=df["datetime"], y=df['nat_demand'], title='Time Series')

In [None]:
from statsmodels.tsa.seasonal import seasonal_decompose

decomp_add = seasonal_decompose(df["nat_demand"], period = 24*30)
decomp_add.plot();

In [None]:
decomp_mul = seasonal_decompose(df["nat_demand"], period = 24*30, model = "multiplicative")
decomp_mul.plot();

In [None]:
from statsmodels.tsa.stattools import adfuller
print('Results of Dickey-Fuller Test:')
dftest = adfuller(df["nat_demand"])
dfoutput = pd.Series(dftest[0:4], index=['Test Statistic','p-value','#Lags Used','Number of Observations Used'])
print(dfoutput)

In [None]:
# Part A:自变量分布分析

# A1.数值型变量分布
num_cols = ['T2M_toc','QV2M_toc','TQL_toc','W2M_toc',
            'T2M_san','QV2M_san','TQL_san','W2M_san',
            'T2M_dav','QV2M_dav','TQL_dav','W2M_dav']

plt.figure(figsize=(15,20))
for i,col in enumerate(num_cols):
    plt.subplot(6,4,i+1)
    sns.kdeplot(data=df[col], fill=True) # Changed shade to fill
    plt.title(f'{col} Distribution')
plt.tight_layout()
plt.show()

# A2.Boxplot展示数值变量分布
plt.figure(figsize=(15,10))
for i,col in enumerate(num_cols):
    plt.subplot(3,4,i+1)
    sns.boxplot(x=df[col])
    plt.title(f'{col} Boxplot')
plt.tight_layout()
plt.show()

# A3.离散型变量分布
cat_cols = ['month','day', 'hour']
plt.figure(figsize=(15, 8))
for i, col in enumerate(cat_cols):
    plt.subplot(2, 3, i + 1)
    if col in ['month', 'hour']:
        sns.countplot(data=df, x=col, hue=col, palette='viridis', legend=False)
    else:
        df[col].value_counts().sort_index().plot(kind='bar', color='teal')
    plt.title(f'{col} Distribution')
plt.tight_layout()
plt.show()

# Part B:相关性分析
# B1.Pearson相关系数矩阵
corr_matrix = df[num_cols + ['nat_demand']].corr()
plt.figure(figsize=(12,10))
heatmap = sns.heatmap(corr_matrix[['nat_demand']].sort_values('nat_demand',ascending=False),
                      annot=True,fmt=".2f",vmin=-1,vmax=1,
                      cmap='coolwarm',linewidths=.5)
heatmap.set_title('Correlation with nat_demand',pad=12)
plt.show()

# B2.hour/day/month与需求 的关系
plt.figure(figsize=(12,6))
ax = sns.lineplot(data=df,x='hour',y='nat_demand',
                 err_style="band", # Modern error display style 
                 color='darkorange',
                 linewidth=3,
                 marker='o')
ax.set_title('Hourly Demand Pattern with Confidence Interval',pad=10)
ax.set_xlabel('Hour of Day',labelpad=10)
ax.set_ylabel('Energy Demand',labelpad=10)
plt.grid(True,alpha=0.3)
plt.show()

plt.figure(figsize=(12,6))
ax = sns.lineplot(data=df,x='day',y='nat_demand',
                 err_style="band", # Modern error display style 
                 color='darkorange',
                 linewidth=3,
                 marker='o')
ax.set_title('Dayly Demand Pattern with Confidence Interval',pad=10)
ax.set_xlabel('Day of Month',labelpad=10)
ax.set_ylabel('Energy Demand',labelpad=10)
plt.grid(True,alpha=0.3)
plt.show()

plt.figure(figsize=(12,6))
ax = sns.lineplot(data=df,x='month',y='nat_demand',
                 err_style="band", # Modern error display style 
                 color='darkorange',
                 linewidth=3,
                 marker='o')
ax.set_title('Monthly Demand Pattern with Confidence Interval',pad=10)
ax.set_xlabel('Month',labelpad=10)
ax.set_ylabel('Energy Demand',labelpad=10)
plt.grid(True,alpha=0.3)
plt.show()

# B3.holiday影响分析
import scipy.stats as stats

holiday_data = [df[df['holiday']==0]['nat_demand'],
                df[df['holiday']==1]['nat_demand']]

plt.figure(figsize=(8,5))
box = sns.boxplot(x='holiday', y='nat_demand', data=df, 
                 hue='holiday', palette=['skyblue','salmon'], legend=False)

test_result = stats.ttest_ind(*hol#iday_data)
box.set_title(f'Holiday Impact (p-value={test_result.pvalue:.4f})')

box.set_xticks([0, 1])
box.set_xticklabels(['Normal Day', 'Holiday'])

plt.show()

# B4.monthly趋势 (添加平滑处理)
import statsmodels.api as sm
plt.figure(figsize=(12,6))
monthly_avg = df.groupby('month')['nat_demand'].mean().reset_index()
sns.lineplot(x='month',y='nat_demand',data=monthly_avg,
             estimator=None,lw=3,
             marker='s',markersize=10,
             label='Monthly Average')
# Add smoothed trend line (optional)
lowess_smoothed = sm.nonparametric.lowess(
    monthly_avg['nat_demand'], monthly_avg['month'], frac=0.33)
plt.plot(lowess_smoothed[:,0], lowess_smoothed[:,1],
         'r--',lw=3,
         label='Smoothed Trend')
plt.show()

# Pattern Identification

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import seaborn as sns
from statsmodels.tsa.stattools import acf
from statsmodels.graphics.tsaplots import plot_acf
import matplotlib.pyplot as plt

## I. start with basic statistics for both numeric and categorical data

In [None]:
def unistats(dataframe,sorted="Missing"):
    """"
    Takes dataframe and sorted as parameter
    Returns count, missing, unique, dtype, mode and other stats
    """
    pd.set_option("display.max_rows",100)
    pd.set_option("display.max_columns",100)
    output_df = pd.DataFrame(columns = ["Count","Missing","Unique", "Dtype", "Mode", "Mean", "Min", "25%", "Median", "75%", "Max", "Std", "Skew", "Kurt"])

    for col in dataframe:
        if pd.api.types.is_numeric_dtype(dataframe[col]):
            output_df.loc[col] =[dataframe[col].count() ,dataframe[col].isnull().sum() ,dataframe[col].nunique() ,dataframe[col].dtype ,dataframe[col].mode().values[0], dataframe[col].mean(), dataframe[col].min(), dataframe[col].quantile(0.25), dataframe[col].median(), dataframe[col].quantile(0.75),dataframe[col].max(), dataframe[col].std(), dataframe[col].skew(),dataframe[col].kurt()]
        else:
            output_df.loc[col] =[dataframe[col].count() ,dataframe[col].isnull().sum() ,dataframe[col].nunique() ,dataframe[col].dtype , "-", "-", "-","-", "-", "-","-", "-", "-","-"]


    return output_df.sort_values(by = ["Dtype",sorted])

def scatter(dataframe, target, feature):
    from statsmodels.formula.api import ols
    from statsmodels.stats.diagnostic import het_breuschpagan
    from scipy import stats
    """
    Takes dataframe, target and feature as parameter
    Use it with a numeric column
    Fits an OLS model with the given feature
    Applies breuschpagan test
    Returns the scatterplot, regression and test results.
    """

    sns.set_style(style="white")

    model = ols(formula= f"{target}~{feature}", data = dataframe).fit()

    lm, p1, f, p2 = het_breuschpagan(model.resid,model.model.exog)
    m, b, r, p, err = stats.linregress(dataframe[feature], dataframe[target])

    string = "y = " + str(round(m,2)) + "x " + str(round(b,2)) + "\n"
    string += "r_2 = " + str(round(r**2, 4))  + "\n"
    string += str(round(r**2, 4)*100) + "% of variance is explained" + "\n"
    string += "p = " + str(round(p, 5)) + "\n"
    if p < 0.05:
        string += "Significant" + "\n"
    else:
        string += "Not Significant" + "\n"
    string += str(dataframe[feature].name) + " skew = " + str(round(dataframe[feature].skew(), 2)) + "\n"
    if dataframe[feature].skew() < 0:
        string += str(dataframe[feature].name) + " is negatively skewed" + "\n"
    else:
        string += str(dataframe[feature].name) + " is positively skewed" + "\n"
    string += str(dataframe[target].name) + " skew = " + str(round(dataframe[target].skew(), 2)) + "\n"
    if dataframe[target].skew() < 0:
        string += str(dataframe[target].name) + " is negatively skewed" + "\n"
    else:
        string += str(dataframe[target].name) + " is positively skewed" + "\n"
    string += str(dataframe[feature].name) + " Breushpagan Test = " + "LM stat: " + str(round(lm,4)) + " p value: " + str(round(p1,4)) + " F stat: " + str(round(f,4)) + " p value: " + str(round(p2,4)) + "\n"
    if p1 < 0.05:
        string += "Variance of residuals are not distributed equally" + "\n"
    else:
        string += "Variance of residuals are distributed equally" + "\n"
    ax = sns.jointplot(x = feature, y = target, kind = "reg", data = dataframe)
    ax.fig.text( 1, 0.1, string, fontsize = 12, transform = plt.gcf().transFigure)

def plot_predictions(test,predicted):
    plt.plot(test, color='red',label='Real Demand')
    plt.plot(predicted, color='green',label='Predicted Demand')
    plt.title('Demand Prediction')
    plt.xlabel('Time')
    plt.ylabel('Prediction')
    plt.legend()
    plt.show()


def return_rmse(test,predicted):
    rmse = math.sqrt(mean_squared_error(test, predicted))
    print("The root mean squared error is {}.".format(rmse))

def hist_and_boxplot(dataframe, label):
    """
    Takes dataframe and feature as parameter
    Returns histogram and boxplot"""
    plt.figure(figsize=(6,3))
    plt.subplot(1,2,1)
    sns.histplot(data = dataframe, x = label)
    plt.subplot(1,2,2)
    sns.boxplot(data = dataframe, x = label)
    plt.show();

In [None]:
df = pd.read_csv("continuous dataset_en.csv")

In [None]:
df.head()

In [None]:
unistats(df)

* lets create some datetime features,
* Since electricity demand is highly correlated with day and hour. 
* These features will help us to interpret and gain insights.

In [None]:
df["datetime"] = pd.to_datetime(df["datetime"])
df["month"] = df["datetime"].dt.month
df["day"] = df["datetime"].dt.day
df["hour"] = df["datetime"].dt.hour
df["dayofweek"] = df["datetime"].dt.dayofweek
df["dayofyear"] = df["datetime"].dt.dayofyear

## II. lets look at the date range

In [None]:
df["datetime"].min(), df["datetime"].max()

## III. lets look at the histograms of numeric columns to better understand the distributions

In [None]:
for col in df.select_dtypes(include = "number").columns:
  hist_and_boxplot(df,col)

## IV. Lets dive into nat demand column

In [None]:
hist_and_boxplot(df,"nat_demand")

### Observe some outliers and want to look at them

In [None]:
df[df["nat_demand"] <500]

## V. Lets visualize the time series

In [None]:
def plot_df(df, x, y, title="", xlabel='Date', ylabel='Value', dpi=100):
    plt.figure(figsize=(16,5), dpi=dpi)
    plt.plot(x, y, color='tab:red')
    plt.gca().set(title=title, xlabel=xlabel, ylabel=ylabel)
    plt.show()

plot_df(df, x=df["datetime"], y=df['nat_demand'], title='Time Series')

* Here we can see that the data is from 2015 to 2020 which is the year of pandemic.
* Since the pandemic highly affected the electricity usage I will only look at till the start of pandemic which is late 2019.
* For consistency I will change these values with the mean

In [None]:
df = df[df["datetime"] < "2019-12-31"]
df.loc[df["nat_demand"] <500,"nat_demand"] = df["nat_demand"].mean()
plot_df(df, x=df["datetime"], y=df['nat_demand'], title='Time Series')

* Since it is hourly, hard to catch trends and seasoanility. Hence lets decompose it and try to see these effects.

In [None]:
from statsmodels.tsa.seasonal import seasonal_decompose

decomp_add = seasonal_decompose(df["nat_demand"], period = 24*365)
decomp_add.plot();
decomp_mul = seasonal_decompose(df["nat_demand"], period = 24*365, model = "multiplicative")
decomp_mul.plot();

## VI. Draw the seasonal proportions for the first three years

In [None]:
plt.figure(figsize=(12, 6))

seasonal = decomp_add.seasonal

#seasonal = decomp_mul.seasonal

# first year （0 to 8760）
time_index_1 = pd.date_range(start='2015-01-03 01:00', periods=8760, freq='h')
plt.plot(time_index_1, seasonal[:8760], label="Year 1 (2015)", alpha=0.7)

# Second year（8760 to 17520）
time_index_2 = pd.date_range(start='2016-01-03 01:00', periods=8760, freq='h')
plt.plot(time_index_2, seasonal[8760:17520], label="Year 2 (2016)", alpha=0.7)

# Third year（17520 to 26280）
time_index_3 = pd.date_range(start='2017-01-03 01:00', periods=8760, freq='h')
plt.plot(time_index_3, seasonal[17520:26280], label="Year 3 (2017)", alpha=0.7)

plt.title("Seasonal Component (First 3 Years)")
plt.xlabel("Date")
plt.ylabel("Seasonal (MW)")
plt.grid(True)
plt.legend()
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

* In the seasonal chart, fluctuations repeat in cycles of 8,760 hours (approximately 4.57 cycles within 40,000 hours).
* Within each 8,760 hour segment, the distribution of peaks and troughs is similar (for example, peaks may correspond to winter and summer, and troughs to spring and autumn).

## VII. check the stationality,We reject the null hypothesis that the nat demand is stationary since p value is very small.

In [None]:
from statsmodels.tsa.stattools import adfuller
print('Results of Dickey-Fuller Test:')
dftest = adfuller(df["nat_demand"])
dfoutput = pd.Series(dftest[0:4], index=['Test Statistic','p-value','#Lags Used','Number of Observations Used'])
print(dfoutput)

## VIII. check the stationality,We reject the null hypothesis that the nat demand is stationary since p value is very small.

In [None]:
from statsmodels.tsa.stattools import adfuller
print('Results of Dickey-Fuller Test:')
dftest = adfuller(df["nat_demand"])
dfoutput = pd.Series(dftest[0:4], index=['Test Statistic','p-value','#Lags Used','Number of Observations Used'])
print(dfoutput)

## IX. Lets also look by season

In [None]:
energy_con_winter = df[df["month"].isin([12,1,2])]
energy_con_summer = df[df["month"].isin([6,7,8])]
sns.boxplot(data = energy_con_winter, x = "hour", y = "nat_demand")
sns.boxplot(data = energy_con_summer, x = "hour", y = "nat_demand")

* High electricity consumption might be due to the hot weather during the dry season, which increases the demand for air conditioners. Furthermore, the dry season is the peak tourist season and economic activities increase.
* In June, July and August (the rainy season), the electricity consumption is low. This might be due to the frequent rainfall during the rainy season and the slightly lower (but still higher) temperature, which leads to a decrease in the demand for air conditioners. The rainy season may reduce outdoor activities and lower electricity consumption

## X. Create a season column to retain only December, January, February and June, July, August

In [None]:
df_subset = df[df["month"].isin([12, 1, 2, 6, 7, 8])]
df_subset["season"] = df_subset["month"].map({
    12: "Dry Season (Dec-Feb)", 1: "Dry Season (Dec-Feb)", 2: "Dry Season (Dec-Feb)",
    6: "Rainy Season (Jun-Aug)", 7: "Rainy Season (Jun-Aug)", 8: "Rainy Season (Jun-Aug)"
})

plt.figure(figsize=(12, 6))
sns.boxplot(x="hour", y="nat_demand", hue="season", data=df_subset)
plt.title("Electricity Demand by Hour: Dry Season (Dec-Feb) vs Rainy Season (Jun-Aug) in Panama")
plt.xlabel("Hour")
plt.ylabel("nat_demand (MW)")
plt.show()

* Here the orange boxplots represent the usage in winter.
* As we can see, the usage in winter is higher than summer.
* Hence we conclude there is seasonality.

### Daily

In [None]:
sns.boxplot(data = df, x = "dayofweek", y = "nat_demand")

* We can see that weekdays are higher in the usage of electricity.

### hourly

In [None]:
sns.boxplot(data = df, x = "hour", y = "nat_demand")

* Hourly effect can be seen easily, at night there is little and peek at 11-12 mid-day and again decrease till the night.

In [None]:
plt.scatter(df["T2M_toc"], df["nat_demand"])
plt.xlabel("Temperature (T2M_toc)")
plt.ylabel("Demand (MW)")
plt.title("Demand vs Temperature")
plt.show()

In [None]:
df.boxplot(column="nat_demand", by="holiday")
plt.title("Demand by Holiday")
plt.show()

In [None]:
decomp_year = seasonal_decompose(df["nat_demand"], period=24*365, model="multiplicative")
decomp_year.plot()

## XI. ACF test

In [None]:
data_acf = pd.read_csv("continuous dataset_en.csv", parse_dates=['datetime'], index_col='datetime')['nat_demand']
lags = 96
acf_values, confint = acf(data_acf, nlags=lags, alpha=0.05)

lower_conf = confint[:, 0] - acf_values
upper_conf = confint[:, 1] - acf_values

#draw figure
plt.figure(figsize=(13, 6))
plt.stem(range(lags + 1), acf_values, markerfmt='bo', basefmt='k-')
plt.fill_between(range(lags + 1), lower_conf, upper_conf, color='red', alpha=0.5)
plt.title('Autocorrelation Function (ACF) with Confidence Intervals')
plt.xlabel('Lag')
plt.ylabel('Autocorrelation')
plt.grid(True)
plt.show()

# Winters Method

In [None]:
import pandas as pd
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error
from datetime import datetime
import warnings
import os
import itertools
import matplotlib.pyplot as plt

warnings.filterwarnings('ignore')

In [None]:
# 加载和预处理数据
def load_and_preprocess_data(filepath):
    if not os.path.exists(filepath):
        raise FileNotFoundError(f"文件路径不存在：{filepath}")

    try:
        # 读取 Excel 文件
        data = pd.read_csv(filepath) #文件格式

        # 检查日期列是否存在
        if 'datetime' not in data.columns or 'nat_demand' not in data.columns:
            raise KeyError("数据缺少必要的列：'日期（2015-2019）' 或 '国家电力负载'")
        
        # 设置日期为索引，并选择电力负载列
        data.set_index('datetime', inplace=True)
        series = data['nat_demand']

        # 填充缺失值（例如使用线性插值）
        series = series.interpolate(method='linear').fillna(method='bfill').fillna(method='ffill')

        # 重采样为每日数据（求均值）
        daily_series = series.resample('D').mean()

        return daily_series

    except Exception as e:
        raise ValueError(f"数据加载和预处理失败：{e}")

In [None]:
# Holt-Winters 模型实现
class HoltWinters:
    def __init__(self, series, seasonal_periods=365, alpha=0.2, beta=0.1, gamma=0.3):
        if len(series) < seasonal_periods:
            raise ValueError(f"时间序列长度不足，至少需要 {seasonal_periods} 个数据点。")

        self.series = series
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.seasonal_periods = seasonal_periods
        self.level = None
        self.trend = None
        self.seasonal = None
        self.forecast = None

    def initialize(self):
        self.level = np.mean(self.series[:self.seasonal_periods])
        first_period = self.series[:self.seasonal_periods]
        second_period = self.series[self.seasonal_periods:2*self.seasonal_periods]

        self.trend = (np.mean(second_period) - np.mean(first_period)) / self.seasonal_periods \
            if len(second_period) > 0 else 0

        self.seasonal = np.zeros(self.seasonal_periods)
        for i in range(self.seasonal_periods):
            self.seasonal[i] = self.series[i] / self.level if i < len(self.series) else 1.0

    def fit(self):
        self.initialize()
        result = []
        seasonals = self.seasonal.copy()

        for i in range(len(self.series)):
            old_level = self.level
            old_trend = self.trend
        return np.array(forecast)

In [None]:
# 计算准确性指标
def calculate_metrics(actual, predicted):
    mae = mean_absolute_error(actual, predicted)
    rmse = np.sqrt(mean_squared_error(actual, predicted))
    mape = np.mean(np.abs((actual - predicted) / actual)) * 100

    return {
        'MAE': mae,
        'RMSE': rmse,
        'MAPE': mape
    }

In [None]:
# 网格搜索优化参数
def grid_search_holt_winters(train, test, seasonal_periods=365):
    # 定义参数范围
    alpha_range = np.arange(0.1, 1.0, 0.2)  # [0.1, 0.3, 0.5, 0.7, 0.9]
    beta_range = np.arange(0.1, 0.5, 0.2)   # [0.1, 0.3]
    gamma_range = np.arange(0.1, 0.5, 0.2)  # [0.1, 0.3]

    best_mae = float('inf')
    best_params = None
    best_forecast = None

    # 遍历所有参数组合
    for alpha, beta, gamma in itertools.product(alpha_range, beta_range, gamma_range):
        try:
            # 训练模型
            model = HoltWinters(train, seasonal_periods=seasonal_periods, alpha=alpha, beta=beta, gamma=gamma)
            model.fit()

            # 预测
            forecast = model.predict(len(test))

            # 计算 MAE
            mae = mean_absolute_error(test, forecast)

            # 更新最佳参数
            if mae < best_mae:
                best_mae = mae
                best_params = {'alpha': alpha, 'beta': beta, 'gamma': gamma}
                best_forecast = forecast

        except Exception as e:
            print(f"参数组合 (alpha={alpha}, beta={beta}, gamma={gamma}) 失败: {e}")
            continue

    return best_params, best_forecast, best_mae

In [None]:
# 主函数
def main():
    # 使用您的文件路径
    filepath = 'continuous dataset_en.csv'

    try:
        series = load_and_preprocess_data(filepath)

        train = series[:-30]
        test = series[-30:]

        # 进行网格搜索
        print("开始网格搜索以优化参数...")
        best_params, forecast, best_mae = grid_search_holt_winters(train, test, seasonal_periods=365)
        print(f"最佳参数: alpha={best_params['alpha']}, beta={best_params['beta']}, gamma={best_params['gamma']}")
        print(f"最佳 MAE: {best_mae:.2f}")

        # 使用最佳参数重新训练模型
        model = HoltWinters(train, seasonal_periods=365, 
                           alpha=best_params['alpha'], 
                           beta=best_params['beta'], 
                           gamma=best_params['gamma'])
        model.fit()
        

        metrics = calculate_metrics(test, forecast)

        print("Holt-Winters 预测结果：")
        print("\n准确性指标：")
        for metric, value in metrics.items():
            print(f"{metric}: {value:.2f}")

        results = pd.DataFrame({
            'Actual': test,
            'Predicted': forecast
        }, index=test.index)

        results.to_csv('holt_winters_forecast_results.csv')

        import matplotlib.pyplot as plt

        plt.figure(figsize=(12, 6))
        plt.plot(results.index, results['Actual'], label='Actual')
        plt.plot(results.index, results['Predicted'], label='Predicted')
        plt.title('Winters Method Results(Test-Last 30 days)')
        plt.xlabel('Time')
        plt.ylabel('Demand')
        plt.legend()
        plt.grid(True)
        plt.savefig('holt_winters_forecast.png')
        plt.close()
        
    except Exception as e:
        print(f"运行失败：{e}")


if __name__ == "__main__":
    main()

# ARIMA method

In [None]:
import pandas as pd  
import numpy as np  
import matplotlib.pyplot as plt  
import seaborn as sns  
from statsmodels.tsa.arima.model import ARIMA  
from statsmodels.tsa.stattools import adfuller  
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf  
from sklearn.metrics import mean_squared_error, mean_absolute_error  
from sklearn.model_selection import train_test_split  
import warnings  
warnings.filterwarnings('ignore')  
import matplotlib.dates as mdates  
from statsmodels.tsa.seasonal import seasonal_decompose  
import statsmodels.api as sm  
import os  
from datetime import datetime, timedelta  
  
plt.rcParams['font.sans-serif'] = ['SimHei']   
plt.rcParams['axes.unicode_minus'] = False    

# 1. data
print("1. 数据加载与预处理")  
df = pd.read_csv("continuous dataset_en.csv", parse_dates=['datetime'])  
df.set_index('datetime', inplace=True) 

# 检查并处理缺失值  
print(f"原始数据形状: {df.shape}")  
print(f"缺失值数量:\n{df.isna().sum()}")  

# 填充缺失值 (如果有)  
if df.isna().sum().sum() > 0:  
    print("处理缺失值...")  
    # 对时间序列数据，通常使用前向填充  
    df = df.fillna(method='ffill')  

In [None]:
# 2. 基本数据探索  
print("\n2. 数据探索")  
print(f"数据范围: {df.index.min()} 至 {df.index.max()}")  
print(f"数据描述性统计:\n{df['nat_demand'].describe()}")  

In [None]:
# 3. 可视化原始数据  
plt.figure(figsize=(15, 6))  
plt.plot(df.index, df['nat_demand'])  
plt.title('电力需求时间序列图 (nat_demand)')  
plt.xlabel('日期')  
plt.ylabel('需求量')  
plt.grid(True)  
plt.tight_layout()  
plt.savefig('电力需求时间序列图.png', dpi=300)  

In [None]:
# 4. 时间序列分解  
print("\n3. 时间序列分解")  
# 分解时间序列以观察趋势、季节性和残差  
decomposition = seasonal_decompose(df['nat_demand'], model='additive', period=24)  # 每日24小时周期  
fig, axes = plt.subplots(4, 1, figsize=(15, 12))  
decomposition.observed.plot(ax=axes[0])  
axes[0].set_title('原始数据')  
axes[0].set_ylabel('需求量')  
decomposition.trend.plot(ax=axes[1])  
axes[1].set_title('趋势')  
axes[1].set_ylabel('趋势')  
decomposition.seasonal.plot(ax=axes[2])  
axes[2].set_title('季节性')  
axes[2].set_ylabel('季节性')  
decomposition.resid.plot(ax=axes[3])  
axes[3].set_title('残差')  
axes[3].set_ylabel('残差')  
plt.tight_layout()  
plt.savefig('时间序列分解.png', dpi=300)  

In [None]:
# 5. 平稳性检验  
print("\n4. 平稳性检验 (ADF测试)")  
def adf_test(series):  
    result = adfuller(series.dropna())  
    print(f'ADF统计量: {result[0]}')  
    print(f'p值: {result[1]}')  
    print(f'临界值:')  
    for key, value in result[4].items():  
        print(f'\t{key}: {value}')  
    if result[1] <= 0.05:  
        print("=> 序列是平稳的 (拒绝原假设)")  
    else:  
        print("=> 序列不平稳 (未拒绝原假设)")  
    return result[1] <= 0.05  

# 对原始数据进行平稳性检验  
is_stationary = adf_test(df['nat_demand'])  


In [None]:
# 6. 如果数据不平稳，进行差分处理  
if not is_stationary:  
    print("\n5. 进行差分处理")  
    # 一阶差分  
    df['nat_demand_diff1'] = df['nat_demand'].diff()  
    
    # 再次检验平稳性  
    plt.figure(figsize=(15, 6))  
    plt.plot(df.index[1:], df['nat_demand_diff1'][1:])  
    plt.title('电力需求一阶差分')  
    plt.xlabel('日期')  
    plt.ylabel('差分值')  
    plt.grid(True)  
    plt.tight_layout()  
    plt.savefig('一阶差分.png', dpi=300)  
    
    print("\n一阶差分后的平稳性检验:")  
    is_stationary_diff1 = adf_test(df['nat_demand_diff1'])  
    
    # 如果一阶差分后仍不平稳，尝试季节性差分  
    if not is_stationary_diff1:  
        print("\n进行季节性差分 (24小时)...")  
        df['nat_demand_seasonal_diff'] = df['nat_demand'].diff(24)  
        
        plt.figure(figsize=(15, 6))  
        plt.plot(df.index[24:], df['nat_demand_seasonal_diff'][24:])  
        plt.title('电力需求季节性差分 (周期=24小时)')  
        plt.xlabel('日期')  
        plt.ylabel('季节性差分值')  
        plt.grid(True)  
        plt.tight_layout()  
        plt.savefig('季节性差分.png', dpi=300)  
        
        print("\n季节性差分后的平稳性检验:")  
        is_stationary_seasonal = adf_test(df['nat_demand_seasonal_diff'])  
else:  
    print("数据已经是平稳的，无需差分处理")  
    df['nat_demand_diff1'] = df['nat_demand']  

In [None]:
# 7. 确定ARIMA模型参数 - ACF和PACF图  
print("\n6. 确定ARIMA模型参数")  
# 使用差分后的数据  
diff_data = df['nat_demand_diff1'].dropna()  

fig, axes = plt.subplots(2, 1, figsize=(15, 10))  
plot_acf(diff_data, lags=48, ax=axes[0])  # 48小时的滞后  
axes[0].set_title('自相关函数 (ACF)')  
plot_pacf(diff_data, lags=48, ax=axes[1])  
axes[1].set_title('偏自相关函数 (PACF)')  
plt.tight_layout()  
plt.savefig('ACF_PACF图.png', dpi=300)  

In [None]:
# 8. 拆分训练集和测试集  
print("\n7. 拆分训练集和测试集")  
# 使用80%数据作为训练集，20%作为测试集  
train_size = int(len(df) * 0.8)  
train_data = df.iloc[:train_size]  
test_data = df.iloc[train_size:]  

print(f"训练集大小: {len(train_data)}")  
print(f"测试集大小: {len(test_data)}")  

In [None]:
# 9. 拟合ARIMA模型  
print("\n8. 拟合ARIMA模型")  

p, d, q = 2, 1, 2 

# 创建和训练ARIMA模型  
model = ARIMA(train_data['nat_demand'], order=(p, d, q))  
model_fit = model.fit()  
print(model_fit.summary())  

In [None]:
# 10. 模型诊断  
print("\n9. 模型诊断")  
# 残差分析  
residuals = model_fit.resid  
fig, axes = plt.subplots(2, 2, figsize=(15, 10))  

# 残差时间序列图  
axes[0, 0].plot(residuals)  
axes[0, 0].set_title('残差时间序列')  
axes[0, 0].set_xlabel('时间')  
axes[0, 0].set_ylabel('残差')  

# 残差直方图  
axes[0, 1].hist(residuals, bins=30)  
axes[0, 1].set_title('残差直方图')  
axes[0, 1].set_xlabel('残差')  
axes[0, 1].set_ylabel('频率')  

# 残差ACF图  
plot_acf(residuals, lags=40, ax=axes[1, 0])  
axes[1, 0].set_title('残差自相关函数')  

# 残差Q-Q图  
import scipy.stats as stats  
stats.probplot(residuals, dist="norm", plot=axes[1, 1])  
axes[1, 1].set_title('残差Q-Q图')  

plt.tight_layout()  
plt.savefig('模型诊断.png', dpi=300)  

In [None]:
# 11. 进行预测  
print("\n10. 进行预测")  
# 在测试集上进行预测  
predictions = model_fit.forecast(steps=len(test_data))  

# 创建包含预测值的DataFrame  
pred_df = pd.DataFrame({  
    'Actual': test_data['nat_demand'],  
    'Predicted': predictions  
})  

# 可视化预测结果  
plt.figure(figsize=(15, 6))  
plt.plot(test_data.index, test_data['nat_demand'], label='实际值')  
plt.plot(test_data.index, predictions, color='red', label='预测值')  
plt.title('ARIMA模型预测结果')  
plt.xlabel('日期')  
plt.ylabel('电力需求')  
plt.legend()  
plt.grid(True)  
plt.tight_layout()  
plt.savefig('预测结果.png', dpi=300)  

In [None]:
# 12. 评估模型性能  
print("\n11. 评估模型性能")  
# 计算均方误差、均方根误差和平均绝对误差  
mse = mean_squared_error(test_data['nat_demand'], predictions)  
rmse = np.sqrt(mse)  
mae = mean_absolute_error(test_data['nat_demand'], predictions)  

print(f"均方误差 (MSE): {mse:.4f}")  
print(f"均方根误差 (RMSE): {rmse:.4f}")  
print(f"平均绝对误差 (MAE): {mae:.4f}")  

In [None]:
# 13. 分析外部变量的影响 (可选)  
print("\n12. 分析自变量的影响")  
# 检查自变量与电力需求的相关性  
plt.figure(figsize=(12, 10))  
correlation_matrix = df[['nat_demand', 'T2M_toc', 'QV2M_toc', 'TQL_toc', 'W2M_toc',   
                       'T2M_san', 'QV2M_san', 'TQL_san', 'W2M_san',  
                       'T2M_dav', 'QV2M_dav', 'TQL_dav', 'W2M_dav']].corr()  

sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt='.2f')  
plt.title('自变量与电力需求的相关性热力图')  
plt.tight_layout()  
plt.savefig('相关性热力图.png', dpi=300)  

In [None]:
# 14. SARIMAX模型 (包含外部变量)  
print("\n13. 使用SARIMAX模型考虑外部变量")  
# 选择相关性强的外部变量  
selected_features = ['T2M_toc', 'QV2M_toc', 'T2M_san', 'W2M_dav']  # 示例，请根据相关性分析结果调整  

# 创建外生变量矩阵  
exog_train = train_data[selected_features]  
exog_test = test_data[selected_features]  

# 拟合SARIMAX模型  
# (p,d,q) 为ARIMA参数，(P,D,Q,s) 为季节性参数  
# 这里使用 (2,1,2)×(1,1,1,24) 作为示例  
sarimax_model = sm.tsa.SARIMAX(  
    train_data['nat_demand'],  
    exog=exog_train,  
    order=(p, d, q),  
    seasonal_order=(1, 1, 1, 24),  # 季节性参数 (P,D,Q,s)，s=24表示24小时周期  
    enforce_stationarity=False,  
    enforce_invertibility=False  
)  

sarimax_results = sarimax_model.fit()  
print(sarimax_results.summary())  

# 使用SARIMAX模型进行预测  
sarimax_predictions = sarimax_results.forecast(steps=len(test_data), exog=exog_test)  

# 对比ARIMA和SARIMAX模型的预测结果  
plt.figure(figsize=(15, 6))  
plt.plot(test_data.index, test_data['nat_demand'], label='实际值')  
plt.plot(test_data.index, predictions, color='red', label='ARIMA预测')  
plt.plot(test_data.index, sarimax_predictions, color='green', label='SARIMAX预测')  
plt.title('ARIMA vs SARIMAX 预测结果比较')  
plt.xlabel('日期')  
plt.ylabel('电力需求')  
plt.legend()  
plt.grid(True)  
plt.tight_layout()  
plt.savefig('ARIMA与SARIMAX比较.png', dpi=300)  

# 计算SARIMAX模型的评估指标  
sarimax_mse = mean_squared_error(test_data['nat_demand'], sarimax_predictions)  
sarimax_rmse = np.sqrt(sarimax_mse)  
sarimax_mae = mean_absolute_error(test_data['nat_demand'], sarimax_predictions)

In [None]:
# 15. 总结: 显示最优模型及其参数  
if sarimax_rmse < rmse:  
    print("\n最优模型: SARIMAX")  
    print(f"最优参数: ARIMA({p},{d},{q})x季节性参数(1,1,1,24)")  
    print(f"最优模型RMSE: {sarimax_rmse:.4f}")  
    print(f"显著影响电力需求的外部变量: {selected_features}")  
else:  
    print("\n最优模型: ARIMA")  
    print(f"最优参数: ARIMA({p},{d},{q})")  
    print(f"最优模型RMSE: {rmse:.4f}")  

print("\n分析完成!")  

# LSTM method

In [None]:
# pip install tensorflow --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple

In [None]:
import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import matplotlib.pyplot as plt
import statsmodels.api as sm
from statsmodels.graphics.api import qqplot
from statsmodels.graphics.tsaplots import plot_acf

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.callbacks import EarlyStopping

#1. 数据加载与预处理 
file_path = 'continuous dataset_en.csv'
try:
    df = pd.read_csv(file_path)
except FileNotFoundError:
    print(f"错误：找不到文件 {file_path}。请确保文件路径正确。")
    # exit() # 在notebook环境中不exit

df_ts = df.copy()
df_ts["datetime"] = pd.to_datetime(df_ts["datetime"])
df_ts = df_ts.set_index("datetime") # 设置datetime为索引

# --- 数据过滤：选择 2015-01-01 到 2019-12-31 的数据 ---
print(f"原始数据范围: {df_ts.index.min()} 到 {df_ts.index.max()}")
df_ts = df_ts['2015-01-01':'2019-12-31']
print(f"筛选后数据范围: {df_ts.index.min()} 到 {df_ts.index.max()}")

# 填充缺失值
df_ts["nat_demand"] = df_ts["nat_demand"].ffill()
numerical_cols = df_ts.select_dtypes(include=np.number).columns.difference(['nat_demand'])
df_ts[numerical_cols] = df_ts[numerical_cols].fillna(df_ts[numerical_cols].median())

# --- 2. 特征工程 ---
# 时间特征
df_ts['hour'] = df_ts.index.hour
df_ts['dayofweek'] = df_ts.index.dayofweek
df_ts['quarter'] = df_ts.index.quarter
df_ts['month'] = df_ts.index.month
df_ts['year'] = df_ts.index.year
df_ts['dayofyear'] = df_ts.index.dayofyear
df_ts['weekofyear'] = df_ts.index.isocalendar().week.astype(int)
# Sin/Cos 变换
df_ts['hour_sin'] = np.sin(2 * np.pi * df_ts['hour'] / 24)
df_ts['hour_cos'] = np.cos(2 * np.pi * df_ts['hour'] / 24)
df_ts['dayofweek_sin'] = np.sin(2 * np.pi * df_ts['dayofweek'] / 7)
df_ts['dayofweek_cos'] = np.cos(2 * np.pi * df_ts['dayofweek'] / 7)
df_ts['month_sin'] = np.sin(2 * np.pi * df_ts['month'] / 12)
df_ts['month_cos'] = np.cos(2 * np.pi * df_ts['month'] / 12)
df_ts['dayofyear_sin'] = np.sin(2 * np.pi * df_ts['dayofyear'] / 365)
df_ts['dayofyear_cos'] = np.cos(2 * np.pi * df_ts['dayofyear'] / 365)
df_ts['weekofyear_sin'] = np.sin(2 * np.pi * df_ts['weekofyear'] / 52)
df_ts['weekofyear_cos'] = np.cos(2 * np.pi * df_ts['weekofyear'] / 52)
# 滞后特征
for lag in [1, 2, 3, 24, 48, 24*7]:
    df_ts[f'nat_demand_lag_{lag}'] = df_ts['nat_demand'].shift(lag)
# 滚动窗口特征
for window in [24, 24*7]:
    df_ts[f'nat_demand_rolling_mean_{window}'] = df_ts['nat_demand'].rolling(window=window).mean()
# 节假日特征
df_ts['holiday'] = df_ts['holiday'].fillna('None')
holiday_dummies = pd.get_dummies(df_ts['holiday'], prefix='holiday', dummy_na=False)
df_ts = pd.concat([df_ts, holiday_dummies], axis=1)
df_ts.drop('holiday', axis=1, inplace=True)
# Holiday_ID 和 school 特征保留

In [None]:
# --- 3. 数据准备用于 LSTM (按时间划分) ---

# 删除由于滞后和滚动计算产生的NaN行
df_lstm = df_ts.dropna().copy()
print(f"去除NaN后数据范围: {df_lstm.index.min()} 到 {df_lstm.index.max()}")


# --- 定义训练集和测试集 (基于时间) ---
test_days = 30
split_date = df_lstm.index.max() - pd.Timedelta(days=test_days - 1) # 测试集开始日期 (包含)
split_date = split_date.normalize() #确保从一天的开始计算

train_df = df_lstm.loc[df_lstm.index < split_date]
test_df = df_lstm.loc[df_lstm.index >= split_date]

print(f"训练集范围: {train_df.index.min()} 到 {train_df.index.max()}")
print(f"测试集范围: {test_df.index.min()} 到 {test_df.index.max()}")
print(f"训练集大小: {len(train_df)}, 测试集大小: {len(test_df)}")


# 定义特征和目标
feature_cols = [col for col in df_lstm.columns if col not in ['nat_demand']]
target_col = 'nat_demand'

# 数据标准化 (在训练集上拟合，然后转换训练集和测试集)
scaler_X = MinMaxScaler()
scaler_y = MinMaxScaler()

# 拟合训练集并转换
X_train_scaled = scaler_X.fit_transform(train_df[feature_cols])
y_train_scaled = scaler_y.fit_transform(train_df[target_col].values.reshape(-1, 1))

# 转换测试集
X_test_scaled = scaler_X.transform(test_df[feature_cols])
y_test_scaled = scaler_y.transform(test_df[target_col].values.reshape(-1, 1))


# 创建时间序列数据所需的序列 (sequences)
def create_sequences(X, y, time_steps=1):
    Xs, ys = [], []
    # 确保循环不会超出 y 的边界
    for i in range(len(X) - time_steps):
        v = X[i:(i + time_steps)]
        Xs.append(v)
        # 对应的 y 是时间步结束后的下一个值
        ys.append(y[i + time_steps])
    return np.array(Xs), np.array(ys)

# 设定时间步长
TIME_STEPS = 24 # 例如，使用过去 24 个小时的数据来预测下一个小时

# 为训练集和测试集创建序列
X_train_seq, y_train_seq = create_sequences(X_train_scaled, y_train_scaled, TIME_STEPS)
X_test_seq, y_test_seq = create_sequences(X_test_scaled, y_test_scaled, TIME_STEPS) # y_test_seq 是缩放后的目标，用于潜在对比，但主要用原始值评估

print(f"Prepared data for LSTM: X_train_seq shape {X_train_seq.shape}, y_train_seq shape {y_train_seq.shape}")
print(f"Prepared data for LSTM: X_test_seq shape {X_test_seq.shape}, y_test_seq shape {y_test_seq.shape}")

In [None]:
# --- 4. 构建和训练 LSTM 模型 ---
print("Building and training LSTM model...")

model = Sequential()
model.add(LSTM(units=50,
               return_sequences=True,
               input_shape=(X_train_seq.shape[1], X_train_seq.shape[2])))
model.add(Dropout(0.2))
model.add(LSTM(units=50))
model.add(Dropout(0.2))
model.add(Dense(units=1))

model.compile(optimizer='adam', loss='mse')

early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

history = model.fit(X_train_seq, y_train_seq,
                    epochs=70, #可以修改测试次数
                    batch_size=32,
                    validation_split=0.2, # 从训练序列中分验证集
                    callbacks=[early_stopping],
                    verbose=1)

print("LSTM model training finished.")


In [None]:
# --- 5. 模型预测 (测试集) ---
# 在测试序列上进行预测
y_pred_scaled = model.predict(X_test_seq)

# 将预测结果反标准化回原始尺度
y_pred = scaler_y.inverse_transform(y_pred_scaled)


# 确保我们有足够的实际值来匹配预测值
if len(test_df) >= TIME_STEPS + len(y_pred):
    test_indices_actual = test_df.index[TIME_STEPS : TIME_STEPS + len(y_pred)]
    y_test_actual = test_df.loc[test_indices_actual, target_col].values
else:
     # 如果 test_df 不够长，可能需要调整逻辑或检查序列创建
     print("警告：测试集的实际值数量不足以完全匹配预测值。")
     # 取尽可能多的值进行比较
     available_length = len(test_df) - TIME_STEPS
     test_indices_actual = test_df.index[TIME_STEPS : TIME_STEPS + available_length]
     y_test_actual = test_df.loc[test_indices_actual, target_col].values
     y_pred = y_pred[:available_length] # 截断预测值以匹配

# 展平预测结果以便比较
y_pred = y_pred.flatten()

print(f"实际值数量: {len(y_test_actual)}, 预测值数量: {len(y_pred)}")
# 再次确保长度一致 (通常在上面处理后应该一致)
min_len = min(len(y_test_actual), len(y_pred))
y_test_actual = y_test_actual[:min_len]
y_pred = y_pred[:min_len]
test_indices_actual = test_indices_actual[:min_len] # 同时调整索引长度


In [None]:
# --- 6. 模型评估 (测试集) ---
if len(y_test_actual) > 0 and len(y_pred) > 0: # 确保有数据进行评估
    rmse = mean_squared_error(y_test_actual, y_pred, squared=False)
    mae = mean_absolute_error(y_test_actual, y_pred)
    r2 = r2_score(y_test_actual, y_pred)
    mape = mean_absolute_percentage_error(y_test_actual, y_pred)

    print(f"\nLSTM 模型在测试集 ({test_days}天) 上的性能指标：")
    print(f"RMSE: {rmse:.3f}")
    print(f"MAE: {mae:.3f}")
    print(f"R²: {r2:.3f}")
else:
    print("\n无法进行模型评估，因为实际值或预测值为空。")


In [None]:
# --- 8. 可视化最终预测结果 (测试集) ---
if len(y_test_actual) > 0 and len(y_pred) > 0:
    plt.figure(figsize=(14, 7))
    plt.plot(test_indices_actual, y_test_actual, label='Actual')
    plt.plot(test_indices_actual, y_pred, label='LSTM Predict', color='red', alpha=0.8)
    plt.title(f'LSTM Results (Test - Last{test_days}days)')
    plt.xlabel('Time')
    plt.ylabel('Demand')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
else:
    print("\n无法可视化预测结果，因为实际值或预测值为空。")

In [None]:
print(mape)