In [1]:
import numpy as np
import common
from common import reset_counters, check_mat, randomize_matrix
from multiplication import strassen
from functions import recursive_inverse

# Import counters for easy access
counter_add = common.counter_add
counter_mul = common.counter_mul
counter_sub = common.counter_sub

## Recursive Matrix Inversion Test


In [2]:
A = np.array([[2.0, 1.0], [1.0, 2.0]])
reset_counters()

A_inv = recursive_inverse(A, strassen)
A_inv_numpy = np.linalg.inv(A)

print("Original matrix A:")
print(A)
print("\nRecursive inverse:")
print(A_inv)
print("\nNumPy inverse:")
print(A_inv_numpy)
print("\nVerification (A * A_inv should be identity):")
print(strassen(A, A_inv))
print("\nCounters:")
print(f"Additions: {common.counter_add}, Multiplications: {common.counter_mul}, Subtractions: {common.counter_sub}")


Original matrix A:
[[2. 1.]
 [1. 2.]]

Recursive inverse:
[[ 0.66666667 -0.33333333]
 [-0.33333333  0.66666667]]

NumPy inverse:
[[ 0.66666667 -0.33333333]
 [-0.33333333  0.66666667]]

Verification (A * A_inv should be identity):
[[1. 0.]
 [0. 1.]]

Counters:
Additions: 22, Multiplications: 28, Subtractions: 0


In [3]:
# Test with a 4x4 matrix
A = np.array([[4.0, 1.0, 2.0, 0.5],
              [1.0, 3.0, 0.5, 1.0],
              [2.0, 0.5, 5.0, 1.0],
              [0.5, 1.0, 1.0, 2.0]])

reset_counters()

A_inv = recursive_inverse(A, strassen)
A_inv_numpy = np.linalg.inv(A)

print("Original matrix A:")
print(A)
print("\nRecursive inverse:")
print(A_inv)
print("\nNumPy inverse:")
print(A_inv_numpy)
print("\nDifference (should be close to zero):")
print(np.abs(A_inv - A_inv_numpy))
print("\nVerification (A * A_inv should be identity):")
identity_check = strassen(A, A_inv)
print(identity_check)
print("\nIs close to identity?", np.allclose(identity_check, np.eye(4), atol=1e-6))
print("\nCounters:")
print(f"Additions: {common.counter_add}, Multiplications: {common.counter_mul}, Subtractions: {common.counter_sub}")


Original matrix A:
[[4.  1.  2.  0.5]
 [1.  3.  0.5 1. ]
 [2.  0.5 5.  1. ]
 [0.5 1.  1.  2. ]]

Recursive inverse:
[[ 0.33676333 -0.101029   -0.13096352  0.03180543]
 [-0.101029    0.4303087   0.03928906 -0.20954163]
 [-0.13096352  0.03928906  0.27315248 -0.12347989]
 [ 0.03180543 -0.20954163 -0.12347989  0.6585594 ]]

NumPy inverse:
[[ 0.33676333 -0.101029   -0.13096352  0.03180543]
 [-0.101029    0.4303087   0.03928906 -0.20954163]
 [-0.13096352  0.03928906  0.27315248 -0.12347989]
 [ 0.03180543 -0.20954163 -0.12347989  0.6585594 ]]

Difference (should be close to zero):
[[5.55111512e-17 1.38777878e-17 2.77555756e-17 0.00000000e+00]
 [0.00000000e+00 5.55111512e-17 0.00000000e+00 0.00000000e+00]
 [2.77555756e-17 0.00000000e+00 0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 1.38777878e-17 0.00000000e+00]]

Verification (A * A_inv should be identity):
[[ 1.00000000e+00  2.77555756e-17  1.66533454e-16  0.00000000e+00]
 [-2.77555756e-17  1.00000000e+00  2.77555756e-17  0.

## Verify recursive inversion for various matrix sizes


In [4]:
GREEN = '\033[92m'
RED = '\033[91m'
ENDCOLOR = '\033[0m'

test_count = 50
correct_count = 0

for n in range(1, test_count + 1):
    print(f"Verifying n = {n}")
    
    # Generate a random invertible matrix
    # Use a matrix with larger diagonal values to ensure invertibility
    A = np.random.rand(n, n) * 0.5 + np.eye(n) * 2.0
    
    reset_counters()
    
    try:
        A_inv = recursive_inverse(A, strassen)
        A_inv_numpy = np.linalg.inv(A)
        
        # Check if the result is close to NumPy's inverse
        if check_mat(A_inv, A_inv_numpy, tol=1e-5):
            # Also verify that A * A_inv is close to identity
            identity_check = strassen(A, A_inv)
            if check_mat(identity_check, np.eye(n), tol=1e-4):
                correct_count += 1
            else:
                print(f"  Warning: A * A_inv not close to identity for n={n}")
        else:
            print(f"  Warning: Result differs from NumPy for n={n}")
    except Exception as e:
        print(f"  Error for n={n}: {e}")

print(f'\n{correct_count}/{test_count} correct')
if correct_count == test_count:
    print(f'{GREEN}Success!{ENDCOLOR}')
else:
    print(f'{RED}Test failed!{ENDCOLOR}')


Verifying n = 1
Verifying n = 2
Verifying n = 3
Verifying n = 4
Verifying n = 5
Verifying n = 6
Verifying n = 7
Verifying n = 8
Verifying n = 9
Verifying n = 10
Verifying n = 11
Verifying n = 12
Verifying n = 13
Verifying n = 14
Verifying n = 15
Verifying n = 16
Verifying n = 17
Verifying n = 18
Verifying n = 19
Verifying n = 20
Verifying n = 21
Verifying n = 22
Verifying n = 23
Verifying n = 24
Verifying n = 25
Verifying n = 26
Verifying n = 27
Verifying n = 28
Verifying n = 29
Verifying n = 30
Verifying n = 31
Verifying n = 32
Verifying n = 33
Verifying n = 34
Verifying n = 35
Verifying n = 36
Verifying n = 37
Verifying n = 38
Verifying n = 39
Verifying n = 40
Verifying n = 41
Verifying n = 42
Verifying n = 43
Verifying n = 44
Verifying n = 45
Verifying n = 46
Verifying n = 47
Verifying n = 48
Verifying n = 49
Verifying n = 50

50/50 correct
[92mSuccess![0m
