In [5]:
import math
import torch

from pscan_unfolded import pscan

In [6]:
B, L, D, N = 1, 16, 1, 1

A = torch.randn(B, L, D, N).to("cuda")
X = torch.randn(B, L, D, N).to("cuda")

In [67]:
def pscan(A, X):
    B, D, L, _ = A.size()
    num_steps = int(math.log2(L))

    # up sweep (last 2 steps unfolded)
    Aa = A
    Xa = X
    for _ in range(num_steps-2):
        T = Xa.size(2)
        Aa = Aa.view(B, D, T//2, 2, -1)
        Xa = Xa.view(B, D, T//2, 2, -1)
            
        Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
        Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])

        Aa = Aa[:, :, :, 1]
        Xa = Xa[:, :, :, 1]

    # we have only 4, 2 or 1 nodes left
    if Xa.size(2) == 4:
        Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
        Aa[:, :, 1].mul_(Aa[:, :, 0])

        Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1])))
        #Aa[:, :, 3].mul_(Aa[:, :, 2]) # todo : virer ?
    elif Xa.size(2) == 2:
        Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
        #Aa[:, :, 1].mul_(Aa[:, :, 0]) # todo : virer ?
        return
    else:
        return

    # down sweep (first 2 steps unfolded)
    Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
    Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
    Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1]))
    Aa[:, :, 2].mul_(Aa[:, :, 1])

    for k in range(num_steps-3, -1, -1):
        Aa = A[:, :, 2**k-1:L:2**k]
        Xa = X[:, :, 2**k-1:L:2**k]

        T = Xa.size(2)
        Aa = Aa.view(B, D, T//2, 2, -1)
        Xa = Xa.view(B, D, T//2, 2, -1)

        Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
        Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])

def pscan_rev(A, X):
    B, D, L, _ = A.size()
    num_steps = int(math.log2(L))

    # up sweep (last 2 steps unfolded)
    Aa = A
    Xa = X
    for _ in range(num_steps-2):
        T = Xa.size(2)
        Aa = Aa.view(B, D, T//2, 2, -1)
        Xa = Xa.view(B, D, T//2, 2, -1)
            
        Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1]))
        Aa[:, :, :, 0].mul_(Aa[:, :, :, 1])

        Aa = Aa[:, :, :, 0]
        Xa = Xa[:, :, :, 0]

    # we have only 4, 2 or 1 nodes left
    if Xa.size(2) == 4:
        Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3]))
        Aa[:, :, 2].mul_(Aa[:, :, 3])

        Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1] + Aa[:, :, 2].mul(Xa[:, :, 2])))
        #Aa[:, :, 3].mul_(Aa[:, :, 2]) # todo : virer ?
    elif Xa.size(2) == 2:
        Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1]))
        #Aa[:, :, 1].mul_(Aa[:, :, 0]) # todo : virer ?
        return
    else:
        return

    # down sweep (first 2 steps unfolded)
    Aa = A[:, :, 0:L:2**(num_steps-2)]
    Xa = X[:, :, 0:L:2**(num_steps-2)]
    Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2]))
    Aa[:, :, 1].mul_(Aa[:, :, 2])

    for k in range(num_steps-3, -1, -1):
        Aa = A[:, :, 0:L:2**k]
        Xa = X[:, :, 0:L:2**k]

        T = Xa.size(2)
        Aa = Aa.view(B, D, T//2, 2, -1)
        Xa = Xa.view(B, D, T//2, 2, -1)

        Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0]))
        Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0])

In [71]:
A = torch.ones(8).long().view(1, 1, 8, 1)
X = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(1, 1, 8, 1)

In [72]:
pscan_rev(A, X)

In [73]:
X

tensor([[[[36],
          [35],
          [33],
          [30],
          [26],
          [21],
          [15],
          [ 8]]]])

In [None]:
from 