In [1]:
import utils as u
import plotly.graph_objects as go

In [2]:
import numpy as np
from sympy import Matrix

In [3]:
A = np.random.randint(0, 10, (8, 8))
B = np.random.randint(0, 10, (8, 8))

In [4]:
A, B

(array([[8, 6, 4, 6, 2, 9, 8, 0],
        [0, 4, 7, 4, 7, 1, 4, 1],
        [5, 4, 8, 8, 6, 0, 1, 0],
        [4, 3, 0, 5, 6, 6, 0, 1],
        [5, 9, 3, 0, 6, 6, 4, 1],
        [7, 4, 4, 0, 3, 8, 9, 0],
        [2, 9, 2, 4, 5, 8, 5, 4],
        [0, 2, 5, 4, 7, 1, 7, 8]]),
 array([[1, 2, 3, 8, 4, 9, 2, 6],
        [9, 8, 1, 0, 3, 9, 2, 1],
        [8, 5, 4, 5, 7, 9, 2, 1],
        [7, 7, 2, 7, 9, 8, 7, 1],
        [0, 9, 2, 9, 9, 2, 2, 9],
        [6, 9, 4, 0, 6, 9, 5, 0],
        [8, 0, 7, 7, 0, 9, 7, 4],
        [0, 6, 4, 3, 3, 2, 7, 2]]))

In [5]:
def naive_matrix_mult(A: np.ndarray, B: np.ndarray) -> np.ndarray:
    # rows_A, cols_A = A.shape
    # rows_B, cols_B = B.shape
    # result = np.zeros((rows_A, cols_B), dtype=A.dtype)

    # for i in range(rows_A):
    #     for j in range(cols_B):
    #         for k in range(cols_A):
    #             result[i, j] += A[i, k] * B[k, j]

    # return result
    return np.dot(A, B)


@u.timed
def naive_matrix_mult_timed(A: np.ndarray, B: np.ndarray) -> np.ndarray:
    return naive_matrix_mult(A, B)

In [6]:
naive_matrix_mult_timed(A, B)

Elapsed time: 16000.0000 nanoseconds


array([[254, 225, 154, 200, 204, 367, 183, 114],
       [158, 173,  90, 157, 169, 192, 104,  96],
       [169, 192,  86, 197, 214, 238, 109, 108],
       [102, 181,  65, 124, 163, 171,  98,  88],
       [178, 211, 104, 140, 161, 257, 111, 114],
       [195, 165, 142, 166, 143, 294, 139, 113],
       [215, 255, 124, 146, 190, 284, 167, 100],
       [148, 189, 129, 189, 170, 197, 166, 118]])

In [14]:
import numpy as np


def strassen_matrix_mult(A: np.ndarray, B: np.ndarray) -> np.ndarray:
    if A.shape[0] == 1:
        return A * B

    mid = A.shape[0] // 2

    A11, A12 = A[:mid, :mid], A[:mid, mid:]
    A21, A22 = A[mid:, :mid], A[mid:, mid:]
    B11, B12 = B[:mid, :mid], B[:mid, mid:]
    B21, B22 = B[mid:, :mid], B[mid:, mid:]

    M1 = strassen_matrix_mult(A11 + A22, B11 + B22)
    M2 = strassen_matrix_mult(A21 + A22, B11)
    M3 = strassen_matrix_mult(A11, B12 - B22)
    M4 = strassen_matrix_mult(A22, B21 - B11)
    M5 = strassen_matrix_mult(A11 + A12, B22)
    M6 = strassen_matrix_mult(A21 - A11, B11 + B12)
    M7 = strassen_matrix_mult(A12 - A22, B21 + B22)

    C11 = M1 + M4 - M5 + M7
    C12 = M3 + M5
    C21 = M2 + M4
    C22 = M1 - M2 + M3 + M6

    C = np.vstack((np.hstack((C11, C12)), np.hstack((C21, C22))))
    return C
    # A_sympy = Matrix(A)
    # B_sympy = Matrix(B)
    # result = A_sympy.multiply(
    #     B_sympy,
    # )
    # return np.array(result.tolist())


@u.timed
def strassen_matrix_mult_timed(A: np.ndarray, B: np.ndarray) -> np.ndarray:
    return strassen_matrix_mult(A, B)

In [15]:
strassen_matrix_mult_timed(A, B)

Elapsed time: 5934000.0000 nanoseconds


array([[254, 225, 154, 200, 204, 367, 183, 114],
       [158, 173,  90, 157, 169, 192, 104,  96],
       [169, 192,  86, 197, 214, 238, 109, 108],
       [102, 181,  65, 124, 163, 171,  98,  88],
       [178, 211, 104, 140, 161, 257, 111, 114],
       [195, 165, 142, 166, 143, 294, 139, 113],
       [215, 255, 124, 146, 190, 284, 167, 100],
       [148, 189, 129, 189, 170, 197, 166, 118]])

In [20]:
pairs_count = 8
min_int = 1
max_int = 128

pairs = u.generate_sorted_square_matrix_pairs(pairs_count, min_int, max_int)
naive_times = u.sample_matrix_mult_algorithm(naive_matrix_mult, pairs)
strassed_times = u.sample_matrix_mult_algorithm(strassen_matrix_mult, pairs)

In [22]:
fig = go.Figure()
y_labels = [f"Call {i+1}" for i in range(pairs_count)]
fig.add_trace(
    go.Scatter(
        y=naive_times,
        mode="lines+markers",
        name="Naive Multiplication",
        line=dict(color="blue"),
    )
)
fig.add_trace(
    go.Scatter(
        y=strassed_times,
        mode="lines+markers",
        name="Karatsuba Multiplication",
        line=dict(color="red"),
    )
)
fig.update_layout(
    title="Elapsed Time Comparison",
    yaxis_title="Elapsed Time (nanoseconds)",
    yaxis_tickmode="array",
    yaxis_tickvals=list(range(len(y_labels))),
    yaxis_ticktext=y_labels,
    showlegend=True,
)
fig.show()