# TODO
1. dual denominator expansion
1. variable substitution

In [140]:
import copy

In [1]:
class Expression(): #Base Expression Class
    
    #Addition
    def __add__(self, other):
        return BinaryExpression(self, "+", other)
    def __radd__(self, other):
        return BinaryExpression(other, "+", self)
    
    #Subtraction
    def __sub__(self, other):
        return BinaryExpression(self, "-", other)
    def __rsub__(self, other):
        return BinaryExpression(other, "-", self)
    
    #Multiplication
    def __mul__(self, other):
        return BinaryExpression(self, "*", other)
    def __rmul__(self, other):
        return BinaryExpression(other, "*", self)
    
    #Division
    def __truediv__(self, other):
        return BinaryExpression(self, "/", other)
    def __rtruediv__(self, other):
        return BinaryExpression(other, "/", self)
    
    #Exponents
    def __pow__(self, other):
        return BinaryExpression(self, "^", other)
    def __rpow__(self, other):
        return BinaryExpression(other, "^", self)

In [141]:
class BinaryExpression(Expression): #Binary Expression Class
    left = None
    operator = None
    right = None
    
    def __init__(self, left, operator, right):
        self.left = left
        self.operator = operator
        self.right = right
        
    def __hash__(self):
        return hash(self.generateTreeString())
        
    def containsAdditionOrSubtraction(self): #Test if expression or children contain addition or subtraction
        if self.operator in ["+", "-"]: #Check self
            return True
        else:
            if isinstance(self.left, BinaryExpression): #Check left child
                leftContains = self.left.containsAdditionOrSubtraction()
            else:
                leftContains = False
            if isinstance(self.right, BinaryExpression): #Check right child
                rightContains = self.right.containsAdditionOrSubtraction()
            else:
                rightContains = False
            return leftContains or rightContains
        
    def generateTreeString(self):
        indentChar = "  "
        if isinstance(self.left, BinaryExpression):
            leftString = indentChar + self.left.generateTreeString().replace("\n", "\n" + indentChar)
        else:
            leftString = indentChar + str(self.left)
        if isinstance(self.right, BinaryExpression):
            rightString = indentChar + self.right.generateTreeString().replace("\n", "\n" + indentChar)
        else:
            rightString = indentChar + str(self.right)
        return self.operator + "\n" + leftString + "\n" + rightString
    
    def symbolSubstitution(self, substitutionDict): #NOT IN PLACE
        if not isinstance(substitutionDict, dict):
            raise TypeError("parameter substitutionDict must be a dictionary!")
        else:
            #Left
            newLeft = self.left
            if isinstance(self.left, BinaryExpression): #Recurse
                newLeft = self.left.symbolSubstitution(substitutionDict)
            elif isinstance(self.left, Symbol): #Substitute Symbol if in dict
                if self.left in substitutionDict.keys():
                    newLeft = substitutionDict[self.left]
            #Right
            newRight = self.right
            if isinstance(self.right, BinaryExpression): #Recurse
                newRight = self.right.symbolSubstitution(substitutionDict)
            elif isinstance(self.right, Symbol): #Substitute Symbol if in dict
                if self.right in substitutionDict.keys():
                    newRight = substitutionDict[self.right]
        return BinaryExpression(newLeft, self.operator, newRight)
    
    def replacePattern(self, replace, capGroups=None, capModifiers=None):
        if capGroups is None: capGroups = {}
        if capModifiers is None: capModifiers = {}
        newExpr = copy.deepcopy(replace)
        substitutionDict = {}
        for groupNum in capGroups.keys():
            if groupNum in capModifiers.keys():
                substitutionDict[Symbol(str(groupNum))] = capModifiers[groupNum](capGroups[groupNum])
            else:
                substitutionDict[Symbol(str(groupNum))] = capGroups[groupNum]
        newExpr = newExpr.symbolSubstitution(substitutionDict)
        self.__dict__ = newExpr.__dict__
        return self

    def matchAndGather(self, match, capFilters=None): #Returns (ResultBool, CaptureGroups)
        if capFilters is None: capFilters = {}
        #Check if self matches
        if isinstance(self, match.__class__): #Self is correct class
            if self.operator == match.operator: #Self has correct operator
                #Check children
                selfChildren = (self.left, self.right)
                matchChildren = (match.left, match.right)
                results = []
                for i in range(0, 2): #Iterate through left and right
                    if isinstance(matchChildren[i], Symbol): #Match child is Symbol
                        try: #The symbol name is a number, denoting a capture group
                            capGroupNum = int(matchChildren[i].name)
                            if capGroupNum in capFilters:
                                if capFilters[capGroupNum](selfChildren[i]):
                                    results.append((True, {capGroupNum: selfChildren[i]}))
                                else:
                                    results.append((False, {}))
                            else:
                                results.append((True, {capGroupNum: selfChildren[i]}))
                        except ValueError: #The symbol name is not a number
                            if matchChildren[i] == selfChildren[i]: #The symbols match
                                results.append((True, {}))
                            else: #The symbols do not match
                                results.append((False, {}))
                    elif isinstance(matchChildren[i], BinaryExpression): #Match child is BinExp
                        results.append(selfChildren[i].matchAndGather(matchChildren[i]), capFilters)
                if results[0][0] and results[1][0]: #Both children matched
                    totalResults = (True, {})
                    for r in results:
                        totalResults[1].update(r[1])
                    return totalResults
        return (False, {})
    
    def matchAndReplace(self, match, replace, capFilters=None, capModifiers=None):
        matchResults = self.matchAndGather(match, capFilters)
        if matchResults[0]:
            self = self.replacePattern(replace, matchResults[1], capModifiers)
        return self
    
    def applyRule(self, rule):
        self = self.matchAndReplace(rule.match, rule.replace, capFilters=rule.capFilters, capModifiers=rule.capModifiers)
        return self
        
    def expandExponents(self):
        #Expand children if necessary
        if isinstance(self.left, BinaryExpression):
            self.left = self.left.expandExponents()
        if isinstance(self.right, BinaryExpression):
            self.right = self.right.expandExponents()
        #Expand self if necessary
        if self.operator == "^": #Current node is exponentiation
            if isinstance(self.right, int):
                if self.right > 2:
                    self.operator = "*"
                    self.right = BinaryExpression(self.left, "^", self.right - 1).expandExponents()
                elif self.right == 2:
                    self.operator = "*"
                    self.right = self.left
                elif self.right == 1:
                    left = self.left
                    self.__class__ = left.__class__
                    self.__dict__ = left.__dict__
                elif self.right <= -1:
                    self.left = BinaryExpression(1, "/", self.left)
                    self.right = -1 * self.right
                    self.__dcit__ = self.expandExponents().__dict__
        return self
    
    def expandFracMul(self):
        if isinstance(self.left, BinaryExpression):
            self.left = self.left.expandFracMul()
        if isinstance(self.right, BinaryExpression):
            self.right = self.right.expandFracMul()
        if self.operator == "*":
            if isinstance(self.left, BinaryExpression) and isinstance(self.right, BinaryExpression):
                if self.left.operator == "/" and self.right.operator == "/":
                    oldLeft = self.left
                    oldRight = self.right
                    self.operator = "/"
                    self.left = BinaryExpression(oldLeft.left, "*", oldRight.left)
                    self.right = BinaryExpression(oldLeft.right, "*", oldRight.right)
        elif self.operator == "/":
            if isinstance(self.left, BinaryExpression) and isinstance(self.right, BinaryExpression):
                if self.left.operator == "/" and self.right.operator == "/":
                    self.operator = "*"
                    self.right.left, self.right.right = self.right.right, self.right.left
                    return self.expandFracMul()
                    #oldLeft = self.left
                    #oldRight = self.right
                    #self.operator = "/"
                    #self.left = BinaryExpression(oldLeft.left, "*", oldRight.right)
                    #self.right = BinaryExpression(oldLeft.right, "*", oldRight.left)
        return self
    
    def expand(self):
        #First expand exponents in place
        self.expandExponents()
        #Expand fraction multiplication/division in place
        self.expandFracMul()
        #Check if expansion at all necessary
        if self.containsAdditionOrSubtraction():
            #Expand children if necessary
            if isinstance(self.left, BinaryExpression):
                self.left = self.left.expand()
            if isinstance(self.right, BinaryExpression):
                self.right = self.right.expand()
            #Expand self if necessary
            if self.operator in ["*", "/"] and (isinstance(self.left, BinaryExpression) or isinstance(self.right, BinaryExpression)):
                binexpChild = None
                distributeOperator = self.operator
                if isinstance(self.left, BinaryExpression): #Check if left child is binary expression
                    if self.left.operator in ["+", "-"]: #Check if left expression is addition or subtraction
                        binexpChild = self.left
                        otherChild = self.right
                if isinstance(self.right, BinaryExpression) and self.operator == "*": #Check if right child is binary expression
                    if self.right.operator in ["+", "-"]: #Check if right expression is addition or subtraction
                        binexpChild = self.right
                        otherChild = self.left
                if binexpChild is None: #Neither child was binary addition or subtraction
                    return self
                else: #One child expression was binary addition or subtraction
                    self.operator = binexpChild.operator
                    self.left = BinaryExpression(binexpChild.left, distributeOperator, otherChild).expand()
                    self.right = BinaryExpression(binexpChild.right, distributeOperator, otherChild).expand()
        #Expand fraction multiplication/division again if necessary
        self.expandFracMul()
        return self

In [57]:
class Rule():
    
    def __init__(self, match, replace, capFilters=None, capModifiers=None):
        self.match = match
        self.replace = replace
        self.capFilters=capFilters
        self.capModifiers=capModifiers

In [3]:
class Term(BinaryExpression): #Term Class (Binary Expression tree containing no + or - which can only be + or - with other expressions)
    
    def __init__(self):
        pass
    #Remove ability to multiply, divide, power?
    
    def reorderTerms(self):
        pass
    
    def expand(self):
        return self

In [4]:
class Symbol(Expression): #Symbol Class
    name = None
    
    def __init__(self, name):
        self.name = name
        
    def __str__(self):
        return self.name
    
    def __eq__(self, other):
        if isinstance(other, Symbol):
            if self.name == other.name:
                return True
        return False
    
    def __hash__(self):
        return hash(self.name)
        
    def identityHash(self): #BAD! not necessarily identity because charcter codes have differing lengths
        if len(self.name) > 32:
            raise ValueError("name length > 32 not supported!")
        else:
            ords = [str(ord(c)) for c in self.name]
            hash_string = "".join(ords)
            hash_int = int(hash_string)
            return hash_int

In [5]:
class Equality(): #Equality class
    pass

In [73]:
x = Symbol("x")
y = Symbol("y")
z = Symbol("z")
a = Symbol("a")
b = Symbol("b")

In [162]:
m = Symbol("0") ** Symbol("1")
r = 1 / (Symbol("0") ** Symbol("1"))
cf = {1: lambda i: i<0 if isinstance(i, int) else False}
cm = {1: lambda i: i * -1}
negativeExponentRule = Rule(m, r, cf, cm)

In [176]:
e = Symbol("x") ** -1

In [177]:
print(e.generateTreeString())

^
  x
  -1


In [178]:
e.applyRule(negativeExponentRule)

<__main__.BinaryExpression at 0x7f9738373cf8>

In [179]:
print(e.generateTreeString())

/
  1
  ^
    x
    1


In [101]:
m = Symbol("0") ** Symbol("1")

In [102]:
r = 1 / (Symbol("0") ** Symbol("1"))

In [66]:
e.matchAndReplace(m, r, capFilters={1: lambda i: i<0 if isinstance(i, int) else False}, capModifiers={1: lambda i: i * -1})
print(e.generateTreeString())

/
  1
  ^
    x
    1


In [39]:
print(r.generateTreeString())

/
  1
  ^
    0
    1


In [40]:
result = e.matchAndGather(m, capFilters={1: lambda i: i<0 if isinstance(i, int) else False})

In [41]:
print(result)

(True, {0: <__main__.Symbol object at 0x7f63d81a9828>, 1: -1})


In [43]:
e.replacePattern(r, capGroups=result[1], capModifiers={1: lambda i: i * -1})

<__main__.BinaryExpression at 0x7f63d81a92e8>

In [44]:
print(e.generateTreeString())

/
  1
  ^
    x
    1


In [237]:
f = lambda i: i<0 if isinstance(i, int) else False

In [238]:
f("hello")

False

In [62]:
print(e.right.)

^
