In [None]:
import os
import re
import numpy as np
import pandas as pd

# ======================== 配置区域 ========================

# 输入的评估结果 CSV（由 eval_all_experiments.ipynb 生成）
INPUT_CSV = "eval_summary_all_experiments.csv"

# 输出的最终提交 Excel 文件
OUTPUT_XLSX = "测试表格输出.xlsx"

# family -> 基准 名称映射
FAMILY_TO_BENCHMARK = {
    "SKIPPD": "SKIPPD",
    "CSG": "StateGrid",
    "GEFCOM": "GEFCom2014",
}

# 模型排序顺序
MODEL_ORDER = [
    "DLinear",
    "PatchTST",
    "iTransformer",
    "TCN",
    "Transformer",
    "GBDT",
    "TimesFM",
]

# 基准排序顺序
BENCHMARK_ORDER = ["SKIPPD", "StateGrid", "GEFCom2014"]

print("配置已加载，准备读取 CSV:", os.path.abspath(INPUT_CSV))

配置已加载，准备读取 CSV: /home/huangyx/workspace/solar-energy/eval_summary_all_experiments.csv


In [8]:
def compute_station_id(row):
    """
    根据 family + dataset_site_base + site_for_csv + model 推断“电站编号”（统一为整数）。
    规则：
      - SKIPPD: 只有一个站 -> 电站 = 1
      - CSG:    dataset_site_base 形如 CSGS3 / CSGS3_S / CSGS3_MS
                取核心名 core = 'CSGS3'，电站 = 3
      - GEFCOM:
          * 对非 TimesFM 模型（多通道）：
              site_for_csv 形如 'GEFCOM_TASK15_ch3'，ch 后面的数字为站号
          * 对 TimesFM 模型（单通道特例）：
              dataset_site_base 形如 'GEFCOM_TASK151' / '152' / '153'，最后一位数字为站号
    """
    family = row["family"]
    base = row["dataset_site_base"]
    site_for_csv = row.get("site_for_csv", "")
    model = row.get("model", "")

    # SKIPPD: 只有一个站，固定为 1
    if family == "SKIPPD":
        return 1

    # CSG: CSGS1 / CSGS1_S / CSGS1_MS
    if family == "CSG":
        core = base.split("_")[0]  # CSGS3_S -> CSGS3
        if not core.startswith("CSGS"):
            raise ValueError(f"CSG 家族站点名不符合预期: {base}")
        num_str = core.replace("CSGS", "")
        if not num_str.isdigit():
            raise ValueError(f"CSG 家族站点编号无法解析: {base}")
        return int(num_str)

    # GEFCOM: 区分 TimesFM 和其他模型
    if family == "GEFCOM":
        # 1) 非 TimesFM：多通道模型，用 site_for_csv 的 _chN
        if model != "TimesFM":
            if isinstance(site_for_csv, str) and "_ch" in site_for_csv:
                try:
                    station_id = int(site_for_csv.split("_ch")[-1])
                    return station_id
                except ValueError:
                    pass
            raise ValueError(f"GEFCOM(非 TimesFM) 家族站点编号无法解析: dataset_site_base={base}, site_for_csv={site_for_csv}")

        # 2) TimesFM 特例：dataset_site_base 形如 'GEFCOM_TASK151' / '152' / '153'
        m = re.match(r"^GEFCOM_TASK(\d+)(\d)$", base)
        if m:
            station_id = int(m.group(2))
            return station_id

        raise ValueError(f"GEFCOM(TimesFM) 家族站点编号无法解析: dataset_site_base={base}, site_for_csv={site_for_csv}")

    # 其它未知 family
    raise ValueError(f"未知 family，无法解析电站编号: {family}")


def csg_name_priority(dataset_site_base):
    """
    对 CSG 家族里的不同命名（CSGSx, CSGSx_S, CSGSx_MS）指定优先级：
      - CSGSx   : 0 （最高）
      - CSGSx_S : 1
      - CSGSx_MS: 2
      - 其它未知后缀: 99
    非 CSG 家族可以统一返回 0，由外层控制。
    """
    if "_" not in dataset_site_base:
        return 0
    suffix = dataset_site_base.split("_", 1)[1]
    if suffix == "S":
        return 1
    if suffix == "MS":
        return 2
    return 99


def pick_best_row(group):
    """
    对同一 (model, family, station_id, track_type) 下的多条记录，选一条“最佳”结果：
      - 对 CSG 家族：优先级 CSGSx > CSGSx_S > CSGSx_MS
      - 其次，在同优先级内：acc_rmse 最大者优先
      - 如果 acc_rmse 全是 NaN，则改用 acc_mae 最大者
      - 如果依然全 NaN，则按名称优先级随便选一条
    """
    family = group["family"].iloc[0]
    g = group.copy()

    # 名称优先级（仅 CSG 有效）
    if family == "CSG":
        g["name_priority"] = g["dataset_site_base"].apply(csg_name_priority)
    else:
        g["name_priority"] = 0

    # 1) 优先根据 acc_rmse 选最大者
    g_valid_rmse = g[~g["acc_rmse"].isna()]
    if len(g_valid_rmse) > 0:
        g_sorted = g_valid_rmse.sort_values(
            by=["name_priority", "acc_rmse"],
            ascending=[True, False],
        )
        return g_sorted.iloc[0]

    # 2) 退而求其次，根据 acc_mae 选最大者
    g_valid_mae = g[~g["acc_mae"].isna()]
    if len(g_valid_mae) > 0:
        g_sorted = g_valid_mae.sort_values(
            by=["name_priority", "acc_mae"],
            ascending=[True, False],
        )
        return g_sorted.iloc[0]

    # 3) 如果都 NaN，就按名称优先级选第一条
    g_sorted = g.sort_values(by=["name_priority"])
    return g_sorted.iloc[0]


print("辅助函数已定义。")

辅助函数已定义。


In [9]:
# 1. 读取 CSV
csv_path = os.path.abspath(INPUT_CSV)
if not os.path.exists(csv_path):
    raise FileNotFoundError(f"找不到输入 CSV 文件: {csv_path}")

df = pd.read_csv(csv_path)
print("CSV 已读取，行数:", len(df))

# 必要列检查
required_cols = [
    "model", "family", "dataset_site_base",
    "track_type", "site_for_csv",
    "acc_mae", "acc_rmse"
]
missing = [c for c in required_cols if c not in df.columns]
if missing:
    raise ValueError(f"CSV 中缺少必要列: {missing}")

# 1.1 统一模型命名（例如 tcn -> TCN）
model_rename_map = {
    "tcn": "TCN",
}
df["model"] = df["model"].replace(model_rename_map)

# 2. 计算电站编号 station_id
df["station_id"] = df.apply(compute_station_id, axis=1)

# 3. 对相同 (model, family, station_id, track_type) 的多条记录做合并，选出最佳记录
group_cols = ["model", "family", "station_id", "track_type"]
best_rows = df.groupby(group_cols, as_index=False).apply(pick_best_row)

# groupby+apply 会生成多重索引，这里把 index 展平
if isinstance(best_rows.index, pd.MultiIndex):
    best_rows = best_rows.reset_index(drop=True)

print("筛选最佳记录后行数:", len(best_rows))

# 4. 准备透视成最终结构
base_cols = ["model", "family", "station_id", "track_type", "acc_mae", "acc_rmse"]
df_best = best_rows[base_cols].copy()

# 5. 分别对 acc_mae 和 acc_rmse 做 pivot，得到列：one-step / 4h / 72h
pivot_mae = df_best.pivot_table(
    index=["model", "family", "station_id"],
    columns="track_type",
    values="acc_mae",
    aggfunc="first"
)
pivot_rmse = df_best.pivot_table(
    index=["model", "family", "station_id"],
    columns="track_type",
    values="acc_rmse",
    aggfunc="first"
)

# 6. 重命名列为目标 Excel 的列名
# 对 MAE
mae_col_map = {
    "one-step": "短临功率 1-MAE/Cap",
    "4h": "超短期功率 1-MAE/Cap",
    "72h": "短期功率 1-MAE/Cap",
}
pivot_mae = pivot_mae.rename(columns=mae_col_map)

# 对 RMSE
rmse_col_map = {
    "one-step": "短临功率 1-RMSE/Cap",
    "4h": "超短期功率 1-RMSE/Cap",
    "72h": "短期功率 1-RMSE/Cap",
}
pivot_rmse = pivot_rmse.rename(columns=rmse_col_map)

# 7. 合并 MAE & RMSE 表
final_df = pivot_mae.join(pivot_rmse, how="outer")

# 8. 把索引恢复为列
final_df = final_df.reset_index()  # 得到 columns: model, family, station_id, ...

# 9. 映射 family -> 基准
final_df["基准"] = final_df["family"].map(FAMILY_TO_BENCHMARK)

# 10. 重命名列：model -> 模型，station_id -> 电站
final_df["模型"] = final_df["model"]
final_df["电站"] = final_df["station_id"].astype(int)

# 11. 按要求的列顺序排列
desired_cols = [
    "模型",
    "基准",
    "电站",
    "短临功率 1-MAE/Cap",
    "短临功率 1-RMSE/Cap",
    "超短期功率 1-MAE/Cap",
    "超短期功率 1-RMSE/Cap",
    "短期功率 1-MAE/Cap",
    "短期功率 1-RMSE/Cap",
]

# 有些列可能因为没跑某些 track 而不存在，这里按需补齐为 NaN
for col in desired_cols:
    if col not in final_df.columns:
        final_df[col] = np.nan

final_df = final_df[desired_cols]

# 12. 所有 acc 列乘以 100（转为百分比数值），列名保持不变
acc_cols = [
    "短临功率 1-MAE/Cap",
    "短临功率 1-RMSE/Cap",
    "超短期功率 1-MAE/Cap",
    "超短期功率 1-RMSE/Cap",
    "短期功率 1-MAE/Cap",
    "短期功率 1-RMSE/Cap",
]
final_df[acc_cols] = final_df[acc_cols] * 100.0

# 13. 设置 模型 / 基准 的排序顺序
final_df["模型"] = pd.Categorical(final_df["模型"], categories=MODEL_ORDER, ordered=True)
final_df["基准"] = pd.Categorical(final_df["基准"], categories=BENCHMARK_ORDER, ordered=True)

final_df = final_df.sort_values(by=["模型", "基准", "电站"]).reset_index(drop=True)

# 14. 导出到 Excel
xlsx_path = os.path.abspath(OUTPUT_XLSX)
final_df.to_excel(xlsx_path, index=False)
print("最终提交表格已生成:", xlsx_path)
display(final_df)

CSV 已读取，行数: 129
筛选最佳记录后行数: 121
最终提交表格已生成: /home/huangyx/workspace/solar-energy/final_submission.xlsx


  best_rows = df.groupby(group_cols, as_index=False).apply(pick_best_row)


track_type,模型,基准,电站,短临功率 1-MAE/Cap,短临功率 1-RMSE/Cap,超短期功率 1-MAE/Cap,超短期功率 1-RMSE/Cap,短期功率 1-MAE/Cap,短期功率 1-RMSE/Cap
0,PatchTST,GEFCom2014,1,96.06972,93.213531,93.23263,88.689584,92.78531,87.371379
1,PatchTST,GEFCom2014,2,95.51072,91.902919,93.06673,88.476174,92.46599,87.017045
2,PatchTST,GEFCom2014,3,95.984507,93.241313,93.36565,89.017938,92.78394,87.819098
3,TCN,SKIPPD,1,97.94623,94.595328,94.89479,90.54477,92.38199,88.413175
4,TCN,StateGrid,1,98.38449,96.198114,93.934065,89.785091,93.21617,88.920217
5,TCN,StateGrid,2,98.56702,96.850208,95.6128,91.838405,94.3749,90.796547
6,TCN,StateGrid,3,97.2266,93.934005,93.14077,88.050254,91.87381,87.042337
7,TCN,StateGrid,4,98.350435,96.04368,95.047617,90.593769,92.442226,86.81115
8,TCN,StateGrid,5,98.10843,95.445628,94.571704,89.838113,92.45552,87.34469
9,TCN,StateGrid,6,98.049194,95.729085,94.698507,90.409311,92.25295,87.559427
