Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError when Registering Hooks on Tensor in mm_albef_dataset #8

Open
othmane42 opened this issue Apr 18, 2024 · 2 comments
Open

Comments

@othmane42
Copy link

First of all, Thank you for sharing the code base of this interesting work!

I encountered an issue while trying to run the following command python mm-shap_albef_dataset.py 3 "refcoco" "yes".
Below is the error message I received :
RuntimeError: cannot register a hook on a tensor that doesn't require gradient
Here's the full Traceback:

  0%|                                                                                             | 0/3 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "mm-shap_albef_dataset.py", line 307, in <module>
    shap_values = explainer(X)
  File "/mnt/c/Users/Documents/work/MM-SHAP/shap/explainers/_permutation.py", line 62, in __call__
    batch_size=batch_size, outputs=outputs, silent=silent
  File "/mnt/c/Users/Documents/work/MM-SHAP/shap/explainers/_permutation.py", line 76, in __call__
    outputs=outputs, silent=silent
  File "/mnt/c/Users/Documents/work/MM-SHAP/shap/explainers/_explainer.py", line 260, in __call__
    batch_size=batch_size, outputs=outputs, silent=silent, **kwargs
  File "/mnt/c/Users/Documents/work/MM-SHAP/shap/explainers/_permutation.py", line 134, in explain_row
    outputs = fm(masks, zero_index=0, batch_size=batch_size)
  File "/mnt/c/Users/Documents/work/MM-SHAP/shap/utils/_masked_model.py", line 65, in __call__
    return self._full_masking_call(full_masks, zero_index=zero_index, batch_size=batch_size)
  File "/mnt/c/Users/Documents/work/MM-SHAP/shap/utils/_masked_model.py", line 141, in _full_masking_call
    outputs = self.model(*joined_masked_inputs)
  File "/mnt/c/Users/Documents/work/MM-SHAP/shap/models/_model.py", line 21, in __call__
    return np.array(self.inner_model(*args))
  File "mm-shap_albef_dataset.py", line 192, in get_model_prediction
    masked_text_inputs.to("cuda"))
  File "/home/anaconda3/envs/shap/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "mm-shap_albef_dataset.py", line 100, in forward
    return_dict=True,
  File "/home/anaconda3/envs/shap/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/c/Users/Documents/work/phd_work/MM-SHAP/ALBEF/models/xbert.py", line 1067, in forward
    mode=mode,
  File "/home/anaconda3/envs/shap/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/c/Users/Documents/work/phd_work/MM-SHAP/ALBEF/models/xbert.py", line 601, in forward
    output_attentions,
  File "/home/anaconda3/envs/shap/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/c/Users/Documents/work/phd_work/MM-SHAP/ALBEF/models/xbert.py", line 504, in forward
    output_attentions=output_attentions,
  File "/home/anaconda3/envs/shap/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/c/Users/Documents/work/phd_work/MM-SHAP/ALBEF/models/xbert.py", line 407, in forward
    output_attentions,
  File "/home/anaconda3/envs/shap/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/c/Users/Documents/work/phd_work/MM-SHAP/ALBEF/models/xbert.py", line 329, in forward
    attention_probs.register_hook(self.save_attn_gradients)
  File "/home/anaconda3/envs/shap/lib/python3.6/site-packages/torch/_tensor.py", line 289, in register_hook
    raise RuntimeError("cannot register a hook on a tensor that "
RuntimeError: cannot register a hook on a tensor that doesn't require gradient

I resolved this issue by setting save_attention=False at this line 215 - mm_albef_dataset.py :

 model.text_encoder.base_model.base_model.encoder.layer[
        block_num].crossattention.self.save_attention = False  

My question is, is it mandatory to keep registering the attention gradients to accuretly calculate the textual and visual contributions?

@ChengYuChuan
Copy link

Maybe showing your virtual environment and GPU environment will help solve the problem.

@othmane42
Copy link
Author

othmane42 commented Apr 23, 2024

Thank you for the suggestion, @ChengYuChuan here are the GPU and virtual environment used :

GPU environment :

 nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Fri_Nov__3_17:16:49_PDT_2023
Cuda compilation tools, release 12.3, V12.3.103
Build cuda_12.3.r12.3/compiler.33492891_0

GPU resource: NVIDIA Geforce RTX 3080 laptop GPU

Virtual environment :

(I'm using the same conda env provided in the repo to ensure reproducibility)

# packages in environment at /home/anaconda3/envs/shap:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main    anaconda
_openmp_mutex             4.5                       1_gnu    anaconda
_py-xgboost-mutex         2.0                       cpu_0    anaconda
abseil-cpp                20210324.2           h9c3ff4c_0    conda-forge
aiohttp                   3.7.4.post0      py36h8f6f2f9_0    conda-forge
argon2-cffi               20.1.0           py36h27cfd23_1    anaconda
arrow-cpp                 3.0.0            py36h6b21186_4    anaconda
async-timeout             3.0.1                   py_1000    conda-forge
async_generator           1.10             py36h28b3542_0    anaconda
attrs                     21.2.0             pyhd8ed1ab_0    conda-forge
autopep8                  1.5.7              pyhd3eb1b0_0    anaconda
aws-c-common              0.4.57               he6710b0_1    anaconda
aws-c-event-stream        0.1.6                h2531618_5    anaconda
aws-checksums             0.1.9                he6710b0_0    anaconda
aws-sdk-cpp               1.8.185              hce553d0_0    anaconda
backports                 1.0                        py_2    anaconda
backports.functools_lru_cache 1.6.4              pyhd8ed1ab_0    conda-forge
blas                      1.0                         mkl    anaconda
bleach                    4.0.0              pyhd3eb1b0_0    anaconda
boost-cpp                 1.69.0            h11c811c_1000    conda-forge
brotli                    1.0.9                h7f98852_5    conda-forge
brotli-bin                1.0.9                h7f98852_5    conda-forge
brotlipy                  0.7.0           py36h27cfd23_1003    anaconda
bzip2                     1.0.8                h7b6447c_0    anaconda
c-ares                    1.17.1               h27cfd23_0    anaconda
ca-certificates           2020.10.14                    0    anaconda
certifi                   2020.6.20                py36_0    anaconda
cffi                      1.14.6           py36h400218f_0    anaconda
chardet                   4.0.0            py36h5fab9bb_1    conda-forge
charset-normalizer        2.0.4              pyhd3eb1b0_0    anaconda
click                     7.1.2              pyh9f0ad1d_0    conda-forge
cloudpickle               2.0.0              pyhd3eb1b0_0    anaconda
configparser              5.2.0              pyhd8ed1ab_0    conda-forge
cryptography              3.4.7            py36hd23ed53_0    anaconda
cudatoolkit               11.1.74              h6bb024c_0    nvidia
cycler                    0.10.0                   py36_0    anaconda
cytoolz                   0.11.0           py36h7b6447c_0    anaconda
dask-core                 2021.3.0           pyhd3eb1b0_0    anaconda
dataclasses               0.8                pyh4f3eec9_6    anaconda
datasets                  1.12.1             pyhd8ed1ab_1    conda-forge
dbus                      1.13.18              hb2f20db_0    anaconda
decorator                 5.1.0              pyhd8ed1ab_0    conda-forge
defusedxml                0.7.1              pyhd3eb1b0_0    anaconda
dill                      0.3.4              pyhd8ed1ab_0    conda-forge
docker-pycreds            0.4.0                      py_0    anaconda
double-conversion         3.1.5                h9c3ff4c_2    conda-forge
entrypoints               0.3             pyhd8ed1ab_1003    conda-forge
expat                     2.4.1                h2531618_2    anaconda
ffmpeg                    4.2.2                h20bf706_0    anaconda
filelock                  3.0.12             pyhd3eb1b0_1    anaconda
fontconfig                2.13.1               h6c09931_0    anaconda
freetype                  2.10.4               h5ab3b9f_0    anaconda
fsspec                    2021.10.0          pyhd8ed1ab_0    conda-forge
gflags                    2.2.2             he1b5a44_1004    conda-forge
gitdb                     4.0.9              pyhd8ed1ab_0    conda-forge
gitpython                 3.1.11                     py_0    conda-forge
glib                      2.69.1               h5202010_0    anaconda
glog                      0.5.0                h48cff8f_0    conda-forge
gmp                       6.2.1                h2531618_2    anaconda
gnutls                    3.6.15               he1e5248_0    anaconda
grpc-cpp                  1.39.0               hae934f6_5    anaconda
gst-plugins-base          1.14.0               h8213a91_2    anaconda
gstreamer                 1.14.0               h28cd5cc_2    anaconda
hdf5                      1.10.2               hba1933b_1    anaconda
huggingface_hub           0.0.17                     py_0    huggingface
icu                       58.2                 he6710b0_3    anaconda
idna                      3.2                pyhd3eb1b0_0    anaconda
idna_ssl                  1.1.0           py36h9f0ad1d_1001    conda-forge
imagehash                 4.2.1              pyhd3eb1b0_0    anaconda
imageio                   2.9.0              pyhd3eb1b0_0    anaconda
importlib-metadata        4.8.1            py36h06a4308_0    anaconda
importlib_metadata        4.8.1                hd3eb1b0_0    anaconda
intel-openmp              2021.3.0          h06a4308_3350    anaconda
ipykernel                 5.5.5            py36hcb3619a_0    conda-forge
ipython                   5.8.0                    py36_1    conda-forge
ipython_genutils          0.2.0                      py_1    conda-forge
ipywidgets                7.6.5              pyhd3eb1b0_1    anaconda
jinja2                    3.0.1              pyhd3eb1b0_0    anaconda
joblib                    1.0.1              pyhd3eb1b0_0    anaconda
jpeg                      9b                   h024ee3a_2
jsonschema                3.2.0              pyhd3eb1b0_2    anaconda
jupyter_client            7.0.6              pyhd8ed1ab_0    conda-forge
jupyter_core              4.8.1            py36h5fab9bb_0    conda-forge
jupyterlab_pygments       0.1.2                      py_0    anaconda
jupyterlab_widgets        1.0.0              pyhd3eb1b0_1    anaconda
kiwisolver                1.3.1            py36h2531618_0    anaconda
krb5                      1.19.2               hcc1bbae_0    conda-forge
lame                      3.100                h7b6447c_0    anaconda
lcms2                     2.12                 h3be6417_0    anaconda
ld_impl_linux-64          2.35.1               h7274673_9    anaconda
libboost                  1.73.0              h3ff78a5_11
libbrotlicommon           1.0.9                h7f98852_5    conda-forge
libbrotlidec              1.0.9                h7f98852_5    conda-forge
libbrotlienc              1.0.9                h7f98852_5    conda-forge
libcurl                   7.78.0               h0b77cf5_0    anaconda
libedit                   3.1.20191231         he28a2e2_2    conda-forge
libev                     4.33                 h516909a_1    conda-forge
libevent                  2.1.10               hcdb4288_3    conda-forge
libffi                    3.3                  he6710b0_2    anaconda
libgcc-ng                 9.3.0               h5101ec6_17    anaconda
libgfortran-ng            7.5.0               ha8ba4b0_17    anaconda
libgfortran4              7.5.0               ha8ba4b0_17    anaconda
libgomp                   9.3.0               h5101ec6_17    anaconda
libidn2                   2.3.2                h7f8727e_0    anaconda
libllvm10                 10.0.1               hbcb73fb_5    anaconda
libnghttp2                1.43.0               h812cca2_0    conda-forge
libopus                   1.3.1                h7b6447c_0    anaconda
libpng                    1.6.37               hbc83047_0    anaconda
libprotobuf               3.17.2               h4ff587b_1    anaconda
libsodium                 1.0.18               h36c2ea0_1    conda-forge
libssh2                   1.9.0                h1ba5d50_1    anaconda
libstdcxx-ng              9.3.0               hd4cf53a_17    anaconda
libtasn1                  4.16.0               h27cfd23_0    anaconda
libthrift                 0.14.2               he6d91bd_1    conda-forge
libtiff                   4.2.0                h85742a9_0
libunistring              0.9.10               h27cfd23_0    anaconda
libuuid                   1.0.3                h1bed415_2    anaconda
libuv                     1.40.0               h7b6447c_0    anaconda
libvpx                    1.7.0                h439df22_0    anaconda
libwebp-base              1.2.0                h27cfd23_0    anaconda
libxcb                    1.14                 h7b6447c_0    anaconda
libxgboost                1.3.3                h2531618_0    anaconda
libxml2                   2.9.12               h03d6c58_0    anaconda
llvmlite                  0.36.0           py36h612dafd_4    anaconda
lz4-c                     1.9.3                h295c915_1    anaconda
markupsafe                2.0.1            py36h27cfd23_0    anaconda
matplotlib                3.3.4            py36h06a4308_0    anaconda
matplotlib-base           3.3.4            py36h62a2d02_0    anaconda
mistune                   0.8.4            py36h7b6447c_0    anaconda
mkl                       2020.2                      256    anaconda
mkl-service               2.3.0            py36he8ac12f_0
mkl_fft                   1.3.0            py36h54f3939_0
mkl_random                1.1.1            py36h0573a6f_0    anaconda
multidict                 5.1.0            py36h27cfd23_2    anaconda
multiprocess              0.70.12.2        py36h8f6f2f9_0    conda-forge
nbclient                  0.5.3              pyhd3eb1b0_0    anaconda
nbconvert                 6.0.7                    py36_0    anaconda
nbformat                  5.1.3              pyhd3eb1b0_0    anaconda
ncurses                   6.2                  he6710b0_1    anaconda
nest-asyncio              1.5.1              pyhd8ed1ab_0    conda-forge
nettle                    3.7.3                hbbd107a_1    anaconda
networkx                  2.5                        py_0    anaconda
ninja                     1.10.2               hff7bd54_1    anaconda
notebook                  6.3.0            py36h06a4308_0    anaconda
numba                     0.53.1           py36ha9443f7_0    anaconda
numpy                     1.19.2           py36h54aff64_0
numpy-base                1.19.2           py36hfa32c7d_0
olefile                   0.46                     py36_0    anaconda
opencv                    3.4.1            py36h6fd60c2_1    anaconda
opencv-python             4.5.3.56                 pypi_0    pypi
openh264                  2.1.0                hd408876_0    anaconda
openjpeg                  2.4.0                h3ad879b_0    anaconda
openssl                   1.1.1n               h7f8727e_0    anaconda
orc                       1.6.9                ha97a36c_3    anaconda
packaging                 21.0               pyhd3eb1b0_0    anaconda
pandas                    1.1.5            py36ha9443f7_0    anaconda
pandoc                    2.12                 h06a4308_0    anaconda
pandocfilters             1.4.3            py36h06a4308_1    anaconda
pathtools                 0.1.2                      py_1    anaconda
pcre                      8.45                 h295c915_0    anaconda
pexpect                   4.8.0              pyh9f0ad1d_2    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    8.3.1            py36h2c7a002_0    anaconda
pip                       21.2.2           py36h06a4308_0    anaconda
prometheus_client         0.11.0             pyhd3eb1b0_0    anaconda
promise                   2.3              py36h5fab9bb_4    conda-forge
prompt_toolkit            1.0.15                     py_1    conda-forge
protobuf                  3.17.2           py36h295c915_0    anaconda
psutil                    5.8.0            py36h27cfd23_1    anaconda
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
py-xgboost                1.3.3            py36h06a4308_0    anaconda
pyarrow                   3.0.0            py36he0739d4_3    anaconda
pycodestyle               2.7.0              pyhd3eb1b0_0    anaconda
pycparser                 2.20                       py_2    anaconda
pygments                  2.10.0             pyhd8ed1ab_0    conda-forge
pyopenssl                 20.0.1             pyhd3eb1b0_1    anaconda
pyparsing                 2.4.7              pyhd3eb1b0_0    anaconda
pyqt                      5.9.2            py36h05f1152_2    anaconda
pyrsistent                0.17.3           py36h7b6447c_0    anaconda
pysocks                   1.7.1            py36h06a4308_0    anaconda
python                    3.6.13               h12debd9_1    anaconda
python-dateutil           2.8.2              pyhd3eb1b0_0    anaconda
python-wget               3.2                        py_0    conda-forge
python-xxhash             2.0.2            py36h8f6f2f9_0    conda-forge
python_abi                3.6                     1_cp36m    huggingface
pytorch                   1.9.1           py3.6_cuda11.1_cudnn8.0.5_0    pytorch
pytz                      2021.1             pyhd3eb1b0_0    anaconda
pywavelets                1.1.1            py36h7b6447c_2    anaconda
pyyaml                    5.4.1            py36h27cfd23_1    anaconda
pyzmq                     19.0.2           py36h9947dbf_2    conda-forge
qt                        5.9.7                h5867ecd_1
re2                       2021.08.01           h9c3ff4c_0    conda-forge
readline                  8.1                  h27cfd23_0    anaconda
regex                     2021.8.3         py36h7f8727e_0    anaconda
requests                  2.26.0             pyhd3eb1b0_0    anaconda
ruamel_yaml               0.15.87          py36h7b6447c_1    anaconda
sacremoses                master                     py_0    huggingface
scikit-image              0.17.2           py36hdf5156a_0    anaconda
scikit-learn              0.24.2           py36ha9443f7_0    anaconda
scipy                     1.5.2            py36h0b6359f_0
send2trash                1.8.0              pyhd3eb1b0_1    anaconda
sentry-sdk                1.5.4              pyhd8ed1ab_0    conda-forge
setuptools                58.0.4           py36h06a4308_0    anaconda
shortuuid                 1.0.1                      py_0    conda-forge
simplegeneric             0.8.1                      py_1    conda-forge
sip                       4.19.8           py36hf484d3e_0    anaconda
six                       1.16.0             pyhd3eb1b0_0    anaconda
slicer                    0.0.7              pyhd8ed1ab_0    conda-forge
smmap                     3.0.5              pyh44b312d_0    conda-forge
snappy                    1.1.8                he1b5a44_3    conda-forge
sqlite                    3.36.0               hc218d9a_0    anaconda
subprocess32              3.5.4                      py_1    anaconda
tbb                       2020.3               hfd86e86_0    anaconda
termcolor                 1.1.0                      py_2    conda-forge
terminado                 0.9.4            py36h06a4308_0    anaconda
testpath                  0.5.0              pyhd3eb1b0_0    anaconda
threadpoolctl             2.2.0              pyh0d69192_0    anaconda
tifffile                  2020.10.1        py36hdd07704_2    anaconda
timm                      0.5.4                    pypi_0    pypi
tk                        8.6.11               h1ccaba5_0    anaconda
tokenizers                0.10.3                   py36_0    huggingface
toml                      0.10.2             pyhd3eb1b0_0    anaconda
toolz                     0.11.2             pyhd3eb1b0_0    anaconda
torchaudio                0.9.1                      py36    pytorch
torchvision               0.10.1               py36_cu111    pytorch
tornado                   6.1              py36h27cfd23_0    anaconda
tqdm                      4.62.2             pyhd3eb1b0_1    anaconda
traitlets                 4.3.3              pyhd8ed1ab_2    conda-forge
transformers              4.11.1                     py_0    huggingface
typing-extensions         3.10.0.2             hd3eb1b0_0    anaconda
typing_extensions         3.10.0.2           pyh06a4308_0    anaconda
uriparser                 0.9.3                he1b5a44_1    conda-forge
urllib3                   1.26.6             pyhd3eb1b0_1    anaconda
utf8proc                  2.6.1                h27cfd23_0    anaconda
wandb                     0.12.10            pyhd8ed1ab_0    conda-forge
wcwidth                   0.2.5              pyh9f0ad1d_2    conda-forge
webencodings              0.5.1                    py36_1    anaconda
wheel                     0.37.0             pyhd3eb1b0_1    anaconda
widgetsnbextension        3.5.1                    py36_0    anaconda
x264                      1!157.20191217       h7b6447c_0    anaconda
xgboost                   1.3.3            py36h06a4308_0    anaconda
xxhash                    0.8.0                h7f98852_3    conda-forge
xz                        5.2.5                h7b6447c_0    anaconda
yaml                      0.2.5                h7b6447c_0    anaconda
yarl                      1.6.3            py36h8f6f2f9_2    conda-forge
yaspin                    2.1.0              pyhd8ed1ab_0    conda-forge
zeromq                    4.3.4                h9c3ff4c_0    conda-forge
zipp                      3.5.0              pyhd3eb1b0_0    anaconda
zlib                      1.2.11               h7b6447c_3    anaconda
zstd                      1.4.9                haebb681_0    anaconda

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants