In [1]:
import os
import yaml
import pandas as pd

In [2]:
__dataset__ = [
    'ETTh1',
    'ETTh2',
    'ETTm1',
    'ETTm2',
    'ECL',
    'Weather',
    'Traffic',
    'Exchange',
    'national_illness'
]

In [3]:
def parse_dataset(configs, key):
    data = configs['data']
    data_path = configs['data_path']
    model_id = configs['model_id']
    
    dataset = []
    for ds in __dataset__:
        if ds in data_path:
            if 'national_illness' in ds:
                dataset.append('ILI')
            else:
                dataset.append(ds.replace('.csv', '').replace('.pkl', '').replace('.npy', ''))
        if ds in model_id:
            dataset.append(ds.split('_')[0])
        if ds in data:
            dataset.append(ds)

    dataset = list(set(dataset))  # Remove duplicates    
    
    if len(dataset) > 1:
        raise ValueError(f"Multiple datasets found in configs: {dataset}")
    if len(dataset) == 1:
        return dataset[0]
    raise ValueError(f"Dataset not found in configs: {configs}")

In [4]:
root_path = '../results'

metrics = {}

for folders in os.listdir(root_path):
    if not os.path.isdir(os.path.join(root_path, folders)):
        continue
    args = os.path.join(root_path, folders, 'args.yaml')
    assert os.path.exists(args), f"File {args} does not exist."
    
    exp = folders.split('Exp')[-1].replace('_', '')
    exp = int(exp)
    
    with open(args, 'r') as f:
        configs = yaml.safe_load(f)
    
    metric_file = os.path.join(root_path, folders, 'metrics.csv')
    metric_file = pd.read_csv(metric_file)
    
    dataset = parse_dataset(configs, __dataset__)
    pred_len = configs['pred_len']
    seq_len = configs['seq_len']
    model = configs['model']
    mse = metric_file['mse'].values[0]
    mae = metric_file['mae'].values[0]
    rmse = metric_file['rmse'].values[0]
    mape = metric_file['mape'].values[0]
    if dataset not in metrics:
        metrics[dataset] = {}
    
    key = (
        model,
        pred_len,
        seq_len,
    )
    if dataset not in metrics:
        metrics[dataset] = {}
        
    if key not in metrics[dataset]:
        metrics[dataset][key] = {
            'mse': [mse],
            'mae': [mae],
            'rmse': [rmse],
            'mape': [mape],
            'num_exp': [exp],
        }
    else:
        metrics[dataset][key]['mse'].append(mse)
        metrics[dataset][key]['mae'].append(mae)
        metrics[dataset][key]['rmse'].append(rmse)
        metrics[dataset][key]['mape'].append(mape)
        metrics[dataset][key]['num_exp'].append(exp)

reduce_keys = [
    'mse',
    'mae',
    'rmse',
    'mape',
    'num_exp',
]

for dataset in metrics:
    for key in metrics[dataset]:
        for reduce_key in reduce_keys:
            if reduce_key == 'num_exp':
                mean = int(len(metrics[dataset][key][reduce_key]))
            else:
                mean = sum(metrics[dataset][key][reduce_key]) / len(metrics[dataset][key][reduce_key])
            metrics[dataset][key][f'{reduce_key}'] = f'{mean:.3f}'
            # metrics[dataset][key][reduce_key] = mean

save_path = '01.metrics.xlsx'

with pd.ExcelWriter(save_path) as writer:
    for dataset in sorted(metrics):
        for key in metrics[dataset]:
            df = pd.DataFrame.from_dict(metrics[dataset], orient='index')
            df = df.reset_index()
            # df.columns = ['model', 'pred_len', 'decompose', 'depth', 'embed_norm', 'use_energy', 'mse', 'mae', 'rmse', 'mape', 'exp'] + [f'{reduce_key}-str' for reduce_key in reduce_keys]
            df.columns = ['model', 'pred_len', 'seq_len'] + [f'{reduce_key}-str' for reduce_key in reduce_keys]

            df = df.sort_values(by=['model', 'pred_len', 'seq_len'])
            df = df.reset_index(drop=True)
            df.to_excel(writer, sheet_name=f'{dataset}', index=False)

print(f"Metrics saved to {save_path}")

Metrics saved to 01.metrics.xlsx
