In [28]:
from lark import Lark, Transformer, v_args
from pathlib import Path
import json

In [29]:
# Loading the grammar file
GRAMMAR_PATH = Path("dsl/grammar.lark")

with open(GRAMMAR_PATH, "r") as f:
    grammar_text = f.read()

dsl_parser = Lark(grammar_text, start="start", parser="lalr")

In [30]:
# Base class for all AST nodes
class ASTNode:
    def to_dict(self):
        raise NotImplementedError

In [31]:
class Series(ASTNode):
    def __init__(self, value):
        self.type = "series"
        self.value = value

    def to_dict(self):
        return {"type": "series", "value": self.value}


In [32]:
class Indicator(ASTNode):
    def __init__(self, text):
        self.type = "indicator"
        name, args = text.split("(", 1)
        self.name = name.lower()
        args = args.rstrip(")")
        self.params = [a.strip() for a in args.split(",")]

    def to_dict(self):
        return {
            "type": "indicator",
            "name": self.name,
            "params": self.params
        }


In [33]:
class BinaryOp(ASTNode):
    def __init__(self, left, op, right):
        self.type = "binary_op"
        self.left = left
        self.right = right
        self.op = op

    def to_dict(self):
        return {
            "type": self.type,
            "left": self.left.to_dict(),
            "op": self.op,
            "right": self.right.to_dict()
        }


In [34]:
class CrossOp(ASTNode):
    def __init__(self, left, direction, right):
        self.type = "cross"
        self.left = left
        self.right = right
        self.dir = direction

    def to_dict(self):
        return {
            "type": "cross",
            "dir": self.dir,
            "left": self.left.to_dict(),
            "right": self.right.to_dict()
        }


In [38]:
class DSLTransformer(Transformer):

    # NEW: handle the root node
    def start(self, items):
        return items   # <-- this fixes the TypeError

    def entry_section(self, items):
        return ("entry", items[0])

    def exit_section(self, items):
        return ("exit", items[0])

    def comparison(self, items):
        left, op, right = items
        return BinaryOp(left, op.value, right)

    def and_op(self, items):
        return {"type": "and", "left": items[0], "right": items[1]}

    def or_op(self, items):
        return {"type": "or", "left": items[0], "right": items[1]}

    def cross_above(self, items):
        return CrossOp(items[0], "above", items[1])

    def cross_below(self, items):
        return CrossOp(items[0], "below", items[1])

    def indicator(self, items):
        return Indicator(str(items[0]))

    def ident(self, items):
        return Series(str(items[0]))

    def number(self, items):
        return Series(str(items[0]))  # treat number as literal node

    def yest_high(self, _):
        return Series("yesterday_high")


In [39]:
def parse_dsl(text: str):
    try:
        tree = dsl_parser.parse(text)
        ast_items = DSLTransformer().transform(tree)

        final = {"entry": [], "exit": []}

        for section, rule in ast_items:
            final[section].append(
                rule.to_dict() if hasattr(rule, "to_dict") else rule
            )

        return final

    except Exception as e:
        print("DSL Parsing Error:", e)
        raise e


In [40]:
dsl_text = """ENTRY: close > SMA(close,20) AND volume > 1000000
EXIT: RSI(close,14) < 30
""".strip()


ast = parse_dsl(dsl_text)
print(ast)


{'entry': [{'type': 'and', 'left': <__main__.BinaryOp object at 0x0000023E53FE77C0>, 'right': <__main__.BinaryOp object at 0x0000023E53FE7FD0>}], 'exit': [{'type': 'binary_op', 'left': {'type': 'indicator', 'name': 'rsi', 'params': ['close', '14']}, 'op': '<', 'right': {'type': 'series', 'value': '30'}}]}
