From 5d91e602cea708ca40ed2dc2648f459ba6404cdd Mon Sep 17 00:00:00 2001 From: Nathaniel McCallum Date: Sat, 26 Feb 2022 10:31:27 -0500 Subject: [PATCH] feat: add new Newtype derive There are a number of der types that could benefit from wrapping in a newtype. For example, `RelativeDistinguishedName` could benefit from an `impl Display` to convert it to a string representation. But because it is just `Vec` we can't really do this. Having a systematic way to derive newtypes in these cases is thus beneficial. Signed-off-by: Nathaniel McCallum --- der/derive/src/lib.rs | 14 +++++ der/derive/src/newtype.rs | 118 ++++++++++++++++++++++++++++++++++++++ der/tests/derive.rs | 35 +++++++++++ 3 files changed, 167 insertions(+) create mode 100644 der/derive/src/newtype.rs diff --git a/der/derive/src/lib.rs b/der/derive/src/lib.rs index d91871039..bf0684c7d 100644 --- a/der/derive/src/lib.rs +++ b/der/derive/src/lib.rs @@ -120,6 +120,7 @@ mod asn1_type; mod attributes; mod choice; mod enumerated; +mod newtype; mod sequence; mod tag; mod value_ord; @@ -129,6 +130,7 @@ use crate::{ attributes::{FieldAttrs, TypeAttrs, ATTR_NAME}, choice::DeriveChoice, enumerated::DeriveEnumerated, + newtype::DeriveNewtype, sequence::DeriveSequence, tag::{Tag, TagMode, TagNumber}, value_ord::DeriveValueOrd, @@ -269,3 +271,15 @@ pub fn derive_value_ord(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); DeriveValueOrd::new(input).to_tokens().into() } + +/// Wraps a der type in a newtype. +/// +/// The newtype receives implementations of `der::FixedTag`, +/// `der::DecodeValue`, `der::EncodeValue`, `Deref`, `DerefMut`, and +/// bi-directional `From`. +#[proc_macro_derive(Newtype)] +#[proc_macro_error] +pub fn derive_newtype(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + DeriveNewtype::new(input).to_tokens().into() +} diff --git a/der/derive/src/newtype.rs b/der/derive/src/newtype.rs new file mode 100644 index 000000000..e29df132c --- /dev/null +++ b/der/derive/src/newtype.rs @@ -0,0 +1,118 @@ +//! Support for deriving newtypes. + +use proc_macro2::TokenStream; +use proc_macro_error::abort; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::{Data, DeriveInput, Fields, FieldsUnnamed, Ident, LifetimeDef, Type}; + +trait PunctuatedExt { + fn only(&self) -> Option<&T>; +} + +impl PunctuatedExt for Punctuated { + fn only(&self) -> Option<&T> { + let mut iter = self.iter(); + + let first = iter.next(); + if let Some(..) = iter.next() { + return None; + } + + first + } +} + +pub(crate) struct DeriveNewtype { + ident: Ident, + ltime: Vec, + ftype: Type, +} + +impl DeriveNewtype { + pub fn new(input: DeriveInput) -> Self { + if let Data::Struct(data) = &input.data { + if let Fields::Unnamed(FieldsUnnamed { unnamed, .. }) = &data.fields { + if let Some(field) = unnamed.only() { + return Self { + ident: input.ident.clone(), + ltime: input.generics.lifetimes().cloned().collect(), + ftype: field.ty.clone(), + }; + } + } + } + + abort!(input, "only derivable on a newtype"); + } + + pub fn to_tokens(&self) -> TokenStream { + let ident = &self.ident; + let ftype = &self.ftype; + let ltime = &self.ltime; + + let (limpl, ltype, param) = match self.ltime.len() { + 0 => (quote! { impl }, quote! { #ident }, quote! { '_ }), + _ => ( + quote! { impl<#(#ltime)*> }, + quote! { #ident<#(#ltime)*> }, + quote! { #(#ltime)* }, + ), + }; + + quote! { + #limpl From<#ftype> for #ltype { + #[inline] + fn from(value: #ftype) -> Self { + Self(value) + } + } + + #limpl From<#ltype> for #ftype { + #[inline] + fn from(value: #ltype) -> Self { + value.0 + } + } + + #limpl ::core::ops::Deref for #ltype { + type Target = #ftype; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + #limpl ::core::ops::DerefMut for #ltype { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + + #limpl ::der::FixedTag for #ltype { + const TAG: ::der::Tag = <#ftype as ::der::FixedTag>::TAG; + } + + #limpl ::der::DecodeValue<#param> for #ltype { + fn decode_value( + decoder: &mut ::der::Decoder<#param>, + header: ::der::Header, + ) -> ::der::Result { + Ok(Self(<#ftype as ::der::DecodeValue>::decode_value(decoder, header)?)) + } + } + + #limpl ::der::EncodeValue for #ltype { + fn encode_value(&self, encoder: &mut ::der::Encoder<'_>) -> ::der::Result<()> { + self.0.encode_value(encoder) + } + + fn value_len(&self) -> ::der::Result<::der::Length> { + self.0.value_len() + } + } + } + } +} diff --git a/der/tests/derive.rs b/der/tests/derive.rs index 2ee70830a..513ed1c26 100644 --- a/der/tests/derive.rs +++ b/der/tests/derive.rs @@ -455,3 +455,38 @@ mod sequence { ); } } + +mod newtype { + use der::{asn1::BitString, Decodable, Encodable}; + use der_derive::Newtype; + + #[derive(Newtype)] + struct Lifetime<'a>(BitString<'a>); + + #[derive(Newtype)] + struct NoLifetime(bool); + + #[test] + fn decode() { + let bs = BitString::from_bytes(&[0, 1, 2, 3]).unwrap(); + let en = bs.to_vec().unwrap(); + let lt = Lifetime::from_der(&en).unwrap(); + assert_eq!(bs, lt.into()); + + let en = true.to_vec().unwrap(); + let lt = NoLifetime::from_der(&en).unwrap(); + assert!(bool::from(lt)); + } + + #[test] + fn encode() { + let bs = BitString::from_bytes(&[0, 1, 2, 3]).unwrap(); + let en = bs.to_vec().unwrap(); + let lt = Lifetime::from(bs).to_vec().unwrap(); + assert_eq!(en, lt); + + let en = true.to_vec().unwrap(); + let lt = NoLifetime::from(true).to_vec().unwrap(); + assert_eq!(en, lt); + } +}