In [1]:
import os
import glob
import re
import pandas as pd
from pandas.api.types import CategoricalDtype
import numpy as np
from datetime import datetime

In [6]:
def get_log(folder_path):
    """获取文件夹中最新的日志文件"""
    log_files = glob.glob(os.path.join(folder_path, '*.log'))
    return max(log_files, key=os.path.getmtime) if log_files else None

def read_log(log_path):
    """从日志中读取 Average 行的指标"""
    with open(log_path) as f:
        for line in reversed(f.readlines()):
            if " - Average:" in line:
                return line[31:-1]
    return None

def seconds_between(t1: str, t2: str) -> int:
    """计算两个时间字符串之间的秒数差"""
    fmt = "%Y-%m-%d %H:%M:%S"
    return int(abs((datetime.strptime(t2, fmt) - datetime.strptime(t1, fmt)).total_seconds()))

def get_time(log_path):
    """获取训练总时间（从 Data shape 到 Average）"""
    with open(log_path) as f:
        t1 = None
        for line in f:
            if " - Data shape:" in line:
                t1 = line[:19]
            elif " - Average:" in line and t1:
                return seconds_between(t1, line[:19])
    return None

def get_avg_time(log_path):
    """获取每个 epoch 的平均训练时间"""
    with open(log_path) as f:
        epoch, t1 = 1, None
        for line in f:
            if " - Epoch: " in line:
                epoch = int(line.split(',')[0].split()[-1])
            elif " - Data shape:" in line:
                t1 = line[:19]
            elif " - Average:" in line and t1:
                return f"{seconds_between(t1, line[:19]) / epoch:.2f}"
    return None

def get_parameter(log_path):
    """获取模型参数量"""
    with open(log_path) as f:
        for line in f:
            if "Parameters:" in line:
                return line.split()[-1]
    return 0

def collect_results(names, datasets, metrics, path):
    """收集所有模型在所有数据集上的结果"""
    rows = []
    for name in names:
        for dataset in datasets:
            log = get_log(f"{path}/{name}/{dataset}")
            if not log:
                continue
            res = read_log(log)
            if res is None:
                continue
            
            row = {'Dataset': dataset, 'Model': name}
            m = {k: v for k, v in re.findall(r'(\w+): ([\-\d\.]+)', res) if k in metrics}
            row.update(m)
            row.update({'time': get_time(log), 'param': get_parameter(log)})
            rows.append(row)
    return rows

In [None]:
names=["HL","LSTM","Transformer","Mamba","UMamba","STGCN","GWNET","ASTGCN","AGCRN","STTN","DGCRN","DCRNN"]
metrics=["MAE","RMSE","MAPE"]
datasets=["panhandle"]
cat_type = CategoricalDtype(categories=names, ordered=True)

In [4]:
names=["ARIMA","SARIMA"]
metrics=["MSE","MAE"]
datasets=["sz_taxi_od","sz_bike_od","sz_subway_od",] #
cat_type = CategoricalDtype(categories=names, ordered=True)

In [5]:
names=["HA_OD","HL_OD","ARIMA","SARIMA","LSTM_OD","GMEL","GWNET_OD","STGCN_OD","HMDLF","MPGCN_OD","STZINB","STTN","AGCRN_OD","ASTGCN_OD","STGODE_OD","ODMixer"]
metrics=["MSE","MAE"]
datasets=["sz_taxi_od","sz_bike_od","sz_subway_od",] #
cat_type = CategoricalDtype(categories=names, ordered=True)

In [8]:
names=["HA_OD","HL_OD","ARIMA","SARIMA","LSTM_OD","GMEL","GWNET_OD","STGCN_OD","HMDLF","MPGCN_OD","STZINB","STTN","AGCRN_OD","ASTGCN_OD","STGODE_OD","ODMixer"]
metrics=["MSE","MAE"]
datasets=["nyc_taxi_od","nyc_bike_od","nyc_subway_od",] #
cat_type = CategoricalDtype(categories=names, ordered=True)

In [7]:
names=["GWNET_OD","STZINB","AGCRN_OD"]
metrics=["MSE","MAE"]
datasets=["sz_subway_bike_od","sz_subway_taxi_od"] #
cat_type = CategoricalDtype(categories=names, ordered=True)

In [None]:
names=["GWNET_OD","STZINB","AGCRN_OD"]
metrics=["MSE","MAE"]
datasets=["nyc_subway_bike_od","sz_subway_taxi_od"] #
cat_type = CategoricalDtype(categories=names, ordered=True)

In [None]:
names=["HL","STGCN","GWNET","ASTGCN","AGCRN","STGODE","STTN","DCRNN","DSTAGNN","LSTM","TrustEnergy"]
metrics=["MAE","RMSE","MAPE","MPIW","WINK","COV"]
datasets=["panhandle"]
cat_type = CategoricalDtype(categories=names, ordered=True)

In [4]:
path="/home/dy23a.fsu/st/result"

In [6]:
path="/home/dy23a.fsu/st/result/sz"

In [9]:
path="/home/dy23a.fsu/st/result/nyc"

In [4]:
path="/home/dy23a.fsu/st/result/ph"

In [7]:
# 收集结果
rows = collect_results(names, datasets, metrics, path)
df = pd.DataFrame(rows)
df['Model'] = df['Model'].astype(cat_type)
df_sorted = df.sort_values(by=['Dataset', 'Model'])
print(df_sorted)

     Dataset        Model    MAE     MAPE    RMSE   time    param
0  panhandle           HL  0.593   63.875   4.230     61       24
1  panhandle         LSTM  0.676   68.764   4.585    224    93443
2  panhandle  Transformer  1.090   72.484  11.119    569    70444
3  panhandle        STGCN  1.185   78.364  11.331    203   277684
4  panhandle        GWNET  1.455   99.943  11.839    217   315356
5  panhandle       ASTGCN  1.717  106.162  11.863    102  3528490
6  panhandle        AGCRN  2.290  134.737  11.969    143   766500
7  panhandle         STTN  1.232   85.889  10.817    706   126764
8  panhandle        DGCRN  0.566   59.276   3.644   1574   307156
9  panhandle        DCRNN  0.683   63.398   4.784  44808    25220


In [None]:
def pivot_metrics(df, prefix="sz_", metrics_cols=["MSE", "MAE"]):
    """将结果表格透视为按数据集分列的格式"""
    df = df.copy()
    df["Dataset"] = df["Dataset"].str.replace(prefix, "").str.replace("_od", "")
    df[metrics_cols] = df[metrics_cols].apply(pd.to_numeric)
    
    df_pivot = df.pivot_table(index="Model", columns="Dataset", values=metrics_cols, observed=False)
    df_pivot.columns = [f"{ds}_{metric}" for metric, ds in df_pivot.columns]
    return df_pivot.reset_index()

# 透视指标结果
df_metrics = pivot_metrics(df_sorted, prefix="sz_", metrics_cols=["MSE", "MAE"])
col_order = ["Model", "taxi_MSE", "taxi_MAE", "bike_MSE", "bike_MAE", "subway_MSE", "subway_MAE"]
df_metrics = df_metrics[[c for c in col_order if c in df_metrics.columns]]
print(df_metrics)

KeyError: "['MSE'] not in index"

In [None]:
# 透视时间和参数结果
df_time = df_sorted.copy()
df_time["Dataset"] = df_time["Dataset"].str.replace("sz_", "").str.replace("_od", "")
df_time[["time", "param"]] = df_time[["time", "param"]].apply(pd.to_numeric)

df_time_pivot = df_time.pivot_table(index="Model", columns="Dataset", values=["time", "param"], observed=False)
df_time_pivot.columns = [f"{ds}_{metric}" for metric, ds in df_time_pivot.columns]
df_time_pivot = df_time_pivot.reset_index()

col_order = ["Model", "taxi_param", "taxi_time", "bike_time", "subway_time"]
df_time_pivot = df_time_pivot[[c for c in col_order if c in df_time_pivot.columns]]
print(df_time_pivot)

        Model  taxi_param  taxi_time  bike_time  subway_time
0       HA_OD         0.0        6.0        6.0          5.0
1       HL_OD         7.0     4445.0     4327.0       2764.0
2       ARIMA         0.0    17533.0    20763.0      23037.0
3      SARIMA         0.0     9563.0    10819.0      11731.0
4     LSTM_OD     21473.0     5294.0     1512.0       1959.0
5        GMEL    330241.0     7129.0     3272.0       1677.0
6    GWNET_OD    230599.0     1368.0      514.0       1413.0
7    STGCN_OD    564843.0     1266.0      620.0        550.0
8       HMDLF   2118772.0     1514.0     1676.0       1203.0
9    MPGCN_OD      1154.0   121081.0   114939.0     120955.0
10     STZINB    213724.0     1940.0     1239.0       3540.0
11   AGCRN_OD    829345.0     1621.0     1892.0       2847.0
12  ASTGCN_OD   1598610.0     1040.0      742.0        890.0
13    ODMixer  23243795.0      926.0      249.0        831.0
