In [65]:
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap
from scipy.optimize import root_scalar
import sympy

In [66]:
x = sympy.Symbol('x')

## Define Functions

In [67]:
def func1(x):
   return jnp.square(x) - 6
func1_expr = x**2 - 6
func1_solutions = sympy.solve(func1_expr, x)
func1_solutions = [sol.evalf() for sol in func1_solutions if sol.is_real]
func1_solutions.sort()

def func2(x):
   return jnp.square(x) - (2*x) - 4
func2_expr = x**2 - 2*x - 4
func2_solutions = sympy.solve(func2_expr, x)
func2_solutions = [sol.evalf() for sol in func2_solutions if sol.is_real]
func2_solutions.sort()

def func3(x):
   return -jnp.square(x) - (2*x)
func3_expr = -x**2 - 2*x
func3_solutions = sympy.solve(func3_expr, x)
func3_solutions = [sol.evalf() for sol in func3_solutions if sol.is_real]
func3_solutions.sort()

def func4(x):
   return jnp.power(x, 3) + 2 * jnp.power(x, 2) + (4*x) + 3
func4_expr = x**3 + 2*x**2 + 4*x + 3
func4_solutions = sympy.solve(func4_expr, x)
func4_solutions = [sol.evalf() for sol in func4_solutions if sol.is_real]
func4_solutions.sort()

def func5(x):
   return -jnp.power(x, 3) - jnp.power(x, 2) +  (5 * x)
func5_expr = -x**3 - x**2 + 5*x
func5_solutions = sympy.solve(func5_expr, x)
func5_solutions = [sol.evalf() for sol in func5_solutions if sol.is_real]
func5_solutions.sort()

def func6(x):
   return jnp.cos(x)
func6_expr = sympy.cos(x)
func6_solutions = sympy.solve(func6_expr, x)
func6_solutions = [sol.evalf() for sol in func6_solutions if sol.is_real]
func6_solutions.sort()

def func7(x):
   return -4 * jnp.sin(x) + jnp.cos(x)
func7_expr = -4 * sympy.sin(x) + sympy.cos(x)
func7_solutions = sympy.solve(func7_expr, x)
func7_solutions = [sol.evalf() for sol in func7_solutions if sol.is_real]
func7_solutions.sort()

def func8(x):
   return (2*x) + 1
func8_expr = 2*x + 1
func8_solutions = sympy.solve(func8_expr, x)
func8_solutions = [sol.evalf() for sol in func8_solutions if sol.is_real]
func8_solutions.sort()

functions = [func1, func2, func3, func4, func5, func6, func7, func8]
solutions = [func1_solutions, func2_solutions, func3_solutions, func4_solutions, func5_solutions, func6_solutions, func7_solutions, func8_solutions]

In [69]:
for i in range(len(functions)):
   for sol in solutions[i]:
      print(functions[i](float(sol)))

4.7683716e-07
4.7683716e-07
0.0
4.7683716e-07
0.0
-0.0
0.0
9.536743e-07
0.0
-9.536743e-07
-4.371139e-08
1.1924881e-08
0.0
0.0


In [71]:
def newtons_method(func, x0, tol=1e-6, max_iter=100):
   x = x0
   for i in range(max_iter):
       x = x - func(x) / grad(func)(x)
       if abs(func(x)) < tol:
           return x
   return x

In [82]:
def converges(func, solutions, x0):
   return np.any([np.isclose(float(newtons_method(func, x0)), float(sol)) for sol in solutions])

In [87]:
for i in range(len(functions)):
    print(f"For function {i+1}: " + str(converges(functions[i], solutions[i], .10)))

For function 1: True
For function 2: True
For function 3: True
For function 4: True
For function 5: True
For function 6: False
For function 7: True
For function 8: True


In [90]:
def radius_of_convergence_bisection(func, solutions, max_x, max_iter=100):
    min_radius = float(solutions[0])
    max_radius = max_x
    min_convergence = converges(func, solutions, min_radius)
    max_convergence = converges(func, solutions, max_radius)
    if min_convergence != max_convergence:
        for i in range(max_iter):
            new_radius = (min_radius + max_radius) / 2
            new_convergence = converges(func, solutions, new_radius)
            if new_convergence:
                min_radius = new_radius
            else:
                max_radius = new_radius
            if abs(min_radius - max_radius) < 1e-6:
                return min_radius
        return min_radius
    return None
    

In [110]:
print(radius_of_convergence_bisection(func6, func6_solutions, 1000.0))
print(radius_of_convergence_bisection(func6, func6_solutions, -1000.0))

5.884843938768045
-0.39061930075447276


In [109]:
converges(func6, func6_solutions, -1.10)

False