In [95]:
### My implementaion of simple matrix multiplication (O(n^3))
function simple_matrix_multiplication(A::Matrix{T}, B::Matrix{T}) where T<:AbstractFloat
    n, m = size(A)
    m2, p = size(B)
    
    C = zeros(T, n, p)
    
    for i in 1:n
        for j in 1:p
            for k in 1:m
                C[i, j] += A[i, k] * B[k, j]
            end
        end
    end
    
    return C
end


simple_matrix_multiplication (generic function with 2 methods)

In [96]:
### implementation of the sterassen matrix multiplication(O(n^2.83))
function strassen(A::Matrix{T}, B::Matrix{T}) where T<:AbstractFloat
    n = size(A, 1)
    
    if n <= 2
        return A * B
    else
        m = div(n, 2)

        # Partition the matrices into submatrices
        A11, A12 = A[1:m, 1:m], A[1:m, m+1:n]
        A21, A22 = A[m+1:n, 1:m], A[m+1:n, m+1:n]
        B11, B12 = B[1:m, 1:m], B[1:m, m+1:n]
        B21, B22 = B[m+1:n, 1:m], B[m+1:n, m+1:n]

        # Strassen's 7 products
        M1 = strassen(A11 + A22, B11 + B22)
        M2 = strassen(A21 + A22, B11)
        M3 = strassen(A11, B12 - B22)
        M4 = strassen(A22, B21 - B11)
        M5 = strassen(A11 + A12, B22)
        M6 = strassen(A21 - A11, B11 + B12)
        M7 = strassen(A12 - A22, B21 + B22)

        # Combining the results into the final matrix
        C11 = M1 + M4 - M5 + M7
        C12 = M3 + M5
        C21 = M2 + M4
        C22 = M1 - M2 + M3 + M6

        # Construct the final matrix from submatrices
        C = [C11 C12; C21 C22]

        return C
    end
end


strassen (generic function with 2 methods)

In [97]:
### hybrid implementation of the sterassen, the matrices smaller than threshold we use normal multiplication 
function hybrid_strassen(A::Matrix{T}, B::Matrix{T}; threshold::Int = 64) where T<:AbstractFloat
    n = size(A, 1)

    if n <= threshold
        return A * B
    else
        m = div(n, 2)

        A11, A12 = A[1:m, 1:m], A[1:m, m+1:n]
        A21, A22 = A[m+1:n, 1:m], A[m+1:n, m+1:n]
        B11, B12 = B[1:m, 1:m], B[1:m, m+1:n]
        B21, B22 = B[m+1:n, 1:m], B[m+1:n, m+1:n]

        M1 = hybrid_strassen(A11 + A22, B11 + B22, threshold=threshold)
        M2 = hybrid_strassen(A21 + A22, B11, threshold=threshold)
        M3 = hybrid_strassen(A11, B12 - B22, threshold=threshold)
        M4 = hybrid_strassen(A22, B21 - B11, threshold=threshold)
        M5 = hybrid_strassen(A11 + A12, B22, threshold=threshold)
        M6 = hybrid_strassen(A21 - A11, B11 + B12, threshold=threshold)
        M7 = hybrid_strassen(A12 - A22, B21 + B22, threshold=threshold)

        C11 = M1 + M4 - M5 + M7
        C12 = M3 + M5
        C21 = M2 + M4
        C22 = M1 - M2 + M3 + M6

        return [C11 C12; C21 C22]
    end
end


hybrid_strassen (generic function with 2 methods)

In [98]:
using CUDA

# implementation of the GPU-accelerated hybrid Strassen function
function hybrid_strassen_gpu(A::Matrix{T}, B::Matrix{T}; threshold::Int = 64) where T<:AbstractFloat
    n = size(A, 1)

    # Base case: Use standard multiplication on the GPU for small matrices
    if n <= threshold
        A_gpu = CuArray(A)
        B_gpu = CuArray(B)
        return A_gpu * B_gpu
    else
        m = div(n, 2)

        # Partition the matrices into submatrices
        A11, A12 = A[1:m, 1:m], A[1:m, m+1:n]
        A21, A22 = A[m+1:n, 1:m], A[m+1:n, m+1:n]
        B11, B12 = B[1:m, 1:m], B[1:m, m+1:n]
        B21, B22 = B[m+1:n, 1:m], B[m+1:n, m+1:n]

        # Recursively calculate the 7 products using hybrid Strassen's method
        M1 = hybrid_strassen_gpu(A11 + A22, B11 + B22, threshold=threshold)
        M2 = hybrid_strassen_gpu(A21 + A22, B11, threshold=threshold)
        M3 = hybrid_strassen_gpu(A11, B12 - B22, threshold=threshold)
        M4 = hybrid_strassen_gpu(A22, B21 - B11, threshold=threshold)
        M5 = hybrid_strassen_gpu(A11 + A12, B22, threshold=threshold)
        M6 = hybrid_strassen_gpu(A21 - A11, B11 + B12, threshold=threshold)
        M7 = hybrid_strassen_gpu(A12 - A22, B21 + B22, threshold=threshold)

        # Combining the results into the final matrix
        C11 = M1 + M4 - M5 + M7
        C12 = M3 + M5
        C21 = M2 + M4
        C22 = M1 - M2 + M3 + M6

        # Combine the results into a full matrix
        C = vcat(hcat(C11, C12), hcat(C21, C22))

        return C
    end
end

hybrid_strassen_gpu (generic function with 2 methods)

In [99]:
#padding the matrix to make it n*n where n=2^k, this is the only suitable size for Strassen algorithm
function pad_matrix(M1::Matrix{T}, M2::Matrix{T} ) where T<:AbstractFloat
    n=2^ceil(Int, log2(max(size(M1)[1], size(M1)[2], size(M2)[2])))
    padded_M1 = zeros(T, n, n)
    original_size = size(M1)
    padded_M1[1:original_size[1], 1:original_size[2]] = M1

    padded_M2 = zeros(T, n, n)
    original_size = size(M2)
    padded_M2[1:original_size[1], 1:original_size[2]] = M2

    return padded_M1, padded_M2
end

pad_matrix (generic function with 3 methods)

In [100]:
#example A and B for multipllication
n = 128
m=128
k=128
A = rand(Float32, n, m);
B = rand(Float32, m, k);

In [103]:
@time C = simple_matrix_multiplication(A, B);

  0.002400 seconds (2 allocations: 64.047 KiB)


In [104]:
padded_A, padded_B = pad_matrix(A, B)
@time C_padded = strassen(padded_A, padded_B)
C = C_padded[1:size(A)[1], 1:size(B)[2]]

  0.044818 seconds (627.46 k allocations: 57.024 MiB, 6.85% gc time)


In [108]:
padded_A, padded_B = pad_matrix(A, B)
@time C_padded = hybrid_strassen(padded_A, padded_B)
C = C_padded[1:size(A)[1], 1:size(B)[2]]

  0.000341 seconds (34 allocations: 580.047 KiB)


In [106]:
@time C_padded = hybrid_strassen_gpu(padded_A, padded_B)
C = C_padded[1:size(A)[1], 1:size(B)[2]]

  0.610302 seconds (44.36 k allocations: 2.725 MiB, 9.51% compilation time)


In [109]:
@time A*B

  0.000294 seconds (2 allocations: 64.047 KiB)
