In [None]:
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap
from scipy.optimize import root_scalar
import sympy

In [None]:
x = sympy.Symbol('x')
func1_solutions = sympy.solve(func1_expr, x)
func1_solutions = [abs(sol.evalf()) for sol in func1_solutions]
func1_solutions.sort()
func1_solution = func1_solutions[0]

In [24]:
solutions

[2.44948974278318, 2.44948974278318]

## Define Functions

In [44]:
def func1(x):
   return jnp.square(x) - 6
func1_expr = x**2 - 6
func1_solutions = sympy.solve(func1_expr, x)
func1_solutions = [sol.evalf() for sol in func1_solutions if sol.is_real]
func1_solutions.sort()
func1_solution = func1_solutions[0]

def func2(x):
   return jnp.square(x) - (2*x) - 4
func2_expr = x**2 - 2*x - 4
func2_solutions = sympy.solve(func2_expr, x)
func2_solutions = [sol.evalf() for sol in func2_solutions if sol.is_real]
func2_solutions.sort()
func2_solution = func2_solutions[0]

def func3(x):
   return -jnp.square(x) - (2*x)
func3_expr = -x**2 - 2*x
func3_solutions = sympy.solve(func3_expr, x)
func3_solutions = [sol.evalf() for sol in func3_solutions if sol.is_real]
func3_solutions.sort()
func3_solution = func3_solutions.pop()

def func4(x):
   return jnp.power(x, 3) + 2 * jnp.power(x, 2) + (4*x) + 3
func4_expr = x**3 + 2*x**2 + 4*x + 3
func4_solutions = sympy.solve(func4_expr, x)
func4_solutions = [sol.evalf() for sol in func4_solutions if sol.is_real]
func4_solutions.sort()
func4_solution = func4_solutions[0]

def func5(x):
   return -jnp.power(x, 3) - jnp.power(x, 2) +  (5 * x)
func5_expr = -x**3 - x**2 + 5*x
func5_solutions = sympy.solve(func5_expr, x)
func5_solutions = [sol.evalf() for sol in func5_solutions if sol.is_real]
func5_solutions.sort()
func5_solution = func5_solutions[0]

def func6(x):
   return jnp.cos(x)
func6_expr = sympy.cos(x)
func6_solutions = sympy.solve(func6_expr, x)
func6_solutions = [sol.evalf() for sol in func6_solutions if sol.is_real]
func6_solutions.sort()
func6_solution = func6_solutions[0]

def func7(x):
   return -4 * jnp.sin(x) + jnp.cos(x)
func7_expr = -4 * sympy.sin(x) + sympy.cos(x)
func7_solutions = sympy.solve(func7_expr, x)
func7_solutions = [sol.evalf() for sol in func7_solutions if sol.is_real]
func7_solutions.sort()
func7_solution = func7_solutions[0]

def func8(x):
   return (2*x) + 1
func8_expr = 2*x + 1
func8_solutions = sympy.solve(func8_expr, x)
func8_solutions = [sol.evalf() for sol in func8_solutions if sol.is_real]
func8_solutions.sort()
func8_solution = func8_solutions[0]

functions = [func1, func2, func3, func4, func5, func6, func7, func8]
solutions = [func1_solution, func2_solution, func3_solution, func4_solution, func5_solution, func6_solution, func7_solution, func8_solution]

In [51]:
for i in range(len(functions)):
   print(functions[i](float(solutions[i])))

4.7683716e-07
0.0
-0.0
0.0
9.536743e-07
-4.371139e-08
0.0
0.0


In [52]:
def newtons_method(func, solution, x0, tol=1e-6, max_iter=100):
   x = x0
   for i in range(max_iter):
       x = x - func(x) / grad(func)(x)
       if abs(func(x)) < tol:
           return x
   return x

In [63]:
np.isclose(float(newtons_method(func1, func1_solution, 1.0)), float(solutions[0]))

False

In [64]:
float(solutions[0])

-2.449489742783178

In [None]:
new