diff --git a/Cargo.lock b/Cargo.lock index 1dd7ae66..e070a87d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -861,6 +861,7 @@ dependencies = [ "digest", "hex-literal", "pem-rfc7468", + "ssh-derive", ] [[package]] diff --git a/ssh-derive/src/attributes.rs b/ssh-derive/src/attributes.rs new file mode 100644 index 00000000..99d42cb1 --- /dev/null +++ b/ssh-derive/src/attributes.rs @@ -0,0 +1,97 @@ +use proc_macro2::TokenStream; +use quote::quote; + +pub(crate) struct ContainerAttributes { + pub(crate) length_prefixed: bool, + pub(crate) discriminant_type: Option, +} + +impl TryFrom<&syn::DeriveInput> for ContainerAttributes { + type Error = syn::Error; + + fn try_from(input: &syn::DeriveInput) -> Result { + let mut length_prefixed = false; + let mut discriminant_type = None; + for attr in &input.attrs { + if attr.path().is_ident("ssh") { + attr.parse_nested_meta(|meta| { + // #[ssh(length_prefixed)] + if meta.path.is_ident("length_prefixed") { + length_prefixed = true; + } else { + return Err(syn::Error::new_spanned(meta.path, "unknown attribute")); + } + Ok(()) + })?; + } else if attr.path().is_ident("repr") { + attr.parse_nested_meta(|meta| { + // #[repr(u8)] or similar + // https://doc.rust-lang.org/reference/type-layout.html#primitive-representations + if meta.path.is_ident("u8") { + discriminant_type = Some(quote! {u8}); + } else if meta.path.is_ident("u16") { + discriminant_type = Some(quote! {u16}); + } else if meta.path.is_ident("u32") { + discriminant_type = Some(quote! {u32}); + } else if meta.path.is_ident("u64") { + discriminant_type = Some(quote! {u64}); + } else if meta.path.is_ident("u128") { + discriminant_type = Some(quote! {u128}); + } else if meta.path.is_ident("usize") { + discriminant_type = Some(quote! {usize}); + } else if meta.path.is_ident("i8") { + discriminant_type = Some(quote! {i8}); + } else if meta.path.is_ident("i16") { + discriminant_type = Some(quote! {i16}); + } else if meta.path.is_ident("i32") { + discriminant_type = Some(quote! {i32}); + } else if meta.path.is_ident("i64") { + discriminant_type = Some(quote! {i64}); + } else if meta.path.is_ident("i128") { + discriminant_type = Some(quote! {i128}); + } else if meta.path.is_ident("isize") { + discriminant_type = Some(quote! {isize}); + } else { + return Err(syn::Error::new_spanned( + meta.path, + "unsupported repr for deriving Encode/Decode, must be a primitive integer type", + )); + } + Ok(()) + })?; + } + } + + Ok(Self { + length_prefixed, + discriminant_type, + }) + } +} + +pub(crate) struct FieldAttributes { + pub(crate) length_prefixed: bool, +} + +impl TryFrom<&syn::Field> for FieldAttributes { + type Error = syn::Error; + + fn try_from(field: &syn::Field) -> Result { + let mut length_prefixed = false; + for attr in &field.attrs { + if attr.path().is_ident("ssh") { + attr.parse_nested_meta(|meta| { + // #[ssh(length_prefixed)] + if meta.path.is_ident("length_prefixed") { + length_prefixed = true; + } else { + return Err(syn::Error::new_spanned(meta.path, "unknown attribute")); + } + Ok(()) + })?; + } + } + + Ok(Self { length_prefixed }) + } +} diff --git a/ssh-derive/src/decode.rs b/ssh-derive/src/decode.rs index c6c10dc9..c72153e2 100644 --- a/ssh-derive/src/decode.rs +++ b/ssh-derive/src/decode.rs @@ -1,92 +1,291 @@ //! Support for deriving the `Decode` trait on structs. -use crate::FieldIr; use proc_macro2::TokenStream; use quote::quote; -use syn::{DeriveInput, Generics, Ident}; +use syn::{spanned::Spanned, DataEnum, DataStruct, DeriveInput}; -/// Derive the `Decode` trait for a struct -pub(crate) struct DeriveDecode { - /// Name of the struct. - ident: Ident, +use crate::attributes::{ContainerAttributes, FieldAttributes}; - /// Generics of the struct. - generics: Generics, - - /// Fields of the struct. - fields: Vec, +pub(crate) fn try_derive_decode(input: DeriveInput) -> syn::Result { + match input.data { + syn::Data::Struct(ref data) => try_derive_decode_for_struct(&input, data), + syn::Data::Enum(ref data) => try_derive_decode_for_enum(&input, data), + syn::Data::Union(_) => abort!(input.ident, "can't derive `Decode` on union types",), + } } -impl DeriveDecode { - /// Parse [`DeriveInput`]. - pub fn new(input: DeriveInput) -> syn::Result { - let data = match input.data { - syn::Data::Struct(data) => data, - _ => abort!( - input.ident, - "can't derive `Decode` on this type: only `struct` types are allowed", - ), - }; +fn try_derive_decode_for_struct( + input: &DeriveInput, + DataStruct { fields, .. }: &DataStruct, +) -> syn::Result { + let container_attributes = ContainerAttributes::try_from(input)?; + let struct_name = &input.ident; + let body = derive_for_fields(fields, quote! { Self })?; + let body = maybe_length_prefixed_result(container_attributes.length_prefixed, &body); + let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl(); - let fields = FieldIr::from_fields(data.fields)?; + Ok(quote! { + #[automatically_derived] + impl #impl_generics ::ssh_encoding::Decode for #struct_name #type_generics #where_clause { + type Error = ::ssh_encoding::Error; - Ok(Self { - ident: input.ident, - generics: input.generics.clone(), - fields, - }) - } + fn decode(reader: &mut impl ::ssh_encoding::Reader) -> ::core::result::Result { + #body + } + } + }) +} - /// Lower the derived output into a [`TokenStream`]. - pub fn to_tokens(&self) -> TokenStream { - let ident = &self.ident; - let (_, generics, where_clause) = self.generics.split_for_impl(); +fn try_derive_decode_for_enum( + input: &DeriveInput, + DataEnum { variants, .. }: &DataEnum, +) -> syn::Result { + let container_attributes = ContainerAttributes::try_from(input)?; + let enum_name = &input.ident; + let discriminant_type = container_attributes + .discriminant_type + .clone() + .ok_or_else(|| { + syn::Error::new( + input.ident.span(), + "enums must have a repr attribute to derive `Decode`", + ) + })?; + let variant_arms = variants + .iter() + .map(|variant| { + let variant_name = &variant.ident; + let discriminant = variant + .discriminant + .as_ref() + .map(|(_, variant)| variant) + .ok_or_else(|| { + syn::Error::new( + variant.span(), + "enum variants must have an explicit discriminant to derive `Decode`", + ) + })?; + let body = derive_for_fields(&variant.fields, quote! { #enum_name::#variant_name })?; + Ok(quote! { #discriminant => { #body } }) + }) + .collect::>>()?; + let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl(); - let mut lowerer = FieldLowerer::new(); - for field in &self.fields { - lowerer.add_field(field); + let body = quote! { + let discriminant = <#discriminant_type as ::ssh_encoding::Decode>::decode(reader)?; + match discriminant { + #(#variant_arms,)* + _ => return Err(::ssh_encoding::Error::InvalidDiscriminant(discriminant.into()).into()), } - let body = lowerer.into_tokens(); + }; + let body = maybe_length_prefixed_result(container_attributes.length_prefixed, &body); - quote! { - #[automatically_derived] - impl #generics ::ssh_encoding::Decode for #ident #generics #where_clause { - type Error = ::ssh_encoding::Error; + Ok(quote! { + #[automatically_derived] + impl #impl_generics ::ssh_encoding::Decode for #enum_name #type_generics #where_clause { + type Error = ::ssh_encoding::Error; - fn decode(reader: &mut impl ::ssh_encoding::Reader) -> Result { - Ok(Self { - #(#body),* - }) - } + fn decode(reader: &mut impl ::ssh_encoding::Reader) -> ::core::result::Result { + #body } } - } + }) } -/// AST lowerer for field decoders. -struct FieldLowerer { - /// Decoder-in-progress. - body: Vec, +/// Generate decoding code for the given fields. +/// +/// This will also handle length-prefixed containers if it is marked as such. +fn derive_for_fields( + fields: &syn::Fields, + output_type_or_variant: TokenStream, +) -> syn::Result { + let mut field_decoders = Vec::with_capacity(fields.len()); + + for field in fields { + let attrs = FieldAttributes::try_from(field)?; + let ty = &field.ty; + field_decoders.push( + if attrs.length_prefixed { + quote! { reader.read_prefixed(|reader| <#ty as ::ssh_encoding::Decode>::decode(reader))? } + } else { + quote! { <#ty as ::ssh_encoding::Decode>::decode(reader)? } + } + ); + } + + let body = match fields { + syn::Fields::Unit => output_type_or_variant, + syn::Fields::Named(named) => { + let named = named + .named + .iter() + .map(|field| field.ident.as_ref().expect("named fields are named")); + quote! { #output_type_or_variant { #(#named: #field_decoders),* } } + } + syn::Fields::Unnamed(_) => { + quote! { #output_type_or_variant ( #(#field_decoders),* ) } + } + }; + + Ok(body) } -impl FieldLowerer { - /// Create a new field decoder lowerer. - fn new() -> Self { - Self { - body: Vec::default(), +fn maybe_length_prefixed_result(length_prefix: bool, body: &TokenStream) -> TokenStream { + if length_prefix { + quote! { + reader.read_prefixed(|reader| { + Ok::<_, ::ssh_encoding::Error>({#body}) + }) } + } else { + quote! { Ok({#body}) } } +} - /// Add a field to the lowerer. - fn add_field(&mut self, field: &FieldIr) { - let ident = field.ident.clone(); - let ty = field.ty.clone(); - let field = quote! { #ident: <#ty as ::ssh_encoding::Decode>::decode(reader)? }; - self.body.push(field); +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + use super::*; + use quote::quote; + + macro_rules! assert_eq_tokens { + ($left:expr, $right:expr) => { + assert_eq!($left.to_string(), $right.to_string()); + }; } - /// Return the resulting tokens. - fn into_tokens(self) -> Vec { - self.body + #[test] + fn test_maybe_length_prefixed() { + let actual = maybe_length_prefixed_result(true, "e! { () }); + let expected = quote! { + reader.read_prefixed(|reader| { + Ok::<_, ::ssh_encoding::Error>({()}) + }) + }; + assert_eq_tokens!(actual, expected); + + let actual = maybe_length_prefixed_result(false, "e! { () }); + let expected = quote! { Ok({()}) }; + assert_eq_tokens!(actual, expected); + } + + #[test] + fn test_derive_for_fields_named() { + let fields: syn::FieldsNamed = syn::parse_quote! ({ + a: u32, + b: String, + #[ssh(length_prefixed)] + c: bool + }); + let actual = derive_for_fields(&syn::Fields::Named(fields), quote! { Self }).unwrap(); + let expected = quote! { + Self { + a: ::decode(reader)?, + b: ::decode(reader)?, + c: reader.read_prefixed(|reader| ::decode(reader))? + } + }; + assert_eq_tokens!(actual, expected); + } + + #[test] + fn test_derive_for_fields_unnamed() { + let fields: syn::FieldsUnnamed = syn::parse_quote!(( + u32, + #[ssh(length_prefixed)] + String, + bool + )); + let actual = derive_for_fields(&syn::Fields::Unnamed(fields), quote! { Self }).unwrap(); + let expected = quote! { + Self ( + ::decode(reader)?, + reader.read_prefixed(|reader| ::decode(reader))?, + ::decode(reader)? + ) + }; + assert_eq_tokens!(actual, expected); + } + + #[test] + fn test_derive_for_fields_unit() { + let actual = derive_for_fields(&syn::Fields::Unit, quote! { Self }).unwrap(); + let expected = quote! { Self }; + assert_eq_tokens!(actual, expected); + } + + #[test] + fn test_derive_for_fields_bad_attribute() { + let fields: syn::FieldsNamed = syn::parse_quote! ({ + #[ssh(not_a_valid_attribute)] + a: u32, + }); + let actual = derive_for_fields(&syn::Fields::Named(fields), quote! { Self }); + assert!(actual.is_err()); + assert!(actual + .unwrap_err() + .to_string() + .contains("unknown attribute")); + } + + #[test] + fn test_try_derive_decode_for_struct() { + let input = syn::parse_quote! { + struct Foo { + #[ssh(length_prefixed)] + a: u32, + b: String, + } + }; + let actual = try_derive_decode(input).unwrap(); + let expected = quote! { + #[automatically_derived] + impl ::ssh_encoding::Decode for Foo { + type Error = ::ssh_encoding::Error; + + fn decode(reader: &mut impl ::ssh_encoding::Reader) -> ::core::result::Result { + Ok({ + Self { + a: reader.read_prefixed(|reader| ::decode(reader))?, + b: ::decode(reader)? + } + }) + } + } + }; + assert_eq!(actual.to_string(), expected.to_string()); + } + + #[test] + fn test_try_derive_decode_for_enum() { + let input = syn::parse_quote! { + #[ssh(length_prefixed)] + #[repr(u8)] + enum Foo { + A = 0, + B = 1, + } + }; + let actual = try_derive_decode(input).unwrap(); + let expected = quote! { + #[automatically_derived] + impl ::ssh_encoding::Decode for Foo { + type Error = ::ssh_encoding::Error; + + fn decode(reader: &mut impl ::ssh_encoding::Reader) -> ::core::result::Result { + reader.read_prefixed(|reader| { + Ok::<_, ::ssh_encoding::Error>({ + let discriminant = ::decode(reader)?; + match discriminant { + 0 => { Foo :: A }, + 1 => { Foo :: B }, + _ => return Err(::ssh_encoding::Error::InvalidDiscriminant(discriminant.into()).into()), + } + }) + }) + } + } + }; + assert_eq!(actual.to_string(), expected.to_string()); } } diff --git a/ssh-derive/src/encode.rs b/ssh-derive/src/encode.rs index 6c618eb7..a7bd23ab 100644 --- a/ssh-derive/src/encode.rs +++ b/ssh-derive/src/encode.rs @@ -1,105 +1,500 @@ //! Support for deriving the `Encode` trait on structs. -use crate::FieldIr; use proc_macro2::TokenStream; -use quote::quote; -use syn::{DeriveInput, Generics, Ident}; +use quote::{quote, ToTokens}; +use syn::{spanned::Spanned, DataEnum, DataStruct, DeriveInput}; -/// Derive the `Encode` trait for a struct -pub(crate) struct DeriveEncode { - /// Name of the struct. - ident: Ident, +use crate::attributes::{ContainerAttributes, FieldAttributes}; - /// Generics of the struct. - generics: Generics, - - /// Fields of the struct. - fields: Vec, +pub(crate) fn try_derive_encode(input: DeriveInput) -> syn::Result { + match input.data { + syn::Data::Struct(ref data) => try_derive_encode_for_struct(&input, data), + syn::Data::Enum(ref data) => try_derive_encode_for_enum(&input, data), + syn::Data::Union(_) => abort!(input.ident, "can't derive `Encode` on union types",), + } } -impl DeriveEncode { - /// Parse [`DeriveInput`]. - pub fn new(input: DeriveInput) -> syn::Result { - let data = match input.data { - syn::Data::Struct(data) => data, - _ => abort!( - input.ident, - "can't derive `Encode` on this type: only `struct` types are allowed", - ), - }; +fn try_derive_encode_for_struct( + input: &DeriveInput, + DataStruct { fields, .. }: &DataStruct, +) -> syn::Result { + let container_attributes = ContainerAttributes::try_from(input)?; + let names = fields_variables(fields, true); + let (field_lengths, field_encoders) = derive_for_fields(fields, names)?; + let (length_prefix_len, length_prefix_encoder) = + maybe_length_prefix(container_attributes.length_prefixed); + let struct_name = &input.ident; + let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl(); - let fields = FieldIr::from_fields(data.fields)?; - - Ok(Self { - ident: input.ident, - generics: input.generics.clone(), - fields, - }) - } - - /// Lower the derived output into a [`TokenStream`]. - pub fn to_tokens(&self) -> TokenStream { - let ident = &self.ident; - let (_, generics, where_clause) = self.generics.split_for_impl(); + Ok(quote! { + #[automatically_derived] + impl #impl_generics ::ssh_encoding::Encode for #struct_name #type_generics #where_clause { + fn encoded_len(&self) -> ::ssh_encoding::Result { + use ::ssh_encoding::CheckedSum; + [ + #length_prefix_len + #(#field_lengths),* + ].checked_sum() + } - let mut lowerer = FieldLowerer::new(); - for field in &self.fields { - lowerer.add_field(field); + fn encode(&self, writer: &mut impl ::ssh_encoding::Writer) -> ::ssh_encoding::Result<()> { + #length_prefix_encoder + #(#field_encoders)* + Ok(()) + } } - let (encoded_len_body, encode_body) = lowerer.into_tokens(); + }) +} - quote! { - #[automatically_derived] - impl #generics ::ssh_encoding::Encode for #ident #generics #where_clause { - fn encoded_len(&self) -> ::ssh_encoding::Result { - use ::ssh_encoding::CheckedSum; +fn try_derive_encode_for_enum( + input: &DeriveInput, + DataEnum { variants, .. }: &DataEnum, +) -> syn::Result { + let enum_name = &input.ident; + let container_attributes = ContainerAttributes::try_from(input)?; + let (length_arms, encode_arms) = + derive_for_variants(&container_attributes, variants.iter(), enum_name)?; + let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl(); - [ - #(#encoded_len_body),* - ] - .checked_sum() + Ok(quote! { + #[automatically_derived] + impl #impl_generics ::ssh_encoding::Encode for #enum_name #type_generics #where_clause { + fn encoded_len(&self) -> ::ssh_encoding::Result { + use ::ssh_encoding::CheckedSum; + match self { + #(#length_arms)* } - - fn encode(&self, writer: &mut impl ::ssh_encoding::Writer) -> ::ssh_encoding::Result<()> { - #(#encode_body)* - Ok(()) + } + fn encode(&self, writer: &mut impl ::ssh_encoding::Writer) -> ::ssh_encoding::Result<()> { + match self { + #(#encode_arms)* } + Ok(()) } } + }) +} + +/// Generate encoding code for the given fields, bound to the given names. +/// +/// This will also handle length-prefixing the container if it is marked as such. +fn derive_for_fields( + fields: &syn::Fields, + names: Vec, +) -> syn::Result<(Vec, Vec)> { + let mut lengths = Vec::new(); + let mut encoders = Vec::new(); + for (field, name) in fields.iter().zip(names) { + let attrs = FieldAttributes::try_from(field)?; + if attrs.length_prefixed { + lengths.push(quote! { ::ssh_encoding::Encode::encoded_len_prefixed(#name)? }); + encoders.push(quote! { ::ssh_encoding::Encode::encode_prefixed(#name, writer)?; }); + } else { + lengths.push(quote! { ::ssh_encoding::Encode::encoded_len(#name)? }); + encoders.push(quote! { ::ssh_encoding::Encode::encode(#name, writer)?; }); + } + } + + Ok((lengths, encoders)) +} + +fn derive_for_variants<'a>( + container_attributes: &ContainerAttributes, + variants: impl Iterator, + enum_name: &'a syn::Ident, +) -> syn::Result<(Vec, Vec)> { + let mut length_arms = Vec::new(); + let mut encode_arms = Vec::new(); + for variant in variants { + let variant_name = &variant.ident; + let names = fields_variables(&variant.fields, false); + let match_variant = match &variant.fields { + syn::Fields::Unit => quote! {}, + syn::Fields::Named(_) => quote! { {#(#names),*} }, + syn::Fields::Unnamed(_) => quote! { (#(#names),*) }, + }; + + let discriminant_type = + container_attributes + .discriminant_type + .clone() + .ok_or_else(|| { + syn::Error::new( + variant.span(), + "enum must have a repr attribute to derive `Encode`", + ) + })?; + let discriminant = variant + .discriminant + .as_ref() + .map(|(_, variant)| variant) + .ok_or_else(|| { + syn::Error::new( + variant.span(), + "enum variants must have an explicit discriminant to derive `Encode`", + ) + })?; + let (field_lengths, field_encoders) = derive_for_fields(&variant.fields, names)?; + let (length_prefix_len, length_prefix_encoder) = + maybe_length_prefix(container_attributes.length_prefixed); + length_arms.push(quote! { + #enum_name::#variant_name #match_variant => { + [ + #length_prefix_len + ::core::mem::size_of::<#discriminant_type>(), + #(#field_lengths),* + ].checked_sum() + } + }); + encode_arms.push(quote! { + #enum_name::#variant_name #match_variant => { + #length_prefix_encoder + ::ssh_encoding::Encode::encode(&(#discriminant as #discriminant_type), writer)?; + #(#field_encoders)* + } + }); + } + + Ok((length_arms, encode_arms)) +} + +/// Generate length prefixing code or empty token streams if not needed. +fn maybe_length_prefix(length_prefix: bool) -> (TokenStream, TokenStream) { + if length_prefix { + ( + quote! { ::ssh_encoding::Encode::encoded_len(&0usize)?, }, + quote! {{ + let len = ::ssh_encoding::Encode::encoded_len(self)? - ::ssh_encoding::Encode::encoded_len(&0usize)?; + ::ssh_encoding::Encode::encode(&len, writer)?; + }}, + ) + } else { + (quote! {}, quote! {}) } } -/// AST lowerer for field decoders. -struct FieldLowerer { - /// Encoded length calculation in progress. - encoded_len_body: Vec, +/// Generate a list of field variables for a struct or enum variant. +/// +/// If `use_self` is true, the fields are accessed using `self.` (for struct fields). +/// Otherwise, the fields are accessed directly (for enum variants and match expressions). +fn fields_variables(fields: &syn::Fields, use_self: bool) -> Vec { + match &fields { + syn::Fields::Unit => Vec::new(), + syn::Fields::Named(field_names) => field_names + .named + .iter() + .map(|field| { + ( + field + .ident + .as_ref() + .expect("named fields are named") + .to_token_stream(), + matches!(field.ty, syn::Type::Reference(_)), + ) + }) + .map(|(name, is_ref)| match (use_self, is_ref) { + (true, true) => quote! { self.#name }, // Avoid double referencing. + (true, false) => quote! { &self.#name }, // Reference the field. + (false, _) => name, // Not via self, so variable should already be a reference. + }) + .collect(), - /// Encoder-in-progress. - encode_body: Vec, + syn::Fields::Unnamed(field_types) => field_types + .unnamed + .iter() + .enumerate() + .map(|(i, field)| { + if use_self { + let index = syn::Index::from(i); + if let syn::Type::Reference(_) = field.ty { + quote! { self.#index } + } else { + quote! { &self.#index } + } + } else { + syn::Ident::new(&format!("field_{i}"), fields.span()).to_token_stream() + } + }) + .collect(), + } } -impl FieldLowerer { - /// Create a new field decoder lowerer. - fn new() -> Self { - Self { - encoded_len_body: Vec::default(), - encode_body: Vec::default(), - } +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + use super::*; + use proc_macro2::Span; + use quote::quote; + + macro_rules! assert_eq_tokens { + ($left:expr, $right:expr) => { + assert_eq!($left.to_string(), $right.to_string()); + }; } - /// Add a field to the lowerer. - fn add_field(&mut self, field: &FieldIr) { - let ident = field.ident.clone(); + #[test] + fn test_field_variables_unit() { + let fields = syn::Fields::Unit; + assert!(fields_variables(&fields, true).is_empty()); + assert!(fields_variables(&fields, false).is_empty()); + } - let field_length = quote! { ::ssh_encoding::Encode::encoded_len(&self.#ident)? }; - self.encoded_len_body.push(field_length); + #[test] + fn test_field_variables_named() { + let fields = syn::Fields::Named(syn::parse_quote! {{ a: u8, b: &u8 }}); + let names = fields_variables(&fields, true); + assert_eq_tokens!(names[0], quote! { &self.a }); + assert_eq_tokens!(names[1], quote! { self.b }); + let names = fields_variables(&fields, false); + assert_eq_tokens!(names[0], quote! { a }); + assert_eq_tokens!(names[1], quote! { b }); + } - let field_encoder = quote! { ::ssh_encoding::Encode::encode(&self.#ident, writer)?; }; - self.encode_body.push(field_encoder); + #[test] + fn test_field_variables_unnamed() { + let fields = syn::Fields::Unnamed(syn::parse_quote! { (u8, &u8) }); + let names = fields_variables(&fields, true); + assert_eq_tokens!(names[0], quote! { &self.0 }); + assert_eq_tokens!(names[1], quote! { self.1 }); + let names = fields_variables(&fields, false); + assert_eq_tokens!(names[0], quote! { field_0 }); + assert_eq_tokens!(names[1], quote! { field_1 }); } - /// Return the resulting tokens. - fn into_tokens(self) -> (Vec, Vec) { - (self.encoded_len_body, self.encode_body) + #[test] + fn test_maybe_length_prefix() { + let (len, encoder) = maybe_length_prefix(true); + assert_eq_tokens!( + len, + quote! { ::ssh_encoding::Encode::encoded_len(&0usize)?, } + ); + assert_eq_tokens!( + encoder, + quote! {{ + let len = ::ssh_encoding::Encode::encoded_len(self)? - ::ssh_encoding::Encode::encoded_len(&0usize)?; + ::ssh_encoding::Encode::encode(&len, writer)?; + }} + ); + + let (len, encoder) = maybe_length_prefix(false); + assert_eq_tokens!(len, quote! {}); + assert_eq_tokens!(encoder, quote! {}); + } + + #[test] + fn test_derive_for_fields() { + let fields = + syn::Fields::Named(syn::parse_quote! {{ a: u8, #[ssh(length_prefixed)] b: &u8 }}); + let names = fields_variables(&fields, true); + let (lengths, encoders) = derive_for_fields(&fields, names).unwrap(); + assert_eq_tokens!( + lengths[0], + quote! { ::ssh_encoding::Encode::encoded_len(&self.a)? } + ); + assert_eq_tokens!( + encoders[0], + quote! { ::ssh_encoding::Encode::encode(&self.a, writer)?; } + ); + assert_eq_tokens!( + lengths[1], + quote! { ::ssh_encoding::Encode::encoded_len_prefixed(self.b)? } + ); + assert_eq_tokens!( + encoders[1], + quote! { ::ssh_encoding::Encode::encode_prefixed(self.b, writer)?; } + ); + } + + #[test] + fn test_derive_for_fields_bad_attribute() { + let fields = syn::Fields::Named(syn::parse_quote! {{ #[ssh(not_an_attribute)] a: u8 }}); + let names = fields_variables(&fields, true); + let err = derive_for_fields(&fields, names).unwrap_err(); + assert_eq!(err.to_string(), "unknown attribute"); + } + + #[test] + fn test_derive_for_variants_no_repr() { + let variant: syn::Variant = syn::parse_quote! { Bar }; + let enum_name = syn::Ident::new("Foo", variant.span()); + let container_attributes = ContainerAttributes { + discriminant_type: None, + length_prefixed: false, + }; + let err = derive_for_variants(&container_attributes, std::iter::once(&variant), &enum_name) + .unwrap_err(); + assert_eq!( + err.to_string(), + "enum must have a repr attribute to derive `Encode`" + ); + } + + #[test] + fn test_derive_for_variants_no_explicit_discriminant() { + let variant: syn::Variant = syn::parse_quote! { Bar }; // Variant without ` = 123` discriminant. + let enum_name = syn::Ident::new("Foo", variant.span()); + let container_attributes = ContainerAttributes { + discriminant_type: Some(quote! { u8 }), + length_prefixed: false, + }; + let err = derive_for_variants(&container_attributes, std::iter::once(&variant), &enum_name) + .unwrap_err(); + assert_eq!( + err.to_string(), + "enum variants must have an explicit discriminant to derive `Encode`" + ); + } + + #[test] + fn test_derive_for_variants() { + let variants: [syn::Variant; 2] = [ + syn::parse_quote! { Foo(u8, u8) = 1 }, + syn::parse_quote! { Bar { a: u8, #[ssh(length_prefixed)] b: &u8 } = 2 }, + ]; + let enum_name = syn::Ident::new("Enum", Span::call_site()); + let container_attributes = ContainerAttributes { + discriminant_type: Some(quote! { u8 }), + length_prefixed: false, + }; + let (length_arms, encode_arms) = + derive_for_variants(&container_attributes, variants.iter(), &enum_name).unwrap(); + assert_eq_tokens!( + length_arms[0], + quote! { + Enum::Foo (field_0, field_1) => { + [ + ::core::mem::size_of::(), + ::ssh_encoding::Encode::encoded_len(field_0)?, + ::ssh_encoding::Encode::encoded_len(field_1)? + ].checked_sum() + } + } + ); + assert_eq_tokens!( + encode_arms[0], + quote! { + Enum::Foo(field_0, field_1) => { + ::ssh_encoding::Encode::encode(&(1 as u8), writer)?; + ::ssh_encoding::Encode::encode(field_0, writer)?; + ::ssh_encoding::Encode::encode(field_1, writer)?; + } + } + ); + assert_eq_tokens!( + length_arms[1], + quote! { + Enum::Bar {a, b} => { + [ + ::core::mem::size_of::(), + ::ssh_encoding::Encode::encoded_len(a)?, + ::ssh_encoding::Encode::encoded_len_prefixed(b)? + ].checked_sum() + } + } + ); + assert_eq_tokens!( + encode_arms[1], + quote! { + Enum::Bar {a, b} => { + ::ssh_encoding::Encode::encode(&(2 as u8), writer)?; + ::ssh_encoding::Encode::encode(a, writer)?; + ::ssh_encoding::Encode::encode_prefixed(b, writer)?; + } + } + ); + } + + #[test] + fn test_derive_for_struct() { + let input: DeriveInput = syn::parse_quote! { + #[ssh(length_prefixed)] + struct Foo { + a: u8, + #[ssh(length_prefixed)] + b: &u8, + } + }; + let output = try_derive_encode(input).unwrap(); + assert_eq_tokens!( + output, + quote! { + #[automatically_derived] + impl ::ssh_encoding::Encode for Foo { + fn encoded_len(&self) -> ::ssh_encoding::Result { + use ::ssh_encoding::CheckedSum; + [ + ::ssh_encoding::Encode::encoded_len(&0usize)?, + ::ssh_encoding::Encode::encoded_len(&self.a)?, + ::ssh_encoding::Encode::encoded_len_prefixed(self.b)? + ].checked_sum() + } + + fn encode(&self, writer: &mut impl ::ssh_encoding::Writer) -> ::ssh_encoding::Result<()> { + { + let len = ::ssh_encoding::Encode::encoded_len(self)? - ::ssh_encoding::Encode::encoded_len(&0usize)?; + ::ssh_encoding::Encode::encode(&len, writer)?; + } + ::ssh_encoding::Encode::encode(&self.a, writer)?; + ::ssh_encoding::Encode::encode_prefixed(self.b, writer)?; + Ok(()) + } + } + } + ); + } + + #[test] + fn test_derive_for_enum() { + let input: DeriveInput = syn::parse_quote! { + #[repr(u8)] + enum Enum { + Foo(u8, u8) = 1, + Bar { a: u8, #[ssh(length_prefixed)] b: &u8 } = 2, + } + }; + let output = try_derive_encode(input).unwrap(); + assert_eq_tokens!( + output, + quote! { + #[automatically_derived] + impl ::ssh_encoding::Encode for Enum { + fn encoded_len(&self) -> ::ssh_encoding::Result { + use ::ssh_encoding::CheckedSum; + match self { + Enum::Foo (field_0, field_1) => { + [ + ::core::mem::size_of::(), + ::ssh_encoding::Encode::encoded_len(field_0)?, + ::ssh_encoding::Encode::encoded_len(field_1)? + ].checked_sum() + } + Enum::Bar {a, b} => { + [ + ::core::mem::size_of::(), + ::ssh_encoding::Encode::encoded_len(a)?, + ::ssh_encoding::Encode::encoded_len_prefixed(b)? + ].checked_sum() + } + } + } + fn encode(&self, writer: &mut impl ::ssh_encoding::Writer) -> ::ssh_encoding::Result<()> { + match self { + Enum::Foo(field_0, field_1) => { + ::ssh_encoding::Encode::encode(&(1 as u8), writer)?; + ::ssh_encoding::Encode::encode(field_0, writer)?; + ::ssh_encoding::Encode::encode(field_1, writer)?; + } + Enum::Bar {a, b} => { + ::ssh_encoding::Encode::encode(&(2 as u8), writer)?; + ::ssh_encoding::Encode::encode(a, writer)?; + ::ssh_encoding::Encode::encode_prefixed(b, writer)?; + } + } + Ok(()) + } + } + } + ); } } diff --git a/ssh-derive/src/field_ir.rs b/ssh-derive/src/field_ir.rs deleted file mode 100644 index 19e2fd1d..00000000 --- a/ssh-derive/src/field_ir.rs +++ /dev/null @@ -1,31 +0,0 @@ -use syn::{Field, Fields, Ident, Type}; - -/// Intermediate representation for a struct field. -pub(crate) struct FieldIr { - /// Field name. - pub ident: Ident, - - /// Field type. - pub ty: Type, -} - -impl FieldIr { - pub fn from_fields(fields: Fields) -> syn::Result> { - fields.iter().map(FieldIr::new).collect() - } - - /// Create a new [`FieldIr`] from the input [`Field`]. - pub fn new(field: &Field) -> syn::Result { - let ident = field.ident.as_ref().cloned().ok_or_else(|| { - syn::Error::new_spanned( - field, - "no name on struct field i.e. tuple structs unsupported", - ) - })?; - - Ok(Self { - ident, - ty: field.ty.clone(), - }) - } -} diff --git a/ssh-derive/src/lib.rs b/ssh-derive/src/lib.rs index caf4a2f8..1b8ca3a3 100644 --- a/ssh-derive/src/lib.rs +++ b/ssh-derive/src/lib.rs @@ -4,7 +4,7 @@ //! Custom derive support for the [`ssh-encoding`] crate. //! //! Note that this crate shouldn't be used directly, but instead accessed -//! by using the `derive` feature of the `der` crate, which re-exports this crate's +//! by using the `derive` feature of the [`ssh-encoding`] crate, which re-exports this crate's //! macros from the toplevel. //! //! [`ssh-encoding`]: ../ssh-encoding @@ -24,11 +24,10 @@ macro_rules! abort { }; } +mod attributes; mod decode; mod encode; -mod field_ir; -use crate::{decode::DeriveDecode, encode::DeriveEncode, field_ir::FieldIr}; use proc_macro::TokenStream; use syn::{parse_macro_input, DeriveInput}; @@ -38,8 +37,8 @@ use syn::{parse_macro_input, DeriveInput}; #[proc_macro_derive(Decode, attributes(ssh))] pub fn derive_decode(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); - match DeriveDecode::new(input) { - Ok(t) => t.to_tokens().into(), + match decode::try_derive_decode(input) { + Ok(t) => t.into(), Err(e) => e.to_compile_error().into(), } } @@ -50,8 +49,8 @@ pub fn derive_decode(input: TokenStream) -> TokenStream { #[proc_macro_derive(Encode, attributes(ssh))] pub fn derive_encode(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); - match DeriveEncode::new(input) { - Ok(t) => t.to_tokens().into(), + match encode::try_derive_encode(input) { + Ok(t) => t.into(), Err(e) => e.to_compile_error().into(), } } diff --git a/ssh-encoding/Cargo.toml b/ssh-encoding/Cargo.toml index 856a08b5..4fa53f35 100644 --- a/ssh-encoding/Cargo.toml +++ b/ssh-encoding/Cargo.toml @@ -20,6 +20,7 @@ base64ct = { version = "1.7", optional = true } bytes = { version = "1", optional = true, default-features = false } digest = { version = "=0.11.0-pre.9", optional = true, default-features = false } pem-rfc7468 = { version = "1.0.0-rc.2", optional = true } +ssh-derive = { version = "0.0.1-alpha", optional = true, path = "../ssh-derive" } [dev-dependencies] hex-literal = "0.4.1" @@ -30,6 +31,7 @@ alloc = ["base64ct?/alloc", "pem-rfc7468?/alloc"] base64 = ["dep:base64ct"] bytes = ["alloc", "dep:bytes"] pem = ["base64", "dep:pem-rfc7468"] +derive = ["ssh-derive"] [package.metadata.docs.rs] all-features = true diff --git a/ssh-encoding/src/derive.rs b/ssh-encoding/src/derive.rs new file mode 100644 index 00000000..78aa3196 --- /dev/null +++ b/ssh-encoding/src/derive.rs @@ -0,0 +1,65 @@ +//! # Deriving [`Encode`] and [`Decode`] +//! +//! The traits [`Encode`] and [`Decode`] can be derived for any struct or enum where all its fields +//! implement [`Encode`] and [`Decode`]. +//! +//! [`Encode`]: [crate::Encode] +//! [`Decode`]: [crate::Decode] +//! ## Example +//! +//! Here is an example of how you could define a handful of the SSH message types. +#![cfg_attr(feature = "alloc", doc = "```")] +#![cfg_attr(not(feature = "alloc"), doc = "```ignore")] +//! use ssh_encoding::{Decode, Encode}; +//! +//! #[derive(Debug, PartialEq, Encode, Decode)] +//! #[repr(u8)] +//! enum Message { +//! Disconnect { +//! reason_code: u32, +//! description: String, +//! language_tag: String, +//! } = 1, +//! EcdhInit { +//! client_public_key: Vec, +//! } = 30, +//! EcdhReply { +//! host_key: HostKey, +//! server_public_key: Vec, +//! #[ssh(length_prefixed)] +//! host_signature: HostSignature, +//! } = 31, +//! } +//! +//! #[derive(Debug, PartialEq, Encode, Decode)] +//! #[ssh(length_prefixed)] +//! struct HostKey { +//! key_type: String, +//! ecdsa_curve_identifier: String, +//! ecdsa_public_key: Vec, +//! } +//! +//! #[derive(Debug, PartialEq, Encode, Decode)] +//! struct HostSignature { +//! signature_type: String, +//! signature: Vec, +//! } +//! +//! let message = Message::EcdhReply { +//! host_key: HostKey { +//! key_type: "ecdsa-sha2-nistp256".into(), +//! ecdsa_curve_identifier: "nistp256".into(), +//! ecdsa_public_key: vec![0x01, 0x02, 0x03], +//! }, +//! server_public_key: vec![0x04, 0x05, 0x06], +//! host_signature: HostSignature { +//! signature_type: "ecdsa-sha2-nistp256".into(), +//! signature: vec![0x07, 0x08, 0x09], +//! }, +//! }; +//! +//! let encoded = message.encode_vec().unwrap(); +//! assert_eq!(&encoded[..13], &[31, 0, 0, 0, 42, 0, 0, 0, 19, 101, 99, 100, 115]); +//! let decoded = Message::decode(&mut &encoded[..]).unwrap(); +//! assert_eq!(message, decoded); +//! ``` diff --git a/ssh-encoding/src/error.rs b/ssh-encoding/src/error.rs index 2e3bb25e..6fa400b8 100644 --- a/ssh-encoding/src/error.rs +++ b/ssh-encoding/src/error.rs @@ -35,6 +35,9 @@ pub enum Error { /// Number of bytes of remaining data at end of message. remaining: usize, }, + + /// Invalid discriminant value in message. + InvalidDiscriminant(u128), } impl core::error::Error for Error { @@ -65,6 +68,9 @@ impl fmt::Display for Error { f, "unexpected trailing data at end of message ({remaining} bytes)", ), + Error::InvalidDiscriminant(discriminant) => { + write!(f, "invalid discriminant value: {discriminant}") + } } } } diff --git a/ssh-encoding/src/lib.rs b/ssh-encoding/src/lib.rs index b8b153db..54450ccd 100644 --- a/ssh-encoding/src/lib.rs +++ b/ssh-encoding/src/lib.rs @@ -62,3 +62,8 @@ pub use digest; #[cfg(feature = "pem")] pub use crate::pem::{DecodePem, EncodePem}; + +#[cfg(feature = "derive")] +pub use ssh_derive::{Decode, Encode}; +#[cfg(feature = "derive")] +pub mod derive; diff --git a/ssh-encoding/tests/derive.rs b/ssh-encoding/tests/derive.rs new file mode 100644 index 00000000..cdc1d1d3 --- /dev/null +++ b/ssh-encoding/tests/derive.rs @@ -0,0 +1,182 @@ +//! Tests for the derive implementations for the `Decode` and `Encode` traits. +#![cfg(all(feature = "derive", feature = "alloc"))] + +use ssh_encoding::{Decode, Encode, Error}; + +#[derive(Debug, PartialEq, Decode, Encode)] +struct MostTypes +where + T: Encode + Decode, +{ + a: u8, + b: u32, + c: u64, + d: usize, + e: bool, + f: [u8; 7], + g: String, + h: Vec, + i: T, +} + +// Only `Encode` is derived for references, as `Decode` isn't implemented for them. +#[derive(Debug, PartialEq, Encode)] +struct Reference<'a>(&'a [u8]); + +#[derive(Debug, PartialEq, Decode, Encode)] +#[ssh(length_prefixed)] +struct LengthPrefixed { + #[ssh(length_prefixed)] + a: u32, + b: String, +} + +#[derive(Debug, PartialEq, Encode, Decode)] +#[repr(u8)] +#[ssh(length_prefixed)] +enum ComplexEnum { + Bar = 1, + Baz { + a: u32, + #[ssh(length_prefixed)] + b: u8, + } = 2, + Fiz(u32, #[ssh(length_prefixed)] u8) = 3, +} + +#[derive(Debug, PartialEq, Encode, Decode)] +#[repr(u32)] +enum SimpleEnum { + A = 1, + B = 2, +} + +#[derive(Debug, PartialEq, Encode, Decode)] +#[repr(u8)] +enum ModerateEnum { + A = 1, + B { a: String } = 2, +} + +#[derive(Debug, PartialEq, Encode, Decode)] +struct Empty; + +#[test] +fn derive_encode_decode_roundtrip_most_types() { + #[rustfmt::skip] + let data = [ + 42, + 0xDE, 0xAD, 0xBE, 0xEF, + 0xCA, 0xFE, 0xBA, 0xBE, 0xFA, 0xCE, 0xFE, 0xED, + 0x00, 0x00, 0xAB, 0xCD, + 0x01, + b'e', b'x', b'a', b'm', b'p', b'l', b'e', + 0x00, 0x00, 0x00, 0x05, b'h', b'e', b'l', b'l', b'o', + 0x00, 0x00, 0x00, 0x05, b'w', b'o', b'r', b'l', b'd', + 0x20, + ]; + let expected = MostTypes { + a: 42, + b: 0xDEAD_BEEF, + c: 0xCAFE_BABE_FACE_FEED, + d: 0xABCD, + e: true, + f: *b"example", + g: "hello".to_string(), + h: b"world".to_vec(), + i: 0x20u8, + }; + assert_eq!(&data, expected.encode_vec().unwrap().as_slice()); + let most_types = MostTypes::::decode(&mut &data[..]).unwrap(); + assert_eq!(most_types, expected); +} + +#[test] +fn derive_encode_reference() { + let data = b"\x00\x00\x00\x07example"; + let expected = Reference(&data[4..]); + assert_eq!(data, expected.encode_vec().unwrap().as_slice()); +} + +#[test] +fn derive_encode_decode_roundtrip_length_prefixed() { + #[rustfmt::skip] + let data = [ + 0x00, 0x00, 0x00, 0x11, + 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x2A, + 0x00, 0x00, 0x00, 0x05, b'h', b'e', b'l', b'l', b'o', + ]; + let expected = LengthPrefixed { + a: 42, + b: "hello".to_string(), + }; + assert_eq!(&data, expected.encode_vec().unwrap().as_slice()); + let length_prefixed = LengthPrefixed::decode(&mut &data[..]).unwrap(); + assert_eq!(length_prefixed, expected); +} + +#[test] +fn derive_encode_decode_empty() { + let data = [0u8; 0]; + let expected = Empty; + assert_eq!(data, expected.encode_vec().unwrap().as_slice()); + let actual = Empty::decode(&mut &data[..]).unwrap(); + assert_eq!(actual, expected); +} + +#[test] +fn derive_encode_decode_enum_unit() { + #[rustfmt::skip] + let data = [ + 0, 0, 0, 1, // Length prefix of entire enum. + 1, // Discriminant for Foo::Bar. + ]; + let expected = ComplexEnum::Bar; + assert_eq!(data, expected.encode_vec().unwrap().as_slice()); + let actual = ComplexEnum::decode(&mut &data[..]).unwrap(); + assert_eq!(actual, expected); +} + +#[test] +fn derive_encode_decode_enum_struct() { + #[rustfmt::skip] + let data = [ + 0, 0, 0, 10, // Length prefix of entire enum. + 2, // Discriminant for Foo::Baz. + 0, 0, 0, 1, // Value of Foo::Baz::a. + 0, 0, 0, 1, // Length prefix of Foo::Baz::b. + 2 // Value of Foo::Baz::b. + ]; + let expected = ComplexEnum::Baz { a: 1, b: 2 }; + assert_eq!(data, expected.encode_vec().unwrap().as_slice()); + let actual = ComplexEnum::decode(&mut &data[..]).unwrap(); + assert_eq!(actual, expected); +} + +#[test] +fn derive_encode_decode_enum_tuple() { + #[rustfmt::skip] + let data = [ + 0, 0, 0, 10, // Length prefix of entire enum. + 3, // Discriminant for Foo::Fiz. + 0, 0, 0, 1, // Value of Foo::Fiz::0. + 0, 0, 0, 1, // Length prefix of Foo::Fiz::1. + 2 // Value of Foo::Fiz::1. + ]; + let expected = ComplexEnum::Fiz(1, 2); + assert_eq!(data, expected.encode_vec().unwrap().as_slice()); + let actual = ComplexEnum::decode(&mut &data[..]).unwrap(); + assert_eq!(actual, expected); +} + +#[test] +fn derive_encode_decode_enum_no_prefix_u32_repr() { + #[rustfmt::skip] + let data = [ + 0, 0, 0, 1, // Discriminant for Bar::A. + ]; + let expected = SimpleEnum::A; + assert_eq!(data, expected.encode_vec().unwrap().as_slice()); + let actual = SimpleEnum::decode(&mut &data[..]).unwrap(); + assert_eq!(actual, expected); +}