diff --git a/pypots/imputation/crossformer/modules/submodules.py b/pypots/imputation/crossformer/modules/submodules.py index 0df19b81..2a67a227 100644 --- a/pypots/imputation/crossformer/modules/submodules.py +++ b/pypots/imputation/crossformer/modules/submodules.py @@ -9,7 +9,7 @@ import torch.nn as nn from einops import rearrange, repeat -from ....nn.modules.transformer import MultiHeadAttention +from ....nn.modules.transformer import ScaledDotProductAttention, MultiHeadAttention class TwoStageAttentionLayer(nn.Module): @@ -33,10 +33,26 @@ def __init__( super().__init__() d_ff = 4 * d_model if d_ff is None else d_ff self.time_attention = MultiHeadAttention( - n_heads, d_model, d_k, d_v, attn_dropout + n_heads, + d_model, + d_k, + d_v, + ScaledDotProductAttention(d_k**0.5, attn_dropout), + ) + self.dim_sender = MultiHeadAttention( + n_heads, + d_model, + d_k, + d_v, + ScaledDotProductAttention(d_k**0.5, attn_dropout), + ) + self.dim_receiver = MultiHeadAttention( + n_heads, + d_model, + d_k, + d_v, + ScaledDotProductAttention(d_k**0.5, attn_dropout), ) - self.dim_sender = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout) - self.dim_receiver = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout) self.router = nn.Parameter(torch.randn(seg_num, factor, d_model)) self.dropout = nn.Dropout(dropout) diff --git a/pypots/imputation/patchtst/modules/core.py b/pypots/imputation/patchtst/modules/core.py index 1ba4206d..9013a802 100644 --- a/pypots/imputation/patchtst/modules/core.py +++ b/pypots/imputation/patchtst/modules/core.py @@ -8,6 +8,7 @@ import torch.nn as nn from .submodules import PatchEmbedding, FlattenHead +from ....nn.modules.transformer.attention import ScaledDotProductAttention from ....nn.modules.transformer.auto_encoder import EncoderLayer from ....utils.metrics import calc_mse @@ -49,8 +50,8 @@ def __init__( n_heads, d_k, d_v, + ScaledDotProductAttention(d_k**0.5, attn_dropout), dropout, - attn_dropout, ) for _ in range(n_layers) ] diff --git a/pypots/imputation/saits/modules/core.py b/pypots/imputation/saits/modules/core.py index b0a4f1c3..4976c594 100644 --- a/pypots/imputation/saits/modules/core.py +++ b/pypots/imputation/saits/modules/core.py @@ -20,6 +20,7 @@ import torch.nn.functional as F from ....nn.modules.transformer import EncoderLayer, PositionalEncoding +from ....nn.modules.transformer.attention import ScaledDotProductAttention from ....utils.metrics import calc_mae @@ -59,8 +60,8 @@ def __init__( n_heads, d_k, d_v, + ScaledDotProductAttention(d_k**0.5, attn_dropout), dropout, - attn_dropout, ) for _ in range(n_layers) ] @@ -73,8 +74,8 @@ def __init__( n_heads, d_k, d_v, + ScaledDotProductAttention(d_k**0.5, attn_dropout), dropout, - attn_dropout, ) for _ in range(n_layers) ] diff --git a/pypots/imputation/transformer/modules/core.py b/pypots/imputation/transformer/modules/core.py index 066b7790..f4cfb841 100644 --- a/pypots/imputation/transformer/modules/core.py +++ b/pypots/imputation/transformer/modules/core.py @@ -19,6 +19,7 @@ import torch.nn as nn from ....nn.modules.transformer import EncoderLayer, PositionalEncoding +from ....nn.modules.transformer.attention import ScaledDotProductAttention from ....utils.metrics import calc_mae @@ -52,8 +53,8 @@ def __init__( n_heads, d_k, d_v, + ScaledDotProductAttention(d_k**0.5, attn_dropout), dropout, - attn_dropout, ) for _ in range(n_layers) ] diff --git a/pypots/modules/__init__.py b/pypots/modules/__init__.py deleted file mode 100644 index 638464fe..00000000 --- a/pypots/modules/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Everything used to be in this package has been moved to pypots.nn.modules. -This package is kept for backward compatibility and will be removed in the future. -""" - -# Created by Wenjie Du -# License: BSD-3-Clause - -from ..utils.logging import logger - -logger.warning( - "🚨 pypots.modules package has been moved to pypots.nn.modules. " - "Please import everything from pypots.nn.modules instead." -) diff --git a/pypots/nn/modules/transformer/attention.py b/pypots/nn/modules/transformer/attention.py index 89684473..1c23efd8 100644 --- a/pypots/nn/modules/transformer/attention.py +++ b/pypots/nn/modules/transformer/attention.py @@ -16,9 +16,30 @@ import torch import torch.nn as nn import torch.nn.functional as F +from abc import abstractmethod -class ScaledDotProductAttention(nn.Module): +class AttentionOperator(nn.Module): + """ + The abstract class for all attention layers. + """ + + def __init__(self): + super().__init__() + + @abstractmethod + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + +class ScaledDotProductAttention(AttentionOperator): """Scaled dot-product attention. Parameters @@ -44,6 +65,7 @@ def forward( k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward processing of the scaled dot-product attention. @@ -51,8 +73,10 @@ def forward( ---------- q: Query tensor. + k: Key tensor. + v: Value tensor. @@ -106,11 +130,8 @@ class MultiHeadAttention(nn.Module): d_v: The dimension of the value tensor. - attn_dropout: - The dropout rate for the attention map. - - attn_temperature: - The temperature for scaling. Default is None, which means d_k**0.5 will be applied. + attention_operator: + The attention operator, e.g. the self-attention proposed in Transformer. """ @@ -120,13 +141,10 @@ def __init__( d_model: int, d_k: int, d_v: int, - attn_dropout: float, - attn_temperature: float = None, + attention_operator: AttentionOperator, ): super().__init__() - attn_temperature = d_k**0.5 if attn_temperature is None else attn_temperature - self.n_heads = n_heads self.d_k = d_k self.d_v = d_v @@ -135,7 +153,7 @@ def __init__( self.w_ks = nn.Linear(d_model, n_heads * d_k, bias=False) self.w_vs = nn.Linear(d_model, n_heads * d_v, bias=False) - self.attention = ScaledDotProductAttention(attn_temperature, attn_dropout) + self.attention_operator = attention_operator self.fc = nn.Linear(n_heads * d_v, d_model, bias=False) def forward( @@ -144,6 +162,7 @@ def forward( k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor], + **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward processing of the multi-head attention module. @@ -189,7 +208,7 @@ def forward( # broadcasting on the head axis attn_mask = attn_mask.unsqueeze(1) - v, attn_weights = self.attention(q, k, v, attn_mask) + v, attn_weights = self.attention_operator(q, k, v, attn_mask, **kwargs) # transpose back -> [batch_size, n_steps, n_heads, d_v] # then merge the last two dimensions to combine all the heads -> [batch_size, n_steps, n_heads*d_v] diff --git a/pypots/nn/modules/transformer/auto_encoder.py b/pypots/nn/modules/transformer/auto_encoder.py index 76761ce3..6aa6e1f2 100644 --- a/pypots/nn/modules/transformer/auto_encoder.py +++ b/pypots/nn/modules/transformer/auto_encoder.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn +from .attention import ScaledDotProductAttention from .embedding import PositionalEncoding from .layers import EncoderLayer, DecoderLayer @@ -78,8 +79,8 @@ def __init__( n_heads, d_k, d_v, + ScaledDotProductAttention(d_k**0.5, attn_dropout), dropout, - attn_dropout, ) for _ in range(n_layers) ] @@ -190,8 +191,9 @@ def __init__( n_heads, d_k, d_v, + ScaledDotProductAttention(d_k**0.5, attn_dropout), + ScaledDotProductAttention(d_k**0.5, attn_dropout), dropout, - attn_dropout, ) for _ in range(n_layers) ] diff --git a/pypots/nn/modules/transformer/layers.py b/pypots/nn/modules/transformer/layers.py index a5a558cc..e66b4b32 100644 --- a/pypots/nn/modules/transformer/layers.py +++ b/pypots/nn/modules/transformer/layers.py @@ -11,7 +11,7 @@ import torch.nn as nn import torch.nn.functional as F -from .attention import MultiHeadAttention +from .attention import MultiHeadAttention, AttentionOperator class PositionWiseFeedForward(nn.Module): @@ -85,11 +85,12 @@ class EncoderLayer(nn.Module): d_v: The dimension of the value tensor. + slf_attn_opt: + The attention operator for the self multi-head attention module in the encoder layer. + dropout: The dropout rate. - attn_dropout: - The dropout rate for the attention map. """ def __init__( @@ -99,11 +100,11 @@ def __init__( n_heads: int, d_k: int, d_v: int, + slf_attn_opt: AttentionOperator, dropout: float = 0.1, - attn_dropout: float = 0.1, ): super().__init__() - self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout) + self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, slf_attn_opt) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.pos_ffn = PositionWiseFeedForward(d_model, d_ffn, dropout) @@ -112,6 +113,7 @@ def forward( self, enc_input: torch.Tensor, src_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward processing of the encoder layer. @@ -137,6 +139,7 @@ def forward( enc_input, enc_input, attn_mask=src_mask, + **kwargs, ) # apply dropout and residual connection @@ -170,12 +173,15 @@ class DecoderLayer(nn.Module): d_v: The dimension of the value tensor. + slf_attn_opt: + The attention operator for the self multi-head attention module in the decoder layer. + + enc_attn_opt: + The attention operator for the encoding multi-head attention module in the decoder layer. + dropout: The dropout rate. - attn_dropout: - The dropout rate for the attention map. - """ def __init__( @@ -185,12 +191,13 @@ def __init__( n_heads: int, d_k: int, d_v: int, + slf_attn_opt: AttentionOperator, + enc_attn_opt: AttentionOperator, dropout: float = 0.1, - attn_dropout: float = 0.1, ): super().__init__() - self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout) - self.enc_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout) + self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, slf_attn_opt) + self.enc_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, enc_attn_opt) self.pos_ffn = PositionWiseFeedForward(d_model, d_ffn, dropout) def forward( @@ -199,6 +206,7 @@ def forward( enc_output: torch.Tensor, slf_attn_mask: Optional[torch.Tensor] = None, dec_enc_attn_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Forward processing of the decoder layer. @@ -231,10 +239,18 @@ def forward( """ dec_output, dec_slf_attn = self.slf_attn( - dec_input, dec_input, dec_input, attn_mask=slf_attn_mask + dec_input, + dec_input, + dec_input, + attn_mask=slf_attn_mask, + **kwargs, ) dec_output, dec_enc_attn = self.enc_attn( - dec_output, enc_output, enc_output, attn_mask=dec_enc_attn_mask + dec_output, + enc_output, + enc_output, + attn_mask=dec_enc_attn_mask, + **kwargs, ) dec_output = self.pos_ffn(dec_output) return dec_output, dec_slf_attn, dec_enc_attn