Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make wrapIndexOnce check async, avoid DtoH sync on index_put_ #125952

Closed
wants to merge 2 commits into from

Conversation

ezyang
Copy link
Contributor

@ezyang ezyang commented May 10, 2024

[ghstack-poisoned]
Copy link

pytorch-bot bot commented May 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125952

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures

As of commit bb0289b with merge base 96a5698 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ezyang added a commit that referenced this pull request May 10, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 20ad93d07612a3326ecaebaa3a7b2a5b45c15f54
Pull Request resolved: #125952
@ezyang ezyang requested review from lezcano, ngimel and eqy May 10, 2024 20:06
@ezyang ezyang added release notes: python_frontend release notes category topic: bug fixes topic category labels May 10, 2024
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already do that within canDispatchToMaskedFill:

static std::tuple<bool, Tensor> canDispatchToMaskedFill(const Tensor& self, const torch::List<c10::optional<at::Tensor>>& indices,
const Tensor& value){
  if (!(value.numel() ==1 && value.device().is_cpu())){

@ezyang
Copy link
Contributor Author

ezyang commented May 10, 2024

OK, that's useful, because the internal user tested and they also said this did not fix it.

@ezyang
Copy link
Contributor Author

ezyang commented May 10, 2024

I'm out of time right now, but here is the repro

import logging
from dataclasses import dataclass
from datetime import datetime

import torch
import torch._inductor.config as inductor_config


logger = logging.getLogger(__name__)
TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"


@dataclass
class BenchmarkConfig:
    batch_size: int = 256
    enable_bf16: bool = True
    enable_pt2: bool = True
    device = "cuda:0"
    d_in = 2048


class SimpleModel(torch.nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear1 = torch.nn.Linear(dim, dim, bias=False)
        self.ts_encoding_params_dict = torch.nn.Parameter(
            torch.empty(
                [
                    2000,
                    dim,
                ]
            ).uniform_(-0.01, 0.01)
        )
        self.linear2 = torch.nn.Linear(dim, dim, bias=False)

    def forward(
        self,
        x,
        num_object,
        user_event_ts_buckets,
    ):
        emb = self.linear1(x)
        # user_event_ts_encoding = self.ts_encoding_params_dict[user_event_ts_buckets, :]
        user_event_ts_encoding = self.ts_encoding_params_dict.index_select(
            0, user_event_ts_buckets
        )
        emb = emb + user_event_ts_encoding
        res = self.linear2(emb)
        return res


def create_model_input(benchmark_config: BenchmarkConfig):
    batch_size = benchmark_config.batch_size
    d_in = benchmark_config.d_in
    device = benchmark_config.device

    dtype = torch.bfloat16 if benchmark_config.enable_bf16 else torch.float32

    x = torch.rand(
        batch_size * 1000, d_in, dtype=dtype, device=device
    ).requires_grad_()  # assuming seq_len_per_example is max_length // 2
    num_object = torch.tensor(
        [1000] * batch_size,
        dtype=torch.int,
        device=device,
    )

    user_event_ts_buckets = torch.randint(
        0,
        2000,
        (1000 * batch_size,),
        dtype=torch.int,
        device=device,
    )

    return (
        x,
        num_object,
        user_event_ts_buckets,
    )


def run_first_model_once(model, input):
    pred = model(*input)
    pred[0].sum().backward()


def single_run_benchmark():
    benchmark_config = BenchmarkConfig()
    model_input = create_model_input(benchmark_config)
    model = SimpleModel(benchmark_config.d_in)

    if benchmark_config.enable_bf16:
        model = model.to(dtype=torch.bfloat16)

    if benchmark_config.enable_pt2:
        inductor_config.decompose_mem_bound_mm = True
        inductor_config.trace.enabled = True
        model = torch.compile(model)
        model = model.to(benchmark_config.device)
        print("Start compiling model.")
        run_first_model_once(model, model_input)
    else:
        model = model.to(benchmark_config.device)

    # trace
    with torch.profiler.profile(with_flops=True) as profiler:
        for _ in range(5):
            run_first_model_once(model, model_input)

    trace_file_prefix = "{}".format(
        datetime.now().strftime(TIME_FORMAT_STR),
    )

    return


def main() -> None:
    single_run_benchmark()
    print("done")


if __name__ == "__main__":
    main()  # pragma: no cover

lezcano added a commit that referenced this pull request May 10, 2024
The previous fix was not general enough.

Fixes #125952

[ghstack-poisoned]
lezcano added a commit that referenced this pull request May 10, 2024
The previous fix was not general enough.

Fixes #125952

ghstack-source-id: d1a956c9c514d10ffa99b4589ea8db0c5b74b46d
Pull Request resolved: #125973
@lezcano
Copy link
Collaborator

lezcano commented May 10, 2024

I put up a fix, but I was not able to test whether it works (my triton version is acting up with the repro). Mind checking if it fixes the issue?

@ezyang ezyang closed this May 11, 2024
@ezyang ezyang reopened this May 11, 2024
@ezyang
Copy link
Contributor Author

ezyang commented May 11, 2024

OK, I got a better stacktrace here:

#12 at::_ops::item::call(at::Tensor const&) from ??:0                                                                            
#13 long at::Tensor::item<long>() const from ??:0                                                                                
#14 at::native::computeLinearIndex(at::Tensor const&, c10::ArrayRef<at::Tensor>, bool) [clone .isra.0] from tmpxft_00061ce6_00000
000-6_Indexing.cudafe1.cpp:0                                                                                                     
#15 at::native::makeLinearIndex(at::Tensor, c10::IListRef<at::OptionalTensorRef>, bool) [clone .constprop.0] from tmpxft_00061ce6
_00000000-6_Indexing.cudafe1.cpp:0                                                                                               
#16 at::native::(anonymous namespace)::index_put_with_sort_kernel(at::Tensor&, c10::List<std::optional<at::Tensor> > const&, at::
Tensor const&, bool, bool) from ??:0                                                                                             
#17 at::native::_index_put_impl_(at::Tensor&, c10::List<std::optional<at::Tensor> > const&, at::Tensor const&, bool, bool) from ?
?:0                

@ezyang
Copy link
Contributor Author

ezyang commented May 11, 2024

It's this one:

static Tensor wrapIndexOnce(const Tensor & index, int64_t dim, int64_t dim_size, bool check_range=true) {
//we don't need to check range in backward - if there were out of bounds indices forward should already have errored out
  if (index.numel() != 0 && check_range) {
    TORCH_INTERNAL_ASSERT(0);
    auto max_idx = index.max().item<int64_t>();
    auto min_idx = index.min().item<int64_t>();
    if (max_idx >= dim_size) {
      TORCH_CHECK_INDEX(false, "index ", max_idx, " is out of bounds for dimension ", dim, " with size ", dim_size);
    } 
    if (min_idx < -dim_size) {
      TORCH_CHECK_INDEX(false, "index ", min_idx, " is out of bounds for dimension ", dim, " with size ", dim_size);
    } 
  }
  return index.remainder(dim_size);
} 

and the reason it doesn't sync in eager is because, dun dun dun, eager used some special API to avoid the check range.

[ghstack-poisoned]
ezyang added a commit that referenced this pull request May 11, 2024
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 0fe1f51c815d8d135172d7f6e15bdb22e5339d48
Pull Request resolved: #125952
@ezyang ezyang changed the title Do not use masked_fill if it would incur DtoH sync Make wrapIndexOnce check async, avoid DtoH sync on index_put_ May 11, 2024
@ezyang ezyang added the ciflow/trunk Trigger trunk jobs on your pull request label May 11, 2024
@ezyang
Copy link
Contributor Author

ezyang commented May 11, 2024

Updated with fix that actually is confirmed to work

@ezyang ezyang requested a review from lezcano May 12, 2024 12:47
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good if it fixes the problem. Could you add a regression test?

Comment on lines +338 to +339
at::_assert_async(index.max() < dim_size);
at::_assert_async(index.min() >= -dim_size);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know we had these working in the end. How does compile interprete these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the CUDA kernel, so compile doesn't really interact with this in a nontrivial way, you just end up hitting this at runtime.

@ezyang
Copy link
Contributor Author

ezyang commented May 13, 2024

I don't really see how to add a regression test. To check if we're doing a DtoH sync, we need some way of detecting such a sync has happened, but there's no facility for programatically determining this.

@ezyang
Copy link
Contributor Author

ezyang commented May 13, 2024

@pytorchbot merge -i

@lezcano
Copy link
Collaborator

lezcano commented May 13, 2024

I figured a way would be to try to cudagraph the relevant code and see that it was able to do so

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 3 checks: Lint / Test collect_env (with_torch), Lint / Test collect_env (without_torch), Lint / Test collect_env (older_python_version)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

tinglvv pushed a commit to tinglvv/pytorch that referenced this pull request May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-td-distributed ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: python_frontend release notes category topic: bug fixes topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants