<a href="https://colab.research.google.com/github/RubeRad/tcscs/blob/master/Strassen.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Strassen's Algorithm for Matrix Multiplication

## Naive Matrix Multiplication

First, recall that the algorithmic complexity to multiply two $n\times n$ matrices $O(n^3)$ (using the straightforward algorithm).

More specifically, each entry in the matrix product is the dot product of a $1\times n$ row and a $n \times 1$ column, which requires $n$ multiplications and $n-1$ additions, so $2n-1$ operations. There are $n^2$ entries in the product, so a precise total of $$n^2(2n-1) = 2n^3-n^2$$ operations. The highest-order term is all that matters, and the constant factor doesn't matter, so matrix multiplication is $O(n^3)$, and doubling the size of the matrices means they will take $8\times$ longer to multiply.

Note also, *addition* of two $n\times n$ matrices requires exactly $n^2$ individual additions, so $O(n^2)$, which is *way* faster (compare for $n=4$: 16 operations to add vs 112 to multiply, and the gap only increases!)



## Creating and multiplying matrices with numpy

To investigate matrix multiplication, we turn to numpy, which is made basically for dealing with matrices:

In [1]:
import numpy as np

In [None]:
# we've seen this before, it just creates a range of equally-spaced numbers
np.arange(1,17)   # note the range is inclusive-->[1,17)<--exclusive. Python basically always does this

In [None]:
# reshape turns that 1x16 'row-vector' into a 4x4 matrix
np.arange(1,17).reshape(4,4)

In [None]:
# let's create two of these and hold onto them in variables named P1 and P2
P1 = np.arange(1,17).reshape(4,4)
P1

In [None]:
P2 = np.arange(17,33).reshape(4,4)
P2

In [None]:
# Unfortunately the * operator does element-wise multiplication, not regular matrix multiplication
P1 * P2 # this is no good

In [None]:
# numpy's special operator for matrix multiplication is @
product = P1 @ P2
product

## Matrix Multiplication as Blocks

So almost the easiest example of matrix multiplication is $2\times 2$. Whatever the entries in the matrices are, their product is:
$$\begin{bmatrix}A&&B\\C&&D\end{bmatrix}\begin{bmatrix}E&&F\\G&&H\end{bmatrix} = \begin{bmatrix}AE+BG&&AF+BH\\CE+DG&&CF+DH\end{bmatrix}$$
It turns out, that equation is also true for matrices larger than 2x2, if you let $A-H$ be not just individual numerical/scalar elements, but submatrices.

This can be proved generally, but we can see that it is the case by blocking these $4\times 4$ matrices into $2\times 2$ submatrices.

In [None]:
# numpy has indexing that enables extraction of submatrices.
# Note again, a range like i:e is [inclusive,exclusive)
A = P1[0:2, 0:2]   # rows 0 and 1; and columns 0 and 1, of P1
A

In [None]:
B = P1[0:2, 2:4]  # rows 0 and 1; and columns 2 and 3, of P1
B

In [None]:
C = P1[2:4, 0:2] # rows 2 and 3; and columns 0 and 1, of P1
C

In [None]:
D = P1[2:4, 2:4] # rows 2 and 3; and columns 2 and 3, of P1
D

In [None]:
# similarly block up P2 into e,f,g,h
E = P2[0:2, 0:2]
F = P2[0:2, 2:4]
G = P2[2:4, 0:2]
H = P2[2:4, 2:4]
H

Now we can assemble all four blocks of the product matrix. Recall
$$\begin{bmatrix}A&&B\\C&&D\end{bmatrix}\begin{bmatrix}E&&F\\G&&H\end{bmatrix} = \begin{bmatrix}AE+BG&&AF+BH\\CE+DG&&CF+DH\end{bmatrix}$$

In [None]:
UL = A@E + B@G # UL stands for Upper Left
UL

You can already see that matches the upper-left block of `product` like it's supposed to. Let's do the other three.

In [None]:
UR = A@F + B@H
UR

In [None]:
LL = C@E + D@G
LL

In [None]:
LR = C@F + D@H
LR

We could just eyeball all of those and verify each block, but let's go ahead and assemble them into a full $4\times 4$ matrix:

In [None]:
upper_left_right = np.hstack( (UL, UR) )  # hstack is 'horizontal stack', we smash two matrices together left and right
upper_left_right

In [None]:
lower_left_right = np.hstack( (LL, LR) )
lower_left_right

In [None]:
block_product = np.vstack( (upper_left_right, lower_left_right) ) # vstack is 'vertical stack'
block_product

In [None]:
product # compare

### Good news
We've designed a divide & conquer algorithm!

### Bad news
Turns out it's exactly the same multiplications/additions as the regular way, just reshuffled, so no savings.

### Hand-wavy analysis
Look at the cells up there for computing `UL,UR,LL,LR`: that involves **8 multiplications of half-size matrices** (plus four additions of half-size matrices). So $2n\times 2n$ matrix multiplication is the same difficulty as 8  $n\times n$ multiplications. $2\times$ the input $\rightarrow 8\times$ the output -- that's precisely what $O(n^3)$ means!



## The Strassen Algorithm
Instead of intermediate products `UL,UR,LL,LR`, which slide right into place in the output, instead, if you're working on this matrix product:
$$\begin{bmatrix}A&&B\\C&&D\end{bmatrix}\begin{bmatrix}E&&F\\G&&H\end{bmatrix} = \begin{bmatrix}AE+BG&&AF+BH\\CE+DG&&CF+DH\end{bmatrix}$$
compute these matrices:

* $M1 = (A+D)(E+H)$
* $M2 = (C+D)E$
* $M3 = A(F-H)$
* $M4 = D(G-E)$
* $M5 = (A+B)H$
* $M6 = (C-A)(E+F)$
* $M7 = (B-D)(G+H)$

THEN, using those, compute these 4:

1. $M1 + M4 - M5 + M7$
1. $M3 + M5$
1. $M2 + M4$
1. $M1 - M2 + M3 + M6$

## **Exercises**

1. Each student take one of 1-4, and on scratch paper, substitute the formulas for $Mi$ above, and simplify.

1. If $A-H$ are all $n\times n$, what is the total number of
  1. $n\times n$ matrix multiplications required to compute 1-4
  1. $n\times n$ matrix additions/subtractions required to compute 1-4?


## Strassen Example with numpy

We can use numpy to do all that for our $4\times 4$ example

In [None]:
M1 = (A+D) @ (E+H)
M2 = (C+D) @ E
M3 = A @ (F-H)
M4 = D @ (G-E)
M5 = (A+B) @ H
M6 = (C-A) @ (E+F)
M7 = (B-D) @ (G+H)

In [None]:
M1 + M4 - M5 + M7

In [None]:
M3 + M5

In [None]:
M2 + M4

In [None]:
M1 - M2 + M3 + M6

In [None]:
strassen_product = np.vstack( (np.hstack(( M1 + M4 - M5 + M7,   M3 + M5           )),
                               np.hstack(( M2 + M4,             M1 - M2 + M3 + M6 ))))
strassen_product

In [None]:
product - strassen_product

## Exercise Answers





1. 1-4 simplify to the same thing as UL, UR, LR, LL
1.
  1. 7 multiplications of $n\times n$ matrices
  1. 18 additions/subtractions of $n\times n$ matrices

## Strassen Analysis

We've got three options for multiplication here
1. Naive: $n\times $n takes exactly $2n^3 - n^2$ operations.
2. Block: $2n\times 2n$ requires **8** $n\times n$ multiplications and **4** $n\times n$ additions.
  * Turns out this also works out to $2n^3 - n^2$ operations.
3. Strassen: $2n\times 2n$ requires **7** $n\times n$ multiplications and **18** $n\times n$ additions.


The key insight here is that, **because matrix multiplication is so much slower than addition, it's worth it to do 14 extra additions to save one multiplication.**

Here's how it maths out. Let $s(n)$ be the number of operations required for Strassen multiplication of an $n\times n$ matrix. Then
$$s(n) = 7s(n/2) + 18(n/2)^2$$
With a little hand-waving, it turns out we can ignore the $18(n/2)^2$ because the additions are negligible compared to the multiplications. With that simplification,
$$ s(n) \approx 7s(n/2) \approx 7(7s(n/4)) \approx \ldots$$
After $log_2 n$ expansions, the $n$ is halved all the way down to 1, leaving the product of $log_2$ sevens, aka
$$s(n) \approx 7^{log_2 n}$$
Use of the properties of logarithms means we can restate that equivalently as
$$s(n) \approx n^{log_2 7} \approx n^{2.8}$$


## Exercise: Strassen vs Naive/Block

So the fact that Strassen is $O(n^{2.8})$ means that it beats the normal $O(n^3)$ matrix multiplication *in the long run, when $n$ is large enough*. We want to find that break-even point -- when is $n$ large enough that Strassen's algorithm would be faster? Smaller than that we should just use regular multiplication.

Here is the beginning of a table. The exercise is to take this to a spreadsheet and continue it to find the break-even point.

|&nbsp; &nbsp; &nbsp; &nbsp; $n$   |  &nbsp; &nbsp; &nbsp; &nbsp; Naive: $2n^3-n^2$  | &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; Strassen=$7s(n/2)+18(n/2)^2$  |
|-----:|-----------------------:|-------------------------:|
1 |    1 |  1 |
2 |   12 | $7(1)+18(1)^2$ = 25 |
4 |  112 | $7(25)+18(2)^2$ = 247 |
