In [37]:
import numpy as np
import matplotlib.pyplot as plt

from typing import Callable
from numpy import ndarray as Matrix

In [50]:
def multiply_classic(A: Matrix, B: Matrix):
    """
    Multiplies `A` times `B` with a classic algorithm, where 
    `A` is an `m x n` matrix and `B` is an `n x l` matrix
    """

    m, n, l = A.shape[0], A.shape[1], B.shape[1]
    multiply = np.empty((m, l))
    sum = 0
    for i in range(m):                      # rows in multiply
        for j in range(l):                  # columns in multiply
            for k in range(n):              # columns in A and rows in B
                sum += A[i, k] * B[k, j]
            multiply[i, j] = sum
            sum = 0

    return multiply

In [None]:
def multiply_strassen(A: Matrix, B: Matrix):
    pass

In [51]:
def test_multiplication(multiplication_algorithm: Callable, A: Matrix | None = None, B: Matrix | None = None):

    A = np.reshape(np.arange(1, 4 * 3 + 1), (4, 3)) if not A else A
    B = np.reshape(np.arange(1, 3 * 5 + 1), (3, 5)) if not B else B
    AB = multiplication_algorithm(A, B)
    AB_numpy = A @ B

    print(f"Matrix A:\n{A}")
    print(f"\nMatrix B:\n{B}")
    print(f"\nMatrix AxB:\n{AB}")
    print(f"\nMatrix AxB (numpy):\n{AB_numpy}")

    assert np.array_equal(AB, AB_numpy), "Results are inconsistent! Check your algorithm!"

    print("\nTest Passed!")

In [52]:
test_multiplication(multiply_classic)


Matrix A:
[[ 1  2  3]
 [ 4  5  6]
 [ 7  8  9]
 [10 11 12]]

Matrix B:
[[ 1  2  3  4  5]
 [ 6  7  8  9 10]
 [11 12 13 14 15]]

Matrix AxB:
[[ 46.  52.  58.  64.  70.]
 [100. 115. 130. 145. 160.]
 [154. 178. 202. 226. 250.]
 [208. 241. 274. 307. 340.]]

Matrix AxB (numpy):
[[ 46  52  58  64  70]
 [100 115 130 145 160]
 [154 178 202 226 250]
 [208 241 274 307 340]]

Test Passed!
