Skip to content

Conversation

@jerrymannil
Copy link
Collaborator

@jerrymannil jerrymannil commented Aug 13, 2025

  • Use input vectorization for reduction_on_fastest_striding_dimension when dim0 >= 128

Reproducer:

import time
import torch

shapes = [
    (5079670, 128)
]

dims = [
    (1)
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.float)
    for _ in range(10):
        w = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    print(w.size())

    start_time = time.time()
    for _ in range(50):
        _ = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    end_time = time.time()
    mean_time = (end_time - start_time)/50
    print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")

Before (MI300X):
Avg time for shape (5079670, 128): 1629.99 us

After (MI300X)
Avg time for shape (5079670, 128): 1008.59 us

cherry-pick of pytorch#160466

Fixes SWDEV-546136

* Use input vectorization for reduction_on_fastest_striding_dimension when dim0 >= 128

**Reproducer:**
```
import time
import torch

shapes = [
    (5079670, 128)
]

dims = [
    (1)
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.float)
    for _ in range(10):
        w = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    print(w.size())

    start_time = time.time()
    for _ in range(50):
        _ = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    end_time = time.time()
    mean_time = (end_time - start_time)/50
    print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")
```

**Before (MI300X):**
Avg time for shape (5079670, 128): 1629.99 us

**After (MI300X)**
Avg time for shape (5079670, 128): 1008.59 us

cherry-pick of pytorch#160466

Fixes SWDEV-546136
@jerrymannil jerrymannil self-assigned this Aug 13, 2025
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Aug 13, 2025

Jenkins build for e627009c65ceca991c6d645b6e0b9b4450d32209 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@jerrymannil
Copy link
Collaborator Author

already merged in #2505

@jerrymannil jerrymannil deleted the jerrymannil-patch-1 branch August 22, 2025 16:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants