In [4]:
import numpy as np
from copy import copy

In [519]:
def naive_mult(a, b):
    
    n_rows = a.shape[0]
    n_cols = b.shape[1]
    res = np.zeros((n_rows, n_cols))
    
    for i in range(n_rows):
        for j in range(n_cols):
            for k in range( a.shape[1] ):
                res[i, j] += a[i, k] * b[k, j]
    
    return res    

class Split:
    def init(self, i=0, j=0, step=0):
        self.i = i
        self.j = j
        self.step = step
    
    def isscalar(self):
        return self.step == 1
        
    def get_start(self):
        return self.i, self.j

    def to_str(self):
        return f"{self.i, self.j}, {self.step}"
    
def split_matrix(split):
    step = split.step // 2
    
    top_left = Split()
    top_left.i = split.i
    top_left.j = split.j
    top_left.step = step

    top_right = Split()
    top_right.i = split.i
    top_right.j = split.j + step
    top_right.step = step
    
    bottom_left = Split()
    bottom_left.i = split.i + step
    bottom_left.j = split.j
    bottom_left.step = step
    
    bottom_right = Split()
    bottom_right.i = split.i + step
    bottom_right.j = split.j + step
    bottom_right.step = step
    
    return np.array([[top_left, top_right], [bottom_left, bottom_right]])
              

def recursive_mult(M_1, M_2, A_idx=None, B_idx=None):
    n = min(M_1.shape[0], M_2.shape[0])
    
    if A_idx is None:
        A_idx = Split()
        A_idx.init()
        A_idx.step = n
    
    if B_idx is None:
        B_idx = Split()
        B_idx.init()
        B_idx.step = n
        
    def rMult(A_idx, B_idx):
        if A_idx.isscalar():
            C = np.array([[M_1[A_idx.get_start()] * M_2[B_idx.get_start()]]])
        else:
            new_A_split = split_matrix(A_idx)
            new_B_split = split_matrix(B_idx)

            C_tl = rMult(new_A_split[0, 0], new_B_split[0, 0]) + rMult(new_A_split[0, 1], new_B_split[1, 0])
            C_tr = rMult(new_A_split[0, 0], new_B_split[0, 1]) + rMult(new_A_split[0, 1], new_B_split[1, 1])
            C_bl = rMult(new_A_split[1, 0], new_B_split[0, 0]) + rMult(new_A_split[1, 1], new_B_split[1, 0])
            C_br = rMult(new_A_split[1, 0], new_B_split[0, 1]) + rMult(new_A_split[1, 1], new_B_split[1, 1])

            C = build_matrix([C_tl, C_tr, C_bl, C_br])
        return C

    return rMult(A_idx, B_idx)

   
def strassen(A, B):
    
    def build_matrix(submatrices):
        C_tl, C_tr, C_bl, C_br = submatrices
        n = C_tl.shape[0]
        new_n = n*2

        new_C = np.zeros( (new_n, new_n) )
        for i in range(n):
            for j in range(n):
                new_C[i, j] = C_tl[i, j]
                new_C[i, j + n] = C_tr[i, j]
                new_C[i + n, j] = C_bl[i, j]
                new_C[i + n, j + n] = C_br[i, j]

        return new_C

    def add(M, A_split, B_split):
        step = A_split.step

        res = np.zeros( (step, step) )
        for i in range(step):
            for j in range(step):
                res[i, j] = M[A_split.i + i, A_split.j + j] + M[B_split.i + i, B_split.j + j]

        return res

    def substract(M, A_split, B_split):
        step = A_split.step

        res = np.zeros( (step, step) )
        for i in range(step):
            for j in range(step):
                res[i, j] = M[A_split.i + i, A_split.j + j] - M[B_split.i + i, B_split.j + j]

        return res
    
    n = A.shape[0]

    A_s = Split()
    A_s.init()
    A_s.step = n
    
    B_s = Split()
    B_s.init()
    B_s.step = n
    
    A_split = split_matrix(A_s)
    B_split = split_matrix(B_s)
    
    S_1 = substract(copy(B), B_split[0, 1], B_split[1, 1])
    
    S_2 = add(copy(A), A_split[0, 0], A_split[0, 1])
    
    S_3 = add(copy(A), A_split[1, 0], A_split[1, 1])
    
    S_4 = substract(copy(B), B_split[1, 0], B_split[0, 0])
    
    S_5 = add(copy(A), A_split[0, 0], A_split[1, 1])
    
    S_6 = add(copy(B), B_split[0, 0], B_split[1, 1])
    
    S_7 = substract(copy(A), A_split[0, 1], A_split[1, 1])
    
    S_8 = add(copy(B), B_split[1, 0], B_split[1, 1])
    
    S_9 = substract(copy(A), A_split[0, 0], A_split[1, 0])
    
    S_10 = add(copy(B), B_split[0, 0], B_split[0, 1])
       
    """
    P_1 = rMult(A_split[0, 0], B_split[0, 1]) - rMult(A_split[0, 0], B_split[1, 1])
    
    P_2 = rMult(A_split[0, 0], B_split[1, 1]) + rMult(A_split[0, 1], B_split[1, 1])
    
    P_3 = rMult(A_split[1, 0], B_split[0, 0]) + rMult(A_split[1, 1], B_split[0, 0])
    
    P_4 = rMult(A_split[1, 1], B_split[1, 0]) - rMult(A_split[1, 1], B_split[0, 0])
    
    P_5 = rMult(A_split[0, 0], B_split[0, 0]) + rMult(A_split[0, 0], B_split[1, 1])
    P_5 += rMult(A_split[1, 1], B_split[0, 0]) + rMult(A_split[1, 1], B_split[1, 1])
    
    P_6 = rMult(A_split[0, 1], B_split[1, 0]) + rMult(A_split[0, 1], B_split[1, 1])
    P_6 += - rMult(A_split[1, 1], B_split[1, 0]) - rMult(A_split[1, 1], B_split[1, 1])
    
    P_7 = rMult(A_split[0, 0], B_split[0, 0]) + rMult(A_split[0, 0], B_split[0, 1])
    P_7 += - rMult(A_split[1, 0], B_split[0, 0]) - rMult(A_split[1, 0], B_split[0, 1])
    """
    
    P_1 = recursive_mult(copy(A), copy(S_1), A_split[0, 0], None)
    P_2 = recursive_mult(copy(S_2), copy(B), None, B_split[1, 1])
    P_3 = recursive_mult(copy(S_3), copy(B), None, B_split[0, 0])
    P_4 = recursive_mult(copy(A), copy(S_4), A_split[1, 1], None)
    P_5 = recursive_mult(copy(S_5), copy(S_6))
    P_6 = recursive_mult(copy(S_7), copy(S_8))
    P_7 = recursive_mult(copy(S_9), copy(S_10))
    
    C_11 = P_5 + P_4 - P_2 + P_6
    C_12 = P_1 + P_2
    C_21 = P_3 + P_4
    C_22 = P_5 + P_1 - P_3 - P_7
    
    C = build_matrix([C_11, C_12, C_21, C_22])
    
    return C
    
    
    
def is_equal(a, b):
    return (abs(strassen_res - naive_res) <= 10e-9).all()

In [539]:
n = 32
A = np.random.rand(n, n)
B = np.random.rand(n, n)

In [540]:
%timeit -n50 strassen_res = strassen(A, B)

207 ms ± 8.03 ms per loop (mean ± std. dev. of 7 runs, 50 loops each)


In [541]:
%timeit -n50 naive_res = naive_mult(A, B)

26.6 ms ± 1.15 ms per loop (mean ± std. dev. of 7 runs, 50 loops each)


In [542]:
is_equal(strassen_res, naive_res)

True

In [513]:
np.dot(A, B)

array([[ 7.19642441,  6.27014167,  7.54262825, ...,  7.86324081,
         8.00589952,  6.62987813],
       [ 7.34780014,  7.26551811,  8.64727421, ...,  8.54186025,
         8.3370981 ,  7.63101564],
       [ 8.56574757,  6.85524574,  9.3563719 , ..., 10.70825769,
         9.07624093,  9.01393827],
       ...,
       [ 6.05373222,  5.92076864,  8.21013581, ...,  8.02610382,
         8.09167631,  6.5425994 ],
       [ 8.48677804,  6.78618704,  8.90877955, ...,  8.73110749,
         9.35552531,  7.4763038 ],
       [ 8.94410642,  7.60340689, 10.72899545, ..., 10.29667625,
         9.89970721,  8.60268255]])

In [466]:
strassen_res

array([[0.30254901, 0.15440135],
       [0.02415313, 0.43583908]])

In [456]:
naive_res

array([[2.19384937, 1.97086571, 2.73209922, 2.18208137, 1.65298491,
        1.20043977, 2.18852903, 1.11316963],
       [2.55968606, 2.70497631, 3.67091188, 2.36424498, 1.77392901,
        1.27334191, 2.95036702, 1.70997761],
       [2.26232552, 2.21526165, 2.95195076, 1.95662313, 1.62944951,
        1.06848511, 2.38248813, 1.60199174],
       [2.7701198 , 1.91912645, 3.54831418, 2.45451417, 1.51524542,
        1.59882328, 2.59032631, 1.55689174],
       [2.21539921, 2.25208861, 2.87042758, 1.81788671, 1.37780386,
        1.02636753, 2.31137927, 1.07602928],
       [2.07715732, 2.12956696, 2.64493299, 1.99140294, 2.0046203 ,
        1.18740461, 2.0299109 , 0.83685712],
       [2.47098833, 2.70430212, 3.11104555, 2.47489075, 2.78336997,
        1.62185074, 2.38472387, 1.13834093],
       [2.12103701, 1.80790953, 2.8944624 , 2.0271973 , 1.81075743,
        1.12974145, 2.01683433, 0.92491895]])

In [457]:
split_test = Split()
split_test.init(0, 0, 4)

test_m = np.random.rand(4, 4)

recursive_mult(test_m, A, None, split_test)

array([[0.99685352, 0.80368283, 0.28854261, 0.34159672],
       [1.968495  , 1.85628709, 0.67799034, 1.14701835],
       [0.99444611, 0.92129612, 0.22638741, 0.32183264],
       [1.6988771 , 1.56024799, 0.62092772, 0.95621988]])

In [458]:
np.dot( test_m, A[:4, :4] )

array([[0.99685352, 0.80368283, 0.28854261, 0.34159672],
       [1.968495  , 1.85628709, 0.67799034, 1.14701835],
       [0.99444611, 0.92129612, 0.22638741, 0.32183264],
       [1.6988771 , 1.56024799, 0.62092772, 0.95621988]])

In [459]:
A

array([[0.75120517, 0.90705693, 0.07810624, 0.19401543, 0.94099356,
        0.51963258, 0.47000509, 0.2579761 ],
       [0.47276097, 0.81688578, 0.12805137, 0.59224093, 0.98760485,
        0.63180237, 0.98090944, 0.39481237],
       [0.40355869, 0.37124198, 0.31184075, 0.39820999, 0.28018527,
        0.40662815, 0.94806892, 0.7709125 ],
       [0.95983786, 0.20603088, 0.37780736, 0.19832483, 0.97972776,
        0.84955826, 0.79433534, 0.32149026],
       [0.60331829, 0.68588193, 0.37728137, 0.86797153, 0.46768063,
        0.27069545, 0.68452116, 0.06509385],
       [0.49653628, 0.65873772, 0.72850809, 0.52224599, 0.5618743 ,
        0.48881368, 0.10579398, 0.39024129],
       [0.42204114, 0.99261446, 0.89545409, 0.30131875, 0.62804943,
        0.96105577, 0.04507958, 0.58250447],
       [0.42327399, 0.05101325, 0.96406732, 0.47604071, 0.86294975,
        0.45675055, 0.1607555 , 0.49931809]])

In [None]:
strassen([2, 2])