# Discovering Faster Matrix Multiplication Algorithms with Human Intelligence

**Author: [Shu Hu](https://shu-hu.com/)**

**The Australian National University and QuantEcon**

This notebook implements the naive algorithm and the Strassen algorithm for computing matrix multiplication, along with the correspondences in ``numpy`` and ``jax``.

In [1]:
import numpy as np

In [2]:
!nvidia-smi

Mon Oct 10 21:28:42 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
def generate_rsmatrix(dim):
  matrix = np.random.rand(dim, dim)
  return np.asarray(matrix)  

In [4]:
dim = 3
A1 = generate_rsmatrix(dim)
B1 = generate_rsmatrix(dim)

In [5]:
dim = 200
A2 = generate_rsmatrix(dim)
B2 = generate_rsmatrix(dim)

In [6]:
print(A1, B1)

[[0.94304355 0.74828382 0.15600759]
 [0.73302349 0.66488995 0.38824657]
 [0.5050888  0.55085124 0.81258057]] [[0.29845303 0.23299462 0.33082265]
 [0.07107561 0.02339411 0.00838108]
 [0.45377198 0.20436025 0.23852793]]


In [7]:
print(A2, B2)

[[0.42509898 0.70400463 0.14636649 ... 0.13232048 0.84970027 0.5123502 ]
 [0.60551446 0.42744885 0.01077057 ... 0.6308503  0.30225012 0.60414681]
 [0.06128461 0.37243065 0.34296435 ... 0.33040971 0.15621893 0.09334944]
 ...
 [0.26604658 0.9251207  0.56628462 ... 0.50867772 0.34978768 0.2376213 ]
 [0.76600087 0.01168033 0.33750169 ... 0.4956604  0.66285637 0.59201705]
 [0.70226882 0.25692166 0.17358614 ... 0.71665409 0.78756364 0.13157192]] [[0.56514509 0.23026495 0.08825653 ... 0.98439212 0.27856671 0.81965351]
 [0.94601523 0.91496916 0.02642946 ... 0.98010941 0.46572889 0.75238011]
 [0.43058591 0.49814698 0.0097374  ... 0.01625059 0.68881644 0.75592984]
 ...
 [0.84075232 0.53151042 0.69723371 ... 0.71089911 0.88919768 0.51878313]
 [0.08123905 0.91102651 0.93937396 ... 0.96981325 0.3734705  0.73492271]
 [0.80174272 0.2713443  0.391313   ... 0.23493563 0.8082847  0.66502398]]


# 1 Naive Algorithm

In [8]:
def naive_algorithm(A, B):
    """
        Implementation of the naive algorithm.
    """
    n = len(A)
    C = [[0 for i in range(n)] for j in range(n)]
    for i in range(n):
        for j in range(n):
            for k in range(n):
                C[i][j] += A[i][k] * B[k][j]
    return C

In [10]:
%%time
print('Matrix multiplication result: ')
print(naive_algorithm(A1, B1))

Matrix multiplication result: 
[[0.4054308089003012, 0.269111256362146, 0.3554637607158879], [0.44220595688351744, 0.2656872036915108, 0.34068092231245484], [0.558623668073736, 0.29662881698859933, 0.3655347099392852]]
CPU times: user 273 µs, sys: 0 ns, total: 273 µs
Wall time: 225 µs


In [11]:
%%time
print('Matrix multiplication result: ')
print(naive_algorithm(A2, B2))

Matrix multiplication result: 
[[52.330615415498656, 50.90114128597263, 52.391654712595006, 51.979026044324875, 55.22883141811139, 49.26242137697702, 53.310802965651476, 52.063941255030194, 51.07881288398266, 53.60394444675146, 53.52716502237283, 58.08462979907358, 54.24869456966449, 54.176837573840174, 54.409290369335864, 54.357009807927895, 52.7488755087161, 54.48566508784936, 53.9474427527761, 54.796539559446764, 50.60996173031342, 53.87728545983299, 53.452016397715894, 56.04029275734743, 53.08166492481037, 52.2486933061791, 53.98193062269189, 57.08721291821807, 50.81564708908022, 54.767749279983754, 51.37337085685377, 54.33968170989236, 51.186510358007546, 48.66398119776428, 52.930380368924006, 53.34168935380297, 54.05808046084696, 50.99104601978581, 50.60049204711659, 53.61413647093008, 51.835230805926166, 48.91280437336336, 51.598606664295865, 54.11441608767946, 51.45680181985973, 55.55343134116727, 56.99340445364107, 52.65987574351712, 48.23111388286944, 53.138729815029386, 52.9

# 2 Strassen Algorithm

In [12]:
def add(A, B):
    """
        Implementation of the matrix addition.
    """
    n = len(A)
    C = [[0 for j in range(0, n)] for i in range(0, n)]
    for i in range(0, n):
        for j in range(0, n):
            C[i][j] = A[i][j] + B[i][j]
    return C

In [13]:
def subtract(A, B):
    """
        Implementation of the matrix subtraction.
    """
    n = len(A)
    C = [[0 for j in range(0, n)] for i in range(0, n)]
    for i in range(0, n):
        for j in range(0, n):
            C[i][j] = A[i][j] - B[i][j]
    return C

In [14]:
def strassen(A, B):
    """
        Implementation of the strassen algorithm.
    """
    n = len(A)
    N = 8
    if n <= N:
        return naive_algorithm(A, B)
    else:
        # initializing the new sub-matrices
        newSize = n / 2
        newSize = int(newSize)
        a11 = [[0 for j in range(0, newSize)] for i in range(0, newSize)]
        a12 = [[0 for j in range(0, newSize)] for i in range(0, newSize)]
        a21 = [[0 for j in range(0, newSize)] for i in range(0, newSize)]
        a22 = [[0 for j in range(0, newSize)] for i in range(0, newSize)]

        b11 = [[0 for j in range(0, newSize)] for i in range(0, newSize)]
        b12 = [[0 for j in range(0, newSize)] for i in range(0, newSize)]
        b21 = [[0 for j in range(0, newSize)] for i in range(0, newSize)]
        b22 = [[0 for j in range(0, newSize)] for i in range(0, newSize)]
        aResult = [[0 for j in range(0, newSize)] for i in range(0, newSize)]
        bResult = [[0 for j in range(0, newSize)] for i in range(0, newSize)]

        # dividing the matrices in 4 sub-matrices:
        for i in range(0, newSize):
            for j in range(0, newSize):
                a11[i][j] = A[i][j]  # top left
                a12[i][j] = A[i][j + newSize]  # top right
                a21[i][j] = A[i + newSize][j]  # bottom left
                a22[i][j] = A[i + newSize][j + newSize]  # bottom right

                b11[i][j] = B[i][j]  # top left
                b12[i][j] = B[i][j + newSize]  # top right
                b21[i][j] = B[i + newSize][j]  # bottom left
                b22[i][j] = B[i + newSize][j + newSize]  # bottom right

        # Calculating p1 to p7:
        aResult = add(a11, a22)
        bResult = add(b11, b22)
        p1 = strassen(aResult, bResult)  # p1 = (a11+a22) * (b11+b22)

        aResult = add(a21, a22)  # a21 + a22
        p2 = strassen(aResult, b11)  # p2 = (a21+a22) * (b11)

        bResult = subtract(b12, b22)  # b12 - b22
        p3 = strassen(a11, bResult)  # p3 = (a11) * (b12 - b22)

        bResult = subtract(b21, b11)  # b21 - b11
        p4 = strassen(a22, bResult)  # p4 = (a22) * (b21 - b11)

        aResult = add(a11, a12)  # a11 + a12
        p5 = strassen(aResult, b22)  # p5 = (a11+a12) * (b22)

        aResult = subtract(a21, a11)  # a21 - a11
        bResult = add(b11, b12)  # b11 + b12
        p6 = strassen(aResult, bResult)  # p6 = (a21-a11) * (b11+b12)

        aResult = subtract(a12, a22)  # a12 - a22
        bResult = add(b21, b22)  # b21 + b22
        p7 = strassen(aResult, bResult)  # p7 = (a12-a22) * (b21+b22)

        # calculating c21, c21, c11 e c22:
        c12 = add(p3, p5)  # c12 = p3 + p5
        c21 = add(p2, p4)  # c21 = p2 + p4

        aResult = add(p1, p4)  # p1 + p4
        bResult = add(aResult, p7)  # p1 + p4 + p7
        c11 = subtract(bResult, p5)  # c11 = p1 + p4 - p5 + p7

        aResult = add(p1, p3)  # p1 + p3
        bResult = add(aResult, p6)  # p1 + p3 + p6
        c22 = subtract(bResult, p2)  # c22 = p1 + p3 - p2 + p6

        # Grouping the results obtained in a single matrix:
        C = [[0 for j in range(0, n)] for i in range(0, n)]
        for i in range(0, newSize):
            for j in range(0, newSize):
                C[i][j] = c11[i][j]
                C[i][j + newSize] = c12[i][j]
                C[i + newSize][j] = c21[i][j]
                C[i + newSize][j + newSize] = c22[i][j]
        return C

In [15]:
%%time
print('Matrix multiplication result: ')
print(strassen(A1, B1))

Matrix multiplication result: 
[[0.4054308089003012, 0.269111256362146, 0.3554637607158879], [0.44220595688351744, 0.2656872036915108, 0.34068092231245484], [0.558623668073736, 0.29662881698859933, 0.3655347099392852]]
CPU times: user 221 µs, sys: 14 µs, total: 235 µs
Wall time: 196 µs


In [16]:
%%time
print('Matrix multiplication result: ')
print(strassen(A2, B2))

Matrix multiplication result: 
[[50.367126600143855, 48.48566831591093, 49.81136988739035, 49.30134955468503, 52.90249094907726, 46.91097114735212, 50.35245873938994, 49.68323031101523, 48.44614824342442, 51.054001119513174, 50.87358682158698, 54.82162180225407, 52.65947123092849, 51.09677724813379, 51.28943332093763, 51.62766838614193, 49.22207568216751, 52.17153648240012, 50.38971048753444, 51.73087704918916, 49.35574644599217, 51.738685137847995, 50.15560344173418, 51.81559019808172, 0, 49.329708455568934, 52.002633871593275, 54.611565224149, 49.97651729902658, 51.922150680913035, 48.94560640474534, 51.35336398395013, 48.58380267083384, 46.27596531255258, 50.09360206571146, 50.33932682202766, 51.41937010961023, 48.68484283075792, 47.53737882973214, 51.33305114438549, 48.81060738401423, 45.95479050623963, 49.49227574486318, 51.66353175120475, 48.52875831875411, 53.32323228050055, 54.53684619392929, 49.86994034275049, 46.49786812006381, 0, 50.55962725439088, 55.42237236349991, 46.2895

# 3 What's More?

In [17]:
dim = 10_000
A3 = generate_rsmatrix(dim)
B3 = generate_rsmatrix(dim)

In [18]:
%%time
print('Matrix multiplication result: ')
print(naive_algorithm(A3, B3))

Matrix multiplication result: 


KeyboardInterrupt: ignored

In [19]:
%%time
print('Matrix multiplication result: ')
print(strassen(A3, B3))

Matrix multiplication result: 


KeyboardInterrupt: ignored

If you are dealing with matrix's width greater than 128 bit, I would suggest that you should do two things:
- forget about the vanilla algorithms mentioned above, and
- nestle in the arms of parallel programming, such as ``Google JAX`` or ``Dask``.

## Faster Algorithms with Google JAX

In [21]:
import jax.numpy as jnp

In [24]:
%%time
print('Matrix multiplication result: ')
print(np.dot(A3, B3))

Matrix multiplication result: 
[[2532.64421385 2483.42277652 2497.1002876  ... 2506.9190752
  2477.77876239 2513.41459959]
 [2508.58291853 2481.70982135 2487.6436791  ... 2489.03995028
  2475.40773258 2477.66680271]
 [2531.77969175 2487.85584773 2514.92096546 ... 2522.00472228
  2491.5362792  2511.49777235]
 ...
 [2507.60401414 2492.88211722 2500.31797113 ... 2509.37242068
  2489.20126204 2496.85991196]
 [2505.39090349 2491.10026972 2502.93401935 ... 2512.64002254
  2478.9989733  2486.16066312]
 [2482.13076291 2462.00431789 2481.14207082 ... 2489.1846613
  2468.77219491 2462.59481489]]
CPU times: user 1min 4s, sys: 738 ms, total: 1min 4s
Wall time: 16.5 s


In [25]:
%%time
print('Matrix multiplication result: ')
print(jnp.dot(A3, B3))

Matrix multiplication result: 
[[2532.6467 2483.42   2497.096  ... 2506.925  2477.781  2513.4048]
 [2508.5845 2481.705  2487.646  ... 2489.0454 2475.409  2477.6675]
 [2531.789  2487.8481 2514.9167 ... 2521.9993 2491.5369 2511.5051]
 ...
 [2507.599  2492.8782 2500.3176 ... 2509.3735 2489.2021 2496.8577]
 [2505.3906 2491.0981 2502.9302 ... 2512.641  2479.0068 2486.1619]
 [2482.1333 2462.0093 2481.1394 ... 2489.1853 2468.7717 2462.5955]]
CPU times: user 1.02 s, sys: 511 ms, total: 1.53 s
Wall time: 1.11 s
