-
-
Notifications
You must be signed in to change notification settings - Fork 287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added NewClassesCrossEntropy criterion and automatic criterion plugin #1514
Changes from 3 commits
5e9d1dd
6e8e1b3
e5f6ed5
a1025cc
06b3acc
5561a44
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,12 +9,17 @@ | |
from avalanche.models import MultiTaskModule, avalanche_forward | ||
|
||
|
||
def cross_entropy_with_oh_targets(outputs, targets, eps=1e-5): | ||
def cross_entropy_with_oh_targets(outputs, targets, eps=1e-5, reduction="mean"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you use pytorch cross-entropy? it is more numerically stable than this implementation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem is that I need to give one hot targets this renders the masking way easier and more natural to implement. The cross entropy from pytorch does not allow that I believe. |
||
"""Calculates cross-entropy with temperature scaling, | ||
targets can also be soft targets but they must sum to 1""" | ||
outputs = torch.nn.functional.softmax(outputs, dim=1) | ||
ce = -(targets * outputs.log()).sum(1) | ||
ce = ce.mean() | ||
if reduction == "mean": | ||
ce = ce.mean() | ||
elif reduction == "none": | ||
return ce | ||
else: | ||
raise NotImplementedError("reduction must be mean or none") | ||
return ce | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is incorrect. You are selecting only "seen" units, which excludes future units (e.g. DER regularizes them).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to have both "seen" and "all" options (not necessarily for all losses).