In [1]:
import numpy
import math
import matplotlib.pyplot
import time

In [2]:
def mx_xpanshn(M1,M2):
    
    assert(M1.shape[1]==M2.shape[0])
    
    mxd = numpy.max([*M1.shape,*M2.shape])
    d = numpy.int32(2**(numpy.ceil(numpy.log2(mxd))))
    
    nM1 = numpy.zeros((d,d))
    nM2 = numpy.zeros((d,d))
    
    nM1[:M1.shape[0],:M1.shape[1]] = M1
    nM2[:M2.shape[0],:M2.shape[1]] = M2
    
    return nM1, nM2

In [3]:
def Strassen(A,B):
    
    dims = [A.shape[0], A.shape[1], B.shape[0], B.shape[1]]
    
    if not ((len(set(dims))) and (numpy.log2(A.shape[0])%1==0)):
        A,B = mx_xpanshn(A,B)
    
    N = A.shape[0]
    
    if N == 1: return A*B
    
    n = N//2
    
    A11 = A[:n,:n]
    A12 = A[:n,n:]
    A21 = A[n:,:n]
    A22 = A[n:,n:]
    
    B11 = B[:n,:n]
    B12 = B[:n,n:]
    B21 = B[n:,:n]
    B22 = B[n:,n:]
    
    P1 = Strassen(A11 + A22, B11 + B22)
    P2 = Strassen(A21 + A22, B11)
    P3 = Strassen(A11, B12 - B22)
    P4 = Strassen(A22, B21 - B11)
    P5 = Strassen(A11 + A12, B22)
    P6 = Strassen(A21 - A11, B11 + B12)
    P7 = Strassen(A12 - A22, B21 + B22)
    
    C11 = P1 + P4 - P5 + P7 
    C12 = P3 + P5
    C21 = P2 + P4
    C22 = P1 - P2 + P3 + P6
    
    C = numpy.zeros((N,N))
    
    C[:n,:n] = C11 
    C[:n,n:] = C12
    C[n:,:n] = C21
    C[n:,n:] = C22
    
    return C[:dims[0], :dims[3]]


In [4]:
def multipl_std(A,B):
    
    dims= [A.shape[0], A.shape[1], B.shape[0], B.shape[1]]
    
    if not ((len(set(dims))) and (numpy.log2(A.shape[0])%1==0)):
        A,B = mx_xpanshn(A,B)
    
    N = A.shape[0]
    
    if N == 1: return A*B
    
    n = N//2
    
    A11 = A[:n,:n]
    A12 = A[:n,n:]
    A21 = A[n:,:n]
    A22 = A[n:,n:]
    
    B11 = B[:n,:n]
    B12 = B[:n,n:]
    B21 = B[n:,:n]
    B22 = B[n:,n:]
    
    C11 = multipl_std(A11, B11) + multipl_std(A12, B21)
    C12 = multipl_std(A11, B12) + multipl_std(A12, B22)
    C21 = multipl_std(A21, B11) + multipl_std(A22, B21)
    C22 = multipl_std(A21, B12) + multipl_std(A22, B22)
    
    C = numpy.zeros((N,N))
    
    C[:n,:n] = C11 
    C[:n,n:] = C12
    C[n:,:n] = C21
    C[n:,n:] = C22
    
    return C[:dims[0], :dims[3]]

In [5]:
A = numpy.random.rand(3,2)

In [6]:
B = numpy.random.rand(2,3)

In [7]:
A@B

array([[1.34963848, 0.12465513, 0.39732044],
       [0.76595969, 0.05972673, 0.39785614],
       [0.87135294, 0.06830355, 0.44698952]])

In [8]:
Strassen(A, B)

array([[1.34963848, 0.12465513, 0.39732044],
       [0.76595969, 0.05972673, 0.39785614],
       [0.87135294, 0.06830355, 0.44698952]])

In [9]:
multipl_std(A,B)

array([[1.34963848, 0.12465513, 0.39732044],
       [0.76595969, 0.05972673, 0.39785614],
       [0.87135294, 0.06830355, 0.44698952]])

In [None]:
N = 1024
Strassen_timeS = []
multipl_std_timeS = []

total_process_time = 0
process_start  = time.time()

for i in range(1,N):
    
    strassen_total_time = 0
    multipl_std_total_time = 0
   
    randA = numpy.random.rand(2**i, 2**i)
    randB = numpy.random.rand(2**i, 2**i)

    start = time.time()
    Strassen(randA, randB)
    finish = time.time()
    strassen_total_time += finish - start

    start = time.time()
    multipl_std(randA, randB)
    finish = time.time()
    multipl_std_total_time += finish - start

    Strassen_timeS.append(strassen_total_time)
    multipl_std_timeS.append(multipl_std_total_time)

process_finish  = time.time()
total_process_time += process_finish - process_start


matplotlib.pyplot.plot(2**numpy.arange(1,N),Strassen_timeS)
matplotlib.pyplot.plot(2**numpy.arange(1,N),multipl_std_timeS)
matplotlib.pyplot.legend(['Strassen','standart mltpl'])

#matplotlib.pyplot.plot(2**numpy.arange(1,N),Strassen_timeS)
 
                                   # Расчёт времени всего процесса 

tp_hrs = total_process_time//3600
tp_mnts = (total_process_time - tp_hrs*3600)//60
tp_scnds = total_process_time - tp_hrs*3600 - tp_mnts*60
print("Расчеты заняли всего лишь", int(tp_hrs), "часов", int(tp_mnts), "минут", int(tp_scnds), "секунд!" ) 
    
