In [76]:
import sympy as sym
sym.init_printing(use_unicode=True)

In [84]:
a, b = sym.symbols('a b')

A = sym.Matrix([[a,2,a,a+b,a-b],
               [a,2,a,a,a-b],
               [3,3,-b,3,-b],
               [a+1,3,a+1,a+1,a-b+1]])

In [223]:
def scale_row(A: sym.Matrix, 
              idx: int, 
              scalar: float, 
              verbosity: int = 0) -> sym.Matrix:
    A = sym.Matrix(A)
    A[idx, :] = scalar * A[idx, :]
    A = sym.simplify(A)
    if verbosity >= 1:
        print(f"R_{idx+1} <- ({scalar})R_{idx+1}")
    if verbosity >= 2:
        sym.pprint(A)
        print('\n')
    return A

scale_row(A, 2, 1/a)

⎡  a    2    a    a + b    a - b  ⎤
⎢                                 ⎥
⎢  a    2    a      a      a - b  ⎥
⎢                                 ⎥
⎢  3    3   -b      3       -b    ⎥
⎢  ─    ─   ───     ─       ───   ⎥
⎢  a    a    a      a        a    ⎥
⎢                                 ⎥
⎣a + 1  3  a + 1  a + 1  a - b + 1⎦

In [224]:
def swap_row(A: sym.Matrix, 
             idx_1: int, 
             idx_2: int, 
             verbosity: int = 0) -> sym.Matrix:
    A = sym.Matrix(A)
    A[idx_1, :], A[idx_2, :] = A[idx_2, :], A[idx_1, :]
    
    if verbosity >= 1:
        print(f"R_{idx_1+1} <-> R_{idx_2+1}")
    if verbosity >= 2:
        sym.pprint(A)
        print('\n')
    return A

swap_row(A, 0, 2)

⎡  3    3   -b      3       -b    ⎤
⎢                                 ⎥
⎢  a    2    a      a      a - b  ⎥
⎢                                 ⎥
⎢  a    2    a    a + b    a - b  ⎥
⎢                                 ⎥
⎣a + 1  3  a + 1  a + 1  a - b + 1⎦

In [225]:
def reduce_row(A: sym.Matrix, 
               idx_1: int,
               scalar: float, 
               idx_2: int, 
               verbosity: int = 0) -> sym.Matrix:
    A = sym.Matrix(A)
    A[idx_1, :] = A[idx_1, :] - scalar * A[idx_2, :]
    A = sym.simplify(A)

    if verbosity >= 1:
        print(f"R_{idx_1+1} <- R_{idx_1+1} - ({scalar})R_{idx_2+1}")
    if verbosity >= 2:
        sym.pprint(A)
        print('\n')
    return A

reduce_row(A, 3, -1, 0)

⎡   a     2     a        a + b         a - b    ⎤
⎢                                               ⎥
⎢   a     2     a          a           a - b    ⎥
⎢                                               ⎥
⎢   3     3    -b          3            -b      ⎥
⎢                                               ⎥
⎣2⋅a + 1  5  2⋅a + 1  2⋅a + b + 1  2⋅a - 2⋅b + 1⎦

In [57]:
def is_zero(expr, symbolic: bool = True) -> bool:
    # if not symbolic:
    #     # symbols are assumed to be non-zero
    #     return expr.is_zero
    if not isinstance(expr, sym.Expr):
        return expr == 0
    sol = sym.solve(sym.Eq(expr, 0), expr.free_symbols)
    return len(sol) != 0

is_zero(1)

False

In [185]:
def get_pivot_row(A: sym.Matrix, 
                  col_idx: int, 
                  follow_GE: bool = False) -> int:
    m, _ = sym.shape(A)

    # Attempt to pick a pivot column that is a non-zero constant that do
    # not contain any symbols so that it is easier to reduce other rows
    if not follow_GE:
        for row_idx in range(col_idx, m):
            term = A[row_idx, col_idx]
            if term != 0:
                if not isinstance(term, sym.Expr):
                    return row_idx
                elif len(term.free_symbols) == 0:
                    return row_idx

    # Attempt to pick the first non-zero row if all rows contain symbols
    for row_idx in range(col_idx, m):
        term = A[row_idx, col_idx]
        if term != 0:
            return row_idx

    # if entire col is 0, ie no pivot_rows found, return -1
    return -1

# print(A)
B = sym.Matrix([[0, b, 0],
                [b/3 + 1, 0, 0]])
get_pivot_row(B, 0)

1

In [234]:
def ref(A: sym.Matrix, 
        verbosity: int = 0,
        max_tries: int = 2,
        follow_GE: bool = False,
        matrices: int = 1):
    # follow_GE can be set to True to follow Gaussian Elimination strictly
    # LU decomposition
    # matrices is the number of matrices return. 
    # 1. Upper 
    # 2. Perm @ Lower Upper 
    # 3. Perm Lower Upper
    U = A.copy()
    m, n = sym.shape(U)

    I = sym.eye(m)
    L = sym.eye(m)
    P = sym.eye(m)

    # Loop over each column
    for col_idx in range(min(m, n)):
        # Find the first non-zero row in the current column
        pivot_row = get_pivot_row(U, col_idx, follow_GE)
        print(f"{pivot_row=}")

        # if not allow_swap:
        #     if col_idx != pivot_row:
        #         continue

        if pivot_row == -1:
            # If no non-zero pivot is found, continue to the next column
            continue
        
        # Swap the current row with the pivot row if necessary
        if pivot_row != col_idx:
            U = swap_row(U, col_idx, pivot_row, verbosity=verbosity)
            P_elem = swap_row(I.copy(), col_idx, pivot_row)
            P = P @ P_elem
            L = P_elem @ L @ P_elem
        
        # Eliminate the current column in rest of the rows below
        for row_idx in range(col_idx+1, m):
            # reduce the row_idx iteratively via partial fractions to
            # prevent division by a possible 0 term
            tries = 0
            print(row_idx, col_idx)
            print(U[row_idx, col_idx])
            print(U[row_idx, col_idx] == 0)
            while U[row_idx, col_idx] != 0:
                tries += 1
                if tries > max_tries:
                    print(f"ERROR: Max tries exceeded to reduce row {row_idx} with row {col_idx}")
                    break
                scalar = sym.simplify(U[row_idx, col_idx]/U[col_idx, col_idx])
                decomp = sym.apart(scalar) # partial fractions
                # simplify scalar is needed to check if apart has effect
                if scalar != decomp:
                    terms = decomp.args
                else:
                    # there is only 1 term (could be integer or proper fraction)
                    terms = [decomp]
                for term in terms:
                    _, d = sym.fraction(term)
                    # ensure denominator is non-zero so that reduction is valid
                    if not is_zero(d):
                        U = reduce_row(U, row_idx, term, col_idx, verbosity=verbosity)
                        # L = sym.Matrix(L)
                        # L[row_idx, col_idx] = -term
                        # L = reduce_row(L, row_idx, -term, col_idx)
                        elem = reduce_row(I.copy(), row_idx, -term, col_idx)
                        L = L @ elem

                # Cases where pivot row contains symbols such that scalar is a 
                # fraction with symbolic denominator. 
                # To reduce further, can only scale row_idx accordingly
                if U[row_idx, col_idx] != 0:
                    scalar = sym.simplify(U[col_idx, col_idx]/U[row_idx, col_idx])
                    n, d = sym.fraction(scalar)
                    # to scale by n, n can be symbols but not 0
                    # to scale by 1/d, d cannot be 0 both numerically or symbolically
                    if (n != 0) and (not is_zero(d)):
                        U = scale_row(U, row_idx, scalar, verbosity=verbosity)
                        elem = scale_row(I.copy(), row_idx, 1/scalar)
                        L = L @ elem


    if matrices == 1:
        return U
    elif matrices == 2:
        return P @ L, U
    elif matrices == 3:
        return P, L, U
    else:
        return U


# A = sym.Matrix([[a,2,a,a+b,a-b],
#                [a,2,a,a,a-b],
#                [3,3,-b,3,-b],
#                [a+1,3,a+1,a+1,a-b+1]])
# P, L, U = ref(A, allow_swap=False, matrices=3)   

ref(test, verbosity=2, matrices=3)             

pivot_row=0
1 0
0
True
2 0
0
True
3 0
0
True
pivot_row=-1
pivot_row=2
3 2
0
True
pivot_row=3


⎛                            ⎡3  3    -b     3   -b  ⎤⎞
⎜                            ⎢                       ⎥⎟
⎜⎡1  0  0  0⎤  ⎡1  0  0  0⎤  ⎢      2⋅b             b⎥⎟
⎜⎢          ⎥  ⎢          ⎥  ⎢0  0  ─── + 2  0  2 - ─⎥⎟
⎜⎢0  1  0  0⎥  ⎢0  1  0  0⎥  ⎢       3              3⎥⎟
⎜⎢          ⎥, ⎢          ⎥, ⎢                       ⎥⎟
⎜⎢0  0  1  0⎥  ⎢0  0  1  0⎥  ⎢       b          b    ⎥⎟
⎜⎢          ⎥  ⎢          ⎥  ⎢0  0   ─ + 1   0  ─ + 1⎥⎟
⎜⎣0  0  0  1⎦  ⎣0  0  0  1⎦  ⎢       3          3    ⎥⎟
⎜                            ⎢                       ⎥⎟
⎝                            ⎣0  0     0     b    0  ⎦⎠

In [231]:
test[1:, 2:]

⎡2⋅b             b⎤
⎢─── + 2  0  2 - ─⎥
⎢ 3              3⎥
⎢                 ⎥
⎢ b          b    ⎥
⎢ ─ + 1   0  ─ + 1⎥
⎢ 3          3    ⎥
⎢                 ⎥
⎣   0     b    0  ⎦

In [232]:
ref(test[1:, 2:])

pivot_row=0
1 0
pivot_row=2
pivot_row=2


⎡2⋅b             b⎤
⎢─── + 2  0  2 - ─⎥
⎢ 3              3⎥
⎢                 ⎥
⎢   0     b    0  ⎥
⎢                 ⎥
⎢              b  ⎥
⎢   0     0    ─  ⎥
⎣              2  ⎦

In [230]:
test

⎡3  3    -b     3   -b  ⎤
⎢                       ⎥
⎢      2⋅b             b⎥
⎢0  0  ─── + 2  0  2 - ─⎥
⎢       3              3⎥
⎢                       ⎥
⎢       b          b    ⎥
⎢0  0   ─ + 1   0  ─ + 1⎥
⎢       3          3    ⎥
⎢                       ⎥
⎣0  0     0     b    0  ⎦

In [204]:
sym.simplify(P @ L @ U)

⎡  a    2    a    a + b    a - b  ⎤
⎢                                 ⎥
⎢  a    2    a      a      a - b  ⎥
⎢                                 ⎥
⎢  3    3   -b      3       -b    ⎥
⎢                                 ⎥
⎣a + 1  3  a + 1  a + 1  a - b + 1⎦

In [199]:
L1, U1, perm = sym.MatrixBase.LUdecomposition(A)
sym.pprint(L1)
sym.pprint(U1)
P_invs = sym.eye(A.rows).permuteFwd(perm)
sym.pprint(P_invs**-1)

⎡  1    0  0  0⎤
⎢              ⎥
⎢  a           ⎥
⎢  ─    1  0  0⎥
⎢  3           ⎥
⎢              ⎥
⎢a   1         ⎥
⎢─ + ─  1  1  0⎥
⎢3   3         ⎥
⎢              ⎥
⎢  a           ⎥
⎢  ─    1  0  1⎥
⎣  3           ⎦
⎡3    3             -b            3           -b          ⎤
⎢                                                         ⎥
⎢                 a⋅b                     a⋅b             ⎥
⎢0  2 - a         ─── + a         0       ─── + a - b     ⎥
⎢                  3                       3              ⎥
⎢                                                         ⎥
⎢            a⋅b     ⎛a   1⎞           a⋅b     ⎛a   1⎞    ⎥
⎢0    0    - ─── + b⋅⎜─ + ─⎟ + 1  0  - ─── + b⋅⎜─ + ─⎟ + 1⎥
⎢             3      ⎝3   3⎠            3      ⎝3   3⎠    ⎥
⎢                                                         ⎥
⎣0    0              0            b            0          ⎦
⎡0  0  0  1⎤
⎢          ⎥
⎢0  1  0  0⎥
⎢          ⎥
⎢1  0  0  0⎥
⎢          ⎥
⎣0  0  1  0⎦


In [213]:
sym.simplify(U1)

⎡3    3       -b      3      -b     ⎤
⎢                                   ⎥
⎢          a⋅(b + 3)     a⋅b        ⎥
⎢0  2 - a  ─────────  0  ─── + a - b⎥
⎢              3          3         ⎥
⎢                                   ⎥
⎢            b              b       ⎥
⎢0    0      ─ + 1    0     ─ + 1   ⎥
⎢            3              3       ⎥
⎢                                   ⎥
⎣0    0        0      b       0     ⎦

In [215]:
test = U1.subs(a, 2)

⎛                            ⎡3  3    -b     3   -b  ⎤⎞
⎜                            ⎢                       ⎥⎟
⎜⎡1  0  0  0⎤  ⎡1  0  0  0⎤  ⎢      2⋅b             b⎥⎟
⎜⎢          ⎥  ⎢          ⎥  ⎢0  0  ─── + 2  0  2 - ─⎥⎟
⎜⎢0  1  0  0⎥  ⎢0  1  0  0⎥  ⎢       3              3⎥⎟
⎜⎢          ⎥, ⎢          ⎥, ⎢                       ⎥⎟
⎜⎢0  0  1  0⎥  ⎢0  0  1  0⎥  ⎢       b          b    ⎥⎟
⎜⎢          ⎥  ⎢          ⎥  ⎢0  0   ─ + 1   0  ─ + 1⎥⎟
⎜⎣0  0  0  1⎦  ⎣0  0  0  1⎦  ⎢       3          3    ⎥⎟
⎜                            ⎢                       ⎥⎟
⎝                            ⎣0  0     0     b    0  ⎦⎠

In [169]:
def ref(A: sym.Matrix, 
        verbosity: int = 0,
        max_tries: int = 2,
        allow_swap: bool = True,
        result: int = 1):
    # allow_swap can be set to False to disable swapping when computing
    # LU decomposition
    # result is the number of matrices return. 
    # 1. Upper 
    # 2. Perm @ Lower Upper 
    # 3. Perm Lower Upper
    U = A.copy()
    m, n = sym.shape(U)

    L = sym.eye(m)
    P = sym.eye(m)

    # Loop over each column
    for col_idx in range(min(m, n)):
        # Find the first non-zero row in the current column
        pivot_row = get_pivot_row(U, col_idx)

        if not allow_swap:
            if col_idx != pivot_row:
                continue

        if pivot_row == -1:
            # If no non-zero pivot is found, continue to the next column
            continue
        
        # Swap the current row with the pivot row if necessary
        if pivot_row != col_idx:
            U = swap_row(U, col_idx, pivot_row, verbosity=verbosity)
            if verbosity == 1:
                E = swap_row(E, col_idx, pivot_row)
            elif verbosity >= 2:
                P = swap_row(P, col_idx, pivot_row)
                elem = sym.eye(m)
                E = swap_row(elem, col_idx, pivot_row) * E
        
        # Eliminate the current column in rest of the rows below
        for row_idx in range(col_idx+1, m):
            # reduce the row_idx iteratively via partial fractions to
            # prevent division by a possible 0 term
            tries = 0
            while U[row_idx, col_idx] != 0:
                tries += 1
                if tries > max_tries:
                    print(f"ERROR: Max tries exceeded to reduce row {row_idx} with row {col_idx}")
                    break
                scalar = sym.simplify(U[row_idx, col_idx]/U[col_idx, col_idx])
                decomp = sym.apart(scalar) # partial fractions
                # simplify scalar is needed to check if apart has effect
                if scalar != decomp:
                    terms = decomp.args
                else:
                    # there is only 1 term (could be integer or proper fraction)
                    terms = [decomp]
                for term in terms:
                    _, d = sym.fraction(term)
                    # ensure denominator is non-zero so that reduction is valid
                    if not is_zero(d):
                        U = reduce_row(U, row_idx, term, col_idx, verbosity=verbosity)
                        if verbosity >= 1:
                            E = reduce_row(E, row_idx, term, col_idx)

                # Cases where pivot row contains symbols such that scalar is a 
                # fraction with symbolic denominator. 
                # To reduce further, can only scale row_idx accordingly
                if U[row_idx, col_idx] != 0:
                    scalar = sym.simplify(U[col_idx, col_idx]/U[row_idx, col_idx])
                    n, d = sym.fraction(scalar)
                    # to scale by n, n can be symbols but not 0
                    # to scale by 1/d, d cannot be 0 both numerically or symbolically
                    if (n != 0) and (not is_zero(d)):
                        U = scale_row(U, row_idx, scalar, verbosity=verbosity)
                        if verbosity >= 1:
                            E = scale_row(E, row_idx, scalar, 0)                            

    if result == 1:
        return U
    elif result == 2:
        return P @ L, U
    elif verbosity == 3:
        return P, L, U
    else:
        return U


A = sym.Matrix([[a,2,a,a+b,a-b],
               [a,2,a,a,a-b],
               [3,3,-b,3,-b],
               [a+1,3,a+1,a+1,a-b+1]])
P, E, A = ref(A, 2)                

R_1 <-> R_3
⎡  3    3   -b      3       -b    ⎤
⎢                                 ⎥
⎢  a    2    a      a      a - b  ⎥
⎢                                 ⎥
⎢  a    2    a    a + b    a - b  ⎥
⎢                                 ⎥
⎣a + 1  3  a + 1  a + 1  a - b + 1⎦


R_2 -> R_2 - (a/3)R_1
⎡  3      3       -b        3        -b     ⎤
⎢                                           ⎥
⎢              a⋅(b + 3)         a⋅b        ⎥
⎢  0    2 - a  ─────────    0    ─── + a - b⎥
⎢                  3              3         ⎥
⎢                                           ⎥
⎢  a      2        a      a + b     a - b   ⎥
⎢                                           ⎥
⎣a + 1    3      a + 1    a + 1   a - b + 1 ⎦


R_3 -> R_3 - (a/3)R_1
⎡  3      3       -b        3        -b     ⎤
⎢                                           ⎥
⎢              a⋅(b + 3)         a⋅b        ⎥
⎢  0    2 - a  ─────────    0    ─── + a - b⎥
⎢                  3              3         ⎥
⎢                                           

NotImplementedError: multivariate partial fraction decomposition

In [166]:
E

⎡  a           ⎤
⎢  ─    1  0  1⎥
⎢  3           ⎥
⎢              ⎥
⎢  a           ⎥
⎢  ─    1  0  0⎥
⎢  3           ⎥
⎢              ⎥
⎢  1    0  0  0⎥
⎢              ⎥
⎢a   1         ⎥
⎢─ + ─  1  1  0⎥
⎣3   3         ⎦

In [155]:
P**-1

⎡0  0  0  1⎤
⎢          ⎥
⎢0  1  0  0⎥
⎢          ⎥
⎢1  0  0  0⎥
⎢          ⎥
⎣0  0  1  0⎦

In [167]:
sym.simplify(P * E * A)

⎡  3    3   -b      3       -b    ⎤
⎢                                 ⎥
⎢  a    2    a      a      a - b  ⎥
⎢                                 ⎥
⎢a + 1  3  a + 1  a + 1  a - b + 1⎥
⎢                                 ⎥
⎣  a    2    a    a + b    a - b  ⎦