Skip to content

Commit

Permalink
Merge pull request #10 from Erotemic/dev/0.1.1
Browse files Browse the repository at this point in the history
Dev/0.1.1
  • Loading branch information
Erotemic committed Dec 1, 2018
2 parents 5d73842 + 7c054f0 commit 6d8c935
Show file tree
Hide file tree
Showing 27 changed files with 1,248 additions and 248 deletions.
401 changes: 401 additions & 0 deletions examples/ggr2_matching.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion netharn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.1.0'
__version__ = '0.1.1'
"""
mkinit netharn --noattrs --dry
"""
Expand Down
25 changes: 19 additions & 6 deletions netharn/criterions/focal.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,19 @@ def _backwards_compat_reduction_kw(size_average, reduce, reduction):
raise Exception(
'Must specify both size_average and reduce in '
'torch < 0.4.1, or specify neither and use reduction')
return size_average, reduce
else:
if not size_average and not reduce:
reduction = 'none'
elif size_average and not reduce:
reduction = 'none'
elif size_average and reduce:
reduction = 'elementwise_mean'
elif not size_average and reduce:
reduction = 'sum'
else:
raise ValueError(
'Impossible combination of size_average and reduce')
return size_average, reduce, reduction


class FocalLoss(torch.nn.modules.loss._WeightedLoss):
Expand Down Expand Up @@ -189,14 +201,15 @@ class FocalLoss(torch.nn.modules.loss._WeightedLoss):
def __init__(self, focus=2, weight=None, size_average=None, reduce=None,
reduction='elementwise_mean', ignore_index=-100):

size_average, reduce, reduction = _backwards_compat_reduction_kw(
size_average, reduce, reduction)
if _HAS_REDUCTION:
super(FocalLoss, self).__init__(weight=weight, reduce=reduce,
size_average=size_average,
super(FocalLoss, self).__init__(weight=weight,
# reduce=reduce,
# size_average=size_average,
reduction=reduction)
else:
size_average, reduce = _backwards_compat_reduction_kw(
size_average, reduce, reduction)
super(FocalLoss, self).__init__(weight=weight,
super(FocalLoss, self).__init__(weight=weight, reduce=reduce,
size_average=size_average)
self.size_average = size_average # fix for travis?
self.reduce = reduce
Expand Down

0 comments on commit 6d8c935

Please sign in to comment.