Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default_with strum macro #254

Merged
merged 3 commits into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions strum_macros/src/helpers/inner_variant_props.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use super::metadata::{InnerVariantExt, InnerVariantMeta};
use syn::{Field, LitStr};

pub trait HasInnerVariantProperties {
fn get_variant_inner_properties(&self) -> syn::Result<StrumInnerVariantProperties>;
}

#[derive(Clone, Eq, PartialEq, Debug, Default)]
pub struct StrumInnerVariantProperties {
pub default_with: Option<LitStr>,
// ident: Option<Ident>,
}

impl HasInnerVariantProperties for Field {
fn get_variant_inner_properties(&self) -> syn::Result<StrumInnerVariantProperties> {
let mut output = StrumInnerVariantProperties { default_with: None };

for meta in self.get_named_metadata()? {
match meta {
InnerVariantMeta::DefaultWith { kw: _, value } => {
output.default_with = Some(value);
}
}
}

Ok(output)
}
}
57 changes: 55 additions & 2 deletions strum_macros/src/helpers/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use syn::{
parse2, parse_str,
punctuated::Punctuated,
spanned::Spanned,
Attribute, DeriveInput, Ident, Lit, LitBool, LitStr, Meta, MetaNameValue, Path, Token, Variant,
Visibility,
Attribute, DeriveInput, Field, Ident, Lit, LitBool, LitStr, Meta, MetaNameValue, Path, Token,
Variant, Visibility,
};

use super::case_style::CaseStyle;
Expand All @@ -31,6 +31,7 @@ pub mod kw {
custom_keyword!(to_string);
custom_keyword!(disabled);
custom_keyword!(default);
custom_keyword!(default_with);
custom_keyword!(props);
custom_keyword!(ascii_case_insensitive);
}
Expand Down Expand Up @@ -178,6 +179,10 @@ pub enum VariantMeta {
},
Disabled(kw::disabled),
Default(kw::default),
DefaultWith {
kw: kw::default_with,
value: LitStr,
},
AsciiCaseInsensitive {
kw: kw::ascii_case_insensitive,
value: bool,
Expand Down Expand Up @@ -215,6 +220,11 @@ impl Parse for VariantMeta {
Ok(VariantMeta::Disabled(input.parse()?))
} else if lookahead.peek(kw::default) {
Ok(VariantMeta::Default(input.parse()?))
} else if lookahead.peek(kw::default_with) {
let kw = input.parse()?;
let _: Token![=] = input.parse()?;
let value = input.parse()?;
Ok(VariantMeta::DefaultWith { kw, value })
} else if lookahead.peek(kw::ascii_case_insensitive) {
let kw = input.parse()?;
let value = if input.peek(Token![=]) {
Expand Down Expand Up @@ -266,6 +276,7 @@ impl Spanned for VariantMeta {
VariantMeta::ToString { kw, .. } => kw.span,
VariantMeta::Disabled(kw) => kw.span,
VariantMeta::Default(kw) => kw.span,
VariantMeta::DefaultWith { kw, .. } => kw.span,
VariantMeta::AsciiCaseInsensitive { kw, .. } => kw.span,
VariantMeta::Props { kw, .. } => kw.span,
}
Expand Down Expand Up @@ -307,3 +318,45 @@ fn get_metadata_inner<'a, T: Parse + Spanned>(
Ok(vec)
})
}

#[derive(Debug)]
pub enum InnerVariantMeta {
DefaultWith { kw: kw::default_with, value: LitStr },
}

impl Spanned for InnerVariantMeta {
fn span(&self) -> Span {
match self {
InnerVariantMeta::DefaultWith { kw, .. } => kw.span(),
}
}
}

impl Parse for InnerVariantMeta {
fn parse(input: ParseStream) -> syn::Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(kw::default_with) {
let kw = input.parse()?;
let _: Token![=] = input.parse()?;
let value = input.parse()?;
Ok(InnerVariantMeta::DefaultWith { kw, value })
} else {
Err(lookahead.error())
}
}
}

pub trait InnerVariantExt {
/// Get all the metadata associated with an enum variant inner.
fn get_named_metadata(&self) -> syn::Result<Vec<InnerVariantMeta>>;
}

impl InnerVariantExt for Field {
fn get_named_metadata(&self) -> syn::Result<Vec<InnerVariantMeta>> {
let result = get_metadata_inner("strum", &self.attrs)?;
self.attrs
.iter()
.filter(|attr| attr.path.is_ident("default_with"))
.try_fold(result, |vec, _attr| Ok(vec))
}
}
2 changes: 2 additions & 0 deletions strum_macros/src/helpers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
pub use self::case_style::CaseStyleHelpers;
pub use self::inner_variant_props::HasInnerVariantProperties;
pub use self::type_props::HasTypeProperties;
pub use self::variant_props::HasStrumVariantProperties;

pub mod case_style;
pub mod inner_variant_props;
mod metadata;
pub mod type_props;
pub mod variant_props;
Expand Down
12 changes: 11 additions & 1 deletion strum_macros/src/helpers/variant_props.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub trait HasStrumVariantProperties {
pub struct StrumVariantProperties {
pub disabled: Option<kw::disabled>,
pub default: Option<kw::default>,
pub default_with: Option<LitStr>,
pub ascii_case_insensitive: Option<bool>,
pub message: Option<LitStr>,
pub detailed_message: Option<LitStr>,
Expand Down Expand Up @@ -62,9 +63,10 @@ impl HasStrumVariantProperties for Variant {

let mut message_kw = None;
let mut detailed_message_kw = None;
let mut to_string_kw = None;
let mut disabled_kw = None;
let mut default_kw = None;
let mut default_with_kw = None;
let mut to_string_kw = None;
let mut ascii_case_insensitive_kw = None;
for meta in self.get_metadata()? {
match meta {
Expand Down Expand Up @@ -114,6 +116,14 @@ impl HasStrumVariantProperties for Variant {
default_kw = Some(kw);
output.default = Some(kw);
}
VariantMeta::DefaultWith { kw, value } => {
if let Some(fst_kw) = default_with_kw {
return Err(occurrence_error(fst_kw, kw, "default_with"));
}

default_with_kw = Some(kw);
output.default_with = Some(value);
}
VariantMeta::AsciiCaseInsensitive { kw, value } => {
if let Some(fst_kw) = ascii_case_insensitive_kw {
return Err(occurrence_error(fst_kw, kw, "ascii_case_insensitive"));
Expand Down
19 changes: 11 additions & 8 deletions strum_macros/src/macros/enum_messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,17 @@ pub fn enum_message_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
if !documentation.is_empty() {
let params = params.clone();
// Strip a single leading space from each documentation line.
let documentation: Vec<LitStr> = documentation.iter().map(|lit_str| {
let line = lit_str.value();
if line.starts_with(' ') {
LitStr::new(&line.as_str()[1..], lit_str.span())
} else {
lit_str.clone()
}
}).collect();
let documentation: Vec<LitStr> = documentation
Peternator7 marked this conversation as resolved.
Show resolved Hide resolved
.iter()
.map(|lit_str| {
let line = lit_str.value();
if line.starts_with(' ') {
LitStr::new(&line.as_str()[1..], lit_str.span())
} else {
lit_str.clone()
}
})
.collect();
if documentation.len() == 1 {
let text = &documentation[0];
documentation_arms
Expand Down
42 changes: 30 additions & 12 deletions strum_macros/src/macros/strings/from_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use quote::quote;
use syn::{Data, DeriveInput, Fields};

use crate::helpers::{
non_enum_error, occurrence_error, HasStrumVariantProperties, HasTypeProperties,
non_enum_error, occurrence_error, HasInnerVariantProperties, HasStrumVariantProperties,
HasTypeProperties,
};

pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
Expand Down Expand Up @@ -45,7 +46,6 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
))
}
}

default_kw = Some(kw);
default = quote! {
::core::result::Result::Ok(#name::#ident(s.into()))
Expand All @@ -56,16 +56,34 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
let params = match &variant.fields {
Fields::Unit => quote! {},
Fields::Unnamed(fields) => {
let defaults =
::core::iter::repeat(quote!(Default::default())).take(fields.unnamed.len());
quote! { (#(#defaults),*) }
if let Some(ref value) = variant_properties.default_with {
Copy link
Contributor Author

@ericmcbride ericmcbride Feb 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure if there was a new-type variant, if we wanted to be able to default each unnamed field. It'll be easy enough to add this feature/tests

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's not too difficult, let's try to match the conventions that serde uses according to this doc:

https://serde.rs/variant-attrs.html#deserialize_with

#[serde(deserialize_with = "path")]

Deserialize this variant using a function that is different from its implementation of Deserialize. The given function must be callable as fn<'de, D>(D) -> Result<FIELDS, D::Error> where D: Deserializer<'de>, although it may also be generic over the elements of FIELDS. Variants used with deserialize_with are not required be able to derive Deserialize.

FIELDS is a tuple of all fields of the variant. A unit variant will have () as its FIELDS type.

I'm not too worried about the generic bounds parts, however, if we can return a tuple, that's probably the clearest way to do it. Thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me see what I can do!

let func = proc_macro2::Ident::new(&value.value(), value.span());
let defaults = vec![quote! { #func() }];
quote! { (#(#defaults),*) }
} else {
let defaults =
::core::iter::repeat(quote!(Default::default())).take(fields.unnamed.len());
quote! { (#(#defaults),*) }
}
}
Fields::Named(fields) => {
let fields = fields
.named
.iter()
.map(|field| field.ident.as_ref().unwrap());
quote! { {#(#fields: Default::default()),*} }
let mut defaults = vec![];
for field in &fields.named {
let meta = field.get_variant_inner_properties()?;
let field = field.ident.as_ref().unwrap();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't sure about the unwrap here. I don't like unwraps in general, and I could return an ok_or_else Syn error here. I thought it'd just be weird if an ident on a field wasn't a thing. Thoughts here?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really mind unwrapping/expecting here. Maybe just use expect and tell people to file an issue against strum if they see this message


if let Some(default_with) = meta.default_with {
let func =
proc_macro2::Ident::new(&default_with.value(), default_with.span());
defaults.push(quote! {
#field: #func()
});
} else {
defaults.push(quote! { #field: Default::default() });
}
}

quote! { {#(#defaults),*} }
}
};

Expand All @@ -79,7 +97,7 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
phf_exact_match_arms.push(quote! { #serialization => #name::#ident #params, });

if is_ascii_case_insensitive {
// Store the lowercase and UPPERCASE variants in the phf map to capture
// Store the lowercase and UPPERCASE variants in the phf map to capture
let ser_string = serialization.value();

let lower =
Expand Down Expand Up @@ -113,6 +131,7 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
}
}
};

let standard_match_body = if standard_match_arms.is_empty() {
default
} else {
Expand All @@ -134,7 +153,6 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
}
}
};

let try_from_str = try_from_str(
name,
&impl_generics,
Expand Down
52 changes: 52 additions & 0 deletions strum_tests/tests/from_str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ enum Color {
Purple,
#[strum(serialize = "blk", serialize = "Black", ascii_case_insensitive)]
Black,
Pink {
#[strum(default_with = "test_default")]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a test to make sure it works properly with a full module path

test_no_default: NoDefault,

#[strum(default_with = "string_test")]
string_test: String,
},
#[strum(default_with = "to_white")]
White(String),
}

#[rustversion::since(1.34)]
Expand Down Expand Up @@ -175,3 +184,46 @@ fn case_insensitive_enum_case_insensitive() {
assert_from_str(CaseInsensitiveEnum::CaseInsensitive, "CaseInsensitive");
assert_from_str(CaseInsensitiveEnum::CaseInsensitive, "caseinsensitive");
}

#[derive(Eq, PartialEq, Debug)]
struct NoDefault(String);

fn test_default() -> NoDefault {
NoDefault(String::from("test"))
}

fn to_white() -> String {
String::from("white-test")
}

fn string_test() -> String {
String::from("This is a string test")
}

#[test]
fn color_default_with() {
match Color::from_str("Pink").unwrap() {
Color::Pink {
test_no_default,
string_test,
} => {
assert_eq!(test_no_default, test_default());
assert_eq!(string_test, String::from("This is a string test"));
}
other => {
panic!("Failed to get correct enum value {:?}", other);
}
}
}

#[test]
fn color_default_with_white() {
match Color::from_str("White").unwrap() {
Color::White(inner) => {
assert_eq!(inner, String::from("white-test"));
}
other => {
panic!("Failed t o get correct enum value {:?}", other);
}
}
}