In [1]:
import copy
from numbers import Number

In [2]:
class Expression():
    """Abstract expression class."""
    
    #Addition
    def __add__(self, other):
        return BinaryExpression(self, "+", other)
    def __radd__(self, other):
        return BinaryExpression(other, "+", self)
    
    #Subtraction
    def __sub__(self, other):
        return BinaryExpression(self, "-", other)
    def __rsub__(self, other):
        return BinaryExpression(other, "-", self)
    
    #Multiplication
    def __mul__(self, other):
        return BinaryExpression(self, "*", other)
    def __rmul__(self, other):
        return BinaryExpression(other, "*", self)
    
    #Division
    def __truediv__(self, other):
        return BinaryExpression(self, "/", other)
    def __rtruediv__(self, other):
        return BinaryExpression(other, "/", self)
    
    #Exponents
    def __pow__(self, other):
        return BinaryExpression(self, "^", other)
    def __rpow__(self, other):
        return BinaryExpression(other, "^", self)

In [3]:
class BinaryExpression(Expression):
    """Class for representing binary expressions such as addition and exponentiation."""
    left = None
    operator = None
    right = None
    
    def __init__(self, left, operator, right):
        self.left = left
        self.operator = operator
        self.right = right
        
    def generateTreeString(self):
        """Generates a string representation of the tree for the BinaryExpression."""
        
        indentChar = "  "
        if isinstance(self.left, BinaryExpression):
            leftString = indentChar + self.left.generateTreeString().replace("\n", "\n" + indentChar)
        else:
            leftString = indentChar + str(self.left)
        if isinstance(self.right, BinaryExpression):
            rightString = indentChar + self.right.generateTreeString().replace("\n", "\n" + indentChar)
        else:
            rightString = indentChar + str(self.right)
        return self.operator + "\n" + leftString + "\n" + rightString
    
    def __hash__(self):
        """Returns hash of the tree string as tree strings are unique to BinaryExpressions."""
        return hash(self.generateTreeString())
    
    def __eq__(self, other):
        """Tests the equality of two Binary Expressions using hashes."""
        return hash(self) == hash(other)
    
    def __repr__(self):
        """Returns tree string starting on a new line."""
        return "\n" + self.generateTreeString()
    
    def symbolSubstitution(self, substitutionDict): #NOT IN PLACE
        """Substitutes in objects for symbols from a dictionary which maps the substitutions. 
        This operation is not in place."""
        if not isinstance(substitutionDict, dict):
            raise TypeError("parameter substitutionDict must be a dictionary!")
        else:
            #Left
            newLeft = self.left
            if isinstance(self.left, BinaryExpression): #Recurse
                newLeft = self.left.symbolSubstitution(substitutionDict)
            elif isinstance(self.left, Symbol): #Substitute Symbol if in dict
                if self.left in substitutionDict.keys():
                    newLeft = substitutionDict[self.left]
            #Right
            newRight = self.right
            if isinstance(self.right, BinaryExpression): #Recurse
                newRight = self.right.symbolSubstitution(substitutionDict)
            elif isinstance(self.right, Symbol): #Substitute Symbol if in dict
                if self.right in substitutionDict.keys():
                    newRight = substitutionDict[self.right]
        return BinaryExpression(newLeft, copy.deepcopy(self.operator), newRight)
    
    def generateReplacement(self, replace, capGroups=None, capModifiers=None): #NOT IN PLACE
        """Substitutes capture groups into a replace pattern. Applies capture modifiers first if supplied.
        This operation is not in place."""
        if capGroups is None: capGroups = {}
        if capModifiers is None: capModifiers = {}
        newExpr = copy.deepcopy(replace)
        substitutionDict = {}
        for groupNum in capGroups.keys():
            if groupNum in capModifiers.keys():
                substitutionDict[Symbol(str(groupNum))] = capModifiers[groupNum](capGroups[groupNum])
            else:
                substitutionDict[Symbol(str(groupNum))] = capGroups[groupNum]
        if len(substitutionDict.keys()) > 0:
            newExpr = newExpr.symbolSubstitution(substitutionDict)
        return newExpr
    
    def matchAndGather(self, match, capFilters=None): #Returns (ResultBool, CaptureGroups)
        """Attempts to match a given pattern against the BinaryExpression object. Applies capture filters if supplied.
        Can also take an array of patterns to match. 
        Behavior is undefined if patterns in the array have different capture groups.
        Returns (resultBool, capGroupsDict)"""
        if capFilters is None:
            capFilters = {}
        if hasattr(match, "__iter__"):
            for pattern in match:
                result = self.matchAndGather(match=pattern, capFilters=capFilters)
                if result[0]:
                    return result
            #None of the match patterns matched
            return (False, {})
        else:
            #Check if self matches
            if isinstance(self, match.__class__): #Self is correct class
                if self.operator == match.operator: #Self has correct operator
                    #Check children
                    selfChildren = (self.left, self.right)
                    matchChildren = (match.left, match.right)
                    results = []
                    for i in range(0, 2): #Iterate through left and right
                        if isinstance(matchChildren[i], Symbol): #Match child is Symbol
                            try: #The symbol name is a number, denoting a capture group
                                capGroupNum = int(matchChildren[i].name)
                                if capGroupNum in capFilters: #There is a filter
                                    if capFilters[capGroupNum](selfChildren[i]): #The filter evaluates True
                                        results.append((True, {capGroupNum: selfChildren[i]}))
                                    else: #The filter evaluates False
                                        results.append((False, {}))
                                else: #There is no filter
                                    results.append((True, {capGroupNum: selfChildren[i]}))
                            except ValueError: #The symbol name is not a number
                                if matchChildren[i] == selfChildren[i]: #The symbols match
                                    results.append((True, {}))
                                else: #The symbols do not match
                                    results.append((False, {}))
                        elif isinstance(matchChildren[i], BinaryExpression): #Match child is BinExp
                            if isinstance(selfChildren[i], BinaryExpression):
                                results.append(selfChildren[i].matchAndGather(matchChildren[i], capFilters))
                            else:
                                results.append((False, {}))
                        else:
                            raise ValueError("Unexpected class in match object: " + str(matchChildren[i].__class__))
                    if results[0][0] and results[1][0]: #Both children matched
                        totalResults = (True, {})
                        for r in results:
                            totalResults[1].update(r[1])
                        return totalResults
            elif isinstance(match, Symbol):
                try: #The symbol name is a number, denoting a capture group
                    capGroupNum = int(match.name)
                    if capGroupNum in capFilters: #There is a filter
                        if capFilters[capGroupNum](self): #The filter evaluates True
                            return(True, {capGroupNum: self})
                        else: #The filter evaluates False
                            return (False, {})
                    else: #There is no filter
                        return (True, {capGroupNum: self})
                except ValueError: #The symbol name is not a number
                    if match == self: #The symbols match
                        return (True, {})
                    else: #The symbols do not match
                        return (False, {})
            else:
                raise ValueError("unsupported type for match object")
            return (False, {})

    def matchAndGenReplace(self, match, replace, capFilters=None, capModifiers=None):
        """Tries to match pattern against self, generates replacement if match succeeds.
        This operation is not performed in place. Returns (resultBool, replacementBinaryExpressionOrNone)"""
        result, capGroups = self.matchAndGather(match, capFilters=capFilters)
        replacement = None
        if result:
            replacement = self.generateReplacement(replace, capGroups=capGroups, capModifiers=capModifiers)
        return (result, replacement)
    
    def applyRule(self, rule):
        """Applies a given rule to the BinaryExpression object.
        Returns (resultBool, replacementBinaryExpressionOrNone)."""
        return self.matchAndGenReplace(
            match = rule.match,
            replace = rule.replace,
            capFilters = rule.capFilters,
            capModifiers = rule.capModifiers
        )
    
    def matchRule(self, matchRule):
        """Matches the given rule against the BinaryExpression object"""
        return self.matchAndGather(matchRule.match, matchRule.capFilters)
        
    def applyRuleset(self, ruleset):
        """Applies the ruleset to self via a dfs algorithm.
        Recursively re-calls on self when the application of a rule leads to a change.
        Could probably be further optimized."""
        #Apply to children (dfs)
        newLeft = self.left.applyRuleset(ruleset) if isinstance(self.left, BinaryExpression) else self.left
        newRight = self.right.applyRuleset(ruleset) if isinstance(self.right, BinaryExpression) else self.right
        newSelf = BinaryExpression(newLeft, self.operator, newRight)
        #Iterate rules
        for rule in ruleset.ruleList:
            result, changedSelf = newSelf.applyRule(rule)
            if result:
                #newSelf was changed by rule, re-call on newSelf with change
                return changedSelf.applyRuleset(ruleset) if isinstance(changedSelf, BinaryExpression) else changedSelf
        #No changes made to newSelf, return newSelf
        return newSelf

In [4]:
class Rule():
    """Class for storing match, replace, capFilters, and capModifiers objects for easy reference and application."""
    
    def __init__(self, match, replace, capFilters=None, capModifiers=None):
        self.match = match
        self.replace = replace
        self.capFilters=capFilters
        self.capModifiers=capModifiers

In [5]:
class MatchRule():
    """Class for storing match and capFilters objects for easy reference and application."""
    
    def __init__(self, match, capFilters=None):
        self.match = match
        self.capFilters = capFilters

In [6]:
class Ruleset():
    """Class for storing multiple rules to be applied using the applyRuleset BinaryExpression algorithm."""
    
    def __init__(self, ruleList):
        self.ruleList = ruleList

In [7]:
class Symbol(Expression):
    """Class for symbolic representation objects. Objects have only a name field."""
    name = None
    
    def __init__(self, name):
        self.name = name
        
    def __str__(self):
        """Returns the name of the Symbol as a string."""
        return self.name
    
    def __eq__(self, other):
        """Tests equality via type and then name if other is also a Symbol."""
        if isinstance(other, Symbol):
            if self.name == other.name:
                return True
        return False
    
    def __hash__(self):
        """Returns a hash based on the name of the Symbol."""
        return hash(self.name)
    
    def __repr__(self):
        """Returns the name of the Symbol as a string."""
        return self.name
    
    def symbolSubstitution(self, substitutionDict):
        """Returns a copy of the object from the substitution dict if self is a key."""
        if self in substitutionDict.keys():
            return copy.deepcopy(substitutionDict[self])
        
    def generateTreeString(self):
        """Returns the name of the Symbol as a string."""
        return self.name

In [8]:
class Rules():
    
    #No Division Rule
    noDivisionRule = Rule(
        match = Symbol("0") / Symbol("1"),
        replace = Symbol("0") * (Symbol("1") ** -1)
    )
    
    #Exponent Power Rule
    expPowerRule = Rule(
        match = (Symbol("0") ** Symbol("1")) ** Symbol("2"),
        replace = Symbol("0") ** (Symbol("1") * Symbol("2"))
    )
    
    #No Subtraction
    noSubtractionRule = Rule(
        match = Symbol("0") - Symbol("1"),
        replace = Symbol("0") + (-1 * Symbol("1"))
    )
    
    #Multiplication Expansion
    multiplicationExpandRule = Rule(
        match = [
            (Symbol("0") + Symbol("1")) * Symbol("2"),
            Symbol("2") * (Symbol("0") + Symbol("1"))
        ],
        replace = (Symbol("0") * Symbol("2")) + (Symbol("1") * Symbol("2"))
    )
    
    #Multiplication expand rule for when both have -1 exponent (denominator multiplication)
    multNegExpExpandRule = Rule(
        match = [
            ((Symbol("0") + Symbol("1")) ** Symbol("3")) * (Symbol("2") ** Symbol("4")),
            (Symbol("2") ** Symbol("3")) * ((Symbol("0") + Symbol("1")) ** Symbol("4"))
        ],
        replace = ((Symbol("0") * Symbol("2")) + (Symbol("1") * Symbol("2"))) ** -1,
        capFilters = {3: lambda e: e == -1, 4: lambda e: e == -1}
    )
    
    #Distribute exponents over multiplication
    expMultDistributeRule = Rule(
        match = (Symbol("0") * Symbol("1")) ** Symbol("2"),
        replace = (Symbol("0") ** Symbol("2")) * (Symbol("1") ** Symbol("2"))
    )
    
    #Distribute positive constant exponents over addition
    expAddExpandRule = Rule(
        match = (Symbol("0") + Symbol("1")) ** Symbol("2"),
        replace = (Symbol("0") + Symbol("1")) * (Symbol("0") + Symbol("1")) ** Symbol("2"),
        capFilters = {2: lambda e: e > 1 if isinstance(e, Number) else False},
        capModifiers = {2: lambda e: e - 1}
    )
    
    #Distribute negative constant exponents over addition
    negExpAddExpandRule = Rule(
        match = (Symbol("0") + Symbol("1")) ** Symbol("2"),
        replace = (Symbol("0") + Symbol("1")) ** -1 * (Symbol("0") + Symbol("1")) ** Symbol("2"),
        capFilters = {2: lambda e: e < -1 if isinstance(e, Number) else False},
        capModifiers = {2: lambda e: e + 1}
    )
    
    #Right Heavy Addition Rule
    rightHeavyAdditionRule = Rule(
        match = (Symbol("0") + Symbol("1")) + Symbol("2"),
        replace = Symbol("0") + (Symbol("1") + Symbol("2"))
    )
    
    #Right Heavy Multiplication Rule
    rightHeavyMultRule = Rule(
        match = (Symbol("0") * Symbol("1")) * Symbol("2"),
        replace = Symbol("0") * (Symbol("1") * Symbol("2"))
    )
    
    def multCompareSortRule(first, second):
        """Compares two factors to determine how they should be ordered. 
        Returns True if they should be swapped and False otherwise."""
        if isinstance(first, Number):
            return False #we want numbers first
        elif isinstance(second, Number):
            return True #we know first is not a number, and if second is a number then they need to switch
        elif isinstance(first, BinaryExpression) and first.operator == "^": #TODO does it matter if the base is a constant?
            return Rules.multCompareSortRule(first.left, second) #we want to sort by bases, not exponents
        elif isinstance(second, BinaryExpression) and second.operator == "^": #TODO ditto
            return Rules.multCompareSortRule(first, second.left) #we want to sort by bases, not exponents
        else:
            return hash(second) < hash(first) #now sort by hashes
        
    def addCompareSortRule(first, second):
        """Compares two addends to determine how they should be ordered.
        Returns True if they should be swapped and False Otherwise."""
        if isinstance(first, Number):
            return False #we want numbers first
        elif isinstance(second, Number):
            return True #we know first is not a number, and if second is a number then they need to switch
        elif isinstance(first, BinaryExpression) and first.operator == "*" and isinstance(first.left, Number):
            return Rules.addCompareSortRule(first.right, second) #we want to sort by term excluding coefficients
        elif isinstance(second, BinaryExpression) and second.operator == "*" and isinstance(second.left, Number):
            return Rules.addCompareSortRule(first, second.right) #we want to sort by term excluding coefficients
        else:
            return hash(second) < hash(first) #now sort by hashes
    
    #Matches mutiplication with a product as the right child (chain multiplication)
    multChainMatchRule = MatchRule(
        match = Symbol("0") * (Symbol("1") * Symbol("2"))
    )

    #Matches multiplication without a product as the right child (non-chain multiplication)
    multMatchRule = MatchRule(
        match = Symbol("0") * Symbol("1"),
        capFilters = {1: lambda e: e.operator != "*" if isinstance(e, BinaryExpression) else True}
    )

    #Rule for sorting multiplication in a chain (where the right child is also a product)
    hashMultChainSortRule = Rule(
        match = Symbol("0"),
        replace = Symbol("0"),
        capFilters = {0: lambda e: Rules.multCompareSortRule(e.left, e.right.left)
                      if e.matchRule(Rules.multChainMatchRule)[0] else False},
        capModifiers = {0: lambda e: BinaryExpression(e.right.left, "*",
                                                  BinaryExpression(e.left, "*", e.right.right))}
    )
    
    #Rule for sorting multiplication not in a chain (where the right child is not a product)
    hashMultSortRule = Rule(
        match = Symbol("0"),
        replace = Symbol("0"),
        capFilters = {0: lambda e: Rules.multCompareSortRule(e.left, e.right)
                      if e.matchRule(Rules.multMatchRule)[0] else False},
        capModifiers = {0: lambda e: BinaryExpression(e.right, "*", e.left)}
    )
    
    #Matches addition with a sum as the right child (chain addition)
    addChainMatchRule = MatchRule(
        match = Symbol("0") + (Symbol("1") + Symbol("2"))
    )
    
    #Matches addition without a sum as the right child (non-chain addition)
    addMatchRule = MatchRule(
        match = Symbol("0") + Symbol("1"),
        capFilters = {1: lambda e: e.operator != "+" if isinstance(e, BinaryExpression) else True}
    )
    
    #Rule for sorting addition in a chain (where the right child is also a sum)
    hashAddChainSortRule = Rule(
        match = Symbol("0"),
        replace = Symbol("0"),
        capFilters = {0: lambda e: Rules.addCompareSortRule(e.left, e.right.left) 
                      if e.matchRule(Rules.addChainMatchRule)[0] else False},
        capModifiers = {0: lambda e: BinaryExpression(e.right.left, "+",
                                                     BinaryExpression(e.left, "+", e.right.right))}
    )
    
    #Rule for sorting addition not in a chain (where the right child is not a sum)
    hashAddSortRule = Rule(
        match = Symbol("0"),
        replace = Symbol("0"),
        capFilters = {0: lambda e: Rules.addCompareSortRule(e.left, e.right) 
                      if e.matchRule(Rules.addMatchRule)[0] else False},
        capModifiers = {0: lambda e: BinaryExpression(e.right, "+", e.left)}
    )
    
    #Left Mult by Constant Match Rule
    leftConstMult = MatchRule(
        match = Symbol("0") * Symbol("1"),
        capFilters = {0: lambda e: isinstance(e, Number)}
    )
    
    #Constant Multiply Evaluable Match Rule
    multEvaluableMatchRule = MatchRule(
        match = Symbol("0"),
        capFilters = {0: lambda e: isinstance(e, BinaryExpression) and e.matchRule(Rules.leftConstMult)[0] \
                      and (isinstance(e.right, Number) or (e.right.matchRule(Rules.leftConstMult)[0] \
                           if isinstance(e.right, BinaryExpression) else False))}
    )

    #Constant Multiplication Evaluation Rule
    constMultEvalRule = Rule(
        match = Symbol("0"),
        replace = Symbol("0"),
        capFilters = {0: lambda e: isinstance(e, BinaryExpression) and e.matchRule(Rules.multEvaluableMatchRule)[0]},
        capModifiers = {0: lambda e: e.left * e.right if isinstance(e.right, Number) \
                        else (e.left * e.right.left) * e.right.right}
    )
    
    #Left Constant Add Match Rule
    leftConstAdd = MatchRule(
        match = Symbol("0") + Symbol("1"),
        capFilters = {0: lambda e: isinstance(e, Number)}
    )
    
    #Constant Add Evaluable Match Rule
    addEvaluableMatchRule = MatchRule(
        match = Symbol("0"),
        capFilters = {0: lambda e: isinstance(e, BinaryExpression) and e.matchRule(Rules.leftConstAdd)[0] \
                      and (isinstance(e.right, Number) or (e.right.matchRule(Rules.leftConstAdd)[0] \
                           if isinstance(e.right, BinaryExpression) else False))}
    )
    
    #Constant Add Evaluation Rule
    constAddEvalRule = Rule(
        match = Symbol("0"),
        replace = Symbol("0"),
        capFilters = {0: lambda e: isinstance(e, BinaryExpression) and e.matchRule(Rules.addEvaluableMatchRule)[0]},
        capModifiers = {0: lambda e: e.left + e.right if isinstance(e.right, Number) \
                        else (e.left + e.right.left) + e.right.right}
    )
    
    #Constant Exponent Evaluable Match Rule
    expEvaluableMatchRule = MatchRule(
        match = Symbol("0") ** Symbol("1"),
        capFilters = {0: lambda e: isinstance(e, Number), 1: lambda e: isinstance(e, Number)}
    )
    
    #Constant Exponent Evaluation Rule
    constExpEvalRule = Rule(
        match = Symbol("0"),
        replace = Symbol("0"),
        capFilters = {0: lambda e: e.matchRule(Rules.expEvaluableMatchRule)[0] if isinstance(e, BinaryExpression) else False},
        capModifiers = {0: lambda e: e.left ** e.right}
    )
    
    def likeTermsRule(first, second):
        """Determines if two terms are like. Returns True if they are and False Otherwise."""
        if isinstance(first, BinaryExpression) and first.matchRule(Rules.leftConstMult)[0]:
            return Rules.likeTermsRule(first.right, second)
        elif isinstance(second, BinaryExpression) and second.matchRule(Rules.leftConstMult)[0]:
            return Rules.likeTermsRule(first, second.right)
        else:
            return first == second
        
    def combineLikeTerms(first, second): #Assumes that first and second are like terms
        """Combines two like terms and returns the result. Unchecked Precondition: the terms are like."""
        if isinstance(first, BinaryExpression) and first.matchRule(Rules.leftConstMult)[0]:
            if isinstance(second, BinaryExpression) and second.matchRule(Rules.leftConstMult)[0]: #first and second have coefficient
                return (first.left + second.left) * first.right
            else: #first has coefficient, second does not
                return Rules.combineLikeTerms(second, first)
        elif isinstance(second, BinaryExpression) and second.matchRule(Rules.leftConstMult)[0]: #first does not, second does
            return (1 + second.left) * first
        else: #neither has a coefficient
            return 2 * first
                
    #Rule for combining like terms
    addCombineRule = Rule(
        match = Symbol("0"),
        replace = Symbol("0"),
        capFilters = {0: lambda e: Rules.likeTermsRule(e.left, e.right) if isinstance(e, BinaryExpression) 
                      and e.matchRule(Rules.addMatchRule)[1] else False},
        capModifiers = {0: lambda e: Rules.combineLikeTerms(e.left, e.right)}
    )
    
    #Rule for combining like terms in chain addition (where the right child is a sum)
    addChainCombineRule = Rule(
        match = Symbol("0"),
        replace = Symbol("0"),
        capFilters = {0: lambda e: Rules.likeTermsRule(e.left, e.right.left) if isinstance(e, BinaryExpression) 
                      and e.matchRule(Rules.addChainMatchRule)[1] else False},
        capModifiers = {0: lambda e: Rules.combineLikeTerms(e.left, e.right.left) + e.right.right}
    )
    
    def likeFactorsRule(first, second):
        """Determines if two factors are like. Returns True if they are and False otherwise."""
        if isinstance(first, BinaryExpression) and first.operator == "^":
            if isinstance(first.left, BinaryExpression) and first.left.operator == "^":
                return False
            else:
                return Rules.likeFactorsRule(first.left, second)
        elif isinstance(second, BinaryExpression) and second.operator == "^":
            if isinstance(second.left, BinaryExpression) and second.left.operator == "^":
                return False
            else:
                return Rules.likeFactorsRule(first, second.left)
        else:
            if isinstance(first, BinaryExpression) or isinstance(second, BinaryExpression):
                return False
            elif isinstance(first, Number) or isinstance(second, Number):
                return False
            else:
                return first == second
        
    def combineLikeFactors(first, second): #Assumes that first and second are like factors
        """Combines two like factors and returns the result. Unchecked Precondition: the factors are like."""
        if isinstance(first, BinaryExpression) and first.operator == "^":
            if isinstance(second, BinaryExpression) and second.operator == "^": #first and second have exponents
                return first.left ** (first.right + second.right)
            else: #first has exponent, second does not
                return Rules.combineLikeFactors(second, first)
        elif isinstance(second, BinaryExpression) and second.operator == "^": #first does not, second does
            return first ** (second.right + 1)
        else: #neither has an exponent
            return first ** 2
    
    #Combine exponents over multiplication
    expCombineRule = Rule(
        match = Symbol("0"),
        replace = Symbol("0"),
        capFilters = {0: lambda e: Rules.likeFactorsRule(e.left, e.right)
                      if (e.matchRule(Rules.multMatchRule)[0] if isinstance(e, BinaryExpression) else False) 
                      else False},
        capModifiers = {0: lambda e: Rules.combineLikeFactors(e.left, e.right)}
    )
    
    #Combine exponents in multiplication chain
    expMultChainCombineRule = Rule(
        match = Symbol("0"),
        replace = Symbol("0"),
        capFilters = {0: lambda e: Rules.likeFactorsRule(e.left, e.right.left) 
                      if (e.matchRule(Rules.multChainMatchRule)[0] if isinstance(e, BinaryExpression) else False)
                      else False},
        capModifiers = {0: lambda e: e.right.right * Rules.combineLikeFactors(e.left, e.right.left)}
    )
    
    #Matches addition in an exponent
    expAddMatchRule = MatchRule(
        match = Symbol("0") ** (Symbol("1") + Symbol("2")),
        capFilters = {0: lambda e: isinstance(e, BinaryExpression) or isinstance(e, Number)}
    )
    
    #Splitting Exponent Addition
    expSplitRule = Rule(
        match = Symbol("0"),
        replace = Symbol("0"),
        capFilters = {0: lambda e: isinstance(e.right.left, Number) or isinstance(e.right.right, Number) 
                      if (e.matchRule(Rules.expAddMatchRule)[0] if isinstance(e, BinaryExpression) else False) else False},
        capModifiers = {0: lambda e: (e.left ** e.right.left) * (e.left ** e.right.right)}
    )
    
    #Multiplicative Inverse Rule
    multInverseRule = Rule(
        match = Symbol("0") * Symbol("1"),
        replace = Symbol("0"),
        capFilters = {0: lambda e: e == 0},
        capModifiers = {0: lambda e: 0}
    )
    
    #Multiplicative Identity Rule
    multIdentityRule = Rule(
        match = Symbol("0") * Symbol("1"),
        replace = Symbol("1"),
        capFilters = {0: lambda e: e == 1}
    )
    
    #Addative Identity Rule
    addIdentityRule = Rule(
        match = Symbol("0") + Symbol("1"),
        replace = Symbol("1"),
        capFilters = {0: lambda e: e == 0}
    )
    
    #Exponential Identity Rule
    expIdentityRule = Rule(
        match = Symbol("0") ** Symbol("1"),
        replace = Symbol("0"),
        capFilters = {1: lambda e: e == 1}
    )
    
    #Exponential Inverse Rule
    expInverseRule = Rule(
        match = Symbol("0") ** Symbol("1"),
        replace = Symbol("0"),
        capFilters = {1: lambda e: e == 0},
        capModifiers = {0: lambda e: 1}
    )
    
    #Defualt Ruleset
    defaultRuleset = Ruleset([
        noSubtractionRule,
        noDivisionRule,
        constMultEvalRule,
        multiplicationExpandRule,
        rightHeavyMultRule,
        hashMultSortRule,
        hashMultChainSortRule,
        rightHeavyAdditionRule,
        constAddEvalRule,
        hashAddSortRule,
        hashAddChainSortRule,
        addCombineRule,
        addChainCombineRule,
        multInverseRule,
        multIdentityRule,
        addIdentityRule,
        expIdentityRule,
        expInverseRule,
        expPowerRule,
        expMultDistributeRule,
        constExpEvalRule,
        multNegExpExpandRule,
        expAddExpandRule,
        negExpAddExpandRule,
        expCombineRule,
        expMultChainCombineRule,
        expSplitRule
    ])

In [9]:
class Equality(): #Equality class
    pass

In [10]:
x = Symbol("x")
y = Symbol("y")
z = Symbol("z")
a = Symbol("a")
b = Symbol("b")

In [13]:
e = 6 * x ** 2 * (3 * x + 4) ** 2 - 5 * (x - 3 * x ** 2) ** 2
e = e.applyRuleset(Rules.defaultRuleset)
e


+
  *
    9
    ^
      x
      4
  +
    *
      174
      ^
        x
        3
    *
      91
      ^
        x
        2

In [508]:
e = (x + 5) ** (y + 2)
e = e.applyRuleset(Rules.defaultRuleset)
e


+
  *
    25
    ^
      +
        5
        x
      y
  +
    *
      10
      *
        x
        ^
          +
            5
            x
          y
    *
      ^
        x
        2
      ^
        +
          5
          x
        y

In [483]:
Rules.likeFactorsRule(x ** 5, x ** 2)

True

In [495]:
e = y ** 5 * x ** 2
e.applyRule(Rules.expCombineRule)

(False, None)

In [433]:
e = (x + 5) * (x - 5)
e = e.applyRuleset(Rules.defaultRuleset)
e


+
  -25
  *
    x
    x

In [434]:
e = BinaryExpression(5, "^", 3)
e = e.applyRuleset(Rules.defaultRuleset)
e

125

In [436]:
e = (x + 5) ** -1 * (x - 5) ** -1
e = e.applyRuleset(Rules.defaultRuleset)
e


^
  +
    -25
    *
      x
      x
  -1

In [386]:
e = x + x * 5
e = e.applyRuleset(Rules.defaultRuleset)
e


*
  6
  x

In [388]:
e = x ** 0 * y ** (1 + 0 * z)
e = e.applyRuleset(Rules.defaultRuleset)
e

y

In [443]:
3 % 1 == 0

True

In [497]:
e = (x + 5) ** 3
e = e.applyRuleset(Rules.defaultRuleset)
e


+
  125
  +
    ^
      x
      3
    +
      *
        15
        ^
          x
          2
      *
        75
        x

In [471]:
e = (x + 5) ** -3
e = e.applyRuleset(Rules.defaultRuleset)
e


^
  +
    125
    +
      *
        15
        *
          x
          x
      +
        *
          x
          *
            x
            x
        *
          75
          x
  -1

In [299]:
Rules.combineLikeTerms(7 * x ** 2, x ** 2)


*
  8
  ^
    x
    2

In [244]:
Rules.addCompareSortRule(3 * x, 10)

True

In [245]:
e = 3 * x + 10
e.matchRule(Rules.addMatchRule)[0]

True

In [246]:
e = (x + y) + z
e.applyRule(Rules.rightHeavyAdditionRule)

(True, 
 +
   x
   +
     y
     z)

In [255]:
e = (5 * x) + 7 + y + (3 * x) + 10 + a + b
print(e)
e = e.applyRuleset(Rules.defaultRuleset)
e


+
  +
    +
      +
        +
          +
            *
              5
              x
            7
          y
        *
          3
          x
      10
    a
  b



+
  17
  +
    a
    +
      y
      +
        b
        +
          *
            5
            x
          *
            3
            x

In [192]:
e = 7 + (3 + x)
e = e.applyRuleset(Rules.defaultRuleset)
e

+
  10
  x

In [182]:
e = (a * 5) * ((a ** 2 * y) * (7 * a ** 5))
e = e.applyRuleset(Rules.defaultRuleset)
e

*
  35
  *
    a
    *
      ^
        a
        2
      *
        ^
          a
          5
        y

In [184]:
e = a * 5 * (x + 5) * 10 + b
e = e.applyRuleset(Rules.defaultRuleset)
print(e.generateTreeString())

+
  *
    50
    *
      a
      x
  +
    *
      250
      a
    b


In [106]:
e = 3 * ((y + 5) + z)
e = e.applyRuleset(Rules.defaultRuleset)
print(e.generateTreeString())

+
  +
    *
      3
      y
    15
  *
    3
    z


In [107]:
e = a * (x + y)
e = e.applyRuleset(Rules.defaultRuleset)
#e = e.applyRule(Rules.multiplicationExpandRule)
print(e.generateTreeString())

+
  *
    x
    a
  *
    y
    a


In [108]:
e = z * 3
e = e.applyRuleset(Rules.defaultRuleset)
print(e.generateTreeString())

*
  3
  z


In [114]:
e = 2 * (3 * x + 3 * y * 4)
print(e.generateTreeString())
print("\n")
e = e.applyRuleset(Rules.defaultRuleset)
print(e.generateTreeString())
print("\n")
e = (3 * x + 3 * y * 4) * 2
e = e.applyRuleset(Rules.defaultRuleset)
print(e.generateTreeString())

*
  2
  +
    *
      3
      x
    *
      *
        3
        y
      4


+
  *
    6
    x
  *
    24
    y


+
  *
    6
    x
  *
    24
    y


In [10]:
import unittest

In [40]:
class TestRules(unittest.TestCase):
    """Testing class for Rules"""
    x = Symbol("x")
    y = Symbol("y")
    z = Symbol("z")
    
    def testNoDivisionRule(self):
        e = self.x / self.y
        e = e.applyRule(Rules.noDivisionRule)[1]
        self.assertEqual(e, self.x * self.y ** -1)
        
        e = 2 / self.x
        e = e.applyRule(Rules.noDivisionRule)[1]
        self.assertEqual(e, 2 * self.x ** -1)
        
        e = self.x / 3
        e = e.applyRule(Rules.noDivisionRule)[1]
        self.assertEqual(e, self.x * BinaryExpression(3, "^", -1))
        
        e = BinaryExpression(3, "/", 2)
        e = e.applyRule(Rules.noDivisionRule)[1]
        self.assertEqual(e, 3 * BinaryExpression(2, "^", -1))
        
        e = self.x * self.y
        self.assertEqual(e.applyRule(Rules.noDivisionRule)[0], False)
        
    def testNoSubtractionRule(self):
        e = self.x - self.y
        e = e.applyRule(Rules.noSubtractionRule)[1]
        self.assertEqual(e, self.x + -1 * self.y)
        
        e = 2 - self.x
        e = e.applyRule(Rules.noSubtractionRule)[1]
        self.assertEqual(e, 2 + -1 * self.x)
        
        e = self.x - 3
        e = e.applyRule(Rules.noSubtractionRule)[1]
        self.assertEqual(e, self.x + BinaryExpression(-1, "*", 3))
        
        e = BinaryExpression(3, "-", 2)
        e = e.applyRule(Rules.noSubtractionRule)[1]
        self.assertEqual(e, 3 + BinaryExpression(-1, "*", 2))
        
        e = self.x + self.y
        self.assertEqual(e.applyRule(Rules.noSubtractionRule)[0], False)
        
    def testNoConstRightMultRule(self):
        e = self.x * 5
        e = e.applyRule(Rules.noConstRightMultRule)[1]
        self.assertEqual(e, 5 * self.x)
        
        e = self.x * 5.0
        e = e.applyRule(Rules.noConstRightMultRule)[1]
        self.assertEqual(e, 5.0 * self.x)
        
        e = 2 * self.x
        self.assertEqual(e.applyRule(Rules.noConstRightMultRule)[0], False)
        
        e = self.x * self.y
        self.assertEqual(e.applyRule(Rules.noConstRightMultRule)[0], False)
        
    def testConstMultEvalRule(self):
        #Note that it only applies to left const mult
        e = 2 * self.x
        self.assertEqual(e.applyRule(Rules.constMultEvalRule)[0], False)
        
        e = BinaryExpression(2, "*", 5)
        e = e.applyRule(Rules.constMultEvalRule)[1]
        self.assertEqual(e, 10)
        
        e = 2 * (3 * self.x)
        e = e.applyRule(Rules.constMultEvalRule)[1]
        self.assertEqual(e, 6 * self.x)
        
    def testMultiplicationExpandRule(self):
        e = 2 * (self.x + self.y)
        e = e.applyRule(Rules.multiplicationExpandRule)[1]
        self.assertEqual(e, self.x * 2 + self.y * 2)
        
        e = 5.5 * (self.x + self.y)
        e = e.applyRule(Rules.multiplicationExpandRule)[1]
        self.assertEqual(e, self.x * 5.5 + self.y * 5.5)
        
        e = (self.x + self.y) * self.z
        e = e.applyRule(Rules.multiplicationExpandRule)[1]
        self.assertEqual(e, self.x * self.z + self.y * self.z)

        
        
class TestRuleset(unittest.TestCase):
    "Testing case for the default ruleset"
    x = Symbol("x")
    y = Symbol("y")
    z = Symbol("z")

    def testCase0(self):
        e = 2 * self.x * 1.5
        e = e.applyRuleset(Rules.defaultRuleset)
        self.assertEqual(e, 3.0 * self.x)
        
    def testCase1(self):
        e = self.x * self.y
        e = e.applyRuleset(Rules.defaultRuleset)
        self.assertEqual(e, self.x * self.y)
        
    def testCase2(self):
        e = 10 * (self.x + 5)
        e = e.applyRuleset(Rules.defaultRuleset)
        self.assertEqual(e, 10 * self.x + 50)
        
        
        

unittest.main(argv=[''], verbosity=2, exit=False)
pass

testConstMultEvalRule (__main__.TestRules) ... ok
testMultiplicationExpandRule (__main__.TestRules) ... ok
testNoConstRightMultRule (__main__.TestRules) ... ok
testNoDivisionRule (__main__.TestRules) ... ok
testNoSubtractionRule (__main__.TestRules) ... ok
testCase0 (__main__.TestRuleset) ... ok
testCase1 (__main__.TestRuleset) ... ok
testCase2 (__main__.TestRuleset) ... ok

----------------------------------------------------------------------
Ran 8 tests in 0.008s

OK
