Skip to content

Commit

Permalink
Add binary_cross_entropy in functional.frontends.torch (#2310)
Browse files Browse the repository at this point in the history
* add frontend.torch.loss_functions and BCE

* add test_loss_functions and edit bce

* revert nn.loss_functions to loss_functions

* Edit formating and edit test code

* Edit formatting

* Update test_loss_functions.py

Update fn_name in helpers.test_frontend_function to fn_tree

* Delete statistical.py

* Update test_loss_functions.py

* Revert "Delete statistical.py"

* Update loss_fuctions and test code

* Update loss_functions formating

* Update test exclude_min and max

* Update reviewed code and test code

Co-authored-by: jiahanxie353 <765130715@qq.com>
  • Loading branch information
whitepurple and jiahanxie353 committed Aug 22, 2022
1 parent 202e97f commit f648e9e
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 30 deletions.
55 changes: 55 additions & 0 deletions ivy/functional/frontends/torch/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,41 @@
import ivy


def _get_reduction_func(reduction):
if reduction == 'none':
ret = lambda x : x
elif reduction == 'mean':
ret = ivy.mean
elif reduction == 'sum':
ret = ivy.sum
else:
raise ValueError("{} is not a valid value for reduction".format(reduction))
return ret


def _legacy_get_string(size_average, reduce):
if size_average is None:
size_average = True
if reduce is None:
reduce = True
if size_average and reduce:
ret = 'mean'
elif reduce:
ret = 'sum'
else:
ret = 'none'
return ret


def _get_reduction(reduction,
size_average=None,
reduce=None):
if size_average is not None or reduce is not None:
return _get_reduction_func(_legacy_get_string(size_average, reduce))
else:
return _get_reduction_func(reduction)


def cross_entropy(
input,
target,
Expand All @@ -16,3 +51,23 @@ def cross_entropy(


cross_entropy.unsupported_dtypes = ("float16",)


def binary_cross_entropy(
input,
target,
weight=None,
size_average=None,
reduce=None,
reduction='mean'
):
reduction = _get_reduction(reduction, size_average, reduce)
result = ivy.binary_cross_entropy(target, input, epsilon=0.0)

if weight is not None:
result = ivy.multiply(weight, result)
result = reduction(result)
return result


binary_cross_entropy.unsupported_dtypes = ('float16', 'float64')
152 changes: 122 additions & 30 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,34 @@
set(ivy_np.valid_float_dtypes).intersection(
set(ivy_torch.valid_float_dtypes)
)
),
min_value=0,
max_value=1,
allow_inf=False,
min_num_dims=2,
max_num_dims=2,
min_dim_size=1,
),
min_value=0,
max_value=1,
allow_inf=False,
min_num_dims=2,
max_num_dims=2,
min_dim_size=1,
),
dtype_and_target=helpers.dtype_and_values(
available_dtypes=tuple(
set(ivy_np.valid_float_dtypes).intersection(
set(ivy_torch.valid_float_dtypes)
)
),
min_value=1.0013580322265625e-05,
max_value=1,
allow_inf=False,
exclude_min=True,
exclude_max=True,
min_num_dims=1,
max_num_dims=1,
min_dim_size=2,
),
as_variable=helpers.list_of_length(x=st.booleans(), length=2),
),
min_value=1.0013580322265625e-05,
max_value=1,
allow_inf=False,
exclude_min=True,
exclude_max=True,
min_num_dims=1,
max_num_dims=1,
min_dim_size=2,
),
as_variable=helpers.list_of_length(x=st.booleans(), length=2),
num_positional_args=helpers.num_positional_args(
fn_name="ivy.functional.frontends.torch.cross_entropy"
),
native_array=helpers.list_of_length(x=st.booleans(), length=2),
),
native_array=helpers.list_of_length(x=st.booleans(), length=2),
)
def test_torch_cross_entropy(
dtype_and_input,
Expand All @@ -54,14 +54,106 @@ def test_torch_cross_entropy(
inputs_dtype, input = dtype_and_input
target_dtype, target = dtype_and_target
helpers.test_frontend_function(
input_dtypes=[inputs_dtype, target_dtype],
as_variable_flags=as_variable,
with_out=False,
num_positional_args=num_positional_args,
native_array_flags=native_array,
fw=fw,
frontend="torch",
fn_tree="nn.functional.cross_entropy",
input=np.asarray(input, dtype=inputs_dtype),
target=np.asarray(target, dtype=target_dtype),
input_dtypes=[inputs_dtype, target_dtype],
as_variable_flags=as_variable,
with_out=False,
num_positional_args=num_positional_args,
native_array_flags=native_array,
fw=fw,
frontend="torch",
fn_tree="nn.functional.cross_entropy",
input=np.asarray(input, dtype=inputs_dtype),
target=np.asarray(target, dtype=target_dtype),
)


# binary_cross_entropy
@given(
dtype_and_true=helpers.dtype_and_values(
available_dtypes=tuple(
set(ivy_np.valid_float_dtypes).intersection(
set(ivy_torch.valid_float_dtypes)
)
),
min_value=0.0,
max_value=1.0,
large_value_safety_factor=1.0,
small_value_safety_factor=1.0,
allow_inf=False,
exclude_min=True,
exclude_max=True,
min_num_dims=1,
max_num_dims=1,
min_dim_size=2,
),
dtype_and_pred=helpers.dtype_and_values(
available_dtypes=tuple(
set(ivy_np.valid_float_dtypes).intersection(
set(ivy_torch.valid_float_dtypes)
)
),
min_value=1.0013580322265625e-05,
max_value=1.0,
large_value_safety_factor=1.0,
small_value_safety_factor=1.0,
allow_inf=False,
exclude_min=True,
exclude_max=True,
min_num_dims=1,
max_num_dims=1,
min_dim_size=2,
),
dtype_and_weight=helpers.dtype_and_values(
available_dtypes=tuple(
set(ivy_np.valid_float_dtypes).intersection(
set(ivy_torch.valid_float_dtypes)
)
),
min_value=1.0013580322265625e-05,
max_value=1.0,
allow_inf=False,
min_num_dims=1,
max_num_dims=1,
min_dim_size=2,
),
size_average=st.booleans(),
reduce=st.booleans(),
reduction=st.sampled_from(["mean", "none", "sum", None]),
as_variable=helpers.list_of_length(x=st.booleans(), length=3),
num_positional_args=helpers.num_positional_args(
fn_name="ivy.functional.frontends.torch.binary_cross_entropy"
),
native_array=helpers.list_of_length(x=st.booleans(), length=3),
)
def test_binary_cross_entropy(
dtype_and_true,
dtype_and_pred,
dtype_and_weight,
size_average,
reduce,
reduction,
as_variable,
num_positional_args,
native_array,
fw,
):
pred_dtype, pred = dtype_and_pred
true_dtype, true = dtype_and_true
weight_dtype, weight = dtype_and_weight

helpers.test_frontend_function(
input_dtypes=[pred_dtype, true_dtype, weight_dtype],
as_variable_flags=as_variable,
with_out=False,
num_positional_args=num_positional_args,
native_array_flags=native_array,
fw=fw,
frontend="torch",
fn_tree="nn.functional.binary_cross_entropy",
input=np.asarray(pred, dtype=pred_dtype),
target=np.asarray(true, dtype=true_dtype),
weight=np.asarray(weight, dtype=weight_dtype),
size_average=size_average,
reduce=reduce,
reduction=reduction,
)

0 comments on commit f648e9e

Please sign in to comment.