Skip to content

Commit

Permalink
Merge pull request #2 from SandishKumarHN/redo-array
Browse files Browse the repository at this point in the history
unit tests for protobuf repeated message
  • Loading branch information
SandishKumarHN committed Aug 31, 2022
2 parents 4c9bd74 + 7d60f9e commit 85884f7
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 0 deletions.
19 changes: 19 additions & 0 deletions connector/proto/src/test/resources/protobuf/repeated_message.desc
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

�
Bconnector/proto/src/test/resources/protobuf/repeated_message.protoorg.apache.spark.sql.proto"�
BasicMessage
id (Rid!
string_value ( R stringValue
int32_value (R
int32Value
int64_value (R
int64Value!
double_value (R doubleValue
float_value (R
floatValue

bool_value (R boolValue
bytes_value ( R
bytesValue"`
RepeatedMessageM
basic_message ( 2(.org.apache.spark.sql.proto.BasicMessageR basicMessageBBRepeatedMessageProtosbproto3
Expand Down
22 changes: 22 additions & 0 deletions connector/proto/src/test/resources/protobuf/repeated_message.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// To compile and create test class:
// protoc --java_out=connector/proto/src/test/resources/protobuf/ connector/proto/src/test/resources/protobuf/repeated_message.proto
// protoc --descriptor_set_out=connector/proto/src/test/resources/protobuf/repeated_message.desc --java_out=connector/proto/src/test/resources/protobuf/org/apache/spark/sql/proto/ connector/proto/src/test/resources/protobuf/repeated_message.proto

syntax = "proto3";

package org.apache.spark.sql.proto;
option java_outer_classname = "RepeatedMessageProtos";

message BasicMessage {
int64 id = 1;
string string_value = 2;
int32 int32_value = 3;
int64 int64_value = 4;
double double_value = 5;
float float_value = 6;
bool bool_value = 7;
bytes bytes_value = 8;
}
message RepeatedMessage {
repeated BasicMessage basic_message = 1;
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,20 @@
*/
package org.apache.spark.sql.proto

import com.google.protobuf.ByteString
import org.apache.spark.SparkException
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.proto.SimpleMessageRepeatedProtos.SimpleMessageRepeated
import org.apache.spark.sql.proto.RepeatedMessageProtos.{BasicMessage, RepeatedMessage}
import org.apache.spark.sql.proto.SimpleMessageEnumProtos.{BasicEnum, SimpleMessageEnum}
import org.apache.spark.sql.proto.SimpleMessageMapProtos.SimpleMessageMap
import org.apache.spark.sql.proto.MessageMultipleMessage.{IncludedExample, MultipleExample, OtherExample}
import org.apache.spark.sql.functions.{lit, struct}
import org.apache.spark.sql.proto.SimpleMessageEnumProtos.SimpleMessageEnum.NestedEnum

import java.util

class ProtoFunctionsSuite extends QueryTest with SharedSparkSession with Serializable {
import testImplicits._

Expand Down Expand Up @@ -108,6 +112,62 @@ class ProtoFunctionsSuite extends QueryTest with SharedSparkSession with Seriali
checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
}

test("roundtrip in from_proto and to_proto - Repeated Message Once") {
val messagePath = testFile("protobuf/repeated_message.desc").replace("file:/", "/")
val basicMessage = BasicMessage.newBuilder()
.setId(1111L)
.setStringValue("value")
.setInt32Value(12345)
.setInt64Value(0x90000000000L)
.setDoubleValue(10000000000.0D)
.setBoolValue(true)
.setBytesValue(ByteString.copyFromUtf8("ProtobufDeserializer"))
.build()
val repeatedMessage = RepeatedMessage.newBuilder()
.addBasicMessage(basicMessage)
.build()

val df = Seq(repeatedMessage.toByteArray).toDF("value")
val fromProtoDF = df.select(functions.from_proto($"value", messagePath, "RepeatedMessage").as("value_from"))
val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from", messagePath, "RepeatedMessage").as("value_to"))
val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", messagePath, "RepeatedMessage").as("value_to_from"))
checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
}

test("roundtrip in from_proto and to_proto - Repeated Message Twice") {
val messagePath = testFile("protobuf/repeated_message.desc").replace("file:/", "/")
val baseArray = new util.ArrayList[BasicMessage]();
val basicMessage1 = BasicMessage.newBuilder()
.setId(1111L)
.setStringValue("value1")
.setInt32Value(12345)
.setInt64Value(0x20000000000L)
.setDoubleValue(10000000000.0D)
.setBoolValue(true)
.setBytesValue(ByteString.copyFromUtf8("ProtobufDeserializer1"))
.build()
val basicMessage2 = BasicMessage.newBuilder()
.setId(2222L)
.setStringValue("value2")
.setInt32Value(54321)
.setInt64Value(0x20000000000L)
.setDoubleValue(20000000000.0D)
.setBoolValue(false)
.setBytesValue(ByteString.copyFromUtf8("ProtobufDeserializer2"))
.build()
baseArray.add(basicMessage1)
baseArray.add(basicMessage2)
val repeatedMessage = RepeatedMessage.newBuilder()
.addAllBasicMessage(baseArray)
.build()

val df = Seq(repeatedMessage.toByteArray).toDF("value")
val fromProtoDF = df.select(functions.from_proto($"value", messagePath, "RepeatedMessage").as("value_from"))
val toProtoDF = fromProtoDF.select(functions.to_proto($"value_from", messagePath, "RepeatedMessage").as("value_to"))
val toFromProtoDF = toProtoDF.select(functions.from_proto($"value_to", messagePath, "RepeatedMessage").as("value_to_from"))
checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
}

test("roundtrip in from_proto and to_proto - Map") {
val messagePath = testFile("protobuf/message_with_map.desc").replace("file:/", "/")
val repeatedMessage = SimpleMessageMap.newBuilder()
Expand Down

0 comments on commit 85884f7

Please sign in to comment.