In [1]:
from gluonts.evaluation.backtest import make_evaluation_predictions
from gluonts.dataset.multivariate_grouper import MultivariateGrouper
import numpy as np
import pandas as pd
from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset


def create_dataset_csv(dataset_simple,):
    print('--------------------------------------------------')
    print(f'create {dataset_simple} dataset csv')
    dataset_alias = {
        'sol':'solar_nips',
        # 'fina':'m3_other',
        'elec':'electricity_nips',
        'traf':'traffic_nips',
        'cup':'kdd_cup_2018_without_missing',
        'taxi':'taxi_30min',
        'wiki':'wiki-rolling_nips',
        # 'exc':'exchange_rate_nips',
        #  'fre':'fred_md',
        }
    
    dataset_name = dataset_alias[dataset_simple]
    dataset = get_dataset(dataset_name, regenerate=False)
    metadata, train_data, test_data = dataset.metadata, dataset.train, dataset.test
    print("metadata", metadata)
    train_grouper = MultivariateGrouper(max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))
    test_grouper = MultivariateGrouper(num_test_dates=int(len(dataset.test)/len(dataset.train)), 
                                    max_target_dim=min(2000, int(dataset.metadata.feat_static_cat[0].cardinality)))

    print("prepare the dataset")
    print(f'len(train_data): {len(train_data)}, len(test_data): {len(test_data)}')

    # if dataset_simple == 'taxi':
    #     for i in train_data:
    #         print(len(i['target']), i['start'], i['item_id'])
    #         print(i)
    #         break
    #     for i in test_data:
    #         print(len(i['target']), i['start'], i['item_id'])
    #         print(i)
    #         break
    #     train_data = train_grouper(train_data)
    #     test_data = test_grouper(test_data)
    #     print(len(train_data), len(test_data))
    #     print(test_data)
    #     test_data = list(test_data)
    #     for i in range(len(test_data)):
    #         # print(len(i['target']), i['start'], i['item_id'])
    #         print(len(test_data[i]['target'][0]))

    # group the dataset
    train_data=train_grouper(dataset.train)
    if dataset_simple == 'cup': # cup的test数据长度不一致，需要处理
        test_data = list(test_data)
        for i in range(len(test_data)):
            if len(test_data[i]['target']) == 10898:
                # 补充8个0
                test_data[i]['target'] = np.concatenate((test_data[i]['target'], np.zeros(8)))
                # 去掉最后8个
                # test_data[i]['target'] = test_data[i]['target'][:-8]
            # print(len(test_data[i]['target']),test_data[i]['start'],test_data[i]['item_id'])
        test_data = test_grouper(test_data)
    else:
        test_data=test_grouper(dataset.test)

    # merge the train and test data
    train_data = list(train_data)
    test_data = list(test_data)
    print(f'train_data.shape: {train_data[0]["target"].shape}')
    print(f'test_data.shape: {test_data[-1]["target"].shape}')
    train_data_T = np.array(train_data[0]['target']).T
    test_data_T = np.array(test_data[-1]['target']).T
    print(f'train_data_T.shape: {train_data_T.shape}')
    print(f'test_data_T.shape: {test_data_T.shape}')

    print(f'train_data_T[-1][:10]: {train_data_T[-1][:10]}')
    print(f'test_data_T[-1][:10]: {test_data_T[-1][:10]}')


    prediction_length = metadata.prediction_length
    test_length = len(test_data)*prediction_length
    if dataset_simple =='taxi':
        # 与train 的部分没有重叠，train的start是2015-01-01 00:00:00，test的start是2016-01-01 00:00:00
        test_data_T_unic = test_data_T[-test_length-prediction_length:]
    else:
        # 与train 的部分有重叠
        test_data_T_unic = test_data_T[-test_length:]

    print((train_data_T[-1][-10:]))
    print((test_data_T[-len(test_data)*prediction_length-1][-10:]))
    
    data_all = np.concatenate((train_data_T, test_data_T_unic), axis=0)
    print(f'data_all.shape: {data_all.shape}')

    # generate dataframe
    metadata = dataset.metadata
    print("metadata", metadata)
    freq = metadata.freq

    start = pd.Timestamp("2012-01-01 00:00:00") # 开始时间 是 2012-01-01 00:00:00
    index = pd.date_range(start=start, freq=freq, periods=len(data_all)) # 生成时间序列，间隔是freq，长度是len(data_all)
    df = pd.DataFrame(data_all, index=index, columns=range(data_all.shape[1])) # 创建一个dataframe，index是时间序列，columns是0,1,2,3,4,5,6,7,8,9
    df.index.name = 'date'
    print(f'df.shape: {df.shape}')

    test_len = len(test_data)*prediction_length
    valid_len = min(7* prediction_length, test_len)
    train_len = len(df) - test_len - valid_len

    if dataset_simple == 'taxi':
        train_len = len(df) - test_len - valid_len - prediction_length # test多添加了一部分，应该去掉

    print("train_len", train_len)
    print("valid_len", valid_len)
    print("test_len", test_len)
    print("prediction_length", prediction_length)

    train_start = 0
    train_end = train_start + train_len
    valid_start = train_end
    valid_end = valid_start + valid_len
    test_start = valid_end
    test_end = test_start + test_len
    
    
    df.to_csv(f'./{dataset_simple}.csv', index=False)
    print(f'./{dataset_simple}.csv saved')
    return train_start, train_end, valid_start, valid_end, test_start, test_end, prediction_length

dataset_split_index = {'dataset':[],'train_start':[], 'train_end':[], 'valid_start':[], 'valid_end':[], 'test_start':[], 'test_end':[], 'prediction_length':[]}
transformed_dataset = ['sol', 'elec', 'traf', 'cup', 'taxi', 'wiki']
# transformed_dataset = ['taxi', 'wiki']
# transformed_dataset = ['cup']
# transformed_dataset = ['taxi']
for dataset in transformed_dataset:
    train_start, train_end, valid_start, valid_end, test_start, test_end, prediction_length = create_dataset_csv(dataset)
    dataset_split_index['dataset'].append(dataset)
    dataset_split_index['train_start'].append(train_start)
    dataset_split_index['train_end'].append(train_end)
    dataset_split_index['valid_start'].append(valid_start)
    dataset_split_index['valid_end'].append(valid_end)
    dataset_split_index['test_start'].append(test_start)
    dataset_split_index['test_end'].append(test_end)
    dataset_split_index['prediction_length'].append(prediction_length)

    dataset_split_index_df = pd.DataFrame(dataset_split_index)
    dataset_split_index_df.to_csv('./dataset_split_index.csv', index=True)




--------------------------------------------------
create sol dataset csv
metadata freq='H' target=None feat_static_cat=[CategoricalFeatureInfo(name='feat_static_cat_0', cardinality='137')] feat_static_real=[] feat_dynamic_real=[] feat_dynamic_cat=[] prediction_length=24
prepare the dataset
len(train_data): 137, len(test_data): 959
train_data.shape: (137, 7009)
test_data.shape: (137, 7177)
train_data_T.shape: (7009, 137)
test_data_T.shape: (7177, 137)
train_data_T[-1][:10]: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
test_data_T[-1][:10]: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
data_all.shape: (7177, 137)
metadata freq='H' target=None feat_static_cat=[CategoricalFeatureInfo(name='feat_static_cat_0', cardinality='137')] feat_static_real=[] feat_dynamic_real=[] feat_dynamic_cat=[] prediction_length=24
df.shape: (7177, 137)
train_len 6841
valid_len 168
test_len 168
prediction_length 24
./sol.csv saved
--------------------------------------------

Bad pipe message: %s [b'\xbfT\x97\xf6\xe1\xcaZ\x97\xf7x\xaf\x9f\x07P\xdc\x06\xde\xb5 ']
Bad pipe message: %s [b'\x19Y5\xda\x1a\xe1\x0cN\x16\xae\xe2\xc2\xa1\x1c\xb0\xab\xd9\x91 gw\xc9\x90\xc7^tR\x15\x9c\xa3\xa9`/Nt\x00\x0f<\xeb\x1b\xe8\xe6\x0b\x13|q\xe1GCc\xc4\x00\x08']
Bad pipe message: %s [b'\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00']
Bad pipe message: %s [b"~\xb3\x0c\xf9\xd2\xd4\xd5\x93=\xa8\xcc\x0bwy\xe4\x19\xca\n\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0#\xc0'\x00g\x00@\xc0\n\xc0\x14\x009\x008\xc0\t\xc0\x13\x003\x002\x00\x9d\xc0\xa1\xc0\x9d\xc0Q\x00\x9c\xc0\xa0\xc0\x9c\xc0P\x00=\x00<\x005\x00/\x00\x9a\x00\x99\xc0\x07\xc0\x11\x00\x96\x00\x05\x00\xff\x01\x00\x00j\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1"]
Bad pipe message: %s [b'{\xd8>\xe8f\xf1\xd4\xb1=,\xf1w\xd7D\xeb\xc8~P\x00\x00>