### Understanding the PyTorch parallel scan, with RNNs/Mamba in mind

This file gives a detailed background and explanation of the `pscan.py` file. The goal is to impement a parallel scan in PyTorch.

In [2]:
import torch
import math

First of all, what is a <b>scan</b> ?

A scan is defined as an operation that takes as input an array and procudes an array as output. You can see that it is quite general.

 A simple and well-known example of a scan is the <i>cumulative sum of an array</i> :

In [6]:
X = torch.tensor([1, 2, 3, 4])

torch.cumsum(X, dim=0)

tensor([ 1,  3,  6, 10])

For the rest of this document, we will denote `L` as the length of our input array `X`.

The most basic way to implement a scan is to use a simple for loop :

In [19]:
Y = torch.zeros_like(X)

cumulative_sum = 0
for t in range(X.size(0)):
    cumulative_sum += X[t]
    Y[t] = cumulative_sum

Y

tensor([ 1,  3,  6, 10])

Quite simple for now, right ?
To setup our notations, we will keep this example for a bit.

Here, we use an accumulator, `cumulative_sum`, which we update as we go through the input array `X`.

An equivalent way to rewrite the above code is :

In [20]:
Y = torch.zeros_like(X)

Y[0] = X[0]
for t in range(1, X.size(0)):
    Y[t] = Y[t-1] + X[t]

Y

tensor([ 1,  3,  6, 10])

While not explicitly present, the accumulator is still here, in `Y`. It is propagated with the recurrence relation `Y[t] = Y[t-1] + X[t]`.

We can visualize what happens with a simple diagram, which should remind you a bit about RNNs :

<p align="center">
    <img src="assets/cumsum_rnns.jpg" alt="cumulative sum" width="1000" height="300" alt="python mamba"/>
</p>

In some sense, `Y` plays the role of the hidden state, while `X` plays the role of the input : as we process the input, we keep and update a running hidden state.

We see that this method of computation, which uses a sequential loop, induces `L` sequential steps of computations in order to compute the whole output `Y`.

Now, is it possible to <b>parallelize</b> this scan operation ?

This is just what does the <b>parallel scan</b>. 

Let's stay with the simple example of our cumulative sum. In fact, let's simplify it even more : let's say with just want to compute the sum of our input array `X`.
Again, we could come up with a for loop to count the elements. But can't we parallelize this computation ?

Yes, and it is best visualized with this simple tree :

<p align="center">
    <img src="assets/reduction_tree.jpg" alt="cumulative sum" width="1000" height="600" alt="python mamba"/>
</p>

If we assume that the length of our array is a power of 2, then we just have to group the elements 2 by 2, add them, and repeat until we are left with one element, our result. If `L` is `2**d`, then we will need to do `math.log2(L)` sequential steps to compute the sum of the array. That's a major speedup over the `L` steps of the naive for loop.

How could we implement this in Python ?

In [110]:
X = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) # input array
L = X.size(0)

We can <i>group the elements by two</i> using :

In [111]:
Xa = X.view(L//2, 2)
Xa

tensor([[1, 2],
        [3, 4],
        [5, 6],
        [7, 8]])

We now have pairs of elements. We can do :

In [112]:
Xa[:, 0]

tensor([1, 3, 5, 7])

and

In [113]:
Xa[:, 1]

tensor([2, 4, 6, 8])

to access the elements from the two groups. We can see that, to compute the first step, we simply need to sum these two arrays :

In [114]:
Xa[:, 0] + Xa[:, 1]

tensor([ 3,  7, 11, 15])

Yay ! We have just accomplished our first step ! Now, we just need to repeat what we've just done.

Note that we could do `Xa = Xa[:, 0] + Xa[:, 1]` and then we would just need to repeat the previous step. But :
- this will allocate extra memory spaces for storing the result of the first step (`Xa` currently shares the data as `X`).
- we will reuse some of these values, later, for the full scan operation.

Hence, we will work by updating `X` <b>in-place</b>, by doing :

In [115]:
Xa[:, 1] += Xa[:, 0]
Xa[:, 1]

tensor([ 3,  7, 11, 15])

This is what we want. Now, we will repeat the step we have just done, but on `Xa[:, 1]` rather than on `X`.

In [116]:
Xa = Xa[:, 1]
Xa = Xa.view(Xa.size(0)//2, 2)

Again, `Xa` is split in two groups :

In [117]:
Xa

tensor([[ 3,  7],
        [11, 15]])

We sum these two groups, and put the result in the second half of `Xa`:

In [118]:
Xa[:, 1] += Xa[:, 0]
Xa[:, 1]

tensor([10, 26])

We have two elements left ! This means, only one more step to go (because `math.log2(2) = 1` of course !) :

In [119]:
Xa = Xa[:, 1]
Xa = Xa.view(Xa.size(0)//2, 2)

Xa[:, 1] += Xa[:, 0]
Xa[:, 1]

tensor([36])

Which is the result we want !