Skip to content

Commit

Permalink
Merge pull request #22 from Aniket21mathur/messagefield
Browse files Browse the repository at this point in the history
Add support for Message fields
  • Loading branch information
Aniket21mathur committed Aug 12, 2020
2 parents af0afd9 + f7082b8 commit fc3c3ed
Show file tree
Hide file tree
Showing 15 changed files with 350 additions and 3 deletions.
57 changes: 57 additions & 0 deletions src/ProtobufProtocolSupport.chpl
Expand Up @@ -260,6 +260,26 @@ module ProtobufProtocolSupport {
return val;
}

proc messageAppendBase(val, ch:writingChannel) throws {
var initialOffset = ch.offset();
ch.mark();
val._writeToOutputFile(ch);
var currentOffset = ch.offset();
ch.revert();
unsignedVarintAppend((currentOffset-initialOffset):uint, ch);
val._writeToOutputFile(ch);
}

proc messageConsumeBase(ch:readingChannel, ref messageObj, memWriter:writingChannel,
memReader:readingChannel) throws {
var s: bytes;
var (payloadLength, _) = unsignedVarintConsume(ch);
ch.readbytes(s, payloadLength:int);
memWriter.write(s);
memWriter.close();
messageObj._parseFromInputFile(memReader);
}

proc writeToOutputFileHelper(ref message, ch) throws {
ch.lock();
defer { ch.unlock(); }
Expand Down Expand Up @@ -456,6 +476,22 @@ module ProtobufProtocolSupport {
return enumConsumeBase(ch);
}

proc messageAppend(val, fieldNumber: int, ch:writingChannel) throws {
tagAppend(fieldNumber, lengthDelimited, ch);
messageAppendBase(val, ch);
}

proc messageConsume(ch:readingChannel, type messageType) throws {
var tmpMem = openmem();
var memWriter = tmpMem.writer(kind=iokind.little, locking=false);
var memReader = tmpMem.reader(kind=iokind.little, locking=false);

var tmpObj: messageType;
messageConsumeBase(ch, tmpObj, memWriter, memReader);
tmpMem.close();
return tmpObj;
}

proc consumeUnknownField(fieldNumber: int, wireType: int, ch: readingChannel): bytes throws {
/*
Opening a file, and generating a writing channel to give as an argument to the
Expand Down Expand Up @@ -905,6 +941,27 @@ module ProtobufProtocolSupport {
}
return returnList;
}

proc messageRepeatedAppend(valList, fieldNumber: int, ch: writingChannel) throws {
if valList.isEmpty() then return;
for val in valList {
tagAppend(fieldNumber, lengthDelimited, ch);
messageAppendBase(val, ch);
}
}

proc messageRepeatedConsume(ch: readingChannel, type messageType) throws {
var returnList: list(messageType);
var tmpMem = openmem();
var memWriter = tmpMem.writer(kind=iokind.little, locking=false);
var memReader = tmpMem.reader(kind=iokind.little, locking=false);

var tmpObj: messageType;
messageConsumeBase(ch, tmpObj, memWriter, memReader);
returnList.append(tmpObj);
tmpMem.close();
return returnList;
}

}

Expand Down
4 changes: 3 additions & 1 deletion src/plugin/Makefile.am
Expand Up @@ -11,7 +11,9 @@ protoc_gen_chpl_SOURCES = \
field_base.cpp\
enum.cpp\
enum_field.cpp\
repeated_enum_field.cpp
repeated_enum_field.cpp\
message_field.cpp\
repeated_message_field.cpp
protoc_gen_chpl_CPPFLAGS = \
-I.
protoc_gen_chpl_LDADD = -lprotobuf -lprotoc
4 changes: 4 additions & 0 deletions src/plugin/field_base.cpp
Expand Up @@ -83,6 +83,8 @@ namespace chapel {
return "int(64)";
case FieldDescriptor::TYPE_ENUM:
return descriptor->enum_type()->name();
case FieldDescriptor::TYPE_MESSAGE:
return descriptor->message_type()->name();
default:
GOOGLE_LOG(FATAL)<< "Unknown field type.";
return "";
Expand Down Expand Up @@ -123,6 +125,8 @@ namespace chapel {
return "sfixed64";
case FieldDescriptor::TYPE_ENUM:
return "enum";
case FieldDescriptor::TYPE_MESSAGE:
return "message";
default:
GOOGLE_LOG(FATAL)<< "Unknown field type.";
return "";
Expand Down
8 changes: 8 additions & 0 deletions src/plugin/helpers.cpp
Expand Up @@ -27,6 +27,8 @@
#include <repeated_primitive_field.h>
#include <enum_field.h>
#include <repeated_enum_field.h>
#include <message_field.h>
#include <repeated_message_field.h>
#include <field_base.h>

namespace chapel {
Expand Down Expand Up @@ -96,6 +98,12 @@ namespace chapel {

FieldGeneratorBase* CreateFieldGenerator(const FieldDescriptor* descriptor) {
switch (descriptor->type()) {
case FieldDescriptor::TYPE_MESSAGE:
if (descriptor->is_repeated()) {
return new RepeatedMessageFieldGenerator(descriptor);
} else {
return new MessageFieldGenerator(descriptor);
}
case FieldDescriptor::TYPE_ENUM:
if (descriptor->is_repeated()) {
return new RepeatedEnumFieldGenerator(descriptor);
Expand Down
20 changes: 18 additions & 2 deletions src/plugin/message.cpp
Expand Up @@ -99,7 +99,15 @@ namespace chapel {
printer->Indent();

for (int i = 0; i < descriptor_->field_count(); i++) {
if(vars[i]["proto_field_type"] == "enum") {
if(vars[i]["proto_field_type"] == "message") {
if(vars[i]["is_repeated"] == "0") {
printer->Print(vars[i],
"$proto_field_type$Append($field_name$, $field_number$, binCh);\n");
} else {
printer->Print(vars[i],
"$proto_field_type$RepeatedAppend($field_name$, $field_number$, binCh);\n");
}
} else if(vars[i]["proto_field_type"] == "enum") {
if(vars[i]["is_repeated"] == "0") {
printer->Print(vars[i],
"$proto_field_type$Append($field_name$:uint(64), $field_number$, binCh);\n");
Expand Down Expand Up @@ -141,7 +149,15 @@ namespace chapel {
for (int i = 0; i < descriptor_->field_count(); i++) {
printer->Print(vars[i],
"when $field_number$ {\n");
if(vars[i]["proto_field_type"] == "enum") {
if(vars[i]["proto_field_type"] == "message") {
if(vars[i]["is_repeated"] == "0") {
printer->Print(vars[i],
" $field_name$ = $proto_field_type$Consume(binCh, $type_name$);\n");
} else {
printer->Print(vars[i],
" $field_name$.extend($proto_field_type$RepeatedConsume(binCh, $type_name$));\n");
}
} else if(vars[i]["proto_field_type"] == "enum") {
if(vars[i]["is_repeated"] == "0") {
printer->Print(vars[i],
" $field_name$ = $proto_field_type$Consume(binCh):$type_name$;\n");
Expand Down
39 changes: 39 additions & 0 deletions src/plugin/message_field.cpp
@@ -0,0 +1,39 @@
/*
* Copyright 2020 Hewlett Packard Enterprise Development LP
* Copyright 2004-2019 Cray Inc.
* Other additional copyright holders may be indicated within.
*
* The entirety of this work is licensed under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
* in compliance with the License.
*
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <message_field.h>

namespace chapel {

MessageFieldGenerator::MessageFieldGenerator(const FieldDescriptor* descriptor)
: FieldGeneratorBase(descriptor) {
}

MessageFieldGenerator::~MessageFieldGenerator() {
}

void MessageFieldGenerator::GenerateMembers(Printer* printer) {
printer->Print(
variables_,
"var $name$: $type_name$;\n"
);
}

} // namespace chapel
44 changes: 44 additions & 0 deletions src/plugin/message_field.h
@@ -0,0 +1,44 @@
/*
* Copyright 2020 Hewlett Packard Enterprise Development LP
* Copyright 2004-2019 Cray Inc.
* Other additional copyright holders may be indicated within.
*
* The entirety of this work is licensed under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
* in compliance with the License.
*
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PB_MESSAGE_FIELD_HH
#define PB_MESSAGE_FIELD_HH

#include <google/protobuf/io/printer.h>
#include <google/protobuf/descriptor.h>

#include <field_base.h>

namespace chapel {

using namespace google::protobuf;
using namespace google::protobuf::io;

class MessageFieldGenerator : public FieldGeneratorBase {
public:
MessageFieldGenerator(const FieldDescriptor* descriptor);
~MessageFieldGenerator();

void GenerateMembers(Printer* printer);
};

} // namespace chapel

#endif /* PB_MESSAGE_FIELD_HH */
1 change: 1 addition & 0 deletions src/plugin/reflection_class.cpp
Expand Up @@ -61,6 +61,7 @@ namespace chapel {
for (int i = 0; i < file_->message_type_count(); i++) {
MessageGenerator messageGenerator(file_->message_type(i));
messageGenerator.Generate(printer);
printer->Print("\n");
}
}

Expand Down
38 changes: 38 additions & 0 deletions src/plugin/repeated_message_field.cpp
@@ -0,0 +1,38 @@
/*
* Copyright 2020 Hewlett Packard Enterprise Development LP
* Copyright 2004-2019 Cray Inc.
* Other additional copyright holders may be indicated within.
*
* The entirety of this work is licensed under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
* in compliance with the License.
*
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <repeated_message_field.h>

namespace chapel {

RepeatedMessageFieldGenerator::RepeatedMessageFieldGenerator(
const FieldDescriptor* descriptor)
: FieldGeneratorBase(descriptor) {
}

RepeatedMessageFieldGenerator::~RepeatedMessageFieldGenerator() {

}

void RepeatedMessageFieldGenerator::GenerateMembers(Printer* printer) {
printer->Print(variables_, "var $name$: list($type_name$);\n");
}

} // namespace chapel
44 changes: 44 additions & 0 deletions src/plugin/repeated_message_field.h
@@ -0,0 +1,44 @@
/*
* Copyright 2020 Hewlett Packard Enterprise Development LP
* Copyright 2004-2019 Cray Inc.
* Other additional copyright holders may be indicated within.
*
* The entirety of this work is licensed under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
* in compliance with the License.
*
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PB_REPEATED_MESSAGE_FIELD_HH
#define PB_REPEATED_MESSAGE_FIELD_HH

#include <google/protobuf/io/printer.h>
#include <google/protobuf/descriptor.h>

#include <field_base.h>

namespace chapel {

using namespace google::protobuf;
using namespace google::protobuf::io;

class RepeatedMessageFieldGenerator : public FieldGeneratorBase {
public:
RepeatedMessageFieldGenerator(const FieldDescriptor* descriptor);
~RepeatedMessageFieldGenerator();

void GenerateMembers(Printer* printer);
};

} // namespace chapel

#endif /* PB_REPEATED_MESSAGE_FIELD_HH */
16 changes: 16 additions & 0 deletions test/endToEnd/messagefield/protoFile/messagefield.proto
@@ -0,0 +1,16 @@
syntax = "proto3";

message messageA {
messageB a = 1;
repeated messageC f = 2;
}

message messageB {
int64 b = 1;
string c = 2;
}

message messageC {
int32 d = 1;
bool e = 2;
}
16 changes: 16 additions & 0 deletions test/endToEnd/messagefield/read.chpl
@@ -0,0 +1,16 @@
use messagefield;
use IO;
use List;

var messageObj = new messageA();
var file = open("out", iomode.r);
var readingChannel = file.reader();

messageObj.parseFromInputFile(readingChannel);

writeln(messageObj.a.b == 150);
writeln(messageObj.a.c == "String with spaces");
writeln(messageObj.f[0].d == 26);
writeln(messageObj.f[0].e == true);
writeln(messageObj.f[1].d == 36);
writeln(messageObj.f[1].e == false);
22 changes: 22 additions & 0 deletions test/endToEnd/messagefield/read.py
@@ -0,0 +1,22 @@
import messagefield_pb2

messageObj = messagefield_pb2.messageA()

file = open("out", "rb")
messageObj.ParseFromString(file.read())
file.close()

if messageObj.a.b != 150 or messageObj.a.c != "String with spaces":
print("false")
else:
print("true")

if messageObj.f[0].d != 26 or messageObj.f[0].e != True:
print("false")
else:
print("true")

if messageObj.f[1].d != 36 or messageObj.f[1].e != False:
print("false")
else:
print("true")

0 comments on commit fc3c3ed

Please sign in to comment.