diff --git a/speedy_antlr_tool/__about__.py b/speedy_antlr_tool/__about__.py index 9c73af2..f708a9b 100644 --- a/speedy_antlr_tool/__about__.py +++ b/speedy_antlr_tool/__about__.py @@ -1 +1 @@ -__version__ = "1.3.1" +__version__ = "1.3.2" diff --git a/speedy_antlr_tool/extractor.py b/speedy_antlr_tool/extractor.py index 6b2660c..a89d0f9 100644 --- a/speedy_antlr_tool/extractor.py +++ b/speedy_antlr_tool/extractor.py @@ -53,7 +53,7 @@ def get_rule_labels(context_cls:antlr4.ParserRuleContext) -> List[str]: labels = [] for line in lines: m = re.match(r'self\.(\w+)\s*=\s*None', line.strip()) - if m: + if m and not m.group(1).startswith("_"): labels.append(m.group(1)) return labels diff --git a/speedy_antlr_tool/templates/sa_X_cpp_parser.cpp b/speedy_antlr_tool/templates/sa_X_cpp_parser.cpp index 20e8752..b80205e 100644 --- a/speedy_antlr_tool/templates/sa_X_cpp_parser.cpp +++ b/speedy_antlr_tool/templates/sa_X_cpp_parser.cpp @@ -7,6 +7,7 @@ #include #include +#include #include "antlr4-runtime.h" #include "{{grammar_name}}Lexer.h" diff --git a/speedy_antlr_tool/templates/speedy_antlr.cpp b/speedy_antlr_tool/templates/speedy_antlr.cpp index 67dbdd6..12f0cf7 100644 --- a/speedy_antlr_tool/templates/speedy_antlr.cpp +++ b/speedy_antlr_tool/templates/speedy_antlr.cpp @@ -4,6 +4,7 @@ */ #include "speedy_antlr.h" +#include using namespace speedy_antlr; @@ -155,7 +156,7 @@ PyObject* Translator::convert_ctx( Py_INCREF(py_label_candidate); // Get start/stop - if(!start){ + if(!start || start==Py_None){ start = py_token; Py_INCREF(start); } @@ -172,18 +173,18 @@ PyObject* Translator::convert_ctx( } catch(PythonException &e) { Py_XDECREF(py_ctx); Py_XDECREF(py_children); + throw; } PyObject_SetAttrString(py_child, "parentCtx", py_ctx); py_label_candidate = py_child; Py_INCREF(py_label_candidate); // Get start/stop - if(i == 0) { + if(!start || start==Py_None) { start = PyObject_GetAttrString(py_child, "start"); } - if(i == ctx->children.size() - 1){ - stop = PyObject_GetAttrString(py_child, "stop"); - } + PyObject *tmp_stop = PyObject_GetAttrString(py_child, "stop"); + if (tmp_stop && tmp_stop!=Py_None) stop = tmp_stop; } else { PyErr_SetString(PyExc_RuntimeError, "Unknown child type"); throw PythonException(); diff --git a/speedy_antlr_tool/validate.py b/speedy_antlr_tool/validate.py index c4e6619..ee968d9 100644 --- a/speedy_antlr_tool/validate.py +++ b/speedy_antlr_tool/validate.py @@ -1,7 +1,7 @@ import re import inspect -from antlr4 import ParserRuleContext +from antlr4 import InputStream, ParserRuleContext from antlr4.tree.Tree import TerminalNodeImpl from antlr4.Token import Token, CommonToken @@ -14,25 +14,34 @@ def validate_top_ctx(py_ctx:ParserRuleContext, cpp_ctx:ParserRuleContext): def validate_ctx(py_ctx:ParserRuleContext, cpp_ctx:ParserRuleContext): assert type(py_ctx) == type(cpp_ctx) - assert len(py_ctx.children) == len(cpp_ctx.children) + pc = list(py_ctx.getChildren()) + cc = list(cpp_ctx.getChildren()) + assert len(pc) == len(cc) # Validate children - for i in range(len(py_ctx.children)): - if isinstance(py_ctx.children[i], TerminalNodeImpl): - validate_tnode(py_ctx.children[i], cpp_ctx.children[i]) - elif isinstance(py_ctx.children[i], ParserRuleContext): - validate_ctx(py_ctx.children[i], cpp_ctx.children[i]) + for i in range(len(pc)): + if isinstance(pc[i], TerminalNodeImpl): + validate_tnode(pc[i], cc[i]) + elif isinstance(pc[i], ParserRuleContext): + validate_ctx(pc[i], cc[i]) else: raise RuntimeError - assert py_ctx.children[i].parentCtx is py_ctx - assert cpp_ctx.children[i].parentCtx is cpp_ctx + assert pc[i].parentCtx is py_ctx + assert cc[i].parentCtx is cpp_ctx # Validate start/stop markers - validate_common_token(py_ctx.start, cpp_ctx.start) - validate_common_token(py_ctx.stop, cpp_ctx.stop) + if (py_ctx.start is not None + and py_ctx.stop is not None + and py_ctx.start.type != Token.EOF + and py_ctx.stop.type != Token.EOF + and py_ctx.start.tokenIndex <= py_ctx.stop.tokenIndex): + validate_common_token(py_ctx.start, cpp_ctx.start) + validate_common_token(py_ctx.stop, cpp_ctx.stop) # Validate labels for label in get_rule_labels(py_ctx): + if label.startswith("_"): + continue py_label = getattr(py_ctx, label) cpp_label = getattr(cpp_ctx, label) assert type(py_label) == type(cpp_label) @@ -64,4 +73,4 @@ def validate_common_token(py_tok:CommonToken, cpp_tok:CommonToken): assert py_tok.line == cpp_tok.line assert py_tok.column == cpp_tok.column assert py_tok.text == cpp_tok.text - assert py_tok.getInputStream() is cpp_tok.getInputStream() + assert isinstance(cpp_tok.getInputStream(), InputStream)