Skip to content

Commit

Permalink
support encrypting with multiple recipients
Browse files Browse the repository at this point in the history
note: this is not binary compatible with previous versions of the
library due to the interface changes required to support multiple keys.
  • Loading branch information
bpholt authored and CJSmith-0141 committed Dec 22, 2023
1 parent b21f6af commit cbc14c8
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 57 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ ThisBuild / tlJdkRelease := Option(8)
ThisBuild / githubWorkflowJavaVersions := Seq(JavaSpec.temurin("17"))
ThisBuild / githubWorkflowScalaVersions := Seq("2.13", "2.12")
ThisBuild / tlCiReleaseBranches := Seq("main")
ThisBuild / tlBaseVersion := "0.4"
ThisBuild / tlBaseVersion := "0.5"
ThisBuild / tlSonatypeUseLegacyHost := true
ThisBuild / mergifyStewardConfig ~= {
_.map(_.copy(mergeMinors = true, author = "dwolla-oss-scala-steward[bot]"))
Expand Down
81 changes: 57 additions & 24 deletions core/src/main/scala/com/dwolla/security/crypto/CryptoAlg.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.dwolla.security.crypto

import cats.data.NonEmptyList
import cats.effect._
import cats.effect.syntax.all._
import cats.syntax.all._
Expand All @@ -19,12 +20,16 @@ import java.io._

trait CryptoAlg[F[_]] {
def encrypt(key: PGPPublicKey,
chunkSize: ChunkSize = defaultChunkSize,
fileName: Option[String] = None,
encryption: Encryption = Aes256,
compression: Compression = Zip,
packetFormat: PgpLiteralDataPacketFormat = Binary,
): Pipe[F, Byte, Byte]
moreKeys: PGPPublicKey*): Pipe[F, Byte, Byte] =
encrypt(NonEmptyList.of(key, moreKeys: _*), EncryptionConfig())

def encrypt(config: EncryptionConfig,
key: PGPPublicKey,
moreKeys: PGPPublicKey*): Pipe[F, Byte, Byte] =
encrypt(NonEmptyList.of(key, moreKeys: _*), config)

def encrypt(keys: NonEmptyList[PGPPublicKey],
config: EncryptionConfig): Pipe[F, Byte, Byte]

def decrypt(key: PGPPrivateKey,
chunkSize: ChunkSize,
Expand Down Expand Up @@ -68,8 +73,11 @@ trait CryptoAlg[F[_]] {
}

object CryptoAlg extends CryptoAlgPlatform {
private def addKey[F[_] : Sync](pgpEncryptedDataGenerator: PGPEncryptedDataGenerator, key: PGPPublicKey): F[Unit] =
Sync[F].blocking(pgpEncryptedDataGenerator.addMethod(new JcePublicKeyKeyEncryptionMethodGenerator(key)))
private def addKeys[F[_] : Sync](pgpEncryptedDataGenerator: PGPEncryptedDataGenerator,
keys: NonEmptyList[PGPPublicKey]): F[Unit] =
keys.traverse_ { key =>
Sync[F].delay(pgpEncryptedDataGenerator.addMethod(new JcePublicKeyKeyEncryptionMethodGenerator(key)))
}

private type PgpEncryptionPipelineComponents = (PGPEncryptedDataGenerator, PGPCompressedDataGenerator, PGPLiteralDataGenerator)

Expand All @@ -96,7 +104,7 @@ object CryptoAlg extends CryptoAlgPlatform {
*
* Plaintext -> "Literal Data" Packetizer -> Compressor -> Encryptor -> OutputStream provided by caller
*/
private[crypto] def encryptingOutputStream[F[_] : Sync](key: PGPPublicKey,
private[crypto] def encryptingOutputStream[F[_] : Sync](keys: NonEmptyList[PGPPublicKey],
chunkSize: ChunkSize,
fileName: Option[String],
encryption: Encryption,
Expand All @@ -105,7 +113,7 @@ object CryptoAlg extends CryptoAlgPlatform {
outputStreamIntoWhichToWriteEncryptedBytes: OutputStream): Resource[F, OutputStream] =
pgpGenerators[F](encryption, compression)
.evalTap { case (pgpEncryptedDataGenerator, _, _) =>
addKey[F](pgpEncryptedDataGenerator, key)
addKeys[F](pgpEncryptedDataGenerator, keys)
}
.evalMap { case (pgpEncryptedDataGenerator, pgpCompressedDataGenerator, pgpLiteralDataGenerator) =>
for {
Expand Down Expand Up @@ -135,20 +143,15 @@ object CryptoAlg extends CryptoAlgPlatform {
private val fingerprintCalculator = new JcaKeyFingerprintCalculator
private val closeStreamsAfterUse = false

override def encrypt(key: PGPPublicKey,
chunkSize: ChunkSize,
fileName: Option[String] = None,
encryption: Encryption = Aes256,
compression: Compression = Zip,
packetFormat: PgpLiteralDataPacketFormat = Binary,
): Pipe[F, Byte, Byte] =
override def encrypt(keys: NonEmptyList[PGPPublicKey], config: EncryptionConfig): Pipe[F, Byte, Byte] =
_.through { bytes =>
readOutputStream(chunkSize.value) { outputStreamToRead =>
Stream
.resource(encryptingOutputStream[F](key, chunkSize, fileName, encryption, compression, packetFormat, outputStreamToRead))
.flatMap(wos => bytes.chunkN(chunkSize.value).flatMap(Stream.chunk).through(writeOutputStream(wos.pure[F], closeStreamsAfterUse)))
.compile
.drain
readOutputStream(config.chunkSize.value) { outputStreamToRead =>
Logger[F].trace(s"${List.fill(keys.length)("🔑").mkString("")} encrypting input with ${keys.length} recipients") >>
Stream
.resource(encryptingOutputStream[F](keys, config.chunkSize, config.fileName, config.encryption, config.compression, config.packetFormat, outputStreamToRead))
.flatMap(wos => bytes.chunkN(config.chunkSize.value).flatMap(Stream.chunk).through(writeOutputStream(wos.pure[F], closeStreamsAfterUse)))
.compile
.drain
}
}

Expand Down Expand Up @@ -181,7 +184,13 @@ object CryptoAlg extends CryptoAlgPlatform {
.map(_.pure[Option])
.recoverWith {
case ex: KeyRingMissingKeyException =>
Logger[F].trace(ex)(s"could not decrypt using key ${pbe.getKeyID}").as(None)
Logger[F]
.trace(ex)(s"could not decrypt using key ${pbe.getKeyID}")
.as(None)
case ex: KeyMismatchException =>
Logger[F]
.trace(ex)(s"could not decrypt using key ${pbe.getKeyID}")
.as(None)
}

case other =>
Expand Down Expand Up @@ -261,3 +270,27 @@ trait CryptoAlgPlatform {
@deprecated("use the variant with LoggerFactory instead", "0.4.0")
private[crypto] def apply[F[_]](ev1: Async[F], ev2: Logger[F]): Resource[F, CryptoAlg[F]] = CryptoAlg.withLogger[F](ev1, ev2)
}

class EncryptionConfig private(val chunkSize: ChunkSize,
val fileName: Option[String],
val encryption: Encryption,
val compression: Compression,
val packetFormat: PgpLiteralDataPacketFormat,
) {
private def copy(chunkSize: ChunkSize = this.chunkSize,
fileName: Option[String] = this.fileName,
encryption: Encryption = this.encryption,
compression: Compression = this.compression,
packetFormat: PgpLiteralDataPacketFormat = this.packetFormat): EncryptionConfig =
new EncryptionConfig(chunkSize, fileName, encryption, compression, packetFormat)

def withChunkSize(chunkSize: ChunkSize): EncryptionConfig = copy(chunkSize = chunkSize)
def withFileName(fileName: Option[String]): EncryptionConfig = copy(fileName = fileName)
def withEncryption(encryption: Encryption): EncryptionConfig = copy(encryption = encryption)
def withCompression(compression: Compression): EncryptionConfig = copy(compression = compression)
def withPacketFormat(packetFormat: PgpLiteralDataPacketFormat): EncryptionConfig = copy(packetFormat = packetFormat)
}

object EncryptionConfig {
def apply(): EncryptionConfig = new EncryptionConfig(defaultChunkSize, None, Aes256, Zip, Binary)
}
2 changes: 2 additions & 0 deletions project/BouncyCastlePlugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ object BouncyCastlePlugin extends AutoPlugin {
libraryDependencies ++= {
Seq(
"org.typelevel" %% "log4cats-noop" % "2.6.0" % Test,
"org.typelevel" %% "log4cats-slf4j" % "2.6.0" % Test,
"ch.qos.logback" % "logback-classic" % "1.4.7" % Test,
"org.scalameta" %% "munit" % "0.7.29" % Test,
"org.typelevel" %% "scalacheck-effect" % "1.0.4" % Test,
"org.typelevel" %% "scalacheck-effect-munit" % "1.0.4" % Test,
Expand Down
12 changes: 12 additions & 0 deletions tests/src/test/resources/logback-test.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<configuration>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern>%date | logger=%logger | '%msg'%n</pattern>
</encoder>
</appender>

<root level="ERROR">
<appender-ref ref="STDOUT"/>
</root>
</configuration>
83 changes: 57 additions & 26 deletions tests/src/test/scala/com/dwolla/security/crypto/CryptoAlgSpec.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.dwolla.security.crypto

import cats.data._
import cats.effect._
import cats.syntax.all._
import com.dwolla.testutils._
Expand All @@ -11,9 +12,9 @@ import org.scalacheck.Arbitrary._
import org.scalacheck._
import org.scalacheck.effect.PropF.{forAllF, forAllNoShrinkF}
import org.scalacheck.util.Pretty
import com.eed3si9n.expecty.Expecty.{assert => Assert}
import org.typelevel.log4cats._
import org.typelevel.log4cats.noop.NoOpLogger
import com.eed3si9n.expecty.Expecty.{ assert => Assert }
import org.typelevel.log4cats.slf4j.Slf4jFactory

import java.io.ByteArrayOutputStream
import scala.concurrent.duration._
Expand All @@ -24,45 +25,75 @@ class CryptoAlgSpec
with PgpArbitraries
with CryptoArbitraries {

private implicit val noOpLogger: LoggerFactory[IO] = new LoggerFactory[IO] {
override def getLoggerFromName(name: String): SelfAwareStructuredLogger[IO] = NoOpLogger[IO]
override def fromName(name: String): IO[SelfAwareStructuredLogger[IO]] = NoOpLogger[IO].pure[IO]
}
private implicit val loggerFactory: LoggerFactory[IO] = Slf4jFactory.create[IO]

private val resource: Fixture[CryptoAlg[IO]] = ResourceSuiteLocalFixture("CryptoAlg[IO]", CryptoAlg[IO])
override def munitFixtures = List(resource)
override def munitFixtures: Seq[Fixture[_]] = List(resource)

override protected def scalaCheckTestParameters: Test.Parameters =
Test.Parameters.default
.withMinSuccessfulTests(1)
.withMinSuccessfulTests(2)

override val munitTimeout: Duration = 2.minutes

test("CryptoAlg should round trip the plaintext using a key pair") {
private def genNelResource[F[_], A](implicit A: Arbitrary[Resource[F, A]]): Gen[Resource[F, NonEmptyList[A]]] =
for {
extraCount <- Gen.chooseNum(0, 10)
a <- A.arbitrary
extras <- Gen.listOfN(extraCount, A.arbitrary)
} yield {
for {
aa <- a
ee <- extras.sequence
} yield NonEmptyList.of(aa, ee: _*)
}

private implicit def arbNelResource[F[_], A](implicit A: Arbitrary[Resource[F, A]]): Arbitrary[Resource[F, NonEmptyList[A]]] = Arbitrary(genNelResource[F, A])

test("CryptoAlg should round trip the plaintext using one or more key pairs") {
val crypto = resource()
implicit val arbKeyPair: Arbitrary[Resource[IO, PGPKeyPair]] = arbWeakKeyPair[IO]

forAllF { (keyPairR: Resource[IO, PGPKeyPair],
forAllF { (keyPairsR: Resource[IO, NonEmptyList[PGPKeyPair]],
bytesG: Stream[Pure, Byte],
encryptionChunkSize: ChunkSize,
decryptionChunkSize: ChunkSize) =>
val materializedBytes: List[Byte] = bytesG.compile.toList
val bytes = Stream.emits(materializedBytes)
val testResource =
for {
keyPair <- keyPairR
roundTrip <- bytes
.through(crypto.encrypt(keyPair.getPublicKey, encryptionChunkSize))
.through(crypto.armor(encryptionChunkSize))
.through(crypto.decrypt(keyPair.getPrivateKey, decryptionChunkSize))
.compile
.resource
.toList
} yield roundTrip

testResource.use { roundTrip => IO {
assertEquals(roundTrip, materializedBytes)
}}
implicit0(logger: Logger[IO]) <- LoggerFactory[IO].create(LoggerName("CryptoAlgSpec.round trip")).toResource
_ <- Logger[IO].trace("starting").toResource
keyPairs <- keyPairsR
_ <- Logger[IO].trace("key pairs generated").toResource
allRecipients = keyPairs.map(_.getPublicKey)
privateKeys = keyPairs.map(_.getPrivateKey)
_ <- Logger[IO].trace(s"encrypting with keys ${allRecipients.map(_.getKeyID)}").toResource
encryptedBytes <-
bytes
.through(crypto.encrypt(allRecipients, EncryptionConfig().withChunkSize(encryptionChunkSize)))
.through(crypto.armor(encryptionChunkSize))
.compile
.resource
.toVector

decryptedBytes <-
privateKeys.traverse { privateKey =>
Logger[IO].trace(s"decrypting with key id ${privateKey.getKeyID}").toResource >>
Stream.emits(encryptedBytes)
.through(crypto.decrypt(privateKey, decryptionChunkSize))
.compile
.resource
.toList
}
_ <- Logger[IO].trace("done with round trips").toResource
} yield decryptedBytes

testResource.use {
_.traverse_ { roundTrippedBytes => IO {
assertEquals(roundTrippedBytes, materializedBytes)
}}
}
}
}

Expand All @@ -87,7 +118,7 @@ class CryptoAlgSpec
for {
keyPair <- keyPairR
chunkSizes <- bytes
.through(crypto.encrypt(keyPair.getPublicKey, encryptionChunkSize))
.through(crypto.encrypt(EncryptionConfig().withChunkSize(encryptionChunkSize), keyPair.getPublicKey))
.through(crypto.armor(encryptionChunkSize))
.chunks
.map(_.size)
Expand Down Expand Up @@ -154,7 +185,7 @@ class CryptoAlgSpec
collection <- collectionR
pub <- keysIn[IO](collection).map(_.getPublicKey).find(_.isEncryptionKey).compile.resource.lastOrError
roundTrip <- Stream.emits(materializedBytes)
.through(crypto.encrypt(pub, encryptionChunkSize))
.through(crypto.encrypt(EncryptionConfig().withChunkSize(encryptionChunkSize), pub))
.through(crypto.armor(encryptionChunkSize))
.through(crypto.decrypt(collection, passphrase, decryptionChunkSize))
.compile
Expand Down Expand Up @@ -185,7 +216,7 @@ class CryptoAlgSpec
kp <- keyPairR
ring <- Resource.eval(pgpKeyRingGenerator[IO](keyRingId, kp, passphrase)).map(_.generateSecretKeyRing())
roundTrip <- Stream.emits(materializedBytes)
.through(crypto.encrypt(kp.getPublicKey, encryptionChunkSize))
.through(crypto.encrypt(EncryptionConfig().withChunkSize(encryptionChunkSize), kp.getPublicKey))
.through(crypto.armor(encryptionChunkSize))
.through(crypto.decrypt(ring, passphrase, decryptionChunkSize))
.compile
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ import org.bouncycastle.openpgp.operator.jcajce.{JcaPGPContentSignerBuilder, Jca
import org.scalacheck.Arbitrary
import org.scalacheck.Arbitrary.arbitrary
import org.scalacheck.effect.PropF.forAllF
import org.typelevel.log4cats._
import org.typelevel.log4cats.noop.NoOpLogger
import com.eed3si9n.expecty.Expecty.{assert => Assert}
import org.typelevel.log4cats.LoggerFactory
import org.typelevel.log4cats.slf4j.Slf4jFactory

import scala.jdk.CollectionConverters._

Expand All @@ -25,10 +25,8 @@ class PGPKeyAlgSpec
with ScalaCheckEffectSuite
with PgpArbitraries
with CryptoArbitraries {
private implicit val L: LoggerFactory[IO] = new LoggerFactory[IO] {
override def getLoggerFromName(name: String): SelfAwareStructuredLogger[IO] = NoOpLogger[IO]
override def fromName(name: String): IO[SelfAwareStructuredLogger[IO]] = NoOpLogger[IO].pure[IO]
}

private implicit val loggerFactory: LoggerFactory[IO] = Slf4jFactory.create[IO]

test("PGPKeyAlg should load a PGPPublicKey from armored public key") {
val key =
Expand Down

0 comments on commit cbc14c8

Please sign in to comment.