Skip to content

Conversation

@dhonnappa-amd
Copy link

Cherry-pick of #2492

* Use input vectorization for reduction_on_fastest_striding_dimension
when dim0 >= 0

**Reproducer:**
```
import time
import torch

shapes = [
    (5079670, 128)
]

dims = [
    (1)
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.float)
    for _ in range(10):
        w = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    print(w.size())

    start_time = time.time()
    for _ in range(50):
        _ = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    end_time = time.time()
    mean_time = (end_time - start_time)/50
    print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")
```

**Before (MI300X):**
Avg time for shape (5079670, 128): 1629.99 us

**After (MI300X)**
Avg time for shape (5079670, 128): 1008.59 us

cherry-pick of pytorch#160466

Fixes SWDEV-546136
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Aug 13, 2025

Jenkins build for b31274f527974938f03d4f4b5e9375c3345154e1 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@pruthvistony pruthvistony marked this pull request as ready for review August 15, 2025 15:27
@pruthvistony pruthvistony merged commit 0def0b8 into release/2.8 Aug 15, 2025
0 of 2 checks passed
@pruthvistony pruthvistony deleted the autogenerated/release/2.8_cherry-pick_pr-2492 branch August 15, 2025 15:27
tvukovic-amd pushed a commit that referenced this pull request Aug 20, 2025
#2505)

Cherry-pick of #2492

Co-authored-by: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants