# TODO
1. Create Rules
2. Create ordering system

In [1]:
import copy
from numbers import Number

In [2]:
class Expression(): #Base Expression Class
    
    #Addition
    def __add__(self, other):
        return BinaryExpression(self, "+", other)
    def __radd__(self, other):
        return BinaryExpression(other, "+", self)
    
    #Subtraction
    def __sub__(self, other):
        return BinaryExpression(self, "-", other)
    def __rsub__(self, other):
        return BinaryExpression(other, "-", self)
    
    #Multiplication
    def __mul__(self, other):
        return BinaryExpression(self, "*", other)
    def __rmul__(self, other):
        return BinaryExpression(other, "*", self)
    
    #Division
    def __truediv__(self, other):
        return BinaryExpression(self, "/", other)
    def __rtruediv__(self, other):
        return BinaryExpression(other, "/", self)
    
    #Exponents
    def __pow__(self, other):
        return BinaryExpression(self, "^", other)
    def __rpow__(self, other):
        return BinaryExpression(other, "^", self)

In [3]:
class BinaryExpression(Expression): #Binary Expression Class
    left = None
    operator = None
    right = None
    
    def __init__(self, left, operator, right):
        self.left = left
        self.operator = operator
        self.right = right
        
    def __hash__(self):
        return hash(self.generateTreeString())
    
    def __eq__(self, other):
        return hash(self) == hash(other)
        
    def containsAdditionOrSubtraction(self): #Test if expression or children contain addition or subtraction
        if self.operator in ["+", "-"]: #Check self
            return True
        else:
            if isinstance(self.left, BinaryExpression): #Check left child
                leftContains = self.left.containsAdditionOrSubtraction()
            else:
                leftContains = False
            if isinstance(self.right, BinaryExpression): #Check right child
                rightContains = self.right.containsAdditionOrSubtraction()
            else:
                rightContains = False
            return leftContains or rightContains
        
    def generateTreeString(self):
        indentChar = "  "
        if isinstance(self.left, BinaryExpression):
            leftString = indentChar + self.left.generateTreeString().replace("\n", "\n" + indentChar)
        else:
            leftString = indentChar + str(self.left)
        if isinstance(self.right, BinaryExpression):
            rightString = indentChar + self.right.generateTreeString().replace("\n", "\n" + indentChar)
        else:
            rightString = indentChar + str(self.right)
        return self.operator + "\n" + leftString + "\n" + rightString
    
    def __repr__(self):
        return self.generateTreeString()
    
    def symbolSubstitution(self, substitutionDict): #NOT IN PLACE
        if not isinstance(substitutionDict, dict):
            raise TypeError("parameter substitutionDict must be a dictionary!")
        else:
            #Left
            newLeft = self.left
            if isinstance(self.left, BinaryExpression): #Recurse
                newLeft = self.left.symbolSubstitution(substitutionDict)
            elif isinstance(self.left, Symbol): #Substitute Symbol if in dict
                if self.left in substitutionDict.keys():
                    newLeft = substitutionDict[self.left]
            #Right
            newRight = self.right
            if isinstance(self.right, BinaryExpression): #Recurse
                newRight = self.right.symbolSubstitution(substitutionDict)
            elif isinstance(self.right, Symbol): #Substitute Symbol if in dict
                if self.right in substitutionDict.keys():
                    newRight = substitutionDict[self.right]
        return BinaryExpression(newLeft, copy.deepcopy(self.operator), newRight)
    
    #Generate replacement BinExp from pattern, capGroups, and capModifiers
    #Output is invariant with self
    def generateReplacement(self, replace, capGroups=None, capModifiers=None): #NOT IN PLACE
        if capGroups is None: capGroups = {}
        if capModifiers is None: capModifiers = {}
        newExpr = copy.deepcopy(replace)
        substitutionDict = {}
        for groupNum in capGroups.keys():
            if groupNum in capModifiers.keys():
                substitutionDict[Symbol(str(groupNum))] = capModifiers[groupNum](capGroups[groupNum])
            else:
                substitutionDict[Symbol(str(groupNum))] = capGroups[groupNum]
        if len(substitutionDict.keys()) > 0:
            newExpr = newExpr.symbolSubstitution(substitutionDict)
        return newExpr

    #Match the given pattern against self, return (result, CaptureGroups)
    #Can match from an array of patterns
    #Behavior is undefined if different patterns have different capture groups
    def matchAndGather(self, match, capFilters=None): #Returns (ResultBool, CaptureGroups)
        if capFilters is None:
            capFilters = {}
        if hasattr(match, "__iter__"):
            for pattern in match:
                result = self.matchAndGather(match=pattern, capFilters=capFilters)
                if result[0]:
                    return result
            #None of the match patterns matched
            return (False, {})
        else:
            #Check if self matches
            if isinstance(self, match.__class__): #Self is correct class
                if self.operator == match.operator: #Self has correct operator
                    #Check children
                    selfChildren = (self.left, self.right)
                    matchChildren = (match.left, match.right)
                    results = []
                    for i in range(0, 2): #Iterate through left and right
                        if isinstance(matchChildren[i], Symbol): #Match child is Symbol
                            try: #The symbol name is a number, denoting a capture group
                                capGroupNum = int(matchChildren[i].name)
                                if capGroupNum in capFilters: #There is a filter
                                    if capFilters[capGroupNum](selfChildren[i]): #The filter evaluates True
                                        results.append((True, {capGroupNum: selfChildren[i]}))
                                    else: #The filter evaluates False
                                        results.append((False, {}))
                                else: #There is no filter
                                    results.append((True, {capGroupNum: selfChildren[i]}))
                            except ValueError: #The symbol name is not a number
                                if matchChildren[i] == selfChildren[i]: #The symbols match
                                    results.append((True, {}))
                                else: #The symbols do not match
                                    results.append((False, {}))
                        elif isinstance(matchChildren[i], BinaryExpression): #Match child is BinExp
                            if isinstance(selfChildren[i], BinaryExpression):
                                results.append(selfChildren[i].matchAndGather(matchChildren[i], capFilters))
                            else:
                                results.append((False, {}))
                        else:
                            raise ValueError("Unexpected class in match object: " + str(matchChildren[i].__class__))
                    if results[0][0] and results[1][0]: #Both children matched
                        totalResults = (True, {})
                        for r in results:
                            totalResults[1].update(r[1])
                        return totalResults
            elif isinstance(match, Symbol):
                try: #The symbol name is a number, denoting a capture group
                    capGroupNum = int(match.name)
                    if capGroupNum in capFilters: #There is a filter
                        if capFilters[capGroupNum](self): #The filter evaluates True
                            return(True, {capGroupNum: self})
                        else: #The filter evaluates False
                            return (False, {})
                    else: #There is no filter
                        return (True, {capGroupNum: self})
                except ValueError: #The symbol name is not a number
                    if match == self: #The symbols match
                        return (True, {})
                    else: #The symbols do not match
                        return (False, {})
            else:
                raise ValueError("unsupported type for match object")
            return (False, {})

    #Tries to match pattern against self, generates replacement if match succeeds
    #Returns (result, replacement)
    def matchAndGenReplace(self, match, replace, capFilters=None, capModifiers=None):
        result, capGroups = self.matchAndGather(match, capFilters=capFilters)
        replacement = None
        if result:
            replacement = self.generateReplacement(replace, capGroups=capGroups, capModifiers=capModifiers)
        return (result, replacement)
    
    #Applies the rule to self, returns (result, BinExp)
    #Note that the resulting BinExp can be the same as self if the rule does not match
    def applyRule(self, rule):
        return self.matchAndGenReplace(
            match = rule.match,
            replace = rule.replace,
            capFilters = rule.capFilters,
            capModifiers = rule.capModifiers
        )
        
    
    #Applies the ruleset to self via a dfs algorithm
    #Recursively re-calls on self when a rule causes a change
    def applyRuleset(self, ruleset):
        #Apply to children (dfs)
        newLeft = self.left.applyRuleset(ruleset) if isinstance(self.left, BinaryExpression) else self.left
        newRight = self.right.applyRuleset(ruleset) if isinstance(self.right, BinaryExpression) else self.right
        newSelf = BinaryExpression(newLeft, self.operator, newRight)
        #Iterate rules
        for rule in ruleset.ruleList:
            result, changedSelf = newSelf.applyRule(rule)
            if result:
                #newSelf was changed by rule, re-call on newSelf with change
                return changedSelf.applyRuleset(ruleset) if isinstance(changedSelf, BinaryExpression) else changedSelf
        #No changes made to newSelf, return newSelf
        return newSelf
    
    def matchRule(self, matchRule):
        return self.matchAndGather(matchRule.match, matchRule.capFilters)

In [4]:
class Rule():
    
    def __init__(self, match, replace, capFilters=None, capModifiers=None):
        self.match = match
        self.replace = replace
        self.capFilters=capFilters
        self.capModifiers=capModifiers

In [5]:
class Ruleset():
    
    def __init__(self, ruleList):
        self.ruleList = ruleList

In [6]:
class MatchRule():
    
    def __init__(self, match, capFilters=None):
        self.match = match
        self.capFilters = capFilters

In [7]:
class Term(BinaryExpression): #Term Class (Binary Expression tree containing no + or - which can only be + or - with other expressions)
    
    def __init__(self):
        pass
    #Remove ability to multiply, divide, power?
    
    def reorderTerms(self):
        pass
    
    def expand(self):
        return self

In [8]:
class Symbol(Expression): #Symbol Class
    name = None
    
    def __init__(self, name):
        self.name = name
        
    def __str__(self):
        return self.name
    
    def __eq__(self, other):
        if isinstance(other, Symbol):
            if self.name == other.name:
                return True
        return False
    
    def __hash__(self):
        return hash(self.name)
    
    def symbolSubstitution(self, substitutionDict):
        if self in substitutionDict.keys():
            return copy.deepcopy(substitutionDict[self])
        
    def generateTreeString(self):
        return self.name
        
    def identityHash(self): #BAD! not necessarily identity because charcter codes have differing lengths
        if len(self.name) > 32:
            raise ValueError("name length > 32 not supported!")
        else:
            ords = [str(ord(c)) for c in self.name]
            hash_string = "".join(ords)
            hash_int = int(hash_string)
            return hash_int

In [9]:
class Rules():
    
    """#Negative Exponents
    negativeExponentRule = Rule(
        match = Symbol("0") ** Symbol("1"),
        replace = 1 / (Symbol("0") ** Symbol("1")),
        capFilters = {1: lambda i: i<0 if isinstance(i, int) else False},
        capModifiers = {1: lambda i: i * -1}
    )"""
    
    #No Division
    noDivisionRule = Rule(
        match = Symbol("0") / Symbol("1"),
        replace = Symbol("0") * (Symbol("1") ** -1)
    )
    
    #Identity Exponents
    identityExponentRule = Rule(
        match = Symbol("0") ** Symbol("1"),
        replace = Symbol("0"),
        capFilters = {1: lambda i: i == 1 if isinstance(i, int) else False}
    )
    
    #Exponent Power Rule
    powerRule = Rule(
        match = (Symbol("0") ** Symbol("1")) ** Symbol("2"),
        replace = Symbol("0") ** (Symbol("1") * Symbol("2"))
    )
    
    #Zero Exponents
    zeroExponentRule = Rule(
        match = Symbol("0") ** Symbol("1"),
        replace = Symbol("0"),
        capFilters = {1: lambda i: i == 0 if isinstance(i, int) else False},
        capModifiers = {0: lambda i: 1}
    )
    
    #Multiplicative Identity
    identityMultiplicationRule = Rule(
        match = [
            Symbol("0") * Symbol("1"),
            Symbol("1") * Symbol("0")
        ],
        replace = Symbol("0"),
        capFilters = {1: lambda i: i == 1 if isinstance(i, int) else False}
    )
    
    #No Subtraction (prefer addition)
    noSubtractionRule = Rule(
        match = Symbol("0") - Symbol("1"),
        replace = Symbol("0") + (-1 * Symbol("1"))
    )
    
    #Multiplication Expansion
    multiplicationExpandRule = Rule(
        match = [
            (Symbol("0") + Symbol("1")) * Symbol("2"),
            Symbol("2") * (Symbol("0") + Symbol("1"))
        ],
        replace = (Symbol("0") * Symbol("2")) + (Symbol("1") * Symbol("2"))
    )
    
    #Fix Exponents First
    exponentSimplifyRule = Rule(
        match = Symbol("0") ** Symbol("1"),
        replace = Symbol("0") ** Symbol("1"),
        capFilters = {1: lambda e: isinstance(e, BinaryExpression)},
        capModifiers= {1: lambda e: e.applyRule(Rules.multiplicationExpandRule)}
    )
    
    #Right Heavy Addition Rule
    rightHeavyAdditionRule = Rule(
        match = (Symbol("0") + Symbol("1")) + Symbol("2"),
        replace = Symbol("0") + (Symbol("1") + Symbol("2"))
    )
    
    #Left Constant Add Match Rule
    leftConstAddMatchRule = MatchRule(
        match = Symbol("0") + Symbol("1"),
        capFilters = {0: lambda e: isinstance(e, Number)}
    )
    
    #Constant Add Evaluable Match Rule
    addEvaluableMatchRule = MatchRule(
        match = Symbol("0"),
        capFilters = {0: lambda e: e.matchRule(Rules.leftConstAddMatchRule) \
                      and (e.right.matchRule(Rules.leftConstAddMatchRule) if isinstance(e.right, BinaryExpression) \
                          else isinstance(e.right, Number)) \
                      if isinstance(e, BinaryExpression) else False}
    )
    
    #Constant Add Evaluation Rule
    constAddEvalRule = Rule(
        match = Symbol("0"),
        replace = Symbol("0"),
        capFilters = {0: lambda e: e.matchRule(Rules.addEvaluableMatchRule)[0]},
        capModifiers = {0: lambda e: BinaryExpression(e.left + e.right.left, "+", e.right.right) \
                        if isinstance(e.right, BinaryExpression) else e.left + e.right}
    )
    
    #Left Constant Multiply Match Rule
    """leftConstMultOrDivMatchRule = MatchRule(
        match = [Symbol("0") * Symbol("1"), Symbol("0") / Symbol("1")],
        capFilters = {0: lambda e: isinstance(e, Number)}
    )"""
    
    #TODO fix this mess
    """#Constant Multiply Evaluable Match Rule
    multEvaluableMatchRule = MatchRule(
        match = Symbol("0"),
        capFilters = {0: lambda e: e.matchRule(Rules.leftConstMultOrDivMatchRule) and e.operator == "*" \
                     and (e.right.matchRule(Rules.leftConstMultOrDivMatchRule) if isinstance(e.right, BinaryExpression) \
                         else isinstance(e.right, Number)) \
                     if isinstance(e, BinaryExpression) else False}
    )"""
    
    #No Right Mult By Constant Rule
    noConstRightMultRule = Rule(
        match = Symbol("0") * Symbol("1"),
        replace = Symbol("1") * Symbol("0"),
        capFilters = {0: lambda e: not isinstance(e, Number),
                      1: lambda e: isinstance(e, Number)},
    )
    
    #Left Mult by Constant Match Rule
    leftConstMult = MatchRule(
        match = Symbol("0") * Symbol("1"),
        capFilters = {0: lambda e: isinstance(e, Number)}
    )
    
    #Constant Multiply Evaluable Match Rule
    multEvaluableMatchRule = MatchRule(
        match = Symbol("0"),
        capFilters = {0: lambda e: isinstance(e, BinaryExpression) and e.matchRule(Rules.leftConstMult)[0] \
                      and (isinstance(e.right, Number) or (e.right.matchRule(Rules.leftConstMult)[0] \
                           if isinstance(e.right, BinaryExpression) else False))}
    )

    constMultEvalRule = Rule(
        match = Symbol("0"),
        replace = Symbol("0"),
        capFilters = {0: lambda e: isinstance(e, BinaryExpression) and e.matchRule(Rules.multEvaluableMatchRule)[0]},
        capModifiers = {0: lambda e: e.left * e.right if isinstance(e.right, Number) \
                        else (e.left * e.right.left) * e.right.right}
    )
    
    #Defualt Ruleset
    defaultRuleset = Ruleset([
        noSubtractionRule,
        noDivisionRule,
        noConstRightMultRule,
        constMultEvalRule,
        multiplicationExpandRule
    ])

In [10]:
class Equality(): #Equality class
    pass

In [104]:
x = Symbol("x")
y = Symbol("y")
z = Symbol("z")
a = Symbol("a")
b = Symbol("b")

In [105]:
e = 3 * (5 * x)
print(e.generateTreeString())
print("\n")
#print(e.applyRule(Rules.constMultEvalRule)[0])
e = e.applyRule(Rules.constMultEvalRule)[1]
print(e.generateTreeString())

*
  3
  *
    5
    x


*
  15
  x


In [106]:
e = 3 * ((y + 5) + z)
e = e.applyRuleset(Rules.defaultRuleset)
print(e.generateTreeString())

+
  +
    *
      3
      y
    15
  *
    3
    z


In [107]:
e = a * (x + y)
e = e.applyRuleset(Rules.defaultRuleset)
#e = e.applyRule(Rules.multiplicationExpandRule)
print(e.generateTreeString())

+
  *
    x
    a
  *
    y
    a


In [108]:
e = z * 3
e = e.applyRuleset(Rules.defaultRuleset)
print(e.generateTreeString())

*
  3
  z


In [114]:
e = 2 * (3 * x + 3 * y * 4)
print(e.generateTreeString())
print("\n")
e = e.applyRuleset(Rules.defaultRuleset)
print(e.generateTreeString())
print("\n")
e = (3 * x + 3 * y * 4) * 2
e = e.applyRuleset(Rules.defaultRuleset)
print(e.generateTreeString())

*
  2
  +
    *
      3
      x
    *
      *
        3
        y
      4


+
  *
    6
    x
  *
    24
    y


+
  *
    6
    x
  *
    24
    y


In [84]:
print(e.applyRule(Rules.constMultEvalRule))

(False, None)


In [26]:
e = 6 * (5 * x)
print(e.applyRule(Rules.constMultEvalRule))

NameError: name 'x' is not defined

In [10]:
import unittest

In [24]:
class TestRules(unittest.TestCase):
    """Testing class for Rules"""
    x = Symbol("x")
    y = Symbol("y")
    z = Symbol("z")
    
    def testNoDivisionRule(self):
        e = self.x / self.y
        e = e.applyRule(Rules.noDivisionRule)[1]
        self.assertEqual(e, self.x * self.y ** -1)
        
        e = 2 / self.x
        e = e.applyRule(Rules.noDivisionRule)[1]
        self.assertEqual(e, 2 * self.x ** -1)
        
        e = self.x / 3
        e = e.applyRule(Rules.noDivisionRule)[1]
        self.assertEqual(e, self.x * BinaryExpression(3, "^", -1))
        
        e = BinaryExpression(3, "/", 2)
        e = e.applyRule(Rules.noDivisionRule)[1]
        self.assertEqual(e, 3 * BinaryExpression(2, "^", -1))
        
        e = self.x * self.y
        self.assertEqual(e.applyRule(Rules.noDivisionRule)[0], False)
        
    def testNoSubtractionRule(self):
        e = self.x - self.y
        e = e.applyRule(Rules.noSubtractionRule)[1]
        self.assertEqual(e, self.x + -1 * self.y)
        
        e = 2 - self.x
        e = e.applyRule(Rules.noSubtractionRule)[1]
        self.assertEqual(e, 2 + -1 * self.x)
        
        e = self.x - 3
        e = e.applyRule(Rules.noSubtractionRule)[1]
        self.assertEqual(e, self.x + BinaryExpression(-1, "*", 3))
        
        e = BinaryExpression(3, "-", 2)
        e = e.applyRule(Rules.noSubtractionRule)[1]
        self.assertEqual(e, 3 + BinaryExpression(-1, "*", 2))
        
        e = self.x + self.y
        self.assertEqual(e.applyRule(Rules.noSubtractionRule)[0], False)
        
    def testNoConstRightMultRule(self):
        e = self.x * 5
        e = e.applyRule(Rules.noConstRightMultRule)[1]
        self.assertEqual(e, 5 * self.x)
        
        e = self.x * 5.0
        e = e.applyRule(Rules.noConstRightMultRule)[1]
        self.assertEqual(e, 5.0 * self.x)
        
        e = 2 * self.x
        self.assertEqual(e.applyRule(Rules.noConstRightMultRule)[0], False)
        
        e = self.x * self.y
        self.assertEqual(e.applyRule(Rules.noConstRightMultRule)[0], False)
        
    def testConstMultEvalRule(self):
        #Note that it only applies to left const mult
        e = 2 * self.x
        self.assertEqual(e.applyRule(Rules.constMultEvalRule)[0], False)
        
        e = BinaryExpression(2, "*", 5)
        e = e.applyRule(Rules.constMultEvalRule)[1]
        self.assertEqual(e, 10)
        
        e = 2 * (3 * self.x)
        e = e.applyRule(Rules.constMultEvalRule)[1]
        self.assertEqual(e, 6 * self.x)
        
    def testMultiplicationExpandRule(self):
        e = 2 * (self.x + self.y)
        e = e.applyRule(Rules.multiplicationExpandRule)[1]
        self.assertEqual(e, self.x * 2 + self.y * 2)
        
        e = 5.5 * (self.x + self.y)
        e = e.applyRule(Rules.multiplicationExpandRule)[1]
        self.assertEqual(e, self.x * 5.5 + self.y * 5.5)
        
        e = (self.x + self.y) * self.z
        e = e.applyRule(Rules.multiplicationExpandRule)[1]
        self.assertEqual(e, self.x * self.z + self.y * self.z)

        
        
class TestRuleset(unittest.TestCase):
    "Testing case for the default ruleset"
    x = Symbol("x")
    y = Symbol("y")
    z = Symbol("z")

    def testCase0(self):
        e = 2 * self.x * 1.5
        e = e.applyRuleset(Rules.defaultRuleset)
        self.assertEqual(e, 3.0 * self.x)
        
    def testCase1(self):
        e = self.x * self.y
        e = e.applyRuleset(Rules.defaultRuleset)
        self.assertEqual(e, self.x * self.y)
        
    def testCase2(self):
        e = 10 * (self.x + 5)
        e = e.applyRuleset(Rules.defaultRuleset)
        self.assertEqual(e, 10 * self.x + 50)
        
        
        

unittest.main(argv=[''], verbosity=2, exit=False)
pass

testConstMultEvalRule (__main__.TestRules) ... ok
testMultiplicationExpandRule (__main__.TestRules) ... ok
testNoConstRightMultRule (__main__.TestRules) ... ok
testNoDivisionRule (__main__.TestRules) ... ok
testNoSubtractionRule (__main__.TestRules) ... ok
testCase0 (__main__.TestRuleset) ... ok
testCase1 (__main__.TestRuleset) ... ok
testCase2 (__main__.TestRuleset) ... ok

----------------------------------------------------------------------
Ran 8 tests in 0.010s

OK
