diff --git a/focal_loss.py b/focal_loss.py index 28b9e3e..b0a4175 100644 --- a/focal_loss.py +++ b/focal_loss.py @@ -51,6 +51,7 @@ def __init__(self, def __repr__(self): arg_keys = ['alpha', 'gamma', 'ignore_index', 'reduction'] arg_vals = [self.__dict__[k] for k in arg_keys] + arg_vals = [f'\'{v}\'' if isinstance(v, str) else v for v in arg_vals] arg_strs = [f'{k}={v}' for k, v in zip(arg_keys, arg_vals)] arg_str = ', '.join(arg_strs) return f'{type(self).__name__}({arg_str})'