diff --git a/crates/snapbox/src/data/format.rs b/crates/snapbox/src/data/format.rs index 7ab71014..40464676 100644 --- a/crates/snapbox/src/data/format.rs +++ b/crates/snapbox/src/data/format.rs @@ -42,18 +42,66 @@ impl From<&std::path::Path> for DataFormat { .file_name() .and_then(|e| e.to_str()) .unwrap_or_default(); - let (file_stem, mut ext) = file_name.split_once('.').unwrap_or((file_name, "")); - if file_stem.is_empty() { - (_, ext) = file_stem.split_once('.').unwrap_or((file_name, "")); + let mut ext = file_name.strip_prefix('.').unwrap_or(file_name); + while let Some((_, new_ext)) = ext.split_once('.') { + ext = new_ext; + match ext { + #[cfg(feature = "json")] + "json" => { + return DataFormat::Json; + } + #[cfg(feature = "json")] + "jsonl" => { + return DataFormat::JsonLines; + } + #[cfg(feature = "term-svg")] + "term.svg" => { + return Self::TermSvg; + } + _ => {} + } } - match ext { - #[cfg(feature = "json")] - "json" => DataFormat::Json, - #[cfg(feature = "json")] - "jsonl" => DataFormat::JsonLines, - #[cfg(feature = "term-svg")] - "term.svg" => Self::TermSvg, - _ => DataFormat::Text, + DataFormat::Text + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn combos() { + #[cfg(feature = "json")] + let json = DataFormat::Json; + #[cfg(not(feature = "json"))] + let json = DataFormat::Text; + #[cfg(feature = "json")] + let jsonl = DataFormat::JsonLines; + #[cfg(not(feature = "json"))] + let jsonl = DataFormat::Text; + #[cfg(feature = "term-svg")] + let term_svg = DataFormat::TermSvg; + #[cfg(not(feature = "term-svg"))] + let term_svg = DataFormat::Text; + let cases = [ + ("foo", DataFormat::Text), + (".foo", DataFormat::Text), + ("foo.txt", DataFormat::Text), + (".foo.txt", DataFormat::Text), + ("foo.stdout.txt", DataFormat::Text), + ("foo.json", json), + ("foo.stdout.json", json), + (".foo.json", json), + ("foo.jsonl", jsonl), + ("foo.stdout.jsonl", jsonl), + (".foo.jsonl", jsonl), + ("foo.term.svg", term_svg), + ("foo.stdout.term.svg", term_svg), + (".foo.term.svg", term_svg), + ]; + for (input, output) in cases { + let input = std::path::Path::new(input); + assert_eq!(DataFormat::from(input), output, "for `{}`", input.display()); } } }