In [1]:
class SymbolObjects:
    def __add__(self, other):
        if isinstance(other,SymbolObjects):
            return Add(self,other)
        else:
            return Add(self,Symbol(other))
    
    def __mul__(self,other):
        if isinstance(other,SymbolObjects):
            return Mult(self,other)
        else:
            return Mult(self,Symbol(other))

In [2]:
class Symbol(SymbolObjects):
    def __init__(self,s):
        self.value = s
        
        if isinstance(s,str):
            self.type = "symbol"
        elif isinstance(s,int):
            self.type = "int"
        elif isinstance(s,float):
            self.type = "float"
        else:
            raise Exception("Unexpected symbol")
        
    def __str__(self):
        return str(self.value)

    def __repr__(self):
        return self.__str__()
    
    def derivative(self,dx):
        if self.type != "symbol":
            return Symbol(0)
        
        if dx == self:
            return Symbol(1)
        else:
            return Symbol(0)
    
    def __eq__(self, other):
        if not isinstance(other,SymbolObjects):
            return self == Symbol(other)
        
        if self.type != other.type:
            return False
        
        return self.value == other.value
    
    def simplify(self):
        return self

In [3]:
def symbols(s):
    arr = s.split()
    return [Symbol(a) for a in arr]

def symbolize(xs):
    if isinstance(xs,SymbolObjects):
        return xs
    
    if not isinstance(xs,list):
        return Symbol(xs)
    
    for i in range(len(xs)):
        if not isinstance(xs[i],SymbolObjects):
            xs[i] = Symbol(xs[i])
    return xs

def simplify(eq):
    eq = eq.simplify()
    s0 = str(eq)
    
    while True:
        eq = eq.simplify()
        s1 = str(eq)
        if s0 == s1:
            break
        else:
            s0 = s1
    return eq

In [66]:
class Add(SymbolObjects):
    def __init__(self,*symbols):
        if len(symbols) < 2:
            raise Exception("Too few arugments")
        
        symbols = list(symbols)
        
        temp = []
        for s in symbols:
            s = symbolize(s)
            if isinstance(s,Add):
                temp += s.symbols
            else:
                temp.append(s)
        
        self.symbols = []
        n=0
        for s in temp:
            if isinstance(s,Symbol) and s.type == "int":
                n += s.value
            else:
                self.symbols.append(s)
        self.symbols = [Symbol(n)]+sorted(self.symbols, key=lambda s: str(s))
        
    def __repr__(self):
        return self.__str__()
    
    def __str__(self):
        ret = str(self.symbols[0])
        
        for s in self.symbols[1:]:
            ret = "%s+%s" % (ret,str(s))
        
        return "(" + ret + ")"
    
    def derivative(self,dx):
        ret = ""
        d_symbols = [s.derivative(dx) for s in self.symbols]
        return Add(*d_symbols)
    
    def simplify(self):
        # simplify all leaf element
        symbols = [ s.simplify() for s in self.symbols ]
        
        # exclude 0
        symbols = [ s for s in symbols if not s == 0]
        
        n = len(symbols)
        if n==0:
            return Symbol(0)
        elif n==1:
            return symbols[0]
        else:
            return Add(*symbols)

In [74]:
class Mult(SymbolObjects):
    def __init__(self,*symbols):
        if len(symbols) < 2:
            raise Exception("Too few arugments")
        
        symbols = list(symbols)
        temp = []
        for s in symbols:
            s = symbolize(s)
            if isinstance(s,Mult):
                temp += s.symbols
            else:
                temp.append(s)
        
        self.symbols = []
        n=1
        for s in temp:
            if isinstance(s,Symbol) and s.type == "int":
                n *= s.value
            else:
                self.symbols.append(s)
        self.symbols = [Symbol(n)]+sorted(self.symbols, key=lambda s: str(s))
        
    def __repr__(self):
        return self.__str__()
    
    def __str__(self):
        ret = str(self.symbols[0])
        
        for s in self.symbols[1:]:
            ret = "%s*%s" % (ret,str(s))
        
        return ret
    
    def derivative(self,dx):
        diffs = []
        d_symbols = [s.derivative(dx) for s in self.symbols]
        
        for i in range(len(self.symbols)):
            rest = self.symbols[:i]+self.symbols[i+1:]
            # (abc)'=a'bc+ab'c+abc'
            diffs.append(Mult(d_symbols[i],*rest))
            
        return Add(*diffs)
    
    def simplify(self):
        # simplify all leaf element
        symbols = [ s.simplify() for s in self.symbols ]
        
        # search 0 
        n = len([s for s in symbols if s == 0])
        if n != 0:
            return Symbol(0)
        
        # exclude 1
        symbols = [ s for s in symbols if not s == 1]
        
        n = len(symbols)
        if n==0:
            return Symbol(1)
        elif n==1:
            return symbols[0]
        else:
            return Mult(*symbols)

In [75]:
class Cos(SymbolObjects):
    def __init__(self,symbol):
        
        self.symbol = symbolize(symbol)
        
    def __repr__(self):
        return self.__str__()
    
    def __str__(self):
        return "cos(%s)" % str(self.symbol)
    
    def derivative(self,dx):
        diff = self.symbol.derivative(dx)
        
        return Mult(-1,Sin(self.symbol),diff)
    
    def simplify(self):
        symbol = self.symbol.simplify()
        
        if symbol == 0:
            return Symbol(1)
        else:
            return Cos(symbol)

In [76]:
class Sin(SymbolObjects):
    def __init__(self,symbol):
        
        self.symbol = symbolize(symbol)
        
    def __repr__(self):
        return self.__str__()
    
    def __str__(self):
        return "sin(%s)" % str(self.symbol)
    
    def derivative(self,dx):
        diff = self.symbol.derivative(dx)
        
        return Mult(Cos(self.symbol),diff)
    
    def simplify(self):
        symbol = self.symbol.simplify()
        
        if symbol == 0:
            return Symbol(0)
        else:
            return Sin(symbol)

In [77]:
class Exp(SymbolObjects):
    def __init__(self,symbol):
        
        self.symbol = symbolize(symbol)
        
    def __repr__(self):
        return self.__str__()
    
    def __str__(self):
        return "exp(%s)" % str(self.symbol)
    
    def derivative(self,dx):
        diff = self.symbol.derivative(dx)
        
        return Mult(Exp(self.symbol),diff)
    
    def simplify(self):
        symbol = self.symbol.simplify()
        
        if symbol == 0:
            return Symbol(1)
        else:
            return Exp(symbol)

In [78]:
def get_jacobi(fs,xs):
    ret = []

    for f in fs:
        temp = []
        for x in xs:
            df = f.derivative(x)
            df = simplify(df)
            temp.append(df)
        ret.append(temp)
    return ret

x,y,z = symbols("x y z")
get_jacobi([x+y,x*y,x*x*x+y*y],[x,y])

[[1, 1], [y, x], [(0+1*x*x+1*x*x+1*x*x), (0+y+y)]]

In [79]:
x,y,z = symbols("x y z")
y = Cos(x)*Sin(y)+Sin(x)+x+Exp(x)
y

(0+1*cos(x)*sin(y)+exp(x)+sin(x)+x)

In [80]:
simplify(y.derivative(x))

(1+-1*sin(x)*sin(y)+cos(x)+exp(x))

In [81]:
dy = y.derivative(x)

In [82]:
dy

(1+-1*sin(x)*sin(y)+0*cos(x)*cos(y)+0*cos(x)*sin(y)+1*cos(x)+1*exp(x))

In [83]:
dy.simplify()

(1+-1*sin(x)*sin(y)+cos(x)+exp(x))

In [84]:
x,y,z = symbols("x y z")
(x+y)*z*y*5*5*5

125*(0+x+y)*y*z

In [None]:
Symbol(0)==0

In [None]:
a = Mult(1,0)

In [None]:
a.simplify()

In [None]:
simplify(Mult(Cos(x),0,Sin(x)))