Skip to content

JJASMINE22/Sequence-Transformer-for-Long-term-sequence-forecasting

Repository files navigation

Sequence Transformer based on Gconv-MultiHeadAttention --Pytorch


目录

  1. 所需环境 Environment
  2. 注意力结构 Attention Structure
  3. 模型结构 Model Structure
  4. 注意事项 Cautions
  5. 文件下载 Download
  6. 训练步骤 How2train
  7. 预测效果 predict
  8. 参考资料 Reference

所需环境

  1. Python3.7
  2. PyTorch>=1.10.1+cu113
  3. numpy==1.19.5
  4. pandas==1.2.4
  5. pyod==0.9.8
  6. matplotlib==3.2.2
  7. CUDA 11.0+

注意力结构

image

模型结构

Encoder
由全连接层、一维分组卷积多头注意力机制组成
image

Decoder
由全连接层、一维分组卷积多头注意力机制组成
image

Sequence Transformer
合并Encoder-Decoder,拼接全连接层
image

注意事项

  1. 时序数据推理,删除了标准Transformer的位置掩码、位置编码、前馈层等机制
  2. 使用一个正态分布变量替代起始序列特征
  3. 将Linear MultiHeadAttention替换为GConv MultiHeadAttention
  4. 训练时,并行推理解码序列;预测时,贯续推理解码序列
  5. 提出特殊的边界序列填充方法,克服卷积操作引发的差异性,保证训练、预测阶段的运算机制相同
  6. 保留三角掩码,防止特征泄露
  7. 加入权重正则化操作,防止过拟合

文件下载

链接:https://pan.baidu.com/s/13T1Qs4NZL8NS4yoxCi-Qyw 提取码:sets 下载解压后放置于config.py中设置的路径即可。

训练步骤

运行train.py即可开始训练。

预测效果

sequence_1
image

sequence_2
image

sequence_3
image

参考资料

https://arxiv.org/pdf/1706.03762.pdf

About

Replace Linear MultiHeadAttention mechanism with GConv

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages