diff --git a/CHANGELOG.md b/CHANGELOG.md index ef1bf75..0b86747 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ ## 0.3.0 (unreleased) +- Added support for MSTL - Added `Stl` struct - Added `into_parts` method to `StlResult` - Changed `StlResult` methods to return slices diff --git a/README.md b/README.md index cd7a2f4..72faeb7 100644 --- a/README.md +++ b/README.md @@ -51,9 +51,20 @@ Get robustness weights res.weights(); ``` +## Multiple Seasonality + +Specify multiple periods [unreleased] + +```rust +use stlrs::Mstl; + +let periods = [6, 10]; +let res = Mstl::fit(&series, &periods).unwrap(); +``` + ## Parameters -Set parameters +Set STL parameters ```rust stlrs::params() @@ -71,6 +82,14 @@ stlrs::params() .robust(false) // if robustness iterations are to be used ``` +Set MSTL parameters [unreleased] + +```rust +Mstl::params() + .iterations(2) // number of iterations + .stl_params(Stl::params()) // STL params +``` + ## Strength Get the seasonal strength @@ -92,6 +111,7 @@ This library was ported from the [Fortran implementation](https://www.netlib.org ## References - [STL: A Seasonal-Trend Decomposition Procedure Based on Loess](https://www.scb.se/contentassets/ca21efb41fee47d293bbee5bf7be7fb3/stl-a-seasonal-trend-decomposition-procedure-based-on-loess.pdf) +- [MSTL: A Seasonal-Trend Decomposition Algorithm for Time Series with Multiple Seasonal Patterns](https://arxiv.org/pdf/2107.13462.pdf) - [Measuring strength of trend and seasonality](https://otexts.com/fpp2/seasonal-strength.html) ## History diff --git a/src/lib.rs b/src/lib.rs index 43bbc70..0744167 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,15 +3,21 @@ //! [View the docs](https://github.com/ankane/stl-rust) mod error; -mod params; -mod result; +mod mstl; +mod mstl_params; +mod mstl_result; mod stl; mod stl_impl; +mod stl_params; +mod stl_result; pub use error::Error; -pub use params::StlParams; -pub use result::StlResult; +pub use mstl::Mstl; +pub use mstl_params::MstlParams; +pub use mstl_result::MstlResult; pub use stl::Stl; +pub use stl_params::StlParams; +pub use stl_result::StlResult; pub fn params() -> StlParams { StlParams::new() diff --git a/src/mstl.rs b/src/mstl.rs new file mode 100644 index 0000000..6fa3a6c --- /dev/null +++ b/src/mstl.rs @@ -0,0 +1,83 @@ +use super::{Error, MstlParams, MstlResult}; + +pub struct Mstl; + +impl Mstl { + pub fn fit(series: &[f32], periods: &[usize]) -> Result { + MstlParams::new().fit(series, periods) + } + + pub fn params() -> MstlParams { + MstlParams::new() + } +} + +#[cfg(test)] +mod tests { + use crate::{Error, Mstl}; + + fn assert_in_delta(exp: f32, act: f32) { + assert!((exp - act).abs() < 0.001); + } + + fn assert_elements_in_delta(exp: &[f32], act: &[f32]) { + assert_eq!(exp.len(), act.len()); + for i in 0..exp.len() { + assert_in_delta(exp[i], act[i]); + } + } + + fn generate_series() -> Vec { + return vec![ + 5.0, 9.0, 2.0, 9.0, 0.0, 6.0, 3.0, 8.0, 5.0, 8.0, + 7.0, 8.0, 8.0, 0.0, 2.0, 5.0, 0.0, 5.0, 6.0, 7.0, + 3.0, 6.0, 1.0, 4.0, 4.0, 4.0, 3.0, 7.0, 5.0, 8.0 + ]; + } + + #[test] + fn test_works() { + let result = Mstl::fit(&generate_series(), &[6, 10]).unwrap(); + assert_elements_in_delta(&[0.28318232, 0.70529824, -1.980384, 2.1643379, -2.3356874], &result.seasonal()[0][..5]); + assert_elements_in_delta(&[1.4130436, 1.6048906, 0.050958008, -1.8706754, -1.7704514], &result.seasonal()[1][..5]); + assert_elements_in_delta(&[5.139485, 5.223691, 5.3078976, 5.387292, 5.4666862], &result.trend()[..5]); + assert_elements_in_delta(&[-1.835711, 1.4661198, -1.3784716, 3.319045, -1.3605475], &result.remainder()[..5]); + } + + #[test] + fn test_unsorted_periods() { + let result = Mstl::fit(&generate_series(), &[10, 6]).unwrap(); + assert_elements_in_delta(&[1.4130436, 1.6048906, 0.050958008, -1.8706754, -1.7704514], &result.seasonal()[0][..5]); + assert_elements_in_delta(&[0.28318232, 0.70529824, -1.980384, 2.1643379, -2.3356874], &result.seasonal()[1][..5]); + assert_elements_in_delta(&[5.139485, 5.223691, 5.3078976, 5.387292, 5.4666862], &result.trend()[..5]); + assert_elements_in_delta(&[-1.835711, 1.4661198, -1.3784716, 3.319045, -1.3605475], &result.remainder()[..5]); + } + + #[test] + fn test_empty_periods() { + let periods: Vec = Vec::new(); + let result = Mstl::fit(&generate_series(), &periods); + assert_eq!( + result.unwrap_err(), + Error::Parameter("periods must not be empty".to_string()) + ); + } + + #[test] + fn test_period_one() { + let result = Mstl::fit(&generate_series(), &[1]); + assert_eq!( + result.unwrap_err(), + Error::Parameter("periods must be at least 2".to_string()) + ); + } + + #[test] + fn test_too_few_periods() { + let result = Mstl::fit(&generate_series(), &[16]); + assert_eq!( + result.unwrap_err(), + Error::Series("series has less than two periods".to_string()) + ); + } +} diff --git a/src/mstl_params.rs b/src/mstl_params.rs new file mode 100644 index 0000000..6d57a15 --- /dev/null +++ b/src/mstl_params.rs @@ -0,0 +1,117 @@ +// Bandara, K., Hyndman, R. J., & Bergmeir, C. (2021). +// MSTL: A Seasonal-Trend Decomposition Algorithm for Time Series with Multiple Seasonal Patterns. +// arXiv:2107.13462 [stat.AP]. https://doi.org/10.48550/arXiv.2107.13462 + +use super::{Error, MstlResult, StlParams}; + +pub struct MstlParams { + iterate: usize, + stl_params: StlParams, +} + +impl MstlParams { + pub fn new() -> Self { + Self { + iterate: 2, + stl_params: StlParams::new(), + } + } + + pub fn iterations(&mut self, iterate: usize) -> &mut Self { + self.iterate = iterate; + self + } + + pub fn stl_params(&mut self, stl_params: StlParams) -> &mut Self { + self.stl_params = stl_params; + self + } + + pub fn fit(&self, series: &[f32], periods: &[usize]) -> Result { + let x = series; + let seas_ids = periods; + let k = x.len(); + + // return error to be consistent with stl + // and ensure seasonal is always same length as seas_ids + if seas_ids.iter().any(|&v| v < 2) { + return Err(Error::Parameter("periods must be at least 2".to_string())); + } + + // return error to be consistent with stl + // and ensure seasonal is always same length as seas_ids + for np in seas_ids { + if k < np * 2 { + return Err(Error::Series("series has less than two periods".to_string())); + } + } + + // keep track of indices instead of sorting seas_ids + // so order is preserved with seasonality + let mut indices: Vec = (0..seas_ids.len()).collect(); + indices.sort_by_key(|&i| &seas_ids[i]); + + let mut iterate = self.iterate; + if seas_ids.len() == 1 { + iterate = 1; + } + + let mut seasonality = Vec::with_capacity(seas_ids.len()); + let mut trend = Vec::new(); + + // TODO add lambda param + let mut deseas = x.to_vec(); + + if !seas_ids.is_empty() { + for _ in 0..seas_ids.len() { + seasonality.push(Vec::new()); + } + + for j in 0..iterate { + for (i, &idx) in indices.iter().enumerate() { + let np = seas_ids[idx]; + + if j > 0 { + for (d, s) in deseas.iter_mut().zip(&seasonality[idx]) { + *d += s; + } + } + + // TODO add seasonal_lengths param + let fit = if self.stl_params.ns.is_some() { + self.stl_params.fit(&deseas, np)? + } else { + let seasonal_length = 7 + 4 * (i + 1); + self.stl_params.clone().seasonal_length(seasonal_length).fit(&deseas, np)? + }; + + (seasonality[idx], trend, _, _) = fit.into_parts(); + + for (d, s) in deseas.iter_mut().zip(&seasonality[idx]) { + *d -= s; + } + } + } + } else { + // TODO use Friedman's Super Smoother for trend + return Err(Error::Parameter("periods must not be empty".to_string())); + } + + let mut remainder = Vec::with_capacity(k); + for i in 0..k { + remainder.push(deseas[i] - trend[i]); + } + + Ok(MstlResult { + seasonal: seasonality, + trend, + remainder + }) + } +} + +impl Default for MstlParams { + fn default() -> Self { + Self::new() + } +} diff --git a/src/mstl_result.rs b/src/mstl_result.rs new file mode 100644 index 0000000..58849dc --- /dev/null +++ b/src/mstl_result.rs @@ -0,0 +1,30 @@ +use super::stl_result::strength; + +#[derive(Clone, Debug)] +pub struct MstlResult { + pub(crate) seasonal: Vec>, + pub(crate) trend: Vec, + pub(crate) remainder: Vec, +} + +impl MstlResult { + pub fn seasonal(&self) -> &[Vec] { + &self.seasonal[..] + } + + pub fn trend(&self) -> &[f32] { + &self.trend + } + + pub fn remainder(&self) -> &[f32] { + &self.remainder + } + + pub fn seasonal_strength(&self) -> Vec { + self.seasonal().iter().map(|s| strength(s, self.remainder())).collect() + } + + pub fn trend_strength(&self) -> f32 { + strength(self.trend(), self.remainder()) + } +} diff --git a/src/params.rs b/src/stl_params.rs similarity index 99% rename from src/params.rs rename to src/stl_params.rs index 5563e3c..62ac4f8 100644 --- a/src/params.rs +++ b/src/stl_params.rs @@ -3,7 +3,7 @@ use super::stl_impl::stl; #[derive(Clone, Debug)] pub struct StlParams { - ns: Option, + pub(crate) ns: Option, nt: Option, nl: Option, isdeg: i32, diff --git a/src/result.rs b/src/stl_result.rs similarity index 94% rename from src/result.rs rename to src/stl_result.rs index 055e303..fb7be56 100644 --- a/src/result.rs +++ b/src/stl_result.rs @@ -11,7 +11,7 @@ fn var(series: &[f32]) -> f32 { series.iter().map(|v| (v - mean).powf(2.0)).sum::() / (series.len() as f32 - 1.0) } -fn strength(component: &[f32], remainder: &[f32]) -> f32 { +pub(crate) fn strength(component: &[f32], remainder: &[f32]) -> f32 { let sr = component.iter().zip(remainder).map(|(a, b)| a + b).collect::>(); (1.0 - var(remainder) / var(&sr)).max(0.0) }