Skip to content

Comments

Support importing both nvfuser and nvfuser_direct modules#4722

Merged
rdspring1 merged 1 commit intomainfrom
import_both_bindings
Jul 10, 2025
Merged

Support importing both nvfuser and nvfuser_direct modules#4722
rdspring1 merged 1 commit intomainfrom
import_both_bindings

Conversation

@rdspring1
Copy link
Collaborator

@rdspring1 rdspring1 commented Jul 3, 2025

This PR modifies nvfuser and nvfuser_direct extensions to allow both of them to be imported in the same script.

  • Change assertion to warning
  • Add py::module_local() to DataType enum that is common between both extensions.

The DataType is now local to the individual extension rather than the global namespace.

PR Stack:

@rdspring1
Copy link
Collaborator Author

!test

@rdspring1 rdspring1 requested a review from jjsjann123 July 3, 2025 19:11
@rdspring1 rdspring1 added Python API Issues related to the Python API Direct Bindings Python extension with direct mapping to NvFuser CPP objects. labels Jul 3, 2025
@github-actions
Copy link

github-actions bot commented Jul 3, 2025

Review updated until commit abafc0a

Description

  • Change import conflict assertions to warnings

  • Make DataType, ParallelType, and CommunicatorBackend enums module-local

  • Update tests to handle warnings instead of assertions


Changes walkthrough 📝

Relevant files
Enhancement
enum.cpp
Make enums module-local                                                                   

python/python_direct/enum.cpp

  • Added py::module_local() to DataType, ParallelType, and
    CommunicatorBackend enums
  • +4/-3     
    python_bindings.cpp
    Make enums module-local                                                                   

    python/python_frontend/python_bindings.cpp

  • Added py::module_local() to DataType, ParallelType, and
    CommunicatorBackend enums
  • +4/-3     
    __init__.py
    Change import conflict handling                                                   

    python/nvfuser/init.py

    • Changed import conflict assertion to a warning
    +6/-3     
    __init__.py
    Change import conflict handling                                                   

    python/nvfuser_direct/init.py

    • Changed import conflict assertion to a warning
    +6/-3     
    Tests
    test_python_frontend.py
    Update import conflict test                                                           

    tests/python/test_python_frontend.py

    • Updated test to handle warnings instead of assertions
    +11/-7   
    test_import.py
    Update import conflict test                                                           

    tests/python_direct/test_import.py

    • Updated test to handle warnings instead of assertions
    +11/-7   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Warning Message

    The warning message suggests caution but does not prevent the import. Consider if this is the desired behavior or if an exception should be raised instead.

    if "nvfuser_direct" in sys.modules:
        warnings.warn(
            "Be careful! You've imported nvfuser when the nvfuser_direct module is already imported.",
            UserWarning,
        )
    Warning Message

    The warning message suggests caution but does not prevent the import. Consider if this is the desired behavior or if an exception should be raised instead.

    if "nvfuser" in sys.modules:
        warnings.warn(
            "Be careful! You've imported nvfuser_direct when the nvfuser module is already imported.",
            UserWarning,
        )
    Test Case

    The test case checks for a warning when importing nvfuser after nvfuser_direct. Ensure that the warning message is consistent with the one in nvfuser/__init__.py.

    def test_import_conflict_nvfuser_then_direct(self):
        import warnings
    
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
    
            import nvfuser  # noqa: F401
            import nvfuser_direct  # noqa: F401
    
            assert len(w) == 1
            assert issubclass(w[-1].category, UserWarning)
            assert (
                "Be careful! You've imported nvfuser_direct when the nvfuser module is already imported."
                in str(w[-1].message)
            )

    Copy link
    Collaborator

    @jjsjann123 jjsjann123 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM.

    OOC, I remember @wujingyue mentioning some singleton issues when we import both libs. Do we have a repro/issue on that?

    @rdspring1
    Copy link
    Collaborator Author

    OOC, I remember @wujingyue mentioning some singleton issues when we import both libs. Do we have a repro/issue on that?

    The singleton issue was caused by cleaning up the communicator. I'm not sure what kind of errors might occur with other singletons.

    I'm fine cherry-pick this PR for the inference demo and keeping the two modules separate.

    @jjsjann123
    Copy link
    Collaborator

    OOC, I remember @wujingyue mentioning some singleton issues when we import both libs. Do we have a repro/issue on that?

    The singleton issue was caused by cleaning up the communicator. I'm not sure what kind of errors might occur with other singletons.

    I'm fine cherry-pick this PR for the inference demo and keeping the two modules separate.

    Not suggesting to hold off from merging. If CI are passing and we have issues that we can work around, we should merge it as-is. Always easier to keep things in main.

    @wujingyue
    Copy link
    Collaborator

    wujingyue commented Jul 8, 2025

    Do we have a repro/issue on that?

    Yes. It's

    $ cat repro.py
    import nvfuser
    import nvfuser_direct
    nvfuser_direct.multidevice.Communicator.instance()
    
    $ mpirun -np 1 -output-filename /tmp/repro python repro.py
    /opt/pytorch/nvfuser/python/nvfuser_direct/__init__.py:9: UserWarning: Be careful! You've imported nvfuser_direct when the nvfuser module is already imported.
      warnings.warn(
    terminate called after throwing an instance of 'c10d::SocketError'
      what():  The server socket has failed to listen on any local network address. port: 29542, useIpv6: false, code: -98, name: EADDRINUSE, message: address already in use
    Exception raised from makeWithPort at /opt/pytorch/pytorch/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp:308 (most recent call first):
    frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x88 (0x73450fea9568 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
    frame #1: <unknown function> + 0x59a423e (0x73456bd4823e in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
    frame #2: <unknown function> + 0x59b98cf (0x73456bd5d8cf in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
    frame #3: <unknown function> + 0x10eae4e (0x73456748ee4e in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
    frame #4: <unknown function> + 0x59b4ab8 (0x73456bd58ab8 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
    frame #5: <unknown function> + 0x599fb77 (0x73456bd43b77 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
    frame #6: c10d::TCPStore::TCPStore(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10d::TCPStoreOptions const&) + 0x189 (0x73456bd45a59 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so)
    frame #7: <unknown function> + 0x8c109f (0x734405c2f09f in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
    frame #8: <unknown function> + 0x8c1819 (0x734405c2f819 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
    frame #9: nvfuser::python_frontend::cleanup() + 0x58 (0x734405593558 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
    frame #10: python() [0x54c870]
    frame #11: python() [0x5761fe]
    frame #12: python() [0x575f4c]
    frame #13: python() [0x6a90dc]
    frame #14: python() [0x6b1b92]
    <omitting python frames>
    frame #18: <unknown function> + 0x2a1ca (0x7345800451ca in /usr/lib/x86_64-linux-gnu/libc.so.6)
    frame #19: __libc_start_main + 0x8b (0x73458004528b in /usr/lib/x86_64-linux-gnu/libc.so.6)
    
    --------------------------------------------------------------------------
    Primary job  terminated normally, but 1 process returned
    a non-zero exit code. Per user-direction, the job has been aborted.
    --------------------------------------------------------------------------
    --------------------------------------------------------------------------
    mpirun noticed that process rank 0 with PID 0 on node swagtron exited on signal 6 (Aborted).
    --------------------------------------------------------------------------
    

    @rdspring1 rdspring1 force-pushed the import_both_bindings branch from 7b18ab6 to 2702f9c Compare July 9, 2025 16:33
    @rdspring1 rdspring1 force-pushed the import_both_bindings branch from 2702f9c to abafc0a Compare July 9, 2025 16:34
    @rdspring1
    Copy link
    Collaborator Author

    !test

    @rdspring1 rdspring1 merged commit 9410b2e into main Jul 10, 2025
    46 of 50 checks passed
    @rdspring1 rdspring1 deleted the import_both_bindings branch July 10, 2025 01:03
    rdspring1 added a commit that referenced this pull request Jul 10, 2025
    This PR adds `cutlass_nvfp4_scaled_mm` to the `nvfuser_direct` python
    bindings, which support nvfp4 gemm.
    
    PR Stack:
    - #4722
    - #4676 **<<< This PR.**
    - #4662
    wujingyue added a commit that referenced this pull request Jul 25, 2025
    for two enhancements:
    1. #4722
    2. #4837
    @wujingyue wujingyue mentioned this pull request Jul 25, 2025
    wujingyue added a commit that referenced this pull request Jul 25, 2025
    nsarka pushed a commit to nsarka/Fuser that referenced this pull request Jul 28, 2025
    This PR modifies `nvfuser` and `nvfuser_direct` extensions to allow both
    of them to be imported in the same script.
    
    * Change assertion to warning
    * Add `py::module_local()` to `DataType` enum that is common between
    both extensions.
    
    The `DataType` is now local to the individual extension rather than the
    global namespace.
    
    PR Stack:
    - NVIDIA#4722 **<<< This PR.**
    - NVIDIA#4676
    - NVIDIA#4662
    nsarka pushed a commit to nsarka/Fuser that referenced this pull request Jul 28, 2025
    This PR adds `cutlass_nvfp4_scaled_mm` to the `nvfuser_direct` python
    bindings, which support nvfp4 gemm.
    
    PR Stack:
    - NVIDIA#4722
    - NVIDIA#4676 **<<< This PR.**
    - NVIDIA#4662
    nsarka pushed a commit to nsarka/Fuser that referenced this pull request Jul 28, 2025
    jjsjann123 added a commit that referenced this pull request Aug 20, 2025
    …4662)
    
    The API for FP8 and NVFP4 are different in SGLang. 
    
    Example:
    ```python
    >>> import nvfuser_direct
    >>> from nvfuser_direct import nvf_cutlass
    >>> help(nvf_cutlass.nvfp4_blockwise_scaled_grouped_mm)
    # nvfp4_blockwise_scaled_grouped_mm(Tensor! output, Tensor a, Tensor b, Tensor a_blockscale,
    # Tensor b_blockscale, Tensor alphas, Tensor ab_strides, Tensor c_strides, Tensor
    # problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()
    ```
    
    PR Stack:
    - #4722
    - #4676
    - #4662  **<<< This PR.**
    
    TODOs: 
    - [ ] Create unit test.
    
    Co-authored-by: jjsjann123 <jiej@nvidia.com>
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    Direct Bindings Python extension with direct mapping to NvFuser CPP objects. Python API Issues related to the Python API

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants