Skip to content

Commit

Permalink
THRIFT-5337 Go set fields write improvement
Browse files Browse the repository at this point in the history
Client: go

There is a duplicate elements check for set in writeFields* function,
and it compares elements using reflect.DeepEqual which is expensive.

It's much faster that generates a *Equals* function for set elements and
call it in duplicate elements check, especially for nested struct
element.

Closes #2307.
  • Loading branch information
simon0-o authored and fishy committed Feb 4, 2021
1 parent 93d2099 commit 4aaef75
Show file tree
Hide file tree
Showing 4 changed files with 549 additions and 9 deletions.
199 changes: 192 additions & 7 deletions compiler/cpp/src/thrift/generate/t_go_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ class t_go_generator : public t_generator {
const string& tstruct_name,
bool is_result = false,
bool uses_countsetfields = false);
void generate_go_struct_equals(std::ostream& out, t_struct* tstruct, const string& tstruct_name);
void generate_go_function_helpers(t_function* tfunction);
void get_publicized_name_and_def_value(t_field* tfield,
string* OUT_pub_name,
Expand Down Expand Up @@ -229,6 +230,12 @@ class t_go_generator : public t_generator {

void generate_serialize_list_element(std::ostream& out, t_list* tlist, std::string iter);

void generate_go_equals(std::ostream& out, t_type* ttype, string tgt, string src);

void generate_go_equals_struct(std::ostream& out, t_type* ttype, string tgt, string src);

void generate_go_equals_container(std::ostream& out, t_type* ttype, string tgt, string src);

void generate_go_docstring(std::ostream& out, t_struct* tstruct);

void generate_go_docstring(std::ostream& out, t_function* tfunction);
Expand Down Expand Up @@ -307,6 +314,7 @@ class t_go_generator : public t_generator {
std::set<std::string> package_identifiers_set_;
std::string read_method_name_;
std::string write_method_name_;
std::string equals_method_name_;

std::set<std::string> commonInitialisms;

Expand Down Expand Up @@ -724,6 +732,7 @@ void t_go_generator::init_generator() {
read_method_name_ = "Read";
write_method_name_ = "Write";
}
equals_method_name_ = "Equals";

while (true) {
// TODO: Do better error checking here.
Expand Down Expand Up @@ -912,7 +921,6 @@ string t_go_generator::go_imports_begin(bool consts) {
std::vector<string> system_packages;
system_packages.push_back("bytes");
system_packages.push_back("context");
system_packages.push_back("reflect");
// If not writing constants, and there are enums, need extra imports.
if (!consts && get_program()->get_enums().size() > 0) {
system_packages.push_back("database/sql/driver");
Expand All @@ -937,7 +945,6 @@ string t_go_generator::go_imports_end() {
"var _ = thrift.ZERO\n"
"var _ = fmt.Printf\n"
"var _ = context.Background\n"
"var _ = reflect.DeepEqual\n"
"var _ = time.Now\n"
"var _ = bytes.Equal\n\n");
}
Expand Down Expand Up @@ -1482,6 +1489,9 @@ void t_go_generator::generate_go_struct_definition(ostream& out,
generate_isset_helpers(out, tstruct, tstruct_name, is_result);
generate_go_struct_reader(out, tstruct, tstruct_name, is_result);
generate_go_struct_writer(out, tstruct, tstruct_name, is_result, num_setable > 0);
if (!is_result && !is_args) {
generate_go_struct_equals(out, tstruct, tstruct_name);
}

out << indent() << "func (p *" << tstruct_name << ") String() string {" << endl;
out << indent() << " if p == nil {" << endl;
Expand Down Expand Up @@ -1851,6 +1861,61 @@ void t_go_generator::generate_go_struct_writer(ostream& out,
}
}

void t_go_generator::generate_go_struct_equals(ostream& out,
t_struct* tstruct,
const string& tstruct_name) {
string name(tstruct->get_name());
const vector<t_field*>& fields = tstruct->get_sorted_members();
vector<t_field*>::const_iterator f_iter;
indent(out) << "func (p *" << tstruct_name << ") " << equals_method_name_ << "(other *"
<< tstruct_name << ") bool {" << endl;
indent_up();

string field_name;
string publicize_field_name;
out << indent() << "if p == other {" << endl;
indent_up();
out << indent() << "return true" << endl;
indent_down();
out << indent() << "} else if p == nil || other == nil {" << endl;
indent_up();
out << indent() << "return false" << endl;
indent_down();
out << indent() << "}" << endl;

for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
field_name = (*f_iter)->get_name();
t_type* field_type = (*f_iter)->get_type();
publicize_field_name = publicize(field_name);
string goType = type_to_go_type_with_opt(field_type, is_pointer_field(*f_iter));

string tgt = "p." + publicize_field_name;
string src = "other." + publicize_field_name;
t_type* ttype = field_type->get_true_type();
// Compare field contents
if (is_pointer_field(*f_iter)
&& (ttype->is_base_type() || ttype->is_enum() || ttype->is_container())) {
string tgtv = "(*" + tgt + ")";
string srcv = "(*" + src + ")";
out << indent() << "if " << tgt << " != " << src << " {" << endl;
indent_up();
out << indent() << "if " << tgt << " == nil || " << src << " == nil {" << endl;
indent_up();
out << indent() << "return false" << endl;
indent_down();
out << indent() << "}" << endl;
generate_go_equals(out, field_type, tgtv, srcv);
indent_down();
out << indent() << "}" << endl;
} else {
generate_go_equals(out, field_type, tgt, src);
}
}
out << indent() << "return true" << endl;
indent_down();
out << indent() << "}" << endl << endl;
}

/**
* Generates a thrift service.
*
Expand Down Expand Up @@ -3389,15 +3454,30 @@ void t_go_generator::generate_serialize_container(ostream& out,
} else if (ttype->is_set()) {
t_set* tset = (t_set*)ttype;
out << indent() << "for i := 0; i<len(" << prefix << "); i++ {" << endl;
out << indent() << " for j := i+1; j<len(" << prefix << "); j++ {" << endl;
indent_up();
out << indent() << "for j := i+1; j<len(" << prefix << "); j++ {" << endl;
indent_up();
string wrapped_prefix = prefix;
if (pointer_field) {
wrapped_prefix = "(" + prefix + ")";
}
out << indent() << " if reflect.DeepEqual(" << wrapped_prefix << "[i]," << wrapped_prefix << "[j]) { " << endl;
out << indent() << " return thrift.PrependError(\"\", fmt.Errorf(\"%T error writing set field: slice is not unique\", " << wrapped_prefix << "[i]))" << endl;
out << indent() << " }" << endl;
out << indent() << " }" << endl;
string goType = type_to_go_type(tset->get_elem_type());
out << indent() << "if func(tgt, src " << goType << ") bool {" << endl;
indent_up();
generate_go_equals(out, tset->get_elem_type(), "tgt", "src");
out << indent() << "return true" << endl;
indent_down();
out << indent() << "}(" << wrapped_prefix << "[i], " << wrapped_prefix << "[j]) {" << endl;
indent_up();
out << indent()
<< "return thrift.PrependError(\"\", fmt.Errorf(\"%T error writing set field: slice is not "
"unique\", "
<< wrapped_prefix << "))" << endl;
indent_down();
out << indent() << "}" << endl;
indent_down();
out << indent() << "}" << endl;
indent_down();
out << indent() << "}" << endl;
out << indent() << "for _, v := range " << prefix << " {" << endl;
indent_up();
Expand Down Expand Up @@ -3463,6 +3543,111 @@ void t_go_generator::generate_serialize_list_element(ostream& out, t_list* tlist
generate_serialize_field(out, &efield, prefix);
}

/**
* Compares any type
*/
void t_go_generator::generate_go_equals(ostream& out, t_type* ori_type, string tgt, string src) {

t_type* ttype = get_true_type(ori_type);
// Do nothing for void types
if (ttype->is_void()) {
throw "compiler error: cannot generate equals for void type: " + tgt;
}

if (ttype->is_struct() || ttype->is_xception()) {
generate_go_equals_struct(out, ttype, tgt, src);
} else if (ttype->is_container()) {
generate_go_equals_container(out, ttype, tgt, src);
} else if (ttype->is_base_type() || ttype->is_enum()) {
out << indent() << "if ";
if (ttype->is_base_type()) {
t_base_type::t_base tbase = ((t_base_type*)ttype)->get_base();
switch (tbase) {
case t_base_type::TYPE_VOID:
throw "compiler error: cannot equals void: " + tgt;
break;

case t_base_type::TYPE_STRING:
if (ttype->is_binary()) {
out << "bytes.Compare(" << tgt << ", " << src << ") != 0";
} else {
out << tgt << " != " << src;
}
break;

case t_base_type::TYPE_BOOL:
case t_base_type::TYPE_I8:
case t_base_type::TYPE_I16:
case t_base_type::TYPE_I32:
case t_base_type::TYPE_I64:
case t_base_type::TYPE_DOUBLE:
out << tgt << " != " << src;
break;

default:
throw "compiler error: no Go name for base type " + t_base_type::t_base_name(tbase);
}
} else if (ttype->is_enum()) {
out << tgt << " != " << src;
}

out << " { return false }" << endl;
} else {
throw "compiler error: Invalid type in generate_go_equals '" + ttype->get_name() + "' for '"
+ tgt + "'";
}
}

/**
* Compares the members of a struct
*/
void t_go_generator::generate_go_equals_struct(ostream& out,
t_type* ttype,
string tgt,
string src) {
(void)ttype;
out << indent() << "if !" << tgt << "." << equals_method_name_ << "(" << src
<< ") { return false }" << endl;
}

/**
* Compares any container type
*/
void t_go_generator::generate_go_equals_container(ostream& out,
t_type* ttype,
string tgt,
string src) {
out << indent() << "if len(" << tgt << ") != len(" << src << ") { return false }" << endl;
if (ttype->is_map()) {
t_map* tmap = (t_map*)ttype;
out << indent() << "for k, _tgt := range " << tgt << " {" << endl;
indent_up();
string element_source = tmp("_src");
out << indent() << element_source << " := " << src << "[k]" << endl;
generate_go_equals(out, tmap->get_val_type(), "_tgt", element_source);
indent_down();
indent(out) << "}" << endl;
} else if (ttype->is_list() || ttype->is_set()) {
t_type* elem;
if (ttype->is_list()) {
t_list* temp = (t_list*)ttype;
elem = temp->get_elem_type();
} else {
t_set* temp = (t_set*)ttype;
elem = temp->get_elem_type();
}
out << indent() << "for i, _tgt := range " << tgt << " {" << endl;
indent_up();
string element_source = tmp("_src");
out << indent() << element_source << " := " << src << "[i]" << endl;
generate_go_equals(out, elem, "_tgt", element_source);
indent_down();
indent(out) << "}" << endl;
} else {
throw "INVALID TYPE IN generate_go_equals_container '" + ttype->get_name();
}
}

/**
* Generates the docstring for a given struct.
*/
Expand Down
109 changes: 109 additions & 0 deletions lib/go/test/EqualsTest.thrift
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
typedef i8 mybyte
typedef string mystr
typedef binary mybin

enum EnumFoo {
e1
e2
}

struct BasicEqualsFoo {
1: bool BoolFoo,
2: optional bool OptBoolFoo,
3: i8 I8Foo,
4: optional i8 OptI8Foo,
5: i16 I16Foo,
6: optional i16 OptI16Foo,
7: i32 I32Foo,
8: optional i32 OptI32Foo,
9: i64 I64Foo,
10: optional i64 OptI64Foo,
11: double DoubleFoo,
12: optional double OptDoubleFoo,
13: string StrFoo,
14: optional string OptStrFoo,
15: binary BinFoo,
16: optional binary OptBinFoo,
17: EnumFoo EnumFoo,
18: optional EnumFoo OptEnumFoo,
19: mybyte MyByteFoo,
20: optional mybyte OptMyByteFoo,
21: mystr MyStrFoo,
22: optional mystr OptMyStrFoo,
23: mybin MyBinFoo,
24: optional mybin OptMyBinFoo,
}

struct StructEqualsFoo {
1: BasicEqualsFoo StructFoo,
2: optional BasicEqualsFoo OptStructFoo,
}

struct ListEqualsFoo {
1: list<i64> I64ListFoo,
2: optional list<i64> OptI64ListFoo,
3: list<string> StrListFoo,
4: optional list<string> OptStrListFoo,
5: list<binary> BinListFoo,
6: optional list<binary> OptBinListFoo,
7: list<BasicEqualsFoo> StructListFoo,
8: optional list<BasicEqualsFoo> OptStructListFoo,
9: list<list<i64>> I64ListListFoo,
10: optional list<list<i64>> OptI64ListListFoo,
11: list<set<i64>> I64SetListFoo,
12: optional list<set<i64>> OptI64SetListFoo,
13: list<map<i64, string>> I64StringMapListFoo,
14: optional list<map<i64, string>> OptI64StringMapListFoo,
15: list<mybyte> MyByteListFoo,
16: optional list<mybyte> OptMyByteListFoo,
17: list<mystr> MyStrListFoo,
18: optional list<mystr> OptMyStrListFoo,
19: list<mybin> MyBinListFoo,
20: optional list<mybin> OptMyBinListFoo,
}

struct SetEqualsFoo {
1: set<i64> I64SetFoo,
2: optional set<i64> OptI64SetFoo,
3: set<string> StrSetFoo,
4: optional set<string> OptStrSetFoo,
5: set<binary> BinSetFoo,
6: optional set<binary> OptBinSetFoo,
7: set<BasicEqualsFoo> StructSetFoo,
8: optional set<BasicEqualsFoo> OptStructSetFoo,
9: set<list<i64>> I64ListSetFoo,
10: optional set<list<i64>> OptI64ListSetFoo,
11: set<set<i64>> I64SetSetFoo,
12: optional set<set<i64>> OptI64SetSetFoo,
13: set<map<i64, string>> I64StringMapSetFoo,
14: optional set<map<i64, string>> OptI64StringMapSetFoo,
15: set<mybyte> MyByteSetFoo,
16: optional set<mybyte> OptMyByteSetFoo,
17: set<mystr> MyStrSetFoo,
18: optional set<mystr> OptMyStrSetFoo,
19: set<mybin> MyBinSetFoo,
20: optional set<mybin> OptMyBinSetFoo,
}

struct MapEqualsFoo {
1: map<i64, string> I64StrMapFoo,
2: optional map<i64, string> OptI64StrMapFoo,
3: map<string, i64> StrI64MapFoo,
4: optional map<string, i64> OptStrI64MapFoo,
5: map<BasicEqualsFoo, binary> StructBinMapFoo,
6: optional map<BasicEqualsFoo, binary> OptStructBinMapFoo,
7: map<binary, BasicEqualsFoo> BinStructMapFoo,
8: optional map<binary, BasicEqualsFoo> OptBinStructMapFoo,
9: map<i64, list<i64>> I64I64ListMapFoo,
10: optional map<i64, list<i64>> OptI64I64ListMapFoo,
11: map<i64, set<i64>> I64I64SetMapFoo,
12: optional map<i64, set<i64>> OptI64I64SetMapFoo,
13: map<i64, map<i64, string>> I64I64StringMapMapFoo,
14: optional map<i64, map<i64, string>> OptI64I64StringMapMapFoo,
15: map<mystr, mybin> MyStrMyBinMapFoo,
16: optional map<mystr, mybin> OptMyStrMyBinMapFoo,
17: map<i64, mybyte> Int64MyByteMapFoo,
18: optional map<i64, mybyte> OptInt64MyByteMapFoo,
19: map<mybyte, i64> MyByteInt64MapFoo,
20: optional map<mybyte, i64> OptMyByteInt64MapFoo,
}
Loading

0 comments on commit 4aaef75

Please sign in to comment.