In [6]:
# Python program to parse LaTeX formulas and produce Python/Prover9 expressions

# by Peter Jipsen 2023-4-6 distributed under LGPL 3 or later.
# Terms are read using Vaughn Pratt's top-down parsing algorithm.

# Modified by Jared Amaral, Jose Arellano, Nathan Nguyen, Alex Wunderli in May 2023 for usage in their Algorithm Analysis course project. 

# List of symbols handled by the parser (at this point)
# =====================================================
# \And \approx \backslash \bb \bigcap \bigcup \bot \cap \cc \cdot  
# \circ \Con \cup \equiv \exists \forall \ge \implies \in \le \ln \m 
# \mathbb \mathbf \mathcal \mid \Mod \models \ne \neg \ngeq \nleq \Not 
# \nvdash \oplus \Or \Pre \setminus \sim \subset \subseteq \supset \supseteq 
# \times \to \top \vdash \vee \vert \wedge + * / ^ _ ! = < > ( ) [ ] \{ \} | | $

# Greek letters and most other LaTeX symbols can be used as variable names.
# A LaTeX symbol named \abc... is translated to the Python variable _abc...

#!pip install provers
#from provers import *
import math, itertools, re, sys, subprocess
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'latex2sympy2'])
from sympy import *
x, y, z, t, i, n, m = symbols('x y z t i n m') # x, y, z, t = symbols('x y z t') # init_session()
from latex2sympy2 import *
from IPython.display import *

# The macros below are used to simplify the input that needs to be typed.
macros=r"""
\renewcommand{\And}{\ \text{and}\ }
\newcommand{\Or}{\ \text{or}\ }
\newcommand{\Not}{\text{not}\ }
\newcommand{\m}{\mathbf}
\newcommand{\bb}{\mathbb}
\newcommand{\cc}{\mathcal}
\newcommand{\s}{\text}
\newcommand{\bsl}{\backslash}
\newcommand{\sm}{{\sim}}
\newcommand{\tup}[1]{(#1)}
\newcommand{\Mod}{\text{Mod}}
\newcommand{\Con}{\text{Con}}
\newcommand{\Pre}{\text{Pre}}
"""
display(Markdown("$"+macros+"$"))
RunningInCOLAB = 'google.colab' in str(get_ipython())
if not RunningInCOLAB: macros=""

# Setting pi and e constants
from IPython.display import *
import math, itertools, re
_pi = sympy.pi
_e = math.e

# Integration Function
def integrate2(a, b):
    return str(integrate(a, b) + "+ C")

# Postfic check function
def is_postfix(t):
    return hasattr(t,'leftd') and len(t.a)==1

def w(t,i): # decide when to add parentheses during printing of terms
    subt = t.a[i] if len(t.a)>i else "index out of range"
    return str(subt) if subt.lbp < t.lbp or subt.a==[] or \
        (subt.sy==t.sy and subt.lbp==t.lbp) or \
        (not hasattr(subt,'leftd') or not hasattr(t,'leftd')) or \
        (is_postfix(subt) and is_postfix(t)) else "("+str(subt)+")"

# Similar to w function but modified for calculus functions
def w2(t,i):
  subt = t.a[i] if len(t.a)>i else "index out of range"
  return str(subt) if subt.lbp < t.lbp or subt.a==[] \
        or (not hasattr(subt,'leftd') and subt.lbp==1200) \
        else "("+str(subt)+")"

def w3(t,i): # always add parentheses
  subt = t.a[i] if len(t.a)>i else "index out of range"
  return "("+str(subt)+")"

# Creates arbitrary constant C (for integration) 
def letter(c): return 'a'<=c<='z' or 'A'<=c<='Z'
def alpha_numeric(c): return 'a'<=c<='z' or 'A'<=c<='Z' or '0'<=c<='9'

# Base Symbol Class
class symbol_base(object):
    a = []
    def __repr__(self): 
        if   len(self.a) == 0: return self.sy.replace("\\","_").replace("{","").replace("}","")
        elif len(self.a) == 2:
         return w(self,0)+" "+self.sy+" "+w(self,1)
        else:
         return self.sy+"("+",".join([w(self,j) for j in range(len(self.a))])+")"

# Symbol Function
def symbol(id, bp=1200): # identifier, binding power; LOWER binds stronger
    if id in symbol_table:
        s = symbol_table[id]    # look symbol up in table
        s.lbp = min(bp, s.lbp)  # update left binding power
    else:
        class s(symbol_base):   # create class for this symbol
            pass
        s.sy = id
        s.lbp = bp
        s.nulld = lambda self: self
        symbol_table[id] = s
    return s

def advance(id=None):
    global token
    if id and token.sy != id:
        raise SyntaxError("Expected "+id+" got "+token.sy)
    token = next()

def nulld(self): # null denotation
    expr = expression()
    advance(")")
    return expr

def nulldbr(self): # null denotation
    expr = expression()
    advance("}")
    return expr

#prefix2:
  # \frac d{dx}(\sin x)
  # ("\frac", "d","dx",("\sin","x"))

# Prefix2 is utilized for differentiation
def prefix2(id, bp=0): # parse n-ary prefix operations
  global token
  def nulld(self): # null denotation
    self.a = [expression(bp), expression(bp)]
    if self.a[0].sy=="d" and self.a[1].sy[0]=="d": self.a.append(expression(bp))
    return self
  s = symbol(id, bp)
  s.nulld = nulld
  return s

# Prefix3 is utilized for integration, limits, and summation
def prefix3(id, bp=0, nargs=1): # parse prefix operator \int, \lim, \sum
  global token
  def nulld(self): # null denotation
    global token
    #print('token.sy',token.sy,'self.sy',self.sy)
    self.a = []
    if token.sy=="_":
      advance("_")
      self.a += [expression(300)]
      if token.sy=="^":
        advance("^")
        self.a += [expression(300)]
    self.a = ([expression(bp)] if nargs==1 else [expression(bp), expression(bp)])+self.a
    return self
  s = symbol(id, bp)
  s.nulld = nulld
  return s

# General prefix function for general mathematics functions
def prefix(id, bp=0): # parse n-ary prefix operations
    global token
    def nulld(self): # null denotation
        global token
        if token.sy not in ["(","[","{"] and self.sy not in ["\\forall","\\exists"]:
            #print('token.sy',token.sy,'self.sy',self.sy)
            self.a = [] if token.sy in [",",")","}",":","=","!="] else [expression(bp)]
            if self.sy=="|": advance("|")
            return self
        else:
            closedelim = ")" if token.sy=="(" else "]" if token.sy=="[" else "}"
            token = next()
            self.a = []
            if token.sy != ")":
                while 1:
                    self.a.append(expression())
                    if token.sy != ",":
                        break
                    advance(",")
            advance(closedelim)
            if closedelim=="}" and token.sy=="(": #make \cmd{v}(...) same as \cmd c(...)
              prefix(self.a[0].sy)
              token = next()
              self.a[0].a = []
              if token.sy != ")":
                while 1:
                    self.a[0].a.append(expression())
                    if token.sy != ",":
                        break
                    advance(",")
              advance(")")
            return self
    s = symbol(id, bp)
    s.nulld = nulld
    return s

# Determines infix
def infix(id, bp, right=False):
    def leftd(self, left): # left denotation
        self.a = [left]
        self.a.append(expression(bp+(1 if right else 0)))
        return self
    s = symbol(id, bp)
    s.leftd = leftd
    return s

# Determines whether expression is pre or infix
def preorinfix(id, bp, right=True): # used for minus
    def leftd(self, left): # left denotation
        self.a = [left]
        self.a.append(expression(bp+(1 if right else 0)))
        return self
    def nulld(self): # null denotation
        global token
        self.a = [expression(bp)]
        return self
    s = symbol(id, bp)
    s.leftd = leftd
    s.nulld = nulld
    return s

def plist(id, bp=0): #parse a parenthesized comma-separated list
    global token
    def nulld(self): # null denotation
        global token
        self.a = []
        if token.sy not in ["]","\\}"]:
            while True:
                self.a.append(expression())
                if token.sy != ",": break
                advance(",")
        advance()
        return self
    s = symbol(id, bp)
    s.nulld = nulld
    return s

# Postfix is utilized for postfix expressions
def postfix(id, bp):
    def leftd(self,left): # left denotation
        self.a = [left]
        return self
    s = symbol(id, bp)
    s.leftd = leftd
    return s

# Symbol table dictionary
symbol_table = {}

# The parsing rules  below decode a string of tokens into an abstract syntax tree with methods .sy 
# for symbol (a string) and .a for arguments.

# Intializes table of mathematical symbols, utilizes lamba calculus
def init_symbol_table():
    global symbol_table
    symbol_table = {}
    symbol("(").nulld = nulld
    symbol(")")
    symbol("{").nulld = nulldbr
    symbol("}")
    prefix("|").__repr__ = lambda x: "len("+str(x.a[0])+")" #interferes with p|q from Prover9
    symbol("]")
    symbol("\\}")
    symbol(",")
    postfix("!",300).__repr__ =       lambda x: "math.factorial("+str(x.a[0])+")"
    postfix("f",300).__repr__ =       lambda x: "f"+w3(x,0)
    postfix("'",300).__repr__ =       lambda x: str(x.a[0])+"'"
    prefix("\\ln",310).__repr__ =     lambda x: "math.log("+str(x.a[0])+")"
    prefix("\\sin",310).__repr__ =    lambda x: "sin("+str(x.a[0])+")"  # use math.sin if sympy is not loaded
    infix(":", 450).__repr__ =        lambda x: str(x.a[0])+": "+w3(x,1) # for f:A\to B
    infix("^", 300).__repr__ =        lambda x: "converse("+str(x.a[0])+")"\
      if len(x.a)>1 and str(x.a[1].sy)=='\\smallsmile' else "O("+str(x.a[0])+")"\
      if P9 and len(x.a)>0 and str(x.a[1])=="-1" else w2(x,0)+"\\wedge "+w2(x,1)\
      if P9 else w2(x,0)+"**"+w2(x,1)                                       # power
    infix("_", 300).__repr__ =        lambda x: str(x.a[0])+"["+w(x,1)+"]"  # sub
    infix(";", 303).__repr__ =        lambda x: "relcomposition("+w(x,0)+","+w(x,1)+")" # relation composition
    infix("\\circ", 303).__repr__ =   lambda x: "relcomposition("+w(x,1)+","+w(x,0)+")" # function composition
    infix("*", 311).__repr__ =        lambda x: w2(x,0)+"\\cdot "+w2(x,1)   # times
    infix("\\cdot", 311).__repr__ =   lambda x: w2(x,0)+"*"+w2(x,1)         # times
    infix("/", 312).__repr__ =        lambda x: w2(x,0)+"/"+w2(x,1)         # over
    infix("+", 313).__repr__ =        lambda x: w2(x,0)+" + "+w2(x,1)       # plus
    preorinfix("-",313).__repr__ =    lambda x: "-"+w(x,0) if len(x.a)==1 else str(x.a[0])+" - "+w(x,1) #negative or minus
    
    # Psuedocode w/ {algpseudocodex}
    prefix("\\algb",310).__repr__ = lambda x: (x.a[0].sy)
    prefix("\\alge",310).__repr__ = lambda x: ""
    
    prefix("\\If",310).__repr__ = lambda x: "if " + str(x.a[0]) + ":"
    prefix("\\State",350).__repr__ = lambda x: "\t" + str(x.a[0])
    prefix("\\Output",310).__repr__ = lambda x: "print(" + str(x.a[0]) + ")"
    prefix("\\Return",310).__repr__ = lambda x: "return " + str(x.a[0])
    prefix("\\While",310).__repr__ = lambda x: "while " + str(x.a[0]) + ":"
    infix("\\gets",345).__repr__ = lambda x: w2(x,0) + " = "+ w2(x,1)
    

    infix("=", 405).__repr__ =        lambda x: w(x,0)+"=="+w(x,1)          # assignment or identity
    infix("==", 405).__repr__ =       lambda x: w(x,0)+" = "+w(x,1)         # assignment or identity
    infix("\\ne", 405).__repr__ =     lambda x: w(x,0)+" != "+w(x,1)        # nonequality
    infix("!=", 405).__repr__ =       lambda x: w(x,0)+"\\ne "+w(x,1)       # nonequality
    infix("\\le", 405).__repr__ =     lambda x: w2(x,0)+" <= "+str(x.a[1])  # less or equal
    infix("<=", 405).__repr__ =       lambda x: w2(x,0)+"\\le "+str(x.a[1]) # less or equal in Python
    infix("\\ge", 405).__repr__ =     lambda x: w2(x,0)+">="+str(x.a[1])    # greater or equal
    infix("<", 405).__repr__ =        lambda x: w2(x,0)+" < "+str(x.a[1])   # less than
    infix(">", 405).__repr__ =        lambda x: w2(x,0)+" > "+str(x.a[1])   # greater than
    prefix("\\Not",450).__repr__=     lambda x: "not "+w3(x,0)              # logical negation
    infix("\\Or", 503).__repr__=      lambda x: w(x,0)+(" or ")+w(x,1)      # disjunction
    infix("\\And", 503).__repr__=     lambda x: w(x,0)+(" and ")+w(x,1)     # conjunction
    postfix("?", 600).__repr__ =      lambda x: str(x.a[0])+"?"             # calculate value and show it
    
    
    
    symbol("(end)")

init_symbol_table()

# tokenize(st):
  # \frac{d}{dx}

# Determines tokens from an expression
def tokenize(st):
    i = 0
    # loop the length of the string
    while i<len(st):
        print("tokenize(" + st + ")")
        tok = st[i]
        j = i+1
        # \lim
        if j<len(st) and (st[j]=="{" or st[j]=="}") and tok=='\\':
          print("if 1")
          j += 1
          tok = st[i:j]
          symbol(tok)
        elif letter(tok) or tok=='\\': #read consecutive letters or digits
            print("if 2")
            while j<len(st) and letter(st[j]): j+=1
            tok = st[i:j]
            if tok=="\\" and j<len(st) and st[j]==" ": j+=1
            if tok=="\\text": j = st.find("}",j)+1 if st[j]=="{" else j #extend token to include {...} part
            if tok=="\\s": j = st.find("}",j)+1 if st[j]=="{" else j
            if tok=="\\mathcal": j = st.find("}",j)+1 if st[j]=="{" else j
            if tok=="\\cc": j = st.find("}",j)+1 if st[j]=="{" else j
            if tok=="\\tup": j += 1 if st[j]=="{" else j
            tok = st[i:j]
            symbol(tok)
            if j<len(st) and st[j]=='(': prefix(tok, 1200) #promote tok to function
        elif "0"<=tok<="9": #read (decimal) number in scientific notation
            print("if 3")
            while j<len(st) and ('0'<=st[j]<='9' or st[j]=='.'):# in ['.','e','E','-']):
                j+=1
            tok = st[i:j]
            symbol(tok)
        elif tok =="-" and st[j]=="-": pass
        elif tok not in " '(,)[]{}\\|\n": #read operator string
            while j<len(st) and not alpha_numeric(st[j]) and \
                  st[j] not in " '(,)[]{}\\\n": j+=1
            tok = st[i:j]
            if tok not in symbol_table: symbol(tok)
        i = j
        if tok not in [' ','\\newline','\\ ','\\quad','\\qquad','\n']: #skip these tokens
            symb = symbol_table[tok]
            if not symb: #symb = symbol(tok)
                raise SyntaxError("Unknown operator")
#            print tok, 'ST', symbol_table.keys()
            yield symb()
    symb = symbol_table["(end)"]
    yield symb()

def expression(rbp=1200): # read an expression from token stream
    global token
    t = token
    try:
      token = next()
    except:
      token = ttt
    left = t.nulld()
    while rbp > token.lbp:
        t = token
        token = next()
        left = t.leftd(left)
    return left

# parse(str):
  # \sin{}

# Parser of expressions
def parse(str):
    print("parse(" + str + ")")
    global token, next
    next = tokenize(str).__next__
    token = next()
    print("token:")
    print(token)
    return expression()

ttt=parse(".")

def ast(t):
    if len(t.a)==0: return '"'+t.sy+'"'
    return '("'+t.sy+'",'+", ".join(ast(s) for s in t.a)+")"

# Convert (a subset of) LaTeX input to valid Python(sympy) code
# Display LaTeX with calculated answers inserted
# Return LaTeX and/or Python code as a string

#nextmath(st, index):
  # checks if the string is enclosed in '$' or '$$'

  # st - string input from user
  # index - st starting index

def nextmath(st,i): #find next j,k>=i such that st[j:k] is inline or display math
  # find first occurence of '$' starting from  i
  j = st.find("$",i)
  # if '$' is not found 
  if j==-1: 
    return (-1,0,False)

  print("first '$' at: " + str(j))
  k = st.find("$",j+1)
  print("second '$' at: " + str(k))
  # check if the math string is just "$$"
  if st[j+1]=="$":
    # set k equal to the starting index of "$$"
    k = st.find("$$",j+2)
    # j = index after the double "$$"
    # k = starting index of "$$"
    # d = True (found "$$")
    return (j+2,k,True)
  else:
    # j = char after first '$'
    # k = index at second '$'
    # d = false
    return (j+1,st.find("$",j+1),False)


# convert st (a LaTeX string) to Python/Prover9 code and evaluate it
# creates syntax tree and decides the hierarchy of functions to use first
# process(st, info, nocolor):
  # st - string input
def process(st, info=False, nocolor=False):
  # use latex2sympy2 parser if the user uses the ls() function
  if st[:3]=="ls(": # use latex2sympy2 parser
    return ("" if nocolor else "\color{green}")+macros+st[3:-1]+("" if nocolor else "\color{blue}")+" = "+latex2latex(st[3:-1])
  # how does this work?
  t=parse(st)
  if info:
    print("Abstract syntax tree:", ast(t))
    print("Expression:", t)
  if t.sy!="?": # check if t is not asking to be evaluated
    if t.sy!="=": # check if t is an assignment
      if t.sy=="show": # check if t is a show command
        try:
          exec(str(t),globals())
        except:
          if info: print("no result")
          return macros+st
        return ("" if nocolor else "\color{green}")+macros+st
      return macros+st
    ss = str(t).replace("==","=",1)
    try:
      exec(ss,globals())
    except:
      if info: print("no result")
      return macros+st
    return ("" if nocolor else "\color{green}")+macros+st
  tt = t.a[0]
  st = st.replace("?","")
  if tt.sy=="=":
    ss = str(tt).replace("==","=",1)
    try:
      exec(ss,globals())
    except:
      if info: print("no result")
      return macros+st
    return ("" if nocolor else "\color{green}")+macros+st+("" if nocolor else "\color{deepskyblue}")+" = "+pyla(eval(str(tt.a[0])))
  try:
    val=eval(str(tt))
    if info: print("Value:", val)
    ltx = val if str(tt)[:5] in ["latex","addpl"] else pyla(val)
  except:
    return ("" if nocolor else "\color{green}")+macros+st
  return ("" if nocolor else "\color{green}")+macros+st+("" if nocolor else "\color{deepskyblue}")+" = "+ltx

# l(st, info, output, nocolor)
  # st - string input from user
  # info - 
  # output -
  # nocolor - 


  # Main function to translate valid LaTeX/Markdown string st
def p(st, info=False, output=False, nocolor=False):
  # assuming this is used to get r"""
  global macros
  st = re.sub("\n%.*?\n","\n",st) #remove LaTeX comments
  st = re.sub("%.*?\n","\n",st) #remove LaTeX comments
  # look for '$' in the string and update indices (j,k)
  (j,k,d) = nextmath(st,0)
  # out = the first '$'
  out = st[0:j]
  # while there are two '$'
  while j!=-1 and k!=-1:
    # process the math equation in latex
    print("out1: " + out)
    out += process(st[j:k],info,nocolor)
    print("out2: " + st[0:j])
    p = k
    (j,k,d) = nextmath(st,k+(2 if d else 1))
    out += st[p:j] if j!=-1 else st[p:]
    print("out3: " + st[0:j])
  display(Markdown(out))
  if output: print(out)

prvrs="Model" in dir() # check if provers module is loaded

$
\renewcommand{\And}{\ \text{and}\ }
\newcommand{\Or}{\ \text{or}\ }
\newcommand{\Not}{\text{not}\ }
\newcommand{\m}{\mathbf}
\newcommand{\bb}{\mathbb}
\newcommand{\cc}{\mathcal}
\newcommand{\s}{\text}
\newcommand{\bsl}{\backslash}
\newcommand{\sm}{{\sim}}
\newcommand{\tup}[1]{(#1)}
\newcommand{\Mod}{\text{Mod}}
\newcommand{\Con}{\text{Con}}
\newcommand{\Pre}{\text{Pre}}
$

parse(.)
next: 
<method-wrapper '__next__' of generator object at 0x0000024399CA69D0>
tokenize(.)
token:
.


In [7]:
# IF statement
p(r"""$1+2
\cdot 3$
""", 1)

first '$' at: 0
second '$' at: 12
out1: $
parse(1+2
\cdot 3)
next: 
<method-wrapper '__next__' of generator object at 0x0000024399CA6880>
tokenize(1+2
\cdot 3)
if 3
token:
1
tokenize(1+2
\cdot 3)
tokenize(1+2
\cdot 3)
if 3
tokenize(1+2
\cdot 3)
tokenize(1+2
\cdot 3)
if 2
tokenize(1+2
\cdot 3)
tokenize(1+2
\cdot 3)
if 3
Abstract syntax tree: ("+","1", ("\cdot","2", "3"))
Expression: 1 + 2*3
out2: $
out3: $1+2
\cdot 3$


$1+2
\cdot 3$
