In [None]:
import torch
import selective_scan_cuda

def run_forward(h0, u, delta, A, Bmat, C, Dvec=None, z=None, delta_bias=None, delta_softplus=False):
    # returns: [out, x, (out_z if has_z), last_state]
    return selective_scan_cuda.fwd(h0, u, delta, A, Bmat, C, Dvec, z, delta_bias, delta_softplus)

def run_backward(h0, u, delta, A, Bmat, C, Dvec,
                 z, delta_bias, dout, d_last_state=None,
                 x=None, out=None, dz=None,
                 delta_softplus=False, recompute_out_z=False):
    return selective_scan_cuda.bwd(h0, u, delta, A, Bmat, C,
                                   Dvec, z, delta_bias,
                                   dout, d_last_state, x, out, dz,
                                   delta_softplus, recompute_out_z)

def normalize_outs(outs):
    """Convert list of tensors/None to CPU tensors or None."""
    return [t.detach().cpu() if isinstance(t, torch.Tensor) else None for t in outs]

def list_equal(a, b):
    if len(a) != len(b): return False
    for ta, tb in zip(a, b):
        if ta is None and tb is None:
            continue
        if (ta is None) ^ (tb is None):
            return False
        if not torch.equal(ta, tb):
            return False
    return True

def check_determinism(fn, *inputs, ntrials=5):
    outs = [normalize_outs(fn(*inputs)) for _ in range(ntrials)]
    for i in range(1, ntrials):
        if not list_equal(outs[0], outs[i]):
            for j, (r, c) in enumerate(zip(outs[0], outs[i])):
                if r is None and c is None:
                    continue
                if not torch.equal(r, c):
                    print(f" Non-deterministic: trial {i}, tensor {j}, max abs diff = {(r-c).abs().max().item():.3e}")
            return False
    print(" Deterministic across all trials")
    return True


B, D, S, dstate = 2, 16, 32, 8
device, dtype = "cuda", torch.float32

h0    = torch.randn(B, D, dstate, device=device, dtype=dtype)
u     = torch.randn(B, D, S,      device=device, dtype=dtype)
delta = torch.randn(B, D, S,      device=device, dtype=dtype)
A     = torch.randn(D, dstate,    device=device, dtype=dtype)
Bmat  = torch.randn(B, D, dstate, S, device=device, dtype=dtype)
C     = torch.randn(B, D, dstate, S, device=device, dtype=dtype)
Dvec  = torch.randn(D, device=device, dtype=dtype)

# Forward
fwd_out = run_forward(h0, u, delta, A, Bmat, C, Dvec)
out, x, last_state = fwd_out[0], fwd_out[1], fwd_out[-1]
check_determinism(run_forward, h0, u, delta, A, Bmat, C, Dvec)

# Backward
dout = torch.randn_like(out)
d_last_state = torch.randn_like(last_state)

check_determinism(run_backward,
                  h0, u, delta, A, Bmat, C, Dvec,
                  None, None, dout, d_last_state,
                  None, out, None,
                  False, False)


 Deterministic across all trials
 Deterministic across all trials


True