Skip to content

Commit

Permalink
Fix bug of CategoricalFocalLoss #20
Browse files Browse the repository at this point in the history
  • Loading branch information
jackguagua committed Jul 22, 2020
1 parent 6d54926 commit 745a201
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
2 changes: 1 addition & 1 deletion deeptables/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,7 @@ def __init__(self, gamma=2., alpha=.25, reduction=losses.Reduction.AUTO, name='f
gamma {float} -- (default: {2.0})
alpha {float} -- (default: {4.0})
"""
super(BinaryFocalLoss, self).__init__(reduction=reduction, name=name)
super(CategoricalFocalLoss, self).__init__(reduction=reduction, name=name)
self.gamma = float(gamma)
self.alpha = float(alpha)

Expand Down
4 changes: 4 additions & 0 deletions deeptables/preprocessing/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from sklearn.utils import column_or_1d
from ..utils import dt_logging, consts

from sklearn.pipeline import Pipeline
from sklearn.base import BaseEstimator,TransformerMixin
from sklearn.pipeline import FeatureUnion

logger = dt_logging.get_logger()


Expand Down

0 comments on commit 745a201

Please sign in to comment.