In [108]:
import numpy as np

from typing import Callable
from numpy import ndarray as Matrix

In [109]:
def multiply_classic(A: Matrix, B: Matrix):
    """
    Multiplies `A` times `B` with a classic algorithm, where 
    `A` is an `m x n` matrix and `B` is an `n x l` matrix.
    """

    m, n, l = A.shape[0], A.shape[1], B.shape[1]
    multiply = np.empty((m, l))
    sum = 0
    for i in range(m):                      # rows in multiply
        for j in range(l):                  # columns in multiply
            for k in range(n):              # columns in A and rows in B
                sum += A[i, k] * B[k, j]
            multiply[i, j] = sum
            sum = 0

    return multiply

In [115]:
def multiply_strassen(A: Matrix, B: Matrix):
    """
    Multiplies `A` times `B` with a Strassen algorithm, where
    both `A` and `B` are square `n x n` matrices.
    """

    def strassen(A: Matrix, B: Matrix, n: int):

        if n == 1:
            return A * B
        
        m = n // 2

        A11 = A[:m, :m]
        A12 = A[:m, m:]
        A21 = A[m:, :m]
        A22 = A[m:, m:]

        B11 = B[:m, :m]
        B12 = B[:m, m:]
        B21 = B[m:, :m]
        B22 = B[m:, m:]

        P1 = strassen(A11 + A22, B11 + B22, m)
        P2 = strassen(A21 + A22, B11, m)
        P3 = strassen(A11, B12 - B22, m)
        P4 = strassen(A22, B21 - B11, m)
        P5 = strassen(A11 + A12, B22, m)
        P6 = strassen(A21 - A11, B11 + B12, m)
        P7 = strassen(A12 - A22, B21 + B22, m)

        C = np.concatenate((
            np.concatenate((P1 + P4 - P5 + P7, P3 + P5), axis=1),
            np.concatenate((P2 + P4, P1 - P2 + P3 + P6), axis=1)
        ), axis=0)

        return C

    return strassen(A, B, A.shape[0])


In [128]:
def test_multiplication(multiplication_algorithm: Callable, A: Matrix | None = None, B: Matrix | None = None):

    print(f"=== Testing {multiplication_algorithm.__name__} algorithm ===")

    A = np.reshape(np.arange(1, 4 * 4 + 1), (4, 4)) if not A else A
    B = np.reshape(np.arange(17, 4 * 4 + 17), (4, 4)) if not B else B

    assert A.shape == B.shape and A.shape[0] == A.shape[1], "Matrices are not square!"
    assert np.isclose(np.log2(A.shape[0]) % 1, 0), "Matrices size is not a power of 2!"
    
    AB = multiplication_algorithm(A, B)
    AB_numpy = A @ B

    print(f"\nMatrix A:\n{A}")
    print(f"\nMatrix B:\n{B}")
    print(f"\nMatrix AxB:\n{AB}")
    print(f"\nMatrix AxB (numpy):\n{AB_numpy}")

    assert np.array_equal(AB, AB_numpy), "Results are inconsistent! Check your algorithm!"

    print("\nTest Passed!")

In [129]:
test_multiplication(multiply_classic)

=== Testing multiply_classic algorithm ===

Matrix A:
[[ 1  2  3  4]
 [ 5  6  7  8]
 [ 9 10 11 12]
 [13 14 15 16]]

Matrix B:
[[17 18 19 20]
 [21 22 23 24]
 [25 26 27 28]
 [29 30 31 32]]

Matrix AxB:
[[ 250.  260.  270.  280.]
 [ 618.  644.  670.  696.]
 [ 986. 1028. 1070. 1112.]
 [1354. 1412. 1470. 1528.]]

Matrix AxB (numpy):
[[ 250  260  270  280]
 [ 618  644  670  696]
 [ 986 1028 1070 1112]
 [1354 1412 1470 1528]]

Test Passed!


In [130]:
test_multiplication(multiply_strassen)

=== Testing multiply_strassen algorithm ===

Matrix A:
[[ 1  2  3  4]
 [ 5  6  7  8]
 [ 9 10 11 12]
 [13 14 15 16]]

Matrix B:
[[17 18 19 20]
 [21 22 23 24]
 [25 26 27 28]
 [29 30 31 32]]

Matrix AxB:
[[ 250  260  270  280]
 [ 618  644  670  696]
 [ 986 1028 1070 1112]
 [1354 1412 1470 1528]]

Matrix AxB (numpy):
[[ 250  260  270  280]
 [ 618  644  670  696]
 [ 986 1028 1070 1112]
 [1354 1412 1470 1528]]

Test Passed!
