# Task04. 下游任务的端到端学习

# 1. 端到端学习的重要性

### 在Task04中我们使用PyPOTS对自定义数据集进行了两阶段建模, 即在上游先对数据集中的缺失值进行插补, 然后对处理好的数据集进行下游任务建模. 同时, 我们也强调了缺失值本身也是数据集的特性, 数据的缺失模式可能携带了额外的信息来表征数据采集对象的状态, 而这些信息在两阶段处理中可能会丢失, 因为插补值是模型根据观测到的数据分布推测出来的, 插补后下游算法无法知道原数据中的缺失模式, 也就无法充分利用这部分信息来学习. 端到端学习则是使用一个模型直接接受包含缺失值的数据然后在特定任务上进行学习. 虽然在很多时候我们无法提前判断在某个时序数据集上是使用两阶段方法还是端到端方法能够取得更好的效果 (因为最终的效果涉及到很多方面, 包括但不限于模型自身的能力, 超参调优 等等), 但是端到端方法显然拥有更大的潜力

# 2. 使用BRITS直接在PhysioNet2012上进行分类

### 2.1 数据加载

In [1]:
from benchpots.datasets import preprocess_physionet2012

physionet2012_dataset = preprocess_physionet2012(
    subset="set-a", 
    pattern="point", 
    rate=0.1,
)

dataset_for_training = {
    "X": physionet2012_dataset['train_X'],
    "y": physionet2012_dataset['train_y'],
}

dataset_for_validating = {
    "X": physionet2012_dataset['val_X'],
    "y": physionet2012_dataset['val_y'],
}

dataset_for_testing = {
    "X": physionet2012_dataset['test_X'],
    "y": physionet2012_dataset['test_y'],
}

2025-05-10 23:27:30 [INFO]: You're using dataset physionet_2012, please cite it properly in your work. You can find its reference information at the below link: 
https://github.com/WenjieDu/TSDB/tree/main/dataset_profiles/physionet_2012
2025-05-10 23:27:30 [INFO]: Dataset physionet_2012 has already been downloaded. Processing directly...
2025-05-10 23:27:30 [INFO]: Dataset physionet_2012 has already been cached. Loading from cache directly...
2025-05-10 23:27:30 [INFO]: Loaded successfully!
2025-05-10 23:27:33 [INFO]: 23355 values masked out in the val set as ground truth, take 10.11% of the original observed values
2025-05-10 23:27:33 [INFO]: 28874 values masked out in the test set as ground truth, take 10.06% of the original observed values
2025-05-10 23:27:33 [INFO]: Total sample number: 3997
2025-05-10 23:27:33 [INFO]: Training set size: 2557 (63.97%)
2025-05-10 23:27:33 [INFO]: Validation set size: 640 (16.01%)
2025-05-10 23:27:33 [INFO]: Test set size: 800 (20.02%)
2025-05-10 23:

### 2.2 使用PyPOTS进行模型训练

In [2]:
from pypots.classification import BRITS

brits = BRITS(
    n_steps=physionet2012_dataset['n_steps'],
    n_features=physionet2012_dataset['n_features'],
    n_classes=physionet2012_dataset["n_classes"],
    rnn_hidden_size=128,
    epochs=20,
    patience=5,
)

brits.fit(dataset_for_training, dataset_for_validating)

  from .autonotebook import tqdm as notebook_tqdm
2025-05-10 23:27:35 [INFO]: No given device, using default device: cpu
2025-05-10 23:27:35 [INFO]: Using customized CrossEntropy as the training loss function.
2025-05-10 23:27:35 [INFO]: Using customized CrossEntropy as the validation metric function.
2025-05-10 23:27:35 [INFO]: BRITS initialized with the given hyperparameters, the number of trainable parameters: 239,860


[34m
████████╗██╗███╗   ███╗███████╗    ███████╗███████╗██████╗ ██╗███████╗███████╗    █████╗ ██╗
╚══██╔══╝██║████╗ ████║██╔════╝    ██╔════╝██╔════╝██╔══██╗██║██╔════╝██╔════╝   ██╔══██╗██║
   ██║   ██║██╔████╔██║█████╗█████╗███████╗█████╗  ██████╔╝██║█████╗  ███████╗   ███████║██║
   ██║   ██║██║╚██╔╝██║██╔══╝╚════╝╚════██║██╔══╝  ██╔══██╗██║██╔══╝  ╚════██║   ██╔══██║██║
   ██║   ██║██║ ╚═╝ ██║███████╗    ███████║███████╗██║  ██║██║███████╗███████║██╗██║  ██║██║
   ╚═╝   ╚═╝╚═╝     ╚═╝╚══════╝    ╚══════╝╚══════╝╚═╝  ╚═╝╚═╝╚══════╝╚══════╝╚═╝╚═╝  ╚═╝╚═╝
ai4ts v0.0.2 - building AI for unified time-series analysis, https://time-series.ai [0m



2025-05-10 23:27:55 [INFO]: Epoch 001 - training loss (CrossEntropy): 1.6684, validation CrossEntropy: 0.3586
2025-05-10 23:28:10 [INFO]: Epoch 002 - training loss (CrossEntropy): 1.3455, validation CrossEntropy: 0.3242
2025-05-10 23:28:25 [INFO]: Epoch 003 - training loss (CrossEntropy): 1.2476, validation CrossEntropy: 0.3078
2025-05-10 23:28:41 [INFO]: Epoch 004 - training loss (CrossEntropy): 1.1837, validation CrossEntropy: 0.2944
2025-05-10 23:28:55 [INFO]: Epoch 005 - training loss (CrossEntropy): 1.1343, validation CrossEntropy: 0.2714
2025-05-10 23:29:08 [INFO]: Epoch 006 - training loss (CrossEntropy): 1.0961, validation CrossEntropy: 0.2644
2025-05-10 23:29:21 [INFO]: Epoch 007 - training loss (CrossEntropy): 1.0771, validation CrossEntropy: 0.2475
2025-05-10 23:29:34 [INFO]: Epoch 008 - training loss (CrossEntropy): 1.0568, validation CrossEntropy: 0.2329
2025-05-10 23:29:46 [INFO]: Epoch 009 - training loss (CrossEntropy): 1.0158, validation CrossEntropy: 0.2196
2025-05-10

### 2.3 计算分类精度

In [3]:
from pypots.nn.functional.classification import calc_binary_classification_metrics

brits_results = brits.predict(dataset_for_testing)
brits_prediction = brits_results["classification"]

classification_metrics=calc_binary_classification_metrics(
    brits_prediction, dataset_for_testing["y"]
)
print(f"BRITS在测试集上的ROC-AUC为: {classification_metrics['roc_auc']:.4f}\n")
print(f"BRITS在测试集上的PR-AUC为: {classification_metrics['pr_auc']:.4f}\n")

BRITS在测试集上的ROC-AUC为: 0.5823

BRITS在测试集上的PR-AUC为: 0.4272



# 3. 阅读材料

### Cao, W., Wang, D., Li, J., Zhou, H., Li, L., & Li, Y. (2018). [BRITS: Bidirectional Recurrent Imputation for Time Series](https://arxiv.org/abs/1805.10572). *NeurIPS 2018*.
#### 推荐原因: 该文是时序插补领域绕不开的一篇文章. 该文在GRU-D和M-RNN模型的基础上做了改进, 效果获得明显提升. 文章被人工智能顶级会议NeurIPS 2018收录. 截止2025年5月Google Scholar上引用800+.