## Strassen's Algorithm for matrix 
Now it's only for multiplication of 2 n*n matrix n is a power of 2

In [1]:
import numpy as np

In [28]:
m1 = np.arange(64).reshape(8,8)
m2 = np.arange(64).reshape(8,8)
print(m1, "\n",m2)

[[ 0  1  2  3  4  5  6  7]
 [ 8  9 10 11 12 13 14 15]
 [16 17 18 19 20 21 22 23]
 [24 25 26 27 28 29 30 31]
 [32 33 34 35 36 37 38 39]
 [40 41 42 43 44 45 46 47]
 [48 49 50 51 52 53 54 55]
 [56 57 58 59 60 61 62 63]] 
 [[ 0  1  2  3  4  5  6  7]
 [ 8  9 10 11 12 13 14 15]
 [16 17 18 19 20 21 22 23]
 [24 25 26 27 28 29 30 31]
 [32 33 34 35 36 37 38 39]
 [40 41 42 43 44 45 46 47]
 [48 49 50 51 52 53 54 55]
 [56 57 58 59 60 61 62 63]]


We want to find the least computation for matrix multiplication

In [29]:
# [[A B],   .   [[E F],   =   [[AE+BG  AF+BH]
#  [C D]]        [G H]]        [CE+DG  CF+DH]]

In [30]:
def matrix_splitter(Mat:np.matrix):
    n = int(Mat.shape[0]/2)

    s1 = Mat[:n, :n]
    s2 = Mat[:n, n:]
    s3 = Mat[n:, :n]
    s4 = Mat[n:, n:]
    return s1, s2, s3, s4

In [31]:
def strasson(Mat1:np.array,Mat2:np.array)->np.array:
    if (Mat1.shape[0] * Mat1.shape[1] * Mat2.shape[0] * Mat2.shape[1]) % 2 != 0:
        raise ValueError ("shapes should be a power of 2")
    if (Mat1.shape[0] != Mat1.shape[1] and Mat2.shape[0] != Mat2.shape[1]):
        raise ValueError ("Shapes should be n*n")

    # Base Case
    if (Mat1.shape[0]<4) and (Mat2.shape[0]<4):
        A,B,C,D = matrix_splitter(Mat1)
        E,F,G,H = matrix_splitter(Mat2)        
        p1 = A * (F-H)
        p2 = (A+B) * H
        p3 = (C+D) * E
        p4 = D * (G-E)
        p5 = (A+D) * (E+H)
        p6 = (B-D) * (G+H)
        p7 = (A-C) * (E+F)
        c1 = p4+p5+p6-p2
        c2 = p1+p2
        c3 = p3+p4
        c4 = p1+p5-p3-p7
        res = np.vstack((np.hstack((c1,c2)), np.hstack((c3,c4))))
        return res
    
    A,B,C,D = matrix_splitter(Mat1)
    E,F,G,H = matrix_splitter(Mat2)
    return np.vstack((np.hstack((strasson(A,E)+strasson(B,G), strasson(A,F)+strasson(B,H))), np.hstack((strasson(C,E)+strasson(D,G), strasson(C,F)+strasson(D,H)))))



In [32]:
strasson(m1,m2)

array([[ 1120,  1148,  1176,  1204,  1232,  1260,  1288,  1316],
       [ 2912,  3004,  3096,  3188,  3280,  3372,  3464,  3556],
       [ 4704,  4860,  5016,  5172,  5328,  5484,  5640,  5796],
       [ 6496,  6716,  6936,  7156,  7376,  7596,  7816,  8036],
       [ 8288,  8572,  8856,  9140,  9424,  9708,  9992, 10276],
       [10080, 10428, 10776, 11124, 11472, 11820, 12168, 12516],
       [11872, 12284, 12696, 13108, 13520, 13932, 14344, 14756],
       [13664, 14140, 14616, 15092, 15568, 16044, 16520, 16996]])

In [33]:
np.dot(m1,m2)

array([[ 1120,  1148,  1176,  1204,  1232,  1260,  1288,  1316],
       [ 2912,  3004,  3096,  3188,  3280,  3372,  3464,  3556],
       [ 4704,  4860,  5016,  5172,  5328,  5484,  5640,  5796],
       [ 6496,  6716,  6936,  7156,  7376,  7596,  7816,  8036],
       [ 8288,  8572,  8856,  9140,  9424,  9708,  9992, 10276],
       [10080, 10428, 10776, 11124, 11472, 11820, 12168, 12516],
       [11872, 12284, 12696, 13108, 13520, 13932, 14344, 14756],
       [13664, 14140, 14616, 15092, 15568, 16044, 16520, 16996]])

In [34]:
(strasson(m1,m2) == np.dot(m1,m2)).all()

True

In [35]:
m1 = np.arange(36).reshape(6,6)
m2 = np.arange(36).reshape(6,6)
print(m1, "\n",m2)
strasson(m1,m2)

[[ 0  1  2  3  4  5]
 [ 6  7  8  9 10 11]
 [12 13 14 15 16 17]
 [18 19 20 21 22 23]
 [24 25 26 27 28 29]
 [30 31 32 33 34 35]] 
 [[ 0  1  2  3  4  5]
 [ 6  7  8  9 10 11]
 [12 13 14 15 16 17]
 [18 19 20 21 22 23]
 [24 25 26 27 28 29]
 [30 31 32 33 34 35]]


ValueError: shapes should be a power of 2