Skip to content

Commit

Permalink
feat: add Informer as an imputation model;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Apr 1, 2024
1 parent 708f9fe commit 8708853
Show file tree
Hide file tree
Showing 11 changed files with 830 additions and 53 deletions.
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .etsformer import ETSformer
from .fedformer import FEDformer
from .crossformer import Crossformer
from .informer import Informer
from .autoformer import Autoformer
from .dlinear import DLinear
from .patchtst import PatchTST
Expand All @@ -36,6 +37,7 @@
"TimesNet",
"PatchTST",
"DLinear",
"Informer",
"Autoformer",
"BRITS",
"MRNN",
Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/autoformer/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from .submodules import (
SeasonalLayerNorm,
AutoformerEncoderLayer,
AutoformerEncoder,
AutoCorrelation,
AutoCorrelationLayer,
)
from ...informer.modules.submodules import InformerEncoder
from ....nn.modules.transformer.embedding import DataEmbedding
from ....utils.metrics import calc_mse

Expand Down Expand Up @@ -43,7 +43,7 @@ def __init__(
dropout=dropout,
with_pos=False,
)
self.encoder = AutoformerEncoder(
self.encoder = InformerEncoder(
[
AutoformerEncoderLayer(
AutoCorrelationLayer(
Expand Down
49 changes: 0 additions & 49 deletions pypots/imputation/autoformer/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,35 +285,6 @@ def forward(self, x, attn_mask=None):
return res, attn


class AutoformerEncoder(nn.Module):
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
super().__init__()
self.attn_layers = nn.ModuleList(attn_layers)
self.conv_layers = (
nn.ModuleList(conv_layers) if conv_layers is not None else None
)
self.norm = norm_layer

def forward(self, x, attn_mask=None):
attns = []
if self.conv_layers is not None:
for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
x, attn = attn_layer(x, attn_mask=attn_mask)
x = conv_layer(x)
attns.append(attn)
x, attn = self.attn_layers[-1](x)
attns.append(attn)
else:
for attn_layer in self.attn_layers:
x, attn = attn_layer(x, attn_mask=attn_mask)
attns.append(attn)

if self.norm is not None:
x = self.norm(x)

return x, attns


class AutoformerDecoderLayer(nn.Module):
"""
Autoformer decoder layer with the progressive decomposition architecture
Expand Down Expand Up @@ -372,23 +343,3 @@ def forward(self, x, cross, x_mask=None, cross_mask=None):
1, 2
)
return x, residual_trend


class AutoformerDecoder(nn.Module):
def __init__(self, layers, norm_layer=None, projection=None):
super().__init__()
self.layers = nn.ModuleList(layers)
self.norm = norm_layer
self.projection = projection

def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):
for layer in self.layers:
x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
trend = trend + residual_trend

if self.norm is not None:
x = self.norm(x)

if self.projection is not None:
x = self.projection(x)
return x, trend
4 changes: 2 additions & 2 deletions pypots/imputation/fedformer/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

from .submodules import MultiWaveletTransform, FourierBlock
from ...autoformer.modules.submodules import (
AutoformerEncoder,
AutoformerEncoderLayer,
AutoCorrelationLayer,
SeasonalLayerNorm,
)
from ...informer.modules.submodules import InformerEncoder
from ....nn.modules.transformer.embedding import DataEmbedding
from ....utils.metrics import calc_mse

Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(
f"Unsupported version: {version}. Please choose from ['Wavelets', 'Fourier']."
)

self.encoder = AutoformerEncoder(
self.encoder = InformerEncoder(
[
AutoformerEncoderLayer(
AutoCorrelationLayer(
Expand Down
17 changes: 17 additions & 0 deletions pypots/imputation/informer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
The package of the partially-observed time-series imputation model Informer.
Refer to the paper "Wu, H., Xu, J., Wang, J., & Long, M. (2021).
Informer: Decomposition transformers with auto-correlation for long-term series forecasting. NeurIPS 2021.".
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


from .model import Informer

__all__ = [
"Informer",
]
24 changes: 24 additions & 0 deletions pypots/imputation/informer/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Dataset class for Informer.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForInformer(DatasetForSAITS):
"""Actually Informer uses the same data strategy as SAITS, needs MIT for training."""

def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_labels: bool,
file_type: str = "h5py",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_labels, file_type, rate)
Loading

0 comments on commit 8708853

Please sign in to comment.