diff --git a/spec-trait-impl/crates/spec-trait-bin/src/main.rs b/spec-trait-impl/crates/spec-trait-bin/src/main.rs index 5294b91..7b85874 100644 --- a/spec-trait-impl/crates/spec-trait-bin/src/main.rs +++ b/spec-trait-impl/crates/spec-trait-bin/src/main.rs @@ -118,6 +118,20 @@ impl Foo for ZST { } } +#[when(T: Fn(i32) -> i32)] +impl Foo for ZST { + fn foo(&self, _x: T) { + println!("Foo impl ZST where T implements Fn(i32) -> i32"); + } +} + +#[when(T: for<'a> Fn(&'a i32) -> i32)] +impl Foo for ZST { + fn foo(&self, _x: T) { + println!("Foo impl ZST where T implements for<'a> Fn(&'a i32) -> i32"); + } +} + // ZST - Foo2 impl Foo2 for ZST { @@ -133,6 +147,13 @@ impl Foo2 for ZST { } } +#[when(U = Vec)] +impl Foo2 for ZST { + fn foo(&self, _x: T, _y: U) { + println!("Foo2 for ZST where U is Vec"); + } +} + // ZST - Foo3 #[when(T = String)] @@ -295,11 +316,14 @@ fn main() { spec! { zst.foo(1i32); ZST; [i32]; i32: Bar } // -> "Foo impl ZST where T implements Bar" spec! { zst.foo(1i64); ZST; [i64]; i64: Bar + FooBar } // -> "Foo impl ZST where T implements Bar and FooBar" spec! { zst.foo(|x: &u8| x); ZST; [fn(&u8) -> &u8] } // -> "Foo impl ZST where T is a function pointer from &u8 to &u8" + spec! { zst.foo(|x: i32| x); ZST; [T]; T: Fn(i32) -> i32 } // -> "Foo impl ZST where T implements Fn(i32) -> i32" + spec! { zst.foo(|x: &i32| *x); ZST; [T]; T: for<'a> Fn(&'a i32) -> i32 } // -> "Foo impl ZST where T implements for<'a> Fn(&'a i32) -> i32" spec! { zst.foo(1i8); ZST; [i8] } // -> "Default Foo for ZST" println!(); // ZST - Foo2 spec! { zst.foo(1u8, 2u8); ZST; [u8, u8]; u8 = MyType } // -> "Foo2 for ZST where T is MyType" + spec! { zst.foo(2u8, vec![2u8]); ZST; [u8, Vec] } // "Foo2 for ZST where U is Vec" spec! { zst.foo(1i32, 1i32); ZST; [i32, i32] } // -> "Default Foo2 for ZST" println!(); diff --git a/spec-trait-impl/crates/spec-trait-macro/src/spec.rs b/spec-trait-impl/crates/spec-trait-macro/src/spec.rs index 6067b94..437f0ca 100644 --- a/spec-trait-impl/crates/spec-trait-macro/src/spec.rs +++ b/spec-trait-impl/crates/spec-trait-macro/src/spec.rs @@ -11,7 +11,8 @@ use spec_trait_utils::impls::ImplBody; use spec_trait_utils::parsing::{get_generics_lifetimes, get_generics_types}; use spec_trait_utils::traits::TraitBody; use spec_trait_utils::types::{ - assign_lifetimes, get_concrete_type, type_assignable, type_assignable_generic_constraints, + assign_lifetimes, get_concrete_type, trait_assignable, type_assignable, + type_assignable_generic_constraints, }; use std::cmp::Ordering; @@ -206,9 +207,9 @@ fn satisfies_condition( let violates_constraints = // generic parameter is not present in the function parameters or the trait does not match - generic_var.is_none_or(|v| traits.iter().any(|t| !v.traits.contains(t))) || + generic_var.is_none_or(|v| traits.iter().any(|t| !v.traits.iter().any(|concrete_trait| trait_assignable(concrete_trait, t)))) || // generic parameter is forbidden to be implement one of the traits - constraint.not_traits.iter().any(|t| traits.contains(t)) || + constraint.not_traits.iter().any(|t| traits.iter().any(|tr| trait_assignable(tr, t))) || // generic parameter is already assigned to a type that does not implement one of the traits constraint.type_.as_ref().is_some_and(|ty| { let declared_type_var = var.vars @@ -216,7 +217,7 @@ fn satisfies_condition( .find(|v| type_assignable(&v.concrete_type, ty, &var.generics, &var.aliases) ); - declared_type_var.is_none_or(|v| traits.iter().any(|tr| !v.traits.contains(tr))) + declared_type_var.is_none_or(|v| traits.iter().any(|tr| !v.traits.iter().any(|concrete_trait| trait_assignable(concrete_trait, tr)))) }); constraint.generics = var.generics.clone(); diff --git a/spec-trait-impl/crates/spec-trait-utils/src/conditions.rs b/spec-trait-impl/crates/spec-trait-utils/src/conditions.rs index da65180..9afcfab 100644 --- a/spec-trait-impl/crates/spec-trait-utils/src/conditions.rs +++ b/spec-trait-impl/crates/spec-trait-utils/src/conditions.rs @@ -52,7 +52,7 @@ impl Display for WhenCondition { impl Hash for WhenCondition { fn hash(&self, state: &mut H) { - self.to_string().hash(state); + self.to_string().replace(" ", "").hash(state); } } diff --git a/spec-trait-impl/crates/spec-trait-utils/src/parsing.rs b/spec-trait-impl/crates/spec-trait-utils/src/parsing.rs index 263e4c1..e0ef675 100644 --- a/spec-trait-impl/crates/spec-trait-utils/src/parsing.rs +++ b/spec-trait-impl/crates/spec-trait-utils/src/parsing.rs @@ -3,8 +3,8 @@ use crate::specialize::{add_generic_type, collect_generics_lifetimes, collect_ge use quote::ToTokens; use syn::parse::ParseStream; use syn::{ - Error, GenericParam, Generics, Ident, Lifetime, PredicateLifetime, PredicateType, Token, Type, - TypeParam, WherePredicate, + BoundLifetimes, Error, GenericParam, Generics, Lifetime, PredicateLifetime, PredicateType, + Token, Type, TypeParam, TypeParamBound, WherePredicate, }; pub trait ParseTypeOrLifetimeOrTrait { @@ -25,7 +25,7 @@ pub fn parse_type_or_lifetime_or_trait, U>( if input.peek(Token![=]) { parse_type::(ident, input) } else if input.peek(Token![:]) { - parse_trait::(ident, input) + parse_trait_or_lifetime::(ident, input) } else { Err(Error::new( input.span(), @@ -34,6 +34,7 @@ pub fn parse_type_or_lifetime_or_trait, U>( } } +/// parse type from `T = Type` fn parse_type, U>( ident: &str, input: ParseStream, @@ -43,7 +44,8 @@ fn parse_type, U>( Ok(T::from_type(ident.to_string(), to_string(&type_))) } -fn parse_trait, U>( +/// parse trait(s) and lifetime from `T: Trait1 + Trait2 + 'a` +fn parse_trait_or_lifetime, U>( ident: &str, input: ParseStream, ) -> Result { @@ -62,7 +64,7 @@ fn parse_trait, U>( } lifetime = Some(input.parse::()?.to_string()); } else { - traits.push(input.parse::()?.to_string()); + traits.push(parse_trait(input)?); } if input.peek(Token![+]) { @@ -80,6 +82,25 @@ fn parse_trait, U>( Ok(T::from_trait(ident.to_string(), traits, lifetime)) } +/// parse a single trait bound, possibly with `for<'a>` lifetimes +fn parse_trait(input: ParseStream) -> Result { + if input.peek(Token![for]) { + // `for<'a>` + let bound_lifetimes: BoundLifetimes = input.parse()?; + // Trait<'a> or Fn(&'a u8) -> &'a u8 + let bound: TypeParamBound = input.parse()?; + + Ok(format!( + "{} {}", + to_string(&bound_lifetimes), + to_string(&bound) + )) + } else { + let bound: TypeParamBound = input.parse()?; + Ok(to_string(&bound)) + } +} + /** adds the generics in the where clause in the params @@ -200,7 +221,7 @@ mod tests { use super::*; use quote::quote; use syn::parse::Parse; - use syn::parse2; + use syn::{Ident, parse2}; #[derive(Debug, PartialEq)] enum MockTypeOrTrait { @@ -260,6 +281,36 @@ mod tests { ); } + #[test] + fn parse_trait_lifetimes() { + let input = quote! { MyType: for<'a> Clone<'a> }; + let result: MockTypeOrTrait = parse2(input).unwrap(); + + assert_eq!( + result, + MockTypeOrTrait::Trait( + "MyType".to_string(), + vec!["for < 'a > Clone < 'a >".to_string()], + None + ) + ); + } + + #[test] + fn parse_trait_fn() { + let input = quote! { MyType: for<'a> Fn(&'a u8) -> &'a u8 }; + let result: MockTypeOrTrait = parse2(input).unwrap(); + + assert_eq!( + result, + MockTypeOrTrait::Trait( + "MyType".to_string(), + vec!["for < 'a > Fn (& 'a u8) -> & 'a u8".to_string()], + None + ) + ); + } + #[test] fn parse_trait_multiple() { let input = quote! { MyType: Clone + Debug }; diff --git a/spec-trait-impl/crates/spec-trait-utils/src/traits.rs b/spec-trait-impl/crates/spec-trait-utils/src/traits.rs index 07e57b4..b40eb30 100644 --- a/spec-trait-impl/crates/spec-trait-utils/src/traits.rs +++ b/spec-trait-impl/crates/spec-trait-utils/src/traits.rs @@ -11,7 +11,7 @@ use crate::specialize::{ Specializable, TypeReplacer, add_generic_lifetime, add_generic_type, apply_type_condition, get_assignable_conditions, get_used_generics, handle_generics, remove_generic, }; -use crate::types::get_unique_generic_name; +use crate::types::{get_unique_generic_name, replace_type}; use proc_macro2::TokenStream; use quote::quote; use serde::{Deserialize, Serialize}; @@ -191,8 +191,27 @@ impl TraitBody { WhenCondition::Type(impl_generic, type_) => { let mut generics = str_to_generics(&self.generics); - - apply_type_condition(self, &mut generics, impl_generics, impl_generic, type_); + let mut actual_type = str_to_type_name(type_); + + // replace all generics in type with new generics names coming from replace_generics_names + // e.g. `Vec` -> `Vec<__G_0__>` + handle_generics(&to_string(impl_generics), |generic| { + if let Some(corresponding_generic) = + self.get_corresponding_generic(impl_generics, generic) + { + let type_ = str_to_type_name(&corresponding_generic); + replace_type(&mut actual_type, generic, &type_); + } + }); + + // apply the type condition + apply_type_condition( + self, + &mut generics, + impl_generics, + impl_generic, + &to_string(&actual_type), + ); self.generics = to_string(&generics); } diff --git a/spec-trait-impl/crates/spec-trait-utils/src/types.rs b/spec-trait-impl/crates/spec-trait-utils/src/types.rs index 44c5f02..52e01a5 100644 --- a/spec-trait-impl/crates/spec-trait-utils/src/types.rs +++ b/spec-trait-impl/crates/spec-trait-utils/src/types.rs @@ -91,6 +91,7 @@ pub fn type_assignable_generic_constraints( } } +/// check if concrete_type can be assigned to declared_or_concrete_type pub fn type_assignable( concrete_type: &str, declared_or_concrete_type: &str, @@ -101,6 +102,11 @@ pub fn type_assignable( .is_some() } +/// check if concrete_trait can be assigned to declared_or_concrete_trait +pub fn trait_assignable(concrete_trait: &str, declared_or_concrete_trait: &str) -> bool { + concrete_trait.replace(" ", "") == declared_or_concrete_trait.replace(" ", "") +} + /// check if concrete_type can be assigned to declared_type fn can_assign( concrete_type: &Type,