In [1]:
import threading
import numpy as np

from tree import TreeNode

In [2]:
def multiply_naive(A, B):
    return A @ B

def pad_matrix(matrix):
    n = matrix.shape[0]
    padded = np.zeros((n+1, n+1), dtype=matrix.dtype)
    padded[:n, :n] = matrix
    return padded

def split_matrix(matrix):
    n = matrix.shape[0]
    mid = n // 2
    return matrix[:mid, :mid], matrix[:mid, mid:], matrix[mid:, :mid], matrix[mid:, mid:]

def combine_matrices(C11, C12, C21, C22):
    top = np.hstack((C11, C12))
    bottom = np.hstack((C21, C22))
    return np.vstack((top, bottom))

In [3]:
def distributed_multiplication(A, B, max_block_size):
    n = A.shape[0]
    current_node = TreeNode(size=n, base_case=False, padded=False)

    if n <= max_block_size:
        current_node.base_case = True
        C = multiply_naive(A, B)
        return C, current_node

    if n % 2 != 0:
        current_node.padded = True
        A = pad_matrix(A)
        B = pad_matrix(B)

    A11, A12, A21, A22 = split_matrix(A)
    B11, B12, B21, B22 = split_matrix(B)

    results = {}

    def compute(block_name, A_B_blocks, node):
        C1, res_node = distributed_multiplication(A_B_blocks[0][0], A_B_blocks[0][1], max_block_size)
        C2, _ = distributed_multiplication(A_B_blocks[1][0], A_B_blocks[1][1], max_block_size)
        if block_name == 'C11':
            node.add_child(res_node)
        results[block_name] = C1 + C2

    thread_args = [
        ('C11', [(A11, B11), (A12, B21)]),
        ('C12', [(A11, B12), (A12, B22)]),
        ('C21', [(A21, B11), (A22, B21)]),
        ('C22', [(A21, B12), (A22, B22)]),
    ]

    threads = []
    for block_name, A_B_blocks in thread_args:
        t = threading.Thread(target=compute, args=(block_name, A_B_blocks, current_node), name=block_name)
        threads.append(t)
        t.start()

    for t in threads:
        t.join()

    C = combine_matrices(results['C11'], results['C12'], results['C21'], results['C22'])

    if current_node.padded:
        return C[:n, :n], current_node

    return C, current_node

In [4]:
MATRIX_SIZE = 100
MAX_BLOCK_SIZE = 4

In [5]:
A = np.random.randint(0, 10, size=(MATRIX_SIZE, MATRIX_SIZE))
B = np.random.randint(0, 10, size=(MATRIX_SIZE, MATRIX_SIZE))

In [6]:
C, root_node = distributed_multiplication(A, B, MAX_BLOCK_SIZE)
root_node.display()

100x100
└── 50x50
    └── (25+1)x(25+1)
        └── (13+1)x(13+1)
            └── (7+1)x(7+1)
                └── 4x4 (base case)


In [7]:
expected = A @ B
is_equal = np.array_equal(C, expected)

assert is_equal
print(is_equal)

True


In [8]:
for matrix_size in range(1, 21):
    A = np.random.randint(0, 10, size=(matrix_size, matrix_size))
    B = np.random.randint(0, 10, size=(matrix_size, matrix_size))
    max_block_sizes = range(1, 11)
    for max_block_size in max_block_sizes:
        C, _ = distributed_multiplication(A, B, max_block_size)
        expected = A @ B
        assert np.array_equal(C, expected)
    print(f"Test (ms={matrix_size}, mbs={max_block_sizes[0]}..{max_block_sizes[-1]}): PASSED.")

Test (ms=1, mbs=1..10): PASSED.
Test (ms=2, mbs=1..10): PASSED.
Test (ms=3, mbs=1..10): PASSED.
Test (ms=4, mbs=1..10): PASSED.
Test (ms=5, mbs=1..10): PASSED.
Test (ms=6, mbs=1..10): PASSED.
Test (ms=7, mbs=1..10): PASSED.
Test (ms=8, mbs=1..10): PASSED.
Test (ms=9, mbs=1..10): PASSED.
Test (ms=10, mbs=1..10): PASSED.
Test (ms=11, mbs=1..10): PASSED.
Test (ms=12, mbs=1..10): PASSED.
Test (ms=13, mbs=1..10): PASSED.
Test (ms=14, mbs=1..10): PASSED.
Test (ms=15, mbs=1..10): PASSED.
Test (ms=16, mbs=1..10): PASSED.
Test (ms=17, mbs=1..10): PASSED.
Test (ms=18, mbs=1..10): PASSED.
Test (ms=19, mbs=1..10): PASSED.
Test (ms=20, mbs=1..10): PASSED.
