Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions spec-trait-impl/crates/spec-trait-bin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,20 @@ impl<T> Foo<T> for ZST {
}
}

#[when(T: Fn(i32) -> i32)]
impl<T> Foo<T> for ZST {
fn foo(&self, _x: T) {
println!("Foo impl ZST where T implements Fn(i32) -> i32");
}
}

#[when(T: for<'a> Fn(&'a i32) -> i32)]
impl<T> Foo<T> for ZST {
fn foo(&self, _x: T) {
println!("Foo impl ZST where T implements for<'a> Fn(&'a i32) -> i32");
}
}

// ZST - Foo2

impl<T, U> Foo2<T, U> for ZST {
Expand All @@ -133,6 +147,13 @@ impl<T, U> Foo2<T, U> for ZST {
}
}

#[when(U = Vec<T>)]
impl<T, U> Foo2<T, U> for ZST {
fn foo(&self, _x: T, _y: U) {
println!("Foo2 for ZST where U is Vec<T>");
}
}

// ZST - Foo3

#[when(T = String)]
Expand Down Expand Up @@ -295,11 +316,14 @@ fn main() {
spec! { zst.foo(1i32); ZST; [i32]; i32: Bar } // -> "Foo impl ZST where T implements Bar"
spec! { zst.foo(1i64); ZST; [i64]; i64: Bar + FooBar } // -> "Foo impl ZST where T implements Bar and FooBar"
spec! { zst.foo(|x: &u8| x); ZST; [fn(&u8) -> &u8] } // -> "Foo impl ZST where T is a function pointer from &u8 to &u8"
spec! { zst.foo(|x: i32| x); ZST; [T]; T: Fn(i32) -> i32 } // -> "Foo impl ZST where T implements Fn(i32) -> i32"
spec! { zst.foo(|x: &i32| *x); ZST; [T]; T: for<'a> Fn(&'a i32) -> i32 } // -> "Foo impl ZST where T implements for<'a> Fn(&'a i32) -> i32"
spec! { zst.foo(1i8); ZST; [i8] } // -> "Default Foo for ZST"
println!();

// ZST - Foo2
spec! { zst.foo(1u8, 2u8); ZST; [u8, u8]; u8 = MyType } // -> "Foo2 for ZST where T is MyType"
spec! { zst.foo(2u8, vec![2u8]); ZST; [u8, Vec<u8>] } // "Foo2 for ZST where U is Vec<T>"
spec! { zst.foo(1i32, 1i32); ZST; [i32, i32] } // -> "Default Foo2 for ZST"
println!();

Expand Down
9 changes: 5 additions & 4 deletions spec-trait-impl/crates/spec-trait-macro/src/spec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ use spec_trait_utils::impls::ImplBody;
use spec_trait_utils::parsing::{get_generics_lifetimes, get_generics_types};
use spec_trait_utils::traits::TraitBody;
use spec_trait_utils::types::{
assign_lifetimes, get_concrete_type, type_assignable, type_assignable_generic_constraints,
assign_lifetimes, get_concrete_type, trait_assignable, type_assignable,
type_assignable_generic_constraints,
};
use std::cmp::Ordering;

Expand Down Expand Up @@ -206,17 +207,17 @@ fn satisfies_condition(

let violates_constraints =
// generic parameter is not present in the function parameters or the trait does not match
generic_var.is_none_or(|v| traits.iter().any(|t| !v.traits.contains(t))) ||
generic_var.is_none_or(|v| traits.iter().any(|t| !v.traits.iter().any(|concrete_trait| trait_assignable(concrete_trait, t)))) ||
// generic parameter is forbidden to be implement one of the traits
constraint.not_traits.iter().any(|t| traits.contains(t)) ||
constraint.not_traits.iter().any(|t| traits.iter().any(|tr| trait_assignable(tr, t))) ||
// generic parameter is already assigned to a type that does not implement one of the traits
constraint.type_.as_ref().is_some_and(|ty| {
let declared_type_var = var.vars
.iter()
.find(|v|
type_assignable(&v.concrete_type, ty, &var.generics, &var.aliases)
);
declared_type_var.is_none_or(|v| traits.iter().any(|tr| !v.traits.contains(tr)))
declared_type_var.is_none_or(|v| traits.iter().any(|tr| !v.traits.iter().any(|concrete_trait| trait_assignable(concrete_trait, tr))))
});

constraint.generics = var.generics.clone();
Expand Down
2 changes: 1 addition & 1 deletion spec-trait-impl/crates/spec-trait-utils/src/conditions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl Display for WhenCondition {

impl Hash for WhenCondition {
fn hash<H: Hasher>(&self, state: &mut H) {
self.to_string().hash(state);
self.to_string().replace(" ", "").hash(state);
}
}

Expand Down
63 changes: 57 additions & 6 deletions spec-trait-impl/crates/spec-trait-utils/src/parsing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use crate::specialize::{add_generic_type, collect_generics_lifetimes, collect_ge
use quote::ToTokens;
use syn::parse::ParseStream;
use syn::{
Error, GenericParam, Generics, Ident, Lifetime, PredicateLifetime, PredicateType, Token, Type,
TypeParam, WherePredicate,
BoundLifetimes, Error, GenericParam, Generics, Lifetime, PredicateLifetime, PredicateType,
Token, Type, TypeParam, TypeParamBound, WherePredicate,
};

pub trait ParseTypeOrLifetimeOrTrait<T> {
Expand All @@ -25,7 +25,7 @@ pub fn parse_type_or_lifetime_or_trait<T: ParseTypeOrLifetimeOrTrait<U>, U>(
if input.peek(Token![=]) {
parse_type::<T, U>(ident, input)
} else if input.peek(Token![:]) {
parse_trait::<T, U>(ident, input)
parse_trait_or_lifetime::<T, U>(ident, input)
} else {
Err(Error::new(
input.span(),
Expand All @@ -34,6 +34,7 @@ pub fn parse_type_or_lifetime_or_trait<T: ParseTypeOrLifetimeOrTrait<U>, U>(
}
}

/// parse type from `T = Type`
fn parse_type<T: ParseTypeOrLifetimeOrTrait<U>, U>(
ident: &str,
input: ParseStream,
Expand All @@ -43,7 +44,8 @@ fn parse_type<T: ParseTypeOrLifetimeOrTrait<U>, U>(
Ok(T::from_type(ident.to_string(), to_string(&type_)))
}

fn parse_trait<T: ParseTypeOrLifetimeOrTrait<U>, U>(
/// parse trait(s) and lifetime from `T: Trait1 + Trait2 + 'a`
fn parse_trait_or_lifetime<T: ParseTypeOrLifetimeOrTrait<U>, U>(
ident: &str,
input: ParseStream,
) -> Result<U, Error> {
Expand All @@ -62,7 +64,7 @@ fn parse_trait<T: ParseTypeOrLifetimeOrTrait<U>, U>(
}
lifetime = Some(input.parse::<Lifetime>()?.to_string());
} else {
traits.push(input.parse::<Ident>()?.to_string());
traits.push(parse_trait(input)?);
}

if input.peek(Token![+]) {
Expand All @@ -80,6 +82,25 @@ fn parse_trait<T: ParseTypeOrLifetimeOrTrait<U>, U>(
Ok(T::from_trait(ident.to_string(), traits, lifetime))
}

/// parse a single trait bound, possibly with `for<'a>` lifetimes
fn parse_trait(input: ParseStream) -> Result<String, Error> {
if input.peek(Token![for]) {
// `for<'a>`
let bound_lifetimes: BoundLifetimes = input.parse()?;
// Trait<'a> or Fn(&'a u8) -> &'a u8
let bound: TypeParamBound = input.parse()?;

Ok(format!(
"{} {}",
to_string(&bound_lifetimes),
to_string(&bound)
))
} else {
let bound: TypeParamBound = input.parse()?;
Ok(to_string(&bound))
}
}

/**
adds the generics in the where clause in the params

Expand Down Expand Up @@ -200,7 +221,7 @@ mod tests {
use super::*;
use quote::quote;
use syn::parse::Parse;
use syn::parse2;
use syn::{Ident, parse2};

#[derive(Debug, PartialEq)]
enum MockTypeOrTrait {
Expand Down Expand Up @@ -260,6 +281,36 @@ mod tests {
);
}

#[test]
fn parse_trait_lifetimes() {
let input = quote! { MyType: for<'a> Clone<'a> };
let result: MockTypeOrTrait = parse2(input).unwrap();

assert_eq!(
result,
MockTypeOrTrait::Trait(
"MyType".to_string(),
vec!["for < 'a > Clone < 'a >".to_string()],
None
)
);
}

#[test]
fn parse_trait_fn() {
let input = quote! { MyType: for<'a> Fn(&'a u8) -> &'a u8 };
let result: MockTypeOrTrait = parse2(input).unwrap();

assert_eq!(
result,
MockTypeOrTrait::Trait(
"MyType".to_string(),
vec!["for < 'a > Fn (& 'a u8) -> & 'a u8".to_string()],
None
)
);
}

#[test]
fn parse_trait_multiple() {
let input = quote! { MyType: Clone + Debug };
Expand Down
25 changes: 22 additions & 3 deletions spec-trait-impl/crates/spec-trait-utils/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::specialize::{
Specializable, TypeReplacer, add_generic_lifetime, add_generic_type, apply_type_condition,
get_assignable_conditions, get_used_generics, handle_generics, remove_generic,
};
use crate::types::get_unique_generic_name;
use crate::types::{get_unique_generic_name, replace_type};
use proc_macro2::TokenStream;
use quote::quote;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -191,8 +191,27 @@ impl TraitBody {

WhenCondition::Type(impl_generic, type_) => {
let mut generics = str_to_generics(&self.generics);

apply_type_condition(self, &mut generics, impl_generics, impl_generic, type_);
let mut actual_type = str_to_type_name(type_);

// replace all generics in type with new generics names coming from replace_generics_names
// e.g. `Vec<T>` -> `Vec<__G_0__>`
handle_generics(&to_string(impl_generics), |generic| {
if let Some(corresponding_generic) =
self.get_corresponding_generic(impl_generics, generic)
{
let type_ = str_to_type_name(&corresponding_generic);
replace_type(&mut actual_type, generic, &type_);
}
});

// apply the type condition
apply_type_condition(
self,
&mut generics,
impl_generics,
impl_generic,
&to_string(&actual_type),
);

self.generics = to_string(&generics);
}
Expand Down
6 changes: 6 additions & 0 deletions spec-trait-impl/crates/spec-trait-utils/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ pub fn type_assignable_generic_constraints(
}
}

/// check if concrete_type can be assigned to declared_or_concrete_type
pub fn type_assignable(
concrete_type: &str,
declared_or_concrete_type: &str,
Expand All @@ -101,6 +102,11 @@ pub fn type_assignable(
.is_some()
}

/// check if concrete_trait can be assigned to declared_or_concrete_trait
pub fn trait_assignable(concrete_trait: &str, declared_or_concrete_trait: &str) -> bool {
concrete_trait.replace(" ", "") == declared_or_concrete_trait.replace(" ", "")
}

/// check if concrete_type can be assigned to declared_type
fn can_assign(
concrete_type: &Type,
Expand Down