In [13]:
from lark import Lark, Tree
from lexer import Lexer as Lexer_

In [551]:
grammar = """
?start: statements

statements: statement+
statement: print | declare | exception_handling | return | control | expression_statement

print: PRINT_KEYWORD ROUND_OPEN print_args ROUND_CLOSE END_OF_STATEMENT
print_args: expression (COMMA expression)* | 

expression_statement: expression END_OF_STATEMENT 
                    | assignment END_OF_STATEMENT

expression: expression OPERATOR expression -> binary_op 
 | expression COMPARATOR expression -> binary_comp
 | unary_expression 
 | function_call 
 | IDENTIFIER index 
 | ROUND_OPEN expression ROUND_CLOSE 
 | literal 
 | IDENTIFIER COMPOUND_OPERATOR expression 
 | IDENTIFIER DOT_OPERATOR IDENTIFIER expression

unary_expression: UNARY_OPERATOR IDENTIFIER 
| IDENTIFIER UNARY_OPERATOR 
| NOT_OPERATOR IDENTIFIER 
| NOT_OPERATOR ROUND_OPEN expression ROUND_CLOSE

assignment: IDENTIFIER ASSIGNMENT_OPERATOR expression 

index: index SQUARE_OPEN expression SQUARE_CLOSE | 

control: function | if_else | while | do_while | for_loop | break_continue
function: FUNCTION_DECLARATION IDENTIFIER ROUND_OPEN parameters ROUND_CLOSE block
if_else: IF_ELIF ROUND_OPEN expression ROUND_CLOSE block else_temp 
else_temp: ELSE_KEYWORD block | 
while: WHILE_KEYWORD ROUND_OPEN expression ROUND_CLOSE block
do_while: DO_KEYWORD block WHILE_KEYWORD ROUND_OPEN expression ROUND_CLOSE
for_loop: FOR_KEYWORD ROUND_OPEN dec_control_flow END_OF_STATEMENT expression END_OF_STATEMENT for_update ROUND_CLOSE block
for_update: expression | assignment
break_continue: BREAK_CONTINUE END_OF_STATEMENT

dec_control_flow: VARIABLE_DECLARATION IDENTIFIER ASSIGNMENT_OPERATOR expression

declare: tuple_declaration | list_declaration | arr_declaration | exception_declaration | variable_declaration 
tuple_declaration: TUPLE_DECLARATION IDENTIFIER ASSIGNMENT_OPERATOR matrix END_OF_STATEMENT
list_declaration: LIST_DECLARATION IDENTIFIER ASSIGNMENT_OPERATOR matrix END_OF_STATEMENT
arr_declaration: ARR_DECLARATION IDENTIFIER ASSIGNMENT_OPERATOR matrix END_OF_STATEMENT
exception_declaration: EXCEPTION_TYPE IDENTIFIER ASSIGNMENT_OPERATOR IDENTIFIER END_OF_STATEMENT
variable_declaration: VARIABLE_DECLARATION IDENTIFIER variable_declaration_temp ASSIGNMENT_OPERATOR expression variable_declaration_expression_temp END_OF_STATEMENT
variable_declaration_temp: COMMA IDENTIFIER variable_declaration_temp | 
variable_declaration_expression_temp: COMMA expression variable_declaration_expression_temp | 

matrix: matrix_temp | list_content
matrix_temp: SQUARE_OPEN matrix matrix_temp_comma SQUARE_CLOSE | 
matrix_temp_comma: COMMA matrix matrix_temp_comma |
list_content: SQUARE_OPEN expression list_content_temp SQUARE_CLOSE | SQUARE_OPEN SQUARE_CLOSE
list_content_temp: COMMA expression list_content_temp |

exception_handling: try_catch_finally | throw
try_catch_finally: TRY_KEYWORD block CATCH_KEYWORD ROUND_OPEN EXCEPTION_TYPE IDENTIFIER ROUND_CLOSE block FINALLY_KEYWORD block
throw: THROW_KEYWORD EXCEPTION_TYPE ROUND_OPEN print_args ROUND_CLOSE END_OF_STATEMENT

block: CURLY_OPEN statements CURLY_CLOSE | CURLY_OPEN CURLY_CLOSE

function_call: IDENTIFIER ROUND_OPEN argument_temp ROUND_CLOSE -> function_call 
| IDENTIFIER DOT_OPERATOR IDENTIFIER ROUND_OPEN argument_temp ROUND_CLOSE -> inbuilt_function_call

return: RETURN_KEYWORD expression? END_OF_STATEMENT

literal: INTEGER_CONSTANT | DECIMAL_CONSTANT | STRING_LITERAL | BOOLEAN_VALUE | NULL_KEYWORD

argument_temp: expression (COMMA expression)* |

parameters: parameter parameters_temp |
parameter: VARIABLE_DECLARATION IDENTIFIER | LIST_DECLARATION IDENTIFIER | ARR_DECLARATION IDENTIFIER | TUPLE_DECLARATION IDENTIFIER
parameters_temp: COMMA parameter parameters_temp | 

%declare STRING_LITERAL BOOLEAN_VALUE COMMA FUNCTION_DECLARATION BREAK_CONTINUE IF_ELIF ELSE_KEYWORD WHILE_KEYWORD DO_KEYWORD FOR_KEYWORD PRINT_KEYWORD RETURN_KEYWORD VARIABLE_DECLARATION LIST_DECLARATION ARR_DECLARATION TUPLE_DECLARATION EXCEPTION_TYPE NULL_KEYWORD TRY_KEYWORD CATCH_KEYWORD FINALLY_KEYWORD THROW_KEYWORD KEYWORD NOT_OPERATOR ASSIGNMENT_OPERATOR OPERATOR COMPOUND_OPERATOR UNARY_OPERATOR COMPARATOR DOT_OPERATOR PUNCTUATION END_OF_STATEMENT ROUND_OPEN ROUND_CLOSE CURLY_OPEN CURLY_CLOSE SQUARE_OPEN SQUARE_CLOSE DECIMAL_CONSTANT INTEGER_CONSTANT IDENTIFIER QUOTATION ERROR
%import common.WS
%ignore WS
"""

In [552]:
from lark.lexer import Lexer, Token

class MyLexer(Lexer):
    def __init__(self, lexer_conf):
        pass

    def lex(self, data):
        lexer = Lexer_(source_code=data)
        lexer.tokenize()
        tokens = lexer.get_tokens()
        for type, value in tokens:
            yield Token(type, value)

In [553]:
parser = Lark(grammar, start='start', lexer=MyLexer, parser='lalr')

input_string = """
print(2,3);
"""

def visualize_tree(tree, depth=0):
    if isinstance(tree, Tree):
        print("  " * depth + "+-" + str(tree.data))
        for child in tree.children[:-1]:
            print("  " * (depth + 1) + "|")
            visualize_tree(child, depth + 1)
        if tree.children:
            print("  " * (depth + 1) + "|")
            visualize_tree(tree.children[-1], depth + 1)
    else:
        print("  " * depth + "+-" + str(tree))

try:
    tree = parser.parse(input_string)
    visualize_tree(tree)
    print("Parsing successful.")
except Exception as e:
    print("Parsing failed:", e)

['print', '(', '2', ',', '3', ')', ';']
+-statements
  |
  +-statement
    |
    +-print
      |
      +-print
      |
      +-(
      |
      +-print_args
        |
        +-expression
          |
          +-literal
            |
            +-2
        |
        +-,
        |
        +-expression
          |
          +-literal
            |
            +-3
      |
      +-)
      |
      +-;
Parsing successful.


In [343]:
import logging
from typing import List
logging.basicConfig()
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)

class ASTNode:
    """Abstract base class for abstract sequence of sequence of sums"""
    def __init__(self):
        """This is an abstract class and should not be instantiated"""
        this_class = self.__class__.__name__
        if this_class == "ASTNode":
            raise NotImplementedError("ASTNode is an abstract class and should not be instantiated")
        else:
            raise NotImplementedError(f"{this_class} is missing a constructor method")

In [344]:
class Statement(ASTNode):
    pass

class Statements(ASTNode):
    def __init__(self):
        self.statements: List[Statement] = []
    
    def append(self, statement: Statement):
        self.statements.append(statement)

    def __str__(self) -> str:
        el_strs = ", ".join(str(e) for e in self.statements)
        return f"[{el_strs}]"

    def __repr__(self):
        return f"statements({repr(self.statements)})"

In [345]:
from dataclasses import dataclass
import lark

In [571]:
@dataclass
class Declare(Statement):
    pass

@dataclass
class TupleDeclaration(Declare):
    def __init__(self, ):
        self.name: str
        self.value: List[Expression]

@dataclass
class ListDeclaration(Declare):
    pass

@dataclass
class ArrDeclaration(Declare):
    pass

@dataclass
class ExceptionDeclaration(Declare):
    pass

@dataclass
class VariableDeclaration(Declare):
    pass

@dataclass
class ExceptionHandling(Statement):
    pass

@dataclass
class TryCatchFinally(ExceptionHandling):
    pass

@dataclass
class Throw(ExceptionHandling):
    pass

@dataclass
class Return(Statement):
    pass    

@dataclass
class Control(Statement):
    pass

@dataclass
class Function(Control):
    pass

@dataclass
class IfElse(Control):
    pass

@dataclass
class While(Control):
    pass

@dataclass
class DoWhile(Control):
    pass

@dataclass
class ForLoop(Control):
    pass

@dataclass
class BreakContinue(Control):
    pass

class ExpressionStatement(Statement):
    def __init__(self, expression):
        self.expression = expression

class Expression(ASTNode):
    def __init__(self):
        self.children = []

    def add_child(self, child):
        self.children.append(child)

class Assignment(ASTNode):
    def __init__(self, identifier, expression):
        self.identifier = identifier
        self.expression = expression

    def __str__(self) -> str:
        return f"{self.identifier} = {self.expression}"
    
    def __repr__(self):
        return f"assignment({repr(self.identifier)}, {repr(self.expression)})"

class BinaryOp(Expression):
    def __init__(self, operator, left: Expression, right: Expression):
        super().__init__()
        self.operator = operator
        self.add_child(left)
        self.add_child(right)

    def __str__(self) -> str:
        return f"{self.children[0]} {self.operator} {self.children[1]}"
    
    def __repr__(self):
        return f"binary_op({repr(self.operator)}, {repr(self.children[0])}, {repr(self.children[1])})"

class UnaryExpression(Expression):
    def __init__(self, operator, operand):
        super().__init__()
        self.operator = operator
        self.add_child(operand)

    def __str__(self) -> str:
        return f"{self.operator}{self.children[0]}"
    
    def __repr__(self):
        return f"unary_expression({repr(self.operator)}, {repr(self.children[0])})"

class FunctionCall(Expression):
    def __init__(self, name, arguments):
        super().__init__()
        self.name = name
        self.arguments = arguments

    def __str__(self) -> str:
        return f"{self.name}({self.arguments})"
    
    def __repr__(self):
        return f"function_call({repr(self.name)}, {repr(self.arguments)})"
    
class InbuiltFunctionCall(Expression):
    def __init__(self, base, name, arguments):
        super().__init__()
        self.base = base
        self.name = name
        self.arguments = arguments

    def __str__(self) -> str:
        return f"{self.base}.{self.name}({self.arguments})"
    
    def __repr__(self):
        return f"inbuilt_function_call({repr(self.base)}, {repr(self.name)}, {repr(self.arguments)})"

class FunctionArgumentList(FunctionCall):
    def __init__(self):
        self.arguments = []

    def add_argument(self, argument):
        self.arguments.append(argument)

    def __str__(self) -> str:
        args = ", ".join(str(arg) for arg in self.arguments)
        return f"({args})"
    
    def __repr__(self):
        return f"function_argument_list({repr(self.arguments)})"

class FunctionArgument(FunctionArgumentList):
    def __init__(self, expression):
        self.expression = expression

    def __str__(self) -> str:
        return str(self.expression)
    
    def __repr__(self):
        return f"function_argument({repr(self.expression)})"

class Identifier(Expression):
    def __init__(self, name):
        super().__init__()
        self.name = name

    def __str__(self) -> str:
        return self.name
    
    def __repr__(self):
        return f"identifier({repr(self.name)})"

class Index(Expression):
    def __init__(self, base, index):
        super().__init__()
        self.base = base
        self.index = index

    def __str__(self) -> str:
        return f"{self.base}[{self.index}]"
    
    def __repr__(self):
        return f"index({repr(self.base)}, {repr(self.index)})"

class Literal(Expression):
    def __init__(self, value):
        super().__init__()
        self.value = value.value

    def __str__(self) -> str:
        return str(self.value)
    
    def __repr__(self):
        return f"literal({repr(self.value)})"
    
class Operator(Expression):
    def __init__(self, operator):
        super().__init__()
        self.operator = operator

    def __str__(self) -> str:
        return self.operator
    
    def __repr__(self):
        return f"operator({repr(self.operator)})"
    
class UnaryOperator(Expression):
    def __init__(self, operator):
        super().__init__()
        self.operator = operator

    def __str__(self) -> str:
        return self.operator
    
    def __repr__(self):
        return f"unary_operator({repr(self.operator)})"
    
class Comparator(Expression):
    def __init__(self, comparator):
        super().__init__()
        self.comparator = comparator

    def __str__(self) -> str:
        return self.comparator
    
    def __repr__(self):
        return f"comparator({repr(self.comparator)})"
    
class PrintStatement(ASTNode):
    def __init__(self, arguments):
        self.arguments = arguments

    def __str__(self) -> str:
        return f"print({self.arguments})"
    
    def __repr__(self):
        return f"print_statement({repr(self.arguments)})"

class PrintArgumentList(ASTNode):
    def __init__(self):
        self.arguments = []

    def add_argument(self, argument):
        self.arguments.append(argument)

    def __str__(self) -> str:
        args = ", ".join(str(arg) for arg in self.arguments)
        return f"({args})"
    
    def __repr__(self):
        return f"print_argument_list({repr(self.arguments)})"

class PrintArgument(ASTNode):
    def __init__(self, expression):
        self.expression = expression

    def __str__(self) -> str:
        return str(self.expression)
    
    def __repr__(self):
        return f"print_argument({repr(self.expression)})"
    

In [572]:
import lark

import logging
logging.basicConfig()
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)

In [573]:
class Transformer(lark.Transformer):

    def __init__(self):
        self.temp_args = []
        self.args_list = PrintArgumentList()
        self.args_list_function = FunctionArgumentList()

    def print(self, children):
        log.debug(f"Processing 'print' with {children}")
        return (children[2])

    def print_args(self, children):
        log.debug(f"Processing 'print_args' with {children}")
        return children[0]

    def print_args_temp(self, children):
        log.debug(f"Processing 'print_args_temp' with {children}")
        if len(children) > 1:
            self.temp_args.append(children[1])
        if len(children) > 2:
            return children[2]
        else:
            return []

    def expression(self, children):
        log.debug(f"Processing 'expression' with {children}")
        return children

    def statement(self, children):
        log.debug(f"Processing 'statement' with {children}")
        return children[0]

    def statements(self, children):
        log.debug(f"Processing 'statements' with {children}")
        stmts = Statements()
        for child in children:
            stmts.append(child)
        return stmts
    
    def tuple_declaration(self, args):
        log.debug(f"Processing Tuple Declaration with {args}")
        return TupleDeclaration()
    
    def list_declaration(self, args):
        log.debug(f"Processing List Declaration with {args}")
        return ListDeclaration()
    
    def arr_declaration(self, args):
        log.debug(f"Processing Arr Declaration with {args}")
        return ArrDeclaration()
    
    def exception_declaration(self, args):
        log.debug(f"Processing Exception Declaration with {args}")
        return ExceptionDeclaration()

    def variable_declaration(self, args):
        log.debug(f"Processing Variable Declaration with {args}")
        return VariableDeclaration()
    
    def exception_handling(self, args):
        log.debug(f"Processing Exception Handling with {args}")
        return ExceptionHandling()
    
    def try_catch_finally(self, args):
        log.debug(f"Processing Try Catch Finally with {args}")
        return TryCatchFinally()
    
    def throw(self, args):
        log.debug(f"Processing Throw with {args}")
        return Throw()
    
    def return_(self, args):
        log.debug(f"Processing Return with {args}")
        return Return()
    
    def function(self, args):
        log.debug(f"Processing Function with {args}")
        return Function()
    
    def if_else(self, args):
        log.debug(f"Processing If Else with {args}")
        return IfElse()
    
    def while_(self, args):
        log.debug(f"Processing While with {args}")
        return While()
    
    def do_while(self, args):
        log.debug(f"Processing Do While with {args}")
        return DoWhile()
    
    def for_loop(self, args):
        log.debug(f"Processing For Loop with {args}")
        return ForLoop()
    
    def break_continue(self, args):
        log.debug(f"Processing Break Continue with {args}")
        return BreakContinue()
    
    def expression_statement(self, args):
        log.debug(f'expression_statement - {args}')
        return args[0]
    
    def expression(self, children):
        log.debug(f'expression - {children}')
        if hasattr(children[0], 'value') and children[0].value == '(' and hasattr(children[-1], 'value') and children[-1].value == ')':
            return self.expression(children[1:-1])
        if len(children) == 1:
            return children[0]
        elif len(children) == 3:  # binary operation
            operator = children[1]
            return BinaryOp(operator, children[0], children[2])
        elif len(children) == 2:  # unary expression
            operator = children[0]
            return UnaryExpression(operator, children[1])
        elif len(children) == 4:  # function call
            function_name = children[0]
            arguments = children[2]
            return FunctionCall(function_name, arguments)
        elif len(children) == 5:  # identifier with index
            identifier = children[0]
            index = children[2]
            return Index(identifier, index)

    def IDENTIFIER(self, token):
        log.debug(f'IDENTIFIER - {token}')
        return Identifier(token.value)

    def literal(self, children):
        log.debug(f'literal - {children}')
        return Literal(children[0])

    def index(self, children):
        log.debug(f'index - {children}')
        return Index(children[0], children[1])
    
    def assignment(self, children):
        return Assignment(children[0], children[2])
    
    def binary_op(self, children):
        log.debug(f'binary_op - {children}')
        return BinaryOp(children[1], children[0], children[2])
    
    def unary_expression(self, children):
        if len(children) == 2:
            return UnaryExpression(children[0], children[1])
        else:
            return UnaryExpression(children[0], children[2])
        
    def function_call(self, children):
        log.debug(f'function_call - {children}')
        self.args_list_function = FunctionArgumentList()
        return FunctionCall(children[0], children[2])
    
    def inbuilt_function_call(self, children):
        log.debug(f'inbuilt_function_call - {children}')
        self.args_list_function = FunctionArgumentList()
        return InbuiltFunctionCall(children[0], children[2], children[4])
    
    def argument_temp(self, children):
        log.debug(f'argument_temp - {children}')
        if children:
            for i in range(0, len(children), 2):
                self.args_list_function.add_argument(children[i])
        return self.args_list_function
    
    def OPERATOR(self, token):
        return Operator(token.value)
    
    def COMPARATOR(self, token):
        return Comparator(token.value)
    
    def UNARY_OPERATOR(self, token):
        return UnaryOperator(token.value)

    def print(self, children):
        log.debug(f'print - {children}')
        self.args_list = PrintArgumentList()
        return PrintStatement(children[2])

    def print_args(self, children):
        log.debug(f'print_args - {children}')
        if children:
            for i in range(0, len(children), 2):
                self.args_list.add_argument(children[i])
        return self.args_list

In [574]:
tree = parser.parse("""
a = l.slice(1, 2);
""")

['a', '=', 'l', '.', 'slice', '(', '1', ',', '2', ')', ';']


In [575]:
visualize_tree(tree)

+-statements
  |
  +-statement
    |
    +-expression_statement
      |
      +-assignment
        |
        +-a
        |
        +-=
        |
        +-expression
          |
          +-inbuilt_function_call
            |
            +-l
            |
            +-.
            |
            +-slice
            |
            +-(
            |
            +-argument_temp
              |
              +-expression
                |
                +-literal
                  |
                  +-1
              |
              +-,
              |
              +-expression
                |
                +-literal
                  |
                  +-2
            |
            +-)
      |
      +-;


In [576]:
transformer = Transformer()
ast = transformer.transform(tree)

DEBUG:__main__:IDENTIFIER - a
DEBUG:__main__:IDENTIFIER - l
DEBUG:__main__:IDENTIFIER - slice
DEBUG:__main__:literal - [Token('INTEGER_CONSTANT', '1')]
DEBUG:__main__:expression - [literal('1')]
DEBUG:__main__:literal - [Token('INTEGER_CONSTANT', '2')]
DEBUG:__main__:expression - [literal('2')]
DEBUG:__main__:argument_temp - [literal('1'), Token('COMMA', ','), literal('2')]
DEBUG:__main__:inbuilt_function_call - [identifier('l'), Token('DOT_OPERATOR', '.'), identifier('slice'), Token('ROUND_OPEN', '('), function_argument_list([literal('1'), literal('2')]), Token('ROUND_CLOSE', ')')]
DEBUG:__main__:expression - [inbuilt_function_call(identifier('l'), identifier('slice'), function_argument_list([literal('1'), literal('2')]))]
DEBUG:__main__:expression_statement - [assignment(identifier('a'), inbuilt_function_call(identifier('l'), identifier('slice'), function_argument_list([literal('1'), literal('2')]))), Token('END_OF_STATEMENT', ';')]
DEBUG:__main__:Processing 'statement' with [assignm

In [577]:
tree

Tree(Token('RULE', 'statements'), [Tree(Token('RULE', 'statement'), [Tree(Token('RULE', 'expression_statement'), [Tree(Token('RULE', 'assignment'), [Token('IDENTIFIER', 'a'), Token('ASSIGNMENT_OPERATOR', '='), Tree(Token('RULE', 'expression'), [Tree('inbuilt_function_call', [Token('IDENTIFIER', 'l'), Token('DOT_OPERATOR', '.'), Token('IDENTIFIER', 'slice'), Token('ROUND_OPEN', '('), Tree(Token('RULE', 'argument_temp'), [Tree(Token('RULE', 'expression'), [Tree(Token('RULE', 'literal'), [Token('INTEGER_CONSTANT', '1')])]), Token('COMMA', ','), Tree(Token('RULE', 'expression'), [Tree(Token('RULE', 'literal'), [Token('INTEGER_CONSTANT', '2')])])]), Token('ROUND_CLOSE', ')')])])]), Token('END_OF_STATEMENT', ';')])])])

In [578]:
ast

statements([assignment(identifier('a'), inbuilt_function_call(identifier('l'), identifier('slice'), function_argument_list([literal('1'), literal('2')])))])

In [579]:
visualize_tree(ast)

+-[a = l.slice((1, 2))]


In [508]:
repr(ast)

"statements([print_statement(print_argument_list([binary_op(operator('+'), literal('2'), literal('3'))]))])"

In [374]:
visualize_tree(tree)

+-start
  |
  +-statements
    |
    +-statement
      |
      +-print
        |
        +-print
        |
        +-(
        |
        +-print_args
          |
          +-expression
            |
            +-literal
              |
              +-3
          |
          +-print_args_temp
        |
        +-)
        |
        +-;
    |
    +-statement
      |
      +-print
        |
        +-print
        |
        +-(
        |
        +-print_args
          |
          +-expression
            |
            +-literal
              |
              +-3
          |
          +-print_args_temp
            |
            +-,
            |
            +-expression
              |
              +-literal
                |
                +-4
            |
            +-print_args_temp
        |
        +-)
        |
        +-;


In [230]:
tree = parser.parse("""
var a = 3;
""")

['var', 'a', '=', '3', ';']


In [231]:
ast = transformer.transform(tree)
visualize_tree(ast)

DEBUG:__main__:Processing Variable Declaration with [Token('VARIABLE_DECLARATION', 'var'), Token('IDENTIFIER', 'a'), Tree(Token('RULE', 'variable_declaration_temp'), []), Token('ASSIGNMENT_OPERATOR', '='), Tree(Token('RULE', 'expression'), [Tree(Token('RULE', 'literal'), [literal(3)])]), Tree(Token('RULE', 'variable_declaration_expression_temp'), []), Token('END_OF_STATEMENT', ';')]


+-start
  |
  +-statements
    |
    +-statement
      |
      +-declare
        |
        +-VariableDeclaration()
