Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Compute attributions w.r.t the predicted logit, not the predicted loss #4882

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

sarahwie
Copy link

Compute gradient attribution with respect to the predicted class' logit to avoid a dependency of the gradient on the loss' distance to 0, which causes a 0 gradient.

See for justification:

@sarahwie
Copy link
Author

Realized this breaks Hotflip, since that relies on the loss. Also, I am not sure how the input reduction method is intended to be calculated (the paper just says "outputs"), but it will change with this, too.

@schmmd schmmd changed the base branch from master to main December 23, 2020 18:47
@@ -113,13 +113,13 @@ def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict
self._model.forward(**dataset_tensor_dict) # type: ignore
)

loss = outputs["loss"]
predicted_logit = outputs["logits"].squeeze(0)[int(torch.argmax(outputs["probs"]))]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trouble with doing it this way is that it hard-codes assumptions about the model's outputs which may not be true. The test failure you're getting is because of this. This method has to be generic enough to work for any model. This is ok when we query the loss key, because that key is already required by the Trainer. Nothing else is guaranteed to be in the output, so we can't hard-code anything else.

Maybe a better way of accomplishing what you want is to allow the caller to specify the output key, with a default value of "loss". Then it would be the model's responsibility make sure that the value in the key is a single number on which we can call .backward(). E.g., you could imagine adding a target_logit key in your model class, and then use that key when calling get_gradients().

We could get by with less model modification if we add a second flag that says whether to take an argmax of the values in that key, but that gets a bit messy, because then you're always getting gradients of the model's prediction, completely ignoring whatever label was given in the input instance. This breaks a lot of assumptions in other methods in the code (which I think is what you were referring to when you said this breaks hotflip), so I don't really like this option.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback! I agree that using a key is straightforward. I'll refactor.

@dirkgr
Copy link
Member

dirkgr commented Jan 14, 2021

Is this still an active project? Can we help in any way?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants