Skip to content

Commit

Permalink
Merge pull request #334 from WenjieDu/(refactor)flexible_transformer_…
Browse files Browse the repository at this point in the history
…layers

Make the self-attention operator replaceable in Transformer
  • Loading branch information
WenjieDu committed Apr 2, 2024
2 parents 46e2434 + 7e2c6f8 commit f7425e7
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 49 deletions.
24 changes: 20 additions & 4 deletions pypots/imputation/crossformer/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.nn as nn
from einops import rearrange, repeat

from ....nn.modules.transformer import MultiHeadAttention
from ....nn.modules.transformer import ScaledDotProductAttention, MultiHeadAttention


class TwoStageAttentionLayer(nn.Module):
Expand All @@ -33,10 +33,26 @@ def __init__(
super().__init__()
d_ff = 4 * d_model if d_ff is None else d_ff
self.time_attention = MultiHeadAttention(
n_heads, d_model, d_k, d_v, attn_dropout
n_heads,
d_model,
d_k,
d_v,
ScaledDotProductAttention(d_k**0.5, attn_dropout),
)
self.dim_sender = MultiHeadAttention(
n_heads,
d_model,
d_k,
d_v,
ScaledDotProductAttention(d_k**0.5, attn_dropout),
)
self.dim_receiver = MultiHeadAttention(
n_heads,
d_model,
d_k,
d_v,
ScaledDotProductAttention(d_k**0.5, attn_dropout),
)
self.dim_sender = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout)
self.dim_receiver = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout)
self.router = nn.Parameter(torch.randn(seg_num, factor, d_model))

self.dropout = nn.Dropout(dropout)
Expand Down
3 changes: 2 additions & 1 deletion pypots/imputation/patchtst/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn as nn

from .submodules import PatchEmbedding, FlattenHead
from ....nn.modules.transformer.attention import ScaledDotProductAttention
from ....nn.modules.transformer.auto_encoder import EncoderLayer
from ....utils.metrics import calc_mse

Expand Down Expand Up @@ -49,8 +50,8 @@ def __init__(
n_heads,
d_k,
d_v,
ScaledDotProductAttention(d_k**0.5, attn_dropout),
dropout,
attn_dropout,
)
for _ in range(n_layers)
]
Expand Down
5 changes: 3 additions & 2 deletions pypots/imputation/saits/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch.nn.functional as F

from ....nn.modules.transformer import EncoderLayer, PositionalEncoding
from ....nn.modules.transformer.attention import ScaledDotProductAttention
from ....utils.metrics import calc_mae


Expand Down Expand Up @@ -59,8 +60,8 @@ def __init__(
n_heads,
d_k,
d_v,
ScaledDotProductAttention(d_k**0.5, attn_dropout),
dropout,
attn_dropout,
)
for _ in range(n_layers)
]
Expand All @@ -73,8 +74,8 @@ def __init__(
n_heads,
d_k,
d_v,
ScaledDotProductAttention(d_k**0.5, attn_dropout),
dropout,
attn_dropout,
)
for _ in range(n_layers)
]
Expand Down
3 changes: 2 additions & 1 deletion pypots/imputation/transformer/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch.nn as nn

from ....nn.modules.transformer import EncoderLayer, PositionalEncoding
from ....nn.modules.transformer.attention import ScaledDotProductAttention
from ....utils.metrics import calc_mae


Expand Down Expand Up @@ -52,8 +53,8 @@ def __init__(
n_heads,
d_k,
d_v,
ScaledDotProductAttention(d_k**0.5, attn_dropout),
dropout,
attn_dropout,
)
for _ in range(n_layers)
]
Expand Down
14 changes: 0 additions & 14 deletions pypots/modules/__init__.py

This file was deleted.

43 changes: 31 additions & 12 deletions pypots/nn/modules/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,30 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import abstractmethod


class ScaledDotProductAttention(nn.Module):
class AttentionOperator(nn.Module):
"""
The abstract class for all attention layers.
"""

def __init__(self):
super().__init__()

@abstractmethod
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError


class ScaledDotProductAttention(AttentionOperator):
"""Scaled dot-product attention.
Parameters
Expand All @@ -44,15 +65,18 @@ def forward(
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward processing of the scaled dot-product attention.
Parameters
----------
q:
Query tensor.
k:
Key tensor.
v:
Value tensor.
Expand Down Expand Up @@ -106,11 +130,8 @@ class MultiHeadAttention(nn.Module):
d_v:
The dimension of the value tensor.
attn_dropout:
The dropout rate for the attention map.
attn_temperature:
The temperature for scaling. Default is None, which means d_k**0.5 will be applied.
attention_operator:
The attention operator, e.g. the self-attention proposed in Transformer.
"""

Expand All @@ -120,13 +141,10 @@ def __init__(
d_model: int,
d_k: int,
d_v: int,
attn_dropout: float,
attn_temperature: float = None,
attention_operator: AttentionOperator,
):
super().__init__()

attn_temperature = d_k**0.5 if attn_temperature is None else attn_temperature

self.n_heads = n_heads
self.d_k = d_k
self.d_v = d_v
Expand All @@ -135,7 +153,7 @@ def __init__(
self.w_ks = nn.Linear(d_model, n_heads * d_k, bias=False)
self.w_vs = nn.Linear(d_model, n_heads * d_v, bias=False)

self.attention = ScaledDotProductAttention(attn_temperature, attn_dropout)
self.attention_operator = attention_operator
self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)

def forward(
Expand All @@ -144,6 +162,7 @@ def forward(
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor],
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward processing of the multi-head attention module.
Expand Down Expand Up @@ -189,7 +208,7 @@ def forward(
# broadcasting on the head axis
attn_mask = attn_mask.unsqueeze(1)

v, attn_weights = self.attention(q, k, v, attn_mask)
v, attn_weights = self.attention_operator(q, k, v, attn_mask, **kwargs)

# transpose back -> [batch_size, n_steps, n_heads, d_v]
# then merge the last two dimensions to combine all the heads -> [batch_size, n_steps, n_heads*d_v]
Expand Down
6 changes: 4 additions & 2 deletions pypots/nn/modules/transformer/auto_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.nn as nn

from .attention import ScaledDotProductAttention
from .embedding import PositionalEncoding
from .layers import EncoderLayer, DecoderLayer

Expand Down Expand Up @@ -78,8 +79,8 @@ def __init__(
n_heads,
d_k,
d_v,
ScaledDotProductAttention(d_k**0.5, attn_dropout),
dropout,
attn_dropout,
)
for _ in range(n_layers)
]
Expand Down Expand Up @@ -190,8 +191,9 @@ def __init__(
n_heads,
d_k,
d_v,
ScaledDotProductAttention(d_k**0.5, attn_dropout),
ScaledDotProductAttention(d_k**0.5, attn_dropout),
dropout,
attn_dropout,
)
for _ in range(n_layers)
]
Expand Down
42 changes: 29 additions & 13 deletions pypots/nn/modules/transformer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch.nn as nn
import torch.nn.functional as F

from .attention import MultiHeadAttention
from .attention import MultiHeadAttention, AttentionOperator


class PositionWiseFeedForward(nn.Module):
Expand Down Expand Up @@ -85,11 +85,12 @@ class EncoderLayer(nn.Module):
d_v:
The dimension of the value tensor.
slf_attn_opt:
The attention operator for the self multi-head attention module in the encoder layer.
dropout:
The dropout rate.
attn_dropout:
The dropout rate for the attention map.
"""

def __init__(
Expand All @@ -99,11 +100,11 @@ def __init__(
n_heads: int,
d_k: int,
d_v: int,
slf_attn_opt: AttentionOperator,
dropout: float = 0.1,
attn_dropout: float = 0.1,
):
super().__init__()
self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout)
self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, slf_attn_opt)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.pos_ffn = PositionWiseFeedForward(d_model, d_ffn, dropout)
Expand All @@ -112,6 +113,7 @@ def forward(
self,
enc_input: torch.Tensor,
src_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward processing of the encoder layer.
Expand All @@ -137,6 +139,7 @@ def forward(
enc_input,
enc_input,
attn_mask=src_mask,
**kwargs,
)

# apply dropout and residual connection
Expand Down Expand Up @@ -170,12 +173,15 @@ class DecoderLayer(nn.Module):
d_v:
The dimension of the value tensor.
slf_attn_opt:
The attention operator for the self multi-head attention module in the decoder layer.
enc_attn_opt:
The attention operator for the encoding multi-head attention module in the decoder layer.
dropout:
The dropout rate.
attn_dropout:
The dropout rate for the attention map.
"""

def __init__(
Expand All @@ -185,12 +191,13 @@ def __init__(
n_heads: int,
d_k: int,
d_v: int,
slf_attn_opt: AttentionOperator,
enc_attn_opt: AttentionOperator,
dropout: float = 0.1,
attn_dropout: float = 0.1,
):
super().__init__()
self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout)
self.enc_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout)
self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, slf_attn_opt)
self.enc_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, enc_attn_opt)
self.pos_ffn = PositionWiseFeedForward(d_model, d_ffn, dropout)

def forward(
Expand All @@ -199,6 +206,7 @@ def forward(
enc_output: torch.Tensor,
slf_attn_mask: Optional[torch.Tensor] = None,
dec_enc_attn_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward processing of the decoder layer.
Expand Down Expand Up @@ -231,10 +239,18 @@ def forward(
"""
dec_output, dec_slf_attn = self.slf_attn(
dec_input, dec_input, dec_input, attn_mask=slf_attn_mask
dec_input,
dec_input,
dec_input,
attn_mask=slf_attn_mask,
**kwargs,
)
dec_output, dec_enc_attn = self.enc_attn(
dec_output, enc_output, enc_output, attn_mask=dec_enc_attn_mask
dec_output,
enc_output,
enc_output,
attn_mask=dec_enc_attn_mask,
**kwargs,
)
dec_output = self.pos_ffn(dec_output)
return dec_output, dec_slf_attn, dec_enc_attn

0 comments on commit f7425e7

Please sign in to comment.