Skip to content

fix exclusive cumsum calculation#109

Merged
simonguozirui merged 2 commits intomainfrom
fix-exclusive-cumsum
Dec 24, 2025
Merged

fix exclusive cumsum calculation#109
simonguozirui merged 2 commits intomainfrom
fix-exclusive-cumsum

Conversation

@bkal01
Copy link
Collaborator

@bkal01 bkal01 commented Dec 22, 2025

as @AKatydid pointed out on #72, the current exclusive cumsum computation is incorrect.

now we apply torch.cumsum on x (without the last elements along the dimension) and shifts them right by one by prepending 0s.

as @AKatydid pointed out on #72, the current exclusive cumsum computation is incorrect.

now we apply torch.cumsum on x (without the last elements along the dimension) and shifts them right by one by prepending 0s.
@simonguozirui
Copy link
Collaborator

Double-checked the math and wrote a short script to verify it against the PyTorch reference implementation. using roll (not efficient, but easy to check). Equivalent on both cpu and gpu.

def exclusive_cumsum_ref(x: torch.Tensor, dim: int) -> torch.Tensor:
    c = torch.cumsum(x, dim=dim)
    y = torch.roll(c, shifts=1, dims=dim)

    idx = [slice(None)] * x.ndim
    idx[dim] = 0
    y[tuple(idx)] = 0
    return y

Thanks for the fix @bkal01 and @AKatydid for pointing out on #72. Also started change log to document ongoing problem updates!

@simonguozirui simonguozirui merged commit 6bab08b into main Dec 24, 2025
ethanboneh pushed a commit that referenced this pull request Jan 6, 2026
* fix exclusive cumsum calculation

as @AKatydid pointed out on #72, the current exclusive cumsum computation is incorrect.

now we apply torch.cumsum on x (without the last elements along the dimension) and shifts them right by one by prepending 0s.

* check and add changelog

---------

Co-authored-by: Simon Guo <simonguo@stanford.edu>
julian-reed pushed a commit that referenced this pull request Mar 23, 2026
* fix exclusive cumsum calculation

as @AKatydid pointed out on #72, the current exclusive cumsum computation is incorrect.

now we apply torch.cumsum on x (without the last elements along the dimension) and shifts them right by one by prepending 0s.

* check and add changelog

---------

Co-authored-by: Simon Guo <simonguo@stanford.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants