New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to handle multi label multiclass? #26
Comments
I think you can hack the script at backprob.py to generate the target vector you are looking for:
and change that to something like:
Then you also have to change the respective parent functions that call this function such that you can pass your own target vector, but this should be easy. In that way you can set a arbitrary target vector, which can also contain multiple "1"-entries. This might achieve what you are looking for. I am not deep enough in the math behind saliency maps though. i hope @MisaOgura can comment on this. |
Hi @rrags, apologies for the late reply and thanks @dnns92 for jumping in. The current behaviour is that if Allowing users to have more control over which class to visualise, is something I have been wanting to revisit - please expect an update on this. @dnns92 The modification suggested would make sense when |
Okay thanks. This is pretty much what I was doing and just wanted to make sure I was on the right track and get your insight. |
My network does binary classifcation for detection of 11 different classes. E.G. It predicts if or if not there are apples, oranges, pineapple, and pears in the image I give it. So the output is a binary vector of length 11.
Can I use this project with out modifying it?
I have checked and top_label will always be 9, so what I put for target_label will be ignored. I get the error
The predicted class index 9 does notequal the target class index 3. Calculatingthe gradient w.r.t. the predicted class. 'the gradient w.r.t. the predicted class.'
So in my example, if the network outputs a binary vector of length 4 corresponding to whether or not apple, orange, pineapple, and pear are in the image, how can I make it so that when I set target = 3 the code will show the gradient corresponding to the task of detecting pineapples?
Also, I am using modified ResNet-18 (transfer learned).
The text was updated successfully, but these errors were encountered: