Skip to content

Commit

Permalink
Merge pull request #321 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Removing deprecated save_model and load_model, adding the imputation model Autoformer
  • Loading branch information
WenjieDu committed Mar 27, 2024
2 parents a0470b2 + bfbfcec commit bf53667
Show file tree
Hide file tree
Showing 27 changed files with 1,087 additions and 99 deletions.
18 changes: 10 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,20 +132,20 @@ You can install PyPOTS as shown below:
# via pip
pip install pypots # the first time installation
pip install pypots --upgrade # update pypots to the latest version
# install from the latest source code with the latest features but may be not officially released yet
pip install https://github.com/WenjieDu/PyPOTS/archive/main.zip

# via conda
conda install -c conda-forge pypots # the first time installation
conda update -c conda-forge pypots # update pypots to the latest version
````

Alternatively, you can install from the latest source code with the latest features but may be not officially released yet:
> pip install https://github.com/WenjieDu/PyPOTS/archive/main.zip
```


## ❖ Usage
Besides [BrewPOTS](https://github.com/WenjieDu/BrewPOTS), you can also find a simple and quick-start tutorial notebook
on Google Colab <a href="https://colab.research.google.com/drive/1HEFjylEy05-r47jRy0H9jiS_WhD0UWmQ"><img src="https://img.shields.io/badge/GoogleColab-PyPOTS_Tutorials-F9AB00?logo=googlecolab&logoColor=white" alt="Colab tutorials" align="center"/></a>.
If you have further questions, please refer to PyPOTS documentation [docs.pypots.com](https://docs.pypots.com).
Besides [BrewPOTS](https://github.com/WenjieDu/BrewPOTS), you can also find a simple and quick-start tutorial notebook on Google Colab
<a href="https://colab.research.google.com/drive/1HEFjylEy05-r47jRy0H9jiS_WhD0UWmQ">
<img src="https://img.shields.io/badge/GoogleColab-PyPOTS_Tutorials-F9AB00?logo=googlecolab&logoColor=white" alt="Colab tutorials" align="center"/>
</a>. If you have further questions, please refer to PyPOTS documentation [docs.pypots.com](https://docs.pypots.com).
You can also [raise an issue](https://github.com/WenjieDu/PyPOTS/issues) or [ask in our community](#-community).

We present you a usage example of imputing missing values in time series with PyPOTS below, you can click it to view.
Expand Down Expand Up @@ -198,6 +198,7 @@ This functionality is implemented with the [Microsoft NNI](https://github.com/mi
| Neural Net | SAITS | Self-Attention-based Imputation for Time Series [^1] | 2023 |
| Neural Net | Transformer | Attention is All you Need [^2];<br>Self-Attention-based Imputation for Time Series [^1];<br><sub>Note: proposed in [^2], and re-implemented as an imputation model in [^1].</sub> | 2017 |
| Neural Net | TimesNet | Temporal 2D-Variation Modeling for General Time Series Analysis [^14] | 2023 |
| Neural Net | Autoformer | Decomposition transformers with auto-correlation for long-term series forecasting [^15] | 2021 |
| Neural Net | CSDI | Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation [^12] | 2021 |
| Neural Net | US-GAN | Unsupervised GAN for Multivariate Time Series Imputation [^10] | 2021 |
| Neural Net | GP-VAE | Gaussian Process Variational Autoencoder [^11] | 2020 |
Expand Down Expand Up @@ -249,7 +250,7 @@ url={https://arxiv.org/abs/2305.18811},
doi={10.48550/arXiv.2305.18811},
}
```
or
> Wenjie Du. (2023).
> PyPOTS: a Python toolbox for data mining on Partially-Observed Time Series.
> arXiv, abs/2305.18811.https://arxiv.org/abs/2305.18811
Expand Down Expand Up @@ -318,6 +319,7 @@ PyPOTS community is open, transparent, and surely friendly. Let's work together
[^12]: Tashiro, Y., Song, J., Song, Y., & Ermon, S. (2021). [CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation](https://proceedings.neurips.cc/paper/2021/hash/cfe8504bda37b575c70ee1a8276f3486-Abstract.html). *NeurIPS 2021*.
[^13]: Rubin, D. B. (1976). [Inference and missing data](https://academic.oup.com/biomet/article-abstract/63/3/581/270932). *Biometrika*.
[^14]: Wu, H., Hu, T., Liu, Y., Zhou, H., Wang, J., & Long, M. (2023). [TimesNet: Temporal 2d-variation modeling for general time series analysis](https://openreview.net/forum?id=ju_Uqw384Oq). *ICLR 2023*
[^15]: Wu, H., Xu, J., Wang, J., & Long, M. (2021). [Autoformer: Decomposition transformers with auto-correlation for long-term series forecasting](https://proceedings.neurips.cc/paper/2021/hash/bcc0d400288793e8bdcd7c19a8ac0c2b-Abstract.html). *NeurIPS 2021*.


<details>
Expand Down
4 changes: 3 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ please cite it as below and 🌟star `PyPOTS repository <https://github.com/Wenj
.. code-block:: bibtex
:linenos:
@article{du2023PyPOTS,
@article{du2023pypots,
title={{PyPOTS: a Python toolbox for data mining on Partially-Observed Time Series}},
author={Wenjie Du},
year={2023},
Expand All @@ -215,6 +215,8 @@ please cite it as below and 🌟star `PyPOTS repository <https://github.com/Wenj
doi={10.48550/arXiv.2305.18811},
}
or

..
Wenjie Du. (2023).
Expand Down
15 changes: 5 additions & 10 deletions docs/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,13 @@ It is recommended to use **pip** or **conda** for PyPOTS installation as shown b
# via pip
pip install pypots # the first time installation
pip install pypots --upgrade # update pypots to the latest version
.. code-block:: bash
# install from the latest source code with the latest features but may be not officially released yet
pip install https://github.com/WenjieDu/PyPOTS/archive/main.zip
# via conda
conda install -c conda-forge pypots # the first time installation
conda update -c conda-forge pypots # update pypots to the latest version
Alternatively, you can install from the latest source code which may be not officially released yet:

.. code-block:: bash
pip install https://github.com/WenjieDu/PyPOTS/archive/main.zip
Required Dependencies
"""""""""""""""""""""
Expand All @@ -35,7 +30,7 @@ Required Dependencies
* scikit-learn
* torch >=1.10.0
* tsdb >=0.2
* pygrinder >=0.2
* pygrinder >=0.4


Optional Dependencies
Expand All @@ -61,8 +56,8 @@ Please use Python v3.8 or above if possible also for the security of your develo

Because of pytorch_sparse, please refer to https://github.com/rusty1s/pytorch_sparse/issues/207#issuecomment-1065549338.

* **Why we need TSDB and PyGrinder >=0.2?**
Since v0.2, all libraries in PyPOTS ecosystem switch their licenses from GPL-v3-only to BSD-3-Clause, which has less constraints for users.
* **Why we need TSDB >=0.2 and PyGrinder >=0.4?**
Since v0.2, all libraries in PyPOTS Ecosystem switch their licenses from GPL-v3-only to BSD-3-Clause, which has less constraints for users.
Please refer to the discussion in issue `PyPOTS#227 <https://github.com/WenjieDu/PyPOTS/issues/227>`_ for details.

Acceleration
Expand Down
2 changes: 1 addition & 1 deletion docs/milestones.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ please cite it as below and 🌟star `PyPOTS repository <https://github.com/Wenj
.. code-block:: bibtex
:linenos:
@article{du2023PyPOTS,
@article{du2023pypots,
title={{PyPOTS: A Python Toolbox for Data Mining on Partially-Observed Time Series}},
author={Wenjie Du},
year={2023},
Expand Down
9 changes: 9 additions & 0 deletions docs/pypots.imputation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ pypots.imputation.timesnet
:show-inheritance:
:inherited-members:

pypots.imputation.autoformer
------------------------------

.. automodule:: pypots.imputation.autoformer
:members:
:undoc-members:
:show-inheritance:
:inherited-members:

pypots.imputation.csdi
------------------------------

Expand Down
19 changes: 15 additions & 4 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,19 @@ @inproceedings{wu2023timesnet
}

@inproceedings{liu2022nonstationary,
title={Non-stationary Transformers: Exploring the Stationarity in Time Series Forecasting},
author={Liu, Yong and Wu, Haixu and Wang, Jianmin and Long, Mingsheng},
booktitle={Advances in Neural Information Processing Systems},
year={2022}
title={Non-stationary Transformers: Exploring the Stationarity in Time Series Forecasting},
author={Liu, Yong and Wu, Haixu and Wang, Jianmin and Long, Mingsheng},
booktitle={Advances in Neural Information Processing Systems},
year={2022}
}

@inproceedings{wu2021autoformer,
author = {Wu, Haixu and Xu, Jiehui and Wang, Jianmin and Long, Mingsheng},
booktitle = {Advances in Neural Information Processing Systems},
pages = {22419--22430},
publisher = {Curran Associates, Inc.},
title = {Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting},
url = {https://proceedings.neurips.cc/paper_files/paper/2021/file/bcc0d400288793e8bdcd7c19a8ac0c2b-Paper.pdf},
volume = {34},
year = {2021}
}
50 changes: 0 additions & 50 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,56 +332,6 @@ def load(self, path: str) -> None:
raise e
logger.info(f"Model loaded successfully from {path}")

def save_model(
self,
saving_path: str,
overwrite: bool = False,
) -> None:
"""Save the model with current parameters to a disk file.
A ``.pypots`` extension will be appended to the filename if it does not already have one.
Please note that such an extension is not necessary, but to indicate the saved model is from PyPOTS framework
so people can distinguish.
Parameters
----------
saving_path :
The given path to save the model. The directory will be created if it does not exist.
overwrite :
Whether to overwrite the model file if the path already exists.
Warnings
--------
The method save_model is deprecated. Please use `save()` instead.
"""
logger.warning(
"🚨DeprecationWarning: The method save_model is deprecated. Please use `save()` instead."
)
self.save(saving_path, overwrite)

def load_model(self, path: str) -> None:
"""Load the saved model from a disk file.
Parameters
----------
path :
The local path to a disk file saving the trained model.
Notes
-----
If the training environment and the deploying/test environment use the same type of device (GPU/CPU),
you can load the model directly with torch.load(model_path).
Warnings
--------
The method load_model is deprecated. Please use `load()` instead.
"""
logger.warning(
"🚨DeprecationWarning: The method load_model is deprecated. Please use `load()` instead."
)
self.load(path)

@abstractmethod
def fit(
self,
Expand Down
4 changes: 2 additions & 2 deletions pypots/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def _train_model(

mean_val_loss = np.mean(epoch_val_loss_collector)

# save validating loss logs into the tensorboard file for every epoch if in need
# save validation loss logs into the tensorboard file for every epoch if in need
if self.summary_writer is not None:
val_loss_dict = {
"classification_loss": mean_val_loss,
Expand All @@ -322,7 +322,7 @@ def _train_model(
logger.info(
f"Epoch {epoch:03d} - "
f"training loss: {mean_train_loss:.4f}, "
f"validating loss: {mean_val_loss:.4f}"
f"validation loss: {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
Expand Down
2 changes: 1 addition & 1 deletion pypots/clustering/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def _train_model(
logger.info(
f"Epoch {epoch:03d} - "
f"training loss: {mean_train_loss:.4f}, "
f"validating loss: {mean_val_loss:.4f}"
f"validation loss: {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
Expand Down
4 changes: 2 additions & 2 deletions pypots/clustering/crli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def _train_model(
results["generation_loss"].sum().item()
)
mean_val_G_loss = np.mean(epoch_val_loss_G_collector)
# save validating loss logs into the tensorboard file for every epoch if in need
# save validation loss logs into the tensorboard file for every epoch if in need
if self.summary_writer is not None:
val_loss_dict = {
"generation_loss": mean_val_G_loss,
Expand All @@ -286,7 +286,7 @@ def _train_model(
f"Epoch {epoch:03d} - "
f"generator training loss: {mean_epoch_train_G_loss:.4f}, "
f"discriminator training loss: {mean_epoch_train_D_loss:.4f}, "
f"generator validating loss: {mean_val_G_loss:.4f}"
f"generator validation loss: {mean_val_G_loss:.4f}"
)
mean_loss = mean_val_G_loss
else:
Expand Down
4 changes: 2 additions & 2 deletions pypots/clustering/vader/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def _train_model(

mean_val_loss = np.mean(epoch_val_loss_collector)

# save validating loss logs into the tensorboard file for every epoch if in need
# save validation loss logs into the tensorboard file for every epoch if in need
if self.summary_writer is not None:
val_loss_dict = {
"loss": mean_val_loss,
Expand All @@ -303,7 +303,7 @@ def _train_model(
logger.info(
f"Epoch {epoch:03d} - "
f"training loss: {mean_train_loss:.4f}, "
f"validating loss: {mean_val_loss:.4f}"
f"validation loss: {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
Expand Down
4 changes: 2 additions & 2 deletions pypots/forecasting/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def _train_model(

mean_val_loss = np.mean(forecasting_loss_collector)

# save validating loss logs into the tensorboard file for every epoch if in need
# save validation loss logs into the tensorboard file for every epoch if in need
if self.summary_writer is not None:
val_loss_dict = {
"forecasting_loss": mean_val_loss,
Expand All @@ -316,7 +316,7 @@ def _train_model(
logger.info(
f"Epoch {epoch:03d} - "
f"training loss: {mean_train_loss:.4f}, "
f"validating loss: {mean_val_loss:.4f}"
f"validation loss: {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
Expand Down
4 changes: 3 additions & 1 deletion pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from .gpvae import GPVAE
from .mrnn import MRNN
from .saits import SAITS
from .timesnet import TimesNet
from .transformer import Transformer
from .timesnet import TimesNet
from .autoformer import Autoformer
from .usgan import USGAN

# naive imputation methods
Expand All @@ -25,6 +26,7 @@
"SAITS",
"Transformer",
"TimesNet",
"Autoformer",
"BRITS",
"MRNN",
"GPVAE",
Expand Down
17 changes: 17 additions & 0 deletions pypots/imputation/autoformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
The package of the partially-observed time-series imputation model Autoformer.
Refer to the paper "Wu, H., Xu, J., Wang, J., & Long, M. (2021).
Autoformer: Decomposition transformers with auto-correlation for long-term series forecasting. NeurIPS 2021.".
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


from .model import Autoformer

__all__ = [
"Autoformer",
]
24 changes: 24 additions & 0 deletions pypots/imputation/autoformer/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Dataset class for TimesNet.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForAutoformer(DatasetForSAITS):
"""Actually Autoformer uses the same data strategy as SAITS, needs MIT for training."""

def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_labels: bool,
file_type: str = "h5py",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_labels, file_type, rate)
Loading

0 comments on commit bf53667

Please sign in to comment.