In [4]:
import numpy as np

def strassen(A, B):
    A = np.array(A)
    B = np.array(B)

    n = A.shape[0]

    # Base case
    if n == 1:
        return A * B

    mid = n // 2

    # Split into quadrants
    A11 = A[:mid, :mid]
    A12 = A[:mid, mid:]
    A21 = A[mid:, :mid]
    A22 = A[mid:, mid:]

    B11 = B[:mid, :mid]
    B12 = B[:mid, mid:]
    B21 = B[mid:, :mid]
    B22 = B[mid:, mid:]

    # ---- Strassen products (renamed) ----
    P = strassen(A11 + A22, B11 + B22)
    Q = strassen(A21 + A22, B11)
    R = strassen(A11, B12 - B22)
    S = strassen(A22, B21 - B11)
    T = strassen(A11 + A12, B22)
    U = strassen(A21 - A11, B11 + B12)
    V = strassen(A12 - A22, B21 + B22)

    # Combine results (C11, C12, C21, C22)
    C11 = P + S - T + V
    C12 = R + T
    C21 = Q + S
    C22 = P + R - Q + U

    # Reconstruct matrix
    C = np.vstack([
        np.hstack([C11, C12]),
        np.hstack([C21, C22])
    ])

    return C


# -------- MAIN --------
A = [[1, 2],
     [3, 4]]

B = [[5, 6],
     [7, 8]]

print("Result:")
print(strassen(A, B))


Result:
[[19 22]
 [43 50]]
