Skip to content

Commit

Permalink
Added support for MSTL
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Sep 22, 2023
1 parent e484ed0 commit 33ed566
Show file tree
Hide file tree
Showing 8 changed files with 264 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
@@ -1,5 +1,6 @@
## 0.3.0 (unreleased)

- Added support for MSTL
- Added `Stl` struct
- Added `into_parts` method to `StlResult`
- Changed `StlResult` methods to return slices
Expand Down
22 changes: 21 additions & 1 deletion README.md
Expand Up @@ -51,9 +51,20 @@ Get robustness weights
res.weights();
```

## Multiple Seasonality

Specify multiple periods [unreleased]

```rust
use stlrs::Mstl;

let periods = [6, 10];
let res = Mstl::fit(&series, &periods).unwrap();
```

## Parameters

Set parameters
Set STL parameters

```rust
stlrs::params()
Expand All @@ -71,6 +82,14 @@ stlrs::params()
.robust(false) // if robustness iterations are to be used
```

Set MSTL parameters [unreleased]

```rust
Mstl::params()
.iterations(2) // number of iterations
.stl_params(Stl::params()) // STL params
```

## Strength

Get the seasonal strength
Expand All @@ -92,6 +111,7 @@ This library was ported from the [Fortran implementation](https://www.netlib.org
## References

- [STL: A Seasonal-Trend Decomposition Procedure Based on Loess](https://www.scb.se/contentassets/ca21efb41fee47d293bbee5bf7be7fb3/stl-a-seasonal-trend-decomposition-procedure-based-on-loess.pdf)
- [MSTL: A Seasonal-Trend Decomposition Algorithm for Time Series with Multiple Seasonal Patterns](https://arxiv.org/pdf/2107.13462.pdf)
- [Measuring strength of trend and seasonality](https://otexts.com/fpp2/seasonal-strength.html)

## History
Expand Down
14 changes: 10 additions & 4 deletions src/lib.rs
Expand Up @@ -3,15 +3,21 @@
//! [View the docs](https://github.com/ankane/stl-rust)

mod error;
mod params;
mod result;
mod mstl;
mod mstl_params;
mod mstl_result;
mod stl;
mod stl_impl;
mod stl_params;
mod stl_result;

pub use error::Error;
pub use params::StlParams;
pub use result::StlResult;
pub use mstl::Mstl;
pub use mstl_params::MstlParams;
pub use mstl_result::MstlResult;
pub use stl::Stl;
pub use stl_params::StlParams;
pub use stl_result::StlResult;

pub fn params() -> StlParams {
StlParams::new()
Expand Down
83 changes: 83 additions & 0 deletions src/mstl.rs
@@ -0,0 +1,83 @@
use super::{Error, MstlParams, MstlResult};

pub struct Mstl;

impl Mstl {
pub fn fit(series: &[f32], periods: &[usize]) -> Result<MstlResult, Error> {
MstlParams::new().fit(series, periods)
}

pub fn params() -> MstlParams {
MstlParams::new()
}
}

#[cfg(test)]
mod tests {
use crate::{Error, Mstl};

fn assert_in_delta(exp: f32, act: f32) {
assert!((exp - act).abs() < 0.001);
}

fn assert_elements_in_delta(exp: &[f32], act: &[f32]) {
assert_eq!(exp.len(), act.len());
for i in 0..exp.len() {
assert_in_delta(exp[i], act[i]);
}
}

fn generate_series() -> Vec<f32> {
return vec![
5.0, 9.0, 2.0, 9.0, 0.0, 6.0, 3.0, 8.0, 5.0, 8.0,
7.0, 8.0, 8.0, 0.0, 2.0, 5.0, 0.0, 5.0, 6.0, 7.0,
3.0, 6.0, 1.0, 4.0, 4.0, 4.0, 3.0, 7.0, 5.0, 8.0
];
}

#[test]
fn test_works() {
let result = Mstl::fit(&generate_series(), &[6, 10]).unwrap();
assert_elements_in_delta(&[0.28318232, 0.70529824, -1.980384, 2.1643379, -2.3356874], &result.seasonal()[0][..5]);
assert_elements_in_delta(&[1.4130436, 1.6048906, 0.050958008, -1.8706754, -1.7704514], &result.seasonal()[1][..5]);
assert_elements_in_delta(&[5.139485, 5.223691, 5.3078976, 5.387292, 5.4666862], &result.trend()[..5]);
assert_elements_in_delta(&[-1.835711, 1.4661198, -1.3784716, 3.319045, -1.3605475], &result.remainder()[..5]);
}

#[test]
fn test_unsorted_periods() {
let result = Mstl::fit(&generate_series(), &[10, 6]).unwrap();
assert_elements_in_delta(&[1.4130436, 1.6048906, 0.050958008, -1.8706754, -1.7704514], &result.seasonal()[0][..5]);
assert_elements_in_delta(&[0.28318232, 0.70529824, -1.980384, 2.1643379, -2.3356874], &result.seasonal()[1][..5]);
assert_elements_in_delta(&[5.139485, 5.223691, 5.3078976, 5.387292, 5.4666862], &result.trend()[..5]);
assert_elements_in_delta(&[-1.835711, 1.4661198, -1.3784716, 3.319045, -1.3605475], &result.remainder()[..5]);
}

#[test]
fn test_empty_periods() {
let periods: Vec<usize> = Vec::new();
let result = Mstl::fit(&generate_series(), &periods);
assert_eq!(
result.unwrap_err(),
Error::Parameter("periods must not be empty".to_string())
);
}

#[test]
fn test_period_one() {
let result = Mstl::fit(&generate_series(), &[1]);
assert_eq!(
result.unwrap_err(),
Error::Parameter("periods must be at least 2".to_string())
);
}

#[test]
fn test_too_few_periods() {
let result = Mstl::fit(&generate_series(), &[16]);
assert_eq!(
result.unwrap_err(),
Error::Series("series has less than two periods".to_string())
);
}
}
117 changes: 117 additions & 0 deletions src/mstl_params.rs
@@ -0,0 +1,117 @@
// Bandara, K., Hyndman, R. J., & Bergmeir, C. (2021).
// MSTL: A Seasonal-Trend Decomposition Algorithm for Time Series with Multiple Seasonal Patterns.
// arXiv:2107.13462 [stat.AP]. https://doi.org/10.48550/arXiv.2107.13462

use super::{Error, MstlResult, StlParams};

pub struct MstlParams {
iterate: usize,
stl_params: StlParams,
}

impl MstlParams {
pub fn new() -> Self {
Self {
iterate: 2,
stl_params: StlParams::new(),
}
}

pub fn iterations(&mut self, iterate: usize) -> &mut Self {
self.iterate = iterate;
self
}

pub fn stl_params(&mut self, stl_params: StlParams) -> &mut Self {
self.stl_params = stl_params;
self
}

pub fn fit(&self, series: &[f32], periods: &[usize]) -> Result<MstlResult, Error> {
let x = series;
let seas_ids = periods;
let k = x.len();

// return error to be consistent with stl
// and ensure seasonal is always same length as seas_ids
if seas_ids.iter().any(|&v| v < 2) {
return Err(Error::Parameter("periods must be at least 2".to_string()));
}

// return error to be consistent with stl
// and ensure seasonal is always same length as seas_ids
for np in seas_ids {
if k < np * 2 {
return Err(Error::Series("series has less than two periods".to_string()));
}
}

// keep track of indices instead of sorting seas_ids
// so order is preserved with seasonality
let mut indices: Vec<usize> = (0..seas_ids.len()).collect();
indices.sort_by_key(|&i| &seas_ids[i]);

let mut iterate = self.iterate;
if seas_ids.len() == 1 {
iterate = 1;
}

let mut seasonality = Vec::with_capacity(seas_ids.len());
let mut trend = Vec::new();

// TODO add lambda param
let mut deseas = x.to_vec();

if !seas_ids.is_empty() {
for _ in 0..seas_ids.len() {
seasonality.push(Vec::new());
}

for j in 0..iterate {
for (i, &idx) in indices.iter().enumerate() {
let np = seas_ids[idx];

if j > 0 {
for (d, s) in deseas.iter_mut().zip(&seasonality[idx]) {
*d += s;
}
}

// TODO add seasonal_lengths param
let fit = if self.stl_params.ns.is_some() {
self.stl_params.fit(&deseas, np)?
} else {
let seasonal_length = 7 + 4 * (i + 1);
self.stl_params.clone().seasonal_length(seasonal_length).fit(&deseas, np)?
};

(seasonality[idx], trend, _, _) = fit.into_parts();

for (d, s) in deseas.iter_mut().zip(&seasonality[idx]) {
*d -= s;
}
}
}
} else {
// TODO use Friedman's Super Smoother for trend
return Err(Error::Parameter("periods must not be empty".to_string()));
}

let mut remainder = Vec::with_capacity(k);
for i in 0..k {
remainder.push(deseas[i] - trend[i]);
}

Ok(MstlResult {
seasonal: seasonality,
trend,
remainder
})
}
}

impl Default for MstlParams {
fn default() -> Self {
Self::new()
}
}
30 changes: 30 additions & 0 deletions src/mstl_result.rs
@@ -0,0 +1,30 @@
use super::stl_result::strength;

#[derive(Clone, Debug)]
pub struct MstlResult {
pub(crate) seasonal: Vec<Vec<f32>>,
pub(crate) trend: Vec<f32>,
pub(crate) remainder: Vec<f32>,
}

impl MstlResult {
pub fn seasonal(&self) -> &[Vec<f32>] {
&self.seasonal[..]
}

pub fn trend(&self) -> &[f32] {
&self.trend
}

pub fn remainder(&self) -> &[f32] {
&self.remainder
}

pub fn seasonal_strength(&self) -> Vec<f32> {
self.seasonal().iter().map(|s| strength(s, self.remainder())).collect()
}

pub fn trend_strength(&self) -> f32 {
strength(self.trend(), self.remainder())
}
}
2 changes: 1 addition & 1 deletion src/params.rs → src/stl_params.rs
Expand Up @@ -3,7 +3,7 @@ use super::stl_impl::stl;

#[derive(Clone, Debug)]
pub struct StlParams {
ns: Option<usize>,
pub(crate) ns: Option<usize>,
nt: Option<usize>,
nl: Option<usize>,
isdeg: i32,
Expand Down
2 changes: 1 addition & 1 deletion src/result.rs → src/stl_result.rs
Expand Up @@ -11,7 +11,7 @@ fn var(series: &[f32]) -> f32 {
series.iter().map(|v| (v - mean).powf(2.0)).sum::<f32>() / (series.len() as f32 - 1.0)
}

fn strength(component: &[f32], remainder: &[f32]) -> f32 {
pub(crate) fn strength(component: &[f32], remainder: &[f32]) -> f32 {
let sr = component.iter().zip(remainder).map(|(a, b)| a + b).collect::<Vec<f32>>();
(1.0 - var(remainder) / var(&sr)).max(0.0)
}
Expand Down

0 comments on commit 33ed566

Please sign in to comment.