In [17]:
import numpy as np
import matplotlib.pyplot as plt
import os
import pandas as pd
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import re

In [18]:
arg_component_balance=False
arg_add_transformer=True
arg_add_LLM_TSFM=False
add_all_periods=False

In [19]:
# sota performance
def search_sota_performance(dataset, pred_lens=[96], metric='mse',
                            path_old='./results', 
                            path='./results_long_term_forecasting/results'):
    result_dict = {}
    for pred_len in pred_lens:
        model_list_old = os.listdir(path_old)
        model_list_old = [_ for _ in model_list_old if f'pl{pred_len}' in _ and dataset in _]

        try:
            model_list_new = os.listdir(os.path.join(path, dataset))
            model_list_new = [_ for _ in model_list_new if f'pl{pred_len}' in _]
        except:
            model_list_new = []
        model_list = model_list_old + model_list_new

        result_dict[pred_len] = {}
        for model in model_list:
            try:
                result = np.load(os.path.join(path_old, model, 'metrics.npy'), allow_pickle=True)
            except:
                result = np.load(os.path.join(path, dataset, model, 'metrics.npy'), allow_pickle=True)
            result_dict[pred_len][model] = result[0] if metric == 'mae' else result[1]

    df = pd.DataFrame.from_dict(result_dict[pred_lens[0]], orient='index')
    df.columns = [metric]
    df = df.sort_values(by=metric)
    df.index = [_.split('_')[1] if 'LTF' in _ or 'STF' in _ else _.split('_')[6] for _ in df.index]

    return df

In [20]:
metric = 'mse'
result_path = f'./meta/results'
if arg_add_LLM_TSFM:
    datasets = ['ETTh1', 'ETTh2', 'Exchange', 'ili']
else:
    datasets = ['ETTh1', 'ETTh2', 'ETTm1', 'ETTm2', 'ili', 'weather', 'ECL', 'Exchange', 'traffic']

pred_len_1, pred_len_2 = 96, 24
setting = f'component_balance_{arg_component_balance}-add_transformer_{arg_add_transformer}-add_LLM_TSFM_{arg_add_LLM_TSFM}-all_periods_{add_all_periods}'
file_list = [f'{dataset}-{setting}_{pred_len_1}_{pred_len_2}.npz' for dataset in datasets]

print(len(file_list))


datasets = ['ETTm1', 'ETTm2', 'ETTh1', 'ETTh2', 'ECL', 'traffic', 'weather', 'Exchange', 'ili']
datasets_vis = ['ETTm1', 'ETTm2', 'ETTh1', 'ETTh2', 'ECL', 'Traffic', 'Weather', 'Exchange', 'ILI']
baselines = ['TSGym', 'DUET', 'TimeMixer', 'TSMixer', 'MICN', 'TimesNet', 'PatchTST', 'DLinear', 'Crossformer', 'Pyraformer', 'Autoformer', 'SegRNN']

9


In [21]:
def df_generator(metric):
    # TSGym vs best sota
    dfs_rank = []; dfs_dict = {}
    for pred_len_1, pred_len_2 in zip([96, 192, 336, 720], [24, 36, 48, 60]):
        file_list = [f'{dataset}-{setting}_{pred_len_1}_{pred_len_2}.npz' for dataset in datasets]
        dfs = []
        for i, file in enumerate(file_list):
            # dataset = file.split('-')[0]
            dataset = file[:re.search('-component', file).start()]

            if dataset in ['ili', 'covid-19', 'fred-md']:
                df = search_sota_performance(dataset, pred_lens=[pred_len_2], metric=metric)
            else:
                df = search_sota_performance(dataset, pred_lens=[pred_len_1], metric=metric)

            # 添加图例和标题
            perf_epoch = np.load(os.path.join(result_path, file), allow_pickle=True)
            if metric == 'mse':
                top1_perf_epoch = perf_epoch['top1_perf_epoch']
            else:
                top1_perf_epoch = perf_epoch['top1_perf_epoch_mae']

            df.loc['TSGym'] = top1_perf_epoch[perf_epoch['best_epoch'].item()]
            df.dropna(inplace=True)
            df = df.loc[baselines]
            df = df.sort_values(by=metric)
            df = df.reset_index()
            df.columns = ['model', dataset]
            # todo: 有重复, 并且跑出来结果还不同
            df = df.drop_duplicates(subset='model', keep='first')
            df = df.set_index('model')
            dfs.append(df)

        model_names = set.intersection(*map(set, [_.index.tolist() for _ in dfs]))
        for i, df in enumerate(dfs):
            df = df[[_ in model_names for _ in df.index]]
            dfs[i] = df.sort_values(by=df.columns[0])

        ranks = {k: [] for k in baselines}
        ranks['TSGym'] = []
        for df in dfs:
            for baseline in baselines:
                if len(np.where(df.index == baseline)[0]) > 0:
                    ranks[baseline].append((np.where(df.index == baseline)[0] + 1).item())
            ranks['TSGym'].append((np.where(df.index == 'TSGym')[0] + 1).item())

        dfs = pd.concat(dfs, axis=1)
        dfs = dfs.round(4)
        dfs.index = dfs.index.str.replace('TemporalFusionTransformer', 'TFT')
        dfs = dfs.loc[baselines, datasets]
        dfs.columns = datasets_vis
        dfs_dict[pred_len_1] = dfs.T
        
        df_rank = pd.Series({k: np.mean(v) for k,v in ranks.items() if len(v) > 0})
        df_rank = df_rank.sort_values()
        dfs_rank.append(df_rank)

    dfs_rank = pd.concat(dfs_rank, axis=1)
    dfs_rank.columns = [str(_) for _ in [96, 192, 336, 720]]
    dfs_rank.index = dfs_rank.index.str.replace('TemporalFusionTransformer', 'TFT')
    dfs_rank = dfs_rank.round(2)
    dfs_rank = dfs_rank.T[baselines]

    return dfs_dict, dfs_rank

In [22]:
dfs_dict, dfs_rank_dict = {}, {}
for metric in ['mse', 'mae']:
    dfs, dfs_rank = df_generator(metric)
    dfs = pd.concat(list(dfs.values())).groupby(level=0).mean().round(3)
    dfs = dfs.loc[datasets_vis]

    dfs_dict[metric] = dfs
    dfs_rank_dict[metric] = dfs_rank

In [23]:
df = pd.concat(dfs_dict, axis=1)  # keys=['mse', 'mae']，默认顺序
df = df.swaplevel(axis=1)       # 交换多级索引顺序

# 重新排序，保证模型顺序和二级顺序
df = df.reindex(columns=pd.MultiIndex.from_product([baselines, ['mse', 'mae']]))
df.to_excel('./meta/results_paper/TSGym-vs-SOTA.xlsx')

In [24]:
df

Unnamed: 0_level_0,TSGym,TSGym,DUET,DUET,TimeMixer,TimeMixer,TSMixer,TSMixer,MICN,MICN,...,DLinear,DLinear,Crossformer,Crossformer,Pyraformer,Pyraformer,Autoformer,Autoformer,SegRNN,SegRNN
Unnamed: 0_level_1,mse,mae,mse,mae,mse,mae,mse,mae,mse,mae,...,mse,mae,mse,mae,mse,mae,mse,mae,mse,mae
ETTm1,0.357,0.383,0.407,0.409,0.384,0.399,0.527,0.512,0.402,0.429,...,0.404,0.407,0.501,0.501,0.695,0.593,0.532,0.496,0.388,0.404
ETTm2,0.261,0.319,0.296,0.338,0.277,0.325,1.03,0.75,0.342,0.391,...,0.349,0.399,1.487,0.789,1.565,0.876,0.33,0.368,0.273,0.322
ETTh1,0.426,0.44,0.433,0.437,0.448,0.438,0.615,0.579,0.589,0.537,...,0.465,0.461,0.544,0.52,0.814,0.692,0.492,0.485,0.422,0.429
ETTh2,0.358,0.4,0.38,0.403,0.383,0.406,2.16,1.22,0.585,0.53,...,0.566,0.52,1.552,0.908,3.776,1.557,0.446,0.46,0.374,0.405
ECL,0.17,0.265,0.179,0.262,0.185,0.273,0.229,0.337,0.186,0.297,...,0.225,0.319,0.193,0.289,0.295,0.387,0.234,0.34,0.216,0.302
Traffic,0.435,0.313,0.797,0.427,0.496,0.313,0.599,0.403,0.544,0.32,...,0.673,0.419,1.458,0.782,0.697,0.391,0.637,0.397,0.807,0.411
Weather,0.229,0.268,0.252,0.277,0.244,0.274,0.242,0.301,0.264,0.316,...,0.265,0.317,0.253,0.312,0.284,0.349,0.339,0.379,0.251,0.298
Exchange,0.41,0.431,0.322,0.384,0.359,0.402,0.487,0.546,0.346,0.422,...,0.346,0.414,0.904,0.695,1.183,0.855,0.506,0.5,0.408,0.423
ILI,2.233,1.015,2.64,1.018,4.502,1.557,5.617,1.68,2.938,1.178,...,4.367,1.54,4.311,1.396,4.691,1.442,3.156,1.207,4.305,1.397
