In [87]:
import numpy as np
import time

n = int(2 ** 5)
A = np.random.randn(n, n)
B = np.random.randn(n, n)


In [100]:
# O(n^3)
def multiply(n, A, B):
    # base case O(1)
    if n == 1:
        return A * B

    # recursive case
    # divide and conquer 8*T(n/2)
    A11 = A[:n // 2, :n // 2]
    B11 = B[:n // 2, :n // 2]
    A12 = A[:n // 2, n // 2:]
    B12 = B[:n // 2, n // 2:]
    A21 = A[n // 2:, :n // 2]
    B21 = B[n // 2:, :n // 2]
    A22 = A[n // 2:, n // 2:]
    B22 = B[n // 2:, n // 2:]
    M1 = multiply(n // 2, A11, B11)
    M2 = multiply(n // 2, A11, B12)
    M3 = multiply(n // 2, A21, B11)
    M4 = multiply(n // 2, A21, B12)
    M5 = multiply(n // 2, A12, B21)
    M6 = multiply(n // 2, A12, B22)
    M7 = multiply(n // 2, A22, B21)
    M8 = multiply(n // 2, A22, B22)
    # add/combine O(n^2)
    C11 = M1 + M5
    C12 = M2 + M6
    C21 = M3 + M7
    C22 = M4 + M8
    C = np.vstack((np.hstack((C11, C12)),
                   np.hstack((C21, C22))))
    return C


# O(n^2.807)


def strassen(n, A, B):
    # base case
    if n == 1:
        return A * B

    # recursive case
    # intuition: ac + ad + bc + bd = (a+b)(c+d)

    # divide and conquer 7*T(n/2)
    A11 = A[:n // 2, :n // 2]
    B11 = B[:n // 2, :n // 2]
    A12 = A[:n // 2, n // 2:]
    B12 = B[:n // 2, n // 2:]
    A21 = A[n // 2:, :n // 2]
    B21 = B[n // 2:, :n // 2]
    A22 = A[n // 2:, n // 2:]
    B22 = B[n // 2:, n // 2:]
    M1 = multiply(n // 2, A11 + A22, B11 + B22)
    M2 = multiply(n // 2, A21 + A22, B11)
    M3 = multiply(n // 2, A11, B12 - B22)
    M4 = multiply(n // 2, A22, B21 - B11)
    M5 = multiply(n // 2, A11 + A12, B22)
    M6 = multiply(n // 2, A21 - A11, B11 + B12)
    M7 = multiply(n // 2, A12 - A22, B21 + B22)

    C11 = M1 + M4 - M5 + M7
    C12 = M3 + M5
    C21 = M2 + M4
    C22 = M1 - M2 + M3 + M6

    # add/combine O(n^2)
    C = np.vstack((np.hstack((C11, C12)),
                   np.hstack((C21, C22))))
    return C

# Comparison


## Brute Force


In [101]:
tik = time.perf_counter()

C = np.zeros((n, n))
for i in range(n):
    for j in range(n):
        C[i, j] = 0
        for k in range(n):
            C[i, j] += A[i, k] * B[k, j]
print(C[0, 0])

tok = time.perf_counter()
print(tok - tik)


1.5390452627490956
0.018070875000375963


## Divide & Conquer


In [102]:
tik = time.perf_counter()

print(multiply(n, A, B)[0, 0])

tok = time.perf_counter()
print(tok - tik)


1.5390452627490956
0.0792071659998328


## Strassen's


In [103]:
tik = time.perf_counter()

print(strassen(n, A, B)[0, 0])

tok = time.perf_counter()
print(tok - tik)


1.539045262749096
0.07124979100080964


## Numpy


In [104]:
tik = time.perf_counter()

print((np.dot(A, B))[0, 0])

tok = time.perf_counter()
print(tok - tik)


1.5390452627490956
0.0018986249997396953


![](res/MatrixMultComplexity.jpeg)