diff --git a/devtools/decodemsg.c b/devtools/decodemsg.c index 4c332d716978..ada69aa7cc57 100644 --- a/devtools/decodemsg.c +++ b/devtools/decodemsg.c @@ -12,10 +12,14 @@ int main(int argc, char *argv[]) { const u8 *m; bool onion = false; + char *tlv_name = NULL; setup_locale(); opt_register_noarg("--onion", opt_set_bool, &onion, "Decode an error message instead of a peer message"); + opt_register_arg("--tlv", opt_set_charp, opt_show_charp, + &tlv_name, + "Deocde a TLV of this type instead of a peer message"); opt_register_noarg("--help|-h", opt_usage_and_exit, "[]" "Decode a lightning spec wire message from hex, or a series of messages from stdin", @@ -32,7 +36,12 @@ int main(int argc, char *argv[]) errx(1, "'%s' is not valid hex", argv[1]); if (onion) - printonion_type_message(m); + if (tlv_name) + printonion_type_tlv_message(tlv_name, m); + else + printonion_type_message(m); + else if (tlv_name) + printwire_type_tlv_message(tlv_name, m); else printwire_type_message(m); } else { @@ -54,7 +63,12 @@ int main(int argc, char *argv[]) } m = tal_dup_arr(f, u8, f + off, be16_to_cpu(len), 0); if (onion) - printonion_type_message(m); + if (tlv_name) + printonion_type_tlv_message(tlv_name, m); + else + printonion_type_message(m); + else if (tlv_name) + printwire_type_tlv_message(tlv_name, m); else printwire_type_message(m); off += be16_to_cpu(len); diff --git a/tools/generate-wire.py b/tools/generate-wire.py index 040b22082268..986ac32af998 100755 --- a/tools/generate-wire.py +++ b/tools/generate-wire.py @@ -293,20 +293,23 @@ def _guess_type(message, fieldname, base_size): printwire_header_templ = """void printwire_{name}(const char *fieldname, const u8 *cursor); """ -printwire_impl_templ = """void printwire_{name}(const char *fieldname, const u8 *cursor) -{{ -\tsize_t plen = tal_count(cursor); + +printwire_toplevel_tmpl = """\tsize_t plen = tal_count(cursor); \tif (fromwire_u16(&cursor, &plen) != {enum.name}) {{ \t\tprintf("WRONG TYPE?!\\n"); \t\treturn; -\t}} +\t}}""" -{subcalls} +printwire_impl_templ = """{is_internal}void printwire_{name}(const char *fieldname, const u8 *{cursor_ptr}cursor{tlv_args}) +{{ +{toplevel_msg_setup}{subcalls}{lencheck} +}} +""" +printwire_lencheck = """ \tif (plen != 0) \t\tprintf("EXTRA: %s\\n", tal_hexstr(NULL, cursor, plen)); -}} """ @@ -707,43 +710,50 @@ def print_towire(self, is_header, tlv_name): subcalls=str(subcalls), ) - def add_truncate_check(self, subcalls): + def add_truncate_check(self, subcalls, ref): # Report if truncated, otherwise print. - subcalls.append('if (!cursor) {\n' - 'printf("**TRUNCATED**\\n");\n' - 'return;\n' - '}') + call = 'if (!{}cursor) {{\nprintf("**TRUNCATED**\\n");\nreturn;\n}}'.format(ref) + subcalls.append(call) - def print_printwire_array(self, subcalls, basetype, f, num_elems): + def print_printwire_array(self, subcalls, basetype, f, num_elems, ref): + truncate_check_ref = '' if ref else '*' if f.has_array_helper(): - subcalls.append('printwire_{}_array(tal_fmt(NULL, "%s.{}", fieldname), &cursor, &plen, {});' - .format(basetype, f.name, num_elems)) + subcalls.append('printwire_{}_array(tal_fmt(NULL, "%s.{}", fieldname), {}cursor, {}plen, {});' + .format(basetype, f.name, ref, ref, num_elems)) else: subcalls.append('printf("[");') subcalls.append('for (size_t i = 0; i < {}; i++) {{' .format(num_elems)) subcalls.append('{} v;'.format(f.fieldtype.name)) if f.fieldtype.is_assignable(): - subcalls.append('v = fromwire_{}(&cursor, plen);' - .format(f.fieldtype.name, basetype)) + subcalls.append('v = fromwire_{}({}cursor, {}plen);' + .format(f.fieldtype.name, basetype, ref, ref)) else: # We don't handle this yet! assert(basetype not in varlen_structs) - subcalls.append('fromwire_{}(&cursor, &plen, &v);' - .format(basetype)) + subcalls.append('fromwire_{}({}cursor, {}plen, &v);' + .format(basetype, ref, ref)) - self.add_truncate_check(subcalls) + self.add_truncate_check(subcalls, truncate_check_ref) subcalls.append('printwire_{}(tal_fmt(NULL, "%s.{}", fieldname), &v);' .format(basetype, f.name)) subcalls.append('}') subcalls.append('printf("]");') - def print_printwire(self, is_header): + def print_printwire(self, is_header, is_tlv=False): template = printwire_header_templ if is_header else printwire_impl_templ fields = ['\t{} {};\n'.format(f.fieldtype.name, f.name) for f in self.fields if f.is_len_var] + tlv_args = '' if not is_tlv else ', size_t *plen' + ref = '&' if not is_tlv else '' + truncate_check_ref = '' if not is_tlv else '*' + + toplevel_msg_setup = '' + if not is_tlv: + toplevel_msg_setup = printwire_toplevel_tmpl.format(enum=self.enum) + subcalls = CCode() for f in self.fields: basetype = f.fieldtype.base() @@ -752,50 +762,59 @@ def print_printwire(self, is_header): subcalls.append('/*{} */'.format(c)) if f.is_len_var: - subcalls.append('{} {} = fromwire_{}(&cursor, &plen);' - .format(f.fieldtype.name, f.name, basetype)) - self.add_truncate_check(subcalls) + if f.fieldtype.is_var_int(): + subcalls.append('{} {} = fromwire_{}({}cursor, {}plen);' + .format(basetype, f.name, 'var_int', ref, ref)) + else: + subcalls.append('{} {} = fromwire_{}({}cursor, {}plen);' + .format(f.fieldtype.name, f.name, basetype, ref, ref)) + self.add_truncate_check(subcalls, truncate_check_ref) continue subcalls.append('printf("{}=");'.format(f.name)) if f.is_padding(): - subcalls.append('printwire_pad(tal_fmt(NULL, "%s.{}", fieldname), &cursor, &plen, {});' - .format(f.name, f.num_elems)) - self.add_truncate_check(subcalls) + subcalls.append('printwire_pad(tal_fmt(NULL, "%s.{}", fieldname), {}cursor, {}plen, {});' + .format(f.name, ref, ref, f.num_elems)) + self.add_truncate_check(subcalls, truncate_check_ref) elif f.is_array(): - self.print_printwire_array(subcalls, basetype, f, f.num_elems) - self.add_truncate_check(subcalls) + self.print_printwire_array(subcalls, basetype, f, f.num_elems, ref) + self.add_truncate_check(subcalls, truncate_check_ref) elif f.is_variable_size(): - self.print_printwire_array(subcalls, basetype, f, f.lenvar) - self.add_truncate_check(subcalls) + self.print_printwire_array(subcalls, basetype, f, f.lenvar, ref) + self.add_truncate_check(subcalls, truncate_check_ref) else: if f.optional: - subcalls.append("if (fromwire_bool(&cursor, &plen)) {") + subcalls.append("if (fromwire_bool({}cursor, {}plen)) {".format(ref, ref)) if f.is_assignable(): - subcalls.append('{} {} = fromwire_{}(&cursor, &plen);' - .format(f.fieldtype.name, f.name, basetype)) + subcalls.append('{} {} = fromwire_{}({}cursor, {}plen);' + .format(f.fieldtype.name, f.name, basetype, ref, ref)) else: # Don't handle these yet. assert(basetype not in varlen_structs) subcalls.append('{} {};'. format(f.fieldtype.name, f.name)) - subcalls.append('fromwire_{}(&cursor, &plen, &{});' - .format(basetype, f.name)) + subcalls.append('fromwire_{}({}cursor, {}plen, &{});' + .format(basetype, ref, ref, f.name)) - self.add_truncate_check(subcalls) + self.add_truncate_check(subcalls, truncate_check_ref) subcalls.append('printwire_{}(tal_fmt(NULL, "%s.{}", fieldname), &{});' .format(basetype, f.name, f.name)) if f.optional: subcalls.append("} else {") - self.add_truncate_check(subcalls) + self.add_truncate_check(subcalls, truncate_check_ref) subcalls.append("}") + len_check = '' if is_tlv else printwire_lencheck return template.format( + tlv_args=tlv_args, name=self.name, fields=''.join(fields), - enum=self.enum, - subcalls=str(subcalls) + toplevel_msg_setup=toplevel_msg_setup, + subcalls=str(subcalls), + lencheck=len_check, + cursor_ptr=('' if not is_tlv else '*'), + is_internal=('' if not is_tlv else 'static ') ) def print_struct(self): @@ -906,6 +925,28 @@ def print_struct(self): \t\t\tbreak; """ +print_tlv_template = """static void printwire_{tlv_name}(const char *fieldname, const u8 *cursor) +{{ +\tu8 msg_type; +\tu64 msg_size; +\tsize_t plen = tal_count(cursor); + +\twhile (cursor) {{ +\t\tmsg_type = fromwire_u8(&cursor, &plen); +\t\tmsg_size = fromwire_var_int(&cursor, &plen); +\t\tif (!cursor) +\t\t\tbreak; +\t\tswitch ((enum {tlv_name}_type)msg_type) {{ +\t\t\t{printcases} +\t\t\tdefault: +\t\t\t\tprintf("WARNING:No message matching type %d\\n", msg_type); +\t\t}} +\t}} +\tif (plen != 0) +\t\tprintf("EXTRA: %s\\n", tal_hexstr(NULL, cursor, plen)); +}} +""" + def build_tlv_fromwires(tlv_fields): fromwires = [] @@ -970,6 +1011,32 @@ def find_message(messages, name): return None +def print_tlv_printwire(tlv_name, messages): + printcases = '' + for m in messages: + printcases += 'case {enum.name}: printf("{enum.name} (size %"PRIu64"):\\n", msg_size); printwire_{name}("{name}", &cursor, &plen); break;'.format( + enum=m.enum, name=m.name, tlv_name=tlv_name) + return print_tlv_template.format( + tlv_name=tlv_name, + printcases=printcases) + + +def print_tlv_printwires(tlv_fields): + decls = [] + switches = '' + for name, messages in tlv_fields.items(): + # Print each of the message parsers + decls += [m.print_printwire(options.header, is_tlv=True) for m in messages] + + # Print the TLV body parser + decls.append(print_tlv_printwire(name, messages)) + + # Print the 'master' print_tlv_messages cases + switches += tlv_switch_template.format(tlv_name=name) + decls.append(print_master_tlv_template.format(tlv_switches=switches)) + return decls + + def find_message_with_option(messages, optional_messages, name, option): fullname = name + "_" + option.replace('-', '_') @@ -1152,6 +1219,12 @@ def build_tlv_structs(tlv_fields): {func_decls} """ +print_tlv_message_printwire_empty = """void print{enumname}_tlv_message(const char *tlv_name, const u8 *msg) +{{ +\tprintf("~~ No TLV definition found for %s ~~\\n", tlv_name); +}} +""" + print_header_template = """/* This file was generated by generate-wire.py */ /* Do not modify this file! Modify the _csv file it was generated from. */ #ifndef LIGHTNING_{idem} @@ -1162,6 +1235,8 @@ def build_tlv_structs(tlv_fields): void print{enumname}_message(const u8 *msg); +void print{enumname}_tlv_message(const char *tlv_name, const u8 *msg); + {func_decls} #endif /* LIGHTNING_{idem} */ """ @@ -1172,6 +1247,7 @@ def build_tlv_structs(tlv_fields): #include #include #include +#include #include void print{enumname}_message(const u8 *msg) @@ -1186,6 +1262,21 @@ def build_tlv_structs(tlv_fields): {func_decls} """ +print_master_tlv_template = """ +void print_tlv_message(const char *tlv_name, const u8 *msg) +{{ +\t{tlv_switches} +\tprintf("ERR: Unknown TLV message type: %s\n", tlv_name); +}} +""" + +tlv_switch_template = """ +\tif (strcmp(tlv_name, "{tlv_name}") == 0) {{ +\t\tprintwire_{tlv_name}("{tlv_name}", msg); +\t\treturn; +\t}} +""" + idem = re.sub(r'[^A-Z]+', '_', options.headerfilename.upper()) if options.printwire: if options.header: @@ -1207,7 +1298,12 @@ def build_tlv_structs(tlv_fields): printcases = ['case {enum.name}: printf("{enum.name}:\\n"); printwire_{name}("{name}", msg); return;'.format(enum=m.enum, name=m.name) for m in toplevel_messages] if options.printwire: - decls = [m.print_printwire(options.header) for m in messages + messages_with_option] + decls = [m.print_printwire(options.header) for m in toplevel_messages + messages_with_option] + if not options.header: + if len(tlv_fields): + decls += print_tlv_printwires(tlv_fields) + else: + decls += [print_tlv_message_printwire_empty.format(enumname=options.enumname)] else: towire_decls = [] fromwire_decls = []