In [3]:
import pandas as pd
import numpy as np
import os
import re


In [4]:
def get_sota_df(result_dict, model_name_list, datasets):
    new_mse_df = {}
    new_mae_df = {}
    for pred_len, result in result_dict.items():
        df = pd.DataFrame(result)
        if df.empty:
            continue
        df = df.set_index('setting')

        # 遍历model_list，如果模型在mse_df的索引中，就将对应的MSE值放入model_dataset_mse中
        model_dataset_mse = pd.DataFrame(index=model_name_list, columns=datasets)
        model_dataset_mae = pd.DataFrame(index=model_name_list, columns=datasets)
        for setting, row in df.iterrows():
            model_name = row['model_name']
            dataset = row['dataset']
            mse = row['mse']
            mae = row['mae']
            if model_name in model_name_list and dataset in datasets:
                try:
                    if model_dataset_mse.loc[model_name, dataset].isna():
                        model_dataset_mse.loc[model_name, dataset] = mse
                        model_dataset_mae.loc[model_name, dataset] = mae
                    else:
                        # mae的值放哪个取决于mse哪个更小
                        if mse < model_dataset_mse.loc[model_name, dataset]:
                            model_dataset_mse.loc[model_name, dataset] = mse
                            model_dataset_mae.loc[model_name, dataset] = mae
                except:
                    model_dataset_mse.loc[model_name, dataset] = mse
                    model_dataset_mae.loc[model_name, dataset] = mae
            else:
                if model_name != 'MambaSimple' and dataset != 'fred-md' and dataset != 'covid-19':
                    print(model_name, setting, row)
        # print(pred_len)
        # display(model_dataset_mse)
        new_mse_df[pred_len] = model_dataset_mse
        new_mae_df[pred_len] = model_dataset_mae
    return new_mse_df, new_mae_df

In [5]:
# 查看目前跑了哪些模型和数据集

# 获取results中，model,dataset对应mse的df表格数据矩阵
datasets = ['ETTh1', 'ETTh2', 'ETTm1', 'ETTm2', 'ili', 'weather', 'ECL', 'Exchange', 'traffic']
# datasets = ['ETTh1', 'ETTh2', 'ili', 'Exchange', 'ETTm1', 'ETTm2', 'weather', 'ECL', 'traffic']
model_name_list =  ['Mamba', 'Autoformer', 'PatchTST', 'DLinear', 'LightTS',  
                    'MICN', 'Koopa', 'FEDformer', 'Reformer', 'SegRNN',
                        'ETSformer','TSMixer',
                        'TimeXer', 'iTransformer', 'Informer', 'Transformer', 'FreTS', 'SCINet', 'PAttn', 'Nonstationary','Pyraformer', 'Crossformer',
                        'TimeMixer', 'FiLM', 'TemporalFusionTransformer','TiDE' , 'TimesNet','DUET']# ,'MambaSimple'
path = '/data/nishome/user1/minqi/TSGym/results_long_term_forecasting/results'
result_dict = {}
for dataset in datasets:
    # dataset = 'ETTh1'
    pred_lens_long = [96, 192, 336, 720]
    pred_lens_short = [24, 36, 48, 60]

    for i, pred_len in enumerate(pred_lens_long):
        model_list = os.listdir(os.path.join(path, dataset))
        if dataset not in ['ili', 'fred-md', 'covid-19']:
            model_list = [_ for _ in model_list if f'pl{pred_len}' in _ ]
        else:
            model_list = [_ for _ in model_list if f'pl{pred_lens_short[i]}' in _ ]
        
        for model in model_list:
            result = np.load(os.path.join(path, dataset, model, 'metrics.npy'), allow_pickle=True)
            if pred_len not in result_dict:
                result_dict[pred_len] = []
            result_dict[pred_len].append({'setting' : model,
                                          'dataset': dataset,
                                          'model_name': model.split('_')[1],
                                          'mse': result[1],
                                         'mae': result[0]})
new_mse_df, new_mae_df = get_sota_df(result_dict, model_name_list, datasets)          

In [6]:
# 旧版命名方式，获取results中，model,dataset对应mse的df表格数据矩阵

path = '/data/nishome/user1/minqi/TSGym/results'
model_dataset_mse = pd.DataFrame(index=model_name_list)
result_dict = {}

for i, pred_len in enumerate(pred_lens_long):
    model_list = os.listdir(path)
    model_list = [_ for _ in model_list if f'pl{pred_len}' in _ ] + [_ for _ in model_list if f'pl{pred_lens_short[i]}' in _ ]
    
    for model in model_list:
        result = np.load(os.path.join(path, model, 'metrics.npy'), allow_pickle=True)
        if pred_len not in result_dict:
            result_dict[pred_len] = []
        result_dict[pred_len].append({'setting' : model,
                                        'dataset': model.split('_')[3],
                                        'model_name': model.split('_')[6],
                                        'mse': result[1],
                                        'mae': result[0],})
old_mse_df, old_mae_df = get_sota_df(result_dict, model_name_list, datasets)          


In [13]:
sota_baseline_list = ['TimeMixer','TSMixer','MICN','TimesNet','PatchTST','Crossformer','SegRNN','Pyraformer','Autoformer']
mean_combined_df = pd.DataFrame()
for pred_len, msedf1 in new_mse_df.items():
    msedf2 = old_mse_df[pred_len]
    maedf1 = new_mae_df[pred_len]
    maedf2 = old_mae_df[pred_len]
    # 处理空值（填充为无穷大以便比较）
    filled_msedf1 = msedf1.fillna(np.inf)
    filled_msedf2 = msedf2.fillna(np.inf)
    
    # 生成布尔掩码：标记哪些位置 msedf1 更小或等于
    mask = (filled_msedf1 <= filled_msedf2)
    # mask1 = (maedf1.fillna(np.inf) <= maedf2.fillna(np.inf))
    # display(mask)
    # display(mask1)
    # assert mask.equals(mask1)
    
    # 合并 MSE：根据掩码选择最小值（优先保留 msedf1）
    combined_min_mse = pd.DataFrame(
        np.where(mask, filled_msedf1, filled_msedf2),
        index=msedf1.index,
        columns=msedf1.columns
    ).replace(np.inf, np.nan)  # 恢复空值
    
    # 合并 MAE：直接根据 MSE 的掩码选择对应位置的 MAE 值
    combined_mae = pd.DataFrame(
        np.where(mask, maedf1, maedf2),
        index=maedf1.index,
        columns=maedf1.columns
    )
    
    # 重命名列名以区分 MSE 和 MAE
    mse_renamed = combined_min_mse.T.rename(columns=lambda x: f"{x}_MSE")
    mae_renamed = combined_mae.T.rename(columns=lambda x: f"{x}_MAE")

    combined_df = pd.concat([mse_renamed, mae_renamed], axis=1)
    # 按模型排序列（可选）
    ordered_columns = []
    for baseline in model_name_list:
    # for baseline in sota_baseline_list:
        ordered_columns.extend([f"{baseline}_MSE", f"{baseline}_MAE"])
    combined_df = combined_df[ordered_columns]
    
    print(pred_len)
    # display(df1)
    # display(df2)
    # print('mse')
    display(combined_min_mse)
    # print('mae')
    # display(combined_mae.loc[sota_baseline_list].T)
    # display(combined_df)
    if mean_combined_df.empty:
        mean_combined_df = combined_df
    else:
        mean_combined_df += combined_df
    # display(mean_combined_df)
    # combined_min_mse.to_excel(f'/data/nishome/user1/minqi/TSGym/trained_sota_{pred_len}.xlsx')
display((mean_combined_df/4).T)
# (mean_combined_df/4).to_excel(f'/data/nishome/user1/minqi/TSGym/mse_sota.xlsx')

192


Unnamed: 0,ETTh1,ETTh2,ETTm1,ETTm2,ili,weather,ECL,Exchange,traffic
Mamba,0.566329,0.457616,0.462869,0.28557,4.249277,0.259068,0.208106,0.306234,0.644102
Autoformer,0.498675,0.433085,0.539125,0.286258,3.098362,0.29681,0.229572,0.290704,0.637168
PatchTST,0.430196,0.378164,0.367827,0.246258,2.399238,0.219723,0.187888,0.192349,0.472138
DLinear,0.445632,0.482442,0.38156,0.288065,4.401989,0.237513,0.210212,0.185506,0.646321
LightTS,0.499411,0.514732,0.403987,0.326096,6.928264,0.214607,0.227033,0.297822,0.637042
MICN,0.515214,0.496378,0.368462,0.269373,2.627363,0.237577,0.173823,0.186831,0.536964
Koopa,0.430503,0.353081,0.345994,0.235075,2.107995,0.19875,0.200836,0.181211,0.567465
FEDformer,0.413864,0.42614,0.425485,0.266113,3.075809,0.29603,0.211628,0.287009,0.606018
Reformer,0.928336,2.545182,0.913632,1.530561,4.081366,0.405499,0.338028,1.566597,0.692696
SegRNN,0.416803,0.370824,0.370063,0.235982,4.380446,0.212149,0.198839,0.18416,0.78154


336


Unnamed: 0,ETTh1,ETTh2,ETTm1,ETTm2,ili,weather,ECL,Exchange,traffic
Mamba,0.518236,0.47085,0.541008,0.367896,2.96717,0.322646,0.198934,0.677297,0.648502
Autoformer,0.503305,0.470436,0.558628,0.342759,2.976865,0.382884,0.247442,0.469892,0.630689
PatchTST,0.471465,0.426993,0.402256,0.309764,2.08455,0.275848,0.203932,0.321728,0.504555
DLinear,0.496432,0.595955,0.413027,0.357913,4.079771,0.281376,0.223083,0.332853,0.653716
LightTS,0.551687,0.666429,0.445469,0.495429,7.18404,0.263368,0.24812,0.482614,0.659934
MICN,0.619621,0.628298,0.423005,0.400819,2.738723,0.27996,0.188078,0.309903,0.546745
Koopa,0.347367,0.244254,0.38651,0.28816,1.960803,0.244921,0.229891,0.044236,0.598684
FEDformer,0.455934,0.455448,0.445165,0.324244,3.022861,0.332809,0.221088,0.460981,0.633603
Reformer,0.949929,2.565673,1.03291,2.113864,4.361001,0.58544,0.348888,1.948923,0.690233
SegRNN,0.451191,0.420508,0.397975,0.294335,4.678205,0.271825,0.217529,0.341544,0.809765


720


Unnamed: 0,ETTh1,ETTh2,ETTm1,ETTm2,ili,weather,ECL,Exchange,traffic
Mamba,0.604278,0.576683,0.637129,0.574507,3.595927,0.390244,0.239907,1.73614,0.720632
Autoformer,0.531472,0.483811,0.523791,0.434019,2.89311,0.436788,0.244337,1.12659,0.66619
PatchTST,0.534888,0.439634,0.460956,0.412885,1.98894,0.351461,0.245772,0.923453,0.538459
DLinear,0.521082,0.841869,0.477598,0.556414,4.221986,0.346673,0.257823,0.773232,0.693719
LightTS,0.621841,0.95648,0.542322,0.680612,7.127642,0.330278,0.281387,1.030808,0.716807
MICN,0.808267,0.83656,0.495261,0.511051,2.975357,0.345312,0.213443,0.793259,0.573777
Koopa,0.586677,0.475733,0.427356,0.353568,1.97609,0.307103,0.259031,1.624465,0.679473
FEDformer,0.520478,0.481572,0.502541,0.421318,3.101076,0.414975,0.272141,1.173514,0.632994
Reformer,1.153885,2.984038,1.149968,3.013526,4.491176,0.53173,0.313552,1.838381,0.693562
SegRNN,0.448591,0.422789,0.453474,0.387259,4.596816,0.356438,0.260572,1.014516,0.859476


96


Unnamed: 0,ETTh1,ETTh2,ETTm1,ETTm2,ili,weather,ECL,Exchange,traffic
Mamba,0.488289,0.356409,0.362367,0.197037,4.103781,0.192429,0.190015,0.1363,0.702669
Autoformer,0.434532,0.396502,0.506288,0.258274,3.657251,0.240484,0.213805,0.137937,0.613174
PatchTST,0.390603,0.295004,0.327204,0.181935,2.168191,0.175675,0.196937,0.086641,0.471588
DLinear,0.395788,0.343666,0.345455,0.194423,4.765968,0.19522,0.210297,0.093927,0.696605
LightTS,0.448935,0.393748,0.360316,0.225445,7.072687,0.170235,0.214226,0.1331,0.611749
MICN,0.411357,0.379799,0.320593,0.185391,3.408669,0.191264,0.170311,0.092951,0.51674
Koopa,0.403704,0.308469,,,2.209361,0.167305,,0.091252,
FEDformer,0.401848,0.343666,0.378744,0.192951,3.153684,0.216076,,0.15816,0.586253
Reformer,0.85922,1.854121,0.89724,0.766816,3.909399,0.377578,,1.092328,0.700269
SegRNN,0.370867,0.282564,0.330613,0.172851,3.564504,0.165622,0.18816,0.091455,0.776648


Unnamed: 0,ETTh1,ETTh2,ETTm1,ETTm2,ili,weather,ECL,Exchange,traffic
Mamba_MSE,0.544283,0.465389,0.500843,0.356253,3.729039,0.291097,0.20924,0.713993,0.678976
Mamba_MAE,0.50361,0.447625,0.465595,0.370427,1.335222,0.315284,0.312102,0.561838,0.380339
Autoformer_MSE,0.491996,0.445958,0.531958,0.330328,3.156397,0.339242,0.233789,0.506281,0.636805
Autoformer_MAE,0.485448,0.46048,0.495872,0.367878,1.207956,0.379242,0.340127,0.500252,0.398083
PatchTST_MSE,0.456788,0.384949,0.389561,0.28771,2.16023,0.255676,0.208632,0.381043,0.496685
PatchTST_MAE,0.453029,0.409276,0.403743,0.333946,0.901482,0.278579,0.298512,0.412237,0.320605
DLinear_MSE,0.464734,0.565983,0.40441,0.349204,4.367429,0.265196,0.225354,0.346379,0.67259
DLinear_MAE,0.460744,0.520421,0.407232,0.398921,1.540458,0.31666,0.318775,0.414157,0.418584
LightTS_MSE,0.530469,0.632847,0.438023,0.431896,7.078158,0.244622,0.242691,0.486086,0.656383
LightTS_MAE,0.504848,0.55131,0.444865,0.447721,1.975164,0.295435,0.343611,0.493447,0.427872
