diff --git a/python/tvm/relax/training/loss.py b/python/tvm/relax/training/loss.py index 466c2996e7e0..d98d2d727c7f 100644 --- a/python/tvm/relax/training/loss.py +++ b/python/tvm/relax/training/loss.py @@ -27,7 +27,7 @@ from ..block_builder import BlockBuilder from ..expr import Expr, Var, Function, StructInfo -from ..op import abs, sum, mean, subtract, multiply +from ..op import abs, sum, mean, subtract, multiply, reshape, argmax from ..op.nn import log_softmax, nll_loss @@ -290,3 +290,95 @@ def __call__( bb.emit_func_output(loss) return bb.get()[self._loss_name] + + +class CategoricalCrossEntropyLoss(Loss): + r"""CategoricalCrossEntropyLoss. + It is a combination of a converting one-hot target vector to a label, + a log_softmax computation and a nll_loss. + + Parameters + ---------- + reduction : Literal["mean", "sum", "none"] + The reduction method to apply to output. Can be "mean", "sum" or "none". + + none : no reduction will be applied, + mean : the sum of the output will be divided by the batch_size, + sum : the output will be summed. + + ignore_index : int + Specifies a target value that is ignored and does not contribute to the input gradient. + """ + + ignore_index: int + + def __init__( + self, + reduction: Literal["mean", "sum", "none"] = "mean", + ignore_index: int = -100, + ) -> None: + super().__init__("categorical_cross_entropy_loss", 1, reduction) + self.ignore_index = ignore_index + + def __call__( + self, + predictions: Union[Var, StructInfo], + targets: Union[Var, StructInfo], + weights: Optional[Union[Var, StructInfo]] = None, + ) -> Function: + """Get the relax function of CategoricalCrossEntropyLoss. If the parameters are + struct info, it will create corresponding variables. + + Parameters + ---------- + predictions : Union[Var, StructInfo] + The predictions of the model in the calculation of loss. + + targets : Union[Var, StructInfo] + The ground truth in the calculation of loss. + + weights : Optional[Union[Var, StructInfo]] + a manual rescaling weight given to each class. It has to be a Tensor of size C. + + Returns + ------- + The relax function of CategoricalCrossEntropyLoss with the loss name as its global symbol. + """ + + if not "int" in targets.dtype: + raise TypeError( + f"Dtype of targets expected to be int/uint. \ + However, the dtype of targets is {targets.dtype}" + ) + + bb = BlockBuilder() + + predictions = _create_param_var(predictions, "predictions") + targets = _create_param_var(targets, "targets") + + arg_list = [predictions, targets] + if weights: + weights = _create_param_var(weights, "weights") + arg_list.append(weights) + + # In the case of ignore_index >= 0, + # the nll_loss function is used to handle the ignore index. + # In other cases where ignore_index is not needed, just use the simpe product. + with bb.function(self._loss_name, arg_list): + with bb.dataflow(): + logits = bb.emit(log_softmax(predictions)) + if self.ignore_index >= 0: + targets = bb.emit( + reshape(argmax(targets, axis=1), shape=(targets.struct_info.shape[0],)) + ) + loss = bb.emit_output( + nll_loss(logits, targets, weights, self._reduction, self.ignore_index) + ) + else: + lv = bb.emit(-logits * targets.astype("float32")) + if weights: + lv = bb.emit(lv * weights) + loss = bb.emit_output(self._with_reduction(lv)) + bb.emit_func_output(loss) + + return bb.get()[self._loss_name] diff --git a/tests/python/relax/test_training_loss.py b/tests/python/relax/test_training_loss.py index 68d59dca05c4..0a2418aad756 100644 --- a/tests/python/relax/test_training_loss.py +++ b/tests/python/relax/test_training_loss.py @@ -208,5 +208,81 @@ def expected( assert_structural_equal(After["forward_loss"], expected) +def test_categorical_cross_entropy_loss(): + N = 3 + C = 5 + predictions = relax.TensorStructInfo((N, C), "float32") + targets = relax.TensorStructInfo((N, C), "int64") + weights = relax.TensorStructInfo((C,), "float32") + categorical_cross_entropy_loss = relax.training.loss.CategoricalCrossEntropyLoss( + reduction="sum" + ) + + @R.function + def expected( + predictions: R.Tensor((3, 5), "float32"), + targets: R.Tensor((3, 5), "int64"), + weights: R.Tensor((5,), "float32"), + ) -> R.Tensor((), "float32"): + with R.dataflow(): + lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1) + lv: R.Tensor((), "float32") = -lv * targets.astype("float32") + gv: R.Tensor((), "float32") = R.sum(lv * weights) + R.output(gv) + return gv + + assert_structural_equal(categorical_cross_entropy_loss(predictions, targets, weights), expected) + + +def test_categorical_cross_entropy_loss_without_weights(): + N = 3 + C = 5 + predictions = relax.TensorStructInfo((N, C), "float32") + targets = relax.TensorStructInfo((N, C), "int64") + categorical_cross_entropy_loss = relax.training.loss.CategoricalCrossEntropyLoss() + + @R.function + def expected( + predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5), "int64") + ) -> R.Tensor((), "float32"): + with R.dataflow(): + lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1) + gv: R.Tensor((), "float32") = R.mean(-lv * targets.astype("float32")) + R.output(gv) + return gv + + assert_structural_equal(categorical_cross_entropy_loss(predictions, targets), expected) + + +def test_categorical_cross_entropy_loss_with_ignore_index(): + N = 3 + C = 5 + predictions = relax.TensorStructInfo((N, C), "float32") + targets = relax.TensorStructInfo((N, C), "int64") + weights = relax.TensorStructInfo((C,), "float32") + categorical_cross_entropy_loss = relax.training.loss.CategoricalCrossEntropyLoss( + reduction="sum", ignore_index=1 + ) + + @R.function + def expected( + predictions: R.Tensor((3, 5), "float32"), + targets: R.Tensor((3, 5), "int64"), + weights: R.Tensor((5,), "float32"), + ) -> R.Tensor((), "float32"): + with R.dataflow(): + lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1) + targets = relax.op.reshape( + relax.op.argmax(targets, axis=1), shape=(targets.struct_info.shape[0],) + ) + gv: R.Tensor((), "float32") = R.nn.nll_loss( + lv, targets, weights, reduction="sum", ignore_index=1 + ) + R.output(gv) + return gv + + assert_structural_equal(categorical_cross_entropy_loss(predictions, targets, weights), expected) + + if __name__ == "__main__": tvm.testing.main()