In [1]:
# n x n, 0 값을 가진 행렬
def initMatrix(n):
    return [[0 for _ in range(n)] for _ in range(n)]


def add(M1, M2, n):
    temp = initMatrix(n)
    for i in range(n):
        for j in range(n):
            temp[i][j] = M1[i][j] + M2[i][j]
    return temp


def subtract(M1, M2, n):
    temp = initMatrix(n)
    for i in range(n):
        for j in range(n):
            temp[i][j] = M1[i][j] - M2[i][j]
    return temp


In [15]:
def strassen(A, B, n):
    
    # 기본 케이스 1*1을 곱하여 return
    if n == 1:
        C = initMatrix(1)
        C[0][0] = A[0][0] * B[0][0]
        return C
    
    # 변수 초기화
    C = initMatrix(n)
    k = n // 2  # 2 x 2 행렬이면 각각 1 x 1 서브 행렬임

    A11 = initMatrix(k)
    A12 = initMatrix(k)
    A21 = initMatrix(k)
    A22 = initMatrix(k)
    B11 = initMatrix(k)
    B12 = initMatrix(k)
    B21 = initMatrix(k)
    B22 = initMatrix(k)


    for i in range(k):
        for j in range(k):
            A11[i][j] = A[i][j]
            A12[i][j] = A[i][k + j]
            A21[i][j] = A[k + i][j]
            A22[i][j] = A[k + i][k + j]

            B11[i][j] = B[i][j]
            B12[i][j] = B[i][k + j]
            B21[i][j] = B[k + i][j]
            B22[i][j] = B[k + i][k + j]

    # P값 계산        
    P1 = strassen(A11, subtract(B12, B22, k), k)
    P2 = strassen(add(A11, A12, k), B22, k)
    P3 = strassen(add(A21, A22, k), B11, k)
    P4 = strassen(A22, subtract(B21, B11, k), k)
    P5 = strassen(add(A11, A22, k), add(B11, B22, k), k)
    P6 = strassen(subtract(A12, A22, k), add(B21, B22, k), k)
    P7 = strassen(subtract(A11, A21, k), add(B11, B12, k), k)
    
    # A*B=C 계산
    C11 = subtract(add(add(P5, P4, k), P6, k), P2, k)
    C12 = add(P1, P2, k)
    C21 = add(P3, P4, k)
    C22 = subtract(subtract(add(P5, P1, k), P3, k), P7, k)

    for i in range(k):
        for j in range(k):
            C[i][j] = C11[i][j]
            C[i][j + k] = C12[i][j]
            C[k + i][j] = C21[i][j]
            C[k + i][k + j] = C22[i][j]
            
    return C


In [16]:

A = [[1, 3], [7, 5]]
B = [[6, 8], [4, 2]]

print(strassen(A, B, 2))

"""
[[18, 14],
 [62, 66]]
"""

# 128 x 128 행렬끼리 곱하면 약 8초 소요


[[18, 14], [62, 66]]


'\n[[18, 14],\n [62, 66]]\n'

In [29]:
import numpy as np

# 런타임 비교
from timeit import default_timer

N = 2 ** 5 # 행렬의 크기
A = [[np.random.randint(1, 1000) for _ in range(N)] for _ in range(N)]
B = [[np.random.randint(1, 1000) for _ in range(N)] for _ in range(N)]

start = default_timer()
result1 = strassen(A, B, N)
print(f"{default_timer() - start:.5f}s")

start = default_timer()
result2 = np.dot(A,B)
print(f"{default_timer() - start:.5f}s")

0.13163s
0.00041s
