Skip to content

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Mar 20, 2025

This PR is to decouple upcoming changes to the multi-device related presegmentation passes.

Since the presegmentation passes will be updated serially, this to workaround around test failures in the interim.
For example, some examples have manually set loop domains as required for communication. The propagateShardings preseg pass will update the loop domain if a tensorview does not have a device mesh. Since our reorderShardedAxis preseg pass is not yet fully functional for DID loop split, we can run into errors due to incomplete changes to the preseg pass.

This PR is to ease rolling out the changes incrementally.

Eventually, we should not require setting meshes once the update to all passes are complete.

SetSelector selector(
std::unordered_set<TensorView*>(output_tvs.begin(), output_tvs.end()));
MaxLogicalDomainInfoSpanningTree(input_tv, &selector).traverse(&propagator);
scheduler_utils::parallelizeAllLike(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not set the device mesh.

@github-actions
Copy link

github-actions bot commented Mar 20, 2025

Review updated until commit 3a7baef

Description

  • Added missing device meshes in tests

  • Updated test_multidevice_sharding.cpp to include tv1 in device mesh setup

  • Replaced scheduler_utils::parallelizeAllLike with shardAllLike in propagateShardings

  • Enhanced test_multidevice.py to include self.out in device mesh setup


Changes walkthrough 📝

Relevant files
Tests
test_multidevice_sharding.cpp
Update device mesh and parallelization in tests                   

tests/cpp/test_multidevice_sharding.cpp

  • Added tv1->setDeviceMesh(mesh);
  • Replaced scheduler_utils::parallelizeAllLike with shardAllLike
  • +2/-4     
    test_multidevice.py
    Enhance multidevice tests with additional tensor setups   

    tests/python/test_multidevice.py

  • Included self.out in device mesh setup
  • Added self.seed and self.offset to device mesh setup
  • Corrected variable assignments in sdpfa_fwd and sdpfa_bwd
  • +6/-5     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Code Duplication

    The function shardAllLike is called in place of scheduler_utils::parallelizeAllLike. Ensure that this change does not introduce unintended behavior or performance issues.

    shardAllLike(input_tv, output_tvs);
    Variable Naming

    The variables seed and offset are renamed to self.seed and self.offset in the definition method. Ensure that these changes do not affect the functionality and that all references to these variables are updated accordingly.

    self.attn, self.log_sumexp, self.seed, self.offset = self.ops.sdpfa_fwd(
        self.q, self.k, self.v, dropout_p, is_causal, scale=None
    )
    
    self.q_grad, self.k_grad, self.v_grad = self.ops.sdpfa_bwd(
    Missing Non-Sharded TVS

    The list non_sharded_tvs is introduced in the multidevice_schedule method. Verify that this list includes all non-sharded tensor views and that their inclusion is necessary for the correct execution of the test.

    non_sharded_tvs = [self.seed, self.offset]
    
    for t in input_tvs + output_tvs + non_sharded_tvs:
        self.sched._set_device_mesh(t, mesh)

    @Priya2698 Priya2698 requested a review from wujingyue March 20, 2025 21:34
    @Priya2698
    Copy link
    Collaborator Author

    !test

    Copy link
    Collaborator

    @wujingyue wujingyue left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I no longer understand the reason behind this PR.

    shardAllLike(ref_input, outputs_without_mesh);
    and
    tv->setDeviceMesh(output_with_mesh->getDeviceMesh());
    didn't kick in for some reason?

    @Priya2698
    Copy link
    Collaborator Author

    Priya2698 commented Apr 4, 2025

    @wujingyue I prefer merging this to decouple updates to the presegmentation passes. Wdyt?

    @wujingyue
    Copy link
    Collaborator

    Sure. I think we talked about this PR offline. Can you PR-describe the motivations before I approve? IIUC, this tries to work around upcoming problems in sharding propagation before and the added meshes are won't be necessary eventually

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @wujingyue
    Copy link
    Collaborator

    Since, the presegmentation passes will be updated serially, this to workaround around those issues in the interim.

    Can you fix the grammar here?

    @Priya2698 Priya2698 merged commit 7affa36 into main Apr 4, 2025
    51 checks passed
    @Priya2698 Priya2698 deleted the pm/fix_tests branch April 4, 2025 18:57
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants