In [5]:
import torch

from mambapy.pscan import pscan

In [6]:
B, L, D, N = 1, 256, 58, 6
cut1 = 74
cut2 = 245

In [3]:
# full pscan

torch.manual_seed(123456)
A = torch.randn(B, L, D, N, requires_grad=True)
X = torch.randn(B, L, D, N, requires_grad=True)

hs = pscan(A, X)

J = hs.sum()
J.backward()

In [4]:
# chunked pscan

torch.manual_seed(123456)
A_chunks = torch.randn(B, L, D, N, requires_grad=True)
X_chunks = torch.randn(B, L, D, N, requires_grad=True)

hs_1 = pscan(A_chunks[:, :cut1], X_chunks[:, :cut1])
hs_2 = pscan(A_chunks[:, cut1:cut2], X_chunks[:, cut1:cut2], hs_1[:, -1])
hs_3 = pscan(A_chunks[:, cut2:], X_chunks[:, cut2:], hs_2[:, -1])

hs_chunks = torch.cat([hs_1, hs_2, hs_3], dim=1)

J_chunks = hs_chunks.sum()
J_chunks.backward()

In [5]:
# checks

print(torch.allclose(hs, hs_chunks, rtol=0.001))
print(torch.allclose(A.grad, A_chunks.grad, rtol=0.01))
print(torch.allclose(X.grad, X_chunks.grad, rtol=0.01))

True
True
True


In [1]:
import torch

from mambapy.pscan import pscan

In [2]:
B, L, D, N = 1, 256, 58, 6

In [3]:
# full pscan

torch.manual_seed(123456)
A = torch.randn(B, L, D, N, requires_grad=True)
X = torch.randn(B, L, D, N, requires_grad=True)

hs = pscan(A, X)

J = hs.sum()
J.backward()

In [6]:
# chunked pscan

torch.manual_seed(123456)
A_chunks = torch.randn(B, L, D, N, requires_grad=True)
X_chunks = torch.randn(B, L, D, N, requires_grad=True)

chunk_size = 63 # best is power of 2 minus 1
num_chunks = L // chunk_size
remainder = L % chunk_size

print(f"number of chunks: {num_chunks} and remainder: {remainder}")

last_hidden = None
hs_chunks = []

for i in range(num_chunks):
    start_idx = i * chunk_size
    end_idx = start_idx + chunk_size
    A_chunk = A_chunks[:, start_idx:end_idx]
    X_chunk = X_chunks[:, start_idx:end_idx]
    
    # kind of a bug of pytorch : we you send 3 argd, backward has to send back 3 tensors
    if last_hidden is None:
        hs_chunk = pscan(A_chunk, X_chunk)
    else:
        hs_chunk = pscan(A_chunk, X_chunk, last_hidden)
    last_hidden = hs_chunk[:, -1]

    hs_chunks.append(hs_chunk)

if remainder > 0:
    remainder_start_idx = num_chunks * chunk_size
    A_chunk = A_chunks[:, remainder_start_idx:]
    X_chunk = X_chunks[:, remainder_start_idx:]

    hs_chunk = pscan(A_chunk, X_chunk, last_hidden)

    hs_chunks.append(hs_chunk)

hs_chunks = torch.cat(hs_chunks, dim=1)

J_chunks = hs_chunks.sum()
J_chunks.backward()

number of chunks: 4 and remainder: 4


In [7]:
# checks

print(torch.allclose(hs, hs_chunks, atol=0.00001))
print(torch.allclose(A.grad, A_chunks.grad, rtol=0.01))
print(torch.allclose(X.grad, X_chunks.grad, rtol=0.01))

True
True
True
