In [1]:
import os
import sys

sys.path.append(os.path.join(os.getcwd(), 'test'))

In [2]:
import torch
import pandas as pd 
import matplotlib.pyplot as plt

In [3]:
import argparse


def parser_args(args_list):

    parser = argparse.ArgumentParser(description='ITransformer')

    parser.add_argument('--seq_len', type=int, required=False, default=288,
                        help='input the sequence of length')
    parser.add_argument('--pred_len', type=int, required=False, default=96,
                        help='output the sequence of length')
    parser.add_argument('--d_model', type=int, required=False, default=128, 
                        help='the dimension of model')
    parser.add_argument('--n_layers', type=int, required=False, default=4,
                        help='the number of layers')
    parser.add_argument('--factor', type=int, required=False, default=6,
                        help='the number of features')
    parser.add_argument('--n_heads', type=int, required=False, default=4,
                        help='the number of heads')
    parser.add_argument('--activation', type=str, default='gelu',
                        help='activation')
    parser.add_argument('--dropout', type=float, default=0.1,
                        help='dropout')
    parser.add_argument('--d_ff', type=int, default=2048,
                        help='dimension of fcn')
    parser.add_argument('--output_attention', action='store_true',
                        help='whether to output attention in encoder')
    parser.add_argument('--use_norm', type=int, default=False,
                        help='use norm and denorm')
    parser.add_argument('--embed', type=str, default='timeF',
                        help='time features encoding, options:[timeF, fixed, learned]')
    parser.add_argument('--lr', type=str, default=0.0001,
                        help='learning rate')
    parser.add_argument('--freq', type=str, default='t',
                        help='time frequency')
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')

    parser.add_argument('--model_path', type=str, 
                        default='./result/result/iTransformer_model_version_l1.pth')

    if args_list:
        return parser.parse_args(args_list)
    else:
        return parser.parse_args()


In [4]:
from torch.utils.data import DataLoader
from sklearn.metrics import mean_absolute_error
from iTransformer_test import Model_test, test_data_process, dataset

In [5]:
args_list = [
    '--seq_len', '288',
    '--pred_len', '96',
    '--d_model', '128',
    '--n_layers', '4',
    '--factor', '6',
    '--n_heads', '4',
    '--activation', 'gelu',
    '--dropout', '0.1',
    '--d_ff', '2048',
    '--output_attention',
    '--use_norm', False,
    '--embed', 'timeF',
    '--lr', '0.0001',
    '--freq', 't',
    '--device', 'cuda' if torch.cuda.is_available() else 'cpu',
    '--model_path', './result/iTransformer_model_version_l1.pth'
]

In [6]:
# 数据
d_process = test_data_process(path='./src/data/sub_df_test.xlsx', window_length=3, num_features=6)
x, y = d_process.do()
data_set = dataset(x=x, y=y)

test_loader = DataLoader(dataset=data_set, num_workers=4, shuffle=False, batch_size=1)
history_train = pd.read_excel('./src/data/sub_df_test.xlsx')

df = d_process.data['Date']
df['timestamp'] = pd.to_datetime(d_process.data['Date'])

# 提取日期部分，并获取唯一值
unique_dates = df['timestamp'].dt.date.unique()
# 格式化日期为 'YYYY-MM-DD'
formatted_dates = [pd.Timestamp(date).strftime('%Y-%m-%d') for date in unique_dates]

# 打印结果
print(formatted_dates)

['2024-06-15', '2024-06-16', '2024-06-17', '2024-06-18', '2024-06-19', '2024-06-20', '2024-06-21', '2024-06-22', '2024-06-23', '2024-06-24', '2024-06-25', '2024-06-26', '2024-06-27', '2024-06-28', '2024-06-29', '2024-06-30', '2024-07-01', '2024-07-02', '2024-07-03', '2024-07-04', '2024-07-05', '2024-07-06', '2024-07-07', '2024-07-08', '2024-07-09', '2024-07-10', '2024-07-11', '2024-07-12', '2024-07-13', '2024-07-14', '2024-07-15', '2024-07-16']


In [7]:
# 模型参数
configs = parser_args(args_list=args_list)
predict_model = Model_test(configs=configs)

# 测试时间
window_len = 3
test_time = formatted_dates 

# 执行
MAE = 0
for idx, (x, y) in enumerate(test_loader):

    predict_max = history_train.iloc[96*(idx+window_len-1):96*(idx+window_len), -1].max()
    predict_min = history_train.iloc[96*(idx+window_len-1):96*(idx+window_len), -1].min()
    predict_result = predict_model.predict(x_enc=x, x_mark_enc=None, x_dec=None,x_mark_dec=None)

    print(predict_result)

    predict_result = predict_result * (predict_max - predict_min) + predict_min
    predict_mae = mean_absolute_error(predict_result[0].tolist(), y[0].tolist())
                                      
    MAE = MAE + predict_mae

    # 可视化预测和真实结果
    plt.figure(figsize=(4, 3))
    plt.title('Date:' + test_time[idx] + ' \n MAE:' + str(predict_mae))
    plt.xlabel('Time')
    plt.ylabel('Price')
    plt.plot(range(0, 96), predict_result[0].tolist(), label='Predict')
    plt.plot(range(0, 96), y[0].tolist(), label='Real')
    plt.legend()
    plt.show()


print('测试集的MAE', MAE/len(test_loader))


tensor([[0.5472, 0.5570, 0.5334, 0.5206, 0.5466, 0.4926, 0.4747, 0.4999, 0.4844,
         0.4897, 0.4954, 0.4726, 0.5103, 0.4774, 0.4747, 0.4757, 0.4910, 0.4795,
         0.5108, 0.5082, 0.4994, 0.5378, 0.5590, 0.5807, 0.5881, 0.6428, 0.6534,
         0.6369, 0.6489, 0.6075, 0.6216, 0.5559, 0.6024, 0.5493, 0.5440, 0.5015,
         0.4532, 0.3627, 0.3215, 0.2800, 0.2384, 0.1489, 0.1075, 0.1002, 0.1002,
         0.0910, 0.0546, 0.0049, 0.0023, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0094, 0.0222, 0.0480, 0.0896, 0.1242, 0.1413, 0.2102, 0.2682,
         0.3540, 0.4373, 0.5320, 0.5837, 0.5857, 0.6071, 0.7380, 0.8017, 0.8312,
         0.8672, 0.8836, 0.8470, 0.9014, 0.8596, 0.8634, 0.8634, 0.8419, 0.8162,
         0.7967, 0.7641, 0.7741, 0.7655, 0.7338, 0.7227, 0.6847, 0.7181, 0.6566,
         0.5926, 0.5223, 0.5370, 0.4940, 0.4556, 0.4549]],
       grad_fn=<SqueezeBackward1>)
tensor([[0.5472, 0.5569, 0.5334, 0.5207, 0.5466, 0.4926, 0.4747, 0.4998, 0.4844,
         0.4897