本篇主要参考  
- `Efficient Parallelization of a Ubiquitous Sequential Computation`  
- `Were RNNs All We Needed?`  
- [并行扫描博客](https://moreality.net/posts/54261/)  

兼作实现和探究  

# 算法介绍

传统的并行扫描（Blelloh）是针对如下求和问题,以最快且复杂度最低的方法根据输入计算输出  
就可以实现对序列加进行并行化
```shell
# input:
[b0, b1, b2, b3, ..., bn]

# output:
[b0, (b0+b1), (b0+b1+b2), ..., (b0+b1+b2+...+bn)]
```

上述问题可以用如下迭代表达式描述  
$$
x_{t} = x_{t-1} + b_{t}\ \ \ \ \ (x_0 = b_0)
$$

那么对于
$$
x_t = a_t \cdot x_{t-1} + b_t \ \ \ \ (x_0 = b_0)
$$
上面提到的这篇文章就提出了将这种形式转化为$\Sigma$的基本形式，从而可以应用并行扫描
即
```shell
#input
[(,b0),(a1,b1),(a2,b2),(a3,b3)...,(an,bn)]

#output
[x0,x1,x2,x3...,xn]
```
这是一个线性运算序列的并行化问题，后面将展示如何将其转换为求和序列的并行化问题

这篇文章的结论是
$$
log\ x_t = a_t^* + log(x_0 + b_t^*)
$$
其中
$$
a_t^*\ =\ \sum_tlog\ a_t
$$
$$
b_t^*\ =\ \sum_te^{logb_t-a_t^*}
$$
所以我们要求序列$x_t$，只需要求出序列$a_t^*和b_t^*$即可  
而$a_t^*$是求和序列，可以直接用并行扫描    
$a_t^*$计算完毕之后，$b_t^*$也可以转换为求和序列，问题就迎刃而解了  
下图给出计算的顺序

![image.png](attachment:98c60276-42d3-43be-a17d-d72e4f5db140.png)
需要注意padding的问题，在`Were RNNs All We Needed?`的代码附录中可以参考

# 算法验证

为了确保我们的理解无误，我们 **(1)** 先使用传统方法对下述问题进行求解，**(2)** 再使用上图提到的算法结合并行扫描实现，**(3)** 并比对结果
```shell
#input
[(,b0),(a1,b1),(a2,b2),(a3,b3)...,(an,bn)]

#output
[x0,x1,x2,x3...,xn]
```
首先给出并行扫描的函数实现

In [11]:
import random
def blelloch_scan(arr):
        
    arr = list(arr)
    n = len(arr)

    print("Initial array:", arr)

    step = 1
    print("\nUpsweep Phase:")
    while step < n:
        for i in range(0, n, 2 * step):
            if i + step < n:
                arr[i + 2 * step - 1] += arr[i + step - 1]
        print(f"After step size {step}, array: {arr}")
        step *= 2

    arr[-1] = 0
    print("\nSetting root to 0, array:", arr)

    step = n // 2
    print("\nDownsweep Phase:")
    while step > 0:
        for i in range(0, n, 2 * step):
            if i + step < n:
                temp = arr[i + step - 1]
                arr[i + step - 1] = arr[i + 2 * step - 1]
                arr[i + 2 * step - 1] += temp
        print(f"After step size {step}, array: {arr}")
        step //= 2
        
        
    return arr

input_array = [random.randint(0,10) for _ in range(16)]
arr = blelloch_scan(input_array)
prefix_sum_array = [input_array[i] + arr[i] for i in range(len(input_array))]

print(f"\nFinal PrefixSum array: {prefix_sum_array}")

Initial array: [1, 9, 2, 7, 7, 6, 10, 4, 5, 1, 4, 3, 4, 7, 3, 9]

Upsweep Phase:
After step size 1, array: [1, 10, 2, 9, 7, 13, 10, 14, 5, 6, 4, 7, 4, 11, 3, 12]
After step size 2, array: [1, 10, 2, 19, 7, 13, 10, 27, 5, 6, 4, 13, 4, 11, 3, 23]
After step size 4, array: [1, 10, 2, 19, 7, 13, 10, 46, 5, 6, 4, 13, 4, 11, 3, 36]
After step size 8, array: [1, 10, 2, 19, 7, 13, 10, 46, 5, 6, 4, 13, 4, 11, 3, 82]

Setting root to 0, array: [1, 10, 2, 19, 7, 13, 10, 46, 5, 6, 4, 13, 4, 11, 3, 0]

Downsweep Phase:
After step size 8, array: [1, 10, 2, 19, 7, 13, 10, 0, 5, 6, 4, 13, 4, 11, 3, 46]
After step size 4, array: [1, 10, 2, 0, 7, 13, 10, 19, 5, 6, 4, 46, 4, 11, 3, 59]
After step size 2, array: [1, 0, 2, 10, 7, 19, 10, 32, 5, 46, 4, 52, 4, 59, 3, 70]
After step size 1, array: [0, 1, 10, 12, 19, 26, 32, 42, 46, 51, 52, 56, 59, 63, 70, 73]

Final PrefixSum array: [1, 10, 12, 19, 26, 32, 42, 46, 51, 52, 56, 59, 63, 70, 73, 82]
