In [2]:
import pandas as pd


In [5]:
df_T = pd.read_csv('./data/weather/heat.csv')
df_BSU = pd.read_csv('./data/processed_bsu/combined_peak_non_peak.csv')
df_v = pd.read_csv('./data/variables/variables_constant.csv')
df_panel = df_T.merge(df_BSU, on='ID', how='inner').merge(df_v, on='ID', how='inner')
df_panel

Unnamed: 0,ID,4_peak_mean,4_non_mean,7_peak_mean,7_non_mean,non_peak_4_wd,peak_4_wd,non_peak_4_we,peak_4_we,non_peak_7_wd,...,road density,building density,BVI,SVI,GVI,VNMI,VHI,metro distance,slope,AQI
0,28,30.929552,57.779292,68.511911,72.466418,0.0,0.0,0.0,1.0,0.6,...,0.001337,0.031246,0.114503,0.145339,0.068400,0.000000,0.000000,1218.738329,3.551787,26.978649
1,35,30.713772,57.765379,71.498045,73.009976,1.4,4.8,0.0,5.5,0.8,...,0.000222,0.028689,0.054315,0.147267,0.006003,0.000219,0.000017,943.489631,4.723129,26.950956
2,42,29.737355,51.439238,64.771885,62.969746,0.0,0.4,1.0,0.0,0.2,...,0.000000,0.041215,0.032632,0.160840,0.207985,0.000005,0.000404,746.278631,4.112320,26.907364
3,43,30.907864,57.699741,70.466082,72.489102,3.8,9.2,3.5,8.5,0.6,...,0.002563,0.031364,0.020857,0.146140,0.071911,0.000649,0.000000,667.004729,4.558258,27.033136
4,47,29.441757,57.303955,71.048916,72.194805,3.6,6.8,1.0,15.5,3.2,...,0.001088,0.005179,0.043475,0.150263,0.130592,0.003517,0.000000,1500.737554,6.091369,26.362383
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2530,4319,31.206336,56.303270,70.739568,70.230578,5.6,15.4,15.5,12.0,2.2,...,0.000295,0.077847,0.103160,0.160762,0.081311,0.008137,0.000986,769.204882,10.702406,18.472618
2531,4340,30.296126,51.703903,66.256385,64.046575,7.6,19.2,8.5,14.5,4.8,...,0.002583,0.000000,0.159383,0.140309,0.021395,0.002903,0.000033,1009.162450,9.568112,17.354141
2532,4341,31.601119,57.326675,72.179266,71.878995,26.4,52.6,33.5,48.5,12.4,...,0.000140,0.267792,0.112571,0.155573,0.016338,0.007124,0.000774,1112.331460,5.309022,17.826118
2533,4342,31.573660,57.469844,72.251265,72.409717,23.8,23.4,19.5,18.0,4.0,...,0.001296,0.145298,0.254741,0.163659,0.037731,0.007594,0.004545,1062.287726,5.738050,18.303034


In [23]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score
from xgboost import XGBRegressor
import shap
import matplotlib.pyplot as plt
import os
# 设置全局字体和大小
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["font.size"] = 8
date = 'non_peak_4_wd'
# 定义模型配置
config = {
    'y': date,
    'x': ['4_non_mean','bus distance', 'POI diversity', 'NDVI', 'road density',
          'building density', 'BVI', 'SVI', 'GVI', 'VNMI', 'VHI',
          'metro distance', 'slope', 'AQI']
}

print(f"\nTraining model6...")

# 定义特征和目标
X = df_panel[config['x']]
y = df_panel[config['y']]

# 标准化特征
scaler = StandardScaler()
X_standardized = scaler.fit_transform(X)

# 将标准化后的数据转换回 DataFrame，保留原始列名
X_standardized = pd.DataFrame(X_standardized, columns=X.columns)

# 训练 XGBoost 模型
model = XGBRegressor(n_estimators=20, random_state=42)
model.fit(X_standardized, y)

# 预测
y_pred = model.predict(X_standardized)

# 计算性能指标
mse = mean_squared_error(y, y_pred)
r2 = r2_score(y, y_pred)

print(f'Mean Squared Error (MSE): {mse:.4f}')
print(f'R^2 Score: {r2:.4f}')

# 计算 SHAP 值
explainer = shap.Explainer(model)
shap_values = explainer(X_standardized)

# 计算 SHAP 重要性
shap_values_array = shap_values.values
shap_sum = np.abs(shap_values_array).mean(axis=0)

importance_df = pd.DataFrame(list(zip(config['x'], shap_sum)), columns=['Variable', 'SHAP Importance'])
importance_df = importance_df.sort_values(by='SHAP Importance', ascending=False)
importance_df.to_csv(f'./data/XAI_output/{date}.csv')
# 输出重要性排名
print("\nFeature Importance:")
print(importance_df)


# shap.summary_plot(shap_values, X_standardized, plot_type="bar", show=False)
# plt.gcf().set_size_inches(6, 10)  # 设置图像大小
# plt.gca().set_facecolor('white')
# plt.title('Global features importance in spring', fontsize=16, fontname='Times New Roman')
# plt.xticks(fontsize=18, fontname='Times New Roman')  # 设置x轴刻度字体大小
# plt.yticks(fontsize=18, fontname='Times New Roman')  # 设置y轴刻度字体大小
# plt.savefig(os.path.join('../resub_fig/XAI/shap_summary_bar_7nwd.png'), bbox_inches='tight', dpi=300)



# shap.summary_plot(shap_values, X_standardized, show=False)
# plt.gcf().set_size_inches(6, 10)
# plt.gca().set_facecolor('white')
# plt.grid(True, color='gray', linestyle='--', linewidth=0.5)
# plt.title('Global features importance in spring', fontsize=16, fontname='Times New Roman')
# plt.xticks(fontsize=18, fontname='Times New Roman')  # 设置x轴刻度字体大小
# plt.yticks(fontsize=18, fontname='Times New Roman')  # 设置y轴刻度字体大小
# plt.savefig(os.path.join('../resub_fig/XAI/shap_summary_7nwd.png'), bbox_inches='tight', dpi=300)




Training model6...
Mean Squared Error (MSE): 3196.3293
R^2 Score: 0.8467

Feature Importance:
            Variable  SHAP Importance
2      POI diversity        37.429283
3               NDVI        27.544397
11    metro distance        17.559328
1       bus distance        17.458853
0         4_non_mean        10.954561
12             slope         9.796349
13               AQI         8.834731
6                BVI         7.076349
5   building density         6.475719
4       road density         5.437492
7                SVI         3.935411
10               VHI         3.761254
9               VNMI         3.424200
8                GVI         2.678649
