diff --git a/core/src/main/scala/com/dwolla/security/crypto/CryptoAlg.scala b/core/src/main/scala/com/dwolla/security/crypto/CryptoAlg.scala index a63a953..4ff474a 100644 --- a/core/src/main/scala/com/dwolla/security/crypto/CryptoAlg.scala +++ b/core/src/main/scala/com/dwolla/security/crypto/CryptoAlg.scala @@ -180,18 +180,24 @@ object CryptoAlg extends CryptoAlgPlatform { // and we can't use that key ID to lookup the key val recipientKeyId = Option(pbe.getKeyID).filterNot(_ == 0) - pbe.decryptToInputStream(keylike, recipientKeyId) - .map(_.pure[Option]) - .recoverWith { - case ex: KeyRingMissingKeyException => - Logger[F] - .trace(ex)(s"could not decrypt using key ${pbe.getKeyID}") - .as(None) - case ex: KeyMismatchException => - Logger[F] - .trace(ex)(s"could not decrypt using key ${pbe.getKeyID}") - .as(None) - } + // if the recipient is identified, check if it exists in the key material we have + // if it does, or if the recipient is undefined, try to decrypt. + if (recipientKeyId.exists(DecryptToInputStream[F, A].hasKeyId(keylike, _)) || recipientKeyId.isEmpty) + pbe + .decryptToInputStream(keylike, recipientKeyId) + .map(_.pure[Option]) + .recoverWith { + case ex: KeyRingMissingKeyException => + Logger[F] + .trace(ex)(s"could not decrypt using key ${pbe.getKeyID}") + .as(None) + case ex: KeyMismatchException => + Logger[F] + .trace(ex)(s"could not decrypt using key ${pbe.getKeyID}") + .as(None) + } + else + none[InputStream].pure[F] case other => Logger[F].warn(EncryptionTypeError)(s"found wrong type of encrypted data: $other").as(None) diff --git a/core/src/main/scala/com/dwolla/security/crypto/DecryptToInputStream.scala b/core/src/main/scala/com/dwolla/security/crypto/DecryptToInputStream.scala index 592b6a8..6c7a78e 100644 --- a/core/src/main/scala/com/dwolla/security/crypto/DecryptToInputStream.scala +++ b/core/src/main/scala/com/dwolla/security/crypto/DecryptToInputStream.scala @@ -13,6 +13,8 @@ import scala.jdk.CollectionConverters._ private[crypto] sealed trait DecryptToInputStream[F[_], A] { def decryptToInputStream(input: A, maybeKeyId: Option[Long]) (pbed: PGPPublicKeyEncryptedData): F[InputStream] + + def hasKeyId(input: A, id: Long): Boolean } private[crypto] object DecryptToInputStream { @@ -60,6 +62,9 @@ private[crypto] object DecryptToInputStream { implicit def PGPSecretKeyRingCollectionInstance[F[_] : Sync]: DecryptToInputStream[F, (PGPSecretKeyRingCollection, Array[Char])] = new DecryptToInputStream[F, (PGPSecretKeyRingCollection, Array[Char])] { + override def hasKeyId(input: (PGPSecretKeyRingCollection, Array[Char]), id: Long): Boolean = + input._1.contains(id) + override def decryptToInputStream(input: (PGPSecretKeyRingCollection, Array[Char]), maybeKeyId: Option[Long]) (pbed: PGPPublicKeyEncryptedData): F[InputStream] = @@ -77,6 +82,13 @@ private[crypto] object DecryptToInputStream { implicit def PGPSecretKeyRingInstance[F[_] : Sync]: DecryptToInputStream[F, (PGPSecretKeyRing, Array[Char])] = new DecryptToInputStream[F, (PGPSecretKeyRing, Array[Char])] { + override def hasKeyId(input: (PGPSecretKeyRing, Array[Char]), id: Long): Boolean = + input + ._1 + .getSecretKeys + .asScala + .exists(_.getKeyID == id) + override def decryptToInputStream(input: (PGPSecretKeyRing, Array[Char]), maybeKeyId: Option[Long]) (pbed: PGPPublicKeyEncryptedData): F[InputStream] = { @@ -90,6 +102,9 @@ private[crypto] object DecryptToInputStream { implicit def PGPPrivateKeyInstance[F[_] : Sync]: DecryptToInputStream[F, PGPPrivateKey] = new DecryptToInputStream[F, PGPPrivateKey] { + override def hasKeyId(input: PGPPrivateKey, id: Long): Boolean = + input.getKeyID == id + override def decryptToInputStream(input: PGPPrivateKey, maybeKeyId: Option[Long]) (pbed: PGPPublicKeyEncryptedData): F[InputStream] =