diff --git a/project/ProjectPlugin.scala b/project/ProjectPlugin.scala index ca7998c..24f424e 100644 --- a/project/ProjectPlugin.scala +++ b/project/ProjectPlugin.scala @@ -27,6 +27,7 @@ object ProjectPlugin extends AutoPlugin { val scalaTest: String = "3.1.0" val scalatestScalacheck: String = "3.1.0.1" val scalacheckShapeless: String = "1.2.3" + val protocJar: String = "3.11.1" } } @@ -67,7 +68,8 @@ object ProjectPlugin extends AutoPlugin { "com.beachape" %% "enumeratum" % V.enumeratum, "org.scalatest" %% "scalatest" % V.scalaTest % Test, "org.scalatestplus" %% "scalacheck-1-14" % V.scalatestScalacheck % Test, - "com.github.alexarchambault" %% "scalacheck-shapeless_1.14" % V.scalacheckShapeless % Test + "com.github.alexarchambault" %% "scalacheck-shapeless_1.14" % V.scalacheckShapeless % Test, + "com.github.os72" % "protoc-jar" % V.protocJar % Test ), orgScriptTaskListSetting := List( (clean in Global).asRunnableItemFull, diff --git a/src/test/resources/proto/ProtocComparisonSpec.proto b/src/test/resources/proto/ProtocComparisonSpec.proto new file mode 100644 index 0000000..bc232d2 --- /dev/null +++ b/src/test/resources/proto/ProtocComparisonSpec.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +message MessageOne { + double dbl = 1; + bool boolean = 2; +} + +message MessageTwo { + int64 long = 1; + MessageOne messageOne = 2; + MessageOne messageOneOption = 3; + oneof coproduct { + MessageOne a = 4; + MessageTwo b = 5; // recursion! + int32 c = 6; + } +} + +message MessageThree { + int32 int = 2; + repeated int32 packedInts = 3; + repeated int32 unpackedInts = 4 [packed=false]; + int32 intOption = 5; + map stringStringMap = 6; + string _string = 7; + bytes _bytes = 8; + MessageOne messageOne = 9; + MessageOne messageOneOption = 10; + MessageTwo messageTwo = 11; + map intMessageTwoMap = 12; +} diff --git a/src/test/scala/pbdirect/ProtocComparisonSpec.scala b/src/test/scala/pbdirect/ProtocComparisonSpec.scala new file mode 100644 index 0000000..13aaed0 --- /dev/null +++ b/src/test/scala/pbdirect/ProtocComparisonSpec.scala @@ -0,0 +1,189 @@ +package pbdirect + +import org.scalatest.flatspec._ +import org.scalatestplus.scalacheck.Checkers +import org.scalacheck.ScalacheckShapeless._ +import org.scalacheck.Prop._ +import com.github.os72.protocjar._ +import scala.sys.process._ +import java.io._ +import shapeless._ + +class ProtocComparisonSpec extends AnyFlatSpec with Checkers { + import ProtocComparisonSpec._ + + implicit override val generatorDrivenConfig = + PropertyCheckConfiguration(minSuccessful = 500) + + val protoc: File = Protoc.extractProtoc(ProtocVersion.PROTOC_VERSION, true) + val workingDir: File = new File(".") + val protoFile: File = new File("src/test/resources/proto/ProtocComparisonSpec.proto") + val protocCommand = + s"${protoc.getAbsolutePath} --proto_path=${workingDir.getAbsolutePath} --encode=MessageThree $protoFile" + + "pbdirect" should "write the same bytes as protoc does" in check { + forAllNoShrink { (message: MessageThree) => + val pbdirectOutputBytes = message.toPB.toList + + val textFormattedMessage = TextFormatEncoding.messageThree(message) + val in = new ByteArrayInputStream(textFormattedMessage.getBytes) + val out = new ByteArrayOutputStream() + protocCommand.#<(in).#>(out).! + val protocOutputBytes = out.toByteArray.toList + + val label = + s"""|_bytes = ${message._bytes.toList} + | + |text formatted message = + |$textFormattedMessage + | + |binary output of pbdirect = + |$pbdirectOutputBytes + | + |binary output of protoc = + |$protocOutputBytes""".stripMargin + + label |: pbdirectOutputBytes == protocOutputBytes + } + } +} + +object ProtocComparisonSpec { + + case class MessageOne( + dbl: Double, + boolean: Boolean + ) + + case class MessageTwo( + long: Long, + messageOne: MessageOne, + messageOneOption: Option[MessageOne], + @pbIndex(4, 5, 6) coproduct: Option[MessageOne :+: MessageTwo :+: Int :+: CNil] + ) + + case class MessageThree( + @pbIndex(2) int: Int, + @pbIndex(3) packedInts: List[Int], + @pbIndex(4) @pbUnpacked unpackedInts: List[Int], + @pbIndex(5) intOption: Option[Int], + @pbIndex(6) stringStringMap: Map[String, String], + @pbIndex(7) _string: String, + @pbIndex(8) _bytes: Array[Byte], + @pbIndex(9) messageOne: MessageOne, + @pbIndex(10) messageOneOption: Option[MessageOne], + @pbIndex(11) messageTwo: MessageTwo, + @pbIndex(12) intMessageTwoMap: Map[Int, MessageTwo] + ) + + object TextFormatEncoding { + + /* + * Note: + * we could define some kind of Encoder[A] type class + * to encode a value of type A as protobuf text format. + * But then we'd be basically reinventing the whole + * library but for text instead of binary, and then using + * that to test the behaviour of the library. Feels a bit + * weird. + * + * So instead we keep things simple: write out the encoder + * by hand and only support the exact types used in the test. + */ + + def indent(string: String): String = + string.split('\n').map(line => s" $line").mkString("\n") + + def messageOne(m: MessageOne): String = + s"""|dbl: ${m.dbl} + |boolean: ${m.boolean}""".stripMargin + + def embeddedMessageOne(m: MessageOne): String = + s"""|{ + |${indent(messageOne(m))} + |}""".stripMargin + + def option[A](fieldName: String, opt: Option[A])(f: A => String): String = + opt.fold("")(a => s"$fieldName: ${f(a)}") + + def coproduct(cop: MessageOne :+: MessageTwo :+: Int :+: CNil): String = { + cop match { + case Inl(messageOne) => s"a: ${embeddedMessageOne(messageOne)}" + case Inr(Inl(messageTwo)) => s"b: ${embeddedMessageTwo(messageTwo)}" + case Inr(Inr(Inl(int))) => s"c: $int" + case _ => "" + } + } + + def messageTwo(m: MessageTwo): String = + s"""|long: ${m.long} + |messageOne: ${embeddedMessageOne(m.messageOne)} + |${option("messageOneOption", m.messageOneOption)(embeddedMessageOne)} + |${m.coproduct.fold("")(coproduct)} + |""".stripMargin + + def embeddedMessageTwo(m: MessageTwo): String = + s"""|{ + |${indent(messageTwo(m))} + |}""".stripMargin + + def bytes(xs: Array[Byte]): String = + xs.map(b => s"\\${(b.toInt & 0xFF).toOctalString}").mkString + + def string(x: String): String = { + val escaped = x + .replaceAllLiterally("""\""", """\\""") // escape backslashes + .replaceAllLiterally("\n", "\\n") // escape newlines and other control characters + .replaceAllLiterally("\r", "\\r") + .replaceAllLiterally("\b", "\\b") + .replaceAllLiterally("\f", "\\f") + .replaceAllLiterally("\t", "\\t") + .replaceAllLiterally("\u0000", "\\0") + .replaceAllLiterally(""""""", """\"""") // esape double quotes + s""""$escaped"""" // wrap in double quotes + } + + def stringStringMap(map: Map[String, String]): String = { + map + .map { + case (key, value) => + s"""|stringStringMap: { + | key: ${string(key)} + | value: ${string(value)} + |}""".stripMargin + } + .mkString("\n") + } + + def intMessageTwoMap(map: Map[Int, MessageTwo]): String = { + map + .map { + case (key, value) => + s"""|intMessageTwoMap: { + | key: ${key} + | value: { + |${indent(indent(messageTwo(value)))} + | } + |}""".stripMargin + } + .mkString("\n") + } + + def messageThree(m: MessageThree): String = { + s"""|int: ${m.int} + |packedInts: [${m.packedInts.mkString(", ")}] + |unpackedInts: [${m.unpackedInts.mkString(", ")}] + |${option("intOption", m.intOption)(_.toString)} + |${stringStringMap(m.stringStringMap)} + |_string: ${string(m._string)} + |_bytes: "${bytes(m._bytes)}" + |messageOne: ${embeddedMessageOne(m.messageOne)} + |${option("messageOneOption", m.messageOneOption)(embeddedMessageOne)} + |messageTwo: ${embeddedMessageTwo(m.messageTwo)} + |${intMessageTwoMap(m.intMessageTwoMap)} + |""".stripMargin + } + + } + +}