In [6]:
import sympy
from sympy import Symbol, series, sqrt, O
from sympy.solvers.solveset import linsolve

def pade_approx_sqrt(n, m):
    """
    Return a Sympy rational expression R(x) = P(x)/Q(x) that forms
    the (n,m)-order Padé approximant to sqrt(x) around x=0.

    :param n: Degree of numerator polynomial P(x)
    :param m: Degree of denominator polynomial Q(x)
    :return: A Sympy expression for the Padé approximant (P/Q).
    """

    x = Symbol('x', positive=True)
    
    # 1) Get power series for sqrt(x) up to x^(n+m).
    f_series = series(sqrt(x), (x, 0, n + m + 1)).removeO()
    # This is e.g. a0 + a1*x + ... + a_{n+m}*x^{n+m} + O(x^{n+m+1})

    # 2) Extract the coefficients from f_series (optional for debugging).
    #    In principle, we only need f_series if we want to compare coefficients
    #    or do something else. The main series usage is in expr = (sqrt(x)*Q - P).
    
    # 3) Define symbolic coefficients for P(x) and Q(x).
    #    Let P(x) = p0 + p1*x + ... + p_n*x^n
    #        Q(x) = 1 + q1*x + ... + q_m*x^m  (q0 = 1 to fix scale)

    p_syms = sympy.symbols('p0 p1 p2 p3 p4 p5 p6 p7 p8 p9', real=True)[: n+1]
    q_syms = sympy.symbols('q0 q1 q2 q3 q4 q5 q6 q7 q8 q9', real=True)[: m+1]

    # Force q0 = 1 for uniqueness:
    # We'll skip q_syms[0] as a free variable. We'll treat it as 1 instead.
    # So we have n+1 + m unknowns in total.

    # Build the polynomials
    P = sum(p_syms[i]*x**i for i in range(n+1))
    Q = 1 + sum(q_syms[i]*x**(i) for i in range(1, m+1))

    # 4) Condition:  sqrt(x)*Q(x) - P(x) = O(x^{n+m+1}).
    #    We'll expand that and enforce that each coefficient of x^k for k=0..(n+m)
    #    is zero. This yields (n+m+1) linear equations.
    expr = (sqrt(x)*Q - P).expand().series(x, 0, n+m+1).removeO()

    # Now break it down by powers of x
    expr_terms = expr.expand().as_ordered_terms()

    # We'll store power->coefficient for expr
    expr_coeffs = {}
    for term in expr_terms:
        # term is something like c*x^k
        mono = term.expand()
        if mono.is_Number:
            k = 0
            c = mono
        else:
            x_pow = mono.as_powers_dict().get(x, 0)
            c = mono / (x**x_pow)
            k = x_pow
        expr_coeffs[k] = expr_coeffs.get(k, 0) + c

    # Build equations: each coefficient = 0 for k in [0..(n+m)]
    all_eqs = []
    for k in range(n+m+1):
        ck = expr_coeffs.get(k, 0)
        all_eqs.append(sympy.Eq(ck, 0))

    # 5) Solve the linear system for p_syms + q_syms[1..m].
    unknowns = list(p_syms) + list(q_syms[1:])  # exclude q0

    # -- FILTER OUT TRIVIAL OR CONTRADICTORY EQUATIONS --
    real_eqs = []
    for eq in all_eqs:
        eq_simpl = sympy.simplify(eq)
        if eq_simpl == True:
            # trivially 0=0, skip it
            continue
        elif eq_simpl == False:
            # contradiction => no solution
            raise ValueError("Equation simplified to False => inconsistent system.")
        else:
            real_eqs.append(eq_simpl)

    sol = linsolve(real_eqs, *unknowns)
    if not sol or len(sol) == 0:
        raise ValueError("No solution found or infinite solutions. Possibly too large n, m.")
    (solution,) = sol  # unpack from FiniteSet

    # 6) Substitute solution back into P, Q
    #    The first (n+1) in solution correspond to p_syms,
    #    then the next m correspond to q_syms[1..m].
    subs_map = {}
    for i in range(n+1):
        subs_map[p_syms[i]] = solution[i]
    for j in range(m):
        subs_map[q_syms[j+1]] = solution[n+1 + j]

    P_final = P.subs(subs_map)
    Q_final = Q.subs(subs_map)

    # 7) Return the ratio P_final/Q_final, simplified
    return sympy.simplify(P_final/Q_final)


# -----------------------
# EXAMPLE USAGE:
# Let's get a (2,2) approximation for sqrt(x).
if __name__ == "__main__":
    x = sympy.Symbol('x', positive=True)
    approx = pade_approx_sqrt(5, 5)
    print("Pade approximation (2,2) for sqrt(x):")
    print(approx)
    
    # Evaluate at x=1, should be near 1
    val1 = approx.subs(x, 1).evalf()
    print("Approx at x=1:", val1)

    # Evaluate at x=0.25, should be near 0.5
    val2 = approx.subs(x, 0.25).evalf()
    print("Approx at x=0.25:", val2)


Pade approximation (2,2) for sqrt(x):
0
Approx at x=1: 0
Approx at x=0.25: 0
