Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot use compiled model together with the ddp strategy #18798

Closed
quancs opened this issue Oct 14, 2023 · 2 comments
Closed

Cannot use compiled model together with the ddp strategy #18798

quancs opened this issue Oct 14, 2023 · 2 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.0.x

Comments

@quancs
Copy link
Member

quancs commented Oct 14, 2023

Bug description

Cannot use compiled model together with the ddp strategy, and the error disappears when change the strategy to auto.

backend='compile_fn' raised:
AttributeError: 'MultiheadAttention' object has no attribute 'requires_grad'

What version are you seeing the problem on?

v2.0

How to reproduce the bug

CMD:

python code.py

code (code.py):

from typing import *
import pytorch_lightning as pl
import torch
import torch.nn as nn
from jsonargparse import lazy_instance
from packaging.version import Version
from torch import Tensor
from torch.nn import MultiheadAttention
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset


def default_collate_func(batches: List[Tuple[Tensor, Tensor, Dict[str, Any]]]) -> List[Any]:
    mini_batch = []
    for x in zip(*batches):
        if isinstance(x[0], Tensor):
            x = torch.stack(x)
        mini_batch.append(x)
    return mini_batch


class SmsWsjPlusDataset(Dataset):

    def __init__(self,) -> None:
        super().__init__()

    def __getitem__(self, index: int):
        return torch.randn((129, 251, 12)), torch.randn((2, 129, 251)), {'index': index, 'sample_rate': 8000}

    def __len__(self):
        return 100000


class SmsWsjPlusDataModule(LightningDataModule):

    def __init__(self,):
        super().__init__()

    def construct_dataloader(self):
        ds = SmsWsjPlusDataset()
        return DataLoader(ds, batch_size=1, collate_fn=default_collate_func, num_workers=2)

    def train_dataloader(self) -> DataLoader:
        return self.construct_dataloader()

    def val_dataloader(self) -> DataLoader:
        return self.construct_dataloader()


class SpatialNet(nn.Module):

    def __init__(self,):
        super().__init__()
        self.encoder = nn.Conv1d(in_channels=12, out_channels=96, kernel_size=5, stride=1, padding="same")

        layers = []
        for l in range(2):
            layer = nn.ModuleList([nn.LayerNorm(96), MultiheadAttention(embed_dim=96, num_heads=4, batch_first=True)])
            layers.append(layer)
        self.layers = nn.ModuleList(layers)
        self.decoder = nn.Linear(in_features=96, out_features=4)

    def forward(self, x: Tensor, return_attn_score: bool = False) -> Tensor:
        # x: [Batch, Freq, Time, Feature]
        B, F, T, H0 = x.shape
        x = self.encoder(x.reshape(B * F, T, H0).permute(0, 2, 1)).permute(0, 2, 1)
        H = x.shape[2]

        attns = [] if return_attn_score else None
        x = x.reshape(B, F, T, H)
        for m in self.layers:
            x = x.reshape(B * F, T, H)
            x = m[0](x)
            x, attn = m[1].forward(x, x, x)
            x = x.reshape(B, F, T, H)
            if return_attn_score:
                attns.append(attn)

        y = self.decoder(x)
        if return_attn_score:
            return y.contiguous(), attns
        else:
            return y.contiguous()


class TrainModule(pl.LightningModule):

    def __init__(self, arch: nn.Module = lazy_instance(SpatialNet), compile: bool = False):
        super().__init__()

        if compile != False:
            assert Version(torch.__version__) >= Version('2.0.0'), f'compile only works for torch>=2.0: current version: {torch.__version__}'
            self.arch = torch.compile(arch)
        else:
            self.arch = arch

    def forward(self, x: Tensor) -> Tensor:
        B, F, T, H = x.shape
        out = self.arch(x)
        out = torch.view_as_complex(out.float().reshape(B, F, T, -1, 2))  # [B,F,T,Spk]
        out = out.permute(0, 3, 1, 2)  # [B,Spk,F,T]
        return out

    def training_step(self, batch, batch_idx):
        x, ys, paras = batch
        yr_hat = self.forward(x)
        loss = (ys - yr_hat).abs().mean()
        self.log('train/loss', loss, batch_size=ys[0].shape[0], prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, ys, paras = batch
        yr_hat = self.forward(x)
        loss = (ys - yr_hat).abs().mean()
        self.log('val/loss', loss, sync_dist=True, batch_size=ys.shape[0])

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())


if __name__ == '__main__':
    arch = SpatialNet()
    model = TrainModule(arch=arch, compile=True)
    datamodule = SmsWsjPlusDataModule()

    from pytorch_lightning import Trainer
    trainer = Trainer(strategy='ddp', devices="0,") # strategy='auto' is OK
    trainer.fit(model=model, datamodule=datamodule)

Error messages and logs

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/pytorch_lightning/loops/utilities.py:72: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name | Type            | Params
-----------------------------------------
0 | arch | OptimizedModule | 81.1 K
-----------------------------------------
81.1 K    Trainable params
0         Non-trainable params
81.1 K    Total params
0.324     Total estimated model params size (MB)
Sanity Checking DataLoader 0:   0%|                                                                                                                                                                       | 0/2 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/mnt/home/quancs/projects/NBSS/SharedTrainer2.py", line 128, in <module>
    trainer.fit(model=model, datamodule=datamodule)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 545, in fit
    call._call_and_handle_interrupt(
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 102, in launch
    return function(*args, **kwargs)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 581, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 990, in _run
    results = self._run_stage()
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1034, in _run_stage
    self._run_sanity_check()
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1063, in _run_sanity_check
    val_loop.run()
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/pytorch_lightning/loops/utilities.py", line 181, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 134, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 391, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py", line 402, in validation_step
    return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py", line 628, in __call__
    wrapper_output = wrapper_module(*args, **kwargs)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1519, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1355, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py", line 621, in wrapped_forward
    out = method(*_args, **_kwargs)
  File "/mnt/home/quancs/projects/NBSS/SharedTrainer2.py", line 113, in validation_step
    yr_hat = self.forward(x)
  File "/mnt/home/quancs/projects/NBSS/SharedTrainer2.py", line 99, in forward
    out = self.arch(x)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 487, in catch_errors
    return hijacked_callback(frame, cache_entry, hooks, frame_state)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 641, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
    return _compile(
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
    tracer.run()
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2074, in run
    super().run()
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2162, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 833, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 957, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1024, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1009, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/_dynamo/backends/distributed.py", line 268, in compile_fn
    if maybe_param.requires_grad and not self._ignore_parameter(
  File "/mnt/home/quancs/miniconda3/envs/torch211/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1695, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised:
AttributeError: 'MultiheadAttention' object has no attribute 'requires_grad'

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - NVIDIA A100-SXM4-40GB
    - available: True
    - version: 12.1
  • Lightning:
    - lightning-utilities: 0.9.0
    - pytorch-lightning: 2.1.0
    - torch: 2.1.0
    - torchaudio: 2.1.0
    - torcheval: 0.0.6
    - torchmetrics: 1.2.0
    - torchtnt: 0.2.1
    - torchvision: 0.16.0
  • Packages:
    - absl-py: 2.0.0
    - aiohttp: 3.8.6
    - aiosignal: 1.3.1
    - antlr4-python3-runtime: 4.9.3
    - async-timeout: 4.0.3
    - attrs: 23.1.0
    - brotlipy: 0.7.0
    - cachetools: 5.3.1
    - certifi: 2023.7.22
    - cffi: 1.15.1
    - charset-normalizer: 2.0.4
    - colorama: 0.4.6
    - cryptography: 41.0.3
    - dlp-mpi: 0.0.3
    - docopt: 0.6.2
    - docstring-parser: 0.15
    - filelock: 3.9.0
    - frozenlist: 1.4.0
    - fsspec: 2023.9.2
    - future: 0.18.3
    - gitdb: 4.0.10
    - gitpython: 3.1.30
    - gmpy2: 2.1.2
    - google-auth: 2.23.3
    - google-auth-oauthlib: 1.0.0
    - grpcio: 1.59.0
    - idna: 3.4
    - importlib-metadata: 6.8.0
    - importlib-resources: 6.1.0
    - jinja2: 3.1.2
    - jsonargparse: 4.25.0
    - jsonpickle: 1.5.2
    - lazy-dataset: 0.0.14
    - lightning-utilities: 0.9.0
    - markdown: 3.5
    - markdown-it-py: 3.0.0
    - markupsafe: 2.1.1
    - mdurl: 0.1.2
    - mir-eval: 0.7
    - mkl-fft: 1.3.1
    - mkl-random: 1.2.2
    - mkl-service: 2.4.0
    - mpmath: 1.3.0
    - multidict: 6.0.4
    - munch: 2.5.0
    - mypy: 1.6.0
    - mypy-extensions: 1.0.0
    - networkx: 3.1
    - numpy: 1.24.3
    - oauthlib: 3.2.2
    - omegaconf: 2.3.0
    - packaging: 23.2
    - pandas: 2.1.1
    - pesq: 0.0.4
    - pillow: 10.0.1
    - pip: 23.2.1
    - platformdirs: 3.11.0
    - protobuf: 4.24.4
    - psutil: 5.9.5
    - py-cpuinfo: 9.0.0
    - pyasn1: 0.5.0
    - pyasn1-modules: 0.3.0
    - pycparser: 2.21
    - pygments: 2.16.1
    - pyopenssl: 23.2.0
    - pyre-extensions: 0.0.30
    - pysocks: 1.7.1
    - pystoi: 0.3.3
    - python-dateutil: 2.8.2
    - pytorch-lightning: 2.1.0
    - pytz: 2023.3.post1
    - pyyaml: 6.0
    - requests: 2.31.0
    - requests-oauthlib: 1.3.1
    - rich: 13.6.0
    - rsa: 4.9
    - sacred: 0.8.2
    - scipy: 1.11.3
    - setuptools: 68.0.0
    - sh: 1.14.3
    - six: 1.16.0
    - smmap: 5.0.0
    - soundfile: 0.12.1
    - sympy: 1.11.1
    - tabulate: 0.9.0
    - tensorboard: 2.14.1
    - tensorboard-data-server: 0.7.1
    - tomli: 2.0.1
    - torch: 2.1.0
    - torchaudio: 2.1.0
    - torcheval: 0.0.6
    - torchmetrics: 1.2.0
    - torchtnt: 0.2.1
    - torchvision: 0.16.0
    - tqdm: 4.66.1
    - triton: 2.1.0
    - typeshed-client: 2.4.0
    - typing-extensions: 4.7.1
    - typing-inspect: 0.9.0
    - tzdata: 2023.3
    - urllib3: 1.26.16
    - werkzeug: 3.0.0
    - wheel: 0.41.2
    - wrapt: 1.14.1
    - yapf: 0.40.2
    - yarl: 1.9.2
    - zipp: 3.17.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.9.0
    - release: 3.10.0-1160.el7.x86_64
    - version: Proposal for help #1 SMP Mon Oct 19 16:18:59 UTC 2020

More info

No response

@carmocca
Copy link
Member

I don't see anything wrong with your code. This probably needs to be fixed in PyTorch. Since you already opened an issue there, I'll close this

@carmocca carmocca closed this as not planned Won't fix, can't repro, duplicate, stale Nov 21, 2023
@Mohamed-Dhouib
Copy link

Got same problem with pytorch 2.3.0, apparently ddp works well with a compiled model so I guess something may need to be fixed on the pytorch lightning code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.0.x
Projects
None yet
Development

No branches or pull requests

3 participants