Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend #[derive(TransparentWrapper)] #147

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,11 @@ pub fn derive_maybe_pod(
/// - The struct must contain the `Wrapped` type
///
/// If the struct only contains a single field, the `Wrapped` type will
/// automatically be determined if there is more then one field in the struct,
/// you need to specify the `Wrapped` type using `#[transparent(T)]`
/// automatically be determined. If there is more then one field in the struct,
/// you need to specify the `Wrapped` type using `#[transparent(T)]`. Due to
/// technical limitations, the type in the `#[transparent(Type)]` needs to be
/// the exact same token sequence as the corresponding type in the struct
/// definition.
///
/// ## Example
///
Expand Down Expand Up @@ -252,6 +255,12 @@ fn derive_marker_trait_inner<Trait: Derivable>(
quote!()
};

let where_clause = if Trait::requires_where_clause() {
where_clause
} else {
None
};

Ok(quote! {
#asserts

Expand Down
70 changes: 50 additions & 20 deletions derive/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ pub trait Derivable {
fn trait_impl(_input: &DeriveInput) -> Result<(TokenStream, TokenStream)> {
Ok((quote!(), quote!()))
}
fn requires_where_clause() -> bool {
true
}
}

pub struct Pod;
Expand Down Expand Up @@ -234,19 +237,27 @@ impl Derivable for CheckedBitPattern {

pub struct TransparentWrapper;

struct WrappedType {
wrapped_type: syn::Type,
/// Was the type given with a #[transparent(Type)] attribute.
explicit: bool,
}

impl TransparentWrapper {
fn get_wrapper_type(
fn get_wrapped_type(
attributes: &[Attribute], fields: &Fields,
) -> Option<TokenStream> {
let transparent_param = get_simple_attr(attributes, "transparent");
transparent_param.map(|ident| ident.to_token_stream()).or_else(|| {
) -> Option<WrappedType> {
let transparent_param = get_type_from_simple_attr(attributes, "transparent").map(|wrapped_type| WrappedType {
wrapped_type, explicit: true,
});
transparent_param.or_else(|| {
let mut types = get_field_types(&fields);
let first_type = types.next();
if let Some(_) = types.next() {
// can't guess param type if there is more than one field
return None;
} else {
first_type.map(|ty| ty.to_token_stream())
first_type.cloned().map(|wrapped_type| WrappedType { wrapped_type, explicit: false })
}
})
}
Expand All @@ -256,7 +267,7 @@ impl Derivable for TransparentWrapper {
fn ident(input: &DeriveInput) -> Result<syn::Path> {
let fields = get_struct_fields(input)?;

let ty = match Self::get_wrapper_type(&input.attrs, &fields) {
let WrappedType { wrapped_type: ty, .. } = match Self::get_wrapped_type(&input.attrs, &fields) {
Some(ty) => ty,
None => bail!(
"\
Expand All @@ -271,15 +282,23 @@ impl Derivable for TransparentWrapper {

fn asserts(input: &DeriveInput) -> Result<TokenStream> {
let fields = get_struct_fields(input)?;
let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) {
Some(wrapped_type) => wrapped_type.to_string(),
let (wrapped_type, explicit) = match Self::get_wrapped_type(&input.attrs, &fields) {
Some(WrappedType { wrapped_type, explicit }) => (wrapped_type.to_token_stream().to_string(), explicit),
None => unreachable!(), /* other code will already reject this derive */
};
dbg!(&wrapped_type);
let mut wrapped_fields = fields
.iter()
.filter(|field| field.ty.to_token_stream().to_string() == wrapped_type);
.filter(|field| dbg!(field.ty.to_token_stream().to_string()) == wrapped_type);
if let None = wrapped_fields.next() {
bail!("TransparentWrapper must have one field of the wrapped type");
if explicit {
bail!("TransparentWrapper must have one field of the wrapped type. \
The type given in `#[transparent(Type)]` must match tokenwise\
with the type in the struct definition, not just be the same type");
} else {
bail!("TransparentWrapper must have one field of the wrapped type");
}

};
if let Some(_) = wrapped_fields.next() {
bail!("TransparentWrapper can only have one field of the wrapped type")
Expand All @@ -300,6 +319,10 @@ impl Derivable for TransparentWrapper {
}
}
}

fn requires_where_clause() -> bool {
false
}
}

pub struct Contiguous;
Expand Down Expand Up @@ -534,24 +557,31 @@ fn generate_fields_are_trait(
})
}

fn get_ident_from_stream(tokens: TokenStream) -> Option<Ident> {
match tokens.into_iter().next() {
Some(TokenTree::Group(group)) => get_ident_from_stream(group.stream()),
Some(TokenTree::Ident(ident)) => Some(ident),
_ => None,
fn get_wrapped_type_from_stream(tokens: TokenStream) -> Option<syn::Type> {
let mut tokens = tokens.into_iter().peekable();
match tokens.peek() {
Some(TokenTree::Group(group)) => {
let res = get_wrapped_type_from_stream(group.stream());
tokens.next();
match tokens.next() {
// If there were more tokens, the input was invalid
Some(_) => None,
None => res,
}
},
_ => syn::parse2(tokens.collect()).ok(),
}
}

/// get a simple #[foo(bar)] attribute, returning "bar"
fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option<Ident> {
/// get a simple `#[foo(bar)]` attribute, returning `bar`
fn get_type_from_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option<syn::Type> {
for attr in attributes {
if let (AttrStyle::Outer, Some(outer_ident), Some(inner_ident)) = (
if let (AttrStyle::Outer, Some(outer_ident)) = (
&attr.style,
attr.path.get_ident(),
get_ident_from_stream(attr.tokens.clone()),
) {
if outer_ident.to_string() == attr_name {
return Some(inner_ident);
return get_wrapped_type_from_stream(attr.tokens.clone());
}
}
}
Expand Down
27 changes: 27 additions & 0 deletions derive/tests/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,33 @@ struct AnyBitPatternTest {
#[repr(transparent)]
struct NewtypeWrapperTest<T>(T);

#[derive(
Debug, Clone, PartialEq, Eq, TransparentWrapper,
)]
#[repr(transparent)]
struct AlgebraicNewtypeWrapperTest<T>(Vec<T>);

#[test]
fn algebraic_newtype_corect() {
let x: Vec<u32> = vec![1, 2, 3, 4];
let y: AlgebraicNewtypeWrapperTest<u32> = AlgebraicNewtypeWrapperTest::wrap(x.clone());
assert_eq!(y.0, x);
}

#[derive(
Debug, Clone, PartialEq, Eq, TransparentWrapper,
)]
#[repr(transparent)]
#[transparent(Vec<T>)]
struct AlgebraicNewtypeWrapperTestWithFields<T, U>(Vec<T>, PhantomData<U>);

#[test]
fn algebraic_newtype_fields_corect() {
let x: Vec<u32> = vec![1, 2, 3, 4];
let y: AlgebraicNewtypeWrapperTestWithFields<u32, f32> = AlgebraicNewtypeWrapperTestWithFields::wrap(x.clone());
assert_eq!(y.0, x);
}

#[test]
fn fails_cast_contiguous() {
let can_cast = CheckedBitPatternEnumWithValues::is_valid_bit_pattern(&5);
Expand Down