In [1]:
%run set_jupyter.py
%matplotlib inline 

当前工作目录：D:\dlquant
matplotlib显示字体已设置为中文。


In [2]:
import torch
from torch.nn.modules.loss import MSELoss, CrossEntropyLoss
from torch.nn import MSELoss, BCEWithLogitsLoss 

from pytorch_lightning import Trainer, loggers as pl_loggers  
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, Callback
 
from darts.utils.callbacks import TFMProgressBar
from darts.models import TFTModel
from darts.metrics import mape, mse, mae
from sklearn.metrics import mean_squared_error, precision_score 

import optuna  
from pathlib import Path

# 自定义
from config import TIMESERIES_LENGTH # 测试和验证数据长度设置
from data_precessing.timeseries import prepare_timeseries_data  # 获取训练数据、验证数据和测试数据
from utils.model import MAPELoss, LossLogger
from config import TIMESERIES_LENGTH
from models.params import get_pl_trainer_kwargs, early_stopper, progress_bar

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
loss_logger = LossLogger()

progress_bar = TFMProgressBar(  
        enable_sanity_check_bar=False, enable_validation_bar=False  
    ) 

early_stopper = EarlyStopping(  
            monitor="val_loss",  
            patience=10,  
            min_delta=1e-6,  
            mode="min",  
        )  
model_name = "TFTModel"

work_dir = Path(f"logs/{model_name}_logs").resolve() 
work_dir

WindowsPath('D:/dlquant/logs/TFTModel_logs')

In [4]:
data = prepare_timeseries_data('training')

***** xtdata连接成功 *****
服务信息: {'tag': 'sp3', 'version': '1.0'}
服务地址: 127.0.0.1:58610
数据路径: C:\e_trader\bin.x64/../userdata_mini/datadir
设置xtdata.enable_hello = False可隐藏此消息



[32m2024-09-14 21:13:14 | INFO     | download_xt_data:36 - 成功下载股票数据：510050.SH[0m
[32m2024-09-14 21:13:15 | INFO     | download_xt_data:36 - 成功下载股票数据：510300.SH[0m
[32m2024-09-14 21:13:15 | INFO     | download_xt_data:36 - 成功下载股票数据：510500.SH[0m
[32m2024-09-14 21:13:15 | INFO     | download_xt_data:36 - 成功下载股票数据：511260.SH[0m
[32m2024-09-14 21:13:15 | INFO     | download_xt_data:36 - 成功下载股票数据：511010.SH[0m
[32m2024-09-14 21:13:15 | INFO     | download_xt_data:36 - 成功下载股票数据：512010.SH[0m
[32m2024-09-14 21:13:15 | INFO     | download_xt_data:36 - 成功下载股票数据：512040.SH[0m
[32m2024-09-14 21:13:16 | INFO     | download_xt_data:36 - 成功下载股票数据：512690.SH[0m
[32m2024-09-14 21:13:16 | INFO     | download_xt_data:36 - 成功下载股票数据：512290.SH[0m
[32m2024-09-14 21:13:16 | INFO     | download_xt_data:36 - 成功下载股票数据：513050.SH[0m
[32m2024-09-14 21:13:16 | INFO     | download_xt_data:36 - 成功下载股票数据：513100.SH[0m
[32m2024-09-14 21:13:16 | INFO     | download_xt_data:36 - 成功下载股票数据：513500.SH[0m
[32

In [5]:
model = TFTModel(
    input_chunk_length=20, 
    output_chunk_length=5, 
    output_chunk_shift=0, 
    hidden_size=16, 
    lstm_layers=1, 
    num_attention_heads=4, 
    full_attention=False, 
    feed_forward='GatedResidualNetwork', 
    dropout=0.1, 
    hidden_continuous_size=8, 
    categorical_embedding_sizes=None, 
    add_relative_index=False, 
    norm_type='LayerNorm', 
    use_static_covariates=True,
    
    loss_fn=BCEWithLogitsLoss(), 
    pl_trainer_kwargs = get_pl_trainer_kwargs(),

    work_dir = work_dir, 
    save_checkpoints = True,
    force_reset=True,
    model_name = model_name
)

In [None]:
model.fit(
    # 训练集
    series=data['train'],
    past_covariates=data['past_covariates'],
    future_covariates=data['future_covariates'],
    # 验证集
    val_series=data['val'],
    val_past_covariates=data['past_covariates'],
    val_future_covariates=data['future_covariates'],
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

   | Name                              | Type                             | Params | Mode 
------------------------------------------------------------------------------------------------
0  | criterion                         | BCEWithLogitsLoss                | 0      | train
1  | train_criterion                   | BCEWithLogitsLoss                | 0      | train
2  | val_criterion                     | BCEWithLogitsLoss                | 0      | train
3  | train_metrics                     | MetricCollection                 | 0      | train
4  | val_metrics                       | MetricCollection                 | 0      | train
5  | input_embeddings                  | _MultiEmbedding                  | 0      | train
6  | static_covariates_vsn             | _VariableSelectionNetwork        | 0      | train
7  | encoder_vsn                       | _VariableSelectionNetw

Epoch 3:  48%|████▊     | 28/58 [01:36<01:43,  3.45s/it, train_loss=0.486, val_loss=0.666]

In [None]:
model = model.load_from_checkpoint(model_name=model_name, work_dir=work_dir)

In [None]:
pred_steps = TIMESERIES_LENGTH["test_length"]
pred_input = data["test"][:-pred_steps]

pred_series = model.predict(n=pred_steps, series=pred_input)

# 对预测结果进行二值化和展平 
true_labels = data["test"][-pred_steps:].values()  
true_labels = true_labels.astype(int).flatten()  # Flatten to 1D   
binary_predictions = pred_series.values() > 0.5  
binary_predictions = binary_predictions.astype(int).flatten()  

# 计算精确率  
precision = precision_score(true_labels, binary_predictions)  
precision

In [None]:
for i, stock in enumerate(data["test"].columns[:3]):
    data["test"][-pred_steps:].data_array().sel(component=stock).plot(label=f"{stock}_实际数据")
    pred_series.data_array().sel(component=stock).plot(label=f"{stock}_预测结果")
    # data['test'].slice_intersect(hfc).data_array().sel(component=stock).plot(label=f"{stock}_实际数据")
    plt.legend()
    plt.show()