diff --git a/tools/hrw4u/Makefile b/tools/hrw4u/Makefile index 24714279b7e..d370bf4a009 100644 --- a/tools/hrw4u/Makefile +++ b/tools/hrw4u/Makefile @@ -32,6 +32,9 @@ SCRIPT_U4WRH=scripts/u4wrh SCRIPT_LSP=scripts/hrw4u-lsp SCRIPT_KG=scripts/hrw4u-kg +# scripts/hrw4u-ast is dev-only — not packaged or shipped. Run it from the +# source tree via `uv run scripts/hrw4u-ast` or `python scripts/hrw4u-ast`. + # Shared source files (will go in hrw4u package) SHARED_FILES=src/common.py \ src/debugging.py \ @@ -56,7 +59,7 @@ SRC_FILES_HRW4U=src/visitor.py \ src/sandbox.py \ src/kg_visitor.py \ src/ast_nodes.py \ - src/ast_visitor.py + src/ast_builder.py ALL_HRW4U_FILES=$(SHARED_FILES) $(UTILS_FILES) $(SRC_FILES_HRW4U) @@ -191,7 +194,8 @@ coverage: gen coverage-open: coverage uv run python -m webbrowser "file://$(shell pwd)/htmlcov/index.html" -# Build standalone binaries (optional) +# Build standalone binaries (optional). hrw4u-ast is intentionally +# excluded — it's a dev-only inspection tool, not a shipped artifact. build: gen uv run pyinstaller --onedir --name hrw4u --strip $(SCRIPT_HRW4U) uv run pyinstaller --onedir --name u4wrh --strip $(SCRIPT_U4WRH) diff --git a/tools/hrw4u/grammar/hrw4u.g4 b/tools/hrw4u/grammar/hrw4u.g4 index 335a3817b30..ba728e7c638 100644 --- a/tools/hrw4u/grammar/hrw4u.g4 +++ b/tools/hrw4u/grammar/hrw4u.g4 @@ -88,6 +88,7 @@ AND : '&&'; OR : '||'; TILDE : '~'; NOT_TILDE : '!~'; +BANG : '!'; COLON : ':'; COMMA : ','; SEMICOLON : ';'; @@ -217,7 +218,7 @@ term ; factor - : '!' factor + : BANG factor | LPAREN expression RPAREN | functionCall | comparison @@ -230,9 +231,9 @@ comparison : comparable (EQUALS | NEQ | GT | LT) value modifier? | comparable (TILDE | NOT_TILDE) regex modifier? | comparable IN set modifier? - | comparable '!' IN set modifier? + | comparable BANG IN set modifier? | comparable IN iprange - | comparable '!' IN iprange + | comparable BANG IN iprange ; modifier diff --git a/tools/hrw4u/scripts/hrw4u-ast b/tools/hrw4u/scripts/hrw4u-ast new file mode 100755 index 00000000000..ab17579d919 --- /dev/null +++ b/tools/hrw4u/scripts/hrw4u-ast @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""hrw4u-ast - Inspect the HRW4U AST and the stages around it (CST, ...).""" + +from __future__ import annotations + +import argparse +import pprint +import sys +from typing import Any, Callable + +from antlr4 import CommonTokenStream, InputStream + +from hrw4u.ast_builder import ASTBuilder +from hrw4u.hrw4uLexer import hrw4uLexer +from hrw4u.hrw4uParser import hrw4uParser + + +def emit_cst(tree: Any, parser: hrw4uParser) -> None: + print(tree.toStringTree(recog=parser)) + + +def emit_ast(tree: Any, _parser: hrw4uParser) -> None: + ast = ASTBuilder().visit(tree) + pprint.pp(ast) + + +# Stage registry. Adding a new stage (resolved, validated, ...) is a one-line +# addition here plus its emit_* function above. Each emitter takes the parse +# tree and the parser (some stages need the parser for token/rule names). +STAGES: dict[str, Callable[[Any, hrw4uParser], None]] = { + "cst": emit_cst, + "ast": emit_ast, +} + +DEFAULT_STAGE = "ast" + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Inspect the HRW4U AST and surrounding stages.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog="Stages:\n cst ANTLR concrete syntax tree (raw parse tree)\n ast dataclass AST built by ASTBuilder (default)\n") + parser.add_argument( + "input_file", nargs="?", type=argparse.FileType("r"), default=sys.stdin, help="HRW4U source file (default: stdin)") + parser.add_argument( + "--stage", choices=sorted(STAGES.keys()), default=DEFAULT_STAGE, help=f"Which stage to emit (default: {DEFAULT_STAGE})") + args = parser.parse_args() + + content = args.input_file.read() + if args.input_file is not sys.stdin: + args.input_file.close() + + token_stream = CommonTokenStream(hrw4uLexer(InputStream(content))) + antlr_parser = hrw4uParser(token_stream) + tree = antlr_parser.program() + + if antlr_parser.getNumberOfSyntaxErrors() > 0: + print("Parse failed: syntax errors above.", file=sys.stderr) + return 1 + + STAGES[args.stage](tree, antlr_parser) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tools/hrw4u/src/ast_visitor.py b/tools/hrw4u/src/ast_builder.py similarity index 61% rename from tools/hrw4u/src/ast_visitor.py rename to tools/hrw4u/src/ast_builder.py index 4a66ec0a710..a91945e792e 100644 --- a/tools/hrw4u/src/ast_visitor.py +++ b/tools/hrw4u/src/ast_builder.py @@ -18,10 +18,10 @@ from __future__ import annotations from hrw4u.hrw4uVisitor import hrw4uVisitor -from hrw4u.ast_nodes import * +from hrw4u import ast_nodes as nodes -class ASTVisitor(hrw4uVisitor): +class ASTBuilder(hrw4uVisitor): """ANTLR visitor that walks an HRW4U parse tree and produces an AST for HRW4U.""" # Only visitProgram is overridden from the ANTLR visitor interface; @@ -29,7 +29,7 @@ class ASTVisitor(hrw4uVisitor): # method has an explicit return type and full control over how # child results are assembled into parent AST nodes. - def visitProgram(self, ctx) -> HRW4UAST: + def visitProgram(self, ctx) -> nodes.HRW4UAST: items = [] for item in ctx.programItem(): if item.useDirective() is not None: @@ -42,34 +42,34 @@ def visitProgram(self, ctx) -> HRW4UAST: pass else: raise ValueError(f"Unhandled programItem alternative at line {item.start.line}") - return HRW4UAST(body=tuple(items)) + return nodes.HRW4UAST(body=tuple(items)) - def _visit_use_directive(self, ctx) -> UseDirective: - return UseDirective(spec=ctx.QUALIFIED_IDENT().getText(), line=ctx.start.line) + def _visit_use_directive(self, ctx) -> nodes.UseDirective: + return nodes.UseDirective(spec=ctx.QUALIFIED_IDENT().getText(), line=ctx.start.line) - def _visit_procedure_decl(self, ctx) -> ProcedureDecl: + def _visit_procedure_decl(self, ctx) -> nodes.ProcedureDecl: name = ctx.QUALIFIED_IDENT().getText() params = () if ctx.paramList(): params = tuple(self._visit_proc_param(p) for p in ctx.paramList().param()) body = tuple(self._visit_body(ctx.block().blockItem())) - return ProcedureDecl(name=name, params=params, body=body, line=ctx.start.line) + return nodes.ProcedureDecl(name=name, params=params, body=body, line=ctx.start.line) - def _visit_proc_param(self, ctx) -> ProcParam: + def _visit_proc_param(self, ctx) -> nodes.ProcParam: name = ctx.IDENT().getText() default = self._extract_value(ctx.value()) if ctx.value() else None - return ProcParam(name=name, default=default, line=ctx.start.line) + return nodes.ProcParam(name=name, default=default, line=ctx.start.line) - def _visit_section(self, ctx) -> VarSection | Section: + def _visit_section(self, ctx) -> nodes.VarSection | nodes.Section: if ctx.varSection() is not None: - return self._visit_var_section(ctx.varSection(), "txn") + return self._visit_var_section(ctx.varSection(), nodes.VarSectionKind.TXN) if ctx.sessionVarSection() is not None: - return self._visit_var_section(ctx.sessionVarSection(), "session") + return self._visit_var_section(ctx.sessionVarSection(), nodes.VarSectionKind.SESSION) name = ctx.name.text body = self._visit_body(ctx.sectionBody()) - return Section(type=name, body=tuple(body), line=ctx.start.line) + return nodes.Section(type=name, body=tuple(body), line=ctx.start.line) - def _visit_var_section(self, ctx, scope) -> VarSection: + def _visit_var_section(self, ctx, scope: nodes.VarSectionKind) -> nodes.VarSection: decls = [] for var_item in ctx.variables().variablesItem(): if var_item.variableDecl() is not None: @@ -78,13 +78,13 @@ def _visit_var_section(self, ctx, scope) -> VarSection: pass else: raise ValueError(f"Unhandled variablesItem alternative at line {var_item.start.line}") - return VarSection(scope=scope, declarations=tuple(decls), line=ctx.start.line) + return nodes.VarSection(scope=scope, declarations=tuple(decls), line=ctx.start.line) - def _visit_var_decl(self, ctx) -> VarDecl: - return VarDecl( + def _visit_var_decl(self, ctx) -> nodes.VarDecl: + return nodes.VarDecl( name=ctx.name.text, type_name=ctx.typeName.text, slot=int(ctx.slot.text) if ctx.slot else None, line=ctx.start.line) - def _visit_body(self, items) -> list[BodyNode]: + def _visit_body(self, items) -> list[nodes.BodyNode]: """Shared helper for sectionBody and blockItem lists.""" result = [] for item in items: @@ -98,51 +98,51 @@ def _visit_body(self, items) -> list[BodyNode]: raise ValueError(f"Unhandled body item alternative at line {item.start.line}") return result - def _visit_statement(self, ctx) -> BodyNode: + def _visit_statement(self, ctx) -> nodes.BodyNode: line = ctx.start.line if ctx.BREAK(): - return Break(line=line) + return nodes.Break(line=line) if ctx.functionCall(): return self._visit_function_call(ctx.functionCall()) if ctx.EQUAL(): - target = Target.from_dotted(ctx.lhs.text) + target = nodes.Target.from_dotted(ctx.lhs.text) value = self._extract_value(ctx.value()) - return Assignment(target=target, operator="=", value=value, line=line) + return nodes.Assignment(target=target, operator=nodes.AssignOp.ASSIGN, value=value, line=line) if ctx.PLUSEQUAL(): - target = Target.from_dotted(ctx.lhs.text) + target = nodes.Target.from_dotted(ctx.lhs.text) value = self._extract_value(ctx.value()) - return Assignment(target=target, operator="+=", value=value, line=line) + return nodes.Assignment(target=target, operator=nodes.AssignOp.PLUS_ASSIGN, value=value, line=line) if ctx.op: - return FunctionCall(name=ctx.op.text, args=(), line=line) + return nodes.FunctionCall(name=ctx.op.text, args=(), line=line) raise ValueError(f"Unhandled statement alternative at line {line}") - def _visit_function_call(self, ctx) -> FunctionCall: + def _visit_function_call(self, ctx) -> nodes.FunctionCall: name = ctx.funcName.text args = () if ctx.argumentList(): args = tuple(self._extract_value(v) for v in ctx.argumentList().value()) - return FunctionCall(name=name, args=args, line=ctx.start.line) + return nodes.FunctionCall(name=name, args=args, line=ctx.start.line) - def _extract_value(self, ctx) -> ValueExpr: + def _extract_value(self, ctx) -> nodes.ValueExpr: if ctx.number is not None: return int(ctx.number.text) if ctx.str_ is not None: - return LiteralStringValue(raw=ctx.str_.text[1:-1]) + return nodes.LiteralStringValue(text=ctx.str_.text[1:-1]) if ctx.TRUE(): return True if ctx.FALSE(): return False if ctx.ident is not None: - return IdentValue(raw=ctx.ident.text) + return nodes.IdentValue(raw=ctx.ident.text) if ctx.ip(): - return IPValue(raw=ctx.ip().getText()) + return nodes.IPValue(raw=ctx.ip().getText()) if ctx.iprange(): - return tuple(IPValue(raw=ip.getText()) for ip in ctx.iprange().ip()) + return tuple(nodes.IPValue(raw=ip.getText()) for ip in ctx.iprange().ip()) if ctx.paramRef(): - return ParamRef(raw=ctx.paramRef().IDENT().getText()) + return nodes.ParamRef(name=ctx.paramRef().IDENT().getText()) raise ValueError(f"Unhandled value alternative at line {ctx.start.line}") - def _visit_conditional(self, ctx) -> IfBlock: + def _visit_conditional(self, ctx) -> nodes.IfBlock: if_stmt = ctx.ifStatement() condition = self._visit_condition(if_stmt.condition()) block = if_stmt.block() @@ -153,7 +153,7 @@ def _visit_conditional(self, ctx) -> IfBlock: elif_cond = self._visit_condition(elif_ctx.condition()) elif_block = elif_ctx.block() elif_body = tuple(self._visit_body(elif_block.blockItem())) if elif_block else () - elif_branches.append(ElifBranch(condition=elif_cond, body=elif_body, line=elif_ctx.start.line)) + elif_branches.append(nodes.ElifBranch(condition=elif_cond, body=elif_body, line=elif_ctx.start.line)) else_body = () if ctx.elseClause(): @@ -161,28 +161,29 @@ def _visit_conditional(self, ctx) -> IfBlock: if else_block: else_body = tuple(self._visit_body(else_block.blockItem())) - return IfBlock(condition=condition, body=body, elif_branches=tuple(elif_branches), else_body=else_body, line=ctx.start.line) + return nodes.IfBlock( + condition=condition, body=body, elif_branches=tuple(elif_branches), else_body=else_body, line=ctx.start.line) - def _visit_condition(self, ctx) -> ConditionExpr: + def _visit_condition(self, ctx) -> nodes.ConditionExpr: return self._visit_expression(ctx.expression()) - def _visit_expression(self, ctx) -> ConditionExpr: + def _visit_expression(self, ctx) -> nodes.ConditionExpr: if ctx.OR(): left = self._visit_expression(ctx.expression()) right = self._visit_term(ctx.term()) - return LogicalOp(operator="||", left=left, right=right, line=ctx.start.line) + return nodes.LogicalOp(operator=nodes.BoolOp.OR, left=left, right=right, line=ctx.start.line) return self._visit_term(ctx.term()) - def _visit_term(self, ctx) -> ConditionExpr: + def _visit_term(self, ctx) -> nodes.ConditionExpr: if ctx.AND(): left = self._visit_term(ctx.term()) right = self._visit_factor(ctx.factor()) - return LogicalOp(operator="&&", left=left, right=right, line=ctx.start.line) + return nodes.LogicalOp(operator=nodes.BoolOp.AND, left=left, right=right, line=ctx.start.line) return self._visit_factor(ctx.factor()) - def _visit_factor(self, ctx) -> ConditionExpr: - if ctx.getChildCount() == 2 and ctx.getChild(0).getText() == "!": - return NotOp(operand=self._visit_factor(ctx.factor()), line=ctx.start.line) + def _visit_factor(self, ctx) -> nodes.ConditionExpr: + if ctx.BANG(): + return nodes.NotOp(operand=self._visit_factor(ctx.factor()), line=ctx.start.line) if ctx.LPAREN(): return self._visit_expression(ctx.expression()) if ctx.functionCall(): @@ -190,18 +191,18 @@ def _visit_factor(self, ctx) -> ConditionExpr: if ctx.comparison(): return self._visit_comparison(ctx.comparison()) if ctx.ident is not None: - return IdentCondition(name=ctx.ident.text, line=ctx.start.line) + return nodes.IdentCondition(name=ctx.ident.text, line=ctx.start.line) if ctx.TRUE(): - return BoolLiteral(value=True, line=ctx.start.line) + return nodes.BoolLiteral(value=True, line=ctx.start.line) if ctx.FALSE(): - return BoolLiteral(value=False, line=ctx.start.line) + return nodes.BoolLiteral(value=False, line=ctx.start.line) raise ValueError(f"Unhandled factor alternative at line {ctx.start.line}") - def _visit_comparison(self, ctx) -> Comparison: + def _visit_comparison(self, ctx) -> nodes.Comparison: line = ctx.start.line comp = ctx.comparable() if comp.ident is not None: - left = IdentValue(raw=comp.ident.text) + left = nodes.IdentValue(raw=comp.ident.text) else: left = self._visit_function_call(comp.functionCall()) @@ -209,36 +210,34 @@ def _visit_comparison(self, ctx) -> Comparison: right = self._extract_comparison_rhs(ctx, operator) modifiers = self._extract_modifiers(ctx) - return Comparison(left=left, operator=operator, right=right, modifiers=modifiers, line=line) + return nodes.Comparison(left=left, operator=operator, right=right, modifiers=modifiers, line=line) - def _detect_comparison_operator(self, ctx) -> str: + def _detect_comparison_operator(self, ctx) -> nodes.CmpOp: if ctx.EQUALS(): - return "==" + return nodes.CmpOp.EQ if ctx.NEQ(): - return "!=" + return nodes.CmpOp.NEQ if ctx.GT(): - return ">" + return nodes.CmpOp.GT if ctx.LT(): - return "<" + return nodes.CmpOp.LT if ctx.TILDE(): - return "~" + return nodes.CmpOp.MATCH if ctx.NOT_TILDE(): - return "!~" + return nodes.CmpOp.NOT_MATCH if ctx.IN(): - for child in ctx.children: - if hasattr(child, "getText") and child.getText() == "!": - return "!in" - return "in" + return nodes.CmpOp.NOT_IN if ctx.BANG() else nodes.CmpOp.IN raise ValueError(f"Unhandled comparison operator at line {ctx.start.line}") - def _extract_comparison_rhs(self, ctx, operator) -> ValueExpr | RegexValue | tuple[ValueExpr, ...]: - if operator in ("~", "!~"): - return RegexValue(raw=ctx.regex().getText()[1:-1]) - if operator in ("in", "!in"): + def _extract_comparison_rhs(self, ctx, + operator: nodes.CmpOp) -> nodes.ValueExpr | nodes.RegexValue | tuple[nodes.ValueExpr, ...]: + if operator in (nodes.CmpOp.MATCH, nodes.CmpOp.NOT_MATCH): + return nodes.RegexValue(pattern=ctx.regex().getText()[1:-1]) + if operator in (nodes.CmpOp.IN, nodes.CmpOp.NOT_IN): if ctx.set_(): return tuple(self._extract_value(v) for v in ctx.set_().value()) if ctx.iprange(): - return tuple(IPValue(raw=ip.getText()) for ip in ctx.iprange().ip()) + return tuple(nodes.IPValue(raw=ip.getText()) for ip in ctx.iprange().ip()) if ctx.value(): return self._extract_value(ctx.value()) raise ValueError(f"Unhandled comparison RHS at line {ctx.start.line}") diff --git a/tools/hrw4u/src/ast_nodes.py b/tools/hrw4u/src/ast_nodes.py index acf5bacccb3..c86c234202f 100644 --- a/tools/hrw4u/src/ast_nodes.py +++ b/tools/hrw4u/src/ast_nodes.py @@ -18,9 +18,13 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Union +from enum import Enum, auto __all__ = [ + "AssignOp", + "CmpOp", + "BoolOp", + "VarSectionKind", "LiteralStringValue", "IdentValue", "IPValue", @@ -51,10 +55,47 @@ "TopLevelNode", ] +# Enum.__str__ yields "AssignOp.ASSIGN" while the default Enum.__repr__ yields +# ""; we alias __repr__ to __str__ on every operator enum +# so that pprint output (used by hrw4u-ast) is concise and readable. + + +class AssignOp(Enum): + ASSIGN = auto() + PLUS_ASSIGN = auto() + __repr__ = Enum.__str__ + + +class CmpOp(Enum): + EQ = auto() + NEQ = auto() + GT = auto() + LT = auto() + MATCH = auto() + NOT_MATCH = auto() + IN = auto() + NOT_IN = auto() + __repr__ = Enum.__str__ + + +class BoolOp(Enum): + AND = auto() + OR = auto() + __repr__ = Enum.__str__ + + +class VarSectionKind(Enum): + TXN = auto() + SESSION = auto() + __repr__ = Enum.__str__ + @dataclass(frozen=True, kw_only=True) class LiteralStringValue: - raw: str + # The string body with surrounding quotes stripped. Escape sequences + # (e.g. '\n', '\"') are preserved as written; consumers needing the + # decoded value must do their own decoding. + text: str @dataclass(frozen=True, kw_only=True) @@ -69,15 +110,18 @@ class IPValue: @dataclass(frozen=True, kw_only=True) class ParamRef: - raw: str + # Parameter name without the '$' sigil (source `$tag` -> name='tag'). + name: str @dataclass(frozen=True, kw_only=True) class RegexValue: - raw: str + # The regex body with surrounding '/' delimiters stripped. Backslash + # escapes inside the pattern are preserved as written. + pattern: str -ValueExpr = Union[LiteralStringValue, IdentValue, IPValue, ParamRef, int, bool, tuple[IPValue, ...]] +ValueExpr = LiteralStringValue | IdentValue | IPValue | ParamRef | int | bool | tuple[IPValue, ...] @dataclass(frozen=True, kw_only=True) @@ -87,6 +131,8 @@ class Node: @dataclass(frozen=True) class Target: + # Value class, not an AST node — destructured from an Assignment's IDENT + # lhs, so source position lives on the enclosing Assignment. namespace: str | None field: str @@ -104,7 +150,7 @@ def from_dotted(name: str) -> Target: @dataclass(frozen=True, kw_only=True) class Assignment(Node): target: Target - operator: str # "=" or "+=" + operator: AssignOp value: ValueExpr @@ -122,14 +168,14 @@ class Break(Node): @dataclass(frozen=True, kw_only=True) class Comparison(Node): left: IdentValue | FunctionCall - operator: str # "==", "!=", ">", "<", "~", "!~", "in", "!in" + operator: CmpOp right: ValueExpr | RegexValue | tuple[ValueExpr, ...] modifiers: tuple[str, ...] @dataclass(frozen=True, kw_only=True) class LogicalOp(Node): - operator: str # "&&" or "||" + operator: BoolOp left: ConditionExpr right: ConditionExpr @@ -184,7 +230,7 @@ class VarDecl(Node): @dataclass(frozen=True, kw_only=True) class VarSection(Node): - scope: str + scope: VarSectionKind declarations: tuple[VarDecl, ...] @@ -206,6 +252,6 @@ class HRW4UAST: # Type aliases: must follow all class definitions (evaluated at runtime). -ConditionExpr = Union[Comparison, LogicalOp, NotOp, BoolLiteral, IdentCondition, FunctionCall] -BodyNode = Union[Assignment, FunctionCall, IfBlock, Break] -TopLevelNode = Union[UseDirective, VarSection, ProcedureDecl, Section] +ConditionExpr = Comparison | LogicalOp | NotOp | BoolLiteral | IdentCondition | FunctionCall +BodyNode = Assignment | FunctionCall | IfBlock | Break +TopLevelNode = UseDirective | VarSection | ProcedureDecl | Section diff --git a/tools/hrw4u/src/visitor.py b/tools/hrw4u/src/visitor.py index f3fb38c3a39..b379ea501ec 100644 --- a/tools/hrw4u/src/visitor.py +++ b/tools/hrw4u/src/visitor.py @@ -1012,9 +1012,9 @@ def visitComparison(self, ctx, *, last: bool = False) -> None: return operator = ctx.getChild(1) - # Detect negation: '!=' and '!~' are single tokens (NEQ, NOT_TILDE), - # but '!in' is two separate tokens ('!' + IN). - if operator.getText() == '!': + # Detect negation: '!=' and '!~' are single tokens (NEQ, NOT_TILDE); + # 'in' is single, but '!in' is BANG followed by IN as separate tokens. + if ctx.BANG(): negate = True else: negate = operator.symbol.type in (hrw4uParser.NEQ, hrw4uParser.NOT_TILDE) @@ -1136,7 +1136,7 @@ def emit_term(self, ctx, *, last: bool = False) -> None: def emit_factor(self, ctx, *, last: bool = False) -> None: with self.debug_context("emit_factor"), self.trap(ctx): match ctx: - case _ if ctx.getChildCount() == 2 and ctx.getChild(0).getText() == "!": + case _ if ctx.BANG(): self._dbg("`NOT' detected") child = ctx.getChild(1) if child.LPAREN(): diff --git a/tools/hrw4u/tests/test_ast_visitor.py b/tools/hrw4u/tests/test_ast_builder.py similarity index 59% rename from tools/hrw4u/tests/test_ast_visitor.py rename to tools/hrw4u/tests/test_ast_builder.py index ec919d1f060..178bafa902a 100644 --- a/tools/hrw4u/tests/test_ast_visitor.py +++ b/tools/hrw4u/tests/test_ast_builder.py @@ -15,14 +15,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from hrw4u.ast_nodes import * +from hrw4u import ast_nodes as nodes from utils import parse_input_text -from hrw4u.ast_visitor import ASTVisitor +from hrw4u.ast_builder import ASTBuilder -def _build(source: str) -> HRW4UAST: +def _build(source: str) -> nodes.HRW4UAST: _, tree = parse_input_text(source) - return ASTVisitor().visit(tree) + return ASTBuilder().visit(tree) class TestAssignments: @@ -30,17 +30,29 @@ class TestAssignments: def test_simple_assignment(self): ast = _build('REMAP {\n inbound.req.X-Foo = "test";\n}') a = ast.body[0].body[0] - assert isinstance(a, Assignment) - assert a.target == Target.from_dotted("inbound.req.X-Foo") - assert a.operator == "=" - assert a.value == LiteralStringValue(raw="test") + assert isinstance(a, nodes.Assignment) + assert a.target == nodes.Target.from_dotted("inbound.req.X-Foo") + assert a.operator == nodes.AssignOp.ASSIGN + assert a.value == nodes.LiteralStringValue(text="test") def test_bool_value(self): ast = _build('SEND_RESPONSE {\n http.cntl.TXN_DEBUG = true;\n}') a = ast.body[0].body[0] - assert isinstance(a, Assignment) + assert isinstance(a, nodes.Assignment) assert a.value is True + def test_false_value(self): + ast = _build('SEND_RESPONSE {\n http.cntl.TXN_DEBUG = false;\n}') + a = ast.body[0].body[0] + assert isinstance(a, nodes.Assignment) + assert a.value is False + + def test_empty_string_value(self): + ast = _build('REMAP {\n inbound.req.X-Foo = "";\n}') + a = ast.body[0].body[0] + assert isinstance(a, nodes.Assignment) + assert a.value == nodes.LiteralStringValue(text="") + def test_int_value(self): ast = _build('REMAP {\n http.cntl.INTERCEPT_RETRY = 1;\n}') a = ast.body[0].body[0] @@ -49,20 +61,40 @@ def test_int_value(self): def test_plus_equals(self): ast = _build('REMAP {\n inbound.req.X-Foo += "extra";\n}') a = ast.body[0].body[0] - assert a.operator == "+=" + assert a.operator == nodes.AssignOp.PLUS_ASSIGN def test_ip_value(self): ast = _build('REMAP {\n inbound.req.X-IP = 10.0.0.1;\n}') a = ast.body[0].body[0] - assert isinstance(a, Assignment) - assert a.value == IPValue(raw="10.0.0.1") + assert isinstance(a, nodes.Assignment) + assert a.value == nodes.IPValue(raw="10.0.0.1") def test_param_ref_value(self): src = 'procedure local::stamp($tag) {\n inbound.req.X-Stamp = $tag;\n}\nREMAP {\n set-debug();\n}' ast = _build(src) a = ast.body[0].body[0] - assert isinstance(a, Assignment) - assert a.value == ParamRef(raw="tag") + assert isinstance(a, nodes.Assignment) + assert a.value == nodes.ParamRef(name="tag") + + def test_ident_value(self): + src = 'VARS {\n flag: bool;\n}\nREMAP {\n inbound.req.X-Flag = flag;\n}' + ast = _build(src) + a = ast.body[1].body[0] + assert isinstance(a, nodes.Assignment) + assert a.value == nodes.IdentValue(raw="flag") + + +class TestTarget: + + def test_from_dotted(self): + t = nodes.Target.from_dotted("inbound.req.X-Foo") + assert t.namespace == "inbound.req" + assert t.field == "X-Foo" + + def test_from_dotted_no_namespace(self): + t = nodes.Target.from_dotted("flag") + assert t.namespace is None + assert t.field == "flag" class TestFunctionCalls: @@ -70,7 +102,7 @@ class TestFunctionCalls: def test_no_args(self): ast = _build('REMAP {\n set-debug();\n}') fc = ast.body[0].body[0] - assert isinstance(fc, FunctionCall) + assert isinstance(fc, nodes.FunctionCall) assert fc.name == "set-debug" assert fc.args == () @@ -78,19 +110,35 @@ def test_with_args(self): ast = _build('REMAP {\n set-header("X-Foo", "bar");\n}') fc = ast.body[0].body[0] assert fc.name == "set-header" - assert fc.args == (LiteralStringValue(raw="X-Foo"), LiteralStringValue(raw="bar")) + assert fc.args == (nodes.LiteralStringValue(text="X-Foo"), nodes.LiteralStringValue(text="bar")) + + def test_qualified_name(self): + src = 'use test::helper\nREMAP {\n test::helper("tag");\n}' + ast = _build(src) + fc = ast.body[1].body[0] + assert isinstance(fc, nodes.FunctionCall) + assert fc.name == "test::helper" + assert fc.args == (nodes.LiteralStringValue(text="tag"),) + + def test_param_ref_arg(self): + src = 'procedure local::stamp($tag) {\n set-header("X-Stamp", $tag);\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + fc = ast.body[0].body[0] + assert isinstance(fc, nodes.FunctionCall) + assert fc.name == "set-header" + assert fc.args == (nodes.LiteralStringValue(text="X-Stamp"), nodes.ParamRef(name="tag")) def test_standalone_operator(self): ast = _build('REMAP {\n skip-remap;\n}') fc = ast.body[0].body[0] - assert isinstance(fc, FunctionCall) + assert isinstance(fc, nodes.FunctionCall) assert fc.name == "skip-remap" assert fc.args == () def test_break(self): ast = _build('REMAP {\n if true {\n break;\n }\n}') body = ast.body[0].body[0].body - assert isinstance(body[0], Break) + assert isinstance(body[0], nodes.Break) class TestSections: @@ -108,13 +156,13 @@ def test_comments_in_block_skipped(self): def test_section_type(self): ast = _build('REMAP {\n set-debug();\n}') s = ast.body[0] - assert isinstance(s, Section) + assert isinstance(s, nodes.Section) assert s.type == "REMAP" def test_multiple_sections(self): src = 'REMAP {\n set-debug();\n}\nSEND_RESPONSE {\n set-debug();\n}' ast = _build(src) - sections = [i for i in ast.body if isinstance(i, Section)] + sections = [i for i in ast.body if isinstance(i, nodes.Section)] assert len(sections) == 2 assert sections[0].type == "REMAP" assert sections[1].type == "SEND_RESPONSE" @@ -124,16 +172,38 @@ def test_use_directive(self): ast = _build(src) assert len(ast.body) == 2 u = ast.body[0] - assert isinstance(u, UseDirective) + assert isinstance(u, nodes.UseDirective) assert u.spec == "test::add-debug-header" + def test_multiple_use_directives(self): + src = ('use test::add-debug-header\n' + 'use test::stamp-request\n' + 'REMAP {\n test::add-debug-header("tag");\n}') + ast = _build(src) + directives = [i for i in ast.body if isinstance(i, nodes.UseDirective)] + assert len(directives) == 2 + assert directives[0].spec == "test::add-debug-header" + assert directives[1].spec == "test::stamp-request" + + def test_top_level_comments_skipped(self): + src = ( + '# leading comment\n' + 'use test::helper\n' + '# between use and section\n' + 'REMAP {\n set-debug();\n}\n' + '# trailing comment\n') + ast = _build(src) + assert len(ast.body) == 2 + assert isinstance(ast.body[0], nodes.UseDirective) + assert isinstance(ast.body[1], nodes.Section) + def test_item_ordering(self): src = 'VARS {\n x: bool;\n}\nREMAP {\n set-debug();\n}\nSEND_RESPONSE {\n set-debug();\n}' ast = _build(src) assert len(ast.body) == 3 - assert isinstance(ast.body[0], VarSection) - assert isinstance(ast.body[1], Section) - assert isinstance(ast.body[2], Section) + assert isinstance(ast.body[0], nodes.VarSection) + assert isinstance(ast.body[1], nodes.Section) + assert isinstance(ast.body[2], nodes.Section) class TestVarSections: @@ -142,15 +212,15 @@ def test_comments_in_var_section_skipped(self): src = 'VARS {\n # comment\n x: bool;\n # another\n y: int;\n}\nREMAP {\n set-debug();\n}' ast = _build(src) vs = ast.body[0] - assert isinstance(vs, VarSection) + assert isinstance(vs, nodes.VarSection) assert len(vs.declarations) == 2 def test_txn_scope(self): src = 'VARS {\n flag: bool;\n}\nREMAP {\n set-debug();\n}' ast = _build(src) vs = ast.body[0] - assert isinstance(vs, VarSection) - assert vs.scope == "txn" + assert isinstance(vs, nodes.VarSection) + assert vs.scope == nodes.VarSectionKind.TXN assert len(vs.declarations) == 1 assert vs.declarations[0].name == "flag" assert vs.declarations[0].type_name == "bool" @@ -160,27 +230,39 @@ def test_session_scope(self): src = 'SESSION_VARS {\n counter: int;\n}\nREMAP {\n set-debug();\n}' ast = _build(src) vs = ast.body[0] - assert isinstance(vs, VarSection) - assert vs.scope == "session" + assert isinstance(vs, nodes.VarSection) + assert vs.scope == nodes.VarSectionKind.SESSION assert vs.declarations[0].name == "counter" def test_slot(self): src = 'VARS {\n x: int @3;\n}\nREMAP {\n set-debug();\n}' ast = _build(src) vs = ast.body[0] - assert isinstance(vs, VarSection) + assert isinstance(vs, nodes.VarSection) assert vs.declarations[0].slot == 3 def test_multiple_declarations(self): src = 'VARS {\n a: bool;\n b: int;\n c: string;\n}\nREMAP {\n set-debug();\n}' ast = _build(src) vs = ast.body[0] - assert isinstance(vs, VarSection) + assert isinstance(vs, nodes.VarSection) assert len(vs.declarations) == 3 assert vs.declarations[0].name == "a" assert vs.declarations[1].name == "b" assert vs.declarations[2].name == "c" + def test_txn_and_session_in_same_program(self): + src = ('VARS {\n flag: bool;\n}\n' + 'SESSION_VARS {\n counter: int;\n}\n' + 'REMAP {\n set-debug();\n}') + ast = _build(src) + var_sections = [i for i in ast.body if isinstance(i, nodes.VarSection)] + assert len(var_sections) == 2 + assert var_sections[0].scope == nodes.VarSectionKind.TXN + assert var_sections[0].declarations[0].name == "flag" + assert var_sections[1].scope == nodes.VarSectionKind.SESSION + assert var_sections[1].declarations[0].name == "counter" + class TestProcedures: @@ -188,7 +270,7 @@ def test_basic_decl(self): src = 'procedure local::stamp($tag) {\n inbound.req.X-Stamp = "$tag";\n}\nREMAP {\n set-debug();\n}' ast = _build(src) pd = ast.body[0] - assert isinstance(pd, ProcedureDecl) + assert isinstance(pd, nodes.ProcedureDecl) assert pd.name == "local::stamp" assert len(pd.params) == 1 assert pd.params[0].name == "tag" @@ -198,19 +280,33 @@ def test_default_param(self): src = 'procedure local::cache($ttl=300) {\n set-debug();\n}\nREMAP {\n set-debug();\n}' ast = _build(src) pd = ast.body[0] - assert isinstance(pd, ProcedureDecl) + assert isinstance(pd, nodes.ProcedureDecl) assert pd.params[0].name == "ttl" assert pd.params[0].default == 300 + def test_multiple_params(self): + src = ('procedure local::tag($key, $value="x", $count=1) {\n set-debug();\n}\n' + 'REMAP {\n set-debug();\n}') + ast = _build(src) + pd = ast.body[0] + assert isinstance(pd, nodes.ProcedureDecl) + assert len(pd.params) == 3 + assert pd.params[0].name == "key" + assert pd.params[0].default is None + assert pd.params[1].name == "value" + assert pd.params[1].default == nodes.LiteralStringValue(text="x") + assert pd.params[2].name == "count" + assert pd.params[2].default == 1 + def test_body(self): src = ('procedure local::multi() {\n inbound.req.X = "a";\n' ' set-debug();\n}\nREMAP {\n set-debug();\n}') ast = _build(src) pd = ast.body[0] - assert isinstance(pd, ProcedureDecl) + assert isinstance(pd, nodes.ProcedureDecl) assert len(pd.body) == 2 - assert isinstance(pd.body[0], Assignment) - assert isinstance(pd.body[1], FunctionCall) + assert isinstance(pd.body[0], nodes.Assignment) + assert isinstance(pd.body[1], nodes.FunctionCall) class TestConditionExpressions: @@ -221,116 +317,127 @@ def _first_condition(self, source: str): def test_equality_comparison(self): cond = self._first_condition('REMAP {\n if inbound.req.X-Foo == "bar" {\n set-debug();\n }\n}') - assert isinstance(cond, Comparison) - assert cond.left == IdentValue(raw="inbound.req.X-Foo") - assert cond.operator == "==" - assert cond.right == LiteralStringValue(raw="bar") + assert isinstance(cond, nodes.Comparison) + assert cond.left == nodes.IdentValue(raw="inbound.req.X-Foo") + assert cond.operator == nodes.CmpOp.EQ + assert cond.right == nodes.LiteralStringValue(text="bar") assert cond.modifiers == () def test_regex_comparison(self): cond = self._first_condition('REMAP {\n if inbound.url.path ~ /\\.php$/ {\n set-debug();\n }\n}') - assert isinstance(cond, Comparison) - assert cond.operator == "~" - assert isinstance(cond.right, RegexValue) + assert isinstance(cond, nodes.Comparison) + assert cond.operator == nodes.CmpOp.MATCH + assert isinstance(cond.right, nodes.RegexValue) def test_in_set(self): cond = self._first_condition('REMAP {\n if inbound.url.path in ["a", "b"] {\n set-debug();\n }\n}') - assert isinstance(cond, Comparison) - assert cond.operator == "in" - assert cond.right == (LiteralStringValue(raw="a"), LiteralStringValue(raw="b")) + assert isinstance(cond, nodes.Comparison) + assert cond.operator == nodes.CmpOp.IN + assert cond.right == (nodes.LiteralStringValue(text="a"), nodes.LiteralStringValue(text="b")) def test_not_in_set(self): cond = self._first_condition('REMAP {\n if inbound.url.path !in ["a"] {\n set-debug();\n }\n}') - assert isinstance(cond, Comparison) - assert cond.operator == "!in" + assert isinstance(cond, nodes.Comparison) + assert cond.operator == nodes.CmpOp.NOT_IN def test_in_iprange(self): cond = self._first_condition('REMAP {\n if inbound.ip in {10.0.0.0/8} {\n set-debug();\n }\n}') - assert isinstance(cond, Comparison) - assert cond.operator == "in" - assert cond.right == (IPValue(raw="10.0.0.0/8"),) + assert isinstance(cond, nodes.Comparison) + assert cond.operator == nodes.CmpOp.IN + assert cond.right == (nodes.IPValue(raw="10.0.0.0/8"),) + + def test_not_in_iprange(self): + cond = self._first_condition('REMAP {\n if inbound.ip !in {10.0.0.0/8} {\n set-debug();\n }\n}') + assert isinstance(cond, nodes.Comparison) + assert cond.operator == nodes.CmpOp.NOT_IN + assert cond.right == (nodes.IPValue(raw="10.0.0.0/8"),) def test_modifiers(self): cond = self._first_condition('REMAP {\n if inbound.req.X-Foo == "bar" with NOCASE {\n set-debug();\n }\n}') - assert isinstance(cond, Comparison) + assert isinstance(cond, nodes.Comparison) assert cond.modifiers == ("NOCASE",) def test_modifiers_preserve_source_casing(self): cond = self._first_condition('REMAP {\n if inbound.req.X-Foo == "bar" with nocase,Pre {\n set-debug();\n }\n}') - assert isinstance(cond, Comparison) + assert isinstance(cond, nodes.Comparison) assert cond.modifiers == ("nocase", "Pre") def test_function_call_comparable(self): cond = self._first_condition('REMAP {\n if url(true) ~ /pat/ {\n set-debug();\n }\n}') - assert isinstance(cond, Comparison) - assert isinstance(cond.left, FunctionCall) + assert isinstance(cond, nodes.Comparison) + assert isinstance(cond.left, nodes.FunctionCall) assert cond.left.name == "url" assert cond.left.args == (True,) def test_bool_literal_true(self): cond = self._first_condition('REMAP {\n if true {\n set-debug();\n }\n}') - assert isinstance(cond, BoolLiteral) + assert isinstance(cond, nodes.BoolLiteral) assert cond.value is True + def test_bool_literal_false(self): + cond = self._first_condition('REMAP {\n if false {\n set-debug();\n }\n}') + assert isinstance(cond, nodes.BoolLiteral) + assert cond.value is False + def test_ident_condition(self): cond = self._first_condition('REMAP {\n if inbound.resp.All-Cache {\n set-debug();\n }\n}') - assert isinstance(cond, IdentCondition) + assert isinstance(cond, nodes.IdentCondition) assert cond.name == "inbound.resp.All-Cache" def test_not_condition(self): cond = self._first_condition('REMAP {\n if !inbound.resp.All-Cache {\n set-debug();\n }\n}') - assert isinstance(cond, NotOp) - assert isinstance(cond.operand, IdentCondition) + assert isinstance(cond, nodes.NotOp) + assert isinstance(cond.operand, nodes.IdentCondition) def test_and_condition(self): cond = self._first_condition( 'REMAP {\n if inbound.req.X-A == "a" && inbound.req.X-B == "b" {\n set-debug();\n }\n}') - assert isinstance(cond, LogicalOp) - assert cond.operator == "&&" - assert isinstance(cond.left, Comparison) - assert isinstance(cond.right, Comparison) + assert isinstance(cond, nodes.LogicalOp) + assert cond.operator == nodes.BoolOp.AND + assert isinstance(cond.left, nodes.Comparison) + assert isinstance(cond.right, nodes.Comparison) def test_or_condition(self): cond = self._first_condition( 'REMAP {\n if inbound.req.X-A == "a" || inbound.req.X-B == "b" {\n set-debug();\n }\n}') - assert isinstance(cond, LogicalOp) - assert cond.operator == "||" + assert isinstance(cond, nodes.LogicalOp) + assert cond.operator == nodes.BoolOp.OR def test_function_call_in_condition(self): cond = self._first_condition('REMAP {\n if access("/tmp/bar") {\n set-debug();\n }\n}') - assert isinstance(cond, FunctionCall) + assert isinstance(cond, nodes.FunctionCall) assert cond.name == "access" - assert cond.args == (LiteralStringValue(raw="/tmp/bar"),) + assert cond.args == (nodes.LiteralStringValue(text="/tmp/bar"),) def test_not_tilde_comparison(self): cond = self._first_condition('REMAP {\n if inbound.url.path !~ /\\.jpg$/ {\n set-debug();\n }\n}') - assert isinstance(cond, Comparison) - assert cond.operator == "!~" - assert isinstance(cond.right, RegexValue) + assert isinstance(cond, nodes.Comparison) + assert cond.operator == nodes.CmpOp.NOT_MATCH + assert isinstance(cond.right, nodes.RegexValue) def test_greater_than_comparison(self): cond = self._first_condition('REMAP {\n if inbound.req.Content-Length > 1000 {\n set-debug();\n }\n}') - assert isinstance(cond, Comparison) - assert cond.operator == ">" + assert isinstance(cond, nodes.Comparison) + assert cond.operator == nodes.CmpOp.GT assert cond.right == 1000 def test_less_than_comparison(self): cond = self._first_condition('REMAP {\n if inbound.req.Content-Length < 500 {\n set-debug();\n }\n}') - assert isinstance(cond, Comparison) - assert cond.operator == "<" + assert isinstance(cond, nodes.Comparison) + assert cond.operator == nodes.CmpOp.LT assert cond.right == 500 def test_neq_comparison(self): cond = self._first_condition('REMAP {\n if inbound.req.X-Foo != "bar" {\n set-debug();\n }\n}') - assert isinstance(cond, Comparison) - assert cond.operator == "!=" - assert cond.right == LiteralStringValue(raw="bar") + assert isinstance(cond, nodes.Comparison) + assert cond.operator == nodes.CmpOp.NEQ + assert cond.right == nodes.LiteralStringValue(text="bar") def test_parenthesized_condition(self): cond = self._first_condition('REMAP {\n if (inbound.req.X-Foo == "bar") {\n set-debug();\n }\n}') - assert isinstance(cond, Comparison) - assert cond.operator == "==" - assert cond.right == LiteralStringValue(raw="bar") + assert isinstance(cond, nodes.Comparison) + assert cond.operator == nodes.CmpOp.EQ + assert cond.right == nodes.LiteralStringValue(text="bar") def test_and_binds_tighter_than_or(self): # a || b && c should parse as a || (b && c) @@ -338,14 +445,14 @@ def test_and_binds_tighter_than_or(self): 'REMAP {\n' ' if inbound.req.X-A == "a" || inbound.req.X-B == "b" && inbound.req.X-C == "c" {\n' ' set-debug();\n }\n}') - assert isinstance(cond, LogicalOp) - assert cond.operator == "||" - assert isinstance(cond.left, Comparison) - assert cond.left.left == IdentValue(raw="inbound.req.X-A") - assert isinstance(cond.right, LogicalOp) - assert cond.right.operator == "&&" - assert cond.right.left.left == IdentValue(raw="inbound.req.X-B") - assert cond.right.right.left == IdentValue(raw="inbound.req.X-C") + assert isinstance(cond, nodes.LogicalOp) + assert cond.operator == nodes.BoolOp.OR + assert isinstance(cond.left, nodes.Comparison) + assert cond.left.left == nodes.IdentValue(raw="inbound.req.X-A") + assert isinstance(cond.right, nodes.LogicalOp) + assert cond.right.operator == nodes.BoolOp.AND + assert cond.right.left.left == nodes.IdentValue(raw="inbound.req.X-B") + assert cond.right.right.left == nodes.IdentValue(raw="inbound.req.X-C") def test_not_with_and(self): # !ident && comparison should parse as (!ident) && comparison @@ -353,13 +460,13 @@ def test_not_with_and(self): 'REMAP {\n' ' if !inbound.resp.All-Cache && inbound.req.X-B == "b" {\n' ' set-debug();\n }\n}') - assert isinstance(cond, LogicalOp) - assert cond.operator == "&&" - assert isinstance(cond.left, NotOp) - assert isinstance(cond.left.operand, IdentCondition) + assert isinstance(cond, nodes.LogicalOp) + assert cond.operator == nodes.BoolOp.AND + assert isinstance(cond.left, nodes.NotOp) + assert isinstance(cond.left.operand, nodes.IdentCondition) assert cond.left.operand.name == "inbound.resp.All-Cache" - assert isinstance(cond.right, Comparison) - assert cond.right.left == IdentValue(raw="inbound.req.X-B") + assert isinstance(cond.right, nodes.Comparison) + assert cond.right.left == nodes.IdentValue(raw="inbound.req.X-B") def test_not_comparison_with_or(self): # !(a == "x") || b == "y" should parse as (!(a == "x")) || (b == "y") @@ -367,26 +474,26 @@ def test_not_comparison_with_or(self): 'REMAP {\n' ' if !(inbound.req.X-A == "x") || inbound.req.X-B == "y" {\n' ' set-debug();\n }\n}') - assert isinstance(cond, LogicalOp) - assert cond.operator == "||" - assert isinstance(cond.left, NotOp) - assert isinstance(cond.left.operand, Comparison) - assert cond.left.operand.left == IdentValue(raw="inbound.req.X-A") - assert cond.left.operand.right == LiteralStringValue(raw="x") - assert isinstance(cond.right, Comparison) - assert cond.right.left == IdentValue(raw="inbound.req.X-B") + assert isinstance(cond, nodes.LogicalOp) + assert cond.operator == nodes.BoolOp.OR + assert isinstance(cond.left, nodes.NotOp) + assert isinstance(cond.left.operand, nodes.Comparison) + assert cond.left.operand.left == nodes.IdentValue(raw="inbound.req.X-A") + assert cond.left.operand.right == nodes.LiteralStringValue(text="x") + assert isinstance(cond.right, nodes.Comparison) + assert cond.right.left == nodes.IdentValue(raw="inbound.req.X-B") def test_double_negation(self): cond = self._first_condition('REMAP {\n if !!inbound.resp.All-Cache {\n set-debug();\n }\n}') - assert isinstance(cond, NotOp) - assert isinstance(cond.operand, NotOp) - assert isinstance(cond.operand.operand, IdentCondition) + assert isinstance(cond, nodes.NotOp) + assert isinstance(cond.operand, nodes.NotOp) + assert isinstance(cond.operand.operand, nodes.IdentCondition) assert cond.operand.operand.name == "inbound.resp.All-Cache" def test_not_bool_literal(self): cond = self._first_condition('REMAP {\n if !false {\n set-debug();\n }\n}') - assert isinstance(cond, NotOp) - assert isinstance(cond.operand, BoolLiteral) + assert isinstance(cond, nodes.NotOp) + assert isinstance(cond.operand, nodes.BoolLiteral) assert cond.operand.value is False def test_parens_override_precedence(self): @@ -395,14 +502,14 @@ def test_parens_override_precedence(self): 'REMAP {\n' ' if (inbound.req.X-A == "a" || inbound.req.X-B == "b") && inbound.req.X-C == "c" {\n' ' set-debug();\n }\n}') - assert isinstance(cond, LogicalOp) - assert cond.operator == "&&" - assert isinstance(cond.left, LogicalOp) - assert cond.left.operator == "||" - assert cond.left.left.left == IdentValue(raw="inbound.req.X-A") - assert cond.left.right.left == IdentValue(raw="inbound.req.X-B") - assert isinstance(cond.right, Comparison) - assert cond.right.left == IdentValue(raw="inbound.req.X-C") + assert isinstance(cond, nodes.LogicalOp) + assert cond.operator == nodes.BoolOp.AND + assert isinstance(cond.left, nodes.LogicalOp) + assert cond.left.operator == nodes.BoolOp.OR + assert cond.left.left.left == nodes.IdentValue(raw="inbound.req.X-A") + assert cond.left.right.left == nodes.IdentValue(raw="inbound.req.X-B") + assert isinstance(cond.right, nodes.Comparison) + assert cond.right.left == nodes.IdentValue(raw="inbound.req.X-C") def test_nested_parens_with_not(self): # !(a == "x" || b == "y") && c == "z" @@ -410,13 +517,13 @@ def test_nested_parens_with_not(self): 'REMAP {\n' ' if !(inbound.req.X-A == "x" || inbound.req.X-B == "y") && inbound.req.X-C == "z" {\n' ' set-debug();\n }\n}') - assert isinstance(cond, LogicalOp) - assert cond.operator == "&&" - assert isinstance(cond.left, NotOp) - assert isinstance(cond.left.operand, LogicalOp) - assert cond.left.operand.operator == "||" - assert isinstance(cond.right, Comparison) - assert cond.right.left == IdentValue(raw="inbound.req.X-C") + assert isinstance(cond, nodes.LogicalOp) + assert cond.operator == nodes.BoolOp.AND + assert isinstance(cond.left, nodes.NotOp) + assert isinstance(cond.left.operand, nodes.LogicalOp) + assert cond.left.operand.operator == nodes.BoolOp.OR + assert isinstance(cond.right, nodes.Comparison) + assert cond.right.left == nodes.IdentValue(raw="inbound.req.X-C") class TestIfBlocks: @@ -424,7 +531,7 @@ class TestIfBlocks: def test_simple_if(self): ast = _build('REMAP {\n if true {\n inbound.req.X = "y";\n }\n}') ib = ast.body[0].body[0] - assert isinstance(ib, IfBlock) + assert isinstance(ib, nodes.IfBlock) assert len(ib.body) == 1 assert ib.elif_branches == () assert ib.else_body == () @@ -443,9 +550,9 @@ def test_if_elif_else(self): ' inbound.resp.X = "other";\n }\n}') ast = _build(src) ib = ast.body[0].body[0] - assert isinstance(ib, IfBlock) + assert isinstance(ib, nodes.IfBlock) assert len(ib.elif_branches) == 1 - assert isinstance(ib.elif_branches[0], ElifBranch) + assert isinstance(ib.elif_branches[0], nodes.ElifBranch) assert len(ib.elif_branches[0].body) == 1 assert len(ib.else_body) == 1 @@ -465,9 +572,9 @@ def test_nested_if(self): ' if inbound.req.Y == "b" {\n set-debug();\n }\n }\n}') ast = _build(src) outer = ast.body[0].body[0] - assert isinstance(outer, IfBlock) + assert isinstance(outer, nodes.IfBlock) inner = outer.body[0] - assert isinstance(inner, IfBlock) + assert isinstance(inner, nodes.IfBlock) def test_mixed_body(self): src = ( @@ -477,9 +584,22 @@ def test_mixed_body(self): ast = _build(src) body = ast.body[0].body assert len(body) == 3 - assert isinstance(body[0], Assignment) - assert isinstance(body[1], IfBlock) - assert isinstance(body[2], Assignment) + assert isinstance(body[0], nodes.Assignment) + assert isinstance(body[1], nodes.IfBlock) + assert isinstance(body[2], nodes.Assignment) + + def test_empty_blocks(self): + # Grammar permits LBRACE blockItem* RBRACE — i.e. empty if/elif/else bodies. + src = ('REMAP {\n' + ' if inbound.req.X-A == "a" {\n } elif inbound.req.X-B == "b" {\n' + ' } else {\n }\n}') + ast = _build(src) + ib = ast.body[0].body[0] + assert isinstance(ib, nodes.IfBlock) + assert ib.body == () + assert len(ib.elif_branches) == 1 + assert ib.elif_branches[0].body == () + assert ib.else_body == () class TestLineNumbers: @@ -522,97 +642,97 @@ def setup_method(self): def test_use_directive(self): u = self.ast.body[0] - assert isinstance(u, UseDirective) + assert isinstance(u, nodes.UseDirective) assert u.line == 1 def test_var_section(self): vs = self.ast.body[1] - assert isinstance(vs, VarSection) + assert isinstance(vs, nodes.VarSection) assert vs.line == 2 def test_var_decl(self): vd = self.ast.body[1].declarations[0] - assert isinstance(vd, VarDecl) + assert isinstance(vd, nodes.VarDecl) assert vd.line == 3 def test_procedure_decl(self): pd = self.ast.body[2] - assert isinstance(pd, ProcedureDecl) + assert isinstance(pd, nodes.ProcedureDecl) assert pd.line == 5 def test_proc_param(self): pp = self.ast.body[2].params[0] - assert isinstance(pp, ProcParam) + assert isinstance(pp, nodes.ProcParam) assert pp.line == 5 def test_procedure_body_assignment(self): a = self.ast.body[2].body[0] - assert isinstance(a, Assignment) + assert isinstance(a, nodes.Assignment) assert a.line == 6 def test_section(self): s = self.ast.body[3] - assert isinstance(s, Section) + assert isinstance(s, nodes.Section) assert s.line == 8 def test_assignment(self): a = self.ast.body[3].body[0] - assert isinstance(a, Assignment) + assert isinstance(a, nodes.Assignment) assert a.line == 9 def test_function_call(self): fc = self.ast.body[3].body[1] - assert isinstance(fc, FunctionCall) + assert isinstance(fc, nodes.FunctionCall) assert fc.line == 10 def test_standalone_operator(self): fc = self.ast.body[3].body[2] - assert isinstance(fc, FunctionCall) + assert isinstance(fc, nodes.FunctionCall) assert fc.line == 11 def test_if_block(self): ib = self.ast.body[3].body[3] - assert isinstance(ib, IfBlock) + assert isinstance(ib, nodes.IfBlock) assert ib.line == 12 def test_comparison_in_condition(self): cond = self.ast.body[3].body[3].condition - assert isinstance(cond, Comparison) + assert isinstance(cond, nodes.Comparison) assert cond.line == 12 def test_break(self): brk = self.ast.body[3].body[3].body[0] - assert isinstance(brk, Break) + assert isinstance(brk, nodes.Break) assert brk.line == 13 def test_elif_branch(self): eb = self.ast.body[3].body[3].elif_branches[0] - assert isinstance(eb, ElifBranch) + assert isinstance(eb, nodes.ElifBranch) assert eb.line == 14 def test_elif_condition(self): cond = self.ast.body[3].body[3].elif_branches[0].condition - assert isinstance(cond, Comparison) + assert isinstance(cond, nodes.Comparison) assert cond.line == 14 def test_logical_op(self): cond = self.ast.body[3].body[4].condition - assert isinstance(cond, LogicalOp) + assert isinstance(cond, nodes.LogicalOp) assert cond.line == 19 def test_not_op(self): cond = self.ast.body[3].body[5].condition - assert isinstance(cond, NotOp) + assert isinstance(cond, nodes.NotOp) assert cond.line == 22 def test_bool_literal(self): cond = self.ast.body[3].body[6].condition - assert isinstance(cond, BoolLiteral) + assert isinstance(cond, nodes.BoolLiteral) assert cond.line == 25 def test_ident_condition(self): cond = self.ast.body[3].body[7].condition - assert isinstance(cond, IdentCondition) + assert isinstance(cond, nodes.IdentCondition) assert cond.line == 28 @@ -648,18 +768,18 @@ def test_nested_ifs_from_test_data(self): } }''' ast = _build(src) - sections = [i for i in ast.body if isinstance(i, Section)] + sections = [i for i in ast.body if isinstance(i, nodes.Section)] assert len(sections) == 1 s = sections[0] assert s.type == "REMAP" # Top-level if block outer = s.body[0] - assert isinstance(outer, IfBlock) + assert isinstance(outer, nodes.IfBlock) # Body: assignment + nested if - assert isinstance(outer.body[0], Assignment) - assert isinstance(outer.body[1], IfBlock) + assert isinstance(outer.body[0], nodes.Assignment) + assert isinstance(outer.body[1], nodes.IfBlock) middle = outer.body[1] # Middle if has elif and else @@ -668,14 +788,14 @@ def test_nested_ifs_from_test_data(self): # Deepest nested if (3 levels) inner = middle.body[1] - assert isinstance(inner, IfBlock) - assert isinstance(inner.condition, LogicalOp) - assert inner.condition.operator == "||" + assert isinstance(inner, nodes.IfBlock) + assert isinstance(inner.condition, nodes.LogicalOp) + assert inner.condition.operator == nodes.BoolOp.OR # Outer elif has modifiers assert len(outer.elif_branches) == 1 elif_cond = outer.elif_branches[0].condition - assert isinstance(elif_cond, Comparison) + assert isinstance(elif_cond, nodes.Comparison) assert elif_cond.modifiers == ("NOCASE", "PRE") def test_http_cntl_booleans(self): @@ -698,8 +818,8 @@ def test_ip_range_condition(self): }''' ast = _build(src) cond = ast.body[0].body[0].condition - assert isinstance(cond, Comparison) - assert cond.operator == "in" + assert isinstance(cond, nodes.Comparison) + assert cond.operator == nodes.CmpOp.IN assert len(cond.right) == 2 def test_set_membership_with_modifier(self): @@ -711,9 +831,10 @@ def test_set_membership_with_modifier(self): }''' ast = _build(src) cond = ast.body[0].body[0].condition - assert isinstance(cond, Comparison) - assert cond.operator == "in" - assert cond.right == (LiteralStringValue(raw="php"), LiteralStringValue(raw="php3"), LiteralStringValue(raw="php4")) + assert isinstance(cond, nodes.Comparison) + assert cond.operator == nodes.CmpOp.IN + assert cond.right == ( + nodes.LiteralStringValue(text="php"), nodes.LiteralStringValue(text="php3"), nodes.LiteralStringValue(text="php4")) assert cond.modifiers == ("EXT",) def test_debug_pattern_for_lint_rules(self): @@ -727,14 +848,14 @@ def test_debug_pattern_for_lint_rules(self): body = ast.body[0].body # set-debug() function call - assert isinstance(body[0], FunctionCall) + assert isinstance(body[0], nodes.FunctionCall) assert body[0].name == "set-debug" # TXN_DEBUG assignment with True - assert isinstance(body[1], Assignment) - assert body[1].target == Target.from_dotted("http.cntl.TXN_DEBUG") + assert isinstance(body[1], nodes.Assignment) + assert body[1].target == nodes.Target.from_dotted("http.cntl.TXN_DEBUG") assert body[1].value is True # Regular assignment (not flagged) - assert isinstance(body[2], Assignment) + assert isinstance(body[2], nodes.Assignment) assert body[2].target.namespace == "inbound.req" diff --git a/tools/hrw4u/tests/test_hrw4u_ast.py b/tools/hrw4u/tests/test_hrw4u_ast.py new file mode 100644 index 00000000000..726576ab744 --- /dev/null +++ b/tools/hrw4u/tests/test_hrw4u_ast.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import subprocess +import sys +from pathlib import Path + +SAMPLE = 'REMAP {\n inbound.req.X-Foo = "bar";\n set-debug();\n}\n' + + +def run_hrw4u_ast(args: list[str], stdin: str | None = None) -> subprocess.CompletedProcess: + script = Path("scripts/hrw4u-ast").resolve() + cmd = [sys.executable, str(script)] + args + return subprocess.run(cmd, capture_output=True, text=True, input=stdin) + + +def test_default_stage_emits_ast() -> None: + result = run_hrw4u_ast([], stdin=SAMPLE) + assert result.returncode == 0, result.stderr + assert "HRW4UAST" in result.stdout + assert "Section" in result.stdout + assert "set-debug" in result.stdout + + +def test_explicit_ast_stage() -> None: + result = run_hrw4u_ast(["--stage", "ast"], stdin=SAMPLE) + assert result.returncode == 0, result.stderr + assert "HRW4UAST" in result.stdout + + +def test_cst_stage() -> None: + result = run_hrw4u_ast(["--stage", "cst"], stdin=SAMPLE) + assert result.returncode == 0, result.stderr + # toStringTree produces parenthesized rule names; "program" is the start rule. + assert "program" in result.stdout + assert "HRW4UAST" not in result.stdout + + +def test_unknown_stage_errors() -> None: + result = run_hrw4u_ast(["--stage", "bogus"], stdin=SAMPLE) + assert result.returncode != 0 + assert "invalid choice" in result.stderr + + +def test_syntax_error_returns_nonzero() -> None: + result = run_hrw4u_ast([], stdin="REMAP { this is not valid ;") + assert result.returncode == 1 + assert "Parse failed" in result.stderr + + +def test_reads_from_file_argument(tmp_path: Path) -> None: + src = tmp_path / "sample.hrw4u" + src.write_text(SAMPLE) + result = run_hrw4u_ast([str(src)]) + assert result.returncode == 0, result.stderr + assert "HRW4UAST" in result.stdout