Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrated Gradients Notebook - np.ndarray kwargs but only tf.Tensor supported #528

Merged
merged 5 commits into from Nov 16, 2021

Conversation

RobertSamoilescu
Copy link
Collaborator

This PR fixes issue #527.
The newer versions of transformers do not support np.ndarray for optional arguments (i.e., attention_mask).
The error is fixed by casting np.darray to tf.Tensor before passing to explain or forward methods.

# the values of the kwargs have to be `tf.Tensor`. 
# see transformers issue #14404: https://github.com/huggingface/transformers/issues/14404
kwargs = {k: tf.constant(v) for k,v in z_test_sample.items() if k == 'attention_mask'}

In addition, I included import matplotlib.cm as it is required for matplotlib >= 3.4.2.

Also used a smaller validation dataset as when using the full testing dataset I ran into GPU memory issues.

    # using the entire testing dataset might result in memory issues when running on GPU
    model_out.fit(train_embbedings, y_train, 
                  validation_data=(test_embbedings[:100], y_test[:100]),
                  epochs=epochs, 
                  batch_size=batch_size,
                  callbacks=[cp_callback],
                  verbose=1)

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@codecov
Copy link

codecov bot commented Nov 15, 2021

Codecov Report

Merging #528 (cc9522c) into master (6100135) will not change coverage.
The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##           master     #528   +/-   ##
=======================================
  Coverage   82.34%   82.34%           
=======================================
  Files          76       76           
  Lines       10334    10334           
=======================================
  Hits         8510     8510           
  Misses       1824     1824           

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants