In [1]:
import re, copy

In [2]:
class TokenType:

  def __init__(self, t_type: str, regex: str):
    self.t_type = t_type
    self.regex = regex

class Token:
    def __init__(self):
       pass

    def __init__(self, t_type: TokenType, text: str, pos: int):
      self.t_type = t_type
      self.text = text
      self.pos = pos

    def __str__(self):
      return f'Tt: {self.t_type.t_type} Pos: {self.pos} Text: {self.text}'

    def IsMatch(self, *types):
        return self.t_type.t_type in types

In [3]:
tokenTypeList = [
    TokenType('FUNC', 'fn'),
    TokenType('RET', '=>'),
    TokenType('IDENT', '[a-zA-Z_][a-zA-Z0-9_]*'),
    TokenType('ASSIGN', '='),
    TokenType('REAL', '[0-9]+\.[0-9]+'),
    TokenType('INT', '[0-9]+'),
    TokenType('ADD', '\+'),
    TokenType('SUB', '-'),
    TokenType('MUL', '\*'),
    TokenType('DIV', '/'),
    TokenType('MOD', '%'),
    TokenType('LPAR', '\('),
    TokenType('RPAR', '\)'),
    TokenType('SPACE', '[ \\n\\t\\r]+'),
]
EndToken = Token(TokenType('END', '^[^.]$'), 'END', -1)

In [4]:
class Lexer:

  def __init__(self, code: str):
    self.code = code
    self.tokens = []
    self.pos = 0

  def run(self):
    while(self.nextToken()):
      pass
    return self

  def display(self):
    return list(map(str, filter(lambda x: x.t_type.t_type != 'SPACE', self.tokens)))

  def spaced_display(self):
    return list(map(str, self.tokens))
  
  def filtered(self):
    return list(filter(lambda x: x.t_type.t_type != 'SPACE', self.tokens))

  def nextToken(self) -> bool:
    if self.pos >= len(self.code):
      return False
    for t_type in tokenTypeList:
      regexp = re.compile(t_type.regex)
      captures = regexp.match(self.code[self.pos:])
      if captures != None:
        self.tokens.append(Token(t_type, captures.group(0), self.pos))
        self.pos += len(captures.group(0))
        return True
    raise Exception(f'Error occured on position {self.pos}. Text: ...{self.code[self.pos: self.pos + 10]}...')



In [5]:
Lexer('x_1 = (2 / (2 + 3.33) * 4) - -6').run().display()

['Tt: IDENT Pos: 0 Text: x_1',
 'Tt: ASSIGN Pos: 4 Text: =',
 'Tt: LPAR Pos: 6 Text: (',
 'Tt: INT Pos: 7 Text: 2',
 'Tt: DIV Pos: 9 Text: /',
 'Tt: LPAR Pos: 11 Text: (',
 'Tt: INT Pos: 12 Text: 2',
 'Tt: ADD Pos: 14 Text: +',
 'Tt: REAL Pos: 16 Text: 3.33',
 'Tt: RPAR Pos: 20 Text: )',
 'Tt: MUL Pos: 22 Text: *',
 'Tt: INT Pos: 24 Text: 4',
 'Tt: RPAR Pos: 25 Text: )',
 'Tt: SUB Pos: 27 Text: -',
 'Tt: SUB Pos: 29 Text: -',
 'Tt: INT Pos: 30 Text: 6']

In [6]:
class AstNode:

  def __init__(self, token: Token = None, Child1 = None, Child2 = None):
    self.token = token
    self.childs = []
    if Child1 != None:
      self.childs.append(Child1)
    if Child2 != None:
      self.childs.append(Child2)
  def display(self, indent = 0):
    print('   ' * indent, str(self.token))
    for child in self.childs:
      child.display(indent + 1)

class AstFuncNode(AstNode):

  def __init__(self, token: Token = None, Child1 = None, name: Token = None, params = None):
    super().__init__(token, Child1)
    self.fname = name
    self.params = params
    self.param_names = [p.token.text for p in self.params]

  def get_name(self):
    return self.fname.text
  
  def display(self, indent = 0):
    print('   ' * indent, f'Func: {self.fname.text}')
    print('     ' * indent,'Params: ')
    for node in self.params:
      node.display(indent + 1)
    print('     ' * indent,'Body: ')
    super().display(indent + 1)


In [7]:
# func -> "fn" fn-name IDENT* "=>" result
# fn-name -> IDENT

# func-call -> IDENT result* ### (certain number of times) ###

# NUMBER -> <num>
# IDENT -> <ident>
# assign -> IDENT "=" result
# group -> "(" result ")" | NUMBER | IDENT | func-call
# exp -> -group | group
# mult -> exp ( ( "*" | "/" | "%" ) exp )*
# add -> mult ( ( "+" | "-" ) mult )*
# result -> add | assign
# state -> result | func

In [8]:
class Parser:

  def __init__(self, code: str, existed_func = None):
    self.tokens = Lexer(code).run().filtered()
    self.end = len(self.tokens)
    self.pos = 0
    self.funcs = []
    self.fnames = []
    if existed_func  != None:
      self.funcs = existed_func
      self.fnames = list(map(AstFuncNode.get_name, self.funcs))
  
  def nextToken(self) -> Token:
    if self.pos == self.end:
      return EndToken
    self.pos += 1 # maybe not a good decision
    return self.tokens[self.pos - 1]

  def curToken(self) -> Token:
    if self.pos == self.end:
      return EndToken
    return self.tokens[self.pos]
  
  def prevToken(self) -> Token:
    if self.pos - 1 >= self.end:
      return EndToken
    return self.tokens[self.pos - 1]

  # IDENT -> <ident>
  def Ident(self) -> AstNode:
    ident = self.nextToken()
    if not ident.IsMatch('IDENT'):
      raise Exception(f'Not an identifier: {ident.text}')
    else:
      return AstNode(token = ident)

  # NUMBER -> <num>
  def Number(self) -> AstNode:
    number = self.nextToken()
    if not number.IsMatch('REAL', 'INT'):
      raise Exception(f'Not a number: {number.text}')
    else:
      return AstNode(token = number)

  # func-call -> IDENT result* ### (certain number of times) ###
  def Func_call(self) -> AstFuncNode:
    ident = self.nextToken().text
    func_node = copy.deepcopy(self.funcs[self.fnames.index(ident)])
    params = [self.Result() for i in range(len(func_node.params))]
    func_node.params = params
    return func_node

  #group -> "(" result ")" | NUMBER | IDENT | func-call 
  def Group(self) -> AstNode:
    if self.curToken().IsMatch('LPAR'):
      lpar = self.nextToken()
      result = self.Result()
      rpar = self.nextToken()
      if rpar.IsMatch('RPAR'):
        return result
      else:
        raise Exception(f'No RPAR found. Tokens: {list(map(str,self.tokens[self.pos - 2:self.pos + 1]))}')
    elif self.curToken().IsMatch('IDENT'):
      if self.curToken().text in self.fnames:
        return self.Func_call()
      return self.Ident()
    else:
      return self.Number()

  # exp -> "-"group | group
  def Exp(self) -> AstNode:
    if self.curToken().IsMatch('SUB'):
      cur_pos = self.nextToken().pos
      if cur_pos == self.curToken().pos - 1:
        return AstNode(token = self.prevToken(), Child1 = self.Group())
      else:
        raise Exception(f'Bad unar operation: {list(map(str, self.tokens[self.pos - 1:self.pos + 1]))}')
    else:
      return self.Group()

  # mult -> exp ( ( "*" | "/" | "%" ) exp )*
  def Mult(self) -> AstNode:
    result = self.Exp()
    while self.curToken().IsMatch('MUL', 'DIV', 'MOD'):
      result = AstNode(token = self.nextToken(), Child1 = result, Child2 = self.Exp())
    return result

  # add -> mult ( ( "+" | "-" ) mult )*
  def Add(self) -> AstNode:
    result = self.Mult()
    while self.curToken().IsMatch('ADD', 'SUB'):
      result = AstNode(token = self.nextToken(), Child1 = result, Child2 = self.Mult())
    return result

  # assign -> IDENT "=" result
  def Assign(self) -> AstNode:
    ident = self.Ident()
    assign = self.nextToken()
    if not assign.IsMatch('ASSIGN'):
      raise Exception('Bad assignment')
    value = self.Result()
    return AstNode(token = assign, Child1 = ident, Child2 = value)

  # func -> "fn" fn-name IDENT* "=>" result
  def Func(self) -> AstFuncNode:
    fn_token = self.nextToken()
    fn_name = self.Func_name().token
    params = []
    while self.pos != self.end and not self.curToken().IsMatch('RET'):
      params.append(self.Ident())
    if not self.nextToken().IsMatch('RET'):
      raise Exception(f'Empty function body, func: {fn_name}')
    fn_body = self.Result()
    res = AstFuncNode(token = fn_token, name = fn_name, params = params, Child1 = fn_body)
    if len(res.param_names) != len(set(res.param_names)):
      raise Exception('Function\'s declaration includes duplicate variable names!')
    return res
  
  # fn-name -> IDENT
  def Func_name(self) -> AstNode:
    return self.Ident()

  # result -> add | assign
  def Result(self) -> AstNode:
      if len(self.tokens[self.pos:]) > 1 and self.tokens[self.pos + 1].IsMatch('ASSIGN'):
        return self.Assign()
      else:
        return self.Add()
  
  # state -> result | func
  def State(self) -> AstNode:
    if self.curToken().IsMatch('FUNC'):
      return self.Func()
    else:
      return self.Result()
    
  def parse(self) -> AstNode:
    result = self.State()
    if self.pos != self.end:
      raise Exception(f'All tokens must be used. Unused part: {list(map(str, self.tokens[self.pos:]))}')
    else:
      return result


In [9]:
func = Parser('fn avg x y => (x + y) / 2').parse()
ast = Parser('1 + avg avg 4 2 3', [func]).parse()
func.display()

 Func: avg
 Params: 
    Tt: IDENT Pos: 7 Text: x
    Tt: IDENT Pos: 9 Text: y
 Body: 
    Tt: FUNC Pos: 0 Text: fn
       Tt: DIV Pos: 22 Text: /
          Tt: ADD Pos: 17 Text: +
             Tt: IDENT Pos: 15 Text: x
             Tt: IDENT Pos: 19 Text: y
          Tt: INT Pos: 24 Text: 2


In [10]:
class Evaluate:
  def __init__(self, root: AstNode, var_dict = None, funcs = None):
    self.root = root
    if var_dict == None:
      var_dict = dict()
    if funcs == None:
      funcs = dict() 
    self.vars = var_dict
    self.funcs = funcs

  def eval(self, node: AstNode = None) -> float:
    if node == None:
      node = self.root
    t_type = node.token.t_type.t_type
    if t_type == 'ADD':
        return self.eval(node.childs[0]) + self.eval(node.childs[1])
    elif t_type == 'ASSIGN':
        if node.childs[0].token.IsMatch('IDENT'):
          self.vars[node.childs[0].token.text] = self.eval(node.childs[1])
          return self.vars[node.childs[0].token.text]
        else:
          raise Exception(f'ERROR: Invalid assignment. \"{node.childs[0].token.text}\" is not a variable.')
    elif t_type == 'SUB':
        if len(node.childs) == 1:
          return -self.eval(node.childs[0])
        else:
          return self.eval(node.childs[0]) - self.eval(node.childs[1])
    elif t_type == 'MUL':
        return self.eval(node.childs[0]) * self.eval(node.childs[1])
    elif t_type == 'DIV':
        return self.eval(node.childs[0]) / self.eval(node.childs[1])
    elif t_type == 'MOD':
        return self.eval(node.childs[0]) % self.eval(node.childs[1])
    elif t_type == 'INT':
        return int(node.token.text)
    elif t_type == 'REAL':
        return float(node.token.text)
    elif t_type == 'IDENT':
        if node.token.text in self.vars:
          return float(self.vars[node.token.text])
        else:
          raise Exception(f'ERROR: Invalid identifier. No variable with name \"{node.token.text}\" was found.')
    elif t_type == 'FUNC':
        var_dict = dict()
        for i in range(len(node.params)):
            var_dict[node.param_names[i]] = self.eval(node.params[i])
        res = Evaluate(node.childs[0], var_dict, self.funcs).eval()
        return res
    else:
        raise Exception(f'UNIMPLEMENTED: {node.token.t_type.t_type}')

In [11]:
ev = Evaluate(ast, {}, [func])
ev.eval(), ev.vars

(4.0, {})

In [12]:
class Interpreter:
    def __init__(self):
        self.vars = {}
        self.functions = {}

    def check_undefined(self, node: AstFuncNode, func_list):
      var_dict = dict()
      for val in node.param_names:
        var_dict[val] = 0
      Evaluate(node.childs[0], var_dict, func_list).eval()
      pass

    def input(self, expression):
        parser = Parser(expression, list(self.functions.values()))
        if len(parser.tokens) == 0:
            return ""
        ast = parser.parse()
        if expression.startswith('fn'):
          # hack: Calling a function to determine undefined variables.
          self.check_undefined(ast, list(self.functions.values()))
          
          inter = set(self.vars) & set(self.functions | {ast.fname.text: ast})
          if inter:
            raise Exception(f'Defined twice: {inter}')
          self.functions[ast.fname.text] = ast
          return ""

        previous_state = self.vars.copy()
        res = Evaluate(ast, self.vars, list(self.functions.values())).eval()

        inter = set(self.vars) & set(self.functions)
        if inter:
          self.vars = previous_state
          raise Exception(f'Defined twice: {inter}')
        
        return res

In [13]:
interpreter = Interpreter();

# Basic arithmetic
interpreter.input("1 + 1")#, 2)
interpreter.input("2 - 1")#, 1)
interpreter.input("2 * 3")#, 6)
interpreter.input("8 / 4")#, 2)
interpreter.input("7 % 4")#, 3)

# Variables
interpreter.input("x = 1")#, 1)
interpreter.input("x")#, 1)
interpreter.input("x + 3")#, 4)
try:
  interpreter.input("y")
except Exception:
  print(Exception)

interpreter.input('fn avg x y => (x + y) / 2')
interpreter.input('fn avg x y => x + y')
interpreter.input('avg 4 2')


<class 'Exception'>


6.0

In [14]:
interpreter = Interpreter();

interpreter.input('x = 0')
interpreter.input('fn f => 1')

''

In [15]:
interpreter.input("fn x => 0")

Exception: ignored

In [16]:
interpreter.input("f = 5")

Exception: ignored

In [17]:
interpreter.input('f')

1