In [4]:
# 📊 导入库
import json
import os
import matplotlib.pyplot as plt

# 📁 读取所有 metrics/*.json 文件
metric_files = {
    'Linear Regression': 'metrics/lr_val.json',
    'Decision Tree': 'metrics/dt_val.json',
    'CatBoost': 'metrics/catboost_val.json',
    'XGBoost': 'metrics/xgboost_val.json',
    'MLP': 'metrics/mlp_val.json'
}

results = {}

for model_name, path in metric_files.items():
    if os.path.exists(path):
        with open(path) as f:
            data = json.load(f)
            results[model_name] = data['mse']
    else:
        print(f"⚠️ 找不到 {path}")

# 📈 可视化模型验证集 MSE
plt.figure(figsize=(10, 6))
sorted_results = dict(sorted(results.items(), key=lambda x: x[1]))
plt.barh(list(sorted_results.keys()), list(sorted_results.values()), color='skyblue')
plt.xlabel('Validation MSE (Mean Squared Error)')
plt.title('模型验证性能比较 (MSE 越低越好)')
for i, (model, mse) in enumerate(sorted_results.items()):
    plt.text(mse, i, f"{mse:.2f}", va='center', ha='left', fontsize=10)
plt.grid(axis='x')
plt.show()


KeyError: 'mse'