<a href="https://colab.research.google.com/github/LiFeLeSS5858/KA1/blob/main/%D0%9A%D0%906.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from IPython.display import display, Math
import matplotlib
import matplotlib.pyplot as plt

In [2]:
def add(x, y):
  return ['+', x, y]

def mul(x, y):
  return ['*', x, y]

def div(x, y):
  return ["/", x, y]

def sub(x, y):
  return ["-", x, y]

def pow(x, y):
  return ["^", x, y]

def neg(x):
  return ["--", x]

In [3]:
def fsqrt(x):
  return ['sqrt', x]

def fsin(x):
  return ['sin', x]

def fcos(x):
  return ['cos', x]

def fln(x):
  return ['ln', x]

In [4]:
def check(x):
  if isinstance(x, Expression):
    return x
  return Expression(x)

class Expression:
  def __init__(self, f):
    self.f = f
  def __add__(self, x):
    x = check(x)
    return Expression(add(self.f, x.f))
  def __sub__(self, x):
    x = check(x)
    return Expression(sub(self.f, x.f))
  def __mul__(self, x):
    x = check(x)
    return Expression(mul(self.f, x.f))
  def __truediv__(self, x):
    x = check(x)
    return Expression(div(self.f, x.f))
  def __pow__(self, x):
    x = check(x)
    return Expression(pow(self.f, x.f))
  def __radd__(self, x):
    x = check(x)
    return Expression(add(x.f, self.f))
  def __rsub__(self, x):
    x = check(x)
    return Expression(sub(x.f, self.f))
  def __rmul__(self, x):
    x = check(x)
    return Expression(mul(x.f, self.f))
  def __rtruediv__(self, x):
    x = check(x)
    return Expression(div(x.f, self.f))
  def __rpow__(self, x):
    x = check(x)
    return Expression(pow(x.f, self.f))
  def __neg__(self):
    return Expression(neg(self.f))
  def subs(self, var, val):
    return Expression(substitute(self.f, var, val))
  def evalf(self):
    return evaluate(self.f)
  def __str__(self):
    return expr2latex(self.f)
  def show(self):
    display(Math(expr2latex(self.f)))
  def diff(self, var):
    var = check(var).f
    return Expression(differentiate(self.f, var))
  def simp(self):
    return Expression(simplify(self.f))
  def expn(self):
    return Expression(expand(self.f))
  def intg(self, var):
    var = check(var).f
    return Expression(integrate(self.f, var))
  

In [5]:
def sqrt(x):
  x = check(x)
  t = x.f
  r = fsqrt(t)
  return Expression(r)

def sin(x):
  return Expression(fsin(check(x).f))

def cos(x):
  return Expression(fcos(check(x).f))

def ln(x):
  return Expression(fln(check(x).f))


In [6]:
def substitute(f, var, val):
  var = check(var).f
  val = check(val).f
  if f == var:
    return val
  elif isinstance(f, list):
    y = [f[0]]
    for i in range(1, len(f)):
      t = substitute(f[i], var, val)
      y.append(t)
    return y
  else:
    return f


In [7]:
import math

def evaluate(f):
  if isinstance(f, (int, float)):
    return f
  elif isinstance(f, list):
    if f[0] == "+": 
      return evaluate(f[1]) + evaluate(f[2])
    elif f[0] == "*":
      return evaluate(f[1]) * evaluate(f[2])
    elif f[0] == "sqrt":
      return math.sqrt(evaluate(f[1]))
    elif f[0] == "sin":
      return math.sin(evaluate(f[1]))
    elif f[0] == "cos":
      return math.cos(evaluate(f[1]))
  return None
  

In [8]:
def plot(expr, var, a, b, n):
  dx = (b - a) / n
  X, Y = [], []
  for i in range(n + 1):
    x = a + dx * i
    y = expr.subs(var, x).evalf()
    X.append(x)
    Y.append(y)

  fig, ax = plt.subplots(figsize=(12, 9))
  ax.plot(X, Y, color="red", lw=5)
  plt.show()


In [9]:
def symbols(vars):
  return map(Expression, vars.split())

x, y = symbols("x y")  

In [10]:
def expr2latex(x):
  if isinstance(x, (int, float)):
    return str(x)
  if isinstance(x, str):
    return x
  if isinstance(x, list):
    if x[0] == "+":
      return expr2latex(x[1]) + "+" + expr2latex(x[2])
    if x[0] == "-":
      return expr2latex(x[1]) + "-" + expr2latex(x[2])
    if x[0] == "*":
      left = expr2latex(x[1])
      if isinstance(x[1], list) and x[1][0] == "+":
        left = "\\left(" + left + "\\right)"
      right = expr2latex(x[2])
      if isinstance(x[2], list) and x[2][0] == "+":
        right = "\\left(" + right + "\\right)"
      return left + "\\cdot " + right
    if x[0] == "/":
      return "\\dfrac{" + expr2latex(x[1]) + "}{" + expr2latex(x[2]) + "}"
    if x[0] == "^":
      left = expr2latex(x[1])
      if isinstance(x[1], list):
        left = "\\left(" + left + "\\right)"
      right = expr2latex(x[2])
      return "{" + left + "}^{" + right + "}"
    if x[0] == "--":
      a = expr2latex(x[1])
      if isinstance(x[1], list):
        a = "\\left(" + a + "\\right)"
      return "-{" + a + "}"
    if x[0] == "sqrt":
      a = expr2latex(x[1])
      return "\\sqrt{" + a + "}"
    if x[0] == "sin":
      a = expr2latex(x[1])
      if isinstance(x[1], list):
        a = "\\left(" + a + "\\right)"
      return "\\sin{" + a + "}"
    if x[0] == "cos":
      a = expr2latex(x[1])
      if isinstance(x[1], list):
        a = "\\left(" + a + "\\right)"
      return "\\cos{" + a + "}"
    if x[0] == "ln":
      a = expr2latex(x[1])
      if isinstance(x[1], list):
        a = "\\left(" + a + "\\right)"
      return "\\ln{" + a + "}"
    if x[0] == "int":
      a = expr2latex(x[1])
      if isinstance(x[1], list) and x[1][0] in "+-":
        a = "\\left(" + a + "\\right)"
      return "\\int{" + a + "}d" + x[2]


In [11]:
def differentiate(f, x):
  if isinstance(f, int):
    return 0
  if f == x:
    return 1
  if isinstance(f, str):
    return 0
  if isinstance(f, list):
    if f[0] == "+":
      u, v = f[1], f[2]
      u1, v1 = differentiate(u, x), differentiate(v, x)
      return add(u1, v1)
    if f[0] == "-":
      u, v = f[1], f[2]
      u1, v1 = differentiate(u, x), differentiate(v, x)
      return sub(u1, v1)      
    if f[0] == "*":
      u, v = f[1], f[2]
      u1, v1 = differentiate(u, x), differentiate(v, x)
      return add(mul(u1, v), mul(u, v1))   
    if f[0] == "/":
      u, v = f[1], f[2]
      u1, v1 = differentiate(u, x), differentiate(v, x)
      return div(sub(mul(u1, v), mul(u, v1)), pow(v, 2))   
    if f[0] == "^":
      u, v = f[1], f[2]
      u1, v1 = differentiate(u, x), differentiate(v, x)
      p1, p2 = pow(u, v), pow(u, sub(v, 1))
      s1 = mul(p1, mul(v1, fln(u)))
      s2 = mul(v, mul(p2, u1))
      return add(s1, s2) 
    if f[0] == "--":
      u = f[1]
      u1 = differentiate(u, x)
      return neg(u1)
    if f[0] == "ln":
      u = f[1]
      u1 = differentiate(u, x)
      return div(u1, u)
    if f[0] == "sin":
      u = f[1]
      u1 = differentiate(u, x)
      return mul(fcos(u), u1)
    if f[0] == "cos":
      u = f[1]
      u1 = differentiate(u, x)
      return mul(neg(fsin(u)), u1)
    if f[0] == "sqrt":
      u = f[1]
      u1 = differentiate(u, x)
      return div(u1, mul(2, fsqrt(u)))
      

In [12]:
def simplify(f):
  if isinstance(f, list):
    r = [f[0]]
    for x in f[1:]:
      r.append(simplify(x))
    if r[0] == "+":
      if isinstance(r[1], int) and isinstance(r[2], int):
        return r[1] + r[2] 
      if r[1] == 0: 
        return r[2]
      if r[2] == 0:
        return r[1]
      if r[1] == r[2]:
        return mul(2, r[1])
      if isinstance(r[2], list) and r[2][0] == "--": # unary minus
        return simplify(sub(r[1], r[2][1]))
    if r[0] == "-":
      if isinstance(r[1], int) and isinstance(r[2], int):
        return r[1] - r[2]
      if r[1] == 0: 
        return neg(r[2])
      if r[2] == 0:
        return r[1]
      if r[1] == r[2]:
        return 0
      if isinstance(r[2], list) and r[2][0] == "--": # unary minus
        return simplify(add(r[1], r[2][1]))
    if r[0] == "*":
      if isinstance(r[1], int) and isinstance(r[2], int):
        return r[1] * r[2] 
      if r[1] == 0 or r[2] == 0: 
        return 0
      if r[1] == 1:
        return r[2]
      if r[2] == 1:
        return r[1]
      if r[1] == r[2]:
        return pow(r[1], 2)
    if r[0] == "^":
      if isinstance(r[1], int) and isinstance(r[2], int) and r[2] > 0:
        return r[1] ** r[2] 
      if r[2] == 1: 
        return r[1]
    if r[0] == "sqrt":
      if isinstance(r[1], int):
        a = int(math.sqrt(r[1]))
        if a * a == r[1]:
          return a
        return r 
    return r
  return f 

In [13]:
def expand(f):
  if isinstance(f, list):
    r = [f[0]]
    for x in f[1:]:
      r.append(expand(x))
    if r[0] == "*":
      if isinstance(r[1], list) and r[1][0] == "+":
        a = expand(mul(r[1][1], r[2]))
        b = expand(mul(r[1][2], r[2]))
        return add(a, b)
      if isinstance(r[2], list) and r[2][0] == "+":
        a = expand(mul(r[1], r[2][1]))
        b = expand(mul(r[1], r[2][2]))
        return add(a, b)
    if r[0] == "^":
      if isinstance(r[1], list) and r[1][0] == "+" and r[2] == 2:
        a = pow(r[1][1], 2)
        b = mul(2, mul(r[1][1], r[1][2]))
        c = pow(r[1][2], 2)
        return add(a, add(b, c))
    return r
  return f


In [14]:
def integral(f, x):
  return ["int", f, x]

def contains(f, x):
  if f == x:
    return True
  if isinstance(f, (int, str)):
    return False
  if isinstance(f, list):
    for a in f:
      if contains(a, x):
        return True
  return False



In [15]:
def integrate(f, x):
  if not contains(f, x):
    return mul(f, x)
  if f == x:
    return mul(div(1, 2), pow(x, 2))
  if isinstance(f, list):
    if f[0] == "cos" and f[1] == x:
      return fsin(x)
    if f[0] == "sin" and f[1] == x:
      return neg(fcos(x))
    if f[0] == "sqrt" and f[1] == x:
      return mul(div(2, 3), mul(x, fsqrt(x)))
    if f[0] == "^" and f[1] == x and not contains(f[2], x):
      if f[2] == -1:
        return fln(x)
      else:
        return mul(div(1, add(f[2], 1)), pow(x, add(f[2], 1))) 
    if f[0] == "^" and f[2] == x and not contains(f[1], x):
      return mul(div(1, fln(f[1])), f)
    if f[0] in "+-":
      a = integrate(f[1], x)
      b = integrate(f[2], x)
      return [f[0], a, b]
    if f[0] == "*" and not contains(f[1], x):
      a = integrate(f[2], x)
      return mul(f[1], a)
    if f[0] == "*" and not contains(f[2], x):
      a = integrate(f[1], x)
      return mul(f[2], a)
    if f[0] == "/" and not contains(f[2], x):
      a = integrate(f[1], x)
      return mul(div(1, f[2]), a)
    if f[0] == "*" and f[1] == x:
      t = x + "t"
      g = substitute(f[2], pow(x, 2), t)
      if not contains(g, x):
        G = integrate(g, t)
        if not contains(G, "int"): 
          F = substitute(G, t, pow(x, 2))
          return mul(div(1, 2), F)
    if f[0] == "*" and contains(f[1], "sin") and contains(f[2], "cos"):
      t = x+"sin(t)"
      g = substitute(f[1], sin(x), t)
      if not contains(g, x):
        G = integrate(g, t)
        if not contains(G, "int"): 
          F = substitute(G, t, sin(x))
          return F
    if f[0] == "*" and contains(f[1], "cos") and contains(f[2], "sin"):
      t = x+"cos(t)"
      g = substitute(f[1], cos(x), t)
      if not contains(g, x):
        G = integrate(g, t)
        if not contains(G, "int"): 
          F = substitute(G, t, cos(x))
          return neg(F)
    if f[0] == "*" and contains(f[1], "/") and contains(f[2], "ln"):
      t = x+"ln(t)"
      g = substitute(f[1], div(1,x), t)
      if not contains(g, x):
        G = integrate(g, t)
        if not contains(G, "int"): 
          F = substitute(G, t, ln(x))
          return F
    if f[0] == "*" and contains(f[1], "sqrt") and contains(f[2], "sqrt"):
      t = x+"sqrt(t)"
      g = substitute(f[2], f[2][1], t)
      if not contains(g, x):
        G = integrate(g, t)
        if not contains(G, "int"): 
          F = substitute(G, t, f[2][1])
          return mul(2,F)
    if f[0] == "*" and isinstance(f[1][1], int) and contains(f[2], "^"):
      t = x+"a^x"
      g = substitute(f[2], f[2][1], t)
      if not contains(g, x):
        G = integrate(g, t)
        if not contains(G, "int"): 
          F = substitute(G, t, f[2][1])
          return mul(div(1,fln(f[1][1])),F)
    if f[0] == "*" and contains(f[1], "^") and contains(f[2], "^"):
      t = x+"t^n"
      g = substitute(f[2], f[2][1], t)
      if not contains(g, x):
        G = integrate(g, t)
        if not contains(G, "int"): 
          F = substitute(G, t, f[2][1])
          return mul(div(1,f[2][1][2]),F)

  return integral(f, x)

In [36]:
z = x + 1
z.show()
z.expn().simp().intg(x).simp().show()

<IPython.core.display.Math object>

<IPython.core.display.Math object>

In [17]:
z = sin(x ** 2)
z.show()
print(contains(z.f, "+"))

<IPython.core.display.Math object>

False


In [33]:
z = Expression(2) / 3 + y ** 2
z.show()
z.intg(y).show()

<IPython.core.display.Math object>

<IPython.core.display.Math object>

In [34]:
z = x + 1
z.show()
z.intg(x).show()

<IPython.core.display.Math object>

<IPython.core.display.Math object>

In [20]:
z = y
z.show()
z.intg(y).show()

<IPython.core.display.Math object>

<IPython.core.display.Math object>

In [21]:
z = sqrt(x)
z.show()
z.intg(x).show()

<IPython.core.display.Math object>

<IPython.core.display.Math object>

In [22]:
z = x ** (-1)
z.show()
z.intg(x).simp().show()

<IPython.core.display.Math object>

<IPython.core.display.Math object>

In [23]:
z = 2 ** x
z.show()
z.intg(x).show()

<IPython.core.display.Math object>

<IPython.core.display.Math object>

In [24]:
z = 1 - x + sin(x) + cos(x) - sqrt(x)
z.show()
z.intg(x).simp().show()

<IPython.core.display.Math object>

<IPython.core.display.Math object>

In [25]:
z = 2 * ln(x) + 3 * ln(x + 1)
z.show()
z.intg(x).show()

<IPython.core.display.Math object>

<IPython.core.display.Math object>

In [26]:
z = cos(x) / 3
z.show()
z.intg(x).show()

<IPython.core.display.Math object>

<IPython.core.display.Math object>

In [27]:
z = (x + 1) ** 2
z.show()
z = z.expn().simp()
z.show()
z.intg(x).simp().show()

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

In [28]:
z = (cos(x) + 1) ** 2
z.show()
z.expn().simp().intg(x).simp().show()

<IPython.core.display.Math object>

<IPython.core.display.Math object>

In [29]:
z = x * (cos(x ** 2) + (x**2) ** 3)
z.show()
z.intg(x).show()

<IPython.core.display.Math object>

<IPython.core.display.Math object>

In [30]:
z = x * (x ** 2 * cos((x ** 2) ** 2))
z.show()
z.intg(x).show()

<IPython.core.display.Math object>

<IPython.core.display.Math object>

In [31]:
z = 2 * (x * sin(x ** 2)) + x ** 3 / 4 - 5 * sqrt(x)
z.show()
z.intg(x).simp().show()

<IPython.core.display.Math object>

<IPython.core.display.Math object>