In [1]:
import os
import glob
import re
import pandas as pd
from pandas.api.types import CategoricalDtype
import numpy as np
from datetime import datetime

In [2]:
def get_log(folder_path):
    log_files = glob.glob(os.path.join(folder_path, '*.log'))
    if not log_files:
        return None
    latest_file = max(log_files, key=os.path.getmtime)
    return latest_file

def read_log(log_path):
    with open(log_path) as f:
        content=f.readlines()[::-1]
        for line in content:
            if " - Average:" in line:
                return line[31:-1]

def seconds_between(t1: str, t2: str) -> int:
    fmt = "%Y-%m-%d %H:%M:%S"
    dt1 = datetime.strptime(t1, fmt)
    dt2 = datetime.strptime(t2, fmt)
    diff = (dt2 - dt1).total_seconds()
    return int(abs(diff))

def get_time(log_path):
    with open(log_path) as f:
        content=f.readlines()
        t1=''
        t2=''
        for line in content:
            if " - Data shape:" in line:
                t1=line[:19]
                continue
            if " - Average:" in line:
                t2=line[:19]
                return seconds_between(t1,t2)

def get_avg_time(log_path):
    with open(log_path) as f:
        content=f.readlines()
        epoch = 1
        t1=''
        t2=''
        for line in content:
            if " - Epoch: " in line:
                epoch=int(line.split(',')[0].split()[-1])
            if " - Data shape:" in line:
                t1=line[:19]
                continue
            if " - Average:" in line:
                t2=line[:19]
                return "%.2f"%(seconds_between(t1,t2)/epoch)
            

def get_parameter(log_path):
    with open(log_path) as f:
        content=f.readlines()
        for line in content:
            if "The number of parameters:" in line:
                return line.split()[-1]
    return 0

In [3]:
names=["HL","STGCN","GWNET","ASTGCN","AGCRN","STGODE","STTN","DCRNN","DSTAGNN","LSTM"]
metrics=["MAE","RMSE","MAPE","MPIW","WINK","COV"]
datasets=["panhandle"]
cat_type = CategoricalDtype(categories=names, ordered=True)

In [4]:
names=["ARIMA","SARIMA"]
metrics=["MSE","MAE"]
datasets=["sz_taxi_od","sz_bike_od","sz_subway_od",] #
cat_type = CategoricalDtype(categories=names, ordered=True)

In [5]:
names=["HA_OD","HL_OD","ARIMA","SARIMA","LSTM_OD","GMEL","GWNET_OD","STGCN_OD","HMDLF","MPGCN_OD","STZINB","STTN","AGCRN_OD","ASTGCN_OD","STGODE_OD","ODMixer"]
metrics=["MSE","MAE"]
datasets=["sz_taxi_od","sz_bike_od","sz_subway_od",] #
cat_type = CategoricalDtype(categories=names, ordered=True)

In [8]:
names=["HA_OD","HL_OD","ARIMA","SARIMA","LSTM_OD","GMEL","GWNET_OD","STGCN_OD","HMDLF","MPGCN_OD","STZINB","STTN","AGCRN_OD","ASTGCN_OD","STGODE_OD","ODMixer"]
metrics=["MSE","MAE"]
datasets=["nyc_taxi_od","nyc_bike_od","nyc_subway_od",] #
cat_type = CategoricalDtype(categories=names, ordered=True)

In [7]:
names=["GWNET_OD","STZINB","AGCRN_OD"]
metrics=["MSE","MAE"]
datasets=["sz_subway_bike_od","sz_subway_taxi_od"] #
cat_type = CategoricalDtype(categories=names, ordered=True)

In [None]:
names=["GWNET_OD","STZINB","AGCRN_OD"]
metrics=["MSE","MAE"]
datasets=["nyc_subway_bike_od","sz_subway_taxi_od"] #
cat_type = CategoricalDtype(categories=names, ordered=True)

In [None]:
names=["HL","STGCN","GWNET","ASTGCN","AGCRN","STGODE","STTN","DCRNN","DSTAGNN","LSTM","TrustEnergy"]
metrics=["MAE","RMSE","MAPE","MPIW","WINK","COV"]
datasets=["panhandle"]
cat_type = CategoricalDtype(categories=names, ordered=True)

In [8]:
path="/home/dy23a.fsu/st/result/cross"

In [6]:
path="/home/dy23a.fsu/st/result/sz"

In [6]:
path="/blue/gtyson.fsu/dy23a.fsu/result/result/"

In [9]:
path="/home/dy23a.fsu/st/result/nyc"

In [4]:
path="/home/dy23a.fsu/st/result/ph"

In [5]:
rows = []
for name in names:
    for dataset in datasets:
        path_=f"{path}/{name}/{dataset}"
        if log:=get_log(path_):
            res=read_log(log)
            time_=get_time(log)
            param_=get_parameter(log)
            if res is None:
                continue
            row = {'Dataset': dataset, 'Model': name}
            m = dict(re.findall(r'(\w+): ([\-\d\.]+)', res))
            not_keys=[i for i in m.keys() if i not in metrics]
            for i in not_keys:
                del m[i]
            
            row.update(m)
            row.update({'time':time_, 'param':param_})
            rows.append(row)


df = pd.DataFrame(rows)
df['Model'] = df['Model'].astype(cat_type)
df_sorted = df.sort_values(by=['Dataset','Model',])
print(df_sorted)


     Dataset   Model    MAE    MAPE    RMSE  time    param
0  panhandle      HL  0.571  60.085   4.049    91        8
1  panhandle   STGCN  1.108  69.067  11.179    94   277684
2  panhandle   GWNET  0.533  60.515   3.335   221   311252
3  panhandle  ASTGCN  1.267  63.317  11.095   137  3524898
4  panhandle   AGCRN  0.632  60.391   4.227   382   765980
5  panhandle    STTN  0.547  59.155   3.479   443   122660
6  panhandle   DCRNN  0.580  56.809   3.939    23    25220
7  panhandle    LSTM  0.613  66.287   3.960   343    92417


In [8]:
df_sorted["Dataset"] = df_sorted["Dataset"].str.replace("sz_", "").str.replace("_od", "")  # 只保留 bike/taxi/subway
# df_sorted["Dataset"] = df_sorted["Dataset"].str.replace("nyc_", "").str.replace("_od", "")
df_sorted[["MSE", "MAE"]] = df_sorted[["MSE", "MAE"]].apply(pd.to_numeric)
df_pivot = df_sorted.pivot_table(
    index="Model",
    columns="Dataset",
    values=["MSE", "MAE"],
    observed=False
)
df_pivot.columns = [f"{ds}_{metric}" for metric, ds in df_pivot.columns]
df_pivot = df_pivot.reset_index()
col_order = ["Model","taxi_MSE", "taxi_MAE", "bike_MSE", "bike_MAE", "subway_MSE", "subway_MAE"]
df_pivot = df_pivot[col_order]
print(df_pivot)

        Model  taxi_MSE  taxi_MAE  bike_MSE  bike_MAE  subway_MSE  subway_MAE
0       HA_OD     1.041     0.135     7.176     0.106      20.451       0.217
1       HL_OD     0.390     0.114     2.733     0.166       9.329       0.247
2       ARIMA     0.536     0.099     3.199     0.071      10.522       0.134
3      SARIMA     1.059     0.112     5.958     0.092      10.920       0.124
4     LSTM_OD     1.005     0.179     6.172     0.176      10.124       0.208
5        GMEL     0.775     0.202     6.336     0.279      10.590       0.259
6    GWNET_OD     0.237     0.097     0.819     0.135       1.345       0.078
7    STGCN_OD     0.296     0.110     2.736     0.187       5.882       0.194
8       HMDLF     1.442     0.225     6.561     0.217      10.897       0.245
9    MPGCN_OD     0.409     0.130     2.360     0.138       8.641       0.251
10     STZINB     0.259     0.120     0.885     0.129       2.141       0.151
11   AGCRN_OD     0.264     0.113     0.972     0.164       1.24

In [9]:
df_sorted["Dataset"] = df_sorted["Dataset"].str.replace("sz_", "").str.replace("_od", "")  # 只保留 bike/taxi/subway
df_sorted[["time", "param"]] = df_sorted[["time", "param"]].apply(pd.to_numeric)
df_pivot = df_sorted.pivot_table(
    index="Model",
    columns="Dataset",
    values=["time", "param"],
    observed=False
)
df_pivot.columns = [f"{ds}_{metric}" for metric, ds in df_pivot.columns]
df_pivot = df_pivot.reset_index()
col_order = ["Model","taxi_param", "taxi_time", "bike_time", "subway_time"]
# col_order = ["Model","taxi_time", "taxi_param", "bike_time", "bike_param", "subway_time", "subway_param"]
df_pivot = df_pivot[col_order]
print(df_pivot)

        Model  taxi_param  taxi_time  bike_time  subway_time
0       HA_OD         0.0        6.0        6.0          5.0
1       HL_OD         7.0     4445.0     4327.0       2764.0
2       ARIMA         0.0    17533.0    20763.0      23037.0
3      SARIMA         0.0     9563.0    10819.0      11731.0
4     LSTM_OD     21473.0     5294.0     1512.0       1959.0
5        GMEL    330241.0     7129.0     3272.0       1677.0
6    GWNET_OD    230599.0     1368.0      514.0       1413.0
7    STGCN_OD    564843.0     1266.0      620.0        550.0
8       HMDLF   2118772.0     1514.0     1676.0       1203.0
9    MPGCN_OD      1154.0   121081.0   114939.0     120955.0
10     STZINB    213724.0     1940.0     1239.0       3540.0
11   AGCRN_OD    829345.0     1621.0     1892.0       2847.0
12  ASTGCN_OD   1598610.0     1040.0      742.0        890.0
13    ODMixer  23243795.0      926.0      249.0        831.0
