In [1]:
import torch

from pscan_unfolded import pscan
from pscan_unfolded_rev import pscan as pscan_w_rev

In [2]:
device = "cuda"

In [3]:
B, L, D, N = 45, 64, 32, 16

In [4]:
torch.manual_seed(1)
A = torch.randn((B, L, D, N), dtype=torch.float64).to(device)
X = torch.randn((B, L, D, N), dtype=torch.float64).to(device)

A.requires_grad = True
X.requires_grad = True

In [5]:
Y = pscan(A, X)
J = Y.sum()
J.backward()

In [6]:
torch.manual_seed(1)
Ab = torch.randn((B, L, D, N), dtype=torch.float64).to(device)
Xb = torch.randn((B, L, D, N), dtype=torch.float64).to(device)

Ab.requires_grad = True
Xb.requires_grad = True

In [7]:
Y_w_rev = pscan_w_rev(Ab, Xb)
J_w_rev = Y_w_rev.sum()
J_w_rev.backward()

In [8]:
torch.allclose(Y, Y_w_rev)

True

In [9]:
torch.allclose(X.grad, Xb.grad, rtol=0.001)

True

In [10]:
torch.allclose(A.grad, Ab.grad, rtol=0.001)

True

In [23]:
B, L, D, N = 1, 8, 1, 1

In [24]:
A = torch.ones(B, D, L, N, device=device)

In [32]:
B = torch.cat([A[:, :, 1:], torch.zeros_like(A[:, :, [-1]])], dim=2) # todo : ca fait deja une copie non ? donc pas besoin de clone ? timeit
C = torch.nn.functional.pad(A[:, :, 1:], (0, 0, 0, 1))

In [33]:
A, B, C

(tensor([[[[1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.]]]], device='cuda:0'),
 tensor([[[[1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [0.]]]], device='cuda:0'),
 tensor([[[[1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [1.],
           [0.]]]], device='cuda:0'))

In [29]:
torch.allclose(B, C)

False

In [7]:
torch.nn.functional.pad(A[:, :, 1:], (0, 0, 1, 0)).shape

torch.Size([45, 32, 64, 16])

In [9]:
torch.cat([A[:, :, 1:], A[:, :, [-1]]], dim=2).shape

torch.Size([45, 32, 64, 16])

In [40]:
%timeit torch.cat([A[:, :, 1:], A[:, :, [-1]]], dim=2)

10.5 ms ± 218 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [41]:
%timeit torch.nn.functional.pad(A[:, :, 1:], (0, 0, 1, 0))

8.92 ms ± 54.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [21]:
%timeit torch.cat([A[:, :, 1:], A[:, :, [-1]]], dim=2)

509 µs ± 34.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [22]:
%timeit A.clone()

406 µs ± 2.37 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [15]:
import math

def pscan(A, X):
    # A : (B, D, L, N)
    # X : (B, D, L, N)

    # modifies X in place by doing a parallel scan.
    # more formally, X will be populated by these values :
    # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
    # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
        
    B, D, L, _ = A.size()
    num_steps = int(math.log2(L))

    # up sweep (last 2 steps unfolded)
    Aa = A
    Xa = X
    for _ in range(num_steps-2):
        T = Xa.size(2)
        Aa = Aa.view(B, D, T//2, 2, -1)
        Xa = Xa.view(B, D, T//2, 2, -1)
            
        Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
        Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])

        Aa = Aa[:, :, :, 1]
        Xa = Xa[:, :, :, 1]

    # we have only 4, 2 or 1 nodes left
    if Xa.size(2) == 4:
        Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
        Aa[:, :, 1].mul_(Aa[:, :, 0])

        Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1])))
        #Aa[:, :, 3].mul_(Aa[:, :, 2]) # todo : virer ?
    elif Xa.size(2) == 2:
        Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
        #Aa[:, :, 1].mul_(Aa[:, :, 0]) # todo : virer ?
        return
    else:
        return

    # down sweep (first 2 steps unfolded)
    Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
    Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
    Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1]))
    Aa[:, :, 2].mul_(Aa[:, :, 1])

    for k in range(num_steps-3, -1, -1):
        Aa = A[:, :, 2**k-1:L:2**k]
        Xa = X[:, :, 2**k-1:L:2**k]

        T = Xa.size(2)
        Aa = Aa.view(B, D, T//2, 2, -1)
        Xa = Xa.view(B, D, T//2, 2, -1)

        Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
        Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])

def pscan_rev(A, X):
    # A : (B, D, L, N)
    # X : (B, D, L, N)

    B, D, L, _ = A.size()
    num_steps = int(math.log2(L))

    # up sweep (last 2 steps unfolded)
    Aa = A
    Xa = X
    for _ in range(num_steps-2):
        T = Xa.size(2)
        Aa = Aa.view(B, D, T//2, 2, -1)
        Xa = Xa.view(B, D, T//2, 2, -1)
                
        Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1]))
        Aa[:, :, :, 0].mul_(Aa[:, :, :, 1])

        Aa = Aa[:, :, :, 0]
        Xa = Xa[:, :, :, 0]

    # we have only 4, 2 or 1 nodes left
    if Xa.size(2) == 4:
        Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3]))
        Aa[:, :, 2].mul_(Aa[:, :, 3])

        Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1] + Aa[:, :, 1].mul(Xa[:, :, 2])))
        #Aa[:, :, 3].mul_(Aa[:, :, 2]) # todo : virer ?
    elif Xa.size(2) == 2:
        Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1]))
        #Aa[:, :, 1].mul_(Aa[:, :, 0]) # todo : virer ?
        return
    else:
        return

    # down sweep (first 2 steps unfolded)
    Aa = A[:, :, 0:L:2**(num_steps-2)]
    Xa = X[:, :, 0:L:2**(num_steps-2)]
    Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2]))
    Aa[:, :, 1].mul_(Aa[:, :, 2])

    for k in range(num_steps-3, -1, -1):
        Aa = A[:, :, 0:L:2**k]
        Xa = X[:, :, 0:L:2**k]

        T = Xa.size(2)
        Aa = Aa.view(B, D, T//2, 2, -1)
        Xa = Xa.view(B, D, T//2, 2, -1)

        Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0]))
        Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0])

In [25]:
Ab = torch.tensor([1, 1, 2, 1]).view(1, 1, 4, 1).float().to(device)
Xb = torch.tensor([1, 1, 1, 1]).float().view(1, 1, 4, 1).to(device)

In [26]:
pscan_rev(Ab, Xb)

In [27]:
Xb

tensor([[[[5.],
          [4.],
          [3.],
          [1.]]]], device='cuda:0')

In [25]:
Xb

tensor([[[[ 9.],
          [12.],
          [11.],
          [ 5.],
          [ 4.],
          [ 3.],
          [ 2.],
          [ 1.]]]], device='cuda:0')