In [23]:
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap
import sympy

In [25]:
x = sympy.Symbol('x')

In [45]:
def newtons_method(func, x0, tol=1e-6, max_iter=100):
   x = x0
   for i in range(max_iter):
       x = x - func(x) / grad(func)(x) 
       if jnp.abs(func(x)) < tol:  # changed to using jax absolute value to allow for compatibility with other functions
           return x
   return x

In [29]:
def converges(func, solution, x0):# this checks the convergence given a specific root
  return np.isclose(float(newtons_method(func, x0)), float(solution))


def radius_of_convergence_bisection(func, solutions, max_x=100., max_iter=100):
    radii_of_convergence = []
    for i, solution in enumerate(solutions):
        # Check positive side first:
        # Set min and max bounds for the radius of convergence
        min_radius = float(solution)
        if len(solutions) > (i + 1): #if there are more roots, set the max radius to the next root
            max_radius = float(solutions[i + 1])
        else:
            max_radius = max_x

        # Check if bounds converge
        maxRadius_convergence = converges(func, solution, max_radius) # Boolean status of max convergence
        minRadius_convergence = converges(func, solution, min_radius) # Boolean status of min convergence

        if maxRadius_convergence: # Should have additional checks here
            upper_radius = None
        elif not minRadius_convergence:
            upper_radius = 0.0
        else:
            midpoint = float((solution + max_radius) / 2)
            for j in range(max_iter):  # bisection method for finding the radius of convergence
                midpoint_convergence = converges(func, solution, midpoint)
                if midpoint_convergence:
                    min_radius = midpoint
                    midpoint = ((midpoint + max_radius) / 2)
                else:
                    max_radius = midpoint
                    midpoint = ((midpoint + min_radius) / 2)

                if abs((midpoint - solution)) <= 1e-6 or abs((midpoint - max_radius)) <= 1e-6:
                    break

            upper_radius = np.round(midpoint, 5)

        # Perform the same procedure below the min radius
        min_radius = float(solution)
        max_radius = float(solutions[i - 1]) if i > 0 else -max_x


        maxRadius_convergence = converges(func, solution, max_radius)#the convergence of the maxRadius
        minRadius_convergence = converges(func, solution, min_radius)#the convergence of the minRadius

        if maxRadius_convergence:#if the max radius converges then we can conclude that the max radius is infinty
            lower_radius = None
        elif not minRadius_convergence:
            lower_radius = 0.0
        else:
            midpoint = float((solution + max_radius) / 2)
            for i in range(max_iter):#bisection method for finding the radius of convergence
                midpoint_convergence = converges(func, solution, midpoint)
                if i == max_iter - 1:#keep this
                        if midpoint_convergence == True:
                            break
                if midpoint_convergence == True:
                        min_radius = midpoint
                        midpoint = float((midpoint + max_radius) / 2)
                if midpoint_convergence == False:
                        max_radius = midpoint
                        midpoint = float((midpoint + min_radius) / 2)

                if abs(float(midpoint - solution)) <= 1e-6 or abs(float(midpoint - max_radius)) <= 1e-6:
                        break

            lower_radius = np.round(midpoint,5)
        radii_of_convergence.append((lower_radius, upper_radius))
    return radii_of_convergence

In [31]:
def func1(x):
   return jnp.square(x) - 6
func1_expr = x**2 - 6
func1_solutions = sympy.solve(func1_expr, x)
func1_solutions = [sol.evalf() for sol in func1_solutions if sol.is_real]
func1_solutions.sort()

def func2(x):
   return jnp.square(x) - (2*x) - 4
func2_expr = x**2 - 2*x - 4
func2_solutions = sympy.solve(func2_expr, x)
func2_solutions = [sol.evalf() for sol in func2_solutions if sol.is_real]
func2_solutions.sort()

def func3(x):
   return -jnp.square(x) - (2*x)
func3_expr = -x**2 - 2*x
func3_solutions = sympy.solve(func3_expr, x)
func3_solutions = [sol.evalf() for sol in func3_solutions if sol.is_real]
func3_solutions.sort()

def func4(x):
   return jnp.power(x, 3) + 2 * jnp.power(x, 2) + (4*x) + 3
func4_expr = x**3 + 2*x**2 + 4*x + 3
func4_solutions = sympy.solve(func4_expr, x)
func4_solutions = [sol.evalf() for sol in func4_solutions if sol.is_real]
func4_solutions.sort()

def func5(x):
   return -jnp.power(x, 3) - jnp.power(x, 2) +  (5 * x)
func5_expr = -x**3 - x**2 + 5*x
func5_solutions = sympy.solve(func5_expr, x)
func5_solutions = [sol.evalf() for sol in func5_solutions if sol.is_real]
func5_solutions.sort()

def func6(x):
   return jnp.cos(x)
func6_expr = sympy.cos(x)
func6_solutions = sympy.solve(func6_expr, x)
func6_solutions = [sol.evalf() for sol in func6_solutions if sol.is_real]
func6_solutions.sort()

def func7(x):
   return -4 * jnp.sin(x) + jnp.cos(x)
func7_expr = -4 * sympy.sin(x) + sympy.cos(x)
func7_solutions = sympy.solve(func7_expr, x)
func7_solutions = [sol.evalf() for sol in func7_solutions if sol.is_real]
func7_solutions.sort()

def func8(x):
   return (2*x) + 1
func8_expr = 2*x + 1
func8_solutions = sympy.solve(func8_expr, x)
func8_solutions = [sol.evalf() for sol in func8_solutions if sol.is_real]
func8_solutions.sort()

functions = [func1, func2, func3, func4, func5, func6, func7, func8]
solutions = [func1_solutions, func2_solutions, func3_solutions, func4_solutions, func5_solutions, func6_solutions, func7_solutions, func8_solutions]

In [33]:
for func, solutions in zip(functions, solutions):
    print(radius_of_convergence_bisection(func, solutions))

[(None, -0.0), (0.0, None)]
[(None, 1.0), (1.0, None)]
[(None, -1.0), (-0.168, None)]
[(None, None)]
[(None, -1.66667), (-0.67568, 0.81272), (1.0, None)]
[(0.40523, 2.73636), (3.54683, 5.87795)]
[(-0.92058, 1.41054)]
[(None, None)]


In [49]:
# randomly generating linear combos of polynomials using bases x, x^2, x^3, x^4, sin, cos
# then apply our convergence functions to them, to output roots and data of how they converge with newtons

def generate_random_function():
    """Generates a random function as a linear combination of x, x^2, x^3, x^4, sin(x), and cos(x)."""
    bases = [x, x**2, x**3, x**4]
    coefficients = np.random.uniform(-10, 10, len(bases))  # uniform distribution from -10 to 10, can extend to floats if needed
    random_expr = sum(c * b for c, b in zip(coefficients, bases)) # random coefficient times basis summed together to create function
    
    # making jax function
    func = lambda x_val: jnp.array(sp.lambdify(x, random_expr, 'numpy')(x_val)) 
    return func, random_expr

def analyze_function():
    """Generates a random function, finds its roots, and analyzes convergence."""
    func, expr = generate_random_function()
    print(f"Generated function: {expr}")
    
    # Find real roots
    roots = sympy.solve(expr, x)
    real_roots = [r.evalf() for r in roots if r.is_real]
    real_roots.sort()
    print(f"Real roots: {real_roots}")
    
    # Apply radius of convergence analysis
    if real_roots:
        convergence_data = radius_of_convergence_bisection(func, real_roots)
        print(f"Radius of convergence data: {convergence_data}")
    else:
        print("No real roots found, skipping convergence analysis.")
    
    return expr, real_roots

In [None]:
for i in range(10):
    print(f"\n=== Analysis for function {i+1} ===")
    analyze_function()



=== Analysis for function 1 ===
Generated function: -2.44398449674751*x**4 + 5.64955219817449*x**3 - 1.74097143601477*x**2 - 0.681638135747484*x
Real roots: [0]
Radius of convergence data: [(-0.10455, 0.06374)]

=== Analysis for function 2 ===
Generated function: -4.1957978913853*x**4 + 5.5644780946362*x**3 - 5.5407189865863*x**2 - 8.47519120877924*x
Real roots: [-0.721749606427392, 0]
Radius of convergence data: [(None, -0.40906), (-0.40902, None)]

=== Analysis for function 3 ===
Generated function: 5.76444137659266*x**4 - 9.02868632734616*x**3 - 0.33394035079176*x**2 + 7.32347246809201*x
Real roots: [-0.752757627567775, 0]
Radius of convergence data: [(None, -0.44712), (-0.44711, 6.63999)]

=== Analysis for function 4 ===
Generated function: 5.77528644420029*x**4 + 5.32747637166093*x**3 + 2.9373775416835*x**2 - 4.4873759319108*x
Real roots: [0, 0.570862011404217]
Radius of convergence data: [(None, 0.32908), (0.32908, None)]

=== Analysis for function 5 ===
Generated function: -8.7

The function is really slow, so I want to @jit it for e, but I need to overhaul everything using jax operations to do so. Also was having some issues with using trig functions. 

I also want to clean up the presentation of this data further, maybe throw it in a table instead. 