Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 93 additions & 1 deletion python/tvm/relax/training/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ..block_builder import BlockBuilder
from ..expr import Expr, Var, Function, StructInfo

from ..op import abs, sum, mean, subtract, multiply
from ..op import abs, sum, mean, subtract, multiply, reshape, argmax
from ..op.nn import log_softmax, nll_loss


Expand Down Expand Up @@ -290,3 +290,95 @@ def __call__(
bb.emit_func_output(loss)

return bb.get()[self._loss_name]


class CategoricalCrossEntropyLoss(Loss):
r"""CategoricalCrossEntropyLoss.
It is a combination of a converting one-hot target vector to a label,
a log_softmax computation and a nll_loss.

Parameters
----------
reduction : Literal["mean", "sum", "none"]
The reduction method to apply to output. Can be "mean", "sum" or "none".

none : no reduction will be applied,
mean : the sum of the output will be divided by the batch_size,
sum : the output will be summed.

ignore_index : int
Specifies a target value that is ignored and does not contribute to the input gradient.
"""

ignore_index: int

def __init__(
self,
reduction: Literal["mean", "sum", "none"] = "mean",
ignore_index: int = -100,
) -> None:
super().__init__("categorical_cross_entropy_loss", 1, reduction)
self.ignore_index = ignore_index

def __call__(
self,
predictions: Union[Var, StructInfo],
targets: Union[Var, StructInfo],
weights: Optional[Union[Var, StructInfo]] = None,
) -> Function:
"""Get the relax function of CategoricalCrossEntropyLoss. If the parameters are
struct info, it will create corresponding variables.

Parameters
----------
predictions : Union[Var, StructInfo]
The predictions of the model in the calculation of loss.

targets : Union[Var, StructInfo]
The ground truth in the calculation of loss.

weights : Optional[Union[Var, StructInfo]]
a manual rescaling weight given to each class. It has to be a Tensor of size C.

Returns
-------
The relax function of CategoricalCrossEntropyLoss with the loss name as its global symbol.
"""

if not "int" in targets.dtype:
raise TypeError(
f"Dtype of targets expected to be int/uint. \
However, the dtype of targets is {targets.dtype}"
)

bb = BlockBuilder()

predictions = _create_param_var(predictions, "predictions")
targets = _create_param_var(targets, "targets")

arg_list = [predictions, targets]
if weights:
weights = _create_param_var(weights, "weights")
arg_list.append(weights)

# In the case of ignore_index >= 0,
# the nll_loss function is used to handle the ignore index.
# In other cases where ignore_index is not needed, just use the simpe product.
with bb.function(self._loss_name, arg_list):
with bb.dataflow():
logits = bb.emit(log_softmax(predictions))
if self.ignore_index >= 0:
targets = bb.emit(
reshape(argmax(targets, axis=1), shape=(targets.struct_info.shape[0],))
)
loss = bb.emit_output(
nll_loss(logits, targets, weights, self._reduction, self.ignore_index)
)
else:
lv = bb.emit(-logits * targets.astype("float32"))
if weights:
lv = bb.emit(lv * weights)
loss = bb.emit_output(self._with_reduction(lv))
bb.emit_func_output(loss)

return bb.get()[self._loss_name]
76 changes: 76 additions & 0 deletions tests/python/relax/test_training_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,5 +208,81 @@ def expected(
assert_structural_equal(After["forward_loss"], expected)


def test_categorical_cross_entropy_loss():
N = 3
C = 5
predictions = relax.TensorStructInfo((N, C), "float32")
targets = relax.TensorStructInfo((N, C), "int64")
weights = relax.TensorStructInfo((C,), "float32")
categorical_cross_entropy_loss = relax.training.loss.CategoricalCrossEntropyLoss(
reduction="sum"
)

@R.function
def expected(
predictions: R.Tensor((3, 5), "float32"),
targets: R.Tensor((3, 5), "int64"),
weights: R.Tensor((5,), "float32"),
) -> R.Tensor((), "float32"):
with R.dataflow():
lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1)
lv: R.Tensor((), "float32") = -lv * targets.astype("float32")
gv: R.Tensor((), "float32") = R.sum(lv * weights)
R.output(gv)
return gv

assert_structural_equal(categorical_cross_entropy_loss(predictions, targets, weights), expected)


def test_categorical_cross_entropy_loss_without_weights():
N = 3
C = 5
predictions = relax.TensorStructInfo((N, C), "float32")
targets = relax.TensorStructInfo((N, C), "int64")
categorical_cross_entropy_loss = relax.training.loss.CategoricalCrossEntropyLoss()

@R.function
def expected(
predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5), "int64")
) -> R.Tensor((), "float32"):
with R.dataflow():
lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1)
gv: R.Tensor((), "float32") = R.mean(-lv * targets.astype("float32"))
R.output(gv)
return gv

assert_structural_equal(categorical_cross_entropy_loss(predictions, targets), expected)


def test_categorical_cross_entropy_loss_with_ignore_index():
N = 3
C = 5
predictions = relax.TensorStructInfo((N, C), "float32")
targets = relax.TensorStructInfo((N, C), "int64")
weights = relax.TensorStructInfo((C,), "float32")
categorical_cross_entropy_loss = relax.training.loss.CategoricalCrossEntropyLoss(
reduction="sum", ignore_index=1
)

@R.function
def expected(
predictions: R.Tensor((3, 5), "float32"),
targets: R.Tensor((3, 5), "int64"),
weights: R.Tensor((5,), "float32"),
) -> R.Tensor((), "float32"):
with R.dataflow():
lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1)
targets = relax.op.reshape(
relax.op.argmax(targets, axis=1), shape=(targets.struct_info.shape[0],)
)
gv: R.Tensor((), "float32") = R.nn.nll_loss(
lv, targets, weights, reduction="sum", ignore_index=1
)
R.output(gv)
return gv

assert_structural_equal(categorical_cross_entropy_loss(predictions, targets, weights), expected)


if __name__ == "__main__":
tvm.testing.main()