Skip to content

Commit

Permalink
Implement #[zeroize(bound = "T: MyTrait")] (#663)
Browse files Browse the repository at this point in the history
  • Loading branch information
daxpedda authored Nov 11, 2021
1 parent e9a6b98 commit 63c8e60
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 3 deletions.
184 changes: 181 additions & 3 deletions zeroize/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@

use proc_macro2::TokenStream;
use quote::quote;
use syn::{Attribute, Meta, NestedMeta};
use synstructure::{decl_derive, BindStyle, BindingInfo, VariantInfo};
use syn::{
parse::{Parse, ParseStream},
punctuated::Punctuated,
token::Comma,
Attribute, Lit, Meta, NestedMeta, Result, WherePredicate,
};
use synstructure::{decl_derive, AddBounds, BindStyle, BindingInfo, VariantInfo};

decl_derive!(
[Zeroize, attributes(zeroize)] =>
Expand All @@ -18,6 +23,8 @@ decl_derive!(
///
/// On the item level:
/// - `#[zeroize(drop)]`: call `zeroize()` when this item is dropped
/// - `#[zeroize(bound = "T: MyTrait")]`: this replaces any trait bounds
/// inferred by zeroize-derive
///
/// On the field level:
/// - `#[zeroize(skip)]`: skips this field or variant when calling `zeroize()`
Expand All @@ -28,9 +35,17 @@ decl_derive!(
const ZEROIZE_ATTR: &str = "zeroize";

/// Custom derive for `Zeroize`
fn derive_zeroize(s: synstructure::Structure<'_>) -> TokenStream {
fn derive_zeroize(mut s: synstructure::Structure<'_>) -> TokenStream {
let attributes = ZeroizeAttrs::parse(&s);

if let Some(bounds) = attributes.bound {
s.add_bounds(AddBounds::None);

for bound in bounds.0 {
s.add_where_predicate(bound);
}
}

// NOTE: These are split into named functions to simplify testing with
// synstructure's `test_derive!` macro.
if attributes.drop {
Expand All @@ -45,6 +60,17 @@ fn derive_zeroize(s: synstructure::Structure<'_>) -> TokenStream {
struct ZeroizeAttrs {
/// Derive a `Drop` impl which calls zeroize on this type
drop: bool,
/// Custom bounds as defined by the user
bound: Option<Bounds>,
}

/// Parsing helper for custom bounds
struct Bounds(Punctuated<WherePredicate, Comma>);

impl Parse for Bounds {
fn parse(input: ParseStream<'_>) -> Result<Self> {
Ok(Self(Punctuated::parse_terminated(input)?))
}
}

impl ZeroizeAttrs {
Expand Down Expand Up @@ -134,6 +160,49 @@ impl ZeroizeAttrs {
};

self.drop = true;
} else if meta.path().is_ident("bound") {
assert!(self.bound.is_none(), "duplicate #[zeroize] bound flags");

match (variant, binding) {
(_variant, Some(_binding)) => {
// structs don't have a variant prefix, and only structs have bindings outside of a variant
let item_kind = match variant.and_then(|variant| variant.prefix) {
Some(_) => "enum",
None => "struct",
};
panic!(
concat!(
"The #[zeroize(bound)] attribute is not allowed on {} fields. ",
"Use it on the containing {} instead.",
),
item_kind, item_kind,
)
}
(Some(_variant), None) => panic!(concat!(
"The #[zeroize(bound)] attribute is not allowed on enum variants. ",
"Use it on the containing enum instead.",
)),
(None, None) => {
if let Meta::NameValue(meta_name_value) = meta {
if let Lit::Str(lit) = &meta_name_value.lit {
if lit.value().is_empty() {
self.bound = Some(Bounds(Punctuated::new()));
} else {
self.bound = Some(lit.parse().unwrap_or_else(|e| {
panic!("error parsing bounds: {:?} ({})", lit, e)
}));
}

return;
}
}

panic!(concat!(
"The #[zeroize(bound)] attribute expects a name-value syntax with a string literal value.",
"E.g. #[zeroize(bound = \"T: MyTrait\")]."
))
}
}
} else if meta.path().is_ident("skip") {
if variant.is_none() && binding.is_none() {
panic!(concat!(
Expand Down Expand Up @@ -341,6 +410,35 @@ mod tests {
}
}

#[test]
fn zeroize_with_bound() {
test_derive! {
derive_zeroize {
#[zeroize(bound = "T: MyTrait")]
struct Z<T>(T);
}
expands to {
#[allow(non_upper_case_globals)]
#[doc(hidden)]
const _DERIVE_zeroize_Zeroize_FOR_Z: () = {
extern crate zeroize;
impl<T> zeroize::Zeroize for Z<T>
where T: MyTrait
{
fn zeroize(&mut self) {
match self {
Z(ref mut __binding_0,) => {
{ __binding_0.zeroize(); }
}
}
}
}
};
}
no_build // tests the code compiles are in the `zeroize` crate
}
}

#[test]
fn zeroize_on_struct() {
parse_zeroize_test(stringify!(
Expand Down Expand Up @@ -537,6 +635,86 @@ mod tests {
));
}

#[test]
#[should_panic(expected = "duplicate #[zeroize] bound flags")]
fn zeroize_duplicate_bound() {
parse_zeroize_test(stringify!(
#[zeroize(bound = "T: MyTrait")]
#[zeroize(bound = "")]
struct Z<T>(T);
));
}

#[test]
#[should_panic(expected = "duplicate #[zeroize] bound flags")]
fn zeroize_duplicate_bound_list() {
parse_zeroize_test(stringify!(
#[zeroize(bound = "T: MyTrait", bound = "")]
struct Z<T>(T);
));
}

#[test]
#[should_panic(
expected = "The #[zeroize(bound)] attribute is not allowed on struct fields. Use it on the containing struct instead."
)]
fn zeroize_bound_struct() {
parse_zeroize_test(stringify!(
struct Z<T> {
#[zeroize(bound = "T: MyTrait")]
a: T,
}
));
}

#[test]
#[should_panic(
expected = "The #[zeroize(bound)] attribute is not allowed on enum variants. Use it on the containing enum instead."
)]
fn zeroize_bound_enum() {
parse_zeroize_test(stringify!(
enum Z<T> {
#[zeroize(bound = "T: MyTrait")]
A(T),
}
));
}

#[test]
#[should_panic(
expected = "The #[zeroize(bound)] attribute is not allowed on enum fields. Use it on the containing enum instead."
)]
fn zeroize_bound_enum_variant_field() {
parse_zeroize_test(stringify!(
enum Z<T> {
A {
#[zeroize(bound = "T: MyTrait")]
a: T,
},
}
));
}

#[test]
#[should_panic(
expected = "The #[zeroize(bound)] attribute expects a name-value syntax with a string literal value.E.g. #[zeroize(bound = \"T: MyTrait\")]."
)]
fn zeroize_bound_no_value() {
parse_zeroize_test(stringify!(
#[zeroize(bound)]
struct Z<T>(T);
));
}

#[test]
#[should_panic(expected = "error parsing bounds: LitStr { token: \"T\" } (expected `:`)")]
fn zeroize_bound_no_where_predicate() {
parse_zeroize_test(stringify!(
#[zeroize(bound = "T")]
struct Z<T>(T);
));
}

fn parse_zeroize_test(unparsed: &str) -> TokenStream {
derive_zeroize(Structure::new(
&parse_str(unparsed).expect("Failed to parse test input"),
Expand Down
2 changes: 2 additions & 0 deletions zeroize/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
//!
//! On the item level:
//! - `#[zeroize(drop)]`: call `zeroize()` when this item is dropped
//! - `#[zeroize(bound = "T: MyTrait")]`: this replaces any trait bounds
//! inferred by zeroize
//!
//! On the field level:
//! - `#[zeroize(skip)]`: skips this field or variant when calling `zeroize()`
Expand Down
17 changes: 17 additions & 0 deletions zeroize/tests/zeroize_derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,4 +208,21 @@ mod custom_derive_tests {
!boolean
));
}

#[test]
fn derive_bound() {
trait T: Zeroize {}

impl T for u8 {}

#[derive(Zeroize)]
#[zeroize(bound = "X: T")]
struct Z<X>(X);

let mut value = Z(5_u8);

value.zeroize();

assert_eq!(value.0, 0);
}
}

0 comments on commit 63c8e60

Please sign in to comment.