In [99]:
import random
import time
from functools import reduce

In [100]:
def get_matrix(m,n,type='random'):
    matrix = []
    if type=='random':
        for i in range(m):
            row = []
            for j in range(n):
                row.append(random.randint(-50, 51))
            matrix.append(row)
    elif type=='zero':
        matrix = [[0 for j in range(n)] for i in range(m)]
    return matrix


In [101]:
def matrix_dot(A, B):
    m = len(A)
    p = len(A[0])
    n = len(B[0])
    C = [[sum([A[i][k]*B[k][j] for k in range(p)]) for j in range(n)] for i in range(m)]
    return C


In [102]:
def matrix_split(A):
    # A为n阶方阵
    n = len(A)
    A11 = []
    A12 = []
    A21 = []
    A22 = []
    for i in range(int(n/2)):
        A11.append(A[i][0:int(n/2)])
        A12.append(A[i][int(n/2):])
    for i in range(int(n/2),n):
        A21.append(A[i][0:int(n/2)])
        A22.append(A[i][int(n/2):])
    return (A11,A12,A21,A22)


In [103]:
def matrix_join(A11,A12,A21,A22):
    m1 = len(A11)
    m2 = len(A21)
    matrix = []
    for i in range(m1):
        matrix.append(A11[i]+A12[i])
    for i in range(m2):
        matrix.append(A21[i]+A22[i])
    return matrix



In [104]:
def matrix_add(A, B):
    m = len(A)
    n = len(A[0])
    C = [[A[i][j]+B[i][j] for j in range(n)] for i in range(m)]
    return C

In [105]:
def matrix_reduce(A, B):
    m = len(A)
    n = len(A[0])
    C = [[A[i][j]-B[i][j] for j in range(n)] for i in range(m)]
    return C



In [106]:
def Strassen(A, B):
    #https://www.cnblogs.com/hdk1993/p/4552534.html
    if len(A) == 2 and len(A[0]) == 2:
        return matrix_dot(A, B)
    A11, A12, A21, A22 = matrix_split(A)
    B11, B12, B21, B22 = matrix_split(B)
    M1 = Strassen(matrix_add(A11, A22), matrix_add(B11, B22))
    M2 = Strassen(matrix_add(A21, A22), B11)
    M3 = Strassen(A11, matrix_reduce(B12, B22))
    M4 = Strassen(A22, matrix_reduce(B21, B11))
    M5 = Strassen(matrix_add(A11, A12), B22)
    M6 = Strassen(matrix_reduce(A21, A11), matrix_add(B11, B12))
    M7 = Strassen(matrix_reduce(A12, A22), matrix_add(B21, B22))
    C11 = matrix_add(matrix_add(M1, M4), matrix_reduce(M7, M5))
    C12 = matrix_add(M3, M5)
    C21 = matrix_add(M2, M4)
    C22 = matrix_add(matrix_reduce(M1, M2), matrix_add(M3, M6))
    C = matrix_join(C11, C12, C21, C22)
    
    return C

In [107]:
n = 256
A = get_matrix(n,n)
B = get_matrix(n,n)

In [108]:
#print(A)
#print(B)
#print(matrix_dot(A,B))

In [109]:
t0 = time.time()
dot_1 = matrix_dot(A, B)
print(time.time()-t0)
t0 = time.time()
dot_2 = Strassen(A, B)
print(time.time()-t0)
if dot_1 == dot_2:
    print("Correct")


3.7520499229431152
16.087949514389038
Correct
