Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch.compile on nn.module instead of on LightningModule #587

Merged
merged 3 commits into from
Sep 1, 2023

Conversation

tesfaldet
Copy link
Contributor

@tesfaldet tesfaldet commented Jul 8, 2023

What does this PR do?

Related to this issue Lightning-AI/pytorch-lightning#17177.

Before submitting

  • Did you make sure title is self-explanatory and the description concisely explains the PR?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you test your PR locally with pytest command?
  • Did you run pre-commit hooks with pre-commit run -a command?

Did you have fun?

Make sure you had fun coding 🙃

@tesfaldet
Copy link
Contributor Author

Tests can be fixed by merging this PR #585

requirements.txt Outdated Show resolved Hide resolved
@codecov-commenter
Copy link

Codecov Report

Patch coverage: 75.00% and project coverage change: +0.28% 🎉

Comparison is base (8055898) 83.75% compared to head (eb7c731) 84.03%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #587      +/-   ##
==========================================
+ Coverage   83.75%   84.03%   +0.28%     
==========================================
  Files          11       11              
  Lines         357      357              
==========================================
+ Hits          299      300       +1     
+ Misses         58       57       -1     
Files Changed Coverage Δ
src/train.py 96.00% <ø> (+3.54%) ⬆️
src/models/mnist_module.py 96.96% <75.00%> (-1.45%) ⬇️

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -21,6 +21,7 @@ channels:
# compatibility is usually guaranteed

dependencies:
- python=3.10
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is python 3.10 required for torch compile?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, unfortunately

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure? I see here that only python 3.8 is required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forget exactly the error I got, so I'll try to recreate it to show that 3.10 is required, for posterity's sake.

@ashleve
Copy link
Owner

ashleve commented Aug 24, 2023

Thanks for PR, you can set it to ready if you want me to merge

@tesfaldet tesfaldet marked this pull request as ready for review August 28, 2023 14:03
@ashleve ashleve changed the title torch.compile on nn.module instead of on lightningmodule Fix torch.compile on nn.module instead of on LightningModule Sep 1, 2023
@ashleve ashleve merged commit 2654bad into ashleve:main Sep 1, 2023
13 checks passed
@ashleve ashleve added the bug Something isn't working label Sep 1, 2023
@ResearchDaniel
Copy link

Revert this fix as it has been fixed in PyTorch?
Lightning-AI/pytorch-lightning#17177 (comment)
It is more convenient with a global switch

@tesfaldet
Copy link
Contributor Author

Revert this fix as it has been fixed in PyTorch? Lightning-AI/lightning#17177 (comment) It is more convenient with a global switch

I noticed this recently as well. I'm currently testing it for myself. I'm not sure which version of PyTorch fixes this problem as I don't see a version mentioned in the comment Lightning-AI/pytorch-lightning#17177 (comment) but I'll try to find it.

@tesfaldet
Copy link
Contributor Author

Alright, so I created a new conda environment with the following packages installed:

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                  2_kmp_llvm    conda-forge
annotated-types           0.5.0              pyhd8ed1ab_0    conda-forge
antlr-python-runtime      4.9.3              pyhd8ed1ab_1    conda-forge
anyio                     4.0.0              pyhd8ed1ab_0    conda-forge
arrow                     1.2.3              pyhd8ed1ab_0    conda-forge
attrs                     23.1.0             pyh71513ae_1    conda-forge
backoff                   2.2.1              pyhd8ed1ab_0    conda-forge
backports                 1.0                pyhd8ed1ab_3    conda-forge
backports.functools_lru_cache 1.6.5              pyhd8ed1ab_0    conda-forge
beautifulsoup4            4.12.2             pyha770c72_0    conda-forge
blas                      1.0                         mkl    conda-forge
blessed                   1.19.1             pyhe4f9e05_2    conda-forge
brotli-python             1.1.0            py38h17151c0_0    conda-forge
bzip2                     1.0.8                h7f98852_4    conda-forge
ca-certificates           2023.7.22            hbcca054_0    conda-forge
cachecontrol              0.13.1             pyhd8ed1ab_0    conda-forge
cachecontrol-with-filecache 0.13.1             pyhd8ed1ab_0    conda-forge
certifi                   2023.7.22          pyhd8ed1ab_0    conda-forge
cffi                      1.15.1           py38h4a40e3a_3    conda-forge
cfgv                      3.3.1              pyhd8ed1ab_0    conda-forge
charset-normalizer        3.2.0              pyhd8ed1ab_0    conda-forge
cleo                      2.0.1              pyhd8ed1ab_0    conda-forge
click                     8.1.7           unix_pyh707e725_0    conda-forge
cloudpickle               2.2.1                    pypi_0    pypi
colorama                  0.4.6              pyhd8ed1ab_0    conda-forge
colorlog                  6.7.0                    pypi_0    pypi
crashtest                 0.4.1              pyhd8ed1ab_0    conda-forge
croniter                  1.4.1              pyhd8ed1ab_0    conda-forge
cryptography              41.0.3           py38hcdda232_0    conda-forge
cuda-cudart               11.8.89                       0    nvidia
cuda-cupti                11.8.87                       0    nvidia
cuda-libraries            11.8.0                        0    nvidia
cuda-nvrtc                11.8.89                       0    nvidia
cuda-nvtx                 11.8.86                       0    nvidia
cuda-runtime              11.8.0                        0    nvidia
dateutils                 0.6.12                     py_0    conda-forge
dbus                      1.13.6               h5008d03_3    conda-forge
deepdiff                  6.5.0              pyhd8ed1ab_0    conda-forge
distlib                   0.3.7              pyhd8ed1ab_0    conda-forge
dulwich                   0.21.6           py38h01eb140_0    conda-forge
exceptiongroup            1.1.3              pyhd8ed1ab_0    conda-forge
expat                     2.5.0                hcb278e6_1    conda-forge
fastapi                   0.103.0            pyhd8ed1ab_0    conda-forge
ffmpeg                    4.3                  hf484d3e_0    pytorch
filelock                  3.12.4             pyhd8ed1ab_0    conda-forge
freetype                  2.12.1               h267a509_2    conda-forge
fsspec                    2023.9.1           pyh1a96a4e_0    conda-forge
gettext                   0.21.1               h27087fc_0    conda-forge
gmp                       6.2.1                h58526e2_0    conda-forge
gmpy2                     2.1.2            py38h793c122_1    conda-forge
gnutls                    3.6.13               h85f3911_1    conda-forge
h11                       0.14.0             pyhd8ed1ab_0    conda-forge
hydra-colorlog            1.2.0                    pypi_0    pypi
hydra-core                1.3.2              pyhd8ed1ab_0    conda-forge
hydra-submitit-launcher   1.2.0                    pypi_0    pypi
icu                       73.2                 h59595ed_0    conda-forge
identify                  2.5.29             pyhd8ed1ab_0    conda-forge
idna                      3.4                pyhd8ed1ab_0    conda-forge
importlib-metadata        6.8.0              pyha770c72_0    conda-forge
importlib_metadata        6.8.0                hd8ed1ab_0    conda-forge
importlib_resources       6.0.1              pyhd8ed1ab_0    conda-forge
iniconfig                 2.0.0              pyhd8ed1ab_0    conda-forge
inquirer                  3.1.3              pyhd8ed1ab_0    conda-forge
itsdangerous              2.1.2              pyhd8ed1ab_0    conda-forge
jaraco.classes            3.3.0              pyhd8ed1ab_0    conda-forge
jeepney                   0.8.0              pyhd8ed1ab_0    conda-forge
jinja2                    3.1.2              pyhd8ed1ab_1    conda-forge
jpeg                      9e                   h0b41bf4_3    conda-forge
jsonschema                4.17.3             pyhd8ed1ab_0    conda-forge
keyring                   24.2.0           py38h578d9bd_0    conda-forge
lame                      3.100             h166bdaf_1003    conda-forge
lcms2                     2.15                 hfd0df8a_0    conda-forge
ld_impl_linux-64          2.40                 h41732ed_0    conda-forge
lerc                      4.0.0                h27087fc_0    conda-forge
libblas                   3.9.0            16_linux64_mkl    conda-forge
libcblas                  3.9.0            16_linux64_mkl    conda-forge
libcublas                 11.11.3.6                     0    nvidia
libcufft                  10.9.0.58                     0    nvidia
libcufile                 1.7.2.10                      0    nvidia
libcurand                 10.3.3.141                    0    nvidia
libcusolver               11.4.1.48                     0    nvidia
libcusparse               11.7.5.86                     0    nvidia
libdeflate                1.17                 h0b41bf4_0    conda-forge
libexpat                  2.5.0                hcb278e6_1    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-ng                 13.2.0               h807b86a_1    conda-forge
libglib                   2.78.0               hebfc3b9_0    conda-forge
libhwloc                  2.9.2           default_h554bfaf_1009    conda-forge
libiconv                  1.17                 h166bdaf_0    conda-forge
liblapack                 3.9.0            16_linux64_mkl    conda-forge
libnpp                    11.8.0.86                     0    nvidia
libnsl                    2.0.0                h7f98852_0    conda-forge
libnvjpeg                 11.9.0.86                     0    nvidia
libpng                    1.6.39               h753d276_0    conda-forge
libsqlite                 3.43.0               h2797004_0    conda-forge
libstdcxx-ng              13.2.0               h7e041cc_1    conda-forge
libtiff                   4.5.0                h6adf6a1_2    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libwebp-base              1.3.2                hd590300_0    conda-forge
libxcb                    1.13              h7f98852_1004    conda-forge
libxml2                   2.11.5               h232c23b_1    conda-forge
libzlib                   1.2.13               hd590300_5    conda-forge
lightning                 2.0.7              pyhd8ed1ab_0    conda-forge
lightning-cloud           0.5.38             pyhd8ed1ab_0    conda-forge
lightning-utilities       0.9.0              pyhd8ed1ab_0    conda-forge
llvm-openmp               16.0.6               h4dfa4b3_0    conda-forge
markdown-it-py            2.2.0              pyhd8ed1ab_0    conda-forge
markupsafe                2.1.3            py38h01eb140_0    conda-forge
mdurl                     0.1.0              pyhd8ed1ab_0    conda-forge
mkl                       2022.2.1         h84fe81f_16997    conda-forge
more-itertools            10.1.0             pyhd8ed1ab_0    conda-forge
mpc                       1.3.1                hfe3b2da_0    conda-forge
mpfr                      4.2.0                hb012696_0    conda-forge
mpmath                    1.3.0              pyhd8ed1ab_0    conda-forge
msgpack-python            1.0.5            py38hfbd4bf9_0    conda-forge
ncurses                   6.4                  hcb278e6_0    conda-forge
nettle                    3.6                  he412f7d_0    conda-forge
networkx                  3.1                pyhd8ed1ab_0    conda-forge
nodeenv                   1.8.0              pyhd8ed1ab_0    conda-forge
numpy                     1.24.4           py38h59b608b_0    conda-forge
omegaconf                 2.3.0              pyhd8ed1ab_0    conda-forge
openh264                  2.1.1                h780b84a_0    conda-forge
openjpeg                  2.5.0                hfec8fc6_2    conda-forge
openssl                   3.1.2                hd590300_0    conda-forge
ordered-set               4.1.0              pyhd8ed1ab_0    conda-forge
orjson                    3.9.6            py38h0488081_0    conda-forge
packaging                 23.1               pyhd8ed1ab_0    conda-forge
pcre2                     10.40                hc3806b6_0    conda-forge
pexpect                   4.8.0              pyh1a96a4e_2    conda-forge
pillow                    9.4.0            py38hde6dc18_1    conda-forge
pip                       23.2.1             pyhd8ed1ab_0    conda-forge
pkginfo                   1.9.6              pyhd8ed1ab_0    conda-forge
pkgutil-resolve-name      1.3.10             pyhd8ed1ab_1    conda-forge
platformdirs              3.10.0             pyhd8ed1ab_0    conda-forge
pluggy                    1.3.0              pyhd8ed1ab_0    conda-forge
poetry                    1.6.1           linux_pyha804496_0    conda-forge
poetry-core               1.7.0              pyhd8ed1ab_0    conda-forge
poetry-plugin-export      1.5.0              pyhd8ed1ab_0    conda-forge
pre-commit                3.3.3              pyha770c72_0    conda-forge
psutil                    5.9.5            py38h1de0b5d_0    conda-forge
pthread-stubs             0.4               h36c2ea0_1001    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pycparser                 2.21               pyhd8ed1ab_0    conda-forge
pydantic                  2.1.1              pyhd8ed1ab_0    conda-forge
pydantic-core             2.4.0            py38h0cc4f7c_0    conda-forge
pygments                  2.16.1             pyhd8ed1ab_0    conda-forge
pyjwt                     2.8.0              pyhd8ed1ab_0    conda-forge
pyopenssl                 23.2.0             pyhd8ed1ab_1    conda-forge
pyproject_hooks           1.0.0              pyhd8ed1ab_0    conda-forge
pyrsistent                0.19.3           py38h1de0b5d_0    conda-forge
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
pytest                    7.2.2              pyhd8ed1ab_0    conda-forge
python                    3.8.17          he550d4f_0_cpython    conda-forge
python-build              0.10.0             pyhd8ed1ab_1    conda-forge
python-dateutil           2.8.2              pyhd8ed1ab_0    conda-forge
python-dotenv             1.0.0                    pypi_0    pypi
python-editor             1.0.4                      py_0    conda-forge
python-installer          0.7.0              pyhd8ed1ab_0    conda-forge
python-multipart          0.0.6              pyhd8ed1ab_0    conda-forge
python_abi                3.8                      3_cp38    conda-forge
pytorch                   2.0.1           py3.8_cuda11.8_cudnn8.7.0_0    pytorch
pytorch-cuda              11.8                 h7e8668a_5    pytorch
pytorch-lightning         2.0.9              pyhd8ed1ab_0    conda-forge
pytorch-mutex             1.0                        cuda    pytorch
pytz                      2023.3.post1       pyhd8ed1ab_0    conda-forge
pyyaml                    6.0.1            py38h01eb140_0    conda-forge
rapidfuzz                 2.15.1           py38h8dc9893_0    conda-forge
readchar                  4.0.5              pyhd8ed1ab_0    conda-forge
readline                  8.2                  h8228510_1    conda-forge
requests                  2.31.0             pyhd8ed1ab_0    conda-forge
requests-toolbelt         1.0.0              pyhd8ed1ab_0    conda-forge
rich                      13.3.5             pyhd8ed1ab_0    conda-forge
rootutils                 1.0.7                    pypi_0    pypi
secretstorage             3.3.3            py38h578d9bd_1    conda-forge
setuptools                68.2.2             pyhd8ed1ab_0    conda-forge
shellingham               1.5.3              pyhd8ed1ab_0    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
sniffio                   1.3.0              pyhd8ed1ab_0    conda-forge
soupsieve                 2.5                pyhd8ed1ab_1    conda-forge
starlette                 0.27.0             pyhd8ed1ab_0    conda-forge
starsessions              1.3.0              pyhd8ed1ab_0    conda-forge
submitit                  1.4.5                    pypi_0    pypi
sympy                     1.12            pypyh9d50eac_103    conda-forge
tbb                       2021.10.0            h00ab1b0_0    conda-forge
tk                        8.6.12               h27826a3_0    conda-forge
tomli                     2.0.1              pyhd8ed1ab_0    conda-forge
tomlkit                   0.12.1             pyha770c72_0    conda-forge
torchmetrics              1.1.2              pyhd8ed1ab_0    conda-forge
torchtriton               2.0.0                      py38    pytorch
torchvision               0.15.2               py38_cu118    pytorch
tqdm                      4.66.1             pyhd8ed1ab_0    conda-forge
traitlets                 5.10.0             pyhd8ed1ab_0    conda-forge
trove-classifiers         2023.8.7           pyhd8ed1ab_0    conda-forge
typing-extensions         4.8.0                hd8ed1ab_0    conda-forge
typing_extensions         4.8.0              pyha770c72_0    conda-forge
ukkonen                   1.0.1            py38h43d8883_3    conda-forge
urllib3                   2.0.4              pyhd8ed1ab_0    conda-forge
uvicorn                   0.23.2           py38h578d9bd_0    conda-forge
virtualenv                20.24.4            pyhd8ed1ab_0    conda-forge
wcwidth                   0.2.6              pyhd8ed1ab_0    conda-forge
websocket-client          1.6.3              pyhd8ed1ab_0    conda-forge
websockets                11.0.3           py38h01eb140_0    conda-forge
wheel                     0.41.2             pyhd8ed1ab_0    conda-forge
xorg-libxau               1.0.11               hd590300_0    conda-forge
xorg-libxdmcp             1.1.3                h7f98852_0    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
yaml                      0.2.5                h7f98852_2    conda-forge
zipp                      3.16.2             pyhd8ed1ab_0    conda-forge
zlib                      1.2.13               hd590300_5    conda-forge
zstd                      1.5.5                hfc55251_0    conda-forge

This was from the following environment.yaml file:

# reasons you might want to use `environment.yaml` instead of `requirements.txt`:
# - pip installs packages in a loop, without ensuring dependencies across all packages
#   are fulfilled simultaneously, but conda achieves proper dependency control across
#   all packages
# - conda allows for installing packages without requiring certain compilers or
#   libraries to be available in the system, since it installs precompiled binaries

name: lightning-hydra-template-py38

channels:
  - pytorch
  - nvidia
  - conda-forge
  - defaults

# it is strongly recommended to specify versions of packages installed through conda
# to avoid situation when version-unspecified packages install their latest major
# versions which can sometimes break things

# current approach below keeps the dependencies in the same major versions across all
# users, but allows for different minor and patch versions of packages where backwards
# compatibility is usually guaranteed

dependencies:
  - pytorch=2.0.1
  - pytorch-cuda=11.8
  - torchvision=0.15
  - lightning=2.0.7
  - torchmetrics=1.1.2
  - hydra-core=1.3
  - rich=13.3
  - pre-commit=3.3 # dev
  - pytest=7.2 # dev
  - pydantic=2.1 # dev, required by pytest
  - python=3.8

  # --------- loggers --------- #
  # - wandb
  # - neptune-client
  # - mlflow
  # - comet-ml
  # - aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550

  - pip>=23
  - pip:
      - hydra-colorlog==1.2.0
      - rootutils==1.0.7

Then I ran python src/train.py experiment=example trainer=gpu compile=True and got this error:

[2023-09-18 08:58:05,339][src.utils.utils][INFO] - [rank: 0] Enforcing tags! <cfg.extras.enforce_tags=True>
[2023-09-18 08:58:05,348][src.utils.utils][INFO] - [rank: 0] Printing config tree with Rich! <cfg.extras.print_config=True>
[2023-09-18 08:58:05,458][__main__][INFO] - [rank: 0] Instantiating datamodule <src.data.mnist_datamodule.MNISTDataModule>
[2023-09-18 08:58:05,471][__main__][INFO] - [rank: 0] Instantiating model <src.models.mnist_module.MNISTLitModule>
[2023-09-18 08:58:05,591][__main__][INFO] - [rank: 0] Instantiating callbacks...
[2023-09-18 08:58:05,591][src.utils.instantiators][INFO] - [rank: 0] Instantiating callback <lightning.pytorch.callbacks.ModelCheckpoint>
[2023-09-18 08:58:05,597][src.utils.instantiators][INFO] - [rank: 0] Instantiating callback <lightning.pytorch.callbacks.EarlyStopping>
[2023-09-18 08:58:05,599][src.utils.instantiators][INFO] - [rank: 0] Instantiating callback <lightning.pytorch.callbacks.RichModelSummary>
[2023-09-18 08:58:05,600][src.utils.instantiators][INFO] - [rank: 0] Instantiating callback <lightning.pytorch.callbacks.RichProgressBar>
[2023-09-18 08:58:05,602][__main__][INFO] - [rank: 0] Instantiating loggers...
[2023-09-18 08:58:05,602][__main__][INFO] - [rank: 0] Instantiating trainer <lightning.pytorch.trainer.Trainer>
[2023-09-18 08:58:05,844][__main__][INFO] - [rank: 0] Compiling model!
[2023-09-18 08:58:06,123][__main__][INFO] - [rank: 0] Starting training!
[2023-09-18 08:58:26,486][torch._functorch.aot_autograd][WARNING] - Failed to collect metadata on function, produced code may be suboptimal.  Known situations this can occur are inference mode only compilation involving resize_ or prims (!schema.hasAnyAliasInfo() INTERNAL ASSERT FAILED); if your situation looks different please file a bug to PyTorch.
Traceback (most recent call last):
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1676, in aot_wrapper_dedupe
    fw_metadata, _out = run_functionalized_fw_and_collect_metadata(
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 607, in inner
    flat_f_outs = f(*flat_f_args)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2793, in functional_call
    out = Interpreter(mod).run(*args[params_len:], **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/fx/interpreter.py", line 136, in run
    self.env[node] = self.run_node(node)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/fx/interpreter.py", line 177, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/fx/interpreter.py", line 249, in call_function
    return target(*args, **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_inductor/overrides.py", line 38, in __torch_function__
    return func(*args, **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 987, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1162, in dispatch
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 410, in local_scalar_dense
    raise DataDependentOutputException(func)
torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default

While executing %setitem : [#users=0] = call_function[target=operator.setitem](args = (%add, %eq, 1), kwargs = {})
Original traceback:
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torchmetrics/utilities/compute.py", line 54, in _safe_divide
    denom[denom == 0.0] = 1
 |   File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torchmetrics/functional/classification/accuracy.py", line 83, in _accuracy_reduce
    return _safe_divide(tp, tp + fn)
 |   File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torchmetrics/classification/accuracy.py", line 253, in compute
    return _accuracy_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average)

[2023-09-18 08:58:26,558][src.utils.utils][ERROR] - [rank: 0] 
Traceback (most recent call last):
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 670, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/__init__.py", line 1390, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 455, in compile_fx
    return aot_autograd(
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/backends/common.py", line 48, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2822, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2515, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1804, in aot_wrapper_dedupe
    compiled_fn = compiler_fn(wrapped_flat_fn, deduped_flat_args, aot_config)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1280, in aot_dispatch_base
    _fw_metadata, _out = run_functionalized_fw_and_collect_metadata(
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 607, in inner
    flat_f_outs = f(*flat_f_args)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1802, in wrapped_flat_fn
    return flat_fn(*add_dupe_args(args))
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2793, in functional_call
    out = Interpreter(mod).run(*args[params_len:], **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/fx/interpreter.py", line 136, in run
    self.env[node] = self.run_node(node)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/fx/interpreter.py", line 177, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/fx/interpreter.py", line 249, in call_function
    return target(*args, **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_inductor/overrides.py", line 38, in __torch_function__
    return func(*args, **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 987, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1162, in dispatch
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 410, in local_scalar_dense
    raise DataDependentOutputException(func)
torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default

While executing %setitem : [#users=0] = call_function[target=operator.setitem](args = (%add, %eq, 1), kwargs = {})
Original traceback:
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torchmetrics/utilities/compute.py", line 54, in _safe_divide
    denom[denom == 0.0] = 1
 |   File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torchmetrics/functional/classification/accuracy.py", line 83, in _accuracy_reduce
    return _safe_divide(tp, tp + fn)
 |   File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torchmetrics/classification/accuracy.py", line 253, in compute
    return _accuracy_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average)


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/private/home/mattie/Projects/lightning-hydra-template/src/utils/utils.py", line 70, in wrap
    metric_dict, object_dict = task_func(cfg=cfg)
  File "src/train.py", line 91, in train
    trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 532, in fit
    call._call_and_handle_interrupt(
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 571, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 980, in _run
    results = self._run_stage()
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 1021, in _run_stage
    self._run_sanity_check()
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 1050, in _run_sanity_check
    val_loop.run()
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/lightning/pytorch/loops/utilities.py", line 181, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 115, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 376, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_kwargs.values())
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 293, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/lightning/pytorch/strategies/strategy.py", line 393, in validation_step
    return self.model.validation_step(*args, **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/private/home/mattie/Projects/lightning-hydra-template/src/models/mnist_module.py", line 144, in validation_step
    loss, preds, targets = self.model_step(batch)
  File "/private/home/mattie/Projects/lightning-hydra-template/src/models/mnist_module.py", line 147, in <graph break in validation_step>
    self.val_loss(loss)
  File "/private/home/mattie/Projects/lightning-hydra-template/src/models/mnist_module.py", line 148, in <graph break in validation_step>
    self.val_acc(preds, targets)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torchmetrics/metric.py", line 298, in forward
    self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torchmetrics/metric.py", line 357, in _forward_reduce_state_update
    self.reset()
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torchmetrics/metric.py", line 360, in <graph break in _forward_reduce_state_update>
    self._to_sync = self.dist_sync_on_step
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torchmetrics/metric.py", line 361, in <graph break in _forward_reduce_state_update>
    self._should_unsync = False
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torchmetrics/metric.py", line 363, in <graph break in _forward_reduce_state_update>
    self.compute_on_cpu = False
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torchmetrics/metric.py", line 364, in <graph break in _forward_reduce_state_update>
    self._enable_grad = True  # allow grads for batch computation
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torchmetrics/metric.py", line 367, in <graph break in _forward_reduce_state_update>
    self.update(*args, **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torchmetrics/metric.py", line 368, in <graph break in _forward_reduce_state_update>
    batch_val = self.compute()
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torchmetrics/metric.py", line 602, in wrapped_func
    with self.sync_context(
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torchmetrics/metric.py", line 607, in <graph break in wrapped_func>
    value = _squeeze_if_scalar(compute(*args, **kwargs))
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 404, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
    return fn(*args, **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
    return _compile(
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
    out_code = transform_code_object(code, transform)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
    transformations(instructions, code_options)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
    tracer.run()
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
    super().run()
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1792, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 517, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 588, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised DataDependentOutputException: aten._local_scalar_dense.default

While executing %setitem : [#users=0] = call_function[target=operator.setitem](args = (%add, %eq, 1), kwargs = {})
Original traceback:
  File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torchmetrics/utilities/compute.py", line 54, in _safe_divide
    denom[denom == 0.0] = 1
 |   File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torchmetrics/functional/classification/accuracy.py", line 83, in _accuracy_reduce
    return _safe_divide(tp, tp + fn)
 |   File "/private/home/mattie/miniconda3/envs/lightning-hydra-template-py38/lib/python3.8/site-packages/torchmetrics/classification/accuracy.py", line 253, in compute
    return _accuracy_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average)


Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

[2023-09-18 08:58:26,597][src.utils.utils][INFO] - [rank: 0] Output dir: /private/home/mattie/Projects/lightning-hydra-template/logs/train/runs/2023-09-18_08-58-05

Finally, this is my environment info (using https://raw.githubusercontent.com/pytorch/pytorch/main/torch/utils/collect_env.py):

Collecting environment information...
PyTorch version: 2.0.1
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.4 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.8.17 | packaged by conda-forge | (default, Jun 16 2023, 07:06:00)  [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-124-generic-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 11.7.64
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla V100-SXM2-16GB
Nvidia driver version: 470.141.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   46 bits physical, 48 bits virtual
CPU(s):                          80
On-line CPU(s) list:             0-79
Thread(s) per core:              2
Core(s) per socket:              20
Socket(s):                       2
NUMA node(s):                    2
Vendor ID:                       GenuineIntel
CPU family:                      6
Model:                           79
Model name:                      Intel(R) Xeon(R) CPU E5-2698 v4 @ 2.20GHz
Stepping:                        1
CPU MHz:                         2854.379
CPU max MHz:                     3600.0000
CPU min MHz:                     1200.0000
BogoMIPS:                        4390.03
Virtualization:                  VT-x
L1d cache:                       1.3 MiB
L1i cache:                       1.3 MiB
L2 cache:                        10 MiB
L3 cache:                        100 MiB
NUMA node0 CPU(s):               0-19,40-59
NUMA node1 CPU(s):               20-39,60-79
Vulnerability Itlb multihit:     KVM: Mitigation: Split huge pages
Vulnerability L1tf:              Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds:               Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown:          Mitigation; PTI
Vulnerability Mmio stale data:   Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Mitigation; Clear CPU buffers; SMT vulnerable
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single pti intel_ppin ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a rdseed adx smap intel_pt xsaveopt cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts md_clear flush_l1d

Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] pytorch-lightning==2.0.9
[pip3] torch==2.0.1
[pip3] torchmetrics==1.1.2
[pip3] torchvision==0.15.2
[pip3] triton==2.0.0
[conda] blas                      1.0                         mkl    conda-forge
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] libblas                   3.9.0            16_linux64_mkl    conda-forge
[conda] libcblas                  3.9.0            16_linux64_mkl    conda-forge
[conda] liblapack                 3.9.0            16_linux64_mkl    conda-forge
[conda] mkl                       2022.2.1         h84fe81f_16997    conda-forge
[conda] numpy                     1.24.4           py38h59b608b_0    conda-forge
[conda] pytorch                   2.0.1           py3.8_cuda11.8_cudnn8.7.0_0    pytorch
[conda] pytorch-cuda              11.8                 h7e8668a_5    pytorch
[conda] pytorch-lightning         2.0.9              pyhd8ed1ab_0    conda-forge
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchmetrics              1.1.2              pyhd8ed1ab_0    conda-forge
[conda] torchtriton               2.0.0                      py38    pytorch
[conda] torchvision               0.15.2               py38_cu118    pytorch

This is with the setup where torch.compile is performed on the LightningModule (how it was before) instead of the underlying nn.Module. For starters, it doesn't seem like I'm experiencing the same python version issue, as you can see I'm testing this on python 3.8. However, there looks to be a weird issue happening here that I haven't experienced before, which might possibly be because this is now on the latest version of Lightning and PyTorch, respectively, which should have presumably fixed this issue... If others can test this out, I'd be curious to see if you can get compilation working on the Lightning module.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants