Skip to content

Conversation

@jerrymannil
Copy link
Collaborator

@jerrymannil jerrymannil commented Aug 12, 2025

  • Use input vectorization for reduction_on_fastest_striding_dimension when dim0 >= 128

Reproducer:

import time
import torch

shapes = [
    (5079670, 128)
]

dims = [
    (1)
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.float)
    for _ in range(10):
        w = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    print(w.size())

    start_time = time.time()
    for _ in range(50):
        _ = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    end_time = time.time()
    mean_time = (end_time - start_time)/50
    print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")

Before (MI300X):
Avg time for shape (5079670, 128): 1629.99 us

After (MI300X)
Avg time for shape (5079670, 128): 1008.59 us

cherry-pick of pytorch#160466

Fixes SWDEV-546136

Cherry-picked to release/2.8 branch via #2505

* Use input vectorization for reduction_on_fastest_striding_dimension when dim0 >= 0

**Reproducer:**
```
import time
import torch

shapes = [
    (5079670, 128)
]

dims = [
    (1)
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.float)
    for _ in range(10):
        w = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    print(w.size())

    start_time = time.time()
    for _ in range(50):
        _ = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    end_time = time.time()
    mean_time = (end_time - start_time)/50
    print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")
```

**Before (MI300X):**
Avg time for shape (5079670, 128): 1629.99 us

**After (MI300X)**
Avg time for shape (5079670, 128): 1008.59 us

cherry-pick of pytorch#160466
@jerrymannil jerrymannil self-assigned this Aug 12, 2025
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Aug 12, 2025

Jenkins build for 18528f45f38a0cb0eab9c869f1c367156a1d7122 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@pruthvistony
Copy link
Collaborator

Please cherry-pick into all required branches.

@pruthvistony pruthvistony merged commit d93829c into release/2.7 Aug 13, 2025
0 of 2 checks passed
@pruthvistony pruthvistony deleted the jerrymannil-patch-1 branch August 13, 2025 02:16
@jerrymannil
Copy link
Collaborator Author

! cherry-pick --onto release/2.8

1 similar comment
@jithunnair-amd
Copy link
Collaborator

! cherry-pick --onto release/2.8

dhonnappa-amd pushed a commit that referenced this pull request Aug 13, 2025
* Use input vectorization for reduction_on_fastest_striding_dimension
when dim0 >= 0

**Reproducer:**
```
import time
import torch

shapes = [
    (5079670, 128)
]

dims = [
    (1)
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.float)
    for _ in range(10):
        w = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    print(w.size())

    start_time = time.time()
    for _ in range(50):
        _ = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    end_time = time.time()
    mean_time = (end_time - start_time)/50
    print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")
```

**Before (MI300X):**
Avg time for shape (5079670, 128): 1629.99 us

**After (MI300X)**
Avg time for shape (5079670, 128): 1008.59 us

cherry-pick of pytorch#160466

Fixes SWDEV-546136
@dhonnappa-amd
Copy link

Created branch autogenerated/release/2.8_cherry-pick_pr-2492 and #2505

@pruthvistony
Copy link
Collaborator

! cherry-pick --onto rocm7.1_internal_testing

pruthvistony pushed a commit that referenced this pull request Aug 15, 2025
#2505)

Cherry-pick of #2492

Co-authored-by: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com>
tvukovic-amd pushed a commit that referenced this pull request Aug 20, 2025
#2505)

Cherry-pick of #2492

Co-authored-by: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants