diff --git a/Cargo.toml b/Cargo.toml index 4a9568e7..89fd4556 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,6 @@ signature = { version = "2.2.0", features = ["std"] } # For PEM decoding pem = { version = "3", optional = true } -simple_asn1 = { version = "0.6", optional = true } # "aws_lc_rs" feature aws-lc-rs = { version = "1.15.0", optional = true } @@ -43,6 +42,7 @@ p384 = { version = "0.13.0", optional = true, features = ["ecdsa"] } rand = { version = "0.8.5", optional = true, features = ["std"], default-features = false } rsa = { version = "0.9.6", optional = true } sha2 = { version = "0.10.7", optional = true, features = ["oid"] } +zeroize = { version = "1.8.2", features = ["derive"] } [target.'cfg(target_arch = "wasm32")'.dependencies] js-sys = "0.3" @@ -65,7 +65,7 @@ criterion = { version = "0.8", default-features = false } [features] default = ["use_pem"] -use_pem = ["pem", "simple_asn1"] +use_pem = ["pem"] rust_crypto = ["ed25519-dalek", "hmac", "p256", "p384", "rand", "rsa", "sha2"] aws_lc_rs = ["aws-lc-rs"] diff --git a/src/pem/decoder.rs b/src/pem/decoder.rs index 94ff8821..3e28aa55 100644 --- a/src/pem/decoder.rs +++ b/src/pem/decoder.rs @@ -1,3 +1,5 @@ +use zeroize::{Zeroize, ZeroizeOnDrop}; + use crate::errors::{ErrorKind, Result}; /// Supported PEM files for EC and RSA Public and Private Keys @@ -40,11 +42,12 @@ enum Classification { /// Documentation about these formats is at /// PKCS#1: https://tools.ietf.org/html/rfc8017 /// PKCS#8: https://tools.ietf.org/html/rfc5958 -#[derive(Debug)] +#[derive(Debug, ZeroizeOnDrop, Zeroize)] pub(crate) struct PemEncodedKey { content: Vec, - asn1: Vec, + #[zeroize(skip)] pem_type: PemType, + #[zeroize(skip)] standard: Standard, } @@ -53,22 +56,15 @@ impl PemEncodedKey { pub fn new(input: &[u8]) -> Result { match pem::parse(input) { Ok(content) => { - let asn1_content = match simple_asn1::from_der(content.contents()) { - Ok(asn1) => asn1, - Err(_) => return Err(ErrorKind::InvalidKeyFormat.into()), - }; - match content.tag() { // This handles a PKCS#1 RSA Private key "RSA PRIVATE KEY" => Ok(PemEncodedKey { content: content.into_contents(), - asn1: asn1_content, pem_type: PemType::RsaPrivate, standard: Standard::Pkcs1, }), "RSA PUBLIC KEY" => Ok(PemEncodedKey { content: content.into_contents(), - asn1: asn1_content, pem_type: PemType::RsaPublic, standard: Standard::Pkcs1, }), @@ -79,41 +75,22 @@ impl PemEncodedKey { // This handles PKCS#8 certificates and public & private keys tag @ "PRIVATE KEY" | tag @ "PUBLIC KEY" | tag @ "CERTIFICATE" => { - match classify_pem(&asn1_content) { - Some(c) => { - let is_private = tag == "PRIVATE KEY"; - let pem_type = match c { - Classification::Ec => { - if is_private { - PemType::EcPrivate - } else { - PemType::EcPublic - } - } - Classification::Ed => { - if is_private { - PemType::EdPrivate - } else { - PemType::EdPublic - } - } - Classification::Rsa => { - if is_private { - PemType::RsaPrivate - } else { - PemType::RsaPublic - } - } - }; - Ok(PemEncodedKey { - content: content.into_contents(), - asn1: asn1_content, - pem_type, - standard: Standard::Pkcs8, - }) - } - None => Err(ErrorKind::InvalidKeyFormat.into()), - } + let is_private = tag == "PRIVATE KEY"; + let pem_type = match classify_der(content.contents()) + .ok_or(ErrorKind::InvalidKeyFormat)? + { + Classification::Ec if is_private => PemType::EcPrivate, + Classification::Ec => PemType::EcPublic, + Classification::Ed if is_private => PemType::EdPrivate, + Classification::Ed => PemType::EdPublic, + Classification::Rsa if is_private => PemType::RsaPrivate, + Classification::Rsa => PemType::RsaPublic, + }; + Ok(PemEncodedKey { + content: content.into_contents(), + pem_type, + standard: Standard::Pkcs8, + }) } // Unknown/unsupported type @@ -140,7 +117,8 @@ impl PemEncodedKey { match self.standard { Standard::Pkcs1 => Err(ErrorKind::InvalidKeyFormat.into()), Standard::Pkcs8 => match self.pem_type { - PemType::EcPublic => extract_first_bitstring(&self.asn1), + PemType::EcPublic => extract_first_bitstring_der(&self.content) + .ok_or_else(|| ErrorKind::InvalidKeyFormat.into()), _ => Err(ErrorKind::InvalidKeyFormat.into()), }, } @@ -162,7 +140,8 @@ impl PemEncodedKey { match self.standard { Standard::Pkcs1 => Err(ErrorKind::InvalidKeyFormat.into()), Standard::Pkcs8 => match self.pem_type { - PemType::EdPublic => extract_first_bitstring(&self.asn1), + PemType::EdPublic => extract_first_bitstring_der(&self.content) + .ok_or_else(|| ErrorKind::InvalidKeyFormat.into()), _ => Err(ErrorKind::InvalidKeyFormat.into()), }, } @@ -173,68 +152,211 @@ impl PemEncodedKey { match self.standard { Standard::Pkcs1 => Ok(self.content.as_slice()), Standard::Pkcs8 => match self.pem_type { - PemType::RsaPrivate => extract_first_bitstring(&self.asn1), - PemType::RsaPublic => extract_first_bitstring(&self.asn1), + PemType::RsaPrivate | PemType::RsaPublic => { + extract_first_bitstring_der(&self.content) + .ok_or_else(|| ErrorKind::InvalidKeyFormat.into()) + } _ => Err(ErrorKind::InvalidKeyFormat.into()), }, } } } +const TAG_BIT_STRING: u8 = 0x03; +const TAG_OCTET_STRING: u8 = 0x04; +const TAG_OID: u8 = 0x06; +const TAG_SEQUENCE: u8 = 0x30; + // This really just finds and returns the first bitstring or octet string // Which is the x coordinate for EC public keys // And the DER contents of an RSA key // Though PKCS#11 keys shouldn't have anything else. // It will get confusing with certificates. -fn extract_first_bitstring(asn1: &[simple_asn1::ASN1Block]) -> Result<&[u8]> { - for asn1_entry in asn1.iter() { - match asn1_entry { - simple_asn1::ASN1Block::Sequence(_, entries) => { - if let Ok(result) = extract_first_bitstring(entries) { - return Ok(result); +fn extract_first_bitstring_der(bytes: &[u8]) -> Option<&[u8]> { + let mut stack = vec![bytes]; + + // Depth-first search in the DER tree for the first bitstring or octet string + while let Some(bytes) = stack.pop() { + let Some((tag, value, rest)) = read_tlv(bytes) else { + continue; // Skip invalid TLV + }; + + if !rest.is_empty() { + stack.push(rest); + } + + match tag { + // See [ITU-T X.690] §8.6 for the encoding of a bit string. + // + // [ITU-T X.690]: https://www.itu.int/rec/T-REC-X.690 + TAG_BIT_STRING => { + if value.is_empty() { + return None; // Missing the length of the unused bits in the last byte + } else if value[0] != 0 { + // The content wrapped in a bit string is byte aligned + // + // See + // * DER-encoded SEQUENCE for RSA keys (https://www.rfc-editor.org/rfc/rfc8017) + // * EC point (https://www.rfc-editor.org/rfc/rfc5480#section-2.2) + // * raw 32-byte key for Ed25519 (https://www.rfc-editor.org/rfc/rfc8032) + return None; } + return Some(&value[1..]); } - simple_asn1::ASN1Block::BitString(_, _, value) => { - return Ok(value.as_ref()); - } - simple_asn1::ASN1Block::OctetString(_, value) => { - return Ok(value.as_ref()); + TAG_OCTET_STRING => return Some(value), + TAG_SEQUENCE => { + stack.push(value); } - _ => (), + _ => {} } } - Err(ErrorKind::InvalidEcdsaKey.into()) + None } /// Find whether this is EC, RSA, or Ed -fn classify_pem(asn1: &[simple_asn1::ASN1Block]) -> Option { - // These should be constant but the macro requires - // #![feature(const_vec_new)] - let ec_public_key_oid = simple_asn1::oid!(1, 2, 840, 10_045, 2, 1); - let rsa_public_key_oid = simple_asn1::oid!(1, 2, 840, 113_549, 1, 1, 1); - let ed25519_oid = simple_asn1::oid!(1, 3, 101, 112); - - for asn1_entry in asn1.iter() { - match asn1_entry { - simple_asn1::ASN1Block::Sequence(_, entries) => { - if let Some(classification) = classify_pem(entries) { - return Some(classification); - } - } - simple_asn1::ASN1Block::ObjectIdentifier(_, oid) => { - if oid == ec_public_key_oid { - return Some(Classification::Ec); - } - if oid == rsa_public_key_oid { - return Some(Classification::Rsa); - } - if oid == ed25519_oid { - return Some(Classification::Ed); - } +fn classify_der(bytes: &[u8]) -> Option { + const EC_PUBLIC_KEY_OID: &[u8] = &[0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01]; // 1.2.840.10045.2.1 + const RSA_PUBLIC_KEY_OID: &[u8] = &[0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01]; // 1.2.840.113549.1.1.1 + const ED25519_OID: &[u8] = &[0x2B, 0x65, 0x70]; // 1.3.101.112 + + let mut stack = vec![bytes]; + + // Depth-first search in the DER tree for one of the above OIDs + while let Some(bytes) = stack.pop() { + let Some((tag, value, rest)) = read_tlv(bytes) else { + continue; // Skip invalid TLV + }; + + if !rest.is_empty() { + stack.push(rest); + } + + if tag == TAG_OID { + match value { + EC_PUBLIC_KEY_OID => return Some(Classification::Ec), + RSA_PUBLIC_KEY_OID => return Some(Classification::Rsa), + ED25519_OID => return Some(Classification::Ed), + _ => {} } - _ => {} + } else if tag == TAG_SEQUENCE { + stack.push(value); } } + None } + +/// Returns `Some((tag, value, rest))` or `None` if the TLV is invalid. +/// +/// See or [ITU-T X.690] §8.1 for the BER/DER TLV encoding. +/// +/// [ITU-T X.690]: https://www.itu.int/rec/T-REC-X.690 +fn read_tlv(mut bytes: &[u8]) -> Option<(u8, &[u8], &[u8])> { + if bytes.is_empty() { + return None; + } + + // See for tag encoding + let first = bytes[0]; + bytes = &bytes[1..]; + let tag = if first & 0x1f == 0x1f { + // Long form multi-byte tag + // Skip subsequent tag bytes (high bit = continuation) + while !bytes.is_empty() && bytes[0] & 0x80 != 0 { + bytes = &bytes[1..]; // Skip bytes as long as the MSB is set + } + if bytes.is_empty() { + return None; + } + bytes = &bytes[1..]; // final tag byte (high bit clear) + 0xFF // Sentinel value for any long-form tag + } else { + // Short form single-byte tag + first + }; + + // See for length encoding + let len = bytes[0]; + bytes = &bytes[1..]; + let len = if len < 0x80 { + // 0-127: short form + len as usize + } else { + // 128-255: long form => number of bytes without high bit set + let len_len = (len & 0x7f) as usize; + if len_len == 0 { + return None; // Indefinite length; forbidden in DER + } else if size_of::() < len_len { + return None; // Too long; prevents usize overflow + } else if bytes.len() < len_len { + return None; // Not enough bytes + } + let (len_bytes, rest) = bytes.split_at(len_len); + bytes = rest; + // Big-endian base 256 encoding + len_bytes.iter().fold(0, |acc, &x| acc * 256 + x as usize) + }; + + if bytes.len() < len { + return None; // Not enough bytes + } + + let (value, rest) = bytes.split_at(len); + Some((tag, value, rest)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn classify_ec_key() { + let pem = pem::parse(include_bytes!("../../tests/ecdsa/public_ecdsa_key.pem")).unwrap(); + assert_eq!(classify_der(pem.contents()), Some(Classification::Ec)); + } + + #[test] + fn classify_rsa_key() { + let pem = pem::parse(include_bytes!("../../tests/rsa/public_rsa_key_pkcs8.pem")).unwrap(); + assert_eq!(classify_der(pem.contents()), Some(Classification::Rsa)); + } + + #[test] + fn classify_ed25519_key() { + let pem = pem::parse(include_bytes!("../../tests/eddsa/public_ed25519_key.pem")).unwrap(); + assert_eq!(classify_der(pem.contents()), Some(Classification::Ed)); + } + + #[test] + fn ec_public_key_extraction() { + let key = + PemEncodedKey::new(include_bytes!("../../tests/ecdsa/public_ecdsa_key.pem")).unwrap(); + let bytes = key.as_ec_public_key().unwrap(); + assert_eq!(bytes[0], 0x04); // uncompressed point + assert_eq!(bytes.len(), 65); // 1 + 32 + 32 for P-256 + } + + #[test] + fn ed_public_key_extraction() { + let key = + PemEncodedKey::new(include_bytes!("../../tests/eddsa/public_ed25519_key.pem")).unwrap(); + let bytes = key.as_ed_public_key().unwrap(); + assert_eq!(bytes.len(), 32); + } + + #[test] + fn rsa_pkcs8_key_extraction() { + let key = + PemEncodedKey::new(include_bytes!("../../tests/rsa/public_rsa_key_pkcs8.pem")).unwrap(); + let bytes = key.as_rsa_key().unwrap(); + assert_eq!(bytes[0], 0x30); // SEQUENCE + } + #[test] + fn rsa_pkcs1_key() { + let key = PemEncodedKey::new(include_bytes!("../../tests/rsa/private_rsa_key_pkcs1.pem")) + .unwrap(); + let bytes = key.as_rsa_key().unwrap(); + assert_eq!(bytes[0], 0x30); // SEQUENCE + } +}