Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
264 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
use super::{Error, MstlParams, MstlResult}; | ||
|
||
pub struct Mstl; | ||
|
||
impl Mstl { | ||
pub fn fit(series: &[f32], periods: &[usize]) -> Result<MstlResult, Error> { | ||
MstlParams::new().fit(series, periods) | ||
} | ||
|
||
pub fn params() -> MstlParams { | ||
MstlParams::new() | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use crate::{Error, Mstl}; | ||
|
||
fn assert_in_delta(exp: f32, act: f32) { | ||
assert!((exp - act).abs() < 0.001); | ||
} | ||
|
||
fn assert_elements_in_delta(exp: &[f32], act: &[f32]) { | ||
assert_eq!(exp.len(), act.len()); | ||
for i in 0..exp.len() { | ||
assert_in_delta(exp[i], act[i]); | ||
} | ||
} | ||
|
||
fn generate_series() -> Vec<f32> { | ||
return vec![ | ||
5.0, 9.0, 2.0, 9.0, 0.0, 6.0, 3.0, 8.0, 5.0, 8.0, | ||
7.0, 8.0, 8.0, 0.0, 2.0, 5.0, 0.0, 5.0, 6.0, 7.0, | ||
3.0, 6.0, 1.0, 4.0, 4.0, 4.0, 3.0, 7.0, 5.0, 8.0 | ||
]; | ||
} | ||
|
||
#[test] | ||
fn test_works() { | ||
let result = Mstl::fit(&generate_series(), &[6, 10]).unwrap(); | ||
assert_elements_in_delta(&[0.28318232, 0.70529824, -1.980384, 2.1643379, -2.3356874], &result.seasonal()[0][..5]); | ||
assert_elements_in_delta(&[1.4130436, 1.6048906, 0.050958008, -1.8706754, -1.7704514], &result.seasonal()[1][..5]); | ||
assert_elements_in_delta(&[5.139485, 5.223691, 5.3078976, 5.387292, 5.4666862], &result.trend()[..5]); | ||
assert_elements_in_delta(&[-1.835711, 1.4661198, -1.3784716, 3.319045, -1.3605475], &result.remainder()[..5]); | ||
} | ||
|
||
#[test] | ||
fn test_unsorted_periods() { | ||
let result = Mstl::fit(&generate_series(), &[10, 6]).unwrap(); | ||
assert_elements_in_delta(&[1.4130436, 1.6048906, 0.050958008, -1.8706754, -1.7704514], &result.seasonal()[0][..5]); | ||
assert_elements_in_delta(&[0.28318232, 0.70529824, -1.980384, 2.1643379, -2.3356874], &result.seasonal()[1][..5]); | ||
assert_elements_in_delta(&[5.139485, 5.223691, 5.3078976, 5.387292, 5.4666862], &result.trend()[..5]); | ||
assert_elements_in_delta(&[-1.835711, 1.4661198, -1.3784716, 3.319045, -1.3605475], &result.remainder()[..5]); | ||
} | ||
|
||
#[test] | ||
fn test_empty_periods() { | ||
let periods: Vec<usize> = Vec::new(); | ||
let result = Mstl::fit(&generate_series(), &periods); | ||
assert_eq!( | ||
result.unwrap_err(), | ||
Error::Parameter("periods must not be empty".to_string()) | ||
); | ||
} | ||
|
||
#[test] | ||
fn test_period_one() { | ||
let result = Mstl::fit(&generate_series(), &[1]); | ||
assert_eq!( | ||
result.unwrap_err(), | ||
Error::Parameter("periods must be at least 2".to_string()) | ||
); | ||
} | ||
|
||
#[test] | ||
fn test_too_few_periods() { | ||
let result = Mstl::fit(&generate_series(), &[16]); | ||
assert_eq!( | ||
result.unwrap_err(), | ||
Error::Series("series has less than two periods".to_string()) | ||
); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
// Bandara, K., Hyndman, R. J., & Bergmeir, C. (2021). | ||
// MSTL: A Seasonal-Trend Decomposition Algorithm for Time Series with Multiple Seasonal Patterns. | ||
// arXiv:2107.13462 [stat.AP]. https://doi.org/10.48550/arXiv.2107.13462 | ||
|
||
use super::{Error, MstlResult, StlParams}; | ||
|
||
pub struct MstlParams { | ||
iterate: usize, | ||
stl_params: StlParams, | ||
} | ||
|
||
impl MstlParams { | ||
pub fn new() -> Self { | ||
Self { | ||
iterate: 2, | ||
stl_params: StlParams::new(), | ||
} | ||
} | ||
|
||
pub fn iterations(&mut self, iterate: usize) -> &mut Self { | ||
self.iterate = iterate; | ||
self | ||
} | ||
|
||
pub fn stl_params(&mut self, stl_params: StlParams) -> &mut Self { | ||
self.stl_params = stl_params; | ||
self | ||
} | ||
|
||
pub fn fit(&self, series: &[f32], periods: &[usize]) -> Result<MstlResult, Error> { | ||
let x = series; | ||
let seas_ids = periods; | ||
let k = x.len(); | ||
|
||
// return error to be consistent with stl | ||
// and ensure seasonal is always same length as seas_ids | ||
if seas_ids.iter().any(|&v| v < 2) { | ||
return Err(Error::Parameter("periods must be at least 2".to_string())); | ||
} | ||
|
||
// return error to be consistent with stl | ||
// and ensure seasonal is always same length as seas_ids | ||
for np in seas_ids { | ||
if k < np * 2 { | ||
return Err(Error::Series("series has less than two periods".to_string())); | ||
} | ||
} | ||
|
||
// keep track of indices instead of sorting seas_ids | ||
// so order is preserved with seasonality | ||
let mut indices: Vec<usize> = (0..seas_ids.len()).collect(); | ||
indices.sort_by_key(|&i| &seas_ids[i]); | ||
|
||
let mut iterate = self.iterate; | ||
if seas_ids.len() == 1 { | ||
iterate = 1; | ||
} | ||
|
||
let mut seasonality = Vec::with_capacity(seas_ids.len()); | ||
let mut trend = Vec::new(); | ||
|
||
// TODO add lambda param | ||
let mut deseas = x.to_vec(); | ||
|
||
if !seas_ids.is_empty() { | ||
for _ in 0..seas_ids.len() { | ||
seasonality.push(Vec::new()); | ||
} | ||
|
||
for j in 0..iterate { | ||
for (i, &idx) in indices.iter().enumerate() { | ||
let np = seas_ids[idx]; | ||
|
||
if j > 0 { | ||
for (d, s) in deseas.iter_mut().zip(&seasonality[idx]) { | ||
*d += s; | ||
} | ||
} | ||
|
||
// TODO add seasonal_lengths param | ||
let fit = if self.stl_params.ns.is_some() { | ||
self.stl_params.fit(&deseas, np)? | ||
} else { | ||
let seasonal_length = 7 + 4 * (i + 1); | ||
self.stl_params.clone().seasonal_length(seasonal_length).fit(&deseas, np)? | ||
}; | ||
|
||
(seasonality[idx], trend, _, _) = fit.into_parts(); | ||
|
||
for (d, s) in deseas.iter_mut().zip(&seasonality[idx]) { | ||
*d -= s; | ||
} | ||
} | ||
} | ||
} else { | ||
// TODO use Friedman's Super Smoother for trend | ||
return Err(Error::Parameter("periods must not be empty".to_string())); | ||
} | ||
|
||
let mut remainder = Vec::with_capacity(k); | ||
for i in 0..k { | ||
remainder.push(deseas[i] - trend[i]); | ||
} | ||
|
||
Ok(MstlResult { | ||
seasonal: seasonality, | ||
trend, | ||
remainder | ||
}) | ||
} | ||
} | ||
|
||
impl Default for MstlParams { | ||
fn default() -> Self { | ||
Self::new() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
use super::stl_result::strength; | ||
|
||
#[derive(Clone, Debug)] | ||
pub struct MstlResult { | ||
pub(crate) seasonal: Vec<Vec<f32>>, | ||
pub(crate) trend: Vec<f32>, | ||
pub(crate) remainder: Vec<f32>, | ||
} | ||
|
||
impl MstlResult { | ||
pub fn seasonal(&self) -> &[Vec<f32>] { | ||
&self.seasonal[..] | ||
} | ||
|
||
pub fn trend(&self) -> &[f32] { | ||
&self.trend | ||
} | ||
|
||
pub fn remainder(&self) -> &[f32] { | ||
&self.remainder | ||
} | ||
|
||
pub fn seasonal_strength(&self) -> Vec<f32> { | ||
self.seasonal().iter().map(|s| strength(s, self.remainder())).collect() | ||
} | ||
|
||
pub fn trend_strength(&self) -> f32 { | ||
strength(self.trend(), self.remainder()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters