Skip to content

Commit

Permalink
Fix custom functions for optional number types (#315)
Browse files Browse the repository at this point in the history
* Dynamically build the list of number types

Round-tripping these through `quote!().to_token_stream().to_string()`
lets us ensure that the resulting strings always match.

Signed-off-by: Johannes Löthberg <johannes.loethberg@elokon.com>

* Don't destructure number types by reference

Signed-off-by: Johannes Löthberg <johannes.loethberg@elokon.com>

* Add test for custom fns on number types not taking a reference

Signed-off-by: Johannes Löthberg <johannes.loethberg@elokon.com>

---------

Signed-off-by: Johannes Löthberg <johannes.loethberg@elokon.com>
  • Loading branch information
kyrias authored and Keats committed Apr 5, 2024
1 parent dc8cf02 commit 3d202fe
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 51 deletions.
1 change: 1 addition & 0 deletions validator_derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ quote = "1"
proc-macro2 = "1"
proc-macro-error = "1"
darling = { version = "0.20", features = ["suggestions"] }
once_cell = "1.18.0"
8 changes: 5 additions & 3 deletions validator_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ impl ToTokens for ValidateField {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
let field_name = self.ident.clone().unwrap();
let field_name_str = self.ident.clone().unwrap().to_string();
let (actual_field, wrapper_closure) = self.if_let_option_wrapper(&field_name);

let type_name = self.ty.to_token_stream().to_string();
let is_number = NUMBER_TYPES.contains(&type_name);

let (actual_field, wrapper_closure) = self.if_let_option_wrapper(&field_name, is_number);

// Length validation
let length = if let Some(length) = self.length.clone() {
Expand Down Expand Up @@ -167,9 +171,7 @@ impl ToTokens for ValidateField {
// Custom validation
let mut custom = quote!();
// We try to be smart when passing arguments
let type_name = self.ty.to_token_stream().to_string();
let is_cow = type_name.contains("Cow <");
let is_number = NUMBER_TYPES.contains(&type_name.as_str());
let custom_actual_field = if is_cow {
quote!(#actual_field.as_ref())
} else if is_number || type_name.starts_with("&") {
Expand Down
80 changes: 33 additions & 47 deletions validator_derive/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use once_cell::sync::Lazy;

use darling::util::Override;
use darling::{FromField, FromMeta};

Expand All @@ -9,50 +11,32 @@ use syn::{Expr, Field, Ident, Path};
use crate::utils::get_attr;

static OPTIONS_TYPE: [&str; 3] = ["Option|", "std|option|Option|", "core|option|Option|"];
pub(crate) static NUMBER_TYPES: [&str; 42] = [
"usize",
"u8",
"u16",
"u32",
"u64",
"u128",
"isize",
"i8",
"i16",
"i32",
"i64",
"i128",
"f32",
"f64",
"Option<usize>",
"Option<u8>",
"Option<u16>",
"Option<u32>",
"Option<u64>",
"Option<u128>",
"Option<isize>",
"Option<i8>",
"Option<i16>",
"Option<i32>",
"Option<i64>",
"Option<i128>",
"Option<f32>",
"Option<f64>",
"Option<Option<usize>>",
"Option<Option<u8>>",
"Option<Option<u16>>",
"Option<Option<u32>>",
"Option<Option<u64>>",
"Option<Option<u128>>",
"Option<Option<isize>>",
"Option<Option<i8>>",
"Option<Option<i16>>",
"Option<Option<i32>>",
"Option<Option<i64>>",
"Option<Option<i128>>",
"Option<Option<f32>>",
"Option<Option<f64>>",
];

pub(crate) static NUMBER_TYPES: Lazy<Vec<String>> = Lazy::new(|| {
let number_types = [
quote!(usize),
quote!(u8),
quote!(u16),
quote!(u32),
quote!(u64),
quote!(u128),
quote!(isize),
quote!(i8),
quote!(i16),
quote!(i32),
quote!(i64),
quote!(i128),
quote!(f32),
quote!(f64),
];
let mut tys = Vec::with_capacity(number_types.len() * 3);
for ty in number_types {
tys.push(ty.to_string());
tys.push(quote!(Option<#ty>).to_string());
tys.push(quote!(Option<Option<#ty> >).to_string());
}
tys
});

// This struct holds all the validation information on a field
// The "ident" and "ty" fields are populated by `darling`
Expand Down Expand Up @@ -194,21 +178,23 @@ impl ValidateField {
pub fn if_let_option_wrapper(
&self,
field_name: &Ident,
is_number_type: bool,
) -> (proc_macro2::TokenStream, Box<dyn Fn(proc_macro2::TokenStream) -> proc_macro2::TokenStream>)
{
let number_options = self.number_options();
let field_name = field_name.clone();
let actual_field =
if number_options > 0 { quote!(#field_name) } else { quote!(self.#field_name) };
let option_val = quote!(ref #field_name);
let binding_pattern =
if is_number_type { quote!(#field_name) } else { quote!(ref #field_name) };

match number_options {
0 => (actual_field.clone(), Box::new(move |tokens| tokens)),
1 => (
actual_field.clone(),
Box::new(move |tokens| {
quote!(
if let Some(#option_val) = self.#field_name {
if let Some(#binding_pattern) = self.#field_name {
#tokens
}
)
Expand All @@ -218,7 +204,7 @@ impl ValidateField {
actual_field.clone(),
Box::new(move |tokens| {
quote!(
if let Some(Some(#option_val)) = self.#field_name {
if let Some(Some(#binding_pattern)) = self.#field_name {
#tokens
}
)
Expand Down
41 changes: 40 additions & 1 deletion validator_derive_tests/tests/custom.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use validator::{Validate, ValidationError};
use std::collections::HashMap;

use validator::{Validate, ValidationError, ValidationErrors, ValidationErrorsKind};

fn valid_custom_fn(_: &String) -> Result<(), ValidationError> {
Ok(())
Expand Down Expand Up @@ -119,3 +121,40 @@ fn can_nest_custom_validations() {
let t = TestStruct { a: A { val: "invalid value".to_string() } };
assert!(t.validate().is_err());
}

#[test]
fn custom_fn_on_optional_types_work() {
fn number_type_custom_fn(val: i16) -> Result<(), ValidationError> {
if val == 0 {
Ok(())
} else {
Err(ValidationError::new("custom"))
}
}

#[derive(Validate)]
struct TestStruct {
#[validate(custom(function = number_type_custom_fn))]
plain: i16,
#[validate(custom(function = number_type_custom_fn))]
option: Option<i16>,
#[validate(custom(function = number_type_custom_fn))]
option_option: Option<Option<i16>>,
}

let t = TestStruct { plain: 0, option: Some(0), option_option: Some(Some(0)) };
assert!(t.validate().is_ok());

let t = TestStruct { plain: 1, option: Some(1), option_option: Some(Some(1)) };
let mut error = ValidationError::new("custom");
error.add_param("value".into(), &1);
let error_kind = ValidationErrorsKind::Field(vec![{ error }]);
assert_eq!(
t.validate(),
Err(ValidationErrors(HashMap::from_iter([
("plain", error_kind.clone()),
("option", error_kind.clone()),
("option_option", error_kind),
])))
);
}

0 comments on commit 3d202fe

Please sign in to comment.