Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.fx cannot find module <built-in method matmul> when using apex.amp #1359

Closed
juncgu opened this issue Apr 21, 2022 · 3 comments
Closed
Labels
bug Something isn't working

Comments

@juncgu
Copy link

juncgu commented Apr 21, 2022

Describe the Bug
I camp up with the following issue when using torch.fx and apex.amp together.
If a model has torch.matul operations and uses apex.amp (opt_level=O1) for mixed precision optimization, then fx will fail to recompile the model graph and will report the following error:

File "/xxxxxxxxxxxxx/dist-packages/torch/fx/node.py", line 40, in _find_module_of_method
    raise RuntimeError(f'cannot find module for {orig_method}, {name}')
RuntimeError: cannot find module for <built-in method matmul of type object at 0x7f6c5fd9fe40>

Minimal Steps/Code to Reproduce the Bug

The following toy module with torch.matmul can reproduce the error.

import torch
from torch import fx

class ToyMod(torch.nn.Module):
    def __init__(self, in_features=768, out_features=768):
        super().__init__()
        self.linear = torch.nn.Linear(in_features=in_features, out_features=out_features)

    def forward(self, X, other):
        result = torch.matmul(self.linear(input=X), other)
        return result

def test():
    from apex import amp
    mod = ToyMod().cuda()
    optimizer = torch.optim.SGD(mod.parameters(), lr=1e-3)

    # Allow Amp to perform casts as required by the opt_level
    model, optimizer = amp.initialize(mod, optimizer, opt_level="O1")
    graph : fx.Graph = fx.Tracer().trace(model)
    for node in graph.nodes:
        # do something
        pass
    graph.lint()
    return fx.GraphModule(model, graph)

Expected Behavior

Environment

PyTorch version: 1.10.0
CUDA used to build PyTorch: 11.3

Python version: 3.7.3
Is CUDA available: True
CUDA runtime version: 11.3.109

[pip3] numpy==1.21.5
[pip3] torch==1.10.0
[pip3] torchaudio==0.10.0+cu113
[pip3] torchvision==0.11.1+cu113

apex: Version 0.1

@juncgu juncgu added the bug Something isn't working label Apr 21, 2022
@crcrpar
Copy link
Collaborator

crcrpar commented Apr 23, 2022

Does the script work if you use torch.cuda.amp?

@juncgu
Copy link
Author

juncgu commented Apr 23, 2022

Does the script work if you use torch.cuda.amp?

@crcrpar
Yes. The script works with torch.cuda.amp.

@ptrblck
Copy link
Contributor

ptrblck commented Apr 25, 2022

Thanks for verifying @juncgu!
apex.amp is deprecated in favor of the native implementation via torch.cuda.amp (see #818).

Closing

@ptrblck ptrblck closed this as completed Apr 25, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants