## Strassen homework analisys

In [2]:
from matrix import *

In [5]:
def strassen_matrix_mult(A, B):
    ''' Multiply two matrices by using Strassen's algorithm

    Parameters
    ----------
    A: Matrix
        The first matrix to be multiplied
    B: Matrix
        The second matrix to be multiplied

    Returns
    -------
    Matrix
        The row-column multiplication of the matrices passed as parameters

    Raises
    ------
    ValueError
        If the number of columns of `A` is different from the number of
        rows of `B`
    '''

    if(A.num_of_cols != B.num_of_rows):
        raise ValueError("The two matrices can't be multiplied")

    # Base case
    min_size=4 
    if max(A.num_of_rows, B.num_of_cols, A.num_of_cols) < min_size:
        return gauss_matrix_mult(A, B)

    # Rectangular case
    if(A.num_of_rows != A.num_of_cols):
        N = A.num_of_cols
        C = zero_matrix(A.num_of_rows, B.num_of_cols)

        for split_A in range(0, A.num_of_rows, N):
            for split_B in range(0, B.num_of_cols, N):
                temp = strassen_matrix_mult(A.submatrix(split_A, N, 0, N), B.submatrix(0, N, split_B, N))
                C.assign_submatrix(split_A, split_B, temp)
        
        # In case A.num_of_rows % N != 0 or B.num_of_cols % N != 0
        # I pad the rest of the matrix to square it
        print(rows_padded, cols_padded, A_padded, B_padded, temp, "\n", sep="\n\n")
        if(A.num_of_rows % n != 0 or B.num_of_cols % N != 0):
            rows_padded = A.num_of_rows - split_A
            cols_padded = B.num_of_cols - split_B
            A_padded = pad_matrix(A.submatrix(split_A, rows_padded, 0, N), N - rows_padded, 0)
            B_padded = pad_matrix(B.submatrix(0, N, split_B, cols_padded), N - cols_padded, 0)
            temp = strassen_matrix_mult(A_padded, B_padded)
            print(rows_padded, cols_padded, A_padded, B_padded, temp, "\n", sep="\n\n")

            C.assign_submatrix(split_A, split_B, temp.submatrix(0, rows_padded, 0, cols_padded))

        return C

    
    # Uneven case, at this point the matrix is square
    padded = 0
    if(A.num_of_cols % 2 != 0):
        A = pad_matrix(A, 1, 1)
        B = pad_matrix(B, 1, 1)
        padded = 1
    
    A11, A12, A21, A22 = get_matrix_quadrants(A)
    B11, B12, B21, B22 = get_matrix_quadrants(B)
    
    S1 = B12 - B22
    S2 = A11 + A12
    S3 = A21 + A22
    S4 = B21 - B11
    S5 = A11 + A22
    S6 = B11 + B22
    S7 = A12 - A22
    S8 = B21 + B22
    S9 = A11 - A21
    S10 = B11 + B12
    
    P1 = strassen_matrix_mult(A11, S1)
    P2 = strassen_matrix_mult(S2, B22)
    P3 = strassen_matrix_mult(S3, B11)
    P4 = strassen_matrix_mult(A22, S4)
    P5 = strassen_matrix_mult(S5, S6)
    P6 = strassen_matrix_mult(S7, S8)
    P7 = strassen_matrix_mult(S9, S10)
    
    C11 = P5 + P4 - P2 + P6
    C12 = P1 + P2 
    C21 = P3 + P4
    C22 = P5 + P1 - P3 - P7
    
    
    C = [[0 for j in range(B.num_of_cols)] for i in range(A.num_of_rows)]
    C = Matrix(C, clone_matrix=False)
    C.assign_submatrix(0, 0, C11)
    C.assign_submatrix(0, C.num_of_cols//2, C12)
    C.assign_submatrix(C.num_of_rows//2, 0, C21)
    C.assign_submatrix(C.num_of_rows//2, C.num_of_cols//2, C22)
    
    return C.submatrix(0, C.num_of_rows - padded, 0, C.num_of_cols - padded)

In [10]:
n = 5
m = 15
o = 12
A = Matrix([[i + j for j in range(0,n)] for i in range(0,m)])
B = Matrix([[i - j for j in range(0,o)] for i in range(0,n)])
gauss_matrix_mult(A, B) - strassen_matrix_mult(A, B)

ValueError: The two matrices can't be multiplied

In [2]:
A[:][0:2]

NameError: name 'A' is not defined

In [17]:
list(range(1,10, 2))

[1, 3, 5, 7, 9]