From 6e63ee92ef1671f38db78c6f27d358e0dfe372c4 Mon Sep 17 00:00:00 2001 From: Andrew Mitchell Date: Mon, 14 Aug 2023 12:15:58 +0100 Subject: [PATCH] Add test for maad settings parsing --- soundscapy/analysis/_AnalysisSettings.py | 9 ++++++++- test/test__AnalysisSettings.py | 12 ++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/soundscapy/analysis/_AnalysisSettings.py b/soundscapy/analysis/_AnalysisSettings.py index 9bcae66..a18643d 100644 --- a/soundscapy/analysis/_AnalysisSettings.py +++ b/soundscapy/analysis/_AnalysisSettings.py @@ -66,7 +66,9 @@ def from_yaml(cls, filename: Union[Path, str], run_stats=True, force_run_all=Fal AnalysisSettings object """ with open(filename, "r") as f: - return cls(yaml.load(f, Loader=yaml.Loader), run_stats, force_run_all, filename) + return cls( + yaml.load(f, Loader=yaml.Loader), run_stats, force_run_all, filename + ) @classmethod def default(cls, run_stats=True, force_run_all=False): @@ -132,6 +134,11 @@ def parse_maad_all_alpha_indices(self, metric: str): channel: tuple or list of str, or str channel(s) to run the metric on """ + assert metric in [ + "all_temporal_alpha_indices", + "all_spectral_alpha_indices", + ], "metric must be all_temporal_alpha_indices or all_spectral_alpha_indices." + lib_settings = self["scikit-maad"].copy() run = lib_settings[metric]["run"] or self.force_run_all channel = lib_settings[metric]["channel"].copy() diff --git a/test/test__AnalysisSettings.py b/test/test__AnalysisSettings.py index 3aeaa6e..6d4f527 100644 --- a/test/test__AnalysisSettings.py +++ b/test/test__AnalysisSettings.py @@ -120,5 +120,17 @@ def test_to_yaml(example_settings): assert saved_data == example_settings +def test_parse_maad_all_alpha_indices(): + settings = AnalysisSettings.default() + maad_settings = settings.parse_maad_all_alpha_indices("all_temporal_alpha_indices") + assert len(maad_settings) == 2 + with pytest.raises(AssertionError) as excinfo: + obj = settings.parse_maad_all_alpha_indices("missing_key") + assert ( + "metric must be all_temporal_alpha_indices or all_spectral_alpha_indices." + in str(excinfo.value) + ) + + if __name__ == "__main__": pytest.main()