# strassen's algorithm implementation 

### 2x2 matrix

In [207]:
A = [[1,2],[3,4]] # matrix A
B = [[5,6],[7,8]] # matrix B
M = [[0,0],[0,0]] # matrix M

additions and 7 multiplications (instead of conventional 8 muls)

In [208]:
p1 = (A[0][0]+A[1][1]) * (B[0][0]+B[1][1])
p2 = (A[1][0]+A[1][1]) * B[0][0]
p3 = A[0][0] * (B[0][1]-B[1][1])
p4 = A[1][0] * (B[1][0]-B[0][0])
p5 = (A[0][0]+A[0][1]) * B[1][1]
p6 = (A[1][0]-A[0][0]) * (B[0][0]+B[0][1])
p7 = (A[0][1]-A[1][1]) * (B[1][0]+B[1][1])

final additions

In [209]:
M[0][0] = p1 + p4 - p5 + p7
M[0][1] = p3 + p5
M[1][0] = p2 + p4
M[1][1] = p1 + p3 - p2 + p6

In [210]:
M


[[17, 22], [41, 50]]

verification

## 4x4 matrix

this has 7 2x2 matmuls resulting in 49 muls

2x2 strassen takes 10 + 8 = 18 additions and 7 muls

so lets first create a function (module in rtl) for 2x2 matmul

In [211]:
def matmul_2x2 (A, B):
    """
    Multiplies two 2x2 matrices A and B using the Strassen algorithm.
    """
    M = [[0, 0], [0, 0]]  # Resultant matrix
    
    p1 = (A[0][0] + A[1][1]) * (B[0][0] + B[1][1])
    p2 = (A[1][0] + A[1][1]) * B[0][0]
    p3 = A[0][0] * (B[0][1] - B[1][1])
    p4 = A[1][1] * (B[1][0] - B[0][0])
    p5 = (A[0][0] + A[0][1]) * B[1][1]
    p6 = (A[1][0] - A[0][0]) * (B[0][0] + B[0][1])
    p7 = (A[0][1] - A[1][1]) * (B[1][0] + B[1][1])

    M[0][0] = p1 + p4 - p5 + p7
    M[0][1] = p3 + p5
    M[1][0] = p2 + p4
    M[1][1] = p1 + p3 - p2 + p6

    return M

In [212]:
# mattrix addition for 2x2 matrices
def matadd_2x2(A, B):
    """
    Adds two 2x2 matrices A and B.
    """
    return [[A[0][0] + B[0][0], A[0][1] + B[0][1]],
            [A[1][0] + B[1][0], A[1][1] + B[1][1]]]

# matrix subtraction for 2x2 matrices
def matsub_2x2(A, B):
    """
    Subtracts matrix B from matrix A.
    """
    return [[A[0][0] - B[0][0], A[0][1] - B[0][1]],
            [A[1][0] - B[1][0], A[1][1] - B[1][1]]]

then lets do this for 4x4 matrices

In [213]:
# define 2 4x4 matrices
A = [[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]] # matrix A
B = [[17,18,19,20],[21,22,23,24],[25,26,27,28],[29,30,31,32]] # matrix B
M = [[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]] # matrix M 

# decompose A and B into 4 2x2 matrices
A11 = [[A[0][0], A[0][1]], [A[1][0], A[1][1]]]
A12 = [[A[0][2], A[0][3]], [A[1][2], A[1][3]]]
A21 = [[A[2][0], A[2][1]], [A[3][0], A[3][1]]]
A22 = [[A[2][2], A[2][3]], [A[3][2], A[3][3]]]

B11 = [[B[0][0], B[0][1]], [B[1][0], B[1][1]]]
B12 = [[B[0][2], B[0][3]], [B[1][2], B[1][3]]]
B21 = [[B[2][0], B[2][1]], [B[3][0], B[3][1]]]
B22 = [[B[2][2], B[2][3]], [B[3][2], B[3][3]]]

In [214]:
matadd_2x2(A11, A22)

[[12, 14], [20, 22]]

In [215]:
p1 = matmul_2x2 ( matadd_2x2(A11, A22) , matadd_2x2(B11, B22) )
p2 = matmul_2x2 ( matadd_2x2(A21, A22) , B11 )
p3 = matmul_2x2 ( A11 , matsub_2x2(B12, B22) )
p4 = matmul_2x2 ( A22 , matsub_2x2(B21, B11) )
p5 = matmul_2x2 ( matadd_2x2(A11, A12) , B22 )
p6 = matmul_2x2 ( matsub_2x2(A21, A11) , matadd_2x2(B11, B12) )
p7 = matmul_2x2 ( matsub_2x2(A12, A22) , matadd_2x2(B21, B22) )

C11 = matadd_2x2(matadd_2x2(p1, p4), matsub_2x2(p7, p5))
C12 = matadd_2x2(p3, p5)
C21 = matadd_2x2(p2, p4)
C22 = matadd_2x2(matsub_2x2(p1, p2), matadd_2x2(p3, p6))


In [216]:
C12

[[270, 280], [670, 696]]

In [217]:
# Combine C blocks into full 4x4 matrix
C = [
    C11[0] + C12[0],
    C11[1] + C12[1],
    C21[0] + C22[0],
    C21[1] + C22[1],
    ]
C

[[250, 260, 270, 280],
 [618, 644, 670, 696],
 [986, 1028, 1070, 1112],
 [1354, 1412, 1470, 1528]]

In [218]:
# conventional matrix multiplication
M0 = [[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]] # matrix M
for i in range(4):
    for j in range(4):
        M0[i][j] = 0
        for k in range(4):
            M0[i][j] += A[i][k] * B[k][j]

# print the result
for i in range(4):
    for j in range(4):
        print(M0[i][j], end=' ')
    print()

250 260 270 280 
618 644 670 696 
986 1028 1070 1112 
1354 1412 1470 1528 


## improved strassen 

https://epubs.siam.org/doi/pdf/10.1137/22M1502719 

we reduce the number of additions to 12 additions(+/-)

In [None]:
def matmul_2x2_imp (A, B):
    """
    Multiplies two 2x2 matrices A and B using the Strassen algorithm.
    """
    M = [[0, 0], [0, 0]]  # Resultant matrix

    t1 = A[1][0] + A[1][1]
    t2 = A[1][1] - A[0][1]
    t3 = A[1][1] - A[0][0]
    t4 = B[1][1] - B[0][0]
    t5 = B[1][0] + B[1][1]
    t6 = B[1][1] - B[0][1]
    
    m1 = A[0][0] * B[0][0]
    m2 = A[0][1] * B[1][0]
    m3 = A[1][0] * t4
    m4 = A[1][1] * B[1][1]
    m5 = t1 * t5
    m6 = t2 * t6
    m7 = t3 * B[0][1]

    M[0][0] = m1 + m2
    M[0][1] = m5 - m7
    M[1][0] = m3 + m6
    M[1][1] = m5 + m6 - m2 - m4

    return M

In [220]:
G = [[1,2],[3,4]] # matrix G
H = [[1,2],[3,4]] # matrix H

In [221]:
matmul_2x2_imp(G, H)

[[7, 43], [13, 31]]

In [223]:
matmul_2x2(G, H)

[[7, 10], [15, 22]]

In [222]:
# conventional matrix multiplication
M1 = [[0,0],[0,0]] # matrix M
for i in range(2):
    for j in range(2):
        M1[i][j] = 0
        for k in range(2):
            M1[i][j] +=G[i][k] * H[k][j]

# print the result
for i in range(2):
    for j in range(2):
        print(M1[i][j], end=' ')
    print()

7 10 
15 22 
